379 lines
12 KiB
Rust
379 lines
12 KiB
Rust
use chrono::{DateTime, Utc};
|
|
use ohttp::{
|
|
KeyConfig, Server as OhttpServer, SymmetricSuite,
|
|
hpke::{Aead, Kdf, Kem},
|
|
};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
use tokio::sync::RwLock;
|
|
use tracing::{error, info};
|
|
|
|
/// Represents a key with its metadata
|
|
#[derive(Clone, Debug)]
|
|
pub struct KeyInfo {
|
|
pub id: u8,
|
|
pub config: KeyConfig,
|
|
pub server: OhttpServer,
|
|
pub expires_at: DateTime<Utc>,
|
|
pub is_active: bool,
|
|
}
|
|
|
|
/// Configuration for key management
|
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
|
pub struct KeyManagerConfig {
|
|
/// How often to rotate keys (default: 30 days)
|
|
pub rotation_interval: Duration,
|
|
/// How long to keep old keys for decryption (default: 7 days)
|
|
pub key_retention_period: Duration,
|
|
/// Whether to enable automatic rotation
|
|
pub auto_rotation_enabled: bool,
|
|
/// Supported cipher suites
|
|
pub cipher_suites: Vec<CipherSuiteConfig>,
|
|
}
|
|
|
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
|
pub struct CipherSuiteConfig {
|
|
pub kem: String,
|
|
pub kdf: String,
|
|
pub aead: String,
|
|
}
|
|
|
|
impl Default for KeyManagerConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
rotation_interval: Duration::from_secs(30 * 24 * 60 * 60), // 30 days
|
|
key_retention_period: Duration::from_secs(7 * 24 * 60 * 60), // 7 days
|
|
auto_rotation_enabled: true,
|
|
cipher_suites: vec![
|
|
CipherSuiteConfig {
|
|
kem: "X25519_SHA256".to_string(),
|
|
kdf: "HKDF_SHA256".to_string(),
|
|
aead: "AES_128_GCM".to_string(),
|
|
},
|
|
CipherSuiteConfig {
|
|
kem: "X25519_SHA256".to_string(),
|
|
kdf: "HKDF_SHA256".to_string(),
|
|
aead: "CHACHA20_POLY1305".to_string(),
|
|
},
|
|
],
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct KeyManager {
|
|
/// All keys indexed by ID
|
|
keys: Arc<RwLock<HashMap<u8, KeyInfo>>>,
|
|
/// Current active key ID
|
|
active_key_id: Arc<RwLock<u8>>,
|
|
/// Configuration
|
|
config: KeyManagerConfig,
|
|
/// Key ID counter (wraps around after 255)
|
|
next_key_id: Arc<RwLock<u8>>,
|
|
/// Seed for deterministic key generation (optional)
|
|
seed: Option<Vec<u8>>,
|
|
}
|
|
|
|
impl KeyManager {
|
|
pub async fn new(config: KeyManagerConfig) -> Result<Self, Box<dyn std::error::Error>> {
|
|
let manager = Self {
|
|
keys: Arc::new(RwLock::new(HashMap::new())),
|
|
active_key_id: Arc::new(RwLock::new(0)),
|
|
config,
|
|
next_key_id: Arc::new(RwLock::new(1)),
|
|
seed: None,
|
|
};
|
|
|
|
// Generate initial key
|
|
let initial_key = manager.generate_new_key().await?;
|
|
{
|
|
let mut keys = manager.keys.write().await;
|
|
let mut active_id = manager.active_key_id.write().await;
|
|
|
|
keys.insert(initial_key.id, initial_key.clone());
|
|
*active_id = initial_key.id;
|
|
}
|
|
|
|
info!("KeyManager initialized with key ID: {}", initial_key.id);
|
|
Ok(manager)
|
|
}
|
|
|
|
/// Create a key manager with a seed for deterministic key generation
|
|
pub async fn new_with_seed(
|
|
config: KeyManagerConfig,
|
|
seed: Vec<u8>,
|
|
) -> Result<Self, Box<dyn std::error::Error>> {
|
|
if seed.len() < 32 {
|
|
return Err("Seed must be at least 32 bytes".into());
|
|
}
|
|
|
|
let mut manager = Self::new(config).await?;
|
|
manager.seed = Some(seed);
|
|
Ok(manager)
|
|
}
|
|
|
|
/// Generate a new key configuration
|
|
async fn generate_new_key(&self) -> Result<KeyInfo, Box<dyn std::error::Error>> {
|
|
let key_id = {
|
|
let mut next_id = self.next_key_id.write().await;
|
|
let id = *next_id;
|
|
*next_id = next_id.wrapping_add(1);
|
|
id
|
|
};
|
|
|
|
// Parse cipher suites from config
|
|
let mut symmetric_suites = Vec::new();
|
|
for suite in &self.config.cipher_suites {
|
|
let kdf = match suite.kdf.as_str() {
|
|
"HKDF_SHA256" => Kdf::HkdfSha256,
|
|
"HKDF_SHA384" => Kdf::HkdfSha384,
|
|
"HKDF_SHA512" => Kdf::HkdfSha512,
|
|
_ => Kdf::HkdfSha256,
|
|
};
|
|
|
|
let aead = match suite.aead.as_str() {
|
|
"AES_128_GCM" => Aead::Aes128Gcm,
|
|
"AES_256_GCM" => Aead::Aes256Gcm,
|
|
"CHACHA20_POLY1305" => Aead::ChaCha20Poly1305,
|
|
_ => Aead::Aes128Gcm,
|
|
};
|
|
|
|
symmetric_suites.push(SymmetricSuite::new(kdf, aead));
|
|
}
|
|
|
|
// Determine KEM based on config - only X25519 is supported by ohttp crate
|
|
let kem = Kem::X25519Sha256;
|
|
|
|
// Generate key config
|
|
let key_config = if let Some(seed) = &self.seed {
|
|
// Deterministic generation using seed + key_id
|
|
let mut key_seed = seed.clone();
|
|
key_seed.push(key_id);
|
|
|
|
// The ohttp crate doesn't directly support deterministic key generation
|
|
// This would require extending the crate or using a custom implementation
|
|
KeyConfig::new(key_id, kem, symmetric_suites)?
|
|
} else {
|
|
KeyConfig::new(key_id, kem, symmetric_suites)?
|
|
};
|
|
|
|
let server = OhttpServer::new(key_config.clone())?;
|
|
let now = Utc::now();
|
|
|
|
Ok(KeyInfo {
|
|
id: key_id,
|
|
config: key_config,
|
|
server,
|
|
expires_at: now + chrono::Duration::from_std(self.config.rotation_interval)?,
|
|
is_active: true,
|
|
})
|
|
}
|
|
|
|
/// Get the current active server for decryption
|
|
pub async fn get_current_server(&self) -> Result<OhttpServer, Box<dyn std::error::Error>> {
|
|
let keys = self.keys.read().await;
|
|
let active_id = self.active_key_id.read().await;
|
|
|
|
keys.get(&*active_id)
|
|
.map(|info| info.server.clone())
|
|
.ok_or_else(|| "No active key found".into())
|
|
}
|
|
|
|
/// Get a server by key ID (for handling requests with specific key IDs)
|
|
pub async fn get_server_by_id(&self, key_id: u8) -> Option<OhttpServer> {
|
|
let keys = self.keys.read().await;
|
|
keys.get(&key_id).map(|info| info.server.clone())
|
|
}
|
|
|
|
/// Get encoded config for backward compatibility
|
|
pub async fn get_encoded_config(&self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
|
|
let keys = self.keys.read().await;
|
|
let active_id = self.active_key_id.read().await;
|
|
let cfg_bytes = keys
|
|
.get(&*active_id)
|
|
.ok_or("no active key")?
|
|
.config
|
|
.encode()?;
|
|
|
|
let mut out = Vec::with_capacity(cfg_bytes.len() + 2);
|
|
out.extend_from_slice(&(cfg_bytes.len() as u16).to_be_bytes()); // 2-byte length
|
|
out.extend_from_slice(&cfg_bytes);
|
|
Ok(out)
|
|
}
|
|
|
|
/// Rotate keys by generating a new key and marking old ones for expiration
|
|
pub async fn rotate_keys(&self) -> Result<(), Box<dyn std::error::Error>> {
|
|
info!("Starting key rotation");
|
|
|
|
// Generate new key
|
|
let new_key = self.generate_new_key().await?;
|
|
let new_key_id = new_key.id;
|
|
|
|
// Update key store
|
|
{
|
|
let mut keys = self.keys.write().await;
|
|
let mut active_id = self.active_key_id.write().await;
|
|
let now = Utc::now();
|
|
|
|
// Mark current active key for future expiration
|
|
if let Some(current_key) = keys.get_mut(&*active_id) {
|
|
current_key.is_active = false;
|
|
// Keep it around for the retention period
|
|
current_key.expires_at =
|
|
now + chrono::Duration::from_std(self.config.key_retention_period)?;
|
|
}
|
|
|
|
// Add new key
|
|
keys.insert(new_key_id, new_key);
|
|
|
|
// Update active key ID
|
|
*active_id = new_key_id;
|
|
|
|
// Clean up expired keys
|
|
keys.retain(|_, info| info.expires_at > now);
|
|
|
|
info!(
|
|
"Key rotation completed. New active key ID: {}, total keys: {}",
|
|
new_key_id,
|
|
keys.len()
|
|
);
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Check if rotation is needed
|
|
pub async fn should_rotate(&self) -> bool {
|
|
let keys = self.keys.read().await;
|
|
let active_id = self.active_key_id.read().await;
|
|
|
|
if let Some(active_key) = keys.get(&*active_id) {
|
|
let time_until_expiry = active_key.expires_at.signed_duration_since(Utc::now());
|
|
|
|
// Rotate if less than 10% of the rotation interval remains
|
|
let threshold = chrono::Duration::from_std(self.config.rotation_interval / 10)
|
|
.unwrap_or_else(|_| chrono::Duration::days(3));
|
|
|
|
time_until_expiry < threshold
|
|
} else {
|
|
true // No active key, definitely need to rotate
|
|
}
|
|
}
|
|
|
|
/// Start automatic key rotation scheduler
|
|
pub async fn start_rotation_scheduler(self: Arc<Self>) {
|
|
if !self.config.auto_rotation_enabled {
|
|
info!("Automatic key rotation is disabled");
|
|
return;
|
|
}
|
|
|
|
let manager = self;
|
|
tokio::spawn(async move {
|
|
// Check every hour
|
|
let mut interval = tokio::time::interval(Duration::from_secs(3600));
|
|
|
|
loop {
|
|
interval.tick().await;
|
|
|
|
if manager.should_rotate().await {
|
|
if let Err(e) = manager.rotate_keys().await {
|
|
error!("Key rotation failed: {}", e);
|
|
}
|
|
}
|
|
|
|
// Also clean up expired keys
|
|
manager.cleanup_expired_keys().await;
|
|
}
|
|
});
|
|
}
|
|
|
|
/// Clean up expired keys
|
|
async fn cleanup_expired_keys(&self) {
|
|
let mut keys = self.keys.write().await;
|
|
let now = Utc::now();
|
|
let before_count = keys.len();
|
|
|
|
keys.retain(|id, info| {
|
|
if info.expires_at <= now {
|
|
info!("Removing expired key ID: {}", id);
|
|
false
|
|
} else {
|
|
true
|
|
}
|
|
});
|
|
|
|
let removed = before_count - keys.len();
|
|
if removed > 0 {
|
|
info!("Cleaned up {} expired keys", removed);
|
|
}
|
|
}
|
|
|
|
/// Get key manager statistics
|
|
pub async fn get_stats(&self) -> KeyManagerStats {
|
|
let keys = self.keys.read().await;
|
|
let active_id = self.active_key_id.read().await;
|
|
let now = Utc::now();
|
|
|
|
let active_keys = keys.values().filter(|k| k.is_active).count();
|
|
let total_keys = keys.len();
|
|
let expired_keys = keys.values().filter(|k| k.expires_at <= now).count();
|
|
|
|
KeyManagerStats {
|
|
active_key_id: *active_id,
|
|
total_keys,
|
|
active_keys,
|
|
expired_keys,
|
|
rotation_interval: self.config.rotation_interval,
|
|
auto_rotation_enabled: self.config.auto_rotation_enabled,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
pub struct KeyManagerStats {
|
|
pub active_key_id: u8,
|
|
pub total_keys: usize,
|
|
pub active_keys: usize,
|
|
pub expired_keys: usize,
|
|
pub rotation_interval: Duration,
|
|
pub auto_rotation_enabled: bool,
|
|
}
|
|
|
|
// Ensure thread safety
|
|
unsafe impl Send for KeyManager {}
|
|
unsafe impl Sync for KeyManager {}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[tokio::test]
|
|
async fn test_key_generation() {
|
|
let config = KeyManagerConfig::default();
|
|
let manager = KeyManager::new(config).await.unwrap();
|
|
|
|
let stats = manager.get_stats().await;
|
|
assert_eq!(stats.total_keys, 1);
|
|
assert_eq!(stats.active_keys, 1);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_key_rotation() {
|
|
let config = KeyManagerConfig {
|
|
rotation_interval: Duration::from_secs(60),
|
|
key_retention_period: Duration::from_secs(30),
|
|
..Default::default()
|
|
};
|
|
|
|
let manager = KeyManager::new(config).await.unwrap();
|
|
let initial_stats = manager.get_stats().await;
|
|
|
|
// Rotate keys
|
|
manager.rotate_keys().await.unwrap();
|
|
|
|
let new_stats = manager.get_stats().await;
|
|
assert_eq!(new_stats.total_keys, 2);
|
|
assert_ne!(new_stats.active_key_id, initial_stats.active_key_id);
|
|
}
|
|
}
|