Copy/paste codec
This commit is contained in:
parent
d6118657a8
commit
20d9acf917
4 changed files with 223 additions and 5 deletions
|
|
@ -18,3 +18,4 @@ tokio = { version = "1.14.0", features = ["full"] }
|
||||||
tokio-util = { version = "0.7.4", features = ["codec"] }
|
tokio-util = { version = "0.7.4", features = ["codec"] }
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-subscriber = "0.3.17"
|
tracing-subscriber = "0.3.17"
|
||||||
|
bytes = "1.4.0"
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
use fancy_regex::Regex;
|
use fancy_regex::Regex;
|
||||||
use futures::{SinkExt, StreamExt};
|
use futures::{SinkExt, StreamExt};
|
||||||
|
use problem_05::StrictLinesCodec;
|
||||||
use tokio::net::{TcpListener, TcpStream};
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
use tokio_util::codec::{FramedRead, FramedWrite, LinesCodec};
|
use tokio_util::codec::{FramedRead, FramedWrite};
|
||||||
use tracing::{error, info};
|
use tracing::{error, info};
|
||||||
|
|
||||||
const DEFAULT_IP: &str = "0.0.0.0";
|
const DEFAULT_IP: &str = "0.0.0.0";
|
||||||
|
|
@ -43,12 +44,12 @@ pub async fn main() -> Result<()> {
|
||||||
|
|
||||||
pub async fn handle_request(socket: TcpStream, upstream: TcpStream) -> Result<()> {
|
pub async fn handle_request(socket: TcpStream, upstream: TcpStream) -> Result<()> {
|
||||||
let (client_read, client_write) = socket.into_split();
|
let (client_read, client_write) = socket.into_split();
|
||||||
let mut framed_client_read = FramedRead::new(client_read, LinesCodec::new());
|
let mut framed_client_read = FramedRead::new(client_read, StrictLinesCodec::new());
|
||||||
let mut framed_client_write = FramedWrite::new(client_write, LinesCodec::new());
|
let mut framed_client_write = FramedWrite::new(client_write, StrictLinesCodec::new());
|
||||||
|
|
||||||
let (server_read, server_write) = upstream.into_split();
|
let (server_read, server_write) = upstream.into_split();
|
||||||
let mut framed_server_read = FramedRead::new(server_read, LinesCodec::new());
|
let mut framed_server_read = FramedRead::new(server_read, StrictLinesCodec::new());
|
||||||
let mut framed_server_write = FramedWrite::new(server_write, LinesCodec::new());
|
let mut framed_server_write = FramedWrite::new(server_write, StrictLinesCodec::new());
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
|
|
|
||||||
3
problem_05/src/lib.rs
Normal file
3
problem_05/src/lib.rs
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
mod strict_lines_codec;
|
||||||
|
|
||||||
|
pub use strict_lines_codec::*;
|
||||||
213
problem_05/src/strict_lines_codec.rs
Normal file
213
problem_05/src/strict_lines_codec.rs
Normal file
|
|
@ -0,0 +1,213 @@
|
||||||
|
use bytes::{Buf, BufMut, BytesMut};
|
||||||
|
use std::{cmp, fmt, io, str, usize};
|
||||||
|
use tokio_util::codec::{Decoder, Encoder};
|
||||||
|
|
||||||
|
/// A simple [`Decoder`] and [`Encoder`] implementation that splits up data into lines.
|
||||||
|
///
|
||||||
|
/// The difference with tokio_util::codec::LinesCodec is that this one will return `None`
|
||||||
|
/// if there is a missing newline right before EOF.
|
||||||
|
///
|
||||||
|
/// [`Decoder`]: crate::codec::Decoder
|
||||||
|
/// [`Encoder`]: crate::codec::Encoder
|
||||||
|
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
|
||||||
|
pub struct StrictLinesCodec {
|
||||||
|
// Stored index of the next index to examine for a `\n` character.
|
||||||
|
// This is used to optimize searching.
|
||||||
|
// For example, if `decode` was called with `abc`, it would hold `3`,
|
||||||
|
// because that is the next index to examine.
|
||||||
|
// The next time `decode` is called with `abcde\n`, the method will
|
||||||
|
// only look at `de\n` before returning.
|
||||||
|
next_index: usize,
|
||||||
|
|
||||||
|
/// The maximum length for a given line. If `usize::MAX`, lines will be
|
||||||
|
/// read until a `\n` character is reached.
|
||||||
|
max_length: usize,
|
||||||
|
|
||||||
|
/// Are we currently discarding the remainder of a line which was over
|
||||||
|
/// the length limit?
|
||||||
|
is_discarding: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StrictLinesCodec {
|
||||||
|
/// Returns a `LinesCodec` for splitting up data into lines.
|
||||||
|
///
|
||||||
|
/// # Note
|
||||||
|
///
|
||||||
|
/// The returned `LinesCodec` will not have an upper bound on the length
|
||||||
|
/// of a buffered line. See the documentation for [`new_with_max_length`]
|
||||||
|
/// for information on why this could be a potential security risk.
|
||||||
|
///
|
||||||
|
/// [`new_with_max_length`]: crate::codec::LinesCodec::new_with_max_length()
|
||||||
|
pub fn new() -> StrictLinesCodec {
|
||||||
|
StrictLinesCodec {
|
||||||
|
next_index: 0,
|
||||||
|
max_length: usize::MAX,
|
||||||
|
is_discarding: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a `LinesCodec` with a maximum line length limit.
|
||||||
|
///
|
||||||
|
/// If this is set, calls to `LinesCodec::decode` will return a
|
||||||
|
/// [`LinesCodecError`] when a line exceeds the length limit. Subsequent calls
|
||||||
|
/// will discard up to `limit` bytes from that line until a newline
|
||||||
|
/// character is reached, returning `None` until the line over the limit
|
||||||
|
/// has been fully discarded. After that point, calls to `decode` will
|
||||||
|
/// function as normal.
|
||||||
|
///
|
||||||
|
/// # Note
|
||||||
|
///
|
||||||
|
/// Setting a length limit is highly recommended for any `LinesCodec` which
|
||||||
|
/// will be exposed to untrusted input. Otherwise, the size of the buffer
|
||||||
|
/// that holds the line currently being read is unbounded. An attacker could
|
||||||
|
/// exploit this unbounded buffer by sending an unbounded amount of input
|
||||||
|
/// without any `\n` characters, causing unbounded memory consumption.
|
||||||
|
///
|
||||||
|
/// [`LinesCodecError`]: crate::codec::LinesCodecError
|
||||||
|
pub fn new_with_max_length(max_length: usize) -> Self {
|
||||||
|
StrictLinesCodec {
|
||||||
|
max_length,
|
||||||
|
..StrictLinesCodec::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the maximum line length when decoding.
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// use std::usize;
|
||||||
|
/// use tokio_util::codec::LinesCodec;
|
||||||
|
///
|
||||||
|
/// let codec = LinesCodec::new();
|
||||||
|
/// assert_eq!(codec.max_length(), usize::MAX);
|
||||||
|
/// ```
|
||||||
|
/// ```
|
||||||
|
/// use tokio_util::codec::LinesCodec;
|
||||||
|
///
|
||||||
|
/// let codec = LinesCodec::new_with_max_length(256);
|
||||||
|
/// assert_eq!(codec.max_length(), 256);
|
||||||
|
/// ```
|
||||||
|
pub fn max_length(&self) -> usize {
|
||||||
|
self.max_length
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn utf8(buf: &[u8]) -> Result<&str, io::Error> {
|
||||||
|
str::from_utf8(buf)
|
||||||
|
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Unable to decode input as UTF8"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn without_carriage_return(s: &[u8]) -> &[u8] {
|
||||||
|
if let Some(&b'\r') = s.last() {
|
||||||
|
&s[..s.len() - 1]
|
||||||
|
} else {
|
||||||
|
s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Decoder for StrictLinesCodec {
|
||||||
|
type Item = String;
|
||||||
|
type Error = LinesCodecError;
|
||||||
|
|
||||||
|
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<String>, LinesCodecError> {
|
||||||
|
loop {
|
||||||
|
// Determine how far into the buffer we'll search for a newline. If
|
||||||
|
// there's no max_length set, we'll read to the end of the buffer.
|
||||||
|
let read_to = cmp::min(self.max_length.saturating_add(1), buf.len());
|
||||||
|
|
||||||
|
let newline_offset = buf[self.next_index..read_to]
|
||||||
|
.iter()
|
||||||
|
.position(|b| *b == b'\n');
|
||||||
|
|
||||||
|
match (self.is_discarding, newline_offset) {
|
||||||
|
(true, Some(offset)) => {
|
||||||
|
// If we found a newline, discard up to that offset and
|
||||||
|
// then stop discarding. On the next iteration, we'll try
|
||||||
|
// to read a line normally.
|
||||||
|
buf.advance(offset + self.next_index + 1);
|
||||||
|
self.is_discarding = false;
|
||||||
|
self.next_index = 0;
|
||||||
|
}
|
||||||
|
(true, None) => {
|
||||||
|
// Otherwise, we didn't find a newline, so we'll discard
|
||||||
|
// everything we read. On the next iteration, we'll continue
|
||||||
|
// discarding up to max_len bytes unless we find a newline.
|
||||||
|
buf.advance(read_to);
|
||||||
|
self.next_index = 0;
|
||||||
|
if buf.is_empty() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(false, Some(offset)) => {
|
||||||
|
// Found a line!
|
||||||
|
let newline_index = offset + self.next_index;
|
||||||
|
self.next_index = 0;
|
||||||
|
let line = buf.split_to(newline_index + 1);
|
||||||
|
let line = &line[..line.len() - 1];
|
||||||
|
let line = without_carriage_return(line);
|
||||||
|
let line = utf8(line)?;
|
||||||
|
return Ok(Some(line.to_string()));
|
||||||
|
}
|
||||||
|
(false, None) if buf.len() > self.max_length => {
|
||||||
|
// Reached the maximum length without finding a
|
||||||
|
// newline, return an error and start discarding on the
|
||||||
|
// next call.
|
||||||
|
self.is_discarding = true;
|
||||||
|
return Err(LinesCodecError::MaxLineLengthExceeded);
|
||||||
|
}
|
||||||
|
(false, None) => {
|
||||||
|
// We didn't find a line or reach the length limit, so the next
|
||||||
|
// call will resume searching at the current offset.
|
||||||
|
self.next_index = read_to;
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Encoder<T> for StrictLinesCodec
|
||||||
|
where
|
||||||
|
T: AsRef<str>,
|
||||||
|
{
|
||||||
|
type Error = LinesCodecError;
|
||||||
|
|
||||||
|
fn encode(&mut self, line: T, buf: &mut BytesMut) -> Result<(), LinesCodecError> {
|
||||||
|
let line = line.as_ref();
|
||||||
|
buf.reserve(line.len() + 1);
|
||||||
|
buf.put(line.as_bytes());
|
||||||
|
buf.put_u8(b'\n');
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for StrictLinesCodec {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An error occurred while encoding or decoding a line.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum LinesCodecError {
|
||||||
|
/// The maximum line length was exceeded.
|
||||||
|
MaxLineLengthExceeded,
|
||||||
|
/// An IO error occurred.
|
||||||
|
Io(io::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for LinesCodecError {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
LinesCodecError::MaxLineLengthExceeded => write!(f, "max line length exceeded"),
|
||||||
|
LinesCodecError::Io(e) => write!(f, "{}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<io::Error> for LinesCodecError {
|
||||||
|
fn from(e: io::Error) -> LinesCodecError {
|
||||||
|
LinesCodecError::Io(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for LinesCodecError {}
|
||||||
Loading…
Reference in a new issue