Initial commit

This commit is contained in:
Bastian Gruber 2025-07-17 09:14:22 -03:00
commit 7e4429e9d3
No known key found for this signature in database
GPG key ID: D2DF996A188CFBA2
21 changed files with 6303 additions and 0 deletions

21
.gitignore vendored Normal file
View file

@ -0,0 +1,21 @@
# Generated by Cargo
# will have compiled files and executables
debug
target
# These are backup files generated by rustfmt
**/*.rs.bk
# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb
# Generated by cargo mutants
# Contains mutation testing data
**/mutants.out*/
# RustRover
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

3776
Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

65
Cargo.toml Normal file
View file

@ -0,0 +1,65 @@
[package]
name = "ohttp-gateway"
authors = ["Bastian Gruber<foreach@me.com>"]
version = "0.1.0"
edition = "2024"
[dependencies]
# Web framework and async runtime
axum = { version = "0.7", features = ["macros"] }
tokio = { version = "1", features = ["full"] }
hyper = { version = "1", features = ["full"] }
hyper-util = { version = "0.1", features = ["full"] }
# HTTP client for backend requests
reqwest = { version = "0.12", features = ["json", "stream"] }
# OHTTP implementation - Using the martinthomson/ohttp crate
ohttp = { version = "0.5", features = ["rust-hpke"] }
bhttp = "0.5"
# Middleware and utilities
tower = "0.4"
tower-http = { version = "0.6", features = [
"cors",
"trace",
"compression-br",
"timeout",
] }
# Serialization and configuration
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
config = "0.14"
# Logging and observability
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
chrono = "0.4"
# Error handling
thiserror = "1.0"
anyhow = "1.0"
# Metrics and monitoring
prometheus = "0.13"
axum-prometheus = "0.7"
# Security and validation
validator = { version = "0.18", features = ["derive"] }
jsonwebtoken = "9.0"
uuid = { version = "1.0", features = ["v4"] }
# Async utilities
tokio-util = "0.7"
futures = "0.3"
# Random number generation
hex = "0.4"
rand = "0.8"
# Configuration management
clap = { version = "4.0", features = ["derive", "env"] }
[profile.release]
lto = "fat"

54
Dockerfile Normal file
View file

