diff --git a/problem_02/.gitignore b/problem_02/.gitignore new file mode 100644 index 0000000..b83d222 --- /dev/null +++ b/problem_02/.gitignore @@ -0,0 +1 @@ +/target/ diff --git a/problem_02/Cargo.toml b/problem_02/Cargo.toml new file mode 100644 index 0000000..0f494ac --- /dev/null +++ b/problem_02/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "problem_02" +version = "0.1.0" +edition = "2021" + +[[bin]] +name = "server" +path = "bin/server.rs" + +[dependencies] +atoi = "0.3.2" +bytes = "1" +tokio = { version = "1", features = ["full"] } +tracing = "0.1.34" +tracing-subscriber = { version = "0.3.11", features = ["env-filter"] } diff --git a/problem_02/bin/server.rs b/problem_02/bin/server.rs new file mode 100644 index 0000000..4fa84f6 --- /dev/null +++ b/problem_02/bin/server.rs @@ -0,0 +1,15 @@ +use problem_02::{server, DEFAULT_PORT}; + +use tokio::net::TcpListener; +use tokio::signal; + +#[tokio::main] +pub async fn main() -> problem_02::Result<()> { + tracing_subscriber::fmt::try_init()?; + + let listener = TcpListener::bind(&format!("127.0.0.1:{}", DEFAULT_PORT)).await?; + + server::run(listener, signal::ctrl_c()).await?; + + Ok(()) +} diff --git a/problem_02/src/connection.rs b/problem_02/src/connection.rs new file mode 100644 index 0000000..f040300 --- /dev/null +++ b/problem_02/src/connection.rs @@ -0,0 +1,63 @@ +use crate::frame::{self, Frame}; + +use bytes::{Buf, BytesMut}; +use std::io::Cursor; +use tokio::io::{AsyncReadExt, BufWriter}; +use tokio::net::TcpStream; +use tracing::{debug, info}; + +#[derive(Debug)] +pub struct Connection { + stream: BufWriter, + buffer: BytesMut, +} + +impl Connection { + pub fn new(socket: TcpStream) -> Connection { + Connection { + stream: BufWriter::new(socket), + buffer: BytesMut::with_capacity(4 * 1024), + } + } + + pub async fn read_frame(&mut self) -> crate::Result> { + loop { + info!("Loop read_frame"); + if let Some(frame) = self.parse_frame()? { + info!("Frame parsed"); + return Ok(Some(frame)); + } + + if 0 == self.stream.read_buf(&mut self.buffer).await? { + if self.buffer.is_empty() { + return Ok(None); + } else { + return Err("connection reset by peer".into()); + } + } + } + } + + fn parse_frame(&mut self) -> crate::Result> { + use frame::Error::Incomplete; + + let mut buf = Cursor::new(&self.buffer[..]); + debug!(?buf); + + match Frame::check(&mut buf) { + Ok(_) => { + info!("Frame::check succesful"); + let len = buf.position() as usize; + + buf.set_position(0); + + let frame = Frame::parse(&mut buf)?; + self.buffer.advance(len); + + Ok(Some(frame)) + } + Err(Incomplete) => Ok(None), + Err(e) => Err(e.into()), + } + } +} diff --git a/problem_02/src/frame.rs b/problem_02/src/frame.rs new file mode 100644 index 0000000..4eb8c37 --- /dev/null +++ b/problem_02/src/frame.rs @@ -0,0 +1,122 @@ +use bytes::Buf; +use std::fmt; +use std::io::Cursor; +use std::num::TryFromIntError; +use std::string::FromUtf8Error; +use tracing::{debug, error, info}; + +#[derive(Clone, Debug)] +pub enum Frame { + Insert { timestamp: i32, price: i32 }, + Query { mintime: i32, maxtime: i32 }, +} + +#[derive(Debug)] +pub enum Error { + Incomplete, + Other(crate::Error), +} + +impl Frame { + pub fn check(src: &mut Cursor<&[u8]>) -> Result<(), Error> { + info!("Check frame"); + match get_u8(src)? { + b'I' => { + get_line(src)?; + Ok(()) + } + b'Q' => { + get_line(src)?; + Ok(()) + } + actual => Err(format!("protocol error; invalid frame type byte `{}`", actual).into()), + } + } + + pub fn parse(src: &mut Cursor<&[u8]>) -> Result { + info!("Parsing frame"); + match get_u8(src)? { + b'I' => { + info!("Insert message"); + let line = get_line(src)?; + debug!(?line); + Ok(Frame::Insert { + timestamp: get_decimal(&line[1..=4])?, + price: get_decimal(&line[5..=8])?, + }) + } + b'Q' => { + let line = get_line(src)?; + + Ok(Frame::Query { + mintime: get_decimal(&line[1..=4])?, + maxtime: get_decimal(&line[5..=8])?, + }) + } + _ => unimplemented!(), + } + } +} + +fn get_decimal(src: &[u8]) -> Result { + debug!(?src); + + if let Ok(number) = <[u8; 4]>::try_from(src) { + return Ok(i32::from_be_bytes(number)); + }; + + Err("protocol error; invalid frame format".into()) +} + +fn get_u8(src: &mut Cursor<&[u8]>) -> Result { + if !src.has_remaining() { + error!("Incomplete frame"); + return Err(Error::Incomplete); + } + + Ok(src.get_u8()) +} + +fn get_line<'a>(src: &mut Cursor<&'a [u8]>) -> Result<&'a [u8], Error> { + if src.get_ref().len() == 9 { + src.set_position(9); + return Ok(&src.get_ref()[..]); + } + + Err(Error::Incomplete) +} + +impl From for Error { + fn from(src: String) -> Error { + Error::Other(src.into()) + } +} + +impl From<&str> for Error { + fn from(src: &str) -> Error { + src.to_string().into() + } +} + +impl From for Error { + fn from(_src: FromUtf8Error) -> Error { + "protocol error; invalid frame format".into() + } +} + +impl From for Error { + fn from(_src: TryFromIntError) -> Error { + "protocol error; invalid frame format".into() + } +} + +impl std::error::Error for Error {} + +impl fmt::Display for Error { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match self { + Error::Incomplete => "stream ended early".fmt(fmt), + Error::Other(err) => err.fmt(fmt), + } + } +} diff --git a/problem_02/src/lib.rs b/problem_02/src/lib.rs new file mode 100644 index 0000000..82f0795 --- /dev/null +++ b/problem_02/src/lib.rs @@ -0,0 +1,15 @@ +mod connection; +pub use connection::Connection; + +pub mod frame; +pub use frame::Frame; + +pub mod server; + +mod shutdown; +use shutdown::Shutdown; + +pub const DEFAULT_PORT: u16 = 6379; + +pub type Error = Box; +pub type Result = std::result::Result; diff --git a/problem_02/src/server.rs b/problem_02/src/server.rs new file mode 100644 index 0000000..e7a4dea --- /dev/null +++ b/problem_02/src/server.rs @@ -0,0 +1,135 @@ +use crate::{Connection, Shutdown}; + +use std::future::Future; +use std::sync::Arc; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::{broadcast, mpsc, Semaphore}; +use tokio::time::{self, Duration}; +use tracing::{debug, error, info}; + +struct Listener { + listener: TcpListener, + limit_connections: Arc, + notify_shutdown: broadcast::Sender<()>, + shutdown_complete_rx: mpsc::Receiver<()>, + shutdown_complete_tx: mpsc::Sender<()>, +} + +struct Handler { + connection: Connection, + shutdown: Shutdown, + _shutdown_complete: mpsc::Sender<()>, +} + +const MAX_CONNECTIONS: usize = 5; + +pub async fn run(listener: TcpListener, shutdown: impl Future) -> crate::Result<()> { + let (notify_shutdown, _) = broadcast::channel(1); + let (shutdown_complete_tx, shutdown_complete_rx) = mpsc::channel(1); + + let mut server = Listener { + listener, + limit_connections: Arc::new(Semaphore::new(MAX_CONNECTIONS)), + notify_shutdown, + shutdown_complete_tx, + shutdown_complete_rx, + }; + + tokio::select! { + res = server.run() => { + if let Err(err) = res { + error!(cause = %err, "failed to accept"); + } + } + _ = shutdown => { + info!("shutting down"); + } + } + + let Listener { + mut shutdown_complete_rx, + shutdown_complete_tx, + notify_shutdown, + .. + } = server; + + drop(notify_shutdown); + drop(shutdown_complete_tx); + + let _ = shutdown_complete_rx.recv().await; + + Ok(()) +} + +impl Listener { + async fn run(&mut self) -> crate::Result<()> { + info!("accepting inbound connections"); + + loop { + let permit = self + .limit_connections + .clone() + .acquire_owned() + .await + .unwrap(); + + let socket = self.accept().await?; + + let mut handler = Handler { + connection: Connection::new(socket), + shutdown: Shutdown::new(self.notify_shutdown.subscribe()), + _shutdown_complete: self.shutdown_complete_tx.clone(), + }; + + tokio::spawn(async move { + if let Err(err) = handler.run().await { + error!(cause = ?err, "connection error"); + } + drop(permit); + }); + } + } + + async fn accept(&mut self) -> crate::Result { + let mut backoff = 1; + + loop { + match self.listener.accept().await { + Ok((socket, _)) => return Ok(socket), + Err(err) => { + if backoff > 64 { + return Err(err.into()); + } + } + } + + time::sleep(Duration::from_secs(backoff)).await; + + backoff *= 2; + } + } +} + +impl Handler { + async fn run(&mut self) -> crate::Result<()> { + while !self.shutdown.is_shutdown() { + let maybe_frame = tokio::select! { + res = self.connection.read_frame() => res?, + _ = self.shutdown.recv() => { + return Ok(()); + } + }; + + debug!(?maybe_frame); + + let frame = match maybe_frame { + Some(frame) => frame, + None => return Ok(()), + }; + + debug!(?frame); + } + + Ok(()) + } +} diff --git a/problem_02/src/shutdown.rs b/problem_02/src/shutdown.rs new file mode 100644 index 0000000..1c86f83 --- /dev/null +++ b/problem_02/src/shutdown.rs @@ -0,0 +1,30 @@ +use tokio::sync::broadcast; + +#[derive(Debug)] +pub(crate) struct Shutdown { + shutdown: bool, + notify: broadcast::Receiver<()>, +} + +impl Shutdown { + pub(crate) fn new(notify: broadcast::Receiver<()>) -> Shutdown { + Shutdown { + shutdown: false, + notify, + } + } + + pub(crate) fn is_shutdown(&self) -> bool { + self.shutdown + } + + pub(crate) async fn recv(&mut self) { + if self.shutdown { + return; + } + + let _ = self.notify.recv().await; + + self.shutdown = true; + } +}