fix message construction and path forwarding
This commit is contained in:
parent
b11ff4e598
commit
7f8e78d831
12 changed files with 61 additions and 43 deletions
2
Cargo.lock
generated
2
Cargo.lock
generated
|
|
@ -1785,7 +1785,7 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ohttp-gateway"
|
name = "ohttp-gateway"
|
||||||
version = "0.1.1"
|
version = "0.1.2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"axum",
|
"axum",
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
authors = ["Bastian Gruber <foreach@me.com>"]
|
authors = ["Bastian Gruber <foreach@me.com>"]
|
||||||
version = "0.1.1"
|
version = "0.1.2"
|
||||||
edition = "2024"
|
edition = "2024"
|
||||||
name = "ohttp-gateway"
|
name = "ohttp-gateway"
|
||||||
categories = ["web-programming", "web-programming::http-server"]
|
categories = ["web-programming", "web-programming::http-server"]
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
use axum::{
|
use axum::{
|
||||||
|
Json,
|
||||||
http::StatusCode,
|
http::StatusCode,
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
Json,
|
|
||||||
};
|
};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::{error::GatewayError, state::AppState};
|
use crate::{error::GatewayError, state::AppState};
|
||||||
use axum::{extract::State, Json};
|
use axum::{Json, extract::State};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
use crate::AppState;
|
use crate::AppState;
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::State,
|
extract::State,
|
||||||
http::{header, HeaderName, StatusCode},
|
http::{HeaderName, StatusCode, header},
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
};
|
};
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,8 @@ pub mod ohttp;
|
||||||
|
|
||||||
use crate::state::AppState;
|
use crate::state::AppState;
|
||||||
use axum::{
|
use axum::{
|
||||||
routing::{get, post},
|
|
||||||
Router,
|
Router,
|
||||||
|
routing::{get, post},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn routes() -> Router<AppState> {
|
pub fn routes() -> Router<AppState> {
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ use crate::{error::GatewayError, state::AppState};
|
||||||
use axum::{
|
use axum::{
|
||||||
body::{Body, Bytes},
|
body::{Body, Bytes},
|
||||||
extract::State,
|
extract::State,
|
||||||
http::{header, HeaderMap, StatusCode},
|
http::{HeaderMap, StatusCode, header},
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
};
|
};
|
||||||
use bhttp::{Message, Mode};
|
use bhttp::{Message, Mode};
|
||||||
|
|
@ -129,11 +129,7 @@ async fn handle_ohttp_request_inner(
|
||||||
/// Extract key ID from OHTTP request (first byte after version)
|
/// Extract key ID from OHTTP request (first byte after version)
|
||||||
fn extract_key_id_from_request(body: &[u8]) -> Option<u8> {
|
fn extract_key_id_from_request(body: &[u8]) -> Option<u8> {
|
||||||
// OHTTP request format: version(1) + key_id(1) + kem_id(2) + kdf_id(2) + aead_id(2) + enc + ciphertext
|
// OHTTP request format: version(1) + key_id(1) + kem_id(2) + kdf_id(2) + aead_id(2) + enc + ciphertext
|
||||||
if body.len() > 1 {
|
if body.len() > 1 { Some(body[1]) } else { None }
|
||||||
Some(body[1])
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validate the incoming OHTTP request
|
/// Validate the incoming OHTTP request
|
||||||
|
|
@ -208,6 +204,35 @@ fn validate_and_transform_request(
|
||||||
})
|
})
|
||||||
.ok_or_else(|| GatewayError::InvalidRequest("Missing host/authority".to_string()))?;
|
.ok_or_else(|| GatewayError::InvalidRequest("Missing host/authority".to_string()))?;
|
||||||
|
|
||||||
|
// Extract and clean the path
|
||||||
|
let raw_path = control.path().unwrap_or(b"/");
|
||||||
|
let path_str = String::from_utf8_lossy(raw_path);
|
||||||
|
|
||||||
|
// Clean up the path - remove any absolute URL components
|
||||||
|
let clean_path = if path_str.starts_with("http://") || path_str.starts_with("https://") {
|
||||||
|
// Extract just the path from absolute URL
|
||||||
|
if let Some(idx) = path_str
|
||||||
|
.find('/')
|
||||||
|
.and_then(|i| path_str[i + 2..].find('/').map(|j| i + 2 + j))
|
||||||
|
{
|
||||||
|
path_str[idx..].as_bytes()
|
||||||
|
} else {
|
||||||
|
b"/"
|
||||||
|
}
|
||||||
|
} else if path_str.contains(':') && !path_str.starts_with('/') {
|
||||||
|
// Path might contain host:port, clean it
|
||||||
|
b"/"
|
||||||
|
} else {
|
||||||
|
raw_path
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"Request details - host: {}, original_path: {}, clean_path: {}",
|
||||||
|
host,
|
||||||
|
path_str,
|
||||||
|
String::from_utf8_lossy(clean_path)
|
||||||
|
);
|
||||||
|
|
||||||
// Check if origin is allowed
|
// Check if origin is allowed
|
||||||
if !state.config.is_origin_allowed(&host) {
|
if !state.config.is_origin_allowed(&host) {
|
||||||
warn!("Blocked request to forbidden origin: {host}");
|
warn!("Blocked request to forbidden origin: {host}");
|
||||||
|
|
@ -225,21 +250,10 @@ fn validate_and_transform_request(
|
||||||
|
|
||||||
// Clone the message to modify it
|
// Clone the message to modify it
|
||||||
let mut new_message = Message::request(
|
let mut new_message = Message::request(
|
||||||
Vec::from(control.method().unwrap_or(b"GET")),
|
Vec::from(control.method().unwrap_or(b"GET")), // method
|
||||||
Vec::from(
|
Vec::from(rewrite.scheme.as_bytes()), // scheme
|
||||||
format!(
|
Vec::from(rewrite.host.as_bytes()), // authority
|
||||||
"{}://{}{}",
|
Vec::from(clean_path), // path
|
||||||
rewrite.scheme,
|
|
||||||
rewrite.host,
|
|
||||||
control
|
|
||||||
.path()
|
|
||||||
.map(|p| String::from_utf8_lossy(p))
|
|
||||||
.unwrap_or("/".into())
|
|
||||||
)
|
|
||||||
.as_bytes(),
|
|
||||||
),
|
|
||||||
Vec::from(control.scheme().unwrap_or(rewrite.scheme.as_bytes())),
|
|
||||||
Vec::from(rewrite.host.as_bytes()),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
// Copy all headers except Host and Authority
|
// Copy all headers except Host and Authority
|
||||||
|
|
@ -291,12 +305,12 @@ async fn forward_to_backend(
|
||||||
|
|
||||||
// Build the backend URI
|
// Build the backend URI
|
||||||
let uri = if let Some(host) = host {
|
let uri = if let Some(host) = host {
|
||||||
// Check for rewrites
|
// Extract scheme, handling various formats
|
||||||
if let Some(rewrite) = state.config.get_rewrite(&host) {
|
let scheme = control
|
||||||
format!("{}://{}{}", rewrite.scheme, rewrite.host, path)
|
.scheme()
|
||||||
} else {
|
.map(|s| String::from_utf8_lossy(s).into_owned())
|
||||||
build_backend_uri(&state.config.backend_url, &path)?
|
.unwrap_or_else(|| "http".to_string());
|
||||||
}
|
format!("{scheme}://{host}{path}")
|
||||||
} else {
|
} else {
|
||||||
build_backend_uri(&state.config.backend_url, &path)?
|
build_backend_uri(&state.config.backend_url, &path)?
|
||||||
};
|
};
|
||||||
|
|
@ -360,27 +374,31 @@ fn convert_method_to_reqwest(method: &[u8]) -> reqwest::Method {
|
||||||
|
|
||||||
fn build_backend_uri(backend_url: &str, path: &str) -> Result<String, GatewayError> {
|
fn build_backend_uri(backend_url: &str, path: &str) -> Result<String, GatewayError> {
|
||||||
let base_url = backend_url.trim_end_matches('/');
|
let base_url = backend_url.trim_end_matches('/');
|
||||||
let path = if path.starts_with('/') {
|
let clean_path = if path.starts_with('/') {
|
||||||
path
|
path
|
||||||
} else {
|
} else {
|
||||||
&format!("/{path}")
|
&format!("/{path}")
|
||||||
};
|
};
|
||||||
|
|
||||||
// Validate path to prevent SSRF attacks
|
// Validate path to prevent SSRF attacks
|
||||||
if path.contains("..") || path.contains("//") {
|
if clean_path.contains("..") || clean_path.contains("//") {
|
||||||
return Err(GatewayError::InvalidRequest(
|
return Err(GatewayError::InvalidRequest(
|
||||||
"Invalid path detected".to_string(),
|
"Invalid path detected".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Additional validation for suspicious patterns
|
// Additional validation for suspicious patterns
|
||||||
if path.contains('\0') || path.contains('\r') || path.contains('\n') {
|
if clean_path.contains('\0') || clean_path.contains('\r') || clean_path.contains('\n') {
|
||||||
return Err(GatewayError::InvalidRequest(
|
return Err(GatewayError::InvalidRequest(
|
||||||
"Invalid characters in path".to_string(),
|
"Invalid characters in path".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(format!("{base_url}{path}"))
|
// Build the final URI with explicit formatting
|
||||||
|
let final_uri = format!("{base_url}{clean_path}");
|
||||||
|
debug!("build_backend_uri: final_uri = '{}'", final_uri);
|
||||||
|
|
||||||
|
Ok(final_uri)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn should_forward_header(name: &str) -> bool {
|
fn should_forward_header(name: &str) -> bool {
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use ohttp::{
|
use ohttp::{
|
||||||
hpke::{Aead, Kdf, Kem},
|
|
||||||
KeyConfig, Server as OhttpServer, SymmetricSuite,
|
KeyConfig, Server as OhttpServer, SymmetricSuite,
|
||||||
|
hpke::{Aead, Kdf, Kem},
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ mod state;
|
||||||
|
|
||||||
use crate::config::{AppConfig, LogFormat};
|
use crate::config::{AppConfig, LogFormat};
|
||||||
use crate::state::AppState;
|
use crate::state::AppState;
|
||||||
use axum::{middleware as axum_middleware, Router};
|
use axum::{Router, middleware as axum_middleware};
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
|
|
@ -74,7 +74,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn initialize_tracing(config: &AppConfig) {
|
fn initialize_tracing(config: &AppConfig) {
|
||||||
use tracing_subscriber::{fmt, EnvFilter};
|
use tracing_subscriber::{EnvFilter, fmt};
|
||||||
|
|
||||||
let env_filter =
|
let env_filter =
|
||||||
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&config.log_level));
|
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&config.log_level));
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
use prometheus::{register_counter, register_gauge, register_histogram, Counter, Gauge, Histogram};
|
use prometheus::{Counter, Gauge, Histogram, register_counter, register_gauge, register_histogram};
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct AppMetrics {
|
pub struct AppMetrics {
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
use axum::{body::Body, extract::Request, http::StatusCode, middleware::Next, response::Response};
|
use axum::{body::Body, extract::Request, http::StatusCode, middleware::Next, response::Response};
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
use tracing::{info, warn, Instrument};
|
use tracing::{Instrument, info, warn};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
pub async fn logging_middleware(
|
pub async fn logging_middleware(
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
use axum::{
|
use axum::{
|
||||||
body::Body,
|
body::Body,
|
||||||
extract::{ConnectInfo, Request, State},
|
extract::{ConnectInfo, Request, State},
|
||||||
http::{header, HeaderMap, StatusCode},
|
http::{HeaderMap, StatusCode, header},
|
||||||
middleware::Next,
|
middleware::Next,
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
};
|
};
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue