Derive key from seed + tests + change port to 8000 by default

This commit is contained in:
Bastian Gruber 2025-07-22 12:57:19 +02:00
parent 7f8e78d831
commit 86a0f71690
No known key found for this signature in database
GPG key ID: D2DF996A188CFBA2
10 changed files with 870 additions and 122 deletions

101
Cargo.lock generated
View file

@ -254,10 +254,10 @@ dependencies = [
"axum-macros",
"bytes",
"futures-util",
"http",
"http-body",
"http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"hyper",
"hyper 1.6.0",
"hyper-util",
"itoa",
"matchit",
@ -287,8 +287,8 @@ dependencies = [
"async-trait",
"bytes",
"futures-util",
"http",
"http-body",
"http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"mime",
"pin-project-lite",
@ -319,8 +319,8 @@ dependencies = [
"axum",
"bytes",
"futures-core",
"http",
"http-body",
"http 1.3.1",
"http-body 1.0.1",
"matchit",
"metrics",
"metrics-exporter-prometheus",
@ -1056,7 +1056,7 @@ dependencies = [
"fnv",
"futures-core",
"futures-sink",
"http",
"http 1.3.1",
"indexmap",
"slab",
"tokio",
@ -1166,6 +1166,17 @@ dependencies = [
"zeroize",
]
[[package]]
name = "http"
version = "0.2.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1"
dependencies = [
"bytes",
"fnv",
"itoa",
]
[[package]]
name = "http"
version = "1.3.1"
@ -1177,6 +1188,17 @@ dependencies = [
"itoa",
]
[[package]]
name = "http-body"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2"
dependencies = [
"bytes",
"http 0.2.12",
"pin-project-lite",
]
[[package]]
name = "http-body"
version = "1.0.1"
@ -1184,7 +1206,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184"
dependencies = [
"bytes",
"http",
"http 1.3.1",
]
[[package]]
@ -1195,8 +1217,8 @@ checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a"
dependencies = [
"bytes",
"futures-core",
"http",
"http-body",
"http 1.3.1",
"http-body 1.0.1",
"pin-project-lite",
]
@ -1212,6 +1234,28 @@ version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "hyper"
version = "0.14.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7"
dependencies = [
"bytes",
"futures-channel",
"futures-core",
"futures-util",
"http 0.2.12",
"http-body 0.4.6",
"httparse",
"httpdate",
"itoa",
"pin-project-lite",
"tokio",
"tower-service",
"tracing",
"want",
]
[[package]]
name = "hyper"
version = "1.6.0"
@ -1222,8 +1266,8 @@ dependencies = [
"futures-channel",
"futures-util",
"h2",
"http",
"http-body",
"http 1.3.1",
"http-body 1.0.1",
"httparse",
"httpdate",
"itoa",
@ -1239,8 +1283,8 @@ version = "0.27.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58"
dependencies = [
"http",
"hyper",
"http 1.3.1",
"hyper 1.6.0",
"hyper-util",
"rustls",
"rustls-pki-types",
@ -1257,7 +1301,7 @@ checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
dependencies = [
"bytes",
"http-body-util",
"hyper",
"hyper 1.6.0",
"hyper-util",
"native-tls",
"tokio",
@ -1276,9 +1320,9 @@ dependencies = [
"futures-channel",
"futures-core",
"futures-util",
"http",
"http-body",
"hyper",
"http 1.3.1",
"http-body 1.0.1",
"hyper 1.6.0",
"ipnet",
"libc",
"percent-encoding",
@ -1611,7 +1655,7 @@ checksum = "b4f0c8427b39666bf970460908b213ec09b3b350f20c0c2eabcbba51704a08e6"
dependencies = [
"base64 0.22.1",
"http-body-util",
"hyper",
"hyper 1.6.0",
"hyper-util",
"indexmap",
"ipnet",
@ -1796,7 +1840,8 @@ dependencies = [
"config",
"futures",
"hex",
"hyper",
"hyper 0.14.32",
"hyper 1.6.0",
"hyper-util",
"jsonwebtoken",
"ohttp",
@ -2285,10 +2330,10 @@ dependencies = [
"futures-core",
"futures-util",
"h2",
"http",
"http-body",
"http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"hyper",
"hyper 1.6.0",
"hyper-rustls",
"hyper-tls",
"hyper-util",
@ -2976,8 +3021,8 @@ checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5"
dependencies = [
"bitflags",
"bytes",
"http",
"http-body",
"http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"pin-project-lite",
"tower-layer",
@ -2995,8 +3040,8 @@ dependencies = [
"bytes",
"futures-core",
"futures-util",
"http",
"http-body",
"http 1.3.1",
"http-body 1.0.1",
"iri-string",
"pin-project-lite",
"tokio",

View file

@ -68,5 +68,10 @@ rand = "0.8"
# Configuration management
clap = { version = "4.5", features = ["derive", "env"] }
[dev-dependencies]
tokio = { version = "1", features = ["full"] }
hyper = "0.14"
rand = "0.8"
[profile.release]
lto = "fat"

View file

@ -44,7 +44,7 @@ USER ohttp
# Set default environment variables
ENV RUST_LOG=debug,ohttp_gateway=debug
ENV LISTEN_ADDR=0.0.0.0:8080
ENV PORT=8000
ENV BACKEND_URL=http://localhost:8000
ENV REQUEST_TIMEOUT=30
ENV KEY_ROTATION_ENABLED=false

View file

@ -5,7 +5,7 @@ use std::time::Duration;
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct AppConfig {
// Server configuration
pub listen_addr: String,
pub port: String,
pub backend_url: String,
pub request_timeout: Duration,
pub max_body_size: usize,
@ -60,7 +60,7 @@ pub enum LogFormat {
impl Default for AppConfig {
fn default() -> Self {
Self {
listen_addr: "0.0.0.0:8080".to_string(),
port: "0.0.0.0:8000".to_string(),
backend_url: "http://localhost:8080".to_string(),
request_timeout: Duration::from_secs(30),
max_body_size: 10 * 1024 * 1024, // 10MB
@ -86,8 +86,8 @@ impl AppConfig {
let mut config = Self::default();
// Basic configuration
if let Ok(addr) = std::env::var("LISTEN_ADDR") {
config.listen_addr = addr;
if let Ok(port) = std::env::var("PORT") {
config.port = format!("0.0.0.0:{port}");
}
if let Ok(url) = std::env::var("BACKEND_URL") {
@ -234,37 +234,3 @@ impl AppConfig {
.and_then(|config| config.rewrites.get(host))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = AppConfig::default();
assert_eq!(config.listen_addr, "0.0.0.0:8080");
assert!(config.key_rotation_enabled);
}
#[test]
fn test_validation_key_periods() {
let mut config = AppConfig::default();
config.key_retention_period = Duration::from_secs(100);
config.key_rotation_interval = Duration::from_secs(50);
assert!(config.validate().is_err());
}
#[test]
fn test_origin_allowed() {
let mut config = AppConfig::default();
config.allowed_target_origins = Some(
vec!["example.com".to_string(), "test.com".to_string()]
.into_iter()
.collect(),
);
assert!(config.is_origin_allowed("example.com"));
assert!(!config.is_origin_allowed("forbidden.com"));
}
}

View file

@ -108,8 +108,25 @@ impl KeyManager {
return Err("Seed must be at least 32 bytes".into());
}
let mut manager = Self::new(config).await?;
manager.seed = Some(seed);
let manager = Self {
keys: Arc::new(RwLock::new(HashMap::new())),
active_key_id: Arc::new(RwLock::new(0)),
config,
next_key_id: Arc::new(RwLock::new(1)),
seed: Some(seed),
};
// Generate initial key (will now use the seed)
let initial_key = manager.generate_new_key().await?;
{
let mut keys = manager.keys.write().await;
let mut active_id = manager.active_key_id.write().await;
keys.insert(initial_key.id, initial_key.clone());
*active_id = initial_key.id;
}
info!("KeyManager initialized with key ID: {}", initial_key.id);
Ok(manager)
}
@ -142,6 +159,11 @@ impl KeyManager {
symmetric_suites.push(SymmetricSuite::new(kdf, aead));
}
// Validate that we have at least one cipher suite
if symmetric_suites.is_empty() {
return Err("No valid cipher suites configured".into());
}
// Determine KEM based on config - only X25519 is supported by ohttp crate
let kem = Kem::X25519Sha256;
@ -151,9 +173,7 @@ impl KeyManager {
let mut key_seed = seed.clone();
key_seed.push(key_id);
// The ohttp crate doesn't directly support deterministic key generation
// This would require extending the crate or using a custom implementation
KeyConfig::new(key_id, kem, symmetric_suites)?
KeyConfig::derive(key_id, kem, symmetric_suites, &key_seed)?
} else {
KeyConfig::new(key_id, kem, symmetric_suites)?
};
@ -270,8 +290,8 @@ impl KeyManager {
let manager = self;
tokio::spawn(async move {
// Check every hour
let mut interval = tokio::time::interval(Duration::from_secs(3600));
// Use the configured rotation interval for the scheduler
let mut interval = tokio::time::interval(manager.config.rotation_interval);
loop {
interval.tick().await;
@ -343,37 +363,3 @@ pub struct KeyManagerStats {
// Ensure thread safety
unsafe impl Send for KeyManager {}
unsafe impl Sync for KeyManager {}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_key_generation() {
let config = KeyManagerConfig::default();
let manager = KeyManager::new(config).await.unwrap();
let stats = manager.get_stats().await;
assert_eq!(stats.total_keys, 1);
assert_eq!(stats.active_keys, 1);
}
#[tokio::test]
async fn test_key_rotation() {
let config = KeyManagerConfig {
rotation_interval: Duration::from_secs(60),
key_retention_period: Duration::from_secs(30),
..Default::default()
};
let manager = KeyManager::new(config).await.unwrap();
let initial_stats = manager.get_stats().await;
// Rotate keys
manager.rotate_keys().await.unwrap();
let new_stats = manager.get_stats().await;
assert_eq!(new_stats.total_keys, 2);
assert_ne!(new_stats.active_key_id, initial_stats.active_key_id);
}
}

View file

@ -49,7 +49,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let app = create_router(app_state.clone(), &config);
// Parse socket address
let addr: SocketAddr = config.listen_addr.parse()?;
let addr: SocketAddr = config.port.parse()?;
let listener = TcpListener::bind(addr).await?;
info!("OHTTP Gateway listening on {}", addr);
@ -188,16 +188,3 @@ async fn shutdown_signal() {
},
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_loading() {
// Test that default config loads successfully
let config = AppConfig::default();
assert!(!config.debug_mode);
assert!(config.key_rotation_enabled);
}
}

182
tests/common/mod.rs Normal file
View file

@ -0,0 +1,182 @@
//! Test utilities and common code for integration tests
#![cfg(test)]
#![allow(dead_code)]
use ohttp::{
KeyConfig, Server as OhttpServer, SymmetricSuite,
hpke::{Aead, Kdf, Kem},
};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
// Test constants matching Go implementation
pub const LEGACY_KEY_ID: u8 = 0x00;
pub const CURRENT_KEY_ID: u8 = 0x01;
pub const FORBIDDEN_TARGET: &str = "forbidden.example";
pub const ALLOWED_TARGET: &str = "allowed.example";
pub const GATEWAY_DEBUG: bool = true;
pub const BINARY_HTTP_GATEWAY_ENDPOINT: &str = "/binary-http-gateway";
// Mock metrics for testing
#[derive(Debug, Clone, Default)]
pub struct MockMetrics {
pub event_name: String,
pub result_labels: Arc<RwLock<HashMap<String, bool>>>,
}
impl MockMetrics {
pub fn new(event_name: String) -> Self {
Self {
event_name,
result_labels: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn response_status(&self, prefix: &str, status: u16) {
self.fire(&format!("{}_response_status_{}", prefix, status))
.await;
}
pub async fn fire(&self, result: &str) {
let mut labels = self.result_labels.write().await;
if labels.contains_key(result) {
panic!("Metrics.fire called twice for the same result: {}", result);
}
labels.insert(result.to_string(), true);
}
pub async fn contains_result(&self, result: &str) -> bool {
let labels = self.result_labels.read().await;
labels.contains_key(result)
}
}
#[derive(Debug, Default)]
pub struct MockMetricsFactory {
pub metrics: Arc<RwLock<Vec<MockMetrics>>>,
}
impl MockMetricsFactory {
pub fn new() -> Self {
Self {
metrics: Arc::new(RwLock::new(Vec::new())),
}
}
pub async fn create(&self, event_name: String) -> MockMetrics {
let metrics = MockMetrics::new(event_name);
let mut metrics_vec = self.metrics.write().await;
metrics_vec.push(metrics.clone());
metrics
}
pub async fn get_metrics_for_event(&self, event_name: &str) -> Option<MockMetrics> {
let metrics_vec = self.metrics.read().await;
metrics_vec
.iter()
.find(|m| m.event_name == event_name)
.cloned()
}
}
// Test key configuration similar to Go's createGateway
pub fn create_test_key_configs() -> Result<(KeyConfig, KeyConfig), Box<dyn std::error::Error>> {
// Legacy configuration (X25519 only)
let legacy_config = KeyConfig::new(
LEGACY_KEY_ID,
Kem::X25519Sha256,
vec![SymmetricSuite::new(Kdf::HkdfSha256, Aead::Aes128Gcm)],
)?;
// Current configuration (for testing - in real implementation would be post-quantum)
let current_config = KeyConfig::new(
CURRENT_KEY_ID,
Kem::X25519Sha256, // ohttp crate limitation - would be KEM_X25519_KYBER768_DRAFT00
vec![SymmetricSuite::new(Kdf::HkdfSha256, Aead::Aes128Gcm)],
)?;
Ok((current_config, legacy_config))
}
// Create test servers from configs
pub fn create_test_servers() -> Result<(OhttpServer, OhttpServer), Box<dyn std::error::Error>> {
let (current_config, legacy_config) = create_test_key_configs()?;
let current_server = OhttpServer::new(current_config)?;
let legacy_server = OhttpServer::new(legacy_config)?;
Ok((current_server, legacy_server))
}
// Mock HTTP request handler for testing
#[derive(Debug, Clone)]
pub struct MockHTTPRequestHandler;
impl MockHTTPRequestHandler {
pub fn handle(&self, url: &str) -> Result<String, Box<dyn std::error::Error>> {
// Echo the URL back for testing
Ok(url.to_string())
}
}
// Helper function to create test binary HTTP messages
pub fn create_test_binary_http_message() -> Vec<u8> {
// Simple test message similar to Go's {0xCA, 0xFE}
vec![0xCA, 0xFE]
}
// Helper to validate cache control headers
pub fn validate_cache_control_header(header_value: &str) -> Result<(), String> {
if !header_value.starts_with("max-age=") || !header_value.ends_with(", private") {
return Err(format!("Invalid cache-control format: {}", header_value));
}
let max_age_str = header_value
.strip_prefix("max-age=")
.and_then(|s| s.strip_suffix(", private"))
.ok_or("Failed to parse max-age")?;
let max_age: u32 = max_age_str
.parse()
.map_err(|_| "max-age should be a number")?;
const TWELVE_HOURS: u32 = 12 * 3600;
const TWENTY_FOUR_HOURS: u32 = 24 * 3600;
if max_age < TWELVE_HOURS || max_age > TWELVE_HOURS + TWENTY_FOUR_HOURS {
return Err(format!(
"max-age {} should be between 12 and 36 hours",
max_age
));
}
Ok(())
}
// Test result assertion helpers
pub async fn assert_metrics_contains_result(
factory: &MockMetricsFactory,
event: &str,
result: &str,
) -> Result<(), String> {
if let Some(metrics) = factory.get_metrics_for_event(event).await {
if !metrics.contains_result(result).await {
return Err(format!("Expected event {}/{} was not fired", event, result));
}
Ok(())
} else {
Err(format!("No metrics found for event: {}", event))
}
}
pub fn assert_body_contains_error(body: &[u8], expected_text: &str) -> Result<(), String> {
let body_str = String::from_utf8_lossy(body);
if !body_str.contains(expected_text) {
return Err(format!(
"Failed to return expected text ({}) in response. Body text is: {}",
expected_text, body_str
));
}
Ok(())
}

View file

@ -0,0 +1,206 @@
use hyper::StatusCode;
use rand::Rng;
use ohttp_gateway::{key_manager::KeyManager, key_manager::KeyManagerConfig};
mod common;
use common::{LEGACY_KEY_ID, validate_cache_control_header};
// Mock HTTP response structure for testing
struct MockResponse {
status: StatusCode,
headers: std::collections::HashMap<String, String>,
body: Vec<u8>,
}
impl MockResponse {
fn new(status: StatusCode, body: Vec<u8>) -> Self {
Self {
status,
headers: std::collections::HashMap::new(),
body,
}
}
fn add_header(&mut self, name: &str, value: &str) {
self.headers.insert(name.to_string(), value.to_string());
}
fn get_header(&self, name: &str) -> Option<&String> {
self.headers.get(name)
}
}
// Mock config handler - adapt this to match your actual HTTP handler structure
async fn mock_config_handler(
manager: &KeyManager,
) -> Result<MockResponse, Box<dyn std::error::Error>> {
// Generate random cache age between 12-36 hours (mirroring Go implementation)
use rand::Rng;
let mut rng = rand::thread_rng();
let twelve_hours = 12 * 3600;
let twenty_four_hours = 24 * 3600;
let max_age = twelve_hours + rng.gen_range(0..twenty_four_hours);
let encoded_config = manager.get_encoded_config().await?;
let mut response = MockResponse::new(StatusCode::OK, encoded_config);
response.add_header("Cache-Control", &format!("max-age={}, private", max_age));
response.add_header("Content-Type", "application/ohttp-keys");
Ok(response)
}
async fn mock_legacy_config_handler(
manager: &KeyManager,
_key_id: u8,
) -> Result<MockResponse, Box<dyn std::error::Error>> {
// This would need to be implemented based on your legacy config support
// For now, return a simple implementation
let encoded_config = manager.get_encoded_config().await?;
let mut rng = rand::thread_rng();
let twelve_hours = 12 * 3600;
let twenty_four_hours = 24 * 3600;
let max_age = twelve_hours + rng.gen_range(0..twenty_four_hours);
let mut response = MockResponse::new(StatusCode::OK, encoded_config);
response.add_header("Cache-Control", &format!("max-age={}, private", max_age));
Ok(response)
}
#[tokio::test]
async fn test_config_handler() {
let config = KeyManagerConfig::default();
let manager = KeyManager::new(config).await.unwrap();
let response = mock_config_handler(&manager).await.unwrap();
// Check status
assert_eq!(response.status, StatusCode::OK);
// Check headers
assert_eq!(
response.get_header("Content-Type").unwrap(),
"application/ohttp-keys"
);
let cache_control = response.get_header("Cache-Control").unwrap();
validate_cache_control_header(cache_control).unwrap();
// Check body is not empty and has expected structure
assert!(!response.body.is_empty());
assert!(response.body.len() >= 4); // At least length prefix + some config data
}
#[tokio::test]
async fn test_legacy_config_handler() {
let config = KeyManagerConfig::default();
let manager = KeyManager::new(config).await.unwrap();
let response = mock_legacy_config_handler(&manager, LEGACY_KEY_ID)
.await
.unwrap();
// Check status
assert_eq!(response.status, StatusCode::OK);
// Check cache control header exists and is valid
let cache_control = response.get_header("Cache-Control").unwrap();
validate_cache_control_header(cache_control).unwrap();
// Check body
assert!(!response.body.is_empty());
}
#[tokio::test]
async fn test_config_handler_multiple_keys() {
let config = KeyManagerConfig::default();
let manager = KeyManager::new(config).await.unwrap();
// Add another key through rotation
manager.rotate_keys().await.unwrap();
let response = mock_config_handler(&manager).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
// Body should be larger with multiple keys
assert!(response.body.len() >= 8); // At least 2 key configs
}
#[tokio::test]
async fn test_config_consistency() {
let config = KeyManagerConfig::default();
let manager = KeyManager::new(config).await.unwrap();
// Get config multiple times
let response1 = mock_config_handler(&manager).await.unwrap();
let response2 = mock_config_handler(&manager).await.unwrap();
// Both responses should be successful
assert_eq!(response1.status, StatusCode::OK);
assert_eq!(response2.status, StatusCode::OK);
// Key content should be the same (though cache headers may differ)
// Note: In a real implementation, you might want to test deterministic key generation
assert_eq!(response1.body.len(), response2.body.len());
}
#[tokio::test]
async fn test_config_with_deterministic_seed() {
let config = KeyManagerConfig::default();
let seed = vec![0x42u8; 32]; // Fixed seed for deterministic keys
let manager1 = KeyManager::new_with_seed(config.clone(), seed.clone())
.await
.unwrap();
let manager2 = KeyManager::new_with_seed(config, seed).await.unwrap();
let response1 = mock_config_handler(&manager1).await.unwrap();
let response2 = mock_config_handler(&manager2).await.unwrap();
// Both should succeed
assert_eq!(response1.status, StatusCode::OK);
assert_eq!(response2.status, StatusCode::OK);
// With the same seed, the key configurations should be identical
// This now works because we're using KeyConfig::derive() for deterministic generation
assert_eq!(response1.body, response2.body);
// Also verify the bodies are not empty and have valid structure
assert!(!response1.body.is_empty());
assert!(response1.body.len() >= 4);
}
#[tokio::test]
async fn test_cache_control_randomization() {
let config = KeyManagerConfig::default();
let manager = KeyManager::new(config).await.unwrap();
let mut max_ages = std::collections::HashSet::new();
// Generate multiple responses and collect max-age values
for _ in 0..10 {
let response = mock_config_handler(&manager).await.unwrap();
let cache_control = response.get_header("Cache-Control").unwrap();
// Extract max-age value
let max_age_str = cache_control
.strip_prefix("max-age=")
.and_then(|s| s.strip_suffix(", private"))
.unwrap();
let max_age: u32 = max_age_str.parse().unwrap();
max_ages.insert(max_age);
}
// Should have some variation in max-age values (randomization)
// Note: This test might occasionally fail due to randomness, but should usually pass
assert!(
max_ages.len() > 1,
"Cache-Control max-age should be randomized"
);
}

199
tests/integration_tests.rs Normal file
View file

@ -0,0 +1,199 @@
use std::time::Duration;
use ohttp_gateway::{key_manager::KeyManager, key_manager::KeyManagerConfig};
mod common;
use common::*;
#[tokio::test]
async fn test_end_to_end_encryption_decryption() {
let config = KeyManagerConfig::default();
let manager = KeyManager::new(config).await.unwrap();
// Get the server for decryption
let _ = manager.get_current_server().await.unwrap();
// Get the key config for client encryption
let encoded_config = manager.get_encoded_config().await.unwrap();
// Parse the config (this would normally be done by a real OHTTP client)
// For now, create a client with the current key config
let stats = manager.get_stats().await;
let _ = manager.get_server_by_id(stats.active_key_id).await.unwrap();
// Test message
let test_message = create_test_binary_http_message();
// This test verifies that encryption/decryption round trip works
// In a real implementation, you'd use the ohttp client/server APIs
// For now, just verify we can get the components we need
assert!(!encoded_config.is_empty());
assert!(!test_message.is_empty());
}
#[tokio::test]
async fn test_key_rotation_during_requests() {
let config = KeyManagerConfig {
rotation_interval: Duration::from_millis(100),
key_retention_period: Duration::from_millis(200),
auto_rotation_enabled: false, // Manual control
..Default::default()
};
let manager = KeyManager::new(config).await.unwrap();
let initial_stats = manager.get_stats().await;
// Get server for old key
let old_server = manager.get_server_by_id(initial_stats.active_key_id).await;
assert!(old_server.is_some());
// Rotate keys
manager.rotate_keys().await.unwrap();
let new_stats = manager.get_stats().await;
// Old key should still be available for decryption
let old_server_after_rotation = manager.get_server_by_id(initial_stats.active_key_id).await;
assert!(old_server_after_rotation.is_some());
// New key should also be available
let new_server = manager.get_server_by_id(new_stats.active_key_id).await;
assert!(new_server.is_some());
// Active key should have changed
assert_ne!(initial_stats.active_key_id, new_stats.active_key_id);
assert_eq!(new_stats.total_keys, 2);
}
#[tokio::test]
async fn test_invalid_key_id_handling() {
let config = KeyManagerConfig::default();
let manager = KeyManager::new(config).await.unwrap();
let stats = manager.get_stats().await;
let invalid_key_id = stats.active_key_id.wrapping_add(100);
// Should return None for invalid key ID
let server = manager.get_server_by_id(invalid_key_id).await;
assert!(server.is_none());
}
#[tokio::test]
async fn test_concurrent_key_operations() {
let config = KeyManagerConfig::default();
let manager = std::sync::Arc::new(KeyManager::new(config).await.unwrap());
let mut handles = vec![];
// Spawn multiple tasks that access keys concurrently
for i in 0..10 {
let manager_clone = manager.clone();
let handle = tokio::spawn(async move {
if i % 2 == 0 {
// Half the tasks get the current server
let _server = manager_clone.get_current_server().await.unwrap();
} else {
// Half get stats
let _stats = manager_clone.get_stats().await;
}
});
handles.push(handle);
}
// Wait for all tasks to complete
for handle in handles {
handle.await.unwrap();
}
// Manager should still be functional
let final_stats = manager.get_stats().await;
assert_eq!(final_stats.total_keys, 1);
}
#[tokio::test]
async fn test_automatic_rotation_scheduler() {
let config = KeyManagerConfig {
rotation_interval: Duration::from_millis(100),
key_retention_period: Duration::from_millis(200),
auto_rotation_enabled: true,
..Default::default()
};
let manager = std::sync::Arc::new(KeyManager::new(config).await.unwrap());
let initial_stats = manager.get_stats().await;
// Start the rotation scheduler
let manager_clone = manager.clone();
manager_clone.start_rotation_scheduler().await;
// Wait for automatic rotation to occur
tokio::time::sleep(Duration::from_millis(300)).await;
let final_stats = manager.get_stats().await;
// Key should have rotated automatically
// Note: This test might be flaky depending on timing
assert!(final_stats.active_key_id != initial_stats.active_key_id || final_stats.total_keys > 1);
}
#[tokio::test]
async fn test_metrics_tracking() {
let factory = MockMetricsFactory::new();
// Simulate various operations and metric collection
let metrics = factory.create("test_event".to_string()).await;
metrics.fire("operation_success").await;
metrics.response_status("test", 200).await;
assert!(metrics.contains_result("operation_success").await);
assert!(metrics.contains_result("test_response_status_200").await);
// Test the helper function
assert_metrics_contains_result(&factory, "test_event", "operation_success")
.await
.unwrap();
}
#[tokio::test]
async fn test_config_serialization_format() {
let config = KeyManagerConfig::default();
let manager = KeyManager::new(config).await.unwrap();
let encoded_config = manager.get_encoded_config().await.unwrap();
// Verify basic structure: length prefix + config data
assert!(encoded_config.len() >= 4);
let length = u16::from_be_bytes([encoded_config[0], encoded_config[1]]);
assert_eq!(length as usize, encoded_config.len() - 2);
// Verify it contains expected OHTTP key configuration elements
// The exact format would depend on your implementation
let config_data = &encoded_config[2..];
assert!(!config_data.is_empty());
}
#[tokio::test]
async fn test_error_conditions() {
// Test various error conditions
// Invalid seed length
let config = KeyManagerConfig::default();
let short_seed = vec![0u8; 16];
let result = KeyManager::new_with_seed(config.clone(), short_seed).await;
assert!(result.is_err());
// Test with empty cipher suites (if your implementation supports this validation)
let invalid_config = KeyManagerConfig {
cipher_suites: vec![], // Empty cipher suites
..Default::default()
};
// Should return an error for empty cipher suites
let result = KeyManager::new(invalid_config).await;
assert!(
result.is_err(),
"KeyManager should reject empty cipher suites"
);
}

172
tests/key_manager_tests.rs Normal file
View file

@ -0,0 +1,172 @@
use std::time::Duration;
use tokio;
// Your key manager module - adjust the import path as needed
use ohttp_gateway::key_manager::{CipherSuiteConfig, KeyManager, KeyManagerConfig};
#[tokio::test]
async fn test_key_generation() {
let config = KeyManagerConfig::default();
let manager = KeyManager::new(config).await.unwrap();
let stats = manager.get_stats().await;
assert_eq!(stats.total_keys, 1);
assert_eq!(stats.active_keys, 1);
assert!(stats.active_key_id > 0); // Should have generated a key with ID > 0
}
#[tokio::test]
async fn test_key_generation_with_seed() {
let config = KeyManagerConfig::default();
let seed = vec![0u8; 32]; // 32 bytes of zeros for deterministic testing
let manager = KeyManager::new_with_seed(config, seed).await.unwrap();
let stats = manager.get_stats().await;
assert_eq!(stats.total_keys, 1);
assert_eq!(stats.active_keys, 1);
}
#[tokio::test]
async fn test_key_generation_with_insufficient_seed() {
let config = KeyManagerConfig::default();
let short_seed = vec![0u8; 16]; // Only 16 bytes - should fail
let result = KeyManager::new_with_seed(config, short_seed).await;
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("Seed must be at least 32 bytes"));
}
}
#[tokio::test]
async fn test_key_rotation() {
let config = KeyManagerConfig {
rotation_interval: Duration::from_secs(60),
key_retention_period: Duration::from_secs(30),
auto_rotation_enabled: true,
..Default::default()
};
let manager = KeyManager::new(config).await.unwrap();
let initial_stats = manager.get_stats().await;
// Rotate keys
manager.rotate_keys().await.unwrap();
let new_stats = manager.get_stats().await;
assert_eq!(new_stats.total_keys, 2); // Old key + new key
assert_ne!(new_stats.active_key_id, initial_stats.active_key_id);
}
#[tokio::test]
async fn test_get_current_server() {
let config = KeyManagerConfig::default();
let manager = KeyManager::new(config).await.unwrap();
let server = manager.get_current_server().await;
assert!(server.is_ok());
}
#[tokio::test]
async fn test_get_server_by_id() {
let config = KeyManagerConfig::default();
let manager = KeyManager::new(config).await.unwrap();
let stats = manager.get_stats().await;
let active_id = stats.active_key_id;
// Should find the active key
let server = manager.get_server_by_id(active_id).await;
assert!(server.is_some());
// Should not find non-existent key
let non_existent = manager.get_server_by_id(active_id.wrapping_add(100)).await;
assert!(non_existent.is_none());
}
#[tokio::test]
async fn test_should_rotate() {
let config = KeyManagerConfig {
rotation_interval: Duration::from_millis(100), // Very short for testing
..Default::default()
};
let manager = KeyManager::new(config).await.unwrap();
// Should not need rotation immediately
assert!(!manager.should_rotate().await);
// Wait for the rotation interval to pass
tokio::time::sleep(Duration::from_millis(200)).await;
// Now should need rotation
assert!(manager.should_rotate().await);
}
#[tokio::test]
async fn test_get_encoded_config() {
let config = KeyManagerConfig::default();
let manager = KeyManager::new(config).await.unwrap();
let encoded_config = manager.get_encoded_config().await.unwrap();
// Should have at least 4 bytes (2 bytes length + some config data)
assert!(encoded_config.len() >= 4);
// First 2 bytes should be length in big endian
let length = u16::from_be_bytes([encoded_config[0], encoded_config[1]]);
assert_eq!(length as usize, encoded_config.len() - 2);
}
#[tokio::test]
async fn test_multiple_cipher_suites() {
let config = KeyManagerConfig {
cipher_suites: vec![
CipherSuiteConfig {
kem: "X25519_SHA256".to_string(),
kdf: "HKDF_SHA256".to_string(),
aead: "AES_128_GCM".to_string(),
},
CipherSuiteConfig {
kem: "X25519_SHA256".to_string(),
kdf: "HKDF_SHA256".to_string(),
aead: "CHACHA20_POLY1305".to_string(),
},
],
..Default::default()
};
let manager = KeyManager::new(config).await.unwrap();
let stats = manager.get_stats().await;
assert_eq!(stats.total_keys, 1);
}
#[tokio::test]
async fn test_cleanup_expired_keys() {
let config = KeyManagerConfig {
rotation_interval: Duration::from_millis(50),
key_retention_period: Duration::from_millis(100),
auto_rotation_enabled: false, // Manual control for testing
..Default::default()
};
let manager = KeyManager::new(config).await.unwrap();
// Rotate to create an old key
manager.rotate_keys().await.unwrap();
let stats_after_rotation = manager.get_stats().await;
assert_eq!(stats_after_rotation.total_keys, 2);
// Wait for keys to expire and manually trigger cleanup
tokio::time::sleep(Duration::from_millis(200)).await;
// Trigger another rotation which should clean up expired keys
manager.rotate_keys().await.unwrap();
let final_stats = manager.get_stats().await;
// Should have cleaned up the expired key
assert!(final_stats.total_keys <= 2);
}