From 7f8e78d831832d64c7a4bea1930fcec6602fa802 Mon Sep 17 00:00:00 2001 From: Bastian Gruber Date: Thu, 17 Jul 2025 11:57:26 -0300 Subject: [PATCH] fix message construction and path forwarding --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/error.rs | 2 +- src/handlers/health.rs | 2 +- src/handlers/keys.rs | 2 +- src/handlers/mod.rs | 2 +- src/handlers/ohttp.rs | 80 +++++++++++++++++++++++--------------- src/key_manager.rs | 2 +- src/main.rs | 4 +- src/metrics.rs | 2 +- src/middleware/logging.rs | 2 +- src/middleware/security.rs | 2 +- 12 files changed, 61 insertions(+), 43 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index dbc4767..60f2c14 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1785,7 +1785,7 @@ dependencies = [ [[package]] name = "ohttp-gateway" -version = "0.1.1" +version = "0.1.2" dependencies = [ "anyhow", "axum", diff --git a/Cargo.toml b/Cargo.toml index 5600513..2ec25f1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] authors = ["Bastian Gruber "] -version = "0.1.1" +version = "0.1.2" edition = "2024" name = "ohttp-gateway" categories = ["web-programming", "web-programming::http-server"] diff --git a/src/error.rs b/src/error.rs index ec488ce..d473ce1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,7 +1,7 @@ use axum::{ + Json, http::StatusCode, response::{IntoResponse, Response}, - Json, }; use serde_json::json; use thiserror::Error; diff --git a/src/handlers/health.rs b/src/handlers/health.rs index 6693ba1..6e99921 100644 --- a/src/handlers/health.rs +++ b/src/handlers/health.rs @@ -1,5 +1,5 @@ use crate::{error::GatewayError, state::AppState}; -use axum::{extract::State, Json}; +use axum::{Json, extract::State}; use serde_json::json; use std::time::Duration; diff --git a/src/handlers/keys.rs b/src/handlers/keys.rs index aee1668..8b58502 100644 --- a/src/handlers/keys.rs +++ b/src/handlers/keys.rs @@ -1,7 +1,7 @@ use crate::AppState; use axum::{ extract::State, - http::{header, HeaderName, StatusCode}, + http::{HeaderName, StatusCode, header}, response::{IntoResponse, Response}, }; use chrono::Utc; diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs index 21d81fa..e00a63f 100644 --- a/src/handlers/mod.rs +++ b/src/handlers/mod.rs @@ -4,8 +4,8 @@ pub mod ohttp; use crate::state::AppState; use axum::{ - routing::{get, post}, Router, + routing::{get, post}, }; pub fn routes() -> Router { diff --git a/src/handlers/ohttp.rs b/src/handlers/ohttp.rs index 49505d0..60256bd 100644 --- a/src/handlers/ohttp.rs +++ b/src/handlers/ohttp.rs @@ -2,7 +2,7 @@ use crate::{error::GatewayError, state::AppState}; use axum::{ body::{Body, Bytes}, extract::State, - http::{header, HeaderMap, StatusCode}, + http::{HeaderMap, StatusCode, header}, response::{IntoResponse, Response}, }; use bhttp::{Message, Mode}; @@ -129,11 +129,7 @@ async fn handle_ohttp_request_inner( /// Extract key ID from OHTTP request (first byte after version) fn extract_key_id_from_request(body: &[u8]) -> Option { // OHTTP request format: version(1) + key_id(1) + kem_id(2) + kdf_id(2) + aead_id(2) + enc + ciphertext - if body.len() > 1 { - Some(body[1]) - } else { - None - } + if body.len() > 1 { Some(body[1]) } else { None } } /// Validate the incoming OHTTP request @@ -208,6 +204,35 @@ fn validate_and_transform_request( }) .ok_or_else(|| GatewayError::InvalidRequest("Missing host/authority".to_string()))?; + // Extract and clean the path + let raw_path = control.path().unwrap_or(b"/"); + let path_str = String::from_utf8_lossy(raw_path); + + // Clean up the path - remove any absolute URL components + let clean_path = if path_str.starts_with("http://") || path_str.starts_with("https://") { + // Extract just the path from absolute URL + if let Some(idx) = path_str + .find('/') + .and_then(|i| path_str[i + 2..].find('/').map(|j| i + 2 + j)) + { + path_str[idx..].as_bytes() + } else { + b"/" + } + } else if path_str.contains(':') && !path_str.starts_with('/') { + // Path might contain host:port, clean it + b"/" + } else { + raw_path + }; + + debug!( + "Request details - host: {}, original_path: {}, clean_path: {}", + host, + path_str, + String::from_utf8_lossy(clean_path) + ); + // Check if origin is allowed if !state.config.is_origin_allowed(&host) { warn!("Blocked request to forbidden origin: {host}"); @@ -225,21 +250,10 @@ fn validate_and_transform_request( // Clone the message to modify it let mut new_message = Message::request( - Vec::from(control.method().unwrap_or(b"GET")), - Vec::from( - format!( - "{}://{}{}", - rewrite.scheme, - rewrite.host, - control - .path() - .map(|p| String::from_utf8_lossy(p)) - .unwrap_or("/".into()) - ) - .as_bytes(), - ), - Vec::from(control.scheme().unwrap_or(rewrite.scheme.as_bytes())), - Vec::from(rewrite.host.as_bytes()), + Vec::from(control.method().unwrap_or(b"GET")), // method + Vec::from(rewrite.scheme.as_bytes()), // scheme + Vec::from(rewrite.host.as_bytes()), // authority + Vec::from(clean_path), // path ); // Copy all headers except Host and Authority @@ -291,12 +305,12 @@ async fn forward_to_backend( // Build the backend URI let uri = if let Some(host) = host { - // Check for rewrites - if let Some(rewrite) = state.config.get_rewrite(&host) { - format!("{}://{}{}", rewrite.scheme, rewrite.host, path) - } else { - build_backend_uri(&state.config.backend_url, &path)? - } + // Extract scheme, handling various formats + let scheme = control + .scheme() + .map(|s| String::from_utf8_lossy(s).into_owned()) + .unwrap_or_else(|| "http".to_string()); + format!("{scheme}://{host}{path}") } else { build_backend_uri(&state.config.backend_url, &path)? }; @@ -360,27 +374,31 @@ fn convert_method_to_reqwest(method: &[u8]) -> reqwest::Method { fn build_backend_uri(backend_url: &str, path: &str) -> Result { let base_url = backend_url.trim_end_matches('/'); - let path = if path.starts_with('/') { + let clean_path = if path.starts_with('/') { path } else { &format!("/{path}") }; // Validate path to prevent SSRF attacks - if path.contains("..") || path.contains("//") { + if clean_path.contains("..") || clean_path.contains("//") { return Err(GatewayError::InvalidRequest( "Invalid path detected".to_string(), )); } // Additional validation for suspicious patterns - if path.contains('\0') || path.contains('\r') || path.contains('\n') { + if clean_path.contains('\0') || clean_path.contains('\r') || clean_path.contains('\n') { return Err(GatewayError::InvalidRequest( "Invalid characters in path".to_string(), )); } - Ok(format!("{base_url}{path}")) + // Build the final URI with explicit formatting + let final_uri = format!("{base_url}{clean_path}"); + debug!("build_backend_uri: final_uri = '{}'", final_uri); + + Ok(final_uri) } fn should_forward_header(name: &str) -> bool { diff --git a/src/key_manager.rs b/src/key_manager.rs index 0eb28fe..cefb35b 100644 --- a/src/key_manager.rs +++ b/src/key_manager.rs @@ -1,7 +1,7 @@ use chrono::{DateTime, Utc}; use ohttp::{ - hpke::{Aead, Kdf, Kem}, KeyConfig, Server as OhttpServer, SymmetricSuite, + hpke::{Aead, Kdf, Kem}, }; use serde::{Deserialize, Serialize}; use std::collections::HashMap; diff --git a/src/main.rs b/src/main.rs index 122902a..2cb8dd9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,7 +8,7 @@ mod state; use crate::config::{AppConfig, LogFormat}; use crate::state::AppState; -use axum::{middleware as axum_middleware, Router}; +use axum::{Router, middleware as axum_middleware}; use std::net::SocketAddr; use std::time::Duration; use tokio::net::TcpListener; @@ -74,7 +74,7 @@ async fn main() -> Result<(), Box> { } fn initialize_tracing(config: &AppConfig) { - use tracing_subscriber::{fmt, EnvFilter}; + use tracing_subscriber::{EnvFilter, fmt}; let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&config.log_level)); diff --git a/src/metrics.rs b/src/metrics.rs index 15ea750..13816df 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -1,4 +1,4 @@ -use prometheus::{register_counter, register_gauge, register_histogram, Counter, Gauge, Histogram}; +use prometheus::{Counter, Gauge, Histogram, register_counter, register_gauge, register_histogram}; #[derive(Clone)] pub struct AppMetrics { diff --git a/src/middleware/logging.rs b/src/middleware/logging.rs index dad9974..07da62f 100644 --- a/src/middleware/logging.rs +++ b/src/middleware/logging.rs @@ -1,6 +1,6 @@ use axum::{body::Body, extract::Request, http::StatusCode, middleware::Next, response::Response}; use std::time::Instant; -use tracing::{info, warn, Instrument}; +use tracing::{Instrument, info, warn}; use uuid::Uuid; pub async fn logging_middleware( diff --git a/src/middleware/security.rs b/src/middleware/security.rs index 60fa017..f010d5f 100644 --- a/src/middleware/security.rs +++ b/src/middleware/security.rs @@ -1,7 +1,7 @@ use axum::{ body::Body, extract::{ConnectInfo, Request, State}, - http::{header, HeaderMap, StatusCode}, + http::{HeaderMap, StatusCode, header}, middleware::Next, response::{IntoResponse, Response}, };