From 7e90fae353b6a3c2836843c751332f6eab4f03c1 Mon Sep 17 00:00:00 2001 From: Bastian Gruber Date: Sat, 29 Apr 2023 09:30:10 +0200 Subject: [PATCH] Extend the basic server, debug with broadcast --- problem_03/Cargo.toml | 3 +- problem_03/{src => bin}/client.rs | 2 +- problem_03/bin/server.rs | 15 +++++ problem_03/src/connection.rs | 39 +++++++++++ problem_03/src/lib.rs | 12 ++++ problem_03/src/server.rs | 108 ++++++++++++++++++++++++------ problem_03/src/shutdown.rs | 30 +++++++++ 7 files changed, 186 insertions(+), 23 deletions(-) rename problem_03/{src => bin}/client.rs (96%) create mode 100644 problem_03/bin/server.rs create mode 100644 problem_03/src/connection.rs create mode 100644 problem_03/src/lib.rs create mode 100644 problem_03/src/shutdown.rs diff --git a/problem_03/Cargo.toml b/problem_03/Cargo.toml index 3654eae..9284f55 100644 --- a/problem_03/Cargo.toml +++ b/problem_03/Cargo.toml @@ -9,9 +9,10 @@ path = "src/server.rs" [[bin]] name = "client" -path = "src/client.rs" +path = "bin/client.rs" [dependencies] +anyhow = "1.0.70" futures = "0.3.28" tokio = { version = "1.14.0", features = ["full"] } tokio-util = { version = "0.7.4", features = ["codec"] } diff --git a/problem_03/src/client.rs b/problem_03/bin/client.rs similarity index 96% rename from problem_03/src/client.rs rename to problem_03/bin/client.rs index c684eac..8af32a1 100644 --- a/problem_03/src/client.rs +++ b/problem_03/bin/client.rs @@ -22,7 +22,7 @@ async fn main() -> Result<(), Box> { if n > 0 { info!("Receivng from server: {}", buf.trim_end()); } else { - info!("Server is finished sending, break"); + info!("Server is finished sending: {}", n); return; } } else { diff --git a/problem_03/bin/server.rs b/problem_03/bin/server.rs new file mode 100644 index 0000000..8035015 --- /dev/null +++ b/problem_03/bin/server.rs @@ -0,0 +1,15 @@ +// use problem_03::{server, DEFAULT_PORT}; + +// use tokio::net::TcpListener; +// use tokio::signal; + +// #[tokio::main] +// pub async fn main() -> problem_03::Result<()> { +// tracing_subscriber::fmt::try_init()?; + +// let listener = TcpListener::bind(&format!("0.0.0.0:{}", DEFAULT_PORT)).await?; + +// server::run(listener, signal::ctrl_c()).await?; + +// Ok(()) +// } diff --git a/problem_03/src/connection.rs b/problem_03/src/connection.rs new file mode 100644 index 0000000..ea3723f --- /dev/null +++ b/problem_03/src/connection.rs @@ -0,0 +1,39 @@ +use futures::{SinkExt, StreamExt}; +use tokio::net::TcpStream; +use tokio_util::codec::{Framed, LinesCodec}; +use tracing::{debug, error, info}; + +#[derive(Debug)] +pub struct Connection { + stream: Framed, +} + +impl Connection { + pub fn new(socket: TcpStream) -> Connection { + Connection { + stream: Framed::new(socket, LinesCodec::new()), + } + } + + pub async fn read_frame(&mut self) -> crate::Result> { + loop { + info!("Read next frame"); + if let Some(Ok(frame)) = self.stream.next().await { + info!("Frame parsed"); + return Ok(Some(frame)); + } else { + return Err("connection reset by peer".into()); + } + } + } + + pub async fn write_frame(&mut self, response: String) -> crate::Result<()> { + debug!(?response); + if let Err(e) = self.stream.send(response.clone()).await { + error!("Could not write frame to stream"); + return Err(e.to_string().into()); + } + info!("Wrote to frame: {}", response); + Ok(()) + } +} diff --git a/problem_03/src/lib.rs b/problem_03/src/lib.rs new file mode 100644 index 0000000..941319a --- /dev/null +++ b/problem_03/src/lib.rs @@ -0,0 +1,12 @@ +mod connection; +pub use connection::Connection; + +pub mod server; + +mod shutdown; +use shutdown::Shutdown; + +pub const DEFAULT_PORT: u16 = 1222; + +pub type Error = Box; +pub type Result = std::result::Result; diff --git a/problem_03/src/server.rs b/problem_03/src/server.rs index 16d5aa6..64f51e0 100644 --- a/problem_03/src/server.rs +++ b/problem_03/src/server.rs @@ -1,5 +1,12 @@ -use futures::{SinkExt, StreamExt}; +use futures::{stream::StreamExt, SinkExt}; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::{Arc, Mutex}; use tokio::net::TcpListener; +use tokio::sync::{ + broadcast, + broadcast::{Receiver, Sender}, +}; use tokio_util::codec::{Framed, LinesCodec}; use tracing::{error, info}; @@ -9,41 +16,100 @@ const PORT: u16 = 1222; type Error = Box; type Result = std::result::Result; +type Username = String; +type Address = SocketAddr; + +#[derive(Clone, Debug, Default)] +struct Users(Arc>>); + #[tokio::main] async fn main() -> Result<()> { tracing_subscriber::fmt::try_init().expect("Tracing was not setup"); let listener = TcpListener::bind(format!("{IP}:{PORT}")).await?; info!("Listening on: {}", format!("{IP}:{PORT}")); + + let (tx, _) = broadcast::channel(256); + + let db = Users::default(); + // Infinite loop to always listen to new connections on this IP/PORT loop { - // Get the TCP stream out of the new connection, and the address from which - // it is connected to - let (stream, address) = listener.accept().await?; - let mut framed = Framed::new(stream, LinesCodec::new()); - info!("New address connected: {}", address); - let _ = framed.send("You are connected!".to_string()).await; + let (tx, mut rx) = (tx.clone(), tx.subscribe()); + let db = db.clone(); - // We spawn a new task, so every incoming connection can be put on a thread - // and be worked on "in the background" - // This allows us to handle multiple connections "at the same time" tokio::spawn(async move { + let mut framed = Framed::new(stream, LinesCodec::new()); + info!("New address connected: {address}"); + let _ = framed + .send("Welcome to budgetchat! What shall I call you?".to_string()) + .await; + + let mut name = String::default(); + + // We read exactly one line per loop. A line ends with \n. + // So if the client doesn't frame their package with \n at the end, + // we won't process until we find one. + match framed.next().await { + Some(Ok(username)) => { + name = username.clone(); + db.0.lock().unwrap().insert(username.clone(), address); + let message = compose_message(username.clone(), db.clone()); + info!("Adding username: {username} to db"); + let _ = framed.send(message).await; + info!("Send message to client"); + let _ = tx.send(format!("* {} has entered the room", username)); + } + Some(Err(e)) => { + error!("Error parsing message: {e}"); + } + None => { + info!("No frame"); + } + } + loop { - // We read exactly one line per loop. A line ends with \n. - // So if the client doesn't frame their package with \n at the end, - // we won't process until we find one. - match framed.next().await { - Some(n) => { - if let Err(e) = n { - error!("Error parsing message: {}", e); - } else { - let _ = framed.send(n.unwrap()).await; + tokio::select! { + n = framed.next() => { + match n { + Some(Ok(n)) => { + // broadcast message to all clients except the one who sent it + info!("Receiving new chat message: {n}"); + let _ = tx.send(format!("[{}]: {}", name, n)); + } + Some(Err(e)) => { + error!("Error receiving chat message: {e}"); + } + None => { + // Connection dropped + // remove client from db etc. + // send leave message + info!("No next frame"); + let _ = tx.send(format!("* {} has left the room", name)); + break; + } } } - None => return, - }; + message = rx.recv() => { + info!("Broadcast received: {:?}", message.clone().unwrap()); + let _ = framed.send(message.unwrap()).await; + } + } } }); } } + +fn compose_message(name: String, db: Users) -> String { + format!( + "* The room contains: {}", + db.0.lock() + .unwrap() + .keys() + .filter(|n| n.as_str() != name) + .map(|n| n.to_string()) + .collect::>() + .join(", ") + ) +} diff --git a/problem_03/src/shutdown.rs b/problem_03/src/shutdown.rs new file mode 100644 index 0000000..1c86f83 --- /dev/null +++ b/problem_03/src/shutdown.rs @@ -0,0 +1,30 @@ +use tokio::sync::broadcast; + +#[derive(Debug)] +pub(crate) struct Shutdown { + shutdown: bool, + notify: broadcast::Receiver<()>, +} + +impl Shutdown { + pub(crate) fn new(notify: broadcast::Receiver<()>) -> Shutdown { + Shutdown { + shutdown: false, + notify, + } + } + + pub(crate) fn is_shutdown(&self) -> bool { + self.shutdown + } + + pub(crate) async fn recv(&mut self) { + if self.shutdown { + return; + } + + let _ = self.notify.recv().await; + + self.shutdown = true; + } +}