From 20d9acf9170ac83e20d9ccf6e2ec31b8636eb107 Mon Sep 17 00:00:00 2001 From: Bastian Gruber Date: Wed, 3 May 2023 09:42:27 +0200 Subject: [PATCH] Copy/paste codec --- problem_05/Cargo.toml | 1 + problem_05/bin/server.rs | 11 +- problem_05/src/lib.rs | 3 + problem_05/src/strict_lines_codec.rs | 213 +++++++++++++++++++++++++++ 4 files changed, 223 insertions(+), 5 deletions(-) create mode 100644 problem_05/src/lib.rs create mode 100644 problem_05/src/strict_lines_codec.rs diff --git a/problem_05/Cargo.toml b/problem_05/Cargo.toml index b8763d1..d1de01a 100644 --- a/problem_05/Cargo.toml +++ b/problem_05/Cargo.toml @@ -18,3 +18,4 @@ tokio = { version = "1.14.0", features = ["full"] } tokio-util = { version = "0.7.4", features = ["codec"] } tracing = "0.1.37" tracing-subscriber = "0.3.17" +bytes = "1.4.0" diff --git a/problem_05/bin/server.rs b/problem_05/bin/server.rs index 4adffdb..f0f18b3 100644 --- a/problem_05/bin/server.rs +++ b/problem_05/bin/server.rs @@ -1,7 +1,8 @@ use fancy_regex::Regex; use futures::{SinkExt, StreamExt}; +use problem_05::StrictLinesCodec; use tokio::net::{TcpListener, TcpStream}; -use tokio_util::codec::{FramedRead, FramedWrite, LinesCodec}; +use tokio_util::codec::{FramedRead, FramedWrite}; use tracing::{error, info}; const DEFAULT_IP: &str = "0.0.0.0"; @@ -43,12 +44,12 @@ pub async fn main() -> Result<()> { pub async fn handle_request(socket: TcpStream, upstream: TcpStream) -> Result<()> { let (client_read, client_write) = socket.into_split(); - let mut framed_client_read = FramedRead::new(client_read, LinesCodec::new()); - let mut framed_client_write = FramedWrite::new(client_write, LinesCodec::new()); + let mut framed_client_read = FramedRead::new(client_read, StrictLinesCodec::new()); + let mut framed_client_write = FramedWrite::new(client_write, StrictLinesCodec::new()); let (server_read, server_write) = upstream.into_split(); - let mut framed_server_read = FramedRead::new(server_read, LinesCodec::new()); - let mut framed_server_write = FramedWrite::new(server_write, LinesCodec::new()); + let mut framed_server_read = FramedRead::new(server_read, StrictLinesCodec::new()); + let mut framed_server_write = FramedWrite::new(server_write, StrictLinesCodec::new()); loop { tokio::select! { diff --git a/problem_05/src/lib.rs b/problem_05/src/lib.rs new file mode 100644 index 0000000..034a521 --- /dev/null +++ b/problem_05/src/lib.rs @@ -0,0 +1,3 @@ +mod strict_lines_codec; + +pub use strict_lines_codec::*; \ No newline at end of file diff --git a/problem_05/src/strict_lines_codec.rs b/problem_05/src/strict_lines_codec.rs new file mode 100644 index 0000000..2740232 --- /dev/null +++ b/problem_05/src/strict_lines_codec.rs @@ -0,0 +1,213 @@ +use bytes::{Buf, BufMut, BytesMut}; +use std::{cmp, fmt, io, str, usize}; +use tokio_util::codec::{Decoder, Encoder}; + +/// A simple [`Decoder`] and [`Encoder`] implementation that splits up data into lines. +/// +/// The difference with tokio_util::codec::LinesCodec is that this one will return `None` +/// if there is a missing newline right before EOF. +/// +/// [`Decoder`]: crate::codec::Decoder +/// [`Encoder`]: crate::codec::Encoder +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct StrictLinesCodec { + // Stored index of the next index to examine for a `\n` character. + // This is used to optimize searching. + // For example, if `decode` was called with `abc`, it would hold `3`, + // because that is the next index to examine. + // The next time `decode` is called with `abcde\n`, the method will + // only look at `de\n` before returning. + next_index: usize, + + /// The maximum length for a given line. If `usize::MAX`, lines will be + /// read until a `\n` character is reached. + max_length: usize, + + /// Are we currently discarding the remainder of a line which was over + /// the length limit? + is_discarding: bool, +} + +impl StrictLinesCodec { + /// Returns a `LinesCodec` for splitting up data into lines. + /// + /// # Note + /// + /// The returned `LinesCodec` will not have an upper bound on the length + /// of a buffered line. See the documentation for [`new_with_max_length`] + /// for information on why this could be a potential security risk. + /// + /// [`new_with_max_length`]: crate::codec::LinesCodec::new_with_max_length() + pub fn new() -> StrictLinesCodec { + StrictLinesCodec { + next_index: 0, + max_length: usize::MAX, + is_discarding: false, + } + } + + /// Returns a `LinesCodec` with a maximum line length limit. + /// + /// If this is set, calls to `LinesCodec::decode` will return a + /// [`LinesCodecError`] when a line exceeds the length limit. Subsequent calls + /// will discard up to `limit` bytes from that line until a newline + /// character is reached, returning `None` until the line over the limit + /// has been fully discarded. After that point, calls to `decode` will + /// function as normal. + /// + /// # Note + /// + /// Setting a length limit is highly recommended for any `LinesCodec` which + /// will be exposed to untrusted input. Otherwise, the size of the buffer + /// that holds the line currently being read is unbounded. An attacker could + /// exploit this unbounded buffer by sending an unbounded amount of input + /// without any `\n` characters, causing unbounded memory consumption. + /// + /// [`LinesCodecError`]: crate::codec::LinesCodecError + pub fn new_with_max_length(max_length: usize) -> Self { + StrictLinesCodec { + max_length, + ..StrictLinesCodec::new() + } + } + + /// Returns the maximum line length when decoding. + /// + /// ``` + /// use std::usize; + /// use tokio_util::codec::LinesCodec; + /// + /// let codec = LinesCodec::new(); + /// assert_eq!(codec.max_length(), usize::MAX); + /// ``` + /// ``` + /// use tokio_util::codec::LinesCodec; + /// + /// let codec = LinesCodec::new_with_max_length(256); + /// assert_eq!(codec.max_length(), 256); + /// ``` + pub fn max_length(&self) -> usize { + self.max_length + } +} + +fn utf8(buf: &[u8]) -> Result<&str, io::Error> { + str::from_utf8(buf) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Unable to decode input as UTF8")) +} + +fn without_carriage_return(s: &[u8]) -> &[u8] { + if let Some(&b'\r') = s.last() { + &s[..s.len() - 1] + } else { + s + } +} + +impl Decoder for StrictLinesCodec { + type Item = String; + type Error = LinesCodecError; + + fn decode(&mut self, buf: &mut BytesMut) -> Result, LinesCodecError> { + loop { + // Determine how far into the buffer we'll search for a newline. If + // there's no max_length set, we'll read to the end of the buffer. + let read_to = cmp::min(self.max_length.saturating_add(1), buf.len()); + + let newline_offset = buf[self.next_index..read_to] + .iter() + .position(|b| *b == b'\n'); + + match (self.is_discarding, newline_offset) { + (true, Some(offset)) => { + // If we found a newline, discard up to that offset and + // then stop discarding. On the next iteration, we'll try + // to read a line normally. + buf.advance(offset + self.next_index + 1); + self.is_discarding = false; + self.next_index = 0; + } + (true, None) => { + // Otherwise, we didn't find a newline, so we'll discard + // everything we read. On the next iteration, we'll continue + // discarding up to max_len bytes unless we find a newline. + buf.advance(read_to); + self.next_index = 0; + if buf.is_empty() { + return Ok(None); + } + } + (false, Some(offset)) => { + // Found a line! + let newline_index = offset + self.next_index; + self.next_index = 0; + let line = buf.split_to(newline_index + 1); + let line = &line[..line.len() - 1]; + let line = without_carriage_return(line); + let line = utf8(line)?; + return Ok(Some(line.to_string())); + } + (false, None) if buf.len() > self.max_length => { + // Reached the maximum length without finding a + // newline, return an error and start discarding on the + // next call. + self.is_discarding = true; + return Err(LinesCodecError::MaxLineLengthExceeded); + } + (false, None) => { + // We didn't find a line or reach the length limit, so the next + // call will resume searching at the current offset. + self.next_index = read_to; + return Ok(None); + } + } + } + } +} + +impl Encoder for StrictLinesCodec +where + T: AsRef, +{ + type Error = LinesCodecError; + + fn encode(&mut self, line: T, buf: &mut BytesMut) -> Result<(), LinesCodecError> { + let line = line.as_ref(); + buf.reserve(line.len() + 1); + buf.put(line.as_bytes()); + buf.put_u8(b'\n'); + Ok(()) + } +} + +impl Default for StrictLinesCodec { + fn default() -> Self { + Self::new() + } +} + +/// An error occurred while encoding or decoding a line. +#[derive(Debug)] +pub enum LinesCodecError { + /// The maximum line length was exceeded. + MaxLineLengthExceeded, + /// An IO error occurred. + Io(io::Error), +} + +impl fmt::Display for LinesCodecError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + LinesCodecError::MaxLineLengthExceeded => write!(f, "max line length exceeded"), + LinesCodecError::Io(e) => write!(f, "{}", e), + } + } +} + +impl From for LinesCodecError { + fn from(e: io::Error) -> LinesCodecError { + LinesCodecError::Io(e) + } +} + +impl std::error::Error for LinesCodecError {} \ No newline at end of file