Slightly more sophisticated version of problem_03
This commit is contained in:
parent
8edf4cad62
commit
efcfff55e0
6 changed files with 273 additions and 163 deletions
|
|
@ -5,7 +5,7 @@ edition = "2021"
|
|||
|
||||
[[bin]]
|
||||
name = "server"
|
||||
path = "src/server.rs"
|
||||
path = "bin/server.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "client"
|
||||
|
|
|
|||
|
|
@ -1,15 +1,15 @@
|
|||
// use problem_03::{server, DEFAULT_PORT};
|
||||
use problem_03::{server, DEFAULT_IP, DEFAULT_PORT};
|
||||
|
||||
// use tokio::net::TcpListener;
|
||||
// use tokio::signal;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::signal;
|
||||
|
||||
// #[tokio::main]
|
||||
// pub async fn main() -> problem_03::Result<()> {
|
||||
// tracing_subscriber::fmt::try_init()?;
|
||||
#[tokio::main]
|
||||
pub async fn main() -> problem_03::Result<()> {
|
||||
tracing_subscriber::fmt::try_init()?;
|
||||
|
||||
// let listener = TcpListener::bind(&format!("0.0.0.0:{}", DEFAULT_PORT)).await?;
|
||||
let listener = TcpListener::bind(&format!("{DEFAULT_IP}:{DEFAULT_PORT}")).await?;
|
||||
|
||||
// server::run(listener, signal::ctrl_c()).await?;
|
||||
server::run(listener, signal::ctrl_c()).await?;
|
||||
|
||||
// Ok(())
|
||||
// }
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,33 +1,46 @@
|
|||
use crate::{Message, Result, Username};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::broadcast::Sender;
|
||||
use tokio_util::codec::{Framed, LinesCodec};
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct BroadcastMessage {
|
||||
pub(crate) from: Username,
|
||||
pub(crate) message: Message,
|
||||
}
|
||||
|
||||
impl BroadcastMessage {
|
||||
pub fn new(from: Username, message: Message) -> Self {
|
||||
BroadcastMessage { from, message }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Connection {
|
||||
stream: Framed<TcpStream, LinesCodec>,
|
||||
pub stream: Framed<TcpStream, LinesCodec>,
|
||||
pub broadcast: Sender<BroadcastMessage>,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
pub fn new(socket: TcpStream) -> Connection {
|
||||
pub fn new(socket: TcpStream, sender: Sender<BroadcastMessage>) -> Connection {
|
||||
Connection {
|
||||
stream: Framed::new(socket, LinesCodec::new()),
|
||||
broadcast: sender,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_frame(&mut self) -> crate::Result<Option<String>> {
|
||||
loop {
|
||||
info!("Read next frame");
|
||||
if let Some(Ok(frame)) = self.stream.next().await {
|
||||
info!("Frame parsed");
|
||||
return Ok(Some(frame));
|
||||
} else {
|
||||
return Err("connection reset by peer".into());
|
||||
}
|
||||
}
|
||||
pub async fn red_next_frame(&mut self) -> Result<Option<String>> {
|
||||
return if let Some(Ok(frame)) = self.stream.next().await {
|
||||
info!("Frame for parsing the username parsed");
|
||||
Ok(Some(frame))
|
||||
} else {
|
||||
Err("connection reset by peer".into())
|
||||
};
|
||||
}
|
||||
|
||||
pub async fn write_frame(&mut self, response: String) -> crate::Result<()> {
|
||||
pub async fn write_frame(&mut self, response: String) -> Result<()> {
|
||||
debug!(?response);
|
||||
if let Err(e) = self.stream.send(response.clone()).await {
|
||||
error!("Could not write frame to stream");
|
||||
|
|
@ -36,4 +49,12 @@ impl Connection {
|
|||
info!("Wrote to frame: {}", response);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn broadcast_message(&mut self, message: BroadcastMessage) -> Result<()> {
|
||||
match self.broadcast.send(message.clone()) {
|
||||
Ok(n) => info!("Sent broadcast: {n}"),
|
||||
Err(e) => error!("Could not send broadcast: {e}"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
|||
48
problem_03/src/db.rs
Normal file
48
problem_03/src/db.rs
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Db {
|
||||
users: Arc<RwLock<Vec<String>>>,
|
||||
}
|
||||
|
||||
impl Db {
|
||||
pub fn new() -> Self {
|
||||
Db {
|
||||
users: Arc::new(RwLock::new(Vec::default())),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn insert_user(&self, username: String) -> Result<()> {
|
||||
if !username.is_empty() && username.chars().all(char::is_alphabetic) {
|
||||
self.users.write().await.push(username);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!("Cannot insert new user: {username}").into())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_room_members(&self, username: String) -> Vec<String> {
|
||||
info!("Get room members: {:?}", self.users);
|
||||
self.users
|
||||
.read()
|
||||
.await
|
||||
.clone()
|
||||
.into_iter()
|
||||
.filter(|n| {
|
||||
info!("{n} is in the room");
|
||||
*n != username
|
||||
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub async fn remove(&self, username: String) -> Result<()> {
|
||||
self.users.write().await.retain(|n| *n != username);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
@ -1,12 +1,21 @@
|
|||
mod connection;
|
||||
pub use connection::Connection;
|
||||
|
||||
pub use connection::{BroadcastMessage, Connection};
|
||||
use tokio::net::unix::SocketAddr;
|
||||
|
||||
pub mod server;
|
||||
|
||||
mod db;
|
||||
mod shutdown;
|
||||
|
||||
use shutdown::Shutdown;
|
||||
|
||||
pub const DEFAULT_PORT: u16 = 1222;
|
||||
pub const DEFAULT_IP: &str = "0.0.0.0";
|
||||
|
||||
pub type Username = String;
|
||||
pub type Message = String;
|
||||
pub type Address = SocketAddr;
|
||||
|
||||
pub type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
|
|
|||
|
|
@ -1,153 +1,185 @@
|
|||
use futures::{stream::StreamExt, SinkExt};
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::{
|
||||
broadcast,
|
||||
broadcast::{Receiver, Sender},
|
||||
};
|
||||
use tokio_util::codec::{Framed, LinesCodec};
|
||||
use tracing::{error, info};
|
||||
use crate::{BroadcastMessage, Connection, Shutdown};
|
||||
|
||||
const IP: &str = "0.0.0.0";
|
||||
const PORT: u16 = 1222;
|
||||
use crate::db::Db;
|
||||
use futures::StreamExt;
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::sync::{broadcast, mpsc, Semaphore};
|
||||
use tokio::time::{self, Duration};
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
type Result<T> = std::result::Result<T, Error>;
|
||||
struct Listener {
|
||||
listener: TcpListener,
|
||||
limit_connections: Arc<Semaphore>,
|
||||
notify_shutdown: broadcast::Sender<()>,
|
||||
broadcast_message: broadcast::Sender<BroadcastMessage>,
|
||||
shutdown_complete_rx: mpsc::Receiver<()>,
|
||||
shutdown_complete_tx: mpsc::Sender<()>,
|
||||
}
|
||||
|
||||
type Username = String;
|
||||
type Message = String;
|
||||
type Address = SocketAddr;
|
||||
struct Handler {
|
||||
connection: Connection,
|
||||
db: Db,
|
||||
shutdown: Shutdown,
|
||||
_shutdown_complete: mpsc::Sender<()>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
struct BroadcastMessage(Username, Message);
|
||||
const MAX_CONNECTIONS: usize = 100;
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
struct Users(Arc<Mutex<HashMap<Username, Address>>>);
|
||||
pub async fn run(listener: TcpListener, shutdown: impl Future) -> crate::Result<()> {
|
||||
let (notify_shutdown, _) = broadcast::channel(1);
|
||||
let (broadcast_message, _) = broadcast::channel(100);
|
||||
let (shutdown_complete_tx, shutdown_complete_rx) = mpsc::channel(1);
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt::try_init().expect("Tracing was not setup");
|
||||
let mut server = Listener {
|
||||
listener,
|
||||
limit_connections: Arc::new(Semaphore::new(MAX_CONNECTIONS)),
|
||||
notify_shutdown,
|
||||
broadcast_message,
|
||||
shutdown_complete_tx,
|
||||
shutdown_complete_rx,
|
||||
};
|
||||
|
||||
let listener = TcpListener::bind(format!("{IP}:{PORT}")).await?;
|
||||
info!("Listening on: {}", format!("{IP}:{PORT}"));
|
||||
|
||||
let (tx, _) = broadcast::channel(256);
|
||||
|
||||
let db = Users::default();
|
||||
|
||||
let mut valid_name = true;
|
||||
|
||||
// Infinite loop to always listen to new connections on this IP/PORT
|
||||
loop {
|
||||
let (stream, address) = listener.accept().await?;
|
||||
let (tx, rx) = (tx.clone(), tx.subscribe());
|
||||
let db = db.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut framed = Framed::new(stream, LinesCodec::new());
|
||||
info!("New address connected: {address}");
|
||||
let _ = framed
|
||||
.send("Welcome to budgetchat! What shall I call you?".to_string())
|
||||
.await;
|
||||
|
||||
let mut name = String::default();
|
||||
|
||||
// We read exactly one line per loop. A line ends with \n.
|
||||
// So if the client doesn't frame their package with \n at the end,
|
||||
// we won't process until we find one.
|
||||
match framed.next().await {
|
||||
Some(Ok(username)) => {
|
||||
if !username.is_empty() && username.chars().all(char::is_alphanumeric) {
|
||||
name = username.clone();
|
||||
db.0.lock().unwrap().insert(username.clone(), address);
|
||||
let message = compose_message(username.clone(), db.clone());
|
||||
info!("Adding username: {username} to db");
|
||||
let _ = framed.send(message).await;
|
||||
} else {
|
||||
valid_name = false;
|
||||
return;
|
||||
}
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
error!("Error parsing message: {e}");
|
||||
return;
|
||||
}
|
||||
None => {
|
||||
info!("No frame");
|
||||
return;
|
||||
}
|
||||
tokio::select! {
|
||||
res = server.run() => {
|
||||
if let Err(err) = res {
|
||||
error!(cause = %err, "failed to accept");
|
||||
}
|
||||
|
||||
if !valid_name {
|
||||
return;
|
||||
}
|
||||
|
||||
let b = BroadcastMessage(
|
||||
name.clone(),
|
||||
format!("* {} has entered the room", name),
|
||||
);
|
||||
let _ = tx.send(b);
|
||||
let mut rx = rx.resubscribe();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
n = framed.next() => {
|
||||
match n {
|
||||
Some(Ok(n)) => {
|
||||
// broadcast message to all clients except the one who sent it
|
||||
info!("Receiving new chat message: {n}");
|
||||
let b =
|
||||
BroadcastMessage(name.clone(), format!("[{}] {}", name, n));
|
||||
let _ = tx.send(b);
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
error!("Error receiving chat message: {e}");
|
||||
}
|
||||
None => {
|
||||
// Connection dropped
|
||||
// remove client from db etc.
|
||||
// send leave message
|
||||
info!("No next frame");
|
||||
let b =
|
||||
BroadcastMessage(name.clone(), format!("* {} has left the room", name));
|
||||
db.0.lock().unwrap().remove(&name.clone());
|
||||
let _ = tx.send(b);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
message = rx.recv() => {
|
||||
let broadcast = message.clone().unwrap();
|
||||
info!("Broadcast received: {:?}", message.clone().unwrap());
|
||||
if broadcast.0 != name {
|
||||
info!("Broadcast sent to {}: {:?}", name, message.clone().unwrap());
|
||||
let _ = framed.send(message.unwrap().1).await;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if !valid_name {
|
||||
break;
|
||||
}
|
||||
_ = shutdown => {
|
||||
info!("shutting down");
|
||||
}
|
||||
}
|
||||
|
||||
let Listener {
|
||||
mut shutdown_complete_rx,
|
||||
shutdown_complete_tx,
|
||||
notify_shutdown,
|
||||
..
|
||||
} = server;
|
||||
|
||||
drop(notify_shutdown);
|
||||
drop(shutdown_complete_tx);
|
||||
|
||||
let _ = shutdown_complete_rx.recv().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compose_message(name: String, db: Users) -> String {
|
||||
format!(
|
||||
"* The room contains: {}",
|
||||
db.0.lock()
|
||||
.unwrap()
|
||||
.keys()
|
||||
.filter(|n| n.as_str() != name)
|
||||
.map(|n| n.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
)
|
||||
impl Listener {
|
||||
async fn run(&mut self) -> crate::Result<()> {
|
||||
info!("accepting inbound connections");
|
||||
let db = Db::new();
|
||||
loop {
|
||||
let permit = self
|
||||
.limit_connections
|
||||
.clone()
|
||||
.acquire_owned()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let socket = self.accept().await?;
|
||||
let message_sender:
|
||||
broadcast::Sender<BroadcastMessage>
|
||||
=
|
||||
self.broadcast_message.clone();
|
||||
|
||||
let mut handler = Handler {
|
||||
connection: Connection::new(socket, message_sender),
|
||||
db: db.clone(),
|
||||
shutdown: Shutdown::new(self.notify_shutdown.subscribe()),
|
||||
_shutdown_complete: self.shutdown_complete_tx.clone(),
|
||||
};
|
||||
|
||||
info!("Created new handler");
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = handler.run().await {
|
||||
error!(cause = ?err, "connection error");
|
||||
}
|
||||
drop(permit);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async fn accept(&mut self) -> crate::Result<TcpStream> {
|
||||
let mut backoff = 1;
|
||||
|
||||
loop {
|
||||
match self.listener.accept().await {
|
||||
Ok((socket, _)) => return Ok(socket),
|
||||
Err(err) => {
|
||||
if backoff > 64 {
|
||||
return Err(err.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
time::sleep(Duration::from_secs(backoff)).await;
|
||||
|
||||
backoff *= 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Handler {
|
||||
async fn run(&mut self) -> crate::Result<()> {
|
||||
let welcome = String::from("Welcome to budgetchat! What shall I call you?");
|
||||
let username;
|
||||
|
||||
let _ = self.connection.write_frame(welcome).await;
|
||||
|
||||
if let Some(Ok(name)) = self.connection.stream.next().await {
|
||||
info!("Add {name} to db");
|
||||
self.db.insert_user(name.clone()).await?;
|
||||
username = name;
|
||||
} else {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let joined_message = format!("* {username} has entered the room");
|
||||
let _ = self.connection
|
||||
.broadcast_message(BroadcastMessage::new(username.clone(), joined_message));
|
||||
let room_contains_message = format!(
|
||||
"* The room contains {}",
|
||||
self.db.get_room_members(username.clone()).await.join(",")
|
||||
);
|
||||
let _ = self.connection.write_frame(room_contains_message).await;
|
||||
|
||||
let mut receiver = self.connection.broadcast.subscribe();
|
||||
|
||||
while !self.shutdown.is_shutdown() {
|
||||
tokio::select! {
|
||||
res = self.connection.stream.next() => match res {
|
||||
Some(Ok(frame)) => {
|
||||
let _ = self.connection
|
||||
.broadcast_message(BroadcastMessage::new(username.clone(), frame));
|
||||
},
|
||||
Some(Err(_)) => {
|
||||
error!("Could not parse frame");
|
||||
continue;
|
||||
},
|
||||
None => {
|
||||
let message = format!("* {username} has left the room");
|
||||
let _ = self.connection.broadcast_message(BroadcastMessage::new(username.clone(), message.clone()));
|
||||
let _ = self.db.remove(username).await;
|
||||
return Ok(())
|
||||
},
|
||||
},
|
||||
message = receiver.recv() => {
|
||||
info!("Message received: {:?}", message.as_ref().unwrap());
|
||||
if message.as_ref().unwrap().from != username {
|
||||
let _ = self.connection.write_frame(format!("[{}] {}", username, message.as_ref().unwrap().message.clone())).await;
|
||||
}
|
||||
}
|
||||
_ = self.shutdown.recv() => {
|
||||
debug!("Shutdown");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue