diff --git a/03b-multi-node-broadcast/src/connection.rs b/03b-multi-node-broadcast/src/connection.rs deleted file mode 100644 index 983d24c..0000000 --- a/03b-multi-node-broadcast/src/connection.rs +++ /dev/null @@ -1,44 +0,0 @@ -use crate::message::Message; -use std::io::{BufRead, Write}; - -#[derive(Debug)] -pub struct Connection<'a> { - reader: std::io::BufReader>, - writer: std::io::Stdout, -} - -impl<'a> Connection<'a> { - pub fn new(stdin: std::io::Stdin) -> Self { - Connection { - reader: std::io::BufReader::new(stdin.lock()), - writer: std::io::stdout(), - } - } - - pub fn read_one(&mut self) -> Option { - let mut buf = String::new(); - let _ = self.reader.read_line(&mut buf); - return Some(Message::parse_message(buf)); - } - - pub fn read(&mut self) -> Option { - let mut buffer = String::new(); - - match self.reader.read_line(&mut buffer) { - Ok(bytes_read) => { - if bytes_read > 0 { - serde_json::from_str(&buffer).ok() - } else { - None - } - } - Err(_) => None, - } - } - - pub fn write(&mut self, message: Message) { - let message = Message::format_message(message); - writeln!(self.writer, "{}", message).unwrap(); - self.writer.flush().unwrap(); - } -} diff --git a/03b-multi-node-broadcast/src/main.rs b/03b-multi-node-broadcast/src/main.rs index e3b662c..16059c8 100644 --- a/03b-multi-node-broadcast/src/main.rs +++ b/03b-multi-node-broadcast/src/main.rs @@ -1,115 +1,178 @@ -mod connection; mod message; mod node; mod storage; -use crate::connection::Connection; use crate::message::{Body, Message}; use crate::node::Node; -fn main() { - let stdin = std::io::stdin(); - let mut connection = Connection::new(stdin); +use std::sync::Arc; +use std::time::Duration; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}; +use tokio::sync::Mutex; +use tokio::sync::{mpsc, mpsc::Receiver, mpsc::Sender}; +use tokio::time; - let mut node = init_node(&mut connection); +#[tokio::main] +async fn main() { + // let mut interval = time::interval(Duration::from_secs(5)); - while let Some(message) = connection.read() { - handle_message(&mut node, &mut connection, message); + let (reader_tx, mut reader_rx) = mpsc::channel(100); + let (writer_tx, mut writer_rx) = mpsc::channel(100); + + let node = Arc::new(Mutex::new(Node::default())); + // let writer_tx = Arc::new(Mutex::new(writer_tx)); + + let n1 = node.clone(); + let n2 = node.clone(); + + let w1_tx = writer_tx.clone(); + let w2_tx = writer_tx.clone(); + + let read = tokio::spawn(async move { + read_from_stdin(reader_tx).await; + }); + + let write = tokio::spawn(async move { + write_to_stdout(&mut writer_rx).await; + }); + + // tokio::spawn(async move { + // loop { + // interval.tick().await; + // gossip_messages(n1.clone(), w1_tx.clone()).await; + // } + // }); + + let handle = tokio::spawn(async move { + handle_messages(n2, &mut reader_rx, writer_tx).await; + }); + + let _ = tokio::join!(read, write, handle); +} + +async fn read_from_stdin(reader_tx: Sender) { + let stdin = tokio::io::stdin(); + let mut reader = BufReader::new(stdin).lines(); + eprintln!("Reading from stdin"); + while let Ok(Some(line)) = reader.next_line().await { + eprintln!("Reading next line {line:?}"); + let message = Message::parse_message(line); + reader_tx.send(message).await.unwrap(); } } -fn init_node(connection: &mut Connection) -> Node { - let input = connection.read_one().expect("Didn't get input"); +async fn write_to_stdout(writer_rx: &mut Receiver) { + let stdout = tokio::io::stdout(); + let mut writer = BufWriter::new(stdout); - let node; - match input.body { - Body::Init { msg_id, .. } => { - node = Node::init(input.clone()); - - let response = Message { - src: node.id.clone(), - dest: input.src, - body: Body::InitOk { - in_reply_to: msg_id, - }, - }; - - connection.write(response); + while let Some(message) = writer_rx.recv().await { + let message = Message::format_message(message); + if let Err(e) = writer.write_all(message.as_bytes()).await { + eprintln!("Failed to write to stdout: {}", e); } - _ => panic!("Node is not initalized yet"), - } - node + if let Err(e) = writer.flush().await { + eprintln!("Failed to flush stdout: {}", e); + } + } } -fn handle_message(node: &mut Node, connection: &mut Connection, input: Message) { - match input.body { - Body::Broadcast { msg_id, message } => { - node.storage.add_message(message); +async fn gossip_messages(node: Arc>, writer: Arc>>) { + let mut node = node.lock().await; + let writer = writer.lock().await; - let response = Message { - src: node.id.clone(), - dest: input.src, - body: Body::BroadcastOk { - msg_id, - in_reply_to: msg_id, - }, - }; + for n in node.storage.get_neighbours(&node.get_id()) { + let messages = node.storage.get_messages_for_node(n.clone()); + let message = Message { + src: node.id.clone(), + dest: n.clone(), + body: Body::Gossip { + messages: messages.clone(), + }, + }; - connection.write(response); + let _ = writer.send(message).await.unwrap(); + node.storage + .add_to_sent_messages(messages.into_iter().collect(), n); + } +} - let nodes = node.storage.get_neighbours(node.id.clone()); - - for n in nodes { - let output = Message { - src: node.id.clone(), - dest: n, - body: Body::Broadcast { - msg_id, - message: node.storage.get_messages().last().unwrap().clone(), +async fn handle_messages( + node: Arc>, + input: &mut Receiver, + writer: Sender, +) { + while let Some(input) = input.recv().await { + match input.body { + Body::Init { msg_id, .. } => { + node.lock().await.init(input.clone()); + let id = node.lock().await.get_id(); + let response = Message { + src: id, + dest: input.src, + body: Body::InitOk { + in_reply_to: msg_id, }, }; - connection.write(output); + let _ = writer.send(response).await; } - } - Body::Read { msg_id } => { - let output = Message { - src: node.id.clone(), - dest: input.src, - body: Body::ReadOk { - msg_id, - in_reply_to: msg_id, - messages: node.storage.get_messages(), - }, - }; + Body::Broadcast { msg_id, message } => { + let id = node.lock().await.get_id(); + node.lock().await.storage.add_message(message, id.clone()); - connection.write(output); - } - Body::Topology { msg_id, topology } => { - node.storage.init_topology(topology); + let response = Message { + src: id, + dest: input.src, + body: Body::BroadcastOk { + msg_id, + in_reply_to: msg_id, + }, + }; - let output = Message { - src: node.id.clone(), - dest: input.src, - body: Body::TopologyOk { - msg_id, - in_reply_to: msg_id, - }, - }; + let _ = writer.send(response).await; + } + Body::Read { msg_id } => { + let id = node.lock().await.get_id(); - connection.write(output); + let output = Message { + src: id, + dest: input.src, + body: Body::ReadOk { + msg_id, + in_reply_to: msg_id, + messages: node.lock().await.storage.get_messages(), + }, + }; + + let _ = writer.send(output).await; + } + Body::Topology { msg_id, topology } => { + let id = node.lock().await.get_id(); + node.lock().await.storage.init_topology(topology); + + let output = Message { + src: id, + dest: input.src, + body: Body::TopologyOk { + msg_id, + in_reply_to: msg_id, + }, + }; + + let _ = writer.send(output).await; + } + Body::Error { + in_reply_to, + code, + text, + } => { + eprintln!( + "Error received (in_reply_to: {}, code: {}, text: {})", + in_reply_to, code, text + ); + } + _ => (), } - Body::Error { - in_reply_to, - code, - text, - } => { - eprintln!( - "Error received (in_reply_to: {}, code: {}, text: {})", - in_reply_to, code, text - ); - } - _ => (), } } diff --git a/03b-multi-node-broadcast/src/message.rs b/03b-multi-node-broadcast/src/message.rs index 356fa6f..f25b996 100644 --- a/03b-multi-node-broadcast/src/message.rs +++ b/03b-multi-node-broadcast/src/message.rs @@ -49,6 +49,9 @@ pub enum Body { msg_id: u64, in_reply_to: u64, }, + Gossip { + messages: Vec, + }, } impl Message { diff --git a/03b-multi-node-broadcast/src/node.rs b/03b-multi-node-broadcast/src/node.rs index 03ae7c1..f48d4c2 100644 --- a/03b-multi-node-broadcast/src/node.rs +++ b/03b-multi-node-broadcast/src/node.rs @@ -11,18 +11,19 @@ pub(crate) struct Node { } impl Node { - pub(crate) fn init(message: Message) -> Node { + pub(crate) fn init(&mut self, message: Message) { match message.body { Body::Init { node_id, node_ids, .. } => { - return Node { - id: node_id, - availble_nodes: node_ids, - storage: Storage::new(), - } + self.id = node_id; + self.availble_nodes = node_ids; } _ => panic!("Invalid message type"), } } + + pub(crate) fn get_id(&self) -> String { + self.id.clone() + } } diff --git a/03b-multi-node-broadcast/src/storage.rs b/03b-multi-node-broadcast/src/storage.rs index 8f6ae0c..25ad5bc 100644 --- a/03b-multi-node-broadcast/src/storage.rs +++ b/03b-multi-node-broadcast/src/storage.rs @@ -1,36 +1,83 @@ use serde::{Deserialize, Serialize}; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; #[derive(Serialize, Deserialize, Debug, Default)] pub(crate) struct Topology(pub(crate) HashMap>); #[derive(Serialize, Deserialize, Debug, Default)] -pub(crate) struct Messages(pub(crate) Vec); +pub(crate) struct Messages(pub(crate) HashSet); #[derive(Serialize, Deserialize, Debug, Default)] pub(crate) struct Storage { pub(crate) messages: Messages, + pub(crate) received_messages: HashMap, + pub(crate) sent_messages: HashMap, pub(crate) topology: Topology, } impl Storage { - pub(crate) fn new() -> Storage { - Storage::default() - } + pub(crate) fn add_message(&mut self, message: u64, node: String) { + self.messages.0.insert(message); - pub(crate) fn add_message(&mut self, message: u64) { - self.messages.0.push(message); + if self.received_messages.contains_key(&node) { + self.received_messages + .get_mut(&node) + .unwrap() + .0 + .insert(message); + } else { + let mut v = Messages::default(); + v.0.insert(message); + self.received_messages.insert(node, v); + } } pub(crate) fn get_messages(&mut self) -> Vec { - self.messages.0.to_owned() + self.messages.0.clone().into_iter().collect() + } + + pub(crate) fn get_messages_for_node(&self, node: String) -> Vec { + let received: Vec = self + .received_messages + .iter() + .filter(|(key, _)| *key == &node) + .flat_map(|(_, Messages(value))| value) + .cloned() + .collect(); + + let sent: Vec = self + .sent_messages + .iter() + .filter(|(key, _)| *key == &node) + .flat_map(|(_, Messages(value))| value) + .cloned() + .collect(); + + self.messages + .0 + .iter() + .filter(|m| !received.contains(m) && !sent.contains(m)) + .cloned() + .collect() + } + + pub(crate) fn add_to_sent_messages(&mut self, messages: HashSet, node: String) { + if self.sent_messages.contains_key(&node) { + self.sent_messages + .get_mut(&node) + .unwrap() + .0 + .extend(messages); + } else { + self.sent_messages.insert(node, Messages(messages)); + } } pub(crate) fn init_topology(&mut self, topology: HashMap>) { self.topology.0 = topology; } - pub(crate) fn get_neighbours(&mut self, node_id: String) -> Vec { - self.topology.0.get(&node_id).unwrap().to_owned() + pub(crate) fn get_neighbours(&self, node_id: &str) -> Vec { + self.topology.0.get(node_id).unwrap().to_owned() } }