use serde::{Deserialize, Serialize}; use std::collections::HashSet; use std::time::Duration; #[derive(Clone, Debug, Deserialize, Serialize)] pub struct AppConfig { // Server configuration pub port: String, pub backend_url: String, pub request_timeout: Duration, pub max_body_size: usize, // Key management pub key_rotation_interval: Duration, pub key_retention_period: Duration, pub key_rotation_enabled: bool, // Security configuration pub allowed_target_origins: Option>, pub target_rewrites: Option, pub rate_limit: Option, // Operational configuration pub metrics_enabled: bool, pub debug_mode: bool, pub log_format: LogFormat, pub log_level: String, // OHTTP specific pub custom_request_type: Option, pub custom_response_type: Option, pub seed_secret_key: Option, } #[derive(Clone, Debug, Deserialize, Serialize)] pub struct TargetRewriteConfig { pub rewrites: std::collections::HashMap, } #[derive(Clone, Debug, Deserialize, Serialize)] pub struct TargetRewrite { pub scheme: String, pub host: String, } #[derive(Clone, Debug, Deserialize, Serialize)] pub struct RateLimitConfig { pub requests_per_second: u32, pub burst_size: u32, pub by_ip: bool, } #[derive(Clone, Debug, Deserialize, Serialize)] #[serde(rename_all = "lowercase")] pub enum LogFormat { Default, Json, } impl Default for AppConfig { fn default() -> Self { Self { port: "0.0.0.0:8000".to_string(), backend_url: "http://localhost:8080".to_string(), request_timeout: Duration::from_secs(30), max_body_size: 10 * 1024 * 1024, // 10MB key_rotation_interval: Duration::from_secs(30 * 24 * 60 * 60), // 30 days key_retention_period: Duration::from_secs(7 * 24 * 60 * 60), // 7 days key_rotation_enabled: true, allowed_target_origins: None, target_rewrites: None, rate_limit: None, metrics_enabled: true, debug_mode: false, log_format: LogFormat::Default, log_level: "info".to_string(), custom_request_type: None, custom_response_type: None, seed_secret_key: None, } } } impl AppConfig { pub fn from_env() -> Result> { let mut config = Self::default(); // Basic configuration if let Ok(port) = std::env::var("PORT") { config.port = format!("0.0.0.0:{port}"); } if let Ok(url) = std::env::var("BACKEND_URL") { config.backend_url = url; } if let Ok(timeout) = std::env::var("REQUEST_TIMEOUT") { config.request_timeout = Duration::from_secs(timeout.parse()?); } if let Ok(size) = std::env::var("MAX_BODY_SIZE") { config.max_body_size = size.parse()?; } // Key management if let Ok(interval) = std::env::var("KEY_ROTATION_INTERVAL") { config.key_rotation_interval = Duration::from_secs(interval.parse()?); } if let Ok(period) = std::env::var("KEY_RETENTION_PERIOD") { config.key_retention_period = Duration::from_secs(period.parse()?); } if let Ok(enabled) = std::env::var("KEY_ROTATION_ENABLED") { config.key_rotation_enabled = enabled.parse()?; } // Security configuration if let Ok(origins) = std::env::var("ALLOWED_TARGET_ORIGINS") { let origins_set: HashSet = origins .split(',') .map(|s| s.trim().to_string()) .filter(|s| !s.is_empty()) .collect(); if !origins_set.is_empty() { config.allowed_target_origins = Some(origins_set); } } if let Ok(rewrites_json) = std::env::var("TARGET_REWRITES") { let rewrites: std::collections::HashMap = serde_json::from_str(&rewrites_json)?; config.target_rewrites = Some(TargetRewriteConfig { rewrites }); } // Rate limiting if let Ok(rps) = std::env::var("RATE_LIMIT_RPS") { let rate_limit = RateLimitConfig { requests_per_second: rps.parse()?, burst_size: std::env::var("RATE_LIMIT_BURST") .ok() .and_then(|s| s.parse().ok()) .unwrap_or(100), by_ip: std::env::var("RATE_LIMIT_BY_IP") .ok() .and_then(|s| s.parse().ok()) .unwrap_or(true), }; config.rate_limit = Some(rate_limit); } // Operational configuration if let Ok(enabled) = std::env::var("METRICS_ENABLED") { config.metrics_enabled = enabled.parse()?; } if let Ok(debug) = std::env::var("GATEWAY_DEBUG") { config.debug_mode = debug.parse()?; } if let Ok(format) = std::env::var("LOG_FORMAT") { config.log_format = match format.to_lowercase().as_str() { "json" => LogFormat::Json, _ => LogFormat::Default, }; } if let Ok(level) = std::env::var("LOG_LEVEL") { config.log_level = level; } // OHTTP specific if let Ok(req_type) = std::env::var("CUSTOM_REQUEST_TYPE") { config.custom_request_type = Some(req_type); } if let Ok(resp_type) = std::env::var("CUSTOM_RESPONSE_TYPE") { config.custom_response_type = Some(resp_type); } if let Ok(seed) = std::env::var("SEED_SECRET_KEY") { config.seed_secret_key = Some(seed); } // Validate configuration config.validate()?; Ok(config) } fn validate(&self) -> Result<(), Box> { // Validate key rotation settings if self.key_retention_period > self.key_rotation_interval { return Err("Key retention period cannot be longer than rotation interval".into()); } // Validate custom content types match (&self.custom_request_type, &self.custom_response_type) { (Some(req), Some(resp)) if req == resp => { return Err("Request and response content types must be different".into()); } (Some(_), None) | (None, Some(_)) => { return Err("Both custom request and response types must be specified".into()); } _ => {} } // Validate seed if provided if let Some(seed) = &self.seed_secret_key { let decoded = hex::decode(seed).map_err(|_| "SEED_SECRET_KEY must be a hex-encoded string")?; if decoded.len() < 32 { return Err("SEED_SECRET_KEY must be at least 32 bytes (64 hex characters)".into()); } } Ok(()) } /// Check if a target origin is allowed pub fn is_origin_allowed(&self, origin: &str) -> bool { match &self.allowed_target_origins { Some(allowed) => allowed.contains(origin), None => true, // No restrictions if not configured } } /// Get rewrite configuration for a host pub fn get_rewrite(&self, host: &str) -> Option<&TargetRewrite> { self.target_rewrites .as_ref() .and_then(|config| config.rewrites.get(host)) } }