@ -0,0 +1,54 @@
# Build stage
FROM rust:1.88-slim as builder
WORKDIR /app
# Install build dependencies
RUN apt-get update && apt-get install -y \
pkg-config \
libssl-dev \
&& rm -rf /var/lib/apt/lists/*
# Copy Cargo files
COPY Cargo.toml ./
# Create dummy main to cache dependencies
RUN mkdir src && echo "fn main() {}" > src/main.rs
# Build dependencies
RUN RUSTFLAGS="-C target-cpu=native" cargo build --release
RUN rm -rf src
# Copy source code
COPY src ./src
# Build the actual application
RUN touch src/main.rs && RUSTFLAGS="-C target-cpu=native" cargo build --release
# Runtime stage
FROM debian:bookworm-slim
RUN apt-get update && apt-get install -y \
ca-certificates \
libssl3 \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
# Copy the binary from builder
COPY --from=builder /app/target/release/ohttp-gateway /usr/local/bin/ohttp-gateway
# Create non-root user
RUN useradd -m -u 1001 ohttp
USER ohttp
# Set default environment variables
ENV RUST_LOG=debug,ohttp_gateway=debug
ENV LISTEN_ADDR=0.0.0.0:8080
ENV BACKEND_URL=http://localhost:8000
ENV REQUEST_TIMEOUT=30
ENV KEY_ROTATION_ENABLED=false
EXPOSE 8080
CMD ["ohttp-gateway"]

373
LICENSE Normal file
View file

@ -0,0 +1,373 @@
Mozilla Public License Version 2.0
==================================
1. Definitions
--------------
1.1. "Contributor"
means each individual or legal entity that creates, contributes to
the creation of, or owns Covered Software.
1.2. "Contributor Version"
means the combination of the Contributions of others (if any) used
by a Contributor and that particular Contributor's Contribution.
1.3. "Contribution"
means Covered Software of a particular Contributor.
1.4. "Covered Software"
means Source Code Form to which the initial Contributor has attached
the notice in Exhibit A, the Executable Form of such Source Code
Form, and Modifications of such Source Code Form, in each case
including portions thereof.
1.5. "Incompatible With Secondary Licenses"
means
(a) that the initial Contributor has attached the notice described
in Exhibit B to the Covered Software; or
(b) that the Covered Software was made available under the terms of
version 1.1 or earlier of the License, but not also under the
terms of a Secondary License.
1.6. "Executable Form"
means any form of the work other than Source Code Form.
1.7. "Larger Work"
means a work that combines Covered Software with other material, in
a separate file or files, that is not Covered Software.
1.8. "License"
means this document.
1.9. "Licensable"
means having the right to grant, to the maximum extent possible,
whether at the time of the initial grant or subsequently, any and
all of the rights conveyed by this License.
1.10. "Modifications"
means any of the following:
(a) any file in Source Code Form that results from an addition to,
deletion from, or modification of the contents of Covered
Software; or
(b) any new file in Source Code Form that contains any Covered
Software.
1.11. "Patent Claims" of a Contributor
means any patent claim(s), including without limitation, method,
process, and apparatus claims, in any patent Licensable by such
Contributor that would be infringed, but for the grant of the
License, by the making, using, selling, offering for sale, having
made, import, or transfer of either its Contributions or its
Contributor Version.
1.12. "Secondary License"
means either the GNU General Public License, Version 2.0, the GNU
Lesser General Public License, Version 2.1, the GNU Affero General
Public License, Version 3.0, or any later versions of those
licenses.
1.13. "Source Code Form"
means the form of the work preferred for making modifications.
1.14. "You" (or "Your")
means an individual or a legal entity exercising rights under this
License. For legal entities, "You" includes any entity that
controls, is controlled by, or is under common control with You. For
purposes of this definition, "control" means (a) the power, direct
or indirect, to cause the direction or management of such entity,
whether by contract or otherwise, or (b) ownership of more than
fifty percent (50%) of the outstanding shares or beneficial
ownership of such entity.
2. License Grants and Conditions
--------------------------------
2.1. Grants
Each Contributor hereby grants You a world-wide, royalty-free,
non-exclusive license:
(a) under intellectual property rights (other than patent or trademark)
Licensable by such Contributor to use, reproduce, make available,
modify, display, perform, distribute, and otherwise exploit its
Contributions, either on an unmodified basis, with Modifications, or
as part of a Larger Work; and
(b) under Patent Claims of such Contributor to make, use, sell, offer
for sale, have made, import, and otherwise transfer either its
Contributions or its Contributor Version.
2.2. Effective Date
The licenses granted in Section 2.1 with respect to any Contribution
become effective for each Contribution on the date the Contributor first
distributes such Contribution.
2.3. Limitations on Grant Scope
The licenses granted in this Section 2 are the only rights granted under
this License. No additional rights or licenses will be implied from the
distribution or licensing of Covered Software under this License.
Notwithstanding Section 2.1(b) above, no patent license is granted by a
Contributor:
(a) for any code that a Contributor has removed from Covered Software;
or
(b) for infringements caused by: (i) Your and any other third party's
modifications of Covered Software, or (ii) the combination of its
Contributions with other software (except as part of its Contributor
Version); or
(c) under Patent Claims infringed by Covered Software in the absence of
its Contributions.
This License does not grant any rights in the trademarks, service marks,
or logos of any Contributor (except as may be necessary to comply with
the notice requirements in Section 3.4).
2.4. Subsequent Licenses
No Contributor makes additional grants as a result of Your choice to
distribute the Covered Software under a subsequent version of this
License (see Section 10.2) or under the terms of a Secondary License (if
permitted under the terms of Section 3.3).
2.5. Representation
Each Contributor represents that the Contributor believes its
Contributions are its original creation(s) or it has sufficient rights
to grant the rights to its Contributions conveyed by this License.
2.6. Fair Use
This License is not intended to limit any rights You have under
applicable copyright doctrines of fair use, fair dealing, or other
equivalents.
2.7. Conditions
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
in Section 2.1.
3. Responsibilities
-------------------
3.1. Distribution of Source Form
All distribution of Covered Software in Source Code Form, including any
Modifications that You create or to which You contribute, must be under
the terms of this License. You must inform recipients that the Source
Code Form of the Covered Software is governed by the terms of this
License, and how they can obtain a copy of this License. You may not
attempt to alter or restrict the recipients' rights in the Source Code
Form.
3.2. Distribution of Executable Form
If You distribute Covered Software in Executable Form then:
(a) such Covered Software must also be made available in Source Code
Form, as described in Section 3.1, and You must inform recipients of
the Executable Form how they can obtain a copy of such Source Code
Form by reasonable means in a timely manner, at a charge no more
than the cost of distribution to the recipient; and
(b) You may distribute such Executable Form under the terms of this
License, or sublicense it under different terms, provided that the
license for the Executable Form does not attempt to limit or alter
the recipients' rights in the Source Code Form under this License.
3.3. Distribution of a Larger Work
You may create and distribute a Larger Work under terms of Your choice,
provided that You also comply with the requirements of this License for
the Covered Software. If the Larger Work is a combination of Covered
Software with a work governed by one or more Secondary Licenses, and the
Covered Software is not Incompatible With Secondary Licenses, this
License permits You to additionally distribute such Covered Software
under the terms of such Secondary License(s), so that the recipient of
the Larger Work may, at their option, further distribute the Covered
Software under the terms of either this License or such Secondary
License(s).
3.4. Notices
You may not remove or alter the substance of any license notices
(including copyright notices, patent notices, disclaimers of warranty,
or limitations of liability) contained within the Source Code Form of
the Covered Software, except that You may alter any license notices to
the extent required to remedy known factual inaccuracies.
3.5. Application of Additional Terms
You may choose to offer, and to charge a fee for, warranty, support,
indemnity or liability obligations to one or more recipients of Covered
Software. However, You may do so only on Your own behalf, and not on
behalf of any Contributor. You must make it absolutely clear that any
such warranty, support, indemnity, or liability obligation is offered by
You alone, and You hereby agree to indemnify every Contributor for any
liability incurred by such Contributor as a result of warranty, support,
indemnity or liability terms You offer. You may include additional
disclaimers of warranty and limitations of liability specific to any
jurisdiction.
4. Inability to Comply Due to Statute or Regulation
---------------------------------------------------
If it is impossible for You to comply with any of the terms of this
License with respect to some or all of the Covered Software due to
statute, judicial order, or regulation then You must: (a) comply with
the terms of this License to the maximum extent possible; and (b)
describe the limitations and the code they affect. Such description must
be placed in a text file included with all distributions of the Covered
Software under this License. Except to the extent prohibited by statute
or regulation, such description must be sufficiently detailed for a
recipient of ordinary skill to be able to understand it.
5. Termination
--------------
5.1. The rights granted under this License will terminate automatically
if You fail to comply with any of its terms. However, if You become
compliant, then the rights granted under this License from a particular
Contributor are reinstated (a) provisionally, unless and until such
Contributor explicitly and finally terminates Your grants, and (b) on an
ongoing basis, if such Contributor fails to notify You of the
non-compliance by some reasonable means prior to 60 days after You have
come back into compliance. Moreover, Your grants from a particular
Contributor are reinstated on an ongoing basis if such Contributor
notifies You of the non-compliance by some reasonable means, this is the
first time You have received notice of non-compliance with this License
from such Contributor, and You become compliant prior to 30 days after
Your receipt of the notice.
5.2. If You initiate litigation against any entity by asserting a patent
infringement claim (excluding declaratory judgment actions,
counter-claims, and cross-claims) alleging that a Contributor Version
directly or indirectly infringes any patent, then the rights granted to
You by any and all Contributors for the Covered Software under Section
2.1 of this License shall terminate.
5.3. In the event of termination under Sections 5.1 or 5.2 above, all
end user license agreements (excluding distributors and resellers) which
have been validly granted by You or Your distributors under this License
prior to termination shall survive termination.
************************************************************************
* *
* 6. Disclaimer of Warranty *
* ------------------------- *
* *
* Covered Software is provided under this License on an "as is" *
* basis, without warranty of any kind, either expressed, implied, or *
* statutory, including, without limitation, warranties that the *
* Covered Software is free of defects, merchantable, fit for a *
* particular purpose or non-infringing. The entire risk as to the *
* quality and performance of the Covered Software is with You. *
* Should any Covered Software prove defective in any respect, You *
* (not any Contributor) assume the cost of any necessary servicing, *
* repair, or correction. This disclaimer of warranty constitutes an *
* essential part of this License. No use of any Covered Software is *
* authorized under this License except under this disclaimer. *
* *
************************************************************************
************************************************************************
* *
* 7. Limitation of Liability *
* -------------------------- *
* *
* Under no circumstances and under no legal theory, whether tort *
* (including negligence), contract, or otherwise, shall any *
* Contributor, or anyone who distributes Covered Software as *
* permitted above, be liable to You for any direct, indirect, *
* special, incidental, or consequential damages of any character *
* including, without limitation, damages for lost profits, loss of *
* goodwill, work stoppage, computer failure or malfunction, or any *
* and all other commercial damages or losses, even if such party *
* shall have been informed of the possibility of such damages. This *
* limitation of liability shall not apply to liability for death or *
* personal injury resulting from such party's negligence to the *
* extent applicable law prohibits such limitation. Some *
* jurisdictions do not allow the exclusion or limitation of *
* incidental or consequential damages, so this exclusion and *
* limitation may not apply to You. *
* *
************************************************************************
8. Litigation
-------------
Any litigation relating to this License may be brought only in the
courts of a jurisdiction where the defendant maintains its principal
place of business and such litigation shall be governed by laws of that
jurisdiction, without reference to its conflict-of-law provisions.
Nothing in this Section shall prevent a party's ability to bring
cross-claims or counter-claims.
9. Miscellaneous
----------------
This License represents the complete agreement concerning the subject
matter hereof. If any provision of this License is held to be
unenforceable, such provision shall be reformed only to the extent
necessary to make it enforceable. Any law or regulation which provides
that the language of a contract shall be construed against the drafter
shall not be used to construe this License against a Contributor.
10. Versions of the License
---------------------------
10.1. New Versions
Mozilla Foundation is the license steward. Except as provided in Section
10.3, no one other than the license steward has the right to modify or
publish new versions of this License. Each version will be given a
distinguishing version number.
10.2. Effect of New Versions
You may distribute the Covered Software under the terms of the version
of the License under which You originally received the Covered Software,
or under the terms of any subsequent version published by the license
steward.
10.3. Modified Versions
If you create software not governed by this License, and you want to
create a new license for such software, you may create and use a
modified version of this License if you rename the license and remove
any references to the name of the license steward (except to note that
such modified license differs from this License).
10.4. Distributing Source Code Form that is Incompatible With Secondary
Licenses
If You choose to distribute Source Code Form that is Incompatible With
Secondary Licenses under the terms of this version of the License, the
notice described in Exhibit B of this License must be attached.
Exhibit A - Source Code Form License Notice
-------------------------------------------
This Source Code Form is subject to the terms of the Mozilla Public
License, v. 2.0. If a copy of the MPL was not distributed with this
file, You can obtain one at https://mozilla.org/MPL/2.0/.
If it is not possible or desirable to put the notice in a particular
file, then You may include the notice in a location (such as a LICENSE
file in a relevant directory) where a recipient would be likely to look
for such a notice.
You may add additional accurate notices of copyright ownership.
Exhibit B - "Incompatible With Secondary Licenses" Notice
---------------------------------------------------------
This Source Code Form is "Incompatible With Secondary Licenses", as
defined by the Mozilla Public License, v. 2.0.

2
README.md Normal file
View file

@ -0,0 +1,2 @@
# ohttp-gateway
A OHTTP Gateway written in Rust

270
src/config.rs Normal file
View file

@ -0,0 +1,270 @@
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::time::Duration;
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct AppConfig {
// Server configuration
pub listen_addr: String,
pub backend_url: String,
pub request_timeout: Duration,
pub max_body_size: usize,
// Key management
pub key_rotation_interval: Duration,
pub key_retention_period: Duration,
pub key_rotation_enabled: bool,
// Security configuration
pub allowed_target_origins: Option<HashSet<String>>,
pub target_rewrites: Option<TargetRewriteConfig>,
pub rate_limit: Option<RateLimitConfig>,
// Operational configuration
pub metrics_enabled: bool,
pub debug_mode: bool,
pub log_format: LogFormat,
pub log_level: String,
// OHTTP specific
pub custom_request_type: Option<String>,
pub custom_response_type: Option<String>,
pub seed_secret_key: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct TargetRewriteConfig {
pub rewrites: std::collections::HashMap<String, TargetRewrite>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct TargetRewrite {
pub scheme: String,
pub host: String,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct RateLimitConfig {
pub requests_per_second: u32,
pub burst_size: u32,
pub by_ip: bool,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum LogFormat {
Default,
Json,
}
impl Default for AppConfig {
fn default() -> Self {
Self {
listen_addr: "0.0.0.0:8080".to_string(),
backend_url: "http://localhost:8080".to_string(),
request_timeout: Duration::from_secs(30),
max_body_size: 10 * 1024 * 1024, // 10MB
key_rotation_interval: Duration::from_secs(30 * 24 * 60 * 60), // 30 days
key_retention_period: Duration::from_secs(7 * 24 * 60 * 60), // 7 days
key_rotation_enabled: true,
allowed_target_origins: None,
target_rewrites: None,
rate_limit: None,
metrics_enabled: true,
debug_mode: false,
log_format: LogFormat::Default,
log_level: "info".to_string(),
custom_request_type: None,
custom_response_type: None,
seed_secret_key: None,
}
}
}
impl AppConfig {
pub fn from_env() -> Result<Self, Box<dyn std::error::Error>> {
let mut config = Self::default();
// Basic configuration
if let Ok(addr) = std::env::var("LISTEN_ADDR") {
config.listen_addr = addr;
}
if let Ok(url) = std::env::var("BACKEND_URL") {
config.backend_url = url;
}
if let Ok(timeout) = std::env::var("REQUEST_TIMEOUT") {
config.request_timeout = Duration::from_secs(timeout.parse()?);
}
if let Ok(size) = std::env::var("MAX_BODY_SIZE") {
config.max_body_size = size.parse()?;
}
// Key management
if let Ok(interval) = std::env::var("KEY_ROTATION_INTERVAL") {
config.key_rotation_interval = Duration::from_secs(interval.parse()?);
}
if let Ok(period) = std::env::var("KEY_RETENTION_PERIOD") {
config.key_retention_period = Duration::from_secs(period.parse()?);
}
if let Ok(enabled) = std::env::var("KEY_ROTATION_ENABLED") {
config.key_rotation_enabled = enabled.parse()?;
}
// Security configuration
if let Ok(origins) = std::env::var("ALLOWED_TARGET_ORIGINS") {
let origins_set: HashSet<String> = origins
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
if !origins_set.is_empty() {
config.allowed_target_origins = Some(origins_set);
}
}
if let Ok(rewrites_json) = std::env::var("TARGET_REWRITES") {
let rewrites: std::collections::HashMap<String, TargetRewrite> =
serde_json::from_str(&rewrites_json)?;
config.target_rewrites = Some(TargetRewriteConfig { rewrites });
}
// Rate limiting
if let Ok(rps) = std::env::var("RATE_LIMIT_RPS") {
let rate_limit = RateLimitConfig {
requests_per_second: rps.parse()?,
burst_size: std::env::var("RATE_LIMIT_BURST")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(100),
by_ip: std::env::var("RATE_LIMIT_BY_IP")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(true),
};
config.rate_limit = Some(rate_limit);
}
// Operational configuration
if let Ok(enabled) = std::env::var("METRICS_ENABLED") {
config.metrics_enabled = enabled.parse()?;
}
if let Ok(debug) = std::env::var("GATEWAY_DEBUG") {
config.debug_mode = debug.parse()?;
}
if let Ok(format) = std::env::var("LOG_FORMAT") {
config.log_format = match format.to_lowercase().as_str() {
"json" => LogFormat::Json,
_ => LogFormat::Default,
};
}
if let Ok(level) = std::env::var("LOG_LEVEL") {
config.log_level = level;
}
// OHTTP specific
if let Ok(req_type) = std::env::var("CUSTOM_REQUEST_TYPE") {
config.custom_request_type = Some(req_type);
}
if let Ok(resp_type) = std::env::var("CUSTOM_RESPONSE_TYPE") {
config.custom_response_type = Some(resp_type);
}
if let Ok(seed) = std::env::var("SEED_SECRET_KEY") {
config.seed_secret_key = Some(seed);
}
// Validate configuration
config.validate()?;
Ok(config)
}
fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
// Validate key rotation settings
if self.key_retention_period > self.key_rotation_interval {
return Err("Key retention period cannot be longer than rotation interval".into());
}
// Validate custom content types
match (&self.custom_request_type, &self.custom_response_type) {
(Some(req), Some(resp)) if req == resp => {
return Err("Request and response content types must be different".into());
}
(Some(_), None) | (None, Some(_)) => {
return Err("Both custom request and response types must be specified".into());
}
_ => {}
}
// Validate seed if provided
if let Some(seed) = &self.seed_secret_key {
let decoded =
hex::decode(seed).map_err(|_| "SEED_SECRET_KEY must be a hex-encoded string")?;
if decoded.len() < 32 {
return Err("SEED_SECRET_KEY must be at least 32 bytes (64 hex characters)".into());
}
}
Ok(())
}
/// Check if a target origin is allowed
pub fn is_origin_allowed(&self, origin: &str) -> bool {
match &self.allowed_target_origins {
Some(allowed) => allowed.contains(origin),
None => true, // No restrictions if not configured
}
}
/// Get rewrite configuration for a host
pub fn get_rewrite(&self, host: &str) -> Option<&TargetRewrite> {
self.target_rewrites
.as_ref()
.and_then(|config| config.rewrites.get(host))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = AppConfig::default();
assert_eq!(config.listen_addr, "0.0.0.0:8080");
assert!(config.key_rotation_enabled);
}
#[test]
fn test_validation_key_periods() {
let mut config = AppConfig::default();
config.key_retention_period = Duration::from_secs(100);
config.key_rotation_interval = Duration::from_secs(50);
assert!(config.validate().is_err());
}
#[test]
fn test_origin_allowed() {
let mut config = AppConfig::default();
config.allowed_target_origins = Some(
vec!["example.com".to_string(), "test.com".to_string()]
.into_iter()
.collect(),
);
assert!(config.is_origin_allowed("example.com"));
assert!(!config.is_origin_allowed("forbidden.com"));
}
}

66
src/error.rs Normal file
View file

@ -0,0 +1,66 @@
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde_json::json;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum GatewayError {
#[error("Invalid request: {0}")]
InvalidRequest(String),
#[error("Decryption failed: {0}")]
DecryptionError(String),
#[error("Encryption failed: {0}")]
EncryptionError(String),
#[error("Backend error: {0}")]
BackendError(String),
#[error("Request too large: {0}")]
RequestTooLarge(String),
#[error("Configuration error: {0}")]
ConfigurationError(String),
#[error("Internal error: {0}")]
InternalError(String),
}
impl IntoResponse for GatewayError {
fn into_response(self) -> Response {
let (status, error_code, message) = match self {
GatewayError::InvalidRequest(msg) => (StatusCode::BAD_REQUEST, "invalid_request", msg),
GatewayError::DecryptionError(msg) => {
(StatusCode::BAD_REQUEST, "decryption_error", msg)
}
GatewayError::EncryptionError(msg) => {
(StatusCode::INTERNAL_SERVER_ERROR, "encryption_error", msg)
}
GatewayError::BackendError(msg) => (StatusCode::BAD_GATEWAY, "backend_error", msg),
GatewayError::RequestTooLarge(msg) => {
(StatusCode::PAYLOAD_TOO_LARGE, "request_too_large", msg)
}
GatewayError::ConfigurationError(msg) => (
StatusCode::INTERNAL_SERVER_ERROR,
"configuration_error",
msg,
),
GatewayError::InternalError(msg) => {
(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", msg)
}
};
let body = Json(json!({
"error": {
"code": error_code,
"message": message
}
}));
(status, body).into_response()
}
}

77
src/handlers/health.rs Normal file
View file

@ -0,0 +1,77 @@
use crate::{error::GatewayError, state::AppState};
use axum::{extract::State, Json};
use serde_json::json;
use std::time::Duration;
pub async fn health_check(
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, GatewayError> {
let mut health_checks = vec![];
// Check key manager health
let key_status = match state.key_manager.get_encoded_config().await {
Ok(config) if config.len() > 100 => "healthy",
Ok(_) => "unhealthy",
Err(_) => "unhealthy",
};
health_checks.push(json!({
"component": "ohttp_keys",
"status": key_status
}));
// Check backend connectivity - use the correct health endpoint
let backend_health_url = format!("{}/health", state.config.backend_url);
let backend_status = match state
.http_client
.get(&backend_health_url)
.timeout(Duration::from_secs(5))
.send()
.await
{
Ok(resp) if resp.status().is_success() => "healthy",
Ok(resp) => {
tracing::warn!("Backend health check returned: {}", resp.status());
"unhealthy"
}
Err(e) => {
tracing::error!("Backend health check failed: {}", e);
"unhealthy"
}
};
health_checks.push(json!({
"component": "backend",
"status": backend_status,
"url": backend_health_url
}));
let overall_status = if health_checks.iter().all(|c| c["status"] == "healthy") {
"healthy"
} else {
"unhealthy"
};
Ok(Json(json!({
"status": overall_status,
"timestamp": chrono::Utc::now().to_rfc3339(),
"checks": health_checks,
"version": env!("CARGO_PKG_VERSION")
})))
}
pub async fn metrics_handler() -> Result<String, GatewayError> {
use prometheus::{Encoder, TextEncoder};
let encoder = TextEncoder::new();
let metric_families = prometheus::gather();
let mut buffer = Vec::new();
encoder
.encode(&metric_families, &mut buffer)
.map_err(|e| GatewayError::InternalError(format!("Failed to encode metrics: {e}")))?;
String::from_utf8(buffer).map_err(|e| {
GatewayError::InternalError(format!("Failed to convert metrics to string: {e}"))
})
}

83
src/handlers/keys.rs Normal file
View file

@ -0,0 +1,83 @@
use crate::AppState;
use axum::{
extract::State,
http::{header, HeaderName, StatusCode},
response::{IntoResponse, Response},
};
use chrono::Utc;
use tracing::info;
/// Handler for /ohttp-keys endpoint
/// Returns key configurations in the standard OHTTP format
pub async fn get_ohttp_keys(State(state): State<AppState>) -> Result<Response, StatusCode> {
state.metrics.key_requests_total.inc();
match state.key_manager.get_encoded_config().await {
Ok(config_bytes) => {
info!("Serving {} bytes of key configurations", config_bytes.len());
// Calculate cache duration based on rotation interval
let max_age = calculate_cache_max_age(&state);
Ok((
StatusCode::OK,
[
(header::CONTENT_TYPE, "application/ohttp-keys"),
(header::CACHE_CONTROL, &format!("public, max-age={max_age}")),
(HeaderName::from_static("x-content-type-options"), "nosniff"),
],
config_bytes,
)
.into_response())
}
Err(e) => {
tracing::error!("Failed to encode key config: {e}");
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
}
}
/// Legacy endpoint for backward compatibility
/// Some older clients may still use /ohttp-configs
pub async fn get_legacy_ohttp_configs(
State(state): State<AppState>,
) -> Result<Response, StatusCode> {
// Just forward to the main handler
get_ohttp_keys(State(state)).await
}
/// Calculate appropriate cache duration for key configurations
fn calculate_cache_max_age(state: &AppState) -> u64 {
// Cache for 10% of rotation interval, minimum 1 hour, maximum 24 hours
let ten_percent = state.config.key_rotation_interval.as_secs() / 10;
let one_hour = 3600;
let twenty_four_hours = 86400;
ten_percent.max(one_hour).min(twenty_four_hours)
}
/// Health check endpoint specifically for key management
pub async fn key_health_check(State(state): State<AppState>) -> impl IntoResponse {
let stats = state.key_manager.get_stats().await;
let health_status = if stats.active_keys > 0 && stats.expired_keys == 0 {
"healthy"
} else if stats.active_keys > 0 {
"degraded"
} else {
"unhealthy"
};
axum::Json(serde_json::json!({
"status": health_status,
"timestamp": Utc::now().to_rfc3339(),
"key_stats": {
"active_key_id": stats.active_key_id,
"total_keys": stats.total_keys,
"active_keys": stats.active_keys,
"expired_keys": stats.expired_keys,
"rotation_enabled": stats.auto_rotation_enabled,
"rotation_interval_hours": stats.rotation_interval.as_secs() / 3600,
}
}))
}

22
src/handlers/mod.rs Normal file
View file

@ -0,0 +1,22 @@
pub mod health;
pub mod keys;
pub mod ohttp;
use crate::state::AppState;
use axum::{
routing::{get, post},
Router,
};
pub fn routes() -> Router<AppState> {
Router::new()
// OHTTP endpoints
.route("/gateway", post(ohttp::handle_ohttp_request))
.route("/ohttp-keys", get(keys::get_ohttp_keys))
// Legacy endpoints for backward compatibility
.route("/ohttp-configs", get(keys::get_legacy_ohttp_configs))
// Health and monitoring
.route("/health", get(health::health_check))
.route("/health/keys", get(keys::key_health_check))
.route("/metrics", get(health::metrics_handler))
}

477
src/handlers/ohttp.rs Normal file
View file

@ -0,0 +1,477 @@
use crate::{error::GatewayError, state::AppState};
use axum::{
body::{Body, Bytes},
extract::State,
http::{header, HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use bhttp::{Message, Mode};
use tracing::{debug, error, info, warn};
const OHTTP_REQUEST_CONTENT_TYPE: &str = "message/ohttp-req";
const OHTTP_RESPONSE_CONTENT_TYPE: &str = "message/ohttp-res";
pub async fn handle_ohttp_request(
State(state): State<AppState>,
headers: HeaderMap,
body: Bytes,
) -> impl IntoResponse {
let timer = state.metrics.request_duration.start_timer();
state.metrics.requests_total.inc();
// Extract key ID from the request if possible
let key_id = extract_key_id_from_request(&body);
let result = handle_ohttp_request_inner(state.clone(), headers, body, key_id).await;
timer.stop_and_record();
match result {
Ok(response) => response,
Err(e) => {
error!("OHTTP request failed: {:?}", e);
// Log metrics based on error type
match &e {
GatewayError::DecryptionError(_) => state.metrics.decryption_errors_total.inc(),
GatewayError::EncryptionError(_) => state.metrics.encryption_errors_total.inc(),
GatewayError::BackendError(_) => state.metrics.backend_errors_total.inc(),
_ => {}
}
e.into_response()
}
}
}
async fn handle_ohttp_request_inner(
state: AppState,
headers: HeaderMap,
body: Bytes,
key_id: Option<u8>,
) -> Result<Response, GatewayError> {
// Validate request
validate_ohttp_request(&headers, &body, &state)?;
debug!(
"Received OHTTP request with {} bytes, key_id: {:?}",
body.len(),
key_id
);
// Get the appropriate server based on key ID
let server = if let Some(id) = key_id {
// Try to get server for specific key ID
match state.key_manager.get_server_by_id(id).await {
Some(server) => {
debug!("Using server for key ID: {}", id);
server
}
None => {
warn!("Unknown key ID: {}, falling back to current server", id);
state
.key_manager
.get_current_server()
.await
.map_err(|e| GatewayError::ConfigurationError(e.to_string()))?
}
}
} else {
// Use current active server
state
.key_manager
.get_current_server()
.await
.map_err(|e| GatewayError::ConfigurationError(e.to_string()))?
};
// Decrypt the OHTTP request
let (bhttp_request, server_response) = server.decapsulate(&body).map_err(|e| {
error!("Failed to decapsulate OHTTP request: {e}");
GatewayError::DecryptionError(format!("Failed to decapsulate: {e}"))
})?;
debug!(
"Successfully decapsulated request, {} bytes",
bhttp_request.len()
);
// Parse binary HTTP message
let message = parse_bhttp_message(&bhttp_request)?;
// Validate and potentially transform the request
let message = validate_and_transform_request(message, &state)?;
// Forward request to backend
let backend_response = forward_to_backend(&state, message).await?;
// Convert response to binary HTTP format
let bhttp_response = convert_to_binary_http(backend_response).await?;
// Encrypt response back to client
let encrypted_response = server_response.encapsulate(&bhttp_response).map_err(|e| {
GatewayError::EncryptionError(format!("Failed to encapsulate response: {e}"))
})?;
state.metrics.successful_requests_total.inc();
info!("Successfully processed OHTTP request");
// Build response with appropriate headers
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, OHTTP_RESPONSE_CONTENT_TYPE)
.header(header::CACHE_CONTROL, "no-cache, no-store, must-revalidate")
.header("X-Content-Type-Options", "nosniff")
.header("X-Frame-Options", "DENY")
.body(Body::from(encrypted_response))
.map_err(|e| GatewayError::InternalError(format!("Response build error: {e}")))
}
/// Extract key ID from OHTTP request (first byte after version)
fn extract_key_id_from_request(body: &[u8]) -> Option<u8> {
// OHTTP request format: version(1) + key_id(1) + kem_id(2) + kdf_id(2) + aead_id(2) + enc + ciphertext
if body.len() > 1 {
Some(body[1])
} else {
None
}
}
/// Validate the incoming OHTTP request
fn validate_ohttp_request(
headers: &HeaderMap,
body: &Bytes,
state: &AppState,
) -> Result<(), GatewayError> {
// Check content type
let content_type = headers
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.ok_or_else(|| GatewayError::InvalidRequest("Missing content-type header".to_string()))?;
if content_type != OHTTP_REQUEST_CONTENT_TYPE {
return Err(GatewayError::InvalidRequest(format!(
"Invalid content-type: expected '{OHTTP_REQUEST_CONTENT_TYPE}', got '{content_type}'"
)));
}
// Check body size
if body.is_empty() {
return Err(GatewayError::InvalidRequest(
"Empty request body".to_string(),
));
}
if body.len() > state.config.max_body_size {
return Err(GatewayError::RequestTooLarge(format!(
"Request body too large: {} bytes (max: {})",
body.len(),
state.config.max_body_size
)));
}
// Minimum OHTTP request size check
if body.len() < 10 {
return Err(GatewayError::InvalidRequest(
"Request too small to be valid OHTTP".to_string(),
));
}
Ok(())
}
/// Parse binary HTTP message with error handling
fn parse_bhttp_message(data: &[u8]) -> Result<Message, GatewayError> {
let mut cursor = std::io::Cursor::new(data);
Message::read_bhttp(&mut cursor)
.map_err(|e| GatewayError::InvalidRequest(format!("Failed to parse binary HTTP: {e}")))
}
/// Validate and transform the request based on security policies
fn validate_and_transform_request(
message: Message,
state: &AppState,
) -> Result<Message, GatewayError> {
let control = message.control();
// Extract host from authority or Host header
let host = control
.authority()
.map(|a| String::from_utf8_lossy(a).into_owned())
.or_else(|| {
message.header().fields().iter().find_map(|field| {
if field.name().eq_ignore_ascii_case(b"host") {
Some(String::from_utf8_lossy(field.value()).into_owned())
} else {
None
}
})
})
.ok_or_else(|| GatewayError::InvalidRequest("Missing host/authority".to_string()))?;
// Check if origin is allowed
if !state.config.is_origin_allowed(&host) {
warn!("Blocked request to forbidden origin: {host}");
return Err(GatewayError::InvalidRequest(format!(
"Target origin not allowed: {host}"
)));
}
// Apply any configured rewrites
if let Some(rewrite) = state.config.get_rewrite(&host) {
debug!(
"Applying rewrite for host {}: {} -> {}",
host, rewrite.scheme, rewrite.host
);
// Clone the message to modify it
let mut new_message = Message::request(
Vec::from(control.method().unwrap_or(b"GET")),
Vec::from(
format!(
"{}://{}{}",
rewrite.scheme,
rewrite.host,
control
.path()
.map(|p| String::from_utf8_lossy(p))
.unwrap_or("/".into())
)
.as_bytes(),
),
Vec::from(control.scheme().unwrap_or(rewrite.scheme.as_bytes())),
Vec::from(rewrite.host.as_bytes()),
);
// Copy all headers except Host and Authority
for field in message.header().fields() {
let name = field.name();
if !name.eq_ignore_ascii_case(b"host") && !name.eq_ignore_ascii_case(b"authority") {
new_message.put_header(name, field.value());
}
}
// Add the new Host header
new_message.put_header(b"host", rewrite.host.as_bytes());
// Copy body content
if !message.content().is_empty() {
new_message.write_content(message.content());
}
return Ok(new_message);
}
Ok(message)
}
async fn forward_to_backend(
state: &AppState,
bhttp_message: Message,
) -> Result<reqwest::Response, GatewayError> {
let control = bhttp_message.control();
let method = control.method().unwrap_or(b"GET");
let path = control
.path()
.map(|p| String::from_utf8_lossy(p).into_owned())
.unwrap_or_else(|| "/".to_string());
// Extract host for URL construction
let host = control
.authority()
.map(|a| String::from_utf8_lossy(a).into_owned())
.or_else(|| {
bhttp_message.header().fields().iter().find_map(|field| {
if field.name().eq_ignore_ascii_case(b"host") {
Some(String::from_utf8_lossy(field.value()).into_owned())
} else {
None
}
})
});
// Build the backend URI
let uri = if let Some(host) = host {
// Check for rewrites
if let Some(rewrite) = state.config.get_rewrite(&host) {
format!("{}://{}{}", rewrite.scheme, rewrite.host, path)
} else {
build_backend_uri(&state.config.backend_url, &path)?
}
} else {
build_backend_uri(&state.config.backend_url, &path)?
};
info!(
"Forwarding {} request to {}",
String::from_utf8_lossy(method),
uri
);
let reqwest_method = convert_method_to_reqwest(method);
let mut request_builder = state.http_client.request(reqwest_method, &uri);
// Add headers from the binary HTTP message
for field in bhttp_message.header().fields() {
let name = String::from_utf8_lossy(field.name());
let value = String::from_utf8_lossy(field.value());
// Skip headers that should not be forwarded
if should_forward_header(&name) {
request_builder = request_builder.header(name.as_ref(), value.as_ref());
}
}
// Add body if present
let content = bhttp_message.content();
if !content.is_empty() {
request_builder = request_builder.body(content.to_vec());
}
// Send request with timeout
let response = request_builder.send().await.map_err(|e| {
error!("Backend request failed: {e}");
GatewayError::BackendError(format!("Backend request failed: {e}"))
})?;
// Check for backend errors
if response.status().is_server_error() {
return Err(GatewayError::BackendError(format!(
"Backend returned error: {}",
response.status()
)));
}
Ok(response)
}
fn convert_method_to_reqwest(method: &[u8]) -> reqwest::Method {
match method {
b"GET" => reqwest::Method::GET,
b"POST" => reqwest::Method::POST,
b"PUT" => reqwest::Method::PUT,
b"DELETE" => reqwest::Method::DELETE,
b"HEAD" => reqwest::Method::HEAD,
b"OPTIONS" => reqwest::Method::OPTIONS,
b"PATCH" => reqwest::Method::PATCH,
b"TRACE" => reqwest::Method::TRACE,
_ => reqwest::Method::GET,
}
}
fn build_backend_uri(backend_url: &str, path: &str) -> Result<String, GatewayError> {
let base_url = backend_url.trim_end_matches('/');
let path = if path.starts_with('/') {
path
} else {
&format!("/{path}")
};
// Validate path to prevent SSRF attacks
if path.contains("..") || path.contains("//") {
return Err(GatewayError::InvalidRequest(
"Invalid path detected".to_string(),
));
}
// Additional validation for suspicious patterns
if path.contains('\0') || path.contains('\r') || path.contains('\n') {
return Err(GatewayError::InvalidRequest(
"Invalid characters in path".to_string(),
));
}
Ok(format!("{base_url}{path}"))
}
fn should_forward_header(name: &str) -> bool {
const SKIP_HEADERS: &[&str] = &[
"host",
"connection",
"upgrade",
"proxy-authorization",
"proxy-authenticate",
"te",
"trailers",
"transfer-encoding",
"keep-alive",
"http2-settings",
"upgrade-insecure-requests",
];
!SKIP_HEADERS.contains(&name.to_lowercase().as_str())
}
async fn convert_to_binary_http(response: reqwest::Response) -> Result<Vec<u8>, GatewayError> {
let status = response.status();
let headers = response.headers().clone();
let body = response
.bytes()
.await
.map_err(|e| GatewayError::BackendError(format!("Failed to read response body: {e}")))?;
// Create a bhttp response message
let mut message = Message::response(
bhttp::StatusCode::try_from(status.as_u16())
.map_err(|_| GatewayError::InternalError("Invalid status code".to_string()))?,
);
// Add headers
for (name, value) in headers.iter() {
if should_forward_header(name.as_str()) {
message.put_header(name.as_str().as_bytes(), value.as_bytes());
}
}
// Add body
if !body.is_empty() {
message.write_content(&body);
}
// Serialize to binary HTTP using KnownLength mode for compatibility
let mut output = Vec::new();
message
.write_bhttp(Mode::KnownLength, &mut output)
.map_err(|e| GatewayError::InternalError(format!("Failed to write binary HTTP: {e}")))?;
debug!("Created BHTTP response of {} bytes", output.len());
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_key_id() {
let body = vec![0x00, 0x7F, 0x00, 0x01]; // version, key_id, kem_id...
assert_eq!(extract_key_id_from_request(&body), Some(0x7F));
let empty = vec![];
assert_eq!(extract_key_id_from_request(&empty), None);
}
#[test]
fn test_should_forward_header() {
assert!(should_forward_header("content-type"));
assert!(should_forward_header("authorization"));
assert!(!should_forward_header("connection"));
assert!(!should_forward_header("Host"));
}
#[test]
fn test_build_backend_uri() {
assert_eq!(
build_backend_uri("https://backend.com", "/api/test").unwrap(),
"https://backend.com/api/test"
);
assert_eq!(
build_backend_uri("https://backend.com/", "/api/test").unwrap(),
"https://backend.com/api/test"
);
assert!(build_backend_uri("https://backend.com", "/../etc/passwd").is_err());
assert!(build_backend_uri("https://backend.com", "//evil.com").is_err());
}
}

379
src/key_manager.rs Normal file
View file

@ -0,0 +1,379 @@
use chrono::{DateTime, Utc};
use ohttp::{
hpke::{Aead, Kdf, Kem},
KeyConfig, Server as OhttpServer, SymmetricSuite,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tracing::{error, info};
/// Represents a key with its metadata
#[derive(Clone, Debug)]
pub struct KeyInfo {
pub id: u8,
pub config: KeyConfig,
pub server: OhttpServer,
pub expires_at: DateTime<Utc>,
pub is_active: bool,
}
/// Configuration for key management
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct KeyManagerConfig {
/// How often to rotate keys (default: 30 days)
pub rotation_interval: Duration,
/// How long to keep old keys for decryption (default: 7 days)
pub key_retention_period: Duration,
/// Whether to enable automatic rotation
pub auto_rotation_enabled: bool,
/// Supported cipher suites
pub cipher_suites: Vec<CipherSuiteConfig>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct CipherSuiteConfig {
pub kem: String,
pub kdf: String,
pub aead: String,
}
impl Default for KeyManagerConfig {
fn default() -> Self {
Self {
rotation_interval: Duration::from_secs(30 * 24 * 60 * 60), // 30 days
key_retention_period: Duration::from_secs(7 * 24 * 60 * 60), // 7 days
auto_rotation_enabled: true,
cipher_suites: vec![
CipherSuiteConfig {
kem: "X25519_SHA256".to_string(),
kdf: "HKDF_SHA256".to_string(),
aead: "AES_128_GCM".to_string(),
},
CipherSuiteConfig {
kem: "X25519_SHA256".to_string(),
kdf: "HKDF_SHA256".to_string(),
aead: "CHACHA20_POLY1305".to_string(),
},
],
}
}
}
pub struct KeyManager {
/// All keys indexed by ID
keys: Arc<RwLock<HashMap<u8, KeyInfo>>>,
/// Current active key ID
active_key_id: Arc<RwLock<u8>>,
/// Configuration
config: KeyManagerConfig,
/// Key ID counter (wraps around after 255)
next_key_id: Arc<RwLock<u8>>,
/// Seed for deterministic key generation (optional)
seed: Option<Vec<u8>>,
}
impl KeyManager {
pub async fn new(config: KeyManagerConfig) -> Result<Self, Box<dyn std::error::Error>> {
let manager = Self {
keys: Arc::new(RwLock::new(HashMap::new())),
active_key_id: Arc::new(RwLock::new(0)),
config,
next_key_id: Arc::new(RwLock::new(1)),
seed: None,
};
// Generate initial key
let initial_key = manager.generate_new_key().await?;
{
let mut keys = manager.keys.write().await;
let mut active_id = manager.active_key_id.write().await;
keys.insert(initial_key.id, initial_key.clone());
*active_id = initial_key.id;
}
info!("KeyManager initialized with key ID: {}", initial_key.id);
Ok(manager)
}
/// Create a key manager with a seed for deterministic key generation
pub async fn new_with_seed(
config: KeyManagerConfig,
seed: Vec<u8>,
) -> Result<Self, Box<dyn std::error::Error>> {
if seed.len() < 32 {
return Err("Seed must be at least 32 bytes".into());
}
let mut manager = Self::new(config).await?;
manager.seed = Some(seed);
Ok(manager)
}
/// Generate a new key configuration
async fn generate_new_key(&self) -> Result<KeyInfo, Box<dyn std::error::Error>> {
let key_id = {
let mut next_id = self.next_key_id.write().await;
let id = *next_id;
*next_id = next_id.wrapping_add(1);
id
};
// Parse cipher suites from config
let mut symmetric_suites = Vec::new();
for suite in &self.config.cipher_suites {
let kdf = match suite.kdf.as_str() {
"HKDF_SHA256" => Kdf::HkdfSha256,
"HKDF_SHA384" => Kdf::HkdfSha384,
"HKDF_SHA512" => Kdf::HkdfSha512,
_ => Kdf::HkdfSha256,
};
let aead = match suite.aead.as_str() {
"AES_128_GCM" => Aead::Aes128Gcm,
"AES_256_GCM" => Aead::Aes256Gcm,
"CHACHA20_POLY1305" => Aead::ChaCha20Poly1305,
_ => Aead::Aes128Gcm,
};
symmetric_suites.push(SymmetricSuite::new(kdf, aead));
}
// Determine KEM based on config - only X25519 is supported by ohttp crate
let kem = Kem::X25519Sha256;
// Generate key config
let key_config = if let Some(seed) = &self.seed {
// Deterministic generation using seed + key_id
let mut key_seed = seed.clone();
key_seed.push(key_id);
// The ohttp crate doesn't directly support deterministic key generation
// This would require extending the crate or using a custom implementation
KeyConfig::new(key_id, kem, symmetric_suites)?
} else {
KeyConfig::new(key_id, kem, symmetric_suites)?
};
let server = OhttpServer::new(key_config.clone())?;
let now = Utc::now();
Ok(KeyInfo {
id: key_id,
config: key_config,
server,
expires_at: now + chrono::Duration::from_std(self.config.rotation_interval)?,
is_active: true,
})
}
/// Get the current active server for decryption
pub async fn get_current_server(&self) -> Result<OhttpServer, Box<dyn std::error::Error>> {
let keys = self.keys.read().await;
let active_id = self.active_key_id.read().await;
keys.get(&*active_id)
.map(|info| info.server.clone())
.ok_or_else(|| "No active key found".into())
}
/// Get a server by key ID (for handling requests with specific key IDs)
pub async fn get_server_by_id(&self, key_id: u8) -> Option<OhttpServer> {
let keys = self.keys.read().await;
keys.get(&key_id).map(|info| info.server.clone())
}
/// Get encoded config for backward compatibility
pub async fn get_encoded_config(&self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
let keys = self.keys.read().await;
let active_id = self.active_key_id.read().await;
let cfg_bytes = keys
.get(&*active_id)
.ok_or("no active key")?
.config
.encode()?;
let mut out = Vec::with_capacity(cfg_bytes.len() + 2);
out.extend_from_slice(&(cfg_bytes.len() as u16).to_be_bytes()); // 2-byte length
out.extend_from_slice(&cfg_bytes);
Ok(out)
}
/// Rotate keys by generating a new key and marking old ones for expiration
pub async fn rotate_keys(&self) -> Result<(), Box<dyn std::error::Error>> {
info!("Starting key rotation");
// Generate new key
let new_key = self.generate_new_key().await?;
let new_key_id = new_key.id;
// Update key store
{
let mut keys = self.keys.write().await;
let mut active_id = self.active_key_id.write().await;
let now = Utc::now();
// Mark current active key for future expiration
if let Some(current_key) = keys.get_mut(&*active_id) {
current_key.is_active = false;
// Keep it around for the retention period
current_key.expires_at =
now + chrono::Duration::from_std(self.config.key_retention_period)?;
}
// Add new key
keys.insert(new_key_id, new_key);
// Update active key ID
*active_id = new_key_id;
// Clean up expired keys
keys.retain(|_, info| info.expires_at > now);
info!(
"Key rotation completed. New active key ID: {}, total keys: {}",
new_key_id,
keys.len()
);
}
Ok(())
}
/// Check if rotation is needed
pub async fn should_rotate(&self) -> bool {
let keys = self.keys.read().await;
let active_id = self.active_key_id.read().await;
if let Some(active_key) = keys.get(&*active_id) {
let time_until_expiry = active_key.expires_at.signed_duration_since(Utc::now());
// Rotate if less than 10% of the rotation interval remains
let threshold = chrono::Duration::from_std(self.config.rotation_interval / 10)
.unwrap_or_else(|_| chrono::Duration::days(3));
time_until_expiry < threshold
} else {
true // No active key, definitely need to rotate
}
}
/// Start automatic key rotation scheduler
pub async fn start_rotation_scheduler(self: Arc<Self>) {
if !self.config.auto_rotation_enabled {
info!("Automatic key rotation is disabled");
return;
}
let manager = self;
tokio::spawn(async move {
// Check every hour
let mut interval = tokio::time::interval(Duration::from_secs(3600));
loop {
interval.tick().await;
if manager.should_rotate().await {
if let Err(e) = manager.rotate_keys().await {
error!("Key rotation failed: {}", e);
}
}
// Also clean up expired keys
manager.cleanup_expired_keys().await;
}
});
}
/// Clean up expired keys
async fn cleanup_expired_keys(&self) {
let mut keys = self.keys.write().await;
let now = Utc::now();
let before_count = keys.len();
keys.retain(|id, info| {
if info.expires_at <= now {
info!("Removing expired key ID: {}", id);
false
} else {
true
}
});
let removed = before_count - keys.len();
if removed > 0 {
info!("Cleaned up {} expired keys", removed);
}
}
/// Get key manager statistics
pub async fn get_stats(&self) -> KeyManagerStats {
let keys = self.keys.read().await;
let active_id = self.active_key_id.read().await;
let now = Utc::now();
let active_keys = keys.values().filter(|k| k.is_active).count();
let total_keys = keys.len();
let expired_keys = keys.values().filter(|k| k.expires_at <= now).count();
KeyManagerStats {
active_key_id: *active_id,
total_keys,
active_keys,
expired_keys,
rotation_interval: self.config.rotation_interval,
auto_rotation_enabled: self.config.auto_rotation_enabled,
}
}
}
#[derive(Debug, Serialize)]
pub struct KeyManagerStats {
pub active_key_id: u8,
pub total_keys: usize,
pub active_keys: usize,
pub expired_keys: usize,
pub rotation_interval: Duration,
pub auto_rotation_enabled: bool,
}
// Ensure thread safety
unsafe impl Send for KeyManager {}
unsafe impl Sync for KeyManager {}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_key_generation() {
let config = KeyManagerConfig::default();
let manager = KeyManager::new(config).await.unwrap();
let stats = manager.get_stats().await;
assert_eq!(stats.total_keys, 1);
assert_eq!(stats.active_keys, 1);
}
#[tokio::test]
async fn test_key_rotation() {
let config = KeyManagerConfig {
rotation_interval: Duration::from_secs(60),
key_retention_period: Duration::from_secs(30),
..Default::default()
};
let manager = KeyManager::new(config).await.unwrap();
let initial_stats = manager.get_stats().await;
// Rotate keys
manager.rotate_keys().await.unwrap();
let new_stats = manager.get_stats().await;
assert_eq!(new_stats.total_keys, 2);
assert_ne!(new_stats.active_key_id, initial_stats.active_key_id);
}
}

11
src/lib.rs Normal file
View file

@ -0,0 +1,11 @@
pub mod config;
pub mod error;
pub mod handlers;
pub mod key_manager;
pub mod metrics;
pub mod middleware;
pub mod state;
pub use config::AppConfig;
pub use error::GatewayError;
pub use state::AppState;

203
src/main.rs Normal file
View file

@ -0,0 +1,203 @@
mod config;
mod error;
mod handlers;
mod key_manager;
mod metrics;
mod middleware;
mod state;
use crate::config::{AppConfig, LogFormat};
use crate::state::AppState;
use axum::{middleware as axum_middleware, Router};
use std::net::SocketAddr;
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::signal;
use tower_http::compression::CompressionLayer;
use tower_http::cors::{Any, CorsLayer};
use tower_http::timeout::TimeoutLayer;
use tower_http::trace::TraceLayer;
use tracing::{info, warn};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Load configuration first
let config = AppConfig::from_env()?;
// Initialize tracing based on config
initialize_tracing(&config);
info!("Starting OHTTP Gateway v{}", env!("CARGO_PKG_VERSION"));
info!("Configuration loaded: {:?}", config);
// Initialize application state
let app_state = AppState::new(config.clone()).await?;
// Start key rotation scheduler
if config.key_rotation_enabled {
info!("Starting automatic key rotation scheduler");
app_state
.key_manager
.clone()
.start_rotation_scheduler()
.await;
} else {
warn!("Automatic key rotation is disabled");
}
// Create router
let app = create_router(app_state.clone(), &config);
// Parse socket address
let addr: SocketAddr = config.listen_addr.parse()?;
let listener = TcpListener::bind(addr).await?;
info!("OHTTP Gateway listening on {}", addr);
info!("Backend URL: {}", config.backend_url);
if let Some(allowed) = &config.allowed_target_origins {
info!("Allowed origins: {:?}", allowed);
} else {
warn!("No origin restrictions configured - all targets allowed");
}
// Start server with graceful shutdown
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(shutdown_signal())
.await?;
info!("Server stopped gracefully");
Ok(())
}
fn initialize_tracing(config: &AppConfig) {
use tracing_subscriber::{fmt, EnvFilter};
let env_filter =
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&config.log_level));
match config.log_format {
LogFormat::Json => {
fmt()
.json()
.with_env_filter(env_filter)
.with_target(true)
.with_thread_ids(true)
.with_file(config.debug_mode)
.with_line_number(config.debug_mode)
.init();
}
LogFormat::Default => {
fmt()
.with_env_filter(env_filter)
.with_target(true)
.with_thread_ids(true)
.with_file(config.debug_mode)
.with_line_number(config.debug_mode)
.init();
}
}
}
fn create_router(app_state: AppState, config: &AppConfig) -> Router {
let mut app = Router::new();
// Add routes
app = app.merge(handlers::routes());
// Add middleware layers (order matters - first added is executed last)
app = app.layer(
tower::ServiceBuilder::new()
// Outer layers (executed first on request, last on response)
.layer(TraceLayer::new_for_http())
.layer(CompressionLayer::new())
.layer(TimeoutLayer::new(Duration::from_secs(60)))
// Security middleware
.layer(axum_middleware::from_fn_with_state(
app_state.clone(),
middleware::security::security_middleware,
))
// Request validation
.layer(axum_middleware::from_fn(
middleware::security::request_validation_middleware,
))
// Logging middleware
.layer(axum_middleware::from_fn_with_state(
app_state.clone(),
middleware::logging::logging_middleware,
))
// Metrics middleware
.layer(axum_middleware::from_fn_with_state(
app_state.clone(),
middleware::metrics::metrics_middleware,
))
// CORS configuration
.layer(create_cors_layer(config)),
);
app.with_state(app_state)
}
fn create_cors_layer(config: &AppConfig) -> CorsLayer {
if config.debug_mode {
// Permissive CORS in debug mode
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any)
} else {
// Restrictive CORS in production
CorsLayer::new()
.allow_origin([
"https://example.com".parse().unwrap(),
// Add your allowed origins here
])
.allow_methods([axum::http::Method::GET, axum::http::Method::POST])
.allow_headers([axum::http::header::CONTENT_TYPE, axum::http::header::ACCEPT])
.max_age(Duration::from_secs(3600))
}
}
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {
info!("Received Ctrl+C, starting graceful shutdown");
},
_ = terminate => {
info!("Received SIGTERM, starting graceful shutdown");
},
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_loading() {
// Test that default config loads successfully
let config = AppConfig::default();
assert!(!config.debug_mode);
assert!(config.key_rotation_enabled);
}
}

66
src/metrics.rs Normal file
View file

@ -0,0 +1,66 @@
use prometheus::{register_counter, register_gauge, register_histogram, Counter, Gauge, Histogram};
#[derive(Clone)]
pub struct AppMetrics {
pub requests_total: Counter,
pub successful_requests_total: Counter,
pub decryption_errors_total: Counter,
pub encryption_errors_total: Counter,
pub backend_errors_total: Counter,
pub key_requests_total: Counter,
pub request_duration: Histogram,
pub active_connections: Gauge,
}
impl Default for AppMetrics {
fn default() -> Self {
AppMetrics::new()
}
}
impl AppMetrics {
fn new() -> Self {
Self {
requests_total: register_counter!(
"ohttp_requests_total",
"Total number of OHTTP requests"
)
.unwrap(),
successful_requests_total: register_counter!(
"ohttp_successful_requests_total",
"Total number of successful OHTTP requests"
)
.unwrap(),
decryption_errors_total: register_counter!(
"ohttp_decryption_errors_total",
"Total number of decryption errors"
)
.unwrap(),
encryption_errors_total: register_counter!(
"ohttp_encryption_errors_total",
"Total number of encryption errors"
)
.unwrap(),
backend_errors_total: register_counter!(
"ohttp_backend_errors_total",
"Total number of backend errors"
)
.unwrap(),
key_requests_total: register_counter!(
"ohttp_key_requests_total",
"Total number of key configuration requests"
)
.unwrap(),
request_duration: register_histogram!(
"ohttp_request_duration_seconds",
"Duration of OHTTP request processing"
)
.unwrap(),
active_connections: register_gauge!(
"ohttp_active_connections",
"Number of active connections"
)
.unwrap(),
}
}
}

56
src/middleware/logging.rs Normal file
View file

@ -0,0 +1,56 @@
use axum::{body::Body, extract::Request, http::StatusCode, middleware::Next, response::Response};
use std::time::Instant;
use tracing::{info, warn, Instrument};
use uuid::Uuid;
pub async fn logging_middleware(
request: Request<Body>,
next: Next,
) -> Result<Response, StatusCode> {
let request_id = Uuid::new_v4();
let method = request.method().clone();
let uri = request.uri().clone();
let user_agent = request
.headers()
.get("user-agent")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown")
.to_string();
let span = tracing::info_span!(
"http_request",
request_id = %request_id,
method = %method,
uri = %uri,
user_agent = %user_agent
);
async move {
let start = Instant::now();
info!("Processing request");
let response = next.run(request).await;
let duration = start.elapsed();
let status = response.status();
if status.is_success() {
info!(
status = %status,
duration_ms = duration.as_millis(),
"Request completed successfully"
);
} else {
warn!(
status = %status,
duration_ms = duration.as_millis(),
"Request failed"
);
}
Ok(response)
}
.instrument(span)
.await
}

17
src/middleware/metrics.rs Normal file
View file

@ -0,0 +1,17 @@
// Additional metrics middleware if needed
use crate::state::AppState;
use axum::{body::Body, extract::Request, extract::State, middleware::Next, response::Response};
pub async fn metrics_middleware(
State(state): State<AppState>,
request: Request<Body>,
next: Next,
) -> Response {
state.metrics.active_connections.inc();
let response = next.run(request).await;
state.metrics.active_connections.dec();
response
}

3
src/middleware/mod.rs Normal file
View file

@ -0,0 +1,3 @@
pub mod logging;
pub mod metrics;
pub mod security;

188
src/middleware/security.rs Normal file
View file

@ -0,0 +1,188 @@
use axum::{
body::Body,
extract::{ConnectInfo, Request, State},
http::{header, HeaderMap, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::Mutex;
use tracing::{info, warn};
use uuid::Uuid;
use crate::{config::RateLimitConfig, state::AppState};
/// Rate limiter implementation
pub struct RateLimiter {
config: RateLimitConfig,
buckets: Arc<Mutex<HashMap<String, TokenBucket>>>,
}
struct TokenBucket {
tokens: f64,
last_update: Instant,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
config,
buckets: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn check_rate_limit(&self, key: &str) -> bool {
let mut buckets = self.buckets.lock().await;
let now = Instant::now();
let bucket = buckets
.entry(key.to_string())
.or_insert_with(|| TokenBucket {
tokens: self.config.burst_size as f64,
last_update: now,
});
// Calculate tokens to add based on time elapsed
let elapsed = now.duration_since(bucket.last_update).as_secs_f64();
let tokens_to_add = elapsed * (self.config.requests_per_second as f64);
bucket.tokens = (bucket.tokens + tokens_to_add).min(self.config.burst_size as f64);
bucket.last_update = now;
// Check if we have tokens available
if bucket.tokens >= 1.0 {
bucket.tokens -= 1.0;
true
} else {
false
}
}
}
/// Security middleware that adds various security headers and checks
pub async fn security_middleware(
State(state): State<AppState>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
request: Request<Body>,
next: Next,
) -> Result<Response, StatusCode> {
// Generate request ID for tracing
let request_id = Uuid::new_v4();
// Add security headers to the request context
let mut request = request;
request
.headers_mut()
.insert("x-request-id", request_id.to_string().parse().unwrap());
let is_https = matches!(request.uri().scheme_str(), Some("https"));
// Apply rate limiting if configured
if let Some(rate_limit_config) = &state.config.rate_limit {
let rate_limiter = RateLimiter::new(rate_limit_config.clone());
let rate_limit_key = if rate_limit_config.by_ip {
addr.ip().to_string()
} else {
"global".to_string()
};
if !rate_limiter.check_rate_limit(&rate_limit_key).await {
warn!(
"Rate limit exceeded for key: {}, request_id: {}",
rate_limit_key, request_id
);
return Ok((
StatusCode::TOO_MANY_REQUESTS,
[
(
"X-RateLimit-Limit",
rate_limit_config.requests_per_second.to_string(),
),
("X-RateLimit-Remaining", "0".to_string()),
("Retry-After", "1".to_string()),
],
"Rate limit exceeded",
)
.into_response());
}
}
// Process the request
let mut response = next.run(request).await;
// Add security headers to the response
let headers = response.headers_mut();
// Security headers
headers.insert("X-Content-Type-Options", "nosniff".parse().unwrap());
headers.insert("X-Frame-Options", "DENY".parse().unwrap());
headers.insert("X-XSS-Protection", "1; mode=block".parse().unwrap());
headers.insert("Referrer-Policy", "no-referrer".parse().unwrap());
headers.insert("X-Request-ID", request_id.to_string().parse().unwrap());
// HSTS header for HTTPS connections
if is_https {
headers.insert(
"Strict-Transport-Security",
"max-age=31536000; includeSubDomains".parse().unwrap(),
);
}
// Content Security Policy
headers.insert(
"Content-Security-Policy",
"default-src 'none'; frame-ancestors 'none';"
.parse()
.unwrap(),
);
// Remove sensitive headers
headers.remove("Server");
headers.remove("X-Powered-By");
Ok(response)
}
/// Middleware to validate and sanitize incoming requests
pub async fn request_validation_middleware(
headers: HeaderMap,
request: Request<Body>,
next: Next,
) -> Result<Response, StatusCode> {
// Check for required headers only on requests with bodies
if matches!(
request.method(),
&axum::http::Method::POST | &axum::http::Method::PUT | &axum::http::Method::PATCH
) && !headers.contains_key(header::CONTENT_TYPE)
{
return Err(StatusCode::BAD_REQUEST);
}
// Validate User-Agent
if let Some(user_agent) = headers.get(header::USER_AGENT) {
if let Ok(ua_str) = user_agent.to_str() {
// Block known bad user agents
if ua_str.is_empty() || ua_str.contains("bot") || ua_str.contains("crawler") {
info!("Blocked suspicious user agent: {}", ua_str);
return Err(StatusCode::FORBIDDEN);
}
}
}
// Check for suspicious headers that might indicate attacks
const SUSPICIOUS_HEADERS: &[&str] = &["x-forwarded-host", "x-original-url", "x-rewrite-url"];
for header_name in SUSPICIOUS_HEADERS {
if headers.contains_key(*header_name) {
warn!("Request contains suspicious header: {}", header_name);
return Err(StatusCode::BAD_REQUEST);
}
}
Ok(next.run(request).await)
}

94
src/state.rs Normal file
View file

@ -0,0 +1,94 @@
use crate::{
config::AppConfig,
key_manager::{CipherSuiteConfig, KeyManager, KeyManagerConfig},
metrics::AppMetrics,
};
use std::sync::Arc;
#[derive(Clone)]
pub struct AppState {
pub key_manager: Arc<KeyManager>,
pub http_client: reqwest::Client,
pub config: AppConfig,
pub metrics: AppMetrics,
}
impl AppState {
pub async fn new(config: AppConfig) -> Result<Self, Box<dyn std::error::Error>> {
// Configure key manager based on app config
let key_manager_config = KeyManagerConfig {
rotation_interval: config.key_rotation_interval,
key_retention_period: config.key_retention_period,
auto_rotation_enabled: config.key_rotation_enabled,
cipher_suites: get_cipher_suites(&config),
};
// Initialize key manager with or without seed
let key_manager = if let Some(seed_hex) = &config.seed_secret_key {
let seed = hex::decode(seed_hex)?;
Arc::new(KeyManager::new_with_seed(key_manager_config, seed).await?)
} else {
Arc::new(KeyManager::new(key_manager_config).await?)
};
// Create optimized HTTP client for backend requests
let http_client = create_http_client(&config)?;
let metrics = AppMetrics::default();
Ok(Self {
key_manager,
http_client,
config,
metrics,
})
}
}
fn get_cipher_suites(config: &AppConfig) -> Vec<CipherSuiteConfig> {
// Default cipher suites matching the Go implementation
let mut suites = vec![
CipherSuiteConfig {
kem: "X25519_SHA256".to_string(),
kdf: "HKDF_SHA256".to_string(),
aead: "AES_128_GCM".to_string(),
},
CipherSuiteConfig {
kem: "X25519_SHA256".to_string(),
kdf: "HKDF_SHA256".to_string(),
aead: "CHACHA20_POLY1305".to_string(),
},
];
// Add high-security suite if in production mode
if !config.debug_mode {
suites.push(CipherSuiteConfig {
kem: "P256_SHA256".to_string(),
kdf: "HKDF_SHA256".to_string(),
aead: "AES_256_GCM".to_string(),
});
}
suites
}
fn create_http_client(config: &AppConfig) -> Result<reqwest::Client, Box<dyn std::error::Error>> {
let mut client_builder = reqwest::Client::builder()
.timeout(config.request_timeout)
.pool_max_idle_per_host(100)
.pool_idle_timeout(std::time::Duration::from_secs(30))
.tcp_keepalive(std::time::Duration::from_secs(60))
.tcp_nodelay(true)
.user_agent("ohttp-gateway/1.0")
.danger_accept_invalid_certs(config.debug_mode); // Only in debug mode
// Configure proxy if needed
if let Ok(proxy_url) = std::env::var("HTTP_PROXY") {
client_builder = client_builder.proxy(reqwest::Proxy::http(proxy_url)?);
}
if let Ok(proxy_url) = std::env::var("HTTPS_PROXY") {
client_builder = client_builder.proxy(reqwest::Proxy::https(proxy_url)?);
}
Ok(client_builder.build()?)
}