fix message construction and path forwarding

This commit is contained in:
Bastian Gruber 2025-07-17 11:57:26 -03:00
parent b11ff4e598
commit 7f8e78d831
No known key found for this signature in database
GPG key ID: D2DF996A188CFBA2
12 changed files with 61 additions and 43 deletions

2
Cargo.lock generated
View file

@ -1785,7 +1785,7 @@ dependencies = [
[[package]]
name = "ohttp-gateway"
version = "0.1.1"
version = "0.1.2"
dependencies = [
"anyhow",
"axum",

View file

@ -1,6 +1,6 @@
[package]
authors = ["Bastian Gruber <foreach@me.com>"]
version = "0.1.1"
version = "0.1.2"
edition = "2024"
name = "ohttp-gateway"
categories = ["web-programming", "web-programming::http-server"]

View file

@ -1,7 +1,7 @@
use axum::{
Json,
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde_json::json;
use thiserror::Error;

View file

@ -1,5 +1,5 @@
use crate::{error::GatewayError, state::AppState};
use axum::{extract::State, Json};
use axum::{Json, extract::State};
use serde_json::json;
use std::time::Duration;

View file

@ -1,7 +1,7 @@
use crate::AppState;
use axum::{
extract::State,
http::{header, HeaderName, StatusCode},
http::{HeaderName, StatusCode, header},
response::{IntoResponse, Response},
};
use chrono::Utc;

View file

@ -4,8 +4,8 @@ pub mod ohttp;
use crate::state::AppState;
use axum::{
routing::{get, post},
Router,
routing::{get, post},
};
pub fn routes() -> Router<AppState> {

View file

@ -2,7 +2,7 @@ use crate::{error::GatewayError, state::AppState};
use axum::{
body::{Body, Bytes},
extract::State,
http::{header, HeaderMap, StatusCode},
http::{HeaderMap, StatusCode, header},
response::{IntoResponse, Response},
};
use bhttp::{Message, Mode};
@ -129,11 +129,7 @@ async fn handle_ohttp_request_inner(
/// Extract key ID from OHTTP request (first byte after version)
fn extract_key_id_from_request(body: &[u8]) -> Option<u8> {
// OHTTP request format: version(1) + key_id(1) + kem_id(2) + kdf_id(2) + aead_id(2) + enc + ciphertext
if body.len() > 1 {
Some(body[1])
} else {
None
}
if body.len() > 1 { Some(body[1]) } else { None }
}
/// Validate the incoming OHTTP request
@ -208,6 +204,35 @@ fn validate_and_transform_request(
})
.ok_or_else(|| GatewayError::InvalidRequest("Missing host/authority".to_string()))?;
// Extract and clean the path
let raw_path = control.path().unwrap_or(b"/");
let path_str = String::from_utf8_lossy(raw_path);
// Clean up the path - remove any absolute URL components
let clean_path = if path_str.starts_with("http://") || path_str.starts_with("https://") {
// Extract just the path from absolute URL
if let Some(idx) = path_str
.find('/')
.and_then(|i| path_str[i + 2..].find('/').map(|j| i + 2 + j))
{
path_str[idx..].as_bytes()
} else {
b"/"
}
} else if path_str.contains(':') && !path_str.starts_with('/') {
// Path might contain host:port, clean it
b"/"
} else {
raw_path
};
debug!(
"Request details - host: {}, original_path: {}, clean_path: {}",
host,
path_str,
String::from_utf8_lossy(clean_path)
);
// Check if origin is allowed
if !state.config.is_origin_allowed(&host) {
warn!("Blocked request to forbidden origin: {host}");
@ -225,21 +250,10 @@ fn validate_and_transform_request(
// Clone the message to modify it
let mut new_message = Message::request(
Vec::from(control.method().unwrap_or(b"GET")),
Vec::from(
format!(
"{}://{}{}",
rewrite.scheme,
rewrite.host,
control
.path()
.map(|p| String::from_utf8_lossy(p))
.unwrap_or("/".into())
)
.as_bytes(),
),
Vec::from(control.scheme().unwrap_or(rewrite.scheme.as_bytes())),
Vec::from(rewrite.host.as_bytes()),
Vec::from(control.method().unwrap_or(b"GET")), // method
Vec::from(rewrite.scheme.as_bytes()), // scheme
Vec::from(rewrite.host.as_bytes()), // authority
Vec::from(clean_path), // path
);
// Copy all headers except Host and Authority
@ -291,12 +305,12 @@ async fn forward_to_backend(
// Build the backend URI
let uri = if let Some(host) = host {
// Check for rewrites
if let Some(rewrite) = state.config.get_rewrite(&host) {
format!("{}://{}{}", rewrite.scheme, rewrite.host, path)
} else {
build_backend_uri(&state.config.backend_url, &path)?
}
// Extract scheme, handling various formats
let scheme = control
.scheme()
.map(|s| String::from_utf8_lossy(s).into_owned())
.unwrap_or_else(|| "http".to_string());
format!("{scheme}://{host}{path}")
} else {
build_backend_uri(&state.config.backend_url, &path)?
};
@ -360,27 +374,31 @@ fn convert_method_to_reqwest(method: &[u8]) -> reqwest::Method {
fn build_backend_uri(backend_url: &str, path: &str) -> Result<String, GatewayError> {
let base_url = backend_url.trim_end_matches('/');
let path = if path.starts_with('/') {
let clean_path = if path.starts_with('/') {
path
} else {
&format!("/{path}")
};
// Validate path to prevent SSRF attacks
if path.contains("..") || path.contains("//") {
if clean_path.contains("..") || clean_path.contains("//") {
return Err(GatewayError::InvalidRequest(
"Invalid path detected".to_string(),
));
}
// Additional validation for suspicious patterns
if path.contains('\0') || path.contains('\r') || path.contains('\n') {
if clean_path.contains('\0') || clean_path.contains('\r') || clean_path.contains('\n') {
return Err(GatewayError::InvalidRequest(
"Invalid characters in path".to_string(),
));
}
Ok(format!("{base_url}{path}"))
// Build the final URI with explicit formatting
let final_uri = format!("{base_url}{clean_path}");
debug!("build_backend_uri: final_uri = '{}'", final_uri);
Ok(final_uri)
}
fn should_forward_header(name: &str) -> bool {

View file

@ -1,7 +1,7 @@
use chrono::{DateTime, Utc};
use ohttp::{
hpke::{Aead, Kdf, Kem},
KeyConfig, Server as OhttpServer, SymmetricSuite,
hpke::{Aead, Kdf, Kem},
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

View file

@ -8,7 +8,7 @@ mod state;
use crate::config::{AppConfig, LogFormat};
use crate::state::AppState;
use axum::{middleware as axum_middleware, Router};
use axum::{Router, middleware as axum_middleware};
use std::net::SocketAddr;
use std::time::Duration;
use tokio::net::TcpListener;
@ -74,7 +74,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}
fn initialize_tracing(config: &AppConfig) {
use tracing_subscriber::{fmt, EnvFilter};
use tracing_subscriber::{EnvFilter, fmt};
let env_filter =
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&config.log_level));

View file

@ -1,4 +1,4 @@
use prometheus::{register_counter, register_gauge, register_histogram, Counter, Gauge, Histogram};
use prometheus::{Counter, Gauge, Histogram, register_counter, register_gauge, register_histogram};
#[derive(Clone)]
pub struct AppMetrics {

View file

@ -1,6 +1,6 @@
use axum::{body::Body, extract::Request, http::StatusCode, middleware::Next, response::Response};
use std::time::Instant;
use tracing::{info, warn, Instrument};
use tracing::{Instrument, info, warn};
use uuid::Uuid;
pub async fn logging_middleware(

View file

@ -1,7 +1,7 @@
use axum::{
body::Body,
extract::{ConnectInfo, Request, State},
http::{header, HeaderMap, StatusCode},
http::{HeaderMap, StatusCode, header},
middleware::Next,
response::{IntoResponse, Response},
};