diff --git a/problem_03/Cargo.toml b/problem_03/Cargo.toml index 9284f55..7d9c26b 100644 --- a/problem_03/Cargo.toml +++ b/problem_03/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" [[bin]] name = "server" -path = "src/server.rs" +path = "bin/server.rs" [[bin]] name = "client" diff --git a/problem_03/bin/server.rs b/problem_03/bin/server.rs index 8035015..c5fdb7c 100644 --- a/problem_03/bin/server.rs +++ b/problem_03/bin/server.rs @@ -1,15 +1,15 @@ -// use problem_03::{server, DEFAULT_PORT}; +use problem_03::{server, DEFAULT_IP, DEFAULT_PORT}; -// use tokio::net::TcpListener; -// use tokio::signal; +use tokio::net::TcpListener; +use tokio::signal; -// #[tokio::main] -// pub async fn main() -> problem_03::Result<()> { -// tracing_subscriber::fmt::try_init()?; +#[tokio::main] +pub async fn main() -> problem_03::Result<()> { + tracing_subscriber::fmt::try_init()?; -// let listener = TcpListener::bind(&format!("0.0.0.0:{}", DEFAULT_PORT)).await?; + let listener = TcpListener::bind(&format!("{DEFAULT_IP}:{DEFAULT_PORT}")).await?; -// server::run(listener, signal::ctrl_c()).await?; + server::run(listener, signal::ctrl_c()).await?; -// Ok(()) -// } + Ok(()) +} diff --git a/problem_03/src/connection.rs b/problem_03/src/connection.rs index ea3723f..554fd2d 100644 --- a/problem_03/src/connection.rs +++ b/problem_03/src/connection.rs @@ -1,33 +1,46 @@ +use crate::{Message, Result, Username}; use futures::{SinkExt, StreamExt}; use tokio::net::TcpStream; +use tokio::sync::broadcast::Sender; use tokio_util::codec::{Framed, LinesCodec}; use tracing::{debug, error, info}; +#[derive(Clone, Debug, Default)] +pub struct BroadcastMessage { + pub(crate) from: Username, + pub(crate) message: Message, +} + +impl BroadcastMessage { + pub fn new(from: Username, message: Message) -> Self { + BroadcastMessage { from, message } + } +} + #[derive(Debug)] pub struct Connection { - stream: Framed, + pub stream: Framed, + pub broadcast: Sender, } impl Connection { - pub fn new(socket: TcpStream) -> Connection { + pub fn new(socket: TcpStream, sender: Sender) -> Connection { Connection { stream: Framed::new(socket, LinesCodec::new()), + broadcast: sender, } } - pub async fn read_frame(&mut self) -> crate::Result> { - loop { - info!("Read next frame"); - if let Some(Ok(frame)) = self.stream.next().await { - info!("Frame parsed"); - return Ok(Some(frame)); - } else { - return Err("connection reset by peer".into()); - } - } + pub async fn red_next_frame(&mut self) -> Result> { + return if let Some(Ok(frame)) = self.stream.next().await { + info!("Frame for parsing the username parsed"); + Ok(Some(frame)) + } else { + Err("connection reset by peer".into()) + }; } - pub async fn write_frame(&mut self, response: String) -> crate::Result<()> { + pub async fn write_frame(&mut self, response: String) -> Result<()> { debug!(?response); if let Err(e) = self.stream.send(response.clone()).await { error!("Could not write frame to stream"); @@ -36,4 +49,12 @@ impl Connection { info!("Wrote to frame: {}", response); Ok(()) } + + pub fn broadcast_message(&mut self, message: BroadcastMessage) -> Result<()> { + match self.broadcast.send(message.clone()) { + Ok(n) => info!("Sent broadcast: {n}"), + Err(e) => error!("Could not send broadcast: {e}"), + } + Ok(()) + } } diff --git a/problem_03/src/db.rs b/problem_03/src/db.rs new file mode 100644 index 0000000..54afec3 --- /dev/null +++ b/problem_03/src/db.rs @@ -0,0 +1,48 @@ +use std::sync::Arc; +use tokio::sync::RwLock; + +use tracing::{debug, info}; + +use crate::Result; + +#[derive(Debug, Clone)] +pub(crate) struct Db { + users: Arc>>, +} + +impl Db { + pub fn new() -> Self { + Db { + users: Arc::new(RwLock::new(Vec::default())), + } + } + + pub async fn insert_user(&self, username: String) -> Result<()> { + if !username.is_empty() && username.chars().all(char::is_alphabetic) { + self.users.write().await.push(username); + Ok(()) + } else { + Err(format!("Cannot insert new user: {username}").into()) + } + } + + pub async fn get_room_members(&self, username: String) -> Vec { + info!("Get room members: {:?}", self.users); + self.users + .read() + .await + .clone() + .into_iter() + .filter(|n| { + info!("{n} is in the room"); + *n != username + + }) + .collect() + } + + pub async fn remove(&self, username: String) -> Result<()> { + self.users.write().await.retain(|n| *n != username); + Ok(()) + } +} diff --git a/problem_03/src/lib.rs b/problem_03/src/lib.rs index 941319a..b057001 100644 --- a/problem_03/src/lib.rs +++ b/problem_03/src/lib.rs @@ -1,12 +1,21 @@ mod connection; -pub use connection::Connection; + +pub use connection::{BroadcastMessage, Connection}; +use tokio::net::unix::SocketAddr; pub mod server; +mod db; mod shutdown; + use shutdown::Shutdown; pub const DEFAULT_PORT: u16 = 1222; +pub const DEFAULT_IP: &str = "0.0.0.0"; + +pub type Username = String; +pub type Message = String; +pub type Address = SocketAddr; pub type Error = Box; pub type Result = std::result::Result; diff --git a/problem_03/src/server.rs b/problem_03/src/server.rs index 0a9d008..13c7002 100644 --- a/problem_03/src/server.rs +++ b/problem_03/src/server.rs @@ -1,153 +1,185 @@ -use futures::{stream::StreamExt, SinkExt}; -use std::collections::HashMap; -use std::net::SocketAddr; -use std::sync::{Arc, Mutex}; -use tokio::net::TcpListener; -use tokio::sync::{ - broadcast, - broadcast::{Receiver, Sender}, -}; -use tokio_util::codec::{Framed, LinesCodec}; -use tracing::{error, info}; +use crate::{BroadcastMessage, Connection, Shutdown}; -const IP: &str = "0.0.0.0"; -const PORT: u16 = 1222; +use crate::db::Db; +use futures::StreamExt; +use std::future::Future; +use std::sync::Arc; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::{broadcast, mpsc, Semaphore}; +use tokio::time::{self, Duration}; +use tracing::{debug, error, info}; -type Error = Box; -type Result = std::result::Result; +struct Listener { + listener: TcpListener, + limit_connections: Arc, + notify_shutdown: broadcast::Sender<()>, + broadcast_message: broadcast::Sender, + shutdown_complete_rx: mpsc::Receiver<()>, + shutdown_complete_tx: mpsc::Sender<()>, +} -type Username = String; -type Message = String; -type Address = SocketAddr; +struct Handler { + connection: Connection, + db: Db, + shutdown: Shutdown, + _shutdown_complete: mpsc::Sender<()>, +} -#[derive(Clone, Debug, Default)] -struct BroadcastMessage(Username, Message); +const MAX_CONNECTIONS: usize = 100; -#[derive(Clone, Debug, Default)] -struct Users(Arc>>); +pub async fn run(listener: TcpListener, shutdown: impl Future) -> crate::Result<()> { + let (notify_shutdown, _) = broadcast::channel(1); + let (broadcast_message, _) = broadcast::channel(100); + let (shutdown_complete_tx, shutdown_complete_rx) = mpsc::channel(1); -#[tokio::main] -async fn main() -> Result<()> { - tracing_subscriber::fmt::try_init().expect("Tracing was not setup"); + let mut server = Listener { + listener, + limit_connections: Arc::new(Semaphore::new(MAX_CONNECTIONS)), + notify_shutdown, + broadcast_message, + shutdown_complete_tx, + shutdown_complete_rx, + }; - let listener = TcpListener::bind(format!("{IP}:{PORT}")).await?; - info!("Listening on: {}", format!("{IP}:{PORT}")); - - let (tx, _) = broadcast::channel(256); - - let db = Users::default(); - - let mut valid_name = true; - - // Infinite loop to always listen to new connections on this IP/PORT - loop { - let (stream, address) = listener.accept().await?; - let (tx, rx) = (tx.clone(), tx.subscribe()); - let db = db.clone(); - - tokio::spawn(async move { - let mut framed = Framed::new(stream, LinesCodec::new()); - info!("New address connected: {address}"); - let _ = framed - .send("Welcome to budgetchat! What shall I call you?".to_string()) - .await; - - let mut name = String::default(); - - // We read exactly one line per loop. A line ends with \n. - // So if the client doesn't frame their package with \n at the end, - // we won't process until we find one. - match framed.next().await { - Some(Ok(username)) => { - if !username.is_empty() && username.chars().all(char::is_alphanumeric) { - name = username.clone(); - db.0.lock().unwrap().insert(username.clone(), address); - let message = compose_message(username.clone(), db.clone()); - info!("Adding username: {username} to db"); - let _ = framed.send(message).await; - } else { - valid_name = false; - return; - } - } - Some(Err(e)) => { - error!("Error parsing message: {e}"); - return; - } - None => { - info!("No frame"); - return; - } + tokio::select! { + res = server.run() => { + if let Err(err) = res { + error!(cause = %err, "failed to accept"); } - - if !valid_name { - return; - } - - let b = BroadcastMessage( - name.clone(), - format!("* {} has entered the room", name), - ); - let _ = tx.send(b); - let mut rx = rx.resubscribe(); - - loop { - tokio::select! { - n = framed.next() => { - match n { - Some(Ok(n)) => { - // broadcast message to all clients except the one who sent it - info!("Receiving new chat message: {n}"); - let b = - BroadcastMessage(name.clone(), format!("[{}] {}", name, n)); - let _ = tx.send(b); - } - Some(Err(e)) => { - error!("Error receiving chat message: {e}"); - } - None => { - // Connection dropped - // remove client from db etc. - // send leave message - info!("No next frame"); - let b = - BroadcastMessage(name.clone(), format!("* {} has left the room", name)); - db.0.lock().unwrap().remove(&name.clone()); - let _ = tx.send(b); - break; - } - } - } - message = rx.recv() => { - let broadcast = message.clone().unwrap(); - info!("Broadcast received: {:?}", message.clone().unwrap()); - if broadcast.0 != name { - info!("Broadcast sent to {}: {:?}", name, message.clone().unwrap()); - let _ = framed.send(message.unwrap().1).await; - } - - } - } - } - }); - - if !valid_name { - break; + } + _ = shutdown => { + info!("shutting down"); } } + let Listener { + mut shutdown_complete_rx, + shutdown_complete_tx, + notify_shutdown, + .. + } = server; + + drop(notify_shutdown); + drop(shutdown_complete_tx); + + let _ = shutdown_complete_rx.recv().await; + Ok(()) } -fn compose_message(name: String, db: Users) -> String { - format!( - "* The room contains: {}", - db.0.lock() - .unwrap() - .keys() - .filter(|n| n.as_str() != name) - .map(|n| n.to_string()) - .collect::>() - .join(", ") - ) +impl Listener { + async fn run(&mut self) -> crate::Result<()> { + info!("accepting inbound connections"); + let db = Db::new(); + loop { + let permit = self + .limit_connections + .clone() + .acquire_owned() + .await + .unwrap(); + + let socket = self.accept().await?; + let message_sender: + broadcast::Sender + = + self.broadcast_message.clone(); + + let mut handler = Handler { + connection: Connection::new(socket, message_sender), + db: db.clone(), + shutdown: Shutdown::new(self.notify_shutdown.subscribe()), + _shutdown_complete: self.shutdown_complete_tx.clone(), + }; + + info!("Created new handler"); + + tokio::spawn(async move { + if let Err(err) = handler.run().await { + error!(cause = ?err, "connection error"); + } + drop(permit); + }); + } + } + + async fn accept(&mut self) -> crate::Result { + let mut backoff = 1; + + loop { + match self.listener.accept().await { + Ok((socket, _)) => return Ok(socket), + Err(err) => { + if backoff > 64 { + return Err(err.into()); + } + } + } + + time::sleep(Duration::from_secs(backoff)).await; + + backoff *= 2; + } + } +} + +impl Handler { + async fn run(&mut self) -> crate::Result<()> { + let welcome = String::from("Welcome to budgetchat! What shall I call you?"); + let username; + + let _ = self.connection.write_frame(welcome).await; + + if let Some(Ok(name)) = self.connection.stream.next().await { + info!("Add {name} to db"); + self.db.insert_user(name.clone()).await?; + username = name; + } else { + return Ok(()); + } + + let joined_message = format!("* {username} has entered the room"); + let _ = self.connection + .broadcast_message(BroadcastMessage::new(username.clone(), joined_message)); + let room_contains_message = format!( + "* The room contains {}", + self.db.get_room_members(username.clone()).await.join(",") + ); + let _ = self.connection.write_frame(room_contains_message).await; + + let mut receiver = self.connection.broadcast.subscribe(); + + while !self.shutdown.is_shutdown() { + tokio::select! { + res = self.connection.stream.next() => match res { + Some(Ok(frame)) => { + let _ = self.connection + .broadcast_message(BroadcastMessage::new(username.clone(), frame)); + }, + Some(Err(_)) => { + error!("Could not parse frame"); + continue; + }, + None => { + let message = format!("* {username} has left the room"); + let _ = self.connection.broadcast_message(BroadcastMessage::new(username.clone(), message.clone())); + let _ = self.db.remove(username).await; + return Ok(()) + }, + }, + message = receiver.recv() => { + info!("Message received: {:?}", message.as_ref().unwrap()); + if message.as_ref().unwrap().from != username { + let _ = self.connection.write_frame(format!("[{}] {}", username, message.as_ref().unwrap().message.clone())).await; + } + } + _ = self.shutdown.recv() => { + debug!("Shutdown"); + return Ok(()); + } + }; + } + + Ok(()) + } }