fix message construction and path forwarding
This commit is contained in:
parent
b11ff4e598
commit
7f8e78d831
12 changed files with 61 additions and 43 deletions
2
Cargo.lock
generated
2
Cargo.lock
generated
|
|
@ -1785,7 +1785,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "ohttp-gateway"
|
||||
version = "0.1.1"
|
||||
version = "0.1.2"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
authors = ["Bastian Gruber <foreach@me.com>"]
|
||||
version = "0.1.1"
|
||||
version = "0.1.2"
|
||||
edition = "2024"
|
||||
name = "ohttp-gateway"
|
||||
categories = ["web-programming", "web-programming::http-server"]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use axum::{
|
||||
Json,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use serde_json::json;
|
||||
use thiserror::Error;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{error::GatewayError, state::AppState};
|
||||
use axum::{extract::State, Json};
|
||||
use axum::{Json, extract::State};
|
||||
use serde_json::json;
|
||||
use std::time::Duration;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use crate::AppState;
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::{header, HeaderName, StatusCode},
|
||||
http::{HeaderName, StatusCode, header},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use chrono::Utc;
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ pub mod ohttp;
|
|||
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
routing::{get, post},
|
||||
Router,
|
||||
routing::{get, post},
|
||||
};
|
||||
|
||||
pub fn routes() -> Router<AppState> {
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ use crate::{error::GatewayError, state::AppState};
|
|||
use axum::{
|
||||
body::{Body, Bytes},
|
||||
extract::State,
|
||||
http::{header, HeaderMap, StatusCode},
|
||||
http::{HeaderMap, StatusCode, header},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use bhttp::{Message, Mode};
|
||||
|
|
@ -129,11 +129,7 @@ async fn handle_ohttp_request_inner(
|
|||
/// Extract key ID from OHTTP request (first byte after version)
|
||||
fn extract_key_id_from_request(body: &[u8]) -> Option<u8> {
|
||||
// OHTTP request format: version(1) + key_id(1) + kem_id(2) + kdf_id(2) + aead_id(2) + enc + ciphertext
|
||||
if body.len() > 1 {
|
||||
Some(body[1])
|
||||
} else {
|
||||
None
|
||||
}
|
||||
if body.len() > 1 { Some(body[1]) } else { None }
|
||||
}
|
||||
|
||||
/// Validate the incoming OHTTP request
|
||||
|
|
@ -208,6 +204,35 @@ fn validate_and_transform_request(
|
|||
})
|
||||
.ok_or_else(|| GatewayError::InvalidRequest("Missing host/authority".to_string()))?;
|
||||
|
||||
// Extract and clean the path
|
||||
let raw_path = control.path().unwrap_or(b"/");
|
||||
let path_str = String::from_utf8_lossy(raw_path);
|
||||
|
||||
// Clean up the path - remove any absolute URL components
|
||||
let clean_path = if path_str.starts_with("http://") || path_str.starts_with("https://") {
|
||||
// Extract just the path from absolute URL
|
||||
if let Some(idx) = path_str
|
||||
.find('/')
|
||||
.and_then(|i| path_str[i + 2..].find('/').map(|j| i + 2 + j))
|
||||
{
|
||||
path_str[idx..].as_bytes()
|
||||
} else {
|
||||
b"/"
|
||||
}
|
||||
} else if path_str.contains(':') && !path_str.starts_with('/') {
|
||||
// Path might contain host:port, clean it
|
||||
b"/"
|
||||
} else {
|
||||
raw_path
|
||||
};
|
||||
|
||||
debug!(
|
||||
"Request details - host: {}, original_path: {}, clean_path: {}",
|
||||
host,
|
||||
path_str,
|
||||
String::from_utf8_lossy(clean_path)
|
||||
);
|
||||
|
||||
// Check if origin is allowed
|
||||
if !state.config.is_origin_allowed(&host) {
|
||||
warn!("Blocked request to forbidden origin: {host}");
|
||||
|
|
@ -225,21 +250,10 @@ fn validate_and_transform_request(
|
|||
|
||||
// Clone the message to modify it
|
||||
let mut new_message = Message::request(
|
||||
Vec::from(control.method().unwrap_or(b"GET")),
|
||||
Vec::from(
|
||||
format!(
|
||||
"{}://{}{}",
|
||||
rewrite.scheme,
|
||||
rewrite.host,
|
||||
control
|
||||
.path()
|
||||
.map(|p| String::from_utf8_lossy(p))
|
||||
.unwrap_or("/".into())
|
||||
)
|
||||
.as_bytes(),
|
||||
),
|
||||
Vec::from(control.scheme().unwrap_or(rewrite.scheme.as_bytes())),
|
||||
Vec::from(rewrite.host.as_bytes()),
|
||||
Vec::from(control.method().unwrap_or(b"GET")), // method
|
||||
Vec::from(rewrite.scheme.as_bytes()), // scheme
|
||||
Vec::from(rewrite.host.as_bytes()), // authority
|
||||
Vec::from(clean_path), // path
|
||||
);
|
||||
|
||||
// Copy all headers except Host and Authority
|
||||
|
|
@ -291,12 +305,12 @@ async fn forward_to_backend(
|
|||
|
||||
// Build the backend URI
|
||||
let uri = if let Some(host) = host {
|
||||
// Check for rewrites
|
||||
if let Some(rewrite) = state.config.get_rewrite(&host) {
|
||||
format!("{}://{}{}", rewrite.scheme, rewrite.host, path)
|
||||
} else {
|
||||
build_backend_uri(&state.config.backend_url, &path)?
|
||||
}
|
||||
// Extract scheme, handling various formats
|
||||
let scheme = control
|
||||
.scheme()
|
||||
.map(|s| String::from_utf8_lossy(s).into_owned())
|
||||
.unwrap_or_else(|| "http".to_string());
|
||||
format!("{scheme}://{host}{path}")
|
||||
} else {
|
||||
build_backend_uri(&state.config.backend_url, &path)?
|
||||
};
|
||||
|
|
@ -360,27 +374,31 @@ fn convert_method_to_reqwest(method: &[u8]) -> reqwest::Method {
|
|||
|
||||
fn build_backend_uri(backend_url: &str, path: &str) -> Result<String, GatewayError> {
|
||||
let base_url = backend_url.trim_end_matches('/');
|
||||
let path = if path.starts_with('/') {
|
||||
let clean_path = if path.starts_with('/') {
|
||||
path
|
||||
} else {
|
||||
&format!("/{path}")
|
||||
};
|
||||
|
||||
// Validate path to prevent SSRF attacks
|
||||
if path.contains("..") || path.contains("//") {
|
||||
if clean_path.contains("..") || clean_path.contains("//") {
|
||||
return Err(GatewayError::InvalidRequest(
|
||||
"Invalid path detected".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Additional validation for suspicious patterns
|
||||
if path.contains('\0') || path.contains('\r') || path.contains('\n') {
|
||||
if clean_path.contains('\0') || clean_path.contains('\r') || clean_path.contains('\n') {
|
||||
return Err(GatewayError::InvalidRequest(
|
||||
"Invalid characters in path".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(format!("{base_url}{path}"))
|
||||
// Build the final URI with explicit formatting
|
||||
let final_uri = format!("{base_url}{clean_path}");
|
||||
debug!("build_backend_uri: final_uri = '{}'", final_uri);
|
||||
|
||||
Ok(final_uri)
|
||||
}
|
||||
|
||||
fn should_forward_header(name: &str) -> bool {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use chrono::{DateTime, Utc};
|
||||
use ohttp::{
|
||||
hpke::{Aead, Kdf, Kem},
|
||||
KeyConfig, Server as OhttpServer, SymmetricSuite,
|
||||
hpke::{Aead, Kdf, Kem},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ mod state;
|
|||
|
||||
use crate::config::{AppConfig, LogFormat};
|
||||
use crate::state::AppState;
|
||||
use axum::{middleware as axum_middleware, Router};
|
||||
use axum::{Router, middleware as axum_middleware};
|
||||
use std::net::SocketAddr;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpListener;
|
||||
|
|
@ -74,7 +74,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
}
|
||||
|
||||
fn initialize_tracing(config: &AppConfig) {
|
||||
use tracing_subscriber::{fmt, EnvFilter};
|
||||
use tracing_subscriber::{EnvFilter, fmt};
|
||||
|
||||
let env_filter =
|
||||
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&config.log_level));
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use prometheus::{register_counter, register_gauge, register_histogram, Counter, Gauge, Histogram};
|
||||
use prometheus::{Counter, Gauge, Histogram, register_counter, register_gauge, register_histogram};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppMetrics {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use axum::{body::Body, extract::Request, http::StatusCode, middleware::Next, response::Response};
|
||||
use std::time::Instant;
|
||||
use tracing::{info, warn, Instrument};
|
||||
use tracing::{Instrument, info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub async fn logging_middleware(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use axum::{
|
||||
body::Body,
|
||||
extract::{ConnectInfo, Request, State},
|
||||
http::{header, HeaderMap, StatusCode},
|
||||
http::{HeaderMap, StatusCode, header},
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in a new issue