Initial commit
This commit is contained in:
commit
7e4429e9d3
21 changed files with 6303 additions and 0 deletions
21
.gitignore
vendored
Normal file
21
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
# Generated by Cargo
|
||||||
|
# will have compiled files and executables
|
||||||
|
debug
|
||||||
|
target
|
||||||
|
|
||||||
|
# These are backup files generated by rustfmt
|
||||||
|
**/*.rs.bk
|
||||||
|
|
||||||
|
# MSVC Windows builds of rustc generate these, which store debugging information
|
||||||
|
*.pdb
|
||||||
|
|
||||||
|
# Generated by cargo mutants
|
||||||
|
# Contains mutation testing data
|
||||||
|
**/mutants.out*/
|
||||||
|
|
||||||
|
# RustRover
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
.idea/
|
||||||
3776
Cargo.lock
generated
Normal file
3776
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
65
Cargo.toml
Normal file
65
Cargo.toml
Normal file
|
|
@ -0,0 +1,65 @@
|
||||||
|
[package]
|
||||||
|
name = "ohttp-gateway"
|
||||||
|
authors = ["Bastian Gruber<foreach@me.com>"]
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2024"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
# Web framework and async runtime
|
||||||
|
axum = { version = "0.7", features = ["macros"] }
|
||||||
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
hyper = { version = "1", features = ["full"] }
|
||||||
|
hyper-util = { version = "0.1", features = ["full"] }
|
||||||
|
|
||||||
|
# HTTP client for backend requests
|
||||||
|
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||||
|
|
||||||
|
# OHTTP implementation - Using the martinthomson/ohttp crate
|
||||||
|
ohttp = { version = "0.5", features = ["rust-hpke"] }
|
||||||
|
bhttp = "0.5"
|
||||||
|
|
||||||
|
# Middleware and utilities
|
||||||
|
tower = "0.4"
|
||||||
|
tower-http = { version = "0.6", features = [
|
||||||
|
"cors",
|
||||||
|
"trace",
|
||||||
|
"compression-br",
|
||||||
|
"timeout",
|
||||||
|
] }
|
||||||
|
|
||||||
|
# Serialization and configuration
|
||||||
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
|
serde_json = "1.0"
|
||||||
|
config = "0.14"
|
||||||
|
|
||||||
|
# Logging and observability
|
||||||
|
tracing = "0.1"
|
||||||
|
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
|
||||||
|
chrono = "0.4"
|
||||||
|
|
||||||
|
# Error handling
|
||||||
|
thiserror = "1.0"
|
||||||
|
anyhow = "1.0"
|
||||||
|
|
||||||
|
# Metrics and monitoring
|
||||||
|
prometheus = "0.13"
|
||||||
|
axum-prometheus = "0.7"
|
||||||
|
|
||||||
|
# Security and validation
|
||||||
|
validator = { version = "0.18", features = ["derive"] }
|
||||||
|
jsonwebtoken = "9.0"
|
||||||
|
uuid = { version = "1.0", features = ["v4"] }
|
||||||
|
|
||||||
|
# Async utilities
|
||||||
|
tokio-util = "0.7"
|
||||||
|
futures = "0.3"
|
||||||
|
|
||||||
|
# Random number generation
|
||||||
|
hex = "0.4"
|
||||||
|
rand = "0.8"
|
||||||
|
|
||||||
|
# Configuration management
|
||||||
|
clap = { version = "4.0", features = ["derive", "env"] }
|
||||||
|
|
||||||
|
[profile.release]
|
||||||
|
lto = "fat"
|
||||||
54
Dockerfile
Normal file
54
Dockerfile
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
# Build stage
|
||||||
|
FROM rust:1.88-slim as builder
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install build dependencies
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
pkg-config \
|
||||||
|
libssl-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Copy Cargo files
|
||||||
|
COPY Cargo.toml ./
|
||||||
|
|
||||||
|
# Create dummy main to cache dependencies
|
||||||
|
RUN mkdir src && echo "fn main() {}" > src/main.rs
|
||||||
|
|
||||||
|
# Build dependencies
|
||||||
|
RUN RUSTFLAGS="-C target-cpu=native" cargo build --release
|
||||||
|
RUN rm -rf src
|
||||||
|
|
||||||
|
# Copy source code
|
||||||
|
COPY src ./src
|
||||||
|
|
||||||
|
# Build the actual application
|
||||||
|
RUN touch src/main.rs && RUSTFLAGS="-C target-cpu=native" cargo build --release
|
||||||
|
|
||||||
|
# Runtime stage
|
||||||
|
FROM debian:bookworm-slim
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
ca-certificates \
|
||||||
|
libssl3 \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy the binary from builder
|
||||||
|
COPY --from=builder /app/target/release/ohttp-gateway /usr/local/bin/ohttp-gateway
|
||||||
|
|
||||||
|
# Create non-root user
|
||||||
|
RUN useradd -m -u 1001 ohttp
|
||||||
|
USER ohttp
|
||||||
|
|
||||||
|
# Set default environment variables
|
||||||
|
ENV RUST_LOG=debug,ohttp_gateway=debug
|
||||||
|
ENV LISTEN_ADDR=0.0.0.0:8080
|
||||||
|
ENV BACKEND_URL=http://localhost:8000
|
||||||
|
ENV REQUEST_TIMEOUT=30
|
||||||
|
ENV KEY_ROTATION_ENABLED=false
|
||||||
|
|
||||||
|
EXPOSE 8080
|
||||||
|
|
||||||
|
CMD ["ohttp-gateway"]
|
||||||
373
LICENSE
Normal file
373
LICENSE
Normal file
|
|
@ -0,0 +1,373 @@
|
||||||
|
Mozilla Public License Version 2.0
|
||||||
|
==================================
|
||||||
|
|
||||||
|
1. Definitions
|
||||||
|
--------------
|
||||||
|
|
||||||
|
1.1. "Contributor"
|
||||||
|
means each individual or legal entity that creates, contributes to
|
||||||
|
the creation of, or owns Covered Software.
|
||||||
|
|
||||||
|
1.2. "Contributor Version"
|
||||||
|
means the combination of the Contributions of others (if any) used
|
||||||
|
by a Contributor and that particular Contributor's Contribution.
|
||||||
|
|
||||||
|
1.3. "Contribution"
|
||||||
|
means Covered Software of a particular Contributor.
|
||||||
|
|
||||||
|
1.4. "Covered Software"
|
||||||
|
means Source Code Form to which the initial Contributor has attached
|
||||||
|
the notice in Exhibit A, the Executable Form of such Source Code
|
||||||
|
Form, and Modifications of such Source Code Form, in each case
|
||||||
|
including portions thereof.
|
||||||
|
|
||||||
|
1.5. "Incompatible With Secondary Licenses"
|
||||||
|
means
|
||||||
|
|
||||||
|
(a) that the initial Contributor has attached the notice described
|
||||||
|
in Exhibit B to the Covered Software; or
|
||||||
|
|
||||||
|
(b) that the Covered Software was made available under the terms of
|
||||||
|
version 1.1 or earlier of the License, but not also under the
|
||||||
|
terms of a Secondary License.
|
||||||
|
|
||||||
|
1.6. "Executable Form"
|
||||||
|
means any form of the work other than Source Code Form.
|
||||||
|
|
||||||
|
1.7. "Larger Work"
|
||||||
|
means a work that combines Covered Software with other material, in
|
||||||
|
a separate file or files, that is not Covered Software.
|
||||||
|
|
||||||
|
1.8. "License"
|
||||||
|
means this document.
|
||||||
|
|
||||||
|
1.9. "Licensable"
|
||||||
|
means having the right to grant, to the maximum extent possible,
|
||||||
|
whether at the time of the initial grant or subsequently, any and
|
||||||
|
all of the rights conveyed by this License.
|
||||||
|
|
||||||
|
1.10. "Modifications"
|
||||||
|
means any of the following:
|
||||||
|
|
||||||
|
(a) any file in Source Code Form that results from an addition to,
|
||||||
|
deletion from, or modification of the contents of Covered
|
||||||
|
Software; or
|
||||||
|
|
||||||
|
(b) any new file in Source Code Form that contains any Covered
|
||||||
|
Software.
|
||||||
|
|
||||||
|
1.11. "Patent Claims" of a Contributor
|
||||||
|
means any patent claim(s), including without limitation, method,
|
||||||
|
process, and apparatus claims, in any patent Licensable by such
|
||||||
|
Contributor that would be infringed, but for the grant of the
|
||||||
|
License, by the making, using, selling, offering for sale, having
|
||||||
|
made, import, or transfer of either its Contributions or its
|
||||||
|
Contributor Version.
|
||||||
|
|
||||||
|
1.12. "Secondary License"
|
||||||
|
means either the GNU General Public License, Version 2.0, the GNU
|
||||||
|
Lesser General Public License, Version 2.1, the GNU Affero General
|
||||||
|
Public License, Version 3.0, or any later versions of those
|
||||||
|
licenses.
|
||||||
|
|
||||||
|
1.13. "Source Code Form"
|
||||||
|
means the form of the work preferred for making modifications.
|
||||||
|
|
||||||
|
1.14. "You" (or "Your")
|
||||||
|
means an individual or a legal entity exercising rights under this
|
||||||
|
License. For legal entities, "You" includes any entity that
|
||||||
|
controls, is controlled by, or is under common control with You. For
|
||||||
|
purposes of this definition, "control" means (a) the power, direct
|
||||||
|
or indirect, to cause the direction or management of such entity,
|
||||||
|
whether by contract or otherwise, or (b) ownership of more than
|
||||||
|
fifty percent (50%) of the outstanding shares or beneficial
|
||||||
|
ownership of such entity.
|
||||||
|
|
||||||
|
2. License Grants and Conditions
|
||||||
|
--------------------------------
|
||||||
|
|
||||||
|
2.1. Grants
|
||||||
|
|
||||||
|
Each Contributor hereby grants You a world-wide, royalty-free,
|
||||||
|
non-exclusive license:
|
||||||
|
|
||||||
|
(a) under intellectual property rights (other than patent or trademark)
|
||||||
|
Licensable by such Contributor to use, reproduce, make available,
|
||||||
|
modify, display, perform, distribute, and otherwise exploit its
|
||||||
|
Contributions, either on an unmodified basis, with Modifications, or
|
||||||
|
as part of a Larger Work; and
|
||||||
|
|
||||||
|
(b) under Patent Claims of such Contributor to make, use, sell, offer
|
||||||
|
for sale, have made, import, and otherwise transfer either its
|
||||||
|
Contributions or its Contributor Version.
|
||||||
|
|
||||||
|
2.2. Effective Date
|
||||||
|
|
||||||
|
The licenses granted in Section 2.1 with respect to any Contribution
|
||||||
|
become effective for each Contribution on the date the Contributor first
|
||||||
|
distributes such Contribution.
|
||||||
|
|
||||||
|
2.3. Limitations on Grant Scope
|
||||||
|
|
||||||
|
The licenses granted in this Section 2 are the only rights granted under
|
||||||
|
this License. No additional rights or licenses will be implied from the
|
||||||
|
distribution or licensing of Covered Software under this License.
|
||||||
|
Notwithstanding Section 2.1(b) above, no patent license is granted by a
|
||||||
|
Contributor:
|
||||||
|
|
||||||
|
(a) for any code that a Contributor has removed from Covered Software;
|
||||||
|
or
|
||||||
|
|
||||||
|
(b) for infringements caused by: (i) Your and any other third party's
|
||||||
|
modifications of Covered Software, or (ii) the combination of its
|
||||||
|
Contributions with other software (except as part of its Contributor
|
||||||
|
Version); or
|
||||||
|
|
||||||
|
(c) under Patent Claims infringed by Covered Software in the absence of
|
||||||
|
its Contributions.
|
||||||
|
|
||||||
|
This License does not grant any rights in the trademarks, service marks,
|
||||||
|
or logos of any Contributor (except as may be necessary to comply with
|
||||||
|
the notice requirements in Section 3.4).
|
||||||
|
|
||||||
|
2.4. Subsequent Licenses
|
||||||
|
|
||||||
|
No Contributor makes additional grants as a result of Your choice to
|
||||||
|
distribute the Covered Software under a subsequent version of this
|
||||||
|
License (see Section 10.2) or under the terms of a Secondary License (if
|
||||||
|
permitted under the terms of Section 3.3).
|
||||||
|
|
||||||
|
2.5. Representation
|
||||||
|
|
||||||
|
Each Contributor represents that the Contributor believes its
|
||||||
|
Contributions are its original creation(s) or it has sufficient rights
|
||||||
|
to grant the rights to its Contributions conveyed by this License.
|
||||||
|
|
||||||
|
2.6. Fair Use
|
||||||
|
|
||||||
|
This License is not intended to limit any rights You have under
|
||||||
|
applicable copyright doctrines of fair use, fair dealing, or other
|
||||||
|
equivalents.
|
||||||
|
|
||||||
|
2.7. Conditions
|
||||||
|
|
||||||
|
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
|
||||||
|
in Section 2.1.
|
||||||
|
|
||||||
|
3. Responsibilities
|
||||||
|
-------------------
|
||||||
|
|
||||||
|
3.1. Distribution of Source Form
|
||||||
|
|
||||||
|
All distribution of Covered Software in Source Code Form, including any
|
||||||
|
Modifications that You create or to which You contribute, must be under
|
||||||
|
the terms of this License. You must inform recipients that the Source
|
||||||
|
Code Form of the Covered Software is governed by the terms of this
|
||||||
|
License, and how they can obtain a copy of this License. You may not
|
||||||
|
attempt to alter or restrict the recipients' rights in the Source Code
|
||||||
|
Form.
|
||||||
|
|
||||||
|
3.2. Distribution of Executable Form
|
||||||
|
|
||||||
|
If You distribute Covered Software in Executable Form then:
|
||||||
|
|
||||||
|
(a) such Covered Software must also be made available in Source Code
|
||||||
|
Form, as described in Section 3.1, and You must inform recipients of
|
||||||
|
the Executable Form how they can obtain a copy of such Source Code
|
||||||
|
Form by reasonable means in a timely manner, at a charge no more
|
||||||
|
than the cost of distribution to the recipient; and
|
||||||
|
|
||||||
|
(b) You may distribute such Executable Form under the terms of this
|
||||||
|
License, or sublicense it under different terms, provided that the
|
||||||
|
license for the Executable Form does not attempt to limit or alter
|
||||||
|
the recipients' rights in the Source Code Form under this License.
|
||||||
|
|
||||||
|
3.3. Distribution of a Larger Work
|
||||||
|
|
||||||
|
You may create and distribute a Larger Work under terms of Your choice,
|
||||||
|
provided that You also comply with the requirements of this License for
|
||||||
|
the Covered Software. If the Larger Work is a combination of Covered
|
||||||
|
Software with a work governed by one or more Secondary Licenses, and the
|
||||||
|
Covered Software is not Incompatible With Secondary Licenses, this
|
||||||
|
License permits You to additionally distribute such Covered Software
|
||||||
|
under the terms of such Secondary License(s), so that the recipient of
|
||||||
|
the Larger Work may, at their option, further distribute the Covered
|
||||||
|
Software under the terms of either this License or such Secondary
|
||||||
|
License(s).
|
||||||
|
|
||||||
|
3.4. Notices
|
||||||
|
|
||||||
|
You may not remove or alter the substance of any license notices
|
||||||
|
(including copyright notices, patent notices, disclaimers of warranty,
|
||||||
|
or limitations of liability) contained within the Source Code Form of
|
||||||
|
the Covered Software, except that You may alter any license notices to
|
||||||
|
the extent required to remedy known factual inaccuracies.
|
||||||
|
|
||||||
|
3.5. Application of Additional Terms
|
||||||
|
|
||||||
|
You may choose to offer, and to charge a fee for, warranty, support,
|
||||||
|
indemnity or liability obligations to one or more recipients of Covered
|
||||||
|
Software. However, You may do so only on Your own behalf, and not on
|
||||||
|
behalf of any Contributor. You must make it absolutely clear that any
|
||||||
|
such warranty, support, indemnity, or liability obligation is offered by
|
||||||
|
You alone, and You hereby agree to indemnify every Contributor for any
|
||||||
|
liability incurred by such Contributor as a result of warranty, support,
|
||||||
|
indemnity or liability terms You offer. You may include additional
|
||||||
|
disclaimers of warranty and limitations of liability specific to any
|
||||||
|
jurisdiction.
|
||||||
|
|
||||||
|
4. Inability to Comply Due to Statute or Regulation
|
||||||
|
---------------------------------------------------
|
||||||
|
|
||||||
|
If it is impossible for You to comply with any of the terms of this
|
||||||
|
License with respect to some or all of the Covered Software due to
|
||||||
|
statute, judicial order, or regulation then You must: (a) comply with
|
||||||
|
the terms of this License to the maximum extent possible; and (b)
|
||||||
|
describe the limitations and the code they affect. Such description must
|
||||||
|
be placed in a text file included with all distributions of the Covered
|
||||||
|
Software under this License. Except to the extent prohibited by statute
|
||||||
|
or regulation, such description must be sufficiently detailed for a
|
||||||
|
recipient of ordinary skill to be able to understand it.
|
||||||
|
|
||||||
|
5. Termination
|
||||||
|
--------------
|
||||||
|
|
||||||
|
5.1. The rights granted under this License will terminate automatically
|
||||||
|
if You fail to comply with any of its terms. However, if You become
|
||||||
|
compliant, then the rights granted under this License from a particular
|
||||||
|
Contributor are reinstated (a) provisionally, unless and until such
|
||||||
|
Contributor explicitly and finally terminates Your grants, and (b) on an
|
||||||
|
ongoing basis, if such Contributor fails to notify You of the
|
||||||
|
non-compliance by some reasonable means prior to 60 days after You have
|
||||||
|
come back into compliance. Moreover, Your grants from a particular
|
||||||
|
Contributor are reinstated on an ongoing basis if such Contributor
|
||||||
|
notifies You of the non-compliance by some reasonable means, this is the
|
||||||
|
first time You have received notice of non-compliance with this License
|
||||||
|
from such Contributor, and You become compliant prior to 30 days after
|
||||||
|
Your receipt of the notice.
|
||||||
|
|
||||||
|
5.2. If You initiate litigation against any entity by asserting a patent
|
||||||
|
infringement claim (excluding declaratory judgment actions,
|
||||||
|
counter-claims, and cross-claims) alleging that a Contributor Version
|
||||||
|
directly or indirectly infringes any patent, then the rights granted to
|
||||||
|
You by any and all Contributors for the Covered Software under Section
|
||||||
|
2.1 of this License shall terminate.
|
||||||
|
|
||||||
|
5.3. In the event of termination under Sections 5.1 or 5.2 above, all
|
||||||
|
end user license agreements (excluding distributors and resellers) which
|
||||||
|
have been validly granted by You or Your distributors under this License
|
||||||
|
prior to termination shall survive termination.
|
||||||
|
|
||||||
|
************************************************************************
|
||||||
|
* *
|
||||||
|
* 6. Disclaimer of Warranty *
|
||||||
|
* ------------------------- *
|
||||||
|
* *
|
||||||
|
* Covered Software is provided under this License on an "as is" *
|
||||||
|
* basis, without warranty of any kind, either expressed, implied, or *
|
||||||
|
* statutory, including, without limitation, warranties that the *
|
||||||
|
* Covered Software is free of defects, merchantable, fit for a *
|
||||||
|
* particular purpose or non-infringing. The entire risk as to the *
|
||||||
|
* quality and performance of the Covered Software is with You. *
|
||||||
|
* Should any Covered Software prove defective in any respect, You *
|
||||||
|
* (not any Contributor) assume the cost of any necessary servicing, *
|
||||||
|
* repair, or correction. This disclaimer of warranty constitutes an *
|
||||||
|
* essential part of this License. No use of any Covered Software is *
|
||||||
|
* authorized under this License except under this disclaimer. *
|
||||||
|
* *
|
||||||
|
************************************************************************
|
||||||
|
|
||||||
|
************************************************************************
|
||||||
|
* *
|
||||||
|
* 7. Limitation of Liability *
|
||||||
|
* -------------------------- *
|
||||||
|
* *
|
||||||
|
* Under no circumstances and under no legal theory, whether tort *
|
||||||
|
* (including negligence), contract, or otherwise, shall any *
|
||||||
|
* Contributor, or anyone who distributes Covered Software as *
|
||||||
|
* permitted above, be liable to You for any direct, indirect, *
|
||||||
|
* special, incidental, or consequential damages of any character *
|
||||||
|
* including, without limitation, damages for lost profits, loss of *
|
||||||
|
* goodwill, work stoppage, computer failure or malfunction, or any *
|
||||||
|
* and all other commercial damages or losses, even if such party *
|
||||||
|
* shall have been informed of the possibility of such damages. This *
|
||||||
|
* limitation of liability shall not apply to liability for death or *
|
||||||
|
* personal injury resulting from such party's negligence to the *
|
||||||
|
* extent applicable law prohibits such limitation. Some *
|
||||||
|
* jurisdictions do not allow the exclusion or limitation of *
|
||||||
|
* incidental or consequential damages, so this exclusion and *
|
||||||
|
* limitation may not apply to You. *
|
||||||
|
* *
|
||||||
|
************************************************************************
|
||||||
|
|
||||||
|
8. Litigation
|
||||||
|
-------------
|
||||||
|
|
||||||
|
Any litigation relating to this License may be brought only in the
|
||||||
|
courts of a jurisdiction where the defendant maintains its principal
|
||||||
|
place of business and such litigation shall be governed by laws of that
|
||||||
|
jurisdiction, without reference to its conflict-of-law provisions.
|
||||||
|
Nothing in this Section shall prevent a party's ability to bring
|
||||||
|
cross-claims or counter-claims.
|
||||||
|
|
||||||
|
9. Miscellaneous
|
||||||
|
----------------
|
||||||
|
|
||||||
|
This License represents the complete agreement concerning the subject
|
||||||
|
matter hereof. If any provision of this License is held to be
|
||||||
|
unenforceable, such provision shall be reformed only to the extent
|
||||||
|
necessary to make it enforceable. Any law or regulation which provides
|
||||||
|
that the language of a contract shall be construed against the drafter
|
||||||
|
shall not be used to construe this License against a Contributor.
|
||||||
|
|
||||||
|
10. Versions of the License
|
||||||
|
---------------------------
|
||||||
|
|
||||||
|
10.1. New Versions
|
||||||
|
|
||||||
|
Mozilla Foundation is the license steward. Except as provided in Section
|
||||||
|
10.3, no one other than the license steward has the right to modify or
|
||||||
|
publish new versions of this License. Each version will be given a
|
||||||
|
distinguishing version number.
|
||||||
|
|
||||||
|
10.2. Effect of New Versions
|
||||||
|
|
||||||
|
You may distribute the Covered Software under the terms of the version
|
||||||
|
of the License under which You originally received the Covered Software,
|
||||||
|
or under the terms of any subsequent version published by the license
|
||||||
|
steward.
|
||||||
|
|
||||||
|
10.3. Modified Versions
|
||||||
|
|
||||||
|
If you create software not governed by this License, and you want to
|
||||||
|
create a new license for such software, you may create and use a
|
||||||
|
modified version of this License if you rename the license and remove
|
||||||
|
any references to the name of the license steward (except to note that
|
||||||
|
such modified license differs from this License).
|
||||||
|
|
||||||
|
10.4. Distributing Source Code Form that is Incompatible With Secondary
|
||||||
|
Licenses
|
||||||
|
|
||||||
|
If You choose to distribute Source Code Form that is Incompatible With
|
||||||
|
Secondary Licenses under the terms of this version of the License, the
|
||||||
|
notice described in Exhibit B of this License must be attached.
|
||||||
|
|
||||||
|
Exhibit A - Source Code Form License Notice
|
||||||
|
-------------------------------------------
|
||||||
|
|
||||||
|
This Source Code Form is subject to the terms of the Mozilla Public
|
||||||
|
License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||||
|
file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
|
If it is not possible or desirable to put the notice in a particular
|
||||||
|
file, then You may include the notice in a location (such as a LICENSE
|
||||||
|
file in a relevant directory) where a recipient would be likely to look
|
||||||
|
for such a notice.
|
||||||
|
|
||||||
|
You may add additional accurate notices of copyright ownership.
|
||||||
|
|
||||||
|
Exhibit B - "Incompatible With Secondary Licenses" Notice
|
||||||
|
---------------------------------------------------------
|
||||||
|
|
||||||
|
This Source Code Form is "Incompatible With Secondary Licenses", as
|
||||||
|
defined by the Mozilla Public License, v. 2.0.
|
||||||
2
README.md
Normal file
2
README.md
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
# ohttp-gateway
|
||||||
|
A OHTTP Gateway written in Rust
|
||||||
270
src/config.rs
Normal file
270
src/config.rs
Normal file
|
|
@ -0,0 +1,270 @@
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct AppConfig {
|
||||||
|
// Server configuration
|
||||||
|
pub listen_addr: String,
|
||||||
|
pub backend_url: String,
|
||||||
|
pub request_timeout: Duration,
|
||||||
|
pub max_body_size: usize,
|
||||||
|
|
||||||
|
// Key management
|
||||||
|
pub key_rotation_interval: Duration,
|
||||||
|
pub key_retention_period: Duration,
|
||||||
|
pub key_rotation_enabled: bool,
|
||||||
|
|
||||||
|
// Security configuration
|
||||||
|
pub allowed_target_origins: Option<HashSet<String>>,
|
||||||
|
pub target_rewrites: Option<TargetRewriteConfig>,
|
||||||
|
pub rate_limit: Option<RateLimitConfig>,
|
||||||
|
|
||||||
|
// Operational configuration
|
||||||
|
pub metrics_enabled: bool,
|
||||||
|
pub debug_mode: bool,
|
||||||
|
pub log_format: LogFormat,
|
||||||
|
pub log_level: String,
|
||||||
|
|
||||||
|
// OHTTP specific
|
||||||
|
pub custom_request_type: Option<String>,
|
||||||
|
pub custom_response_type: Option<String>,
|
||||||
|
pub seed_secret_key: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct TargetRewriteConfig {
|
||||||
|
pub rewrites: std::collections::HashMap<String, TargetRewrite>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct TargetRewrite {
|
||||||
|
pub scheme: String,
|
||||||
|
pub host: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct RateLimitConfig {
|
||||||
|
pub requests_per_second: u32,
|
||||||
|
pub burst_size: u32,
|
||||||
|
pub by_ip: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum LogFormat {
|
||||||
|
Default,
|
||||||
|
Json,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for AppConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
listen_addr: "0.0.0.0:8080".to_string(),
|
||||||
|
backend_url: "http://localhost:8080".to_string(),
|
||||||
|
request_timeout: Duration::from_secs(30),
|
||||||
|
max_body_size: 10 * 1024 * 1024, // 10MB
|
||||||
|
key_rotation_interval: Duration::from_secs(30 * 24 * 60 * 60), // 30 days
|
||||||
|
key_retention_period: Duration::from_secs(7 * 24 * 60 * 60), // 7 days
|
||||||
|
key_rotation_enabled: true,
|
||||||
|
allowed_target_origins: None,
|
||||||
|
target_rewrites: None,
|
||||||
|
rate_limit: None,
|
||||||
|
metrics_enabled: true,
|
||||||
|
debug_mode: false,
|
||||||
|
log_format: LogFormat::Default,
|
||||||
|
log_level: "info".to_string(),
|
||||||
|
custom_request_type: None,
|
||||||
|
custom_response_type: None,
|
||||||
|
seed_secret_key: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AppConfig {
|
||||||
|
pub fn from_env() -> Result<Self, Box<dyn std::error::Error>> {
|
||||||
|
let mut config = Self::default();
|
||||||
|
|
||||||
|
// Basic configuration
|
||||||
|
if let Ok(addr) = std::env::var("LISTEN_ADDR") {
|
||||||
|
config.listen_addr = addr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(url) = std::env::var("BACKEND_URL") {
|
||||||
|
config.backend_url = url;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(timeout) = std::env::var("REQUEST_TIMEOUT") {
|
||||||
|
config.request_timeout = Duration::from_secs(timeout.parse()?);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(size) = std::env::var("MAX_BODY_SIZE") {
|
||||||
|
config.max_body_size = size.parse()?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Key management
|
||||||
|
if let Ok(interval) = std::env::var("KEY_ROTATION_INTERVAL") {
|
||||||
|
config.key_rotation_interval = Duration::from_secs(interval.parse()?);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(period) = std::env::var("KEY_RETENTION_PERIOD") {
|
||||||
|
config.key_retention_period = Duration::from_secs(period.parse()?);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(enabled) = std::env::var("KEY_ROTATION_ENABLED") {
|
||||||
|
config.key_rotation_enabled = enabled.parse()?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Security configuration
|
||||||
|
if let Ok(origins) = std::env::var("ALLOWED_TARGET_ORIGINS") {
|
||||||
|
let origins_set: HashSet<String> = origins
|
||||||
|
.split(',')
|
||||||
|
.map(|s| s.trim().to_string())
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if !origins_set.is_empty() {
|
||||||
|
config.allowed_target_origins = Some(origins_set);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(rewrites_json) = std::env::var("TARGET_REWRITES") {
|
||||||
|
let rewrites: std::collections::HashMap<String, TargetRewrite> =
|
||||||
|
serde_json::from_str(&rewrites_json)?;
|
||||||
|
config.target_rewrites = Some(TargetRewriteConfig { rewrites });
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rate limiting
|
||||||
|
if let Ok(rps) = std::env::var("RATE_LIMIT_RPS") {
|
||||||
|
let rate_limit = RateLimitConfig {
|
||||||
|
requests_per_second: rps.parse()?,
|
||||||
|
burst_size: std::env::var("RATE_LIMIT_BURST")
|
||||||
|
.ok()
|
||||||
|
.and_then(|s| s.parse().ok())
|
||||||
|
.unwrap_or(100),
|
||||||
|
by_ip: std::env::var("RATE_LIMIT_BY_IP")
|
||||||
|
.ok()
|
||||||
|
.and_then(|s| s.parse().ok())
|
||||||
|
.unwrap_or(true),
|
||||||
|
};
|
||||||
|
config.rate_limit = Some(rate_limit);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Operational configuration
|
||||||
|
if let Ok(enabled) = std::env::var("METRICS_ENABLED") {
|
||||||
|
config.metrics_enabled = enabled.parse()?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(debug) = std::env::var("GATEWAY_DEBUG") {
|
||||||
|
config.debug_mode = debug.parse()?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(format) = std::env::var("LOG_FORMAT") {
|
||||||
|
config.log_format = match format.to_lowercase().as_str() {
|
||||||
|
"json" => LogFormat::Json,
|
||||||
|
_ => LogFormat::Default,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(level) = std::env::var("LOG_LEVEL") {
|
||||||
|
config.log_level = level;
|
||||||
|
}
|
||||||
|
|
||||||
|
// OHTTP specific
|
||||||
|
if let Ok(req_type) = std::env::var("CUSTOM_REQUEST_TYPE") {
|
||||||
|
config.custom_request_type = Some(req_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(resp_type) = std::env::var("CUSTOM_RESPONSE_TYPE") {
|
||||||
|
config.custom_response_type = Some(resp_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(seed) = std::env::var("SEED_SECRET_KEY") {
|
||||||
|
config.seed_secret_key = Some(seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate configuration
|
||||||
|
config.validate()?;
|
||||||
|
|
||||||
|
Ok(config)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
// Validate key rotation settings
|
||||||
|
if self.key_retention_period > self.key_rotation_interval {
|
||||||
|
return Err("Key retention period cannot be longer than rotation interval".into());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate custom content types
|
||||||
|
match (&self.custom_request_type, &self.custom_response_type) {
|
||||||
|
(Some(req), Some(resp)) if req == resp => {
|
||||||
|
return Err("Request and response content types must be different".into());
|
||||||
|
}
|
||||||
|
(Some(_), None) | (None, Some(_)) => {
|
||||||
|
return Err("Both custom request and response types must be specified".into());
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate seed if provided
|
||||||
|
if let Some(seed) = &self.seed_secret_key {
|
||||||
|
let decoded =
|
||||||
|
hex::decode(seed).map_err(|_| "SEED_SECRET_KEY must be a hex-encoded string")?;
|
||||||
|
|
||||||
|
if decoded.len() < 32 {
|
||||||
|
return Err("SEED_SECRET_KEY must be at least 32 bytes (64 hex characters)".into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if a target origin is allowed
|
||||||
|
pub fn is_origin_allowed(&self, origin: &str) -> bool {
|
||||||
|
match &self.allowed_target_origins {
|
||||||
|
Some(allowed) => allowed.contains(origin),
|
||||||
|
None => true, // No restrictions if not configured
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get rewrite configuration for a host
|
||||||
|
pub fn get_rewrite(&self, host: &str) -> Option<&TargetRewrite> {
|
||||||
|
self.target_rewrites
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|config| config.rewrites.get(host))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_default_config() {
|
||||||
|
let config = AppConfig::default();
|
||||||
|
assert_eq!(config.listen_addr, "0.0.0.0:8080");
|
||||||
|
assert!(config.key_rotation_enabled);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validation_key_periods() {
|
||||||
|
let mut config = AppConfig::default();
|
||||||
|
config.key_retention_period = Duration::from_secs(100);
|
||||||
|
config.key_rotation_interval = Duration::from_secs(50);
|
||||||
|
|
||||||
|
assert!(config.validate().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_origin_allowed() {
|
||||||
|
let mut config = AppConfig::default();
|
||||||
|
config.allowed_target_origins = Some(
|
||||||
|
vec!["example.com".to_string(), "test.com".to_string()]
|
||||||
|
.into_iter()
|
||||||
|
.collect(),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(config.is_origin_allowed("example.com"));
|
||||||
|
assert!(!config.is_origin_allowed("forbidden.com"));
|
||||||
|
}
|
||||||
|
}
|
||||||
66
src/error.rs
Normal file
66
src/error.rs
Normal file
|
|
@ -0,0 +1,66 @@
|
||||||
|
use axum::{
|
||||||
|
http::StatusCode,
|
||||||
|
response::{IntoResponse, Response},
|
||||||
|
Json,
|
||||||
|
};
|
||||||
|
use serde_json::json;
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum GatewayError {
|
||||||
|
#[error("Invalid request: {0}")]
|
||||||
|
InvalidRequest(String),
|
||||||
|
|
||||||
|
#[error("Decryption failed: {0}")]
|
||||||
|
DecryptionError(String),
|
||||||
|
|
||||||
|
#[error("Encryption failed: {0}")]
|
||||||
|
EncryptionError(String),
|
||||||
|
|
||||||
|
#[error("Backend error: {0}")]
|
||||||
|
BackendError(String),
|
||||||
|
|
||||||
|
#[error("Request too large: {0}")]
|
||||||
|
RequestTooLarge(String),
|
||||||
|
|
||||||
|
#[error("Configuration error: {0}")]
|
||||||
|
ConfigurationError(String),
|
||||||
|
|
||||||
|
#[error("Internal error: {0}")]
|
||||||
|
InternalError(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse for GatewayError {
|
||||||
|
fn into_response(self) -> Response {
|
||||||
|
let (status, error_code, message) = match self {
|
||||||
|
GatewayError::InvalidRequest(msg) => (StatusCode::BAD_REQUEST, "invalid_request", msg),
|
||||||
|
GatewayError::DecryptionError(msg) => {
|
||||||
|
(StatusCode::BAD_REQUEST, "decryption_error", msg)
|
||||||
|
}
|
||||||
|
GatewayError::EncryptionError(msg) => {
|
||||||
|
(StatusCode::INTERNAL_SERVER_ERROR, "encryption_error", msg)
|
||||||
|
}
|
||||||
|
GatewayError::BackendError(msg) => (StatusCode::BAD_GATEWAY, "backend_error", msg),
|
||||||
|
GatewayError::RequestTooLarge(msg) => {
|
||||||
|
(StatusCode::PAYLOAD_TOO_LARGE, "request_too_large", msg)
|
||||||
|
}
|
||||||
|
GatewayError::ConfigurationError(msg) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
"configuration_error",
|
||||||
|
msg,
|
||||||
|
),
|
||||||
|
GatewayError::InternalError(msg) => {
|
||||||
|
(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", msg)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let body = Json(json!({
|
||||||
|
"error": {
|
||||||
|
"code": error_code,
|
||||||
|
"message": message
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
|
||||||
|
(status, body).into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
77
src/handlers/health.rs
Normal file
77
src/handlers/health.rs
Normal file
|
|
@ -0,0 +1,77 @@
|
||||||
|
use crate::{error::GatewayError, state::AppState};
|
||||||
|
use axum::{extract::State, Json};
|
||||||
|
use serde_json::json;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
pub async fn health_check(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
) -> Result<Json<serde_json::Value>, GatewayError> {
|
||||||
|
let mut health_checks = vec![];
|
||||||
|
|
||||||
|
// Check key manager health
|
||||||
|
let key_status = match state.key_manager.get_encoded_config().await {
|
||||||
|
Ok(config) if config.len() > 100 => "healthy",
|
||||||
|
Ok(_) => "unhealthy",
|
||||||
|
Err(_) => "unhealthy",
|
||||||
|
};
|
||||||
|
|
||||||
|
health_checks.push(json!({
|
||||||
|
"component": "ohttp_keys",
|
||||||
|
"status": key_status
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Check backend connectivity - use the correct health endpoint
|
||||||
|
let backend_health_url = format!("{}/health", state.config.backend_url);
|
||||||
|
let backend_status = match state
|
||||||
|
.http_client
|
||||||
|
.get(&backend_health_url)
|
||||||
|
.timeout(Duration::from_secs(5))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(resp) if resp.status().is_success() => "healthy",
|
||||||
|
Ok(resp) => {
|
||||||
|
tracing::warn!("Backend health check returned: {}", resp.status());
|
||||||
|
"unhealthy"
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Backend health check failed: {}", e);
|
||||||
|
"unhealthy"
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
health_checks.push(json!({
|
||||||
|
"component": "backend",
|
||||||
|
"status": backend_status,
|
||||||
|
"url": backend_health_url
|
||||||
|
}));
|
||||||
|
|
||||||
|
let overall_status = if health_checks.iter().all(|c| c["status"] == "healthy") {
|
||||||
|
"healthy"
|
||||||
|
} else {
|
||||||
|
"unhealthy"
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Json(json!({
|
||||||
|
"status": overall_status,
|
||||||
|
"timestamp": chrono::Utc::now().to_rfc3339(),
|
||||||
|
"checks": health_checks,
|
||||||
|
"version": env!("CARGO_PKG_VERSION")
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn metrics_handler() -> Result<String, GatewayError> {
|
||||||
|
use prometheus::{Encoder, TextEncoder};
|
||||||
|
|
||||||
|
let encoder = TextEncoder::new();
|
||||||
|
let metric_families = prometheus::gather();
|
||||||
|
|
||||||
|
let mut buffer = Vec::new();
|
||||||
|
encoder
|
||||||
|
.encode(&metric_families, &mut buffer)
|
||||||
|
.map_err(|e| GatewayError::InternalError(format!("Failed to encode metrics: {e}")))?;
|
||||||
|
|
||||||
|
String::from_utf8(buffer).map_err(|e| {
|
||||||
|
GatewayError::InternalError(format!("Failed to convert metrics to string: {e}"))
|
||||||
|
})
|
||||||
|
}
|
||||||
83
src/handlers/keys.rs
Normal file
83
src/handlers/keys.rs
Normal file
|
|
@ -0,0 +1,83 @@
|
||||||
|
use crate::AppState;
|
||||||
|
use axum::{
|
||||||
|
extract::State,
|
||||||
|
http::{header, HeaderName, StatusCode},
|
||||||
|
response::{IntoResponse, Response},
|
||||||
|
};
|
||||||
|
use chrono::Utc;
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
/// Handler for /ohttp-keys endpoint
|
||||||
|
/// Returns key configurations in the standard OHTTP format
|
||||||
|
pub async fn get_ohttp_keys(State(state): State<AppState>) -> Result<Response, StatusCode> {
|
||||||
|
state.metrics.key_requests_total.inc();
|
||||||
|
|
||||||
|
match state.key_manager.get_encoded_config().await {
|
||||||
|
Ok(config_bytes) => {
|
||||||
|
info!("Serving {} bytes of key configurations", config_bytes.len());
|
||||||
|
|
||||||
|
// Calculate cache duration based on rotation interval
|
||||||
|
let max_age = calculate_cache_max_age(&state);
|
||||||
|
|
||||||
|
Ok((
|
||||||
|
StatusCode::OK,
|
||||||
|
[
|
||||||
|
(header::CONTENT_TYPE, "application/ohttp-keys"),
|
||||||
|
(header::CACHE_CONTROL, &format!("public, max-age={max_age}")),
|
||||||
|
(HeaderName::from_static("x-content-type-options"), "nosniff"),
|
||||||
|
],
|
||||||
|
config_bytes,
|
||||||
|
)
|
||||||
|
.into_response())
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Failed to encode key config: {e}");
|
||||||
|
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Legacy endpoint for backward compatibility
|
||||||
|
/// Some older clients may still use /ohttp-configs
|
||||||
|
pub async fn get_legacy_ohttp_configs(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
) -> Result<Response, StatusCode> {
|
||||||
|
// Just forward to the main handler
|
||||||
|
get_ohttp_keys(State(state)).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate appropriate cache duration for key configurations
|
||||||
|
fn calculate_cache_max_age(state: &AppState) -> u64 {
|
||||||
|
// Cache for 10% of rotation interval, minimum 1 hour, maximum 24 hours
|
||||||
|
let ten_percent = state.config.key_rotation_interval.as_secs() / 10;
|
||||||
|
let one_hour = 3600;
|
||||||
|
let twenty_four_hours = 86400;
|
||||||
|
|
||||||
|
ten_percent.max(one_hour).min(twenty_four_hours)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Health check endpoint specifically for key management
|
||||||
|
pub async fn key_health_check(State(state): State<AppState>) -> impl IntoResponse {
|
||||||
|
let stats = state.key_manager.get_stats().await;
|
||||||
|
|
||||||
|
let health_status = if stats.active_keys > 0 && stats.expired_keys == 0 {
|
||||||
|
"healthy"
|
||||||
|
} else if stats.active_keys > 0 {
|
||||||
|
"degraded"
|
||||||
|
} else {
|
||||||
|
"unhealthy"
|
||||||
|
};
|
||||||
|
|
||||||
|
axum::Json(serde_json::json!({
|
||||||
|
"status": health_status,
|
||||||
|
"timestamp": Utc::now().to_rfc3339(),
|
||||||
|
"key_stats": {
|
||||||
|
"active_key_id": stats.active_key_id,
|
||||||
|
"total_keys": stats.total_keys,
|
||||||
|
"active_keys": stats.active_keys,
|
||||||
|
"expired_keys": stats.expired_keys,
|
||||||
|
"rotation_enabled": stats.auto_rotation_enabled,
|
||||||
|
"rotation_interval_hours": stats.rotation_interval.as_secs() / 3600,
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
22
src/handlers/mod.rs
Normal file
22
src/handlers/mod.rs
Normal file
|
|
@ -0,0 +1,22 @@
|
||||||
|
pub mod health;
|
||||||
|
pub mod keys;
|
||||||
|
pub mod ohttp;
|
||||||
|
|
||||||
|
use crate::state::AppState;
|
||||||
|
use axum::{
|
||||||
|
routing::{get, post},
|
||||||
|
Router,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn routes() -> Router<AppState> {
|
||||||
|
Router::new()
|
||||||
|
// OHTTP endpoints
|
||||||
|
.route("/gateway", post(ohttp::handle_ohttp_request))
|
||||||
|
.route("/ohttp-keys", get(keys::get_ohttp_keys))
|
||||||
|
// Legacy endpoints for backward compatibility
|
||||||
|
.route("/ohttp-configs", get(keys::get_legacy_ohttp_configs))
|
||||||
|
// Health and monitoring
|
||||||
|
.route("/health", get(health::health_check))
|
||||||
|
.route("/health/keys", get(keys::key_health_check))
|
||||||
|
.route("/metrics", get(health::metrics_handler))
|
||||||
|
}
|
||||||
477
src/handlers/ohttp.rs
Normal file
477
src/handlers/ohttp.rs
Normal file
|
|
@ -0,0 +1,477 @@
|
||||||
|
use crate::{error::GatewayError, state::AppState};
|
||||||
|
use axum::{
|
||||||
|
body::{Body, Bytes},
|
||||||
|
extract::State,
|
||||||
|
http::{header, HeaderMap, StatusCode},
|
||||||
|
response::{IntoResponse, Response},
|
||||||
|
};
|
||||||
|
use bhttp::{Message, Mode};
|
||||||
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
|
const OHTTP_REQUEST_CONTENT_TYPE: &str = "message/ohttp-req";
|
||||||
|
const OHTTP_RESPONSE_CONTENT_TYPE: &str = "message/ohttp-res";
|
||||||
|
|
||||||
|
pub async fn handle_ohttp_request(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
headers: HeaderMap,
|
||||||
|
body: Bytes,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
let timer = state.metrics.request_duration.start_timer();
|
||||||
|
state.metrics.requests_total.inc();
|
||||||
|
|
||||||
|
// Extract key ID from the request if possible
|
||||||
|
let key_id = extract_key_id_from_request(&body);
|
||||||
|
|
||||||
|
let result = handle_ohttp_request_inner(state.clone(), headers, body, key_id).await;
|
||||||
|
timer.stop_and_record();
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(response) => response,
|
||||||
|
Err(e) => {
|
||||||
|
error!("OHTTP request failed: {:?}", e);
|
||||||
|
|
||||||
|
// Log metrics based on error type
|
||||||
|
match &e {
|
||||||
|
GatewayError::DecryptionError(_) => state.metrics.decryption_errors_total.inc(),
|
||||||
|
GatewayError::EncryptionError(_) => state.metrics.encryption_errors_total.inc(),
|
||||||
|
GatewayError::BackendError(_) => state.metrics.backend_errors_total.inc(),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
e.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_ohttp_request_inner(
|
||||||
|
state: AppState,
|
||||||
|
headers: HeaderMap,
|
||||||
|
body: Bytes,
|
||||||
|
key_id: Option<u8>,
|
||||||
|
) -> Result<Response, GatewayError> {
|
||||||
|
// Validate request
|
||||||
|
validate_ohttp_request(&headers, &body, &state)?;
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"Received OHTTP request with {} bytes, key_id: {:?}",
|
||||||
|
body.len(),
|
||||||
|
key_id
|
||||||
|
);
|
||||||
|
|
||||||
|
// Get the appropriate server based on key ID
|
||||||
|
let server = if let Some(id) = key_id {
|
||||||
|
// Try to get server for specific key ID
|
||||||
|
match state.key_manager.get_server_by_id(id).await {
|
||||||
|
Some(server) => {
|
||||||
|
debug!("Using server for key ID: {}", id);
|
||||||
|
server
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
warn!("Unknown key ID: {}, falling back to current server", id);
|
||||||
|
state
|
||||||
|
.key_manager
|
||||||
|
.get_current_server()
|
||||||
|
.await
|
||||||
|
.map_err(|e| GatewayError::ConfigurationError(e.to_string()))?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Use current active server
|
||||||
|
state
|
||||||
|
.key_manager
|
||||||
|
.get_current_server()
|
||||||
|
.await
|
||||||
|
.map_err(|e| GatewayError::ConfigurationError(e.to_string()))?
|
||||||
|
};
|
||||||
|
|
||||||
|
// Decrypt the OHTTP request
|
||||||
|
let (bhttp_request, server_response) = server.decapsulate(&body).map_err(|e| {
|
||||||
|
error!("Failed to decapsulate OHTTP request: {e}");
|
||||||
|
GatewayError::DecryptionError(format!("Failed to decapsulate: {e}"))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"Successfully decapsulated request, {} bytes",
|
||||||
|
bhttp_request.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
// Parse binary HTTP message
|
||||||
|
let message = parse_bhttp_message(&bhttp_request)?;
|
||||||
|
|
||||||
|
// Validate and potentially transform the request
|
||||||
|
let message = validate_and_transform_request(message, &state)?;
|
||||||
|
|
||||||
|
// Forward request to backend
|
||||||
|
let backend_response = forward_to_backend(&state, message).await?;
|
||||||
|
|
||||||
|
// Convert response to binary HTTP format
|
||||||
|
let bhttp_response = convert_to_binary_http(backend_response).await?;
|
||||||
|
|
||||||
|
// Encrypt response back to client
|
||||||
|
let encrypted_response = server_response.encapsulate(&bhttp_response).map_err(|e| {
|
||||||
|
GatewayError::EncryptionError(format!("Failed to encapsulate response: {e}"))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
state.metrics.successful_requests_total.inc();
|
||||||
|
info!("Successfully processed OHTTP request");
|
||||||
|
|
||||||
|
// Build response with appropriate headers
|
||||||
|
Response::builder()
|
||||||
|
.status(StatusCode::OK)
|
||||||
|
.header(header::CONTENT_TYPE, OHTTP_RESPONSE_CONTENT_TYPE)
|
||||||
|
.header(header::CACHE_CONTROL, "no-cache, no-store, must-revalidate")
|
||||||
|
.header("X-Content-Type-Options", "nosniff")
|
||||||
|
.header("X-Frame-Options", "DENY")
|
||||||
|
.body(Body::from(encrypted_response))
|
||||||
|
.map_err(|e| GatewayError::InternalError(format!("Response build error: {e}")))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract key ID from OHTTP request (first byte after version)
|
||||||
|
fn extract_key_id_from_request(body: &[u8]) -> Option<u8> {
|
||||||
|
// OHTTP request format: version(1) + key_id(1) + kem_id(2) + kdf_id(2) + aead_id(2) + enc + ciphertext
|
||||||
|
if body.len() > 1 {
|
||||||
|
Some(body[1])
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate the incoming OHTTP request
|
||||||
|
fn validate_ohttp_request(
|
||||||
|
headers: &HeaderMap,
|
||||||
|
body: &Bytes,
|
||||||
|
state: &AppState,
|
||||||
|
) -> Result<(), GatewayError> {
|
||||||
|
// Check content type
|
||||||
|
let content_type = headers
|
||||||
|
.get(header::CONTENT_TYPE)
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.ok_or_else(|| GatewayError::InvalidRequest("Missing content-type header".to_string()))?;
|
||||||
|
|
||||||
|
if content_type != OHTTP_REQUEST_CONTENT_TYPE {
|
||||||
|
return Err(GatewayError::InvalidRequest(format!(
|
||||||
|
"Invalid content-type: expected '{OHTTP_REQUEST_CONTENT_TYPE}', got '{content_type}'"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check body size
|
||||||
|
if body.is_empty() {
|
||||||
|
return Err(GatewayError::InvalidRequest(
|
||||||
|
"Empty request body".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if body.len() > state.config.max_body_size {
|
||||||
|
return Err(GatewayError::RequestTooLarge(format!(
|
||||||
|
"Request body too large: {} bytes (max: {})",
|
||||||
|
body.len(),
|
||||||
|
state.config.max_body_size
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Minimum OHTTP request size check
|
||||||
|
if body.len() < 10 {
|
||||||
|
return Err(GatewayError::InvalidRequest(
|
||||||
|
"Request too small to be valid OHTTP".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse binary HTTP message with error handling
|
||||||
|
fn parse_bhttp_message(data: &[u8]) -> Result<Message, GatewayError> {
|
||||||
|
let mut cursor = std::io::Cursor::new(data);
|
||||||
|
Message::read_bhttp(&mut cursor)
|
||||||
|
.map_err(|e| GatewayError::InvalidRequest(format!("Failed to parse binary HTTP: {e}")))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate and transform the request based on security policies
|
||||||
|
fn validate_and_transform_request(
|
||||||
|
message: Message,
|
||||||
|
state: &AppState,
|
||||||
|
) -> Result<Message, GatewayError> {
|
||||||
|
let control = message.control();
|
||||||
|
|
||||||
|
// Extract host from authority or Host header
|
||||||
|
let host = control
|
||||||
|
.authority()
|
||||||
|
.map(|a| String::from_utf8_lossy(a).into_owned())
|
||||||
|
.or_else(|| {
|
||||||
|
message.header().fields().iter().find_map(|field| {
|
||||||
|
if field.name().eq_ignore_ascii_case(b"host") {
|
||||||
|
Some(String::from_utf8_lossy(field.value()).into_owned())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.ok_or_else(|| GatewayError::InvalidRequest("Missing host/authority".to_string()))?;
|
||||||
|
|
||||||
|
// Check if origin is allowed
|
||||||
|
if !state.config.is_origin_allowed(&host) {
|
||||||
|
warn!("Blocked request to forbidden origin: {host}");
|
||||||
|
return Err(GatewayError::InvalidRequest(format!(
|
||||||
|
"Target origin not allowed: {host}"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply any configured rewrites
|
||||||
|
if let Some(rewrite) = state.config.get_rewrite(&host) {
|
||||||
|
debug!(
|
||||||
|
"Applying rewrite for host {}: {} -> {}",
|
||||||
|
host, rewrite.scheme, rewrite.host
|
||||||
|
);
|
||||||
|
|
||||||
|
// Clone the message to modify it
|
||||||
|
let mut new_message = Message::request(
|
||||||
|
Vec::from(control.method().unwrap_or(b"GET")),
|
||||||
|
Vec::from(
|
||||||
|
format!(
|
||||||
|
"{}://{}{}",
|
||||||
|
rewrite.scheme,
|
||||||
|
rewrite.host,
|
||||||
|
control
|
||||||
|
.path()
|
||||||
|
.map(|p| String::from_utf8_lossy(p))
|
||||||
|
.unwrap_or("/".into())
|
||||||
|
)
|
||||||
|
.as_bytes(),
|
||||||
|
),
|
||||||
|
Vec::from(control.scheme().unwrap_or(rewrite.scheme.as_bytes())),
|
||||||
|
Vec::from(rewrite.host.as_bytes()),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Copy all headers except Host and Authority
|
||||||
|
for field in message.header().fields() {
|
||||||
|
let name = field.name();
|
||||||
|
if !name.eq_ignore_ascii_case(b"host") && !name.eq_ignore_ascii_case(b"authority") {
|
||||||
|
new_message.put_header(name, field.value());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the new Host header
|
||||||
|
new_message.put_header(b"host", rewrite.host.as_bytes());
|
||||||
|
|
||||||
|
// Copy body content
|
||||||
|
if !message.content().is_empty() {
|
||||||
|
new_message.write_content(message.content());
|
||||||
|
}
|
||||||
|
|
||||||
|
return Ok(new_message);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(message)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn forward_to_backend(
|
||||||
|
state: &AppState,
|
||||||
|
bhttp_message: Message,
|
||||||
|
) -> Result<reqwest::Response, GatewayError> {
|
||||||
|
let control = bhttp_message.control();
|
||||||
|
let method = control.method().unwrap_or(b"GET");
|
||||||
|
let path = control
|
||||||
|
.path()
|
||||||
|
.map(|p| String::from_utf8_lossy(p).into_owned())
|
||||||
|
.unwrap_or_else(|| "/".to_string());
|
||||||
|
|
||||||
|
// Extract host for URL construction
|
||||||
|
let host = control
|
||||||
|
.authority()
|
||||||
|
.map(|a| String::from_utf8_lossy(a).into_owned())
|
||||||
|
.or_else(|| {
|
||||||
|
bhttp_message.header().fields().iter().find_map(|field| {
|
||||||
|
if field.name().eq_ignore_ascii_case(b"host") {
|
||||||
|
Some(String::from_utf8_lossy(field.value()).into_owned())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
// Build the backend URI
|
||||||
|
let uri = if let Some(host) = host {
|
||||||
|
// Check for rewrites
|
||||||
|
if let Some(rewrite) = state.config.get_rewrite(&host) {
|
||||||
|
format!("{}://{}{}", rewrite.scheme, rewrite.host, path)
|
||||||
|
} else {
|
||||||
|
build_backend_uri(&state.config.backend_url, &path)?
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
build_backend_uri(&state.config.backend_url, &path)?
|
||||||
|
};
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Forwarding {} request to {}",
|
||||||
|
String::from_utf8_lossy(method),
|
||||||
|
uri
|
||||||
|
);
|
||||||
|
|
||||||
|
let reqwest_method = convert_method_to_reqwest(method);
|
||||||
|
let mut request_builder = state.http_client.request(reqwest_method, &uri);
|
||||||
|
|
||||||
|
// Add headers from the binary HTTP message
|
||||||
|
for field in bhttp_message.header().fields() {
|
||||||
|
let name = String::from_utf8_lossy(field.name());
|
||||||
|
let value = String::from_utf8_lossy(field.value());
|
||||||
|
|
||||||
|
// Skip headers that should not be forwarded
|
||||||
|
if should_forward_header(&name) {
|
||||||
|
request_builder = request_builder.header(name.as_ref(), value.as_ref());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add body if present
|
||||||
|
let content = bhttp_message.content();
|
||||||
|
if !content.is_empty() {
|
||||||
|
request_builder = request_builder.body(content.to_vec());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send request with timeout
|
||||||
|
let response = request_builder.send().await.map_err(|e| {
|
||||||
|
error!("Backend request failed: {e}");
|
||||||
|
GatewayError::BackendError(format!("Backend request failed: {e}"))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Check for backend errors
|
||||||
|
if response.status().is_server_error() {
|
||||||
|
return Err(GatewayError::BackendError(format!(
|
||||||
|
"Backend returned error: {}",
|
||||||
|
response.status()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_method_to_reqwest(method: &[u8]) -> reqwest::Method {
|
||||||
|
match method {
|
||||||
|
b"GET" => reqwest::Method::GET,
|
||||||
|
b"POST" => reqwest::Method::POST,
|
||||||
|
b"PUT" => reqwest::Method::PUT,
|
||||||
|
b"DELETE" => reqwest::Method::DELETE,
|
||||||
|
b"HEAD" => reqwest::Method::HEAD,
|
||||||
|
b"OPTIONS" => reqwest::Method::OPTIONS,
|
||||||
|
b"PATCH" => reqwest::Method::PATCH,
|
||||||
|
b"TRACE" => reqwest::Method::TRACE,
|
||||||
|
_ => reqwest::Method::GET,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_backend_uri(backend_url: &str, path: &str) -> Result<String, GatewayError> {
|
||||||
|
let base_url = backend_url.trim_end_matches('/');
|
||||||
|
let path = if path.starts_with('/') {
|
||||||
|
path
|
||||||
|
} else {
|
||||||
|
&format!("/{path}")
|
||||||
|
};
|
||||||
|
|
||||||
|
// Validate path to prevent SSRF attacks
|
||||||
|
if path.contains("..") || path.contains("//") {
|
||||||
|
return Err(GatewayError::InvalidRequest(
|
||||||
|
"Invalid path detected".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Additional validation for suspicious patterns
|
||||||
|
if path.contains('\0') || path.contains('\r') || path.contains('\n') {
|
||||||
|
return Err(GatewayError::InvalidRequest(
|
||||||
|
"Invalid characters in path".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(format!("{base_url}{path}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn should_forward_header(name: &str) -> bool {
|
||||||
|
const SKIP_HEADERS: &[&str] = &[
|
||||||
|
"host",
|
||||||
|
"connection",
|
||||||
|
"upgrade",
|
||||||
|
"proxy-authorization",
|
||||||
|
"proxy-authenticate",
|
||||||
|
"te",
|
||||||
|
"trailers",
|
||||||
|
"transfer-encoding",
|
||||||
|
"keep-alive",
|
||||||
|
"http2-settings",
|
||||||
|
"upgrade-insecure-requests",
|
||||||
|
];
|
||||||
|
|
||||||
|
!SKIP_HEADERS.contains(&name.to_lowercase().as_str())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn convert_to_binary_http(response: reqwest::Response) -> Result<Vec<u8>, GatewayError> {
|
||||||
|
let status = response.status();
|
||||||
|
let headers = response.headers().clone();
|
||||||
|
let body = response
|
||||||
|
.bytes()
|
||||||
|
.await
|
||||||
|
.map_err(|e| GatewayError::BackendError(format!("Failed to read response body: {e}")))?;
|
||||||
|
|
||||||
|
// Create a bhttp response message
|
||||||
|
let mut message = Message::response(
|
||||||
|
bhttp::StatusCode::try_from(status.as_u16())
|
||||||
|
.map_err(|_| GatewayError::InternalError("Invalid status code".to_string()))?,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Add headers
|
||||||
|
for (name, value) in headers.iter() {
|
||||||
|
if should_forward_header(name.as_str()) {
|
||||||
|
message.put_header(name.as_str().as_bytes(), value.as_bytes());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add body
|
||||||
|
if !body.is_empty() {
|
||||||
|
message.write_content(&body);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize to binary HTTP using KnownLength mode for compatibility
|
||||||
|
let mut output = Vec::new();
|
||||||
|
message
|
||||||
|
.write_bhttp(Mode::KnownLength, &mut output)
|
||||||
|
.map_err(|e| GatewayError::InternalError(format!("Failed to write binary HTTP: {e}")))?;
|
||||||
|
|
||||||
|
debug!("Created BHTTP response of {} bytes", output.len());
|
||||||
|
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_extract_key_id() {
|
||||||
|
let body = vec![0x00, 0x7F, 0x00, 0x01]; // version, key_id, kem_id...
|
||||||
|
assert_eq!(extract_key_id_from_request(&body), Some(0x7F));
|
||||||
|
|
||||||
|
let empty = vec![];
|
||||||
|
assert_eq!(extract_key_id_from_request(&empty), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_should_forward_header() {
|
||||||
|
assert!(should_forward_header("content-type"));
|
||||||
|
assert!(should_forward_header("authorization"));
|
||||||
|
assert!(!should_forward_header("connection"));
|
||||||
|
assert!(!should_forward_header("Host"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_backend_uri() {
|
||||||
|
assert_eq!(
|
||||||
|
build_backend_uri("https://backend.com", "/api/test").unwrap(),
|
||||||
|
"https://backend.com/api/test"
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
build_backend_uri("https://backend.com/", "/api/test").unwrap(),
|
||||||
|
"https://backend.com/api/test"
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(build_backend_uri("https://backend.com", "/../etc/passwd").is_err());
|
||||||
|
assert!(build_backend_uri("https://backend.com", "//evil.com").is_err());
|
||||||
|
}
|
||||||
|
}
|
||||||
379
src/key_manager.rs
Normal file
379
src/key_manager.rs
Normal file
|
|
@ -0,0 +1,379 @@
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use ohttp::{
|
||||||
|
hpke::{Aead, Kdf, Kem},
|
||||||
|
KeyConfig, Server as OhttpServer, SymmetricSuite,
|
||||||
|
};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
use tracing::{error, info};
|
||||||
|
|
||||||
|
/// Represents a key with its metadata
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct KeyInfo {
|
||||||
|
pub id: u8,
|
||||||
|
pub config: KeyConfig,
|
||||||
|
pub server: OhttpServer,
|
||||||
|
pub expires_at: DateTime<Utc>,
|
||||||
|
pub is_active: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Configuration for key management
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct KeyManagerConfig {
|
||||||
|
/// How often to rotate keys (default: 30 days)
|
||||||
|
pub rotation_interval: Duration,
|
||||||
|
/// How long to keep old keys for decryption (default: 7 days)
|
||||||
|
pub key_retention_period: Duration,
|
||||||
|
/// Whether to enable automatic rotation
|
||||||
|
pub auto_rotation_enabled: bool,
|
||||||
|
/// Supported cipher suites
|
||||||
|
pub cipher_suites: Vec<CipherSuiteConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct CipherSuiteConfig {
|
||||||
|
pub kem: String,
|
||||||
|
pub kdf: String,
|
||||||
|
pub aead: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for KeyManagerConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
rotation_interval: Duration::from_secs(30 * 24 * 60 * 60), // 30 days
|
||||||
|
key_retention_period: Duration::from_secs(7 * 24 * 60 * 60), // 7 days
|
||||||
|
auto_rotation_enabled: true,
|
||||||
|
cipher_suites: vec![
|
||||||
|
CipherSuiteConfig {
|
||||||
|
kem: "X25519_SHA256".to_string(),
|
||||||
|
kdf: "HKDF_SHA256".to_string(),
|
||||||
|
aead: "AES_128_GCM".to_string(),
|
||||||
|
},
|
||||||
|
CipherSuiteConfig {
|
||||||
|
kem: "X25519_SHA256".to_string(),
|
||||||
|
kdf: "HKDF_SHA256".to_string(),
|
||||||
|
aead: "CHACHA20_POLY1305".to_string(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct KeyManager {
|
||||||
|
/// All keys indexed by ID
|
||||||
|
keys: Arc<RwLock<HashMap<u8, KeyInfo>>>,
|
||||||
|
/// Current active key ID
|
||||||
|
active_key_id: Arc<RwLock<u8>>,
|
||||||
|
/// Configuration
|
||||||
|
config: KeyManagerConfig,
|
||||||
|
/// Key ID counter (wraps around after 255)
|
||||||
|
next_key_id: Arc<RwLock<u8>>,
|
||||||
|
/// Seed for deterministic key generation (optional)
|
||||||
|
seed: Option<Vec<u8>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl KeyManager {
|
||||||
|
pub async fn new(config: KeyManagerConfig) -> Result<Self, Box<dyn std::error::Error>> {
|
||||||
|
let manager = Self {
|
||||||
|
keys: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
active_key_id: Arc::new(RwLock::new(0)),
|
||||||
|
config,
|
||||||
|
next_key_id: Arc::new(RwLock::new(1)),
|
||||||
|
seed: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Generate initial key
|
||||||
|
let initial_key = manager.generate_new_key().await?;
|
||||||
|
{
|
||||||
|
let mut keys = manager.keys.write().await;
|
||||||
|
let mut active_id = manager.active_key_id.write().await;
|
||||||
|
|
||||||
|
keys.insert(initial_key.id, initial_key.clone());
|
||||||
|
*active_id = initial_key.id;
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("KeyManager initialized with key ID: {}", initial_key.id);
|
||||||
|
Ok(manager)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a key manager with a seed for deterministic key generation
|
||||||
|
pub async fn new_with_seed(
|
||||||
|
config: KeyManagerConfig,
|
||||||
|
seed: Vec<u8>,
|
||||||
|
) -> Result<Self, Box<dyn std::error::Error>> {
|
||||||
|
if seed.len() < 32 {
|
||||||
|
return Err("Seed must be at least 32 bytes".into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut manager = Self::new(config).await?;
|
||||||
|
manager.seed = Some(seed);
|
||||||
|
Ok(manager)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate a new key configuration
|
||||||
|
async fn generate_new_key(&self) -> Result<KeyInfo, Box<dyn std::error::Error>> {
|
||||||
|
let key_id = {
|
||||||
|
let mut next_id = self.next_key_id.write().await;
|
||||||
|
let id = *next_id;
|
||||||
|
*next_id = next_id.wrapping_add(1);
|
||||||
|
id
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse cipher suites from config
|
||||||
|
let mut symmetric_suites = Vec::new();
|
||||||
|
for suite in &self.config.cipher_suites {
|
||||||
|
let kdf = match suite.kdf.as_str() {
|
||||||
|
"HKDF_SHA256" => Kdf::HkdfSha256,
|
||||||
|
"HKDF_SHA384" => Kdf::HkdfSha384,
|
||||||
|
"HKDF_SHA512" => Kdf::HkdfSha512,
|
||||||
|
_ => Kdf::HkdfSha256,
|
||||||
|
};
|
||||||
|
|
||||||
|
let aead = match suite.aead.as_str() {
|
||||||
|
"AES_128_GCM" => Aead::Aes128Gcm,
|
||||||
|
"AES_256_GCM" => Aead::Aes256Gcm,
|
||||||
|
"CHACHA20_POLY1305" => Aead::ChaCha20Poly1305,
|
||||||
|
_ => Aead::Aes128Gcm,
|
||||||
|
};
|
||||||
|
|
||||||
|
symmetric_suites.push(SymmetricSuite::new(kdf, aead));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine KEM based on config - only X25519 is supported by ohttp crate
|
||||||
|
let kem = Kem::X25519Sha256;
|
||||||
|
|
||||||
|
// Generate key config
|
||||||
|
let key_config = if let Some(seed) = &self.seed {
|
||||||
|
// Deterministic generation using seed + key_id
|
||||||
|
let mut key_seed = seed.clone();
|
||||||
|
key_seed.push(key_id);
|
||||||
|
|
||||||
|
// The ohttp crate doesn't directly support deterministic key generation
|
||||||
|
// This would require extending the crate or using a custom implementation
|
||||||
|
KeyConfig::new(key_id, kem, symmetric_suites)?
|
||||||
|
} else {
|
||||||
|
KeyConfig::new(key_id, kem, symmetric_suites)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let server = OhttpServer::new(key_config.clone())?;
|
||||||
|
let now = Utc::now();
|
||||||
|
|
||||||
|
Ok(KeyInfo {
|
||||||
|
id: key_id,
|
||||||
|
config: key_config,
|
||||||
|
server,
|
||||||
|
expires_at: now + chrono::Duration::from_std(self.config.rotation_interval)?,
|
||||||
|
is_active: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the current active server for decryption
|
||||||
|
pub async fn get_current_server(&self) -> Result<OhttpServer, Box<dyn std::error::Error>> {
|
||||||
|
let keys = self.keys.read().await;
|
||||||
|
let active_id = self.active_key_id.read().await;
|
||||||
|
|
||||||
|
keys.get(&*active_id)
|
||||||
|
.map(|info| info.server.clone())
|
||||||
|
.ok_or_else(|| "No active key found".into())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a server by key ID (for handling requests with specific key IDs)
|
||||||
|
pub async fn get_server_by_id(&self, key_id: u8) -> Option<OhttpServer> {
|
||||||
|
let keys = self.keys.read().await;
|
||||||
|
keys.get(&key_id).map(|info| info.server.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get encoded config for backward compatibility
|
||||||
|
pub async fn get_encoded_config(&self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
|
||||||
|
let keys = self.keys.read().await;
|
||||||
|
let active_id = self.active_key_id.read().await;
|
||||||
|
let cfg_bytes = keys
|
||||||
|
.get(&*active_id)
|
||||||
|
.ok_or("no active key")?
|
||||||
|
.config
|
||||||
|
.encode()?;
|
||||||
|
|
||||||
|
let mut out = Vec::with_capacity(cfg_bytes.len() + 2);
|
||||||
|
out.extend_from_slice(&(cfg_bytes.len() as u16).to_be_bytes()); // 2-byte length
|
||||||
|
out.extend_from_slice(&cfg_bytes);
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Rotate keys by generating a new key and marking old ones for expiration
|
||||||
|
pub async fn rotate_keys(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
info!("Starting key rotation");
|
||||||
|
|
||||||
|
// Generate new key
|
||||||
|
let new_key = self.generate_new_key().await?;
|
||||||
|
let new_key_id = new_key.id;
|
||||||
|
|
||||||
|
// Update key store
|
||||||
|
{
|
||||||
|
let mut keys = self.keys.write().await;
|
||||||
|
let mut active_id = self.active_key_id.write().await;
|
||||||
|
let now = Utc::now();
|
||||||
|
|
||||||
|
// Mark current active key for future expiration
|
||||||
|
if let Some(current_key) = keys.get_mut(&*active_id) {
|
||||||
|
current_key.is_active = false;
|
||||||
|
// Keep it around for the retention period
|
||||||
|
current_key.expires_at =
|
||||||
|
now + chrono::Duration::from_std(self.config.key_retention_period)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new key
|
||||||
|
keys.insert(new_key_id, new_key);
|
||||||
|
|
||||||
|
// Update active key ID
|
||||||
|
*active_id = new_key_id;
|
||||||
|
|
||||||
|
// Clean up expired keys
|
||||||
|
keys.retain(|_, info| info.expires_at > now);
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Key rotation completed. New active key ID: {}, total keys: {}",
|
||||||
|
new_key_id,
|
||||||
|
keys.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if rotation is needed
|
||||||
|
pub async fn should_rotate(&self) -> bool {
|
||||||
|
let keys = self.keys.read().await;
|
||||||
|
let active_id = self.active_key_id.read().await;
|
||||||
|
|
||||||
|
if let Some(active_key) = keys.get(&*active_id) {
|
||||||
|
let time_until_expiry = active_key.expires_at.signed_duration_since(Utc::now());
|
||||||
|
|
||||||
|
// Rotate if less than 10% of the rotation interval remains
|
||||||
|
let threshold = chrono::Duration::from_std(self.config.rotation_interval / 10)
|
||||||
|
.unwrap_or_else(|_| chrono::Duration::days(3));
|
||||||
|
|
||||||
|
time_until_expiry < threshold
|
||||||
|
} else {
|
||||||
|
true // No active key, definitely need to rotate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Start automatic key rotation scheduler
|
||||||
|
pub async fn start_rotation_scheduler(self: Arc<Self>) {
|
||||||
|
if !self.config.auto_rotation_enabled {
|
||||||
|
info!("Automatic key rotation is disabled");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let manager = self;
|
||||||
|
tokio::spawn(async move {
|
||||||
|
// Check every hour
|
||||||
|
let mut interval = tokio::time::interval(Duration::from_secs(3600));
|
||||||
|
|
||||||
|
loop {
|
||||||
|
interval.tick().await;
|
||||||
|
|
||||||
|
if manager.should_rotate().await {
|
||||||
|
if let Err(e) = manager.rotate_keys().await {
|
||||||
|
error!("Key rotation failed: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also clean up expired keys
|
||||||
|
manager.cleanup_expired_keys().await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clean up expired keys
|
||||||
|
async fn cleanup_expired_keys(&self) {
|
||||||
|
let mut keys = self.keys.write().await;
|
||||||
|
let now = Utc::now();
|
||||||
|
let before_count = keys.len();
|
||||||
|
|
||||||
|
keys.retain(|id, info| {
|
||||||
|
if info.expires_at <= now {
|
||||||
|
info!("Removing expired key ID: {}", id);
|
||||||
|
false
|
||||||
|
} else {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let removed = before_count - keys.len();
|
||||||
|
if removed > 0 {
|
||||||
|
info!("Cleaned up {} expired keys", removed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get key manager statistics
|
||||||
|
pub async fn get_stats(&self) -> KeyManagerStats {
|
||||||
|
let keys = self.keys.read().await;
|
||||||
|
let active_id = self.active_key_id.read().await;
|
||||||
|
let now = Utc::now();
|
||||||
|
|
||||||
|
let active_keys = keys.values().filter(|k| k.is_active).count();
|
||||||
|
let total_keys = keys.len();
|
||||||
|
let expired_keys = keys.values().filter(|k| k.expires_at <= now).count();
|
||||||
|
|
||||||
|
KeyManagerStats {
|
||||||
|
active_key_id: *active_id,
|
||||||
|
total_keys,
|
||||||
|
active_keys,
|
||||||
|
expired_keys,
|
||||||
|
rotation_interval: self.config.rotation_interval,
|
||||||
|
auto_rotation_enabled: self.config.auto_rotation_enabled,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct KeyManagerStats {
|
||||||
|
pub active_key_id: u8,
|
||||||
|
pub total_keys: usize,
|
||||||
|
pub active_keys: usize,
|
||||||
|
pub expired_keys: usize,
|
||||||
|
pub rotation_interval: Duration,
|
||||||
|
pub auto_rotation_enabled: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure thread safety
|
||||||
|
unsafe impl Send for KeyManager {}
|
||||||
|
unsafe impl Sync for KeyManager {}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_key_generation() {
|
||||||
|
let config = KeyManagerConfig::default();
|
||||||
|
let manager = KeyManager::new(config).await.unwrap();
|
||||||
|
|
||||||
|
let stats = manager.get_stats().await;
|
||||||
|
assert_eq!(stats.total_keys, 1);
|
||||||
|
assert_eq!(stats.active_keys, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_key_rotation() {
|
||||||
|
let config = KeyManagerConfig {
|
||||||
|
rotation_interval: Duration::from_secs(60),
|
||||||
|
key_retention_period: Duration::from_secs(30),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let manager = KeyManager::new(config).await.unwrap();
|
||||||
|
let initial_stats = manager.get_stats().await;
|
||||||
|
|
||||||
|
// Rotate keys
|
||||||
|
manager.rotate_keys().await.unwrap();
|
||||||
|
|
||||||
|
let new_stats = manager.get_stats().await;
|
||||||
|
assert_eq!(new_stats.total_keys, 2);
|
||||||
|
assert_ne!(new_stats.active_key_id, initial_stats.active_key_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
11
src/lib.rs
Normal file
11
src/lib.rs
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
pub mod config;
|
||||||
|
pub mod error;
|
||||||
|
pub mod handlers;
|
||||||
|
pub mod key_manager;
|
||||||
|
pub mod metrics;
|
||||||
|
pub mod middleware;
|
||||||
|
pub mod state;
|
||||||
|
|
||||||
|
pub use config::AppConfig;
|
||||||
|
pub use error::GatewayError;
|
||||||
|
pub use state::AppState;
|
||||||
203
src/main.rs
Normal file
203
src/main.rs
Normal file
|
|
@ -0,0 +1,203 @@
|
||||||
|
mod config;
|
||||||
|
mod error;
|
||||||
|
mod handlers;
|
||||||
|
mod key_manager;
|
||||||
|
mod metrics;
|
||||||
|
mod middleware;
|
||||||
|
mod state;
|
||||||
|
|
||||||
|
use crate::config::{AppConfig, LogFormat};
|
||||||
|
use crate::state::AppState;
|
||||||
|
use axum::{middleware as axum_middleware, Router};
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
use tokio::signal;
|
||||||
|
use tower_http::compression::CompressionLayer;
|
||||||
|
use tower_http::cors::{Any, CorsLayer};
|
||||||
|
use tower_http::timeout::TimeoutLayer;
|
||||||
|
use tower_http::trace::TraceLayer;
|
||||||
|
use tracing::{info, warn};
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
// Load configuration first
|
||||||
|
let config = AppConfig::from_env()?;
|
||||||
|
|
||||||
|
// Initialize tracing based on config
|
||||||
|
initialize_tracing(&config);
|
||||||
|
|
||||||
|
info!("Starting OHTTP Gateway v{}", env!("CARGO_PKG_VERSION"));
|
||||||
|
info!("Configuration loaded: {:?}", config);
|
||||||
|
|
||||||
|
// Initialize application state
|
||||||
|
let app_state = AppState::new(config.clone()).await?;
|
||||||
|
|
||||||
|
// Start key rotation scheduler
|
||||||
|
if config.key_rotation_enabled {
|
||||||
|
info!("Starting automatic key rotation scheduler");
|
||||||
|
app_state
|
||||||
|
.key_manager
|
||||||
|
.clone()
|
||||||
|
.start_rotation_scheduler()
|
||||||
|
.await;
|
||||||
|
} else {
|
||||||
|
warn!("Automatic key rotation is disabled");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create router
|
||||||
|
let app = create_router(app_state.clone(), &config);
|
||||||
|
|
||||||
|
// Parse socket address
|
||||||
|
let addr: SocketAddr = config.listen_addr.parse()?;
|
||||||
|
let listener = TcpListener::bind(addr).await?;
|
||||||
|
|
||||||
|
info!("OHTTP Gateway listening on {}", addr);
|
||||||
|
info!("Backend URL: {}", config.backend_url);
|
||||||
|
|
||||||
|
if let Some(allowed) = &config.allowed_target_origins {
|
||||||
|
info!("Allowed origins: {:?}", allowed);
|
||||||
|
} else {
|
||||||
|
warn!("No origin restrictions configured - all targets allowed");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start server with graceful shutdown
|
||||||
|
axum::serve(
|
||||||
|
listener,
|
||||||
|
app.into_make_service_with_connect_info::<SocketAddr>(),
|
||||||
|
)
|
||||||
|
.with_graceful_shutdown(shutdown_signal())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
info!("Server stopped gracefully");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn initialize_tracing(config: &AppConfig) {
|
||||||
|
use tracing_subscriber::{fmt, EnvFilter};
|
||||||
|
|
||||||
|
let env_filter =
|
||||||
|
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&config.log_level));
|
||||||
|
|
||||||
|
match config.log_format {
|
||||||
|
LogFormat::Json => {
|
||||||
|
fmt()
|
||||||
|
.json()
|
||||||
|
.with_env_filter(env_filter)
|
||||||
|
.with_target(true)
|
||||||
|
.with_thread_ids(true)
|
||||||
|
.with_file(config.debug_mode)
|
||||||
|
.with_line_number(config.debug_mode)
|
||||||
|
.init();
|
||||||
|
}
|
||||||
|
LogFormat::Default => {
|
||||||
|
fmt()
|
||||||
|
.with_env_filter(env_filter)
|
||||||
|
.with_target(true)
|
||||||
|
.with_thread_ids(true)
|
||||||
|
.with_file(config.debug_mode)
|
||||||
|
.with_line_number(config.debug_mode)
|
||||||
|
.init();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_router(app_state: AppState, config: &AppConfig) -> Router {
|
||||||
|
let mut app = Router::new();
|
||||||
|
|
||||||
|
// Add routes
|
||||||
|
app = app.merge(handlers::routes());
|
||||||
|
|
||||||
|
// Add middleware layers (order matters - first added is executed last)
|
||||||
|
app = app.layer(
|
||||||
|
tower::ServiceBuilder::new()
|
||||||
|
// Outer layers (executed first on request, last on response)
|
||||||
|
.layer(TraceLayer::new_for_http())
|
||||||
|
.layer(CompressionLayer::new())
|
||||||
|
.layer(TimeoutLayer::new(Duration::from_secs(60)))
|
||||||
|
// Security middleware
|
||||||
|
.layer(axum_middleware::from_fn_with_state(
|
||||||
|
app_state.clone(),
|
||||||
|
middleware::security::security_middleware,
|
||||||
|
))
|
||||||
|
// Request validation
|
||||||
|
.layer(axum_middleware::from_fn(
|
||||||
|
middleware::security::request_validation_middleware,
|
||||||
|
))
|
||||||
|
// Logging middleware
|
||||||
|
.layer(axum_middleware::from_fn_with_state(
|
||||||
|
app_state.clone(),
|
||||||
|
middleware::logging::logging_middleware,
|
||||||
|
))
|
||||||
|
// Metrics middleware
|
||||||
|
.layer(axum_middleware::from_fn_with_state(
|
||||||
|
app_state.clone(),
|
||||||
|
middleware::metrics::metrics_middleware,
|
||||||
|
))
|
||||||
|
// CORS configuration
|
||||||
|
.layer(create_cors_layer(config)),
|
||||||
|
);
|
||||||
|
|
||||||
|
app.with_state(app_state)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_cors_layer(config: &AppConfig) -> CorsLayer {
|
||||||
|
if config.debug_mode {
|
||||||
|
// Permissive CORS in debug mode
|
||||||
|
CorsLayer::new()
|
||||||
|
.allow_origin(Any)
|
||||||
|
.allow_methods(Any)
|
||||||
|
.allow_headers(Any)
|
||||||
|
} else {
|
||||||
|
// Restrictive CORS in production
|
||||||
|
CorsLayer::new()
|
||||||
|
.allow_origin([
|
||||||
|
"https://example.com".parse().unwrap(),
|
||||||
|
// Add your allowed origins here
|
||||||
|
])
|
||||||
|
.allow_methods([axum::http::Method::GET, axum::http::Method::POST])
|
||||||
|
.allow_headers([axum::http::header::CONTENT_TYPE, axum::http::header::ACCEPT])
|
||||||
|
.max_age(Duration::from_secs(3600))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn shutdown_signal() {
|
||||||
|
let ctrl_c = async {
|
||||||
|
signal::ctrl_c()
|
||||||
|
.await
|
||||||
|
.expect("failed to install Ctrl+C handler");
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
let terminate = async {
|
||||||
|
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||||
|
.expect("failed to install signal handler")
|
||||||
|
.recv()
|
||||||
|
.await;
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(not(unix))]
|
||||||
|
let terminate = std::future::pending::<()>();
|
||||||
|
|
||||||
|
tokio::select! {
|
||||||
|
_ = ctrl_c => {
|
||||||
|
info!("Received Ctrl+C, starting graceful shutdown");
|
||||||
|
},
|
||||||
|
_ = terminate => {
|
||||||
|
info!("Received SIGTERM, starting graceful shutdown");
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_config_loading() {
|
||||||
|
// Test that default config loads successfully
|
||||||
|
let config = AppConfig::default();
|
||||||
|
assert!(!config.debug_mode);
|
||||||
|
assert!(config.key_rotation_enabled);
|
||||||
|
}
|
||||||
|
}
|
||||||
66
src/metrics.rs
Normal file
66
src/metrics.rs
Normal file
|
|
@ -0,0 +1,66 @@
|
||||||
|
use prometheus::{register_counter, register_gauge, register_histogram, Counter, Gauge, Histogram};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct AppMetrics {
|
||||||
|
pub requests_total: Counter,
|
||||||
|
pub successful_requests_total: Counter,
|
||||||
|
pub decryption_errors_total: Counter,
|
||||||
|
pub encryption_errors_total: Counter,
|
||||||
|
pub backend_errors_total: Counter,
|
||||||
|
pub key_requests_total: Counter,
|
||||||
|
pub request_duration: Histogram,
|
||||||
|
pub active_connections: Gauge,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for AppMetrics {
|
||||||
|
fn default() -> Self {
|
||||||
|
AppMetrics::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AppMetrics {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
requests_total: register_counter!(
|
||||||
|
"ohttp_requests_total",
|
||||||
|
"Total number of OHTTP requests"
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
successful_requests_total: register_counter!(
|
||||||
|
"ohttp_successful_requests_total",
|
||||||
|
"Total number of successful OHTTP requests"
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
decryption_errors_total: register_counter!(
|
||||||
|
"ohttp_decryption_errors_total",
|
||||||
|
"Total number of decryption errors"
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
encryption_errors_total: register_counter!(
|
||||||
|
"ohttp_encryption_errors_total",
|
||||||
|
"Total number of encryption errors"
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
backend_errors_total: register_counter!(
|
||||||
|
"ohttp_backend_errors_total",
|
||||||
|
"Total number of backend errors"
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
key_requests_total: register_counter!(
|
||||||
|
"ohttp_key_requests_total",
|
||||||
|
"Total number of key configuration requests"
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
request_duration: register_histogram!(
|
||||||
|
"ohttp_request_duration_seconds",
|
||||||
|
"Duration of OHTTP request processing"
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
active_connections: register_gauge!(
|
||||||
|
"ohttp_active_connections",
|
||||||
|
"Number of active connections"
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
56
src/middleware/logging.rs
Normal file
56
src/middleware/logging.rs
Normal file
|
|
@ -0,0 +1,56 @@
|
||||||
|
use axum::{body::Body, extract::Request, http::StatusCode, middleware::Next, response::Response};
|
||||||
|
use std::time::Instant;
|
||||||
|
use tracing::{info, warn, Instrument};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
pub async fn logging_middleware(
|
||||||
|
request: Request<Body>,
|
||||||
|
next: Next,
|
||||||
|
) -> Result<Response, StatusCode> {
|
||||||
|
let request_id = Uuid::new_v4();
|
||||||
|
let method = request.method().clone();
|
||||||
|
let uri = request.uri().clone();
|
||||||
|
let user_agent = request
|
||||||
|
.headers()
|
||||||
|
.get("user-agent")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.unwrap_or("unknown")
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let span = tracing::info_span!(
|
||||||
|
"http_request",
|
||||||
|
request_id = %request_id,
|
||||||
|
method = %method,
|
||||||
|
uri = %uri,
|
||||||
|
user_agent = %user_agent
|
||||||
|
);
|
||||||
|
|
||||||
|
async move {
|
||||||
|
let start = Instant::now();
|
||||||
|
|
||||||
|
info!("Processing request");
|
||||||
|
|
||||||
|
let response = next.run(request).await;
|
||||||
|
|
||||||
|
let duration = start.elapsed();
|
||||||
|
let status = response.status();
|
||||||
|
|
||||||
|
if status.is_success() {
|
||||||
|
info!(
|
||||||
|
status = %status,
|
||||||
|
duration_ms = duration.as_millis(),
|
||||||
|
"Request completed successfully"
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
warn!(
|
||||||
|
status = %status,
|
||||||
|
duration_ms = duration.as_millis(),
|
||||||
|
"Request failed"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
.instrument(span)
|
||||||
|
.await
|
||||||
|
}
|
||||||
17
src/middleware/metrics.rs
Normal file
17
src/middleware/metrics.rs
Normal file
|
|
@ -0,0 +1,17 @@
|
||||||
|
// Additional metrics middleware if needed
|
||||||
|
use crate::state::AppState;
|
||||||
|
use axum::{body::Body, extract::Request, extract::State, middleware::Next, response::Response};
|
||||||
|
|
||||||
|
pub async fn metrics_middleware(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
request: Request<Body>,
|
||||||
|
next: Next,
|
||||||
|
) -> Response {
|
||||||
|
state.metrics.active_connections.inc();
|
||||||
|
|
||||||
|
let response = next.run(request).await;
|
||||||
|
|
||||||
|
state.metrics.active_connections.dec();
|
||||||
|
|
||||||
|
response
|
||||||
|
}
|
||||||
3
src/middleware/mod.rs
Normal file
3
src/middleware/mod.rs
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
pub mod logging;
|
||||||
|
pub mod metrics;
|
||||||
|
pub mod security;
|
||||||
188
src/middleware/security.rs
Normal file
188
src/middleware/security.rs
Normal file
|
|
@ -0,0 +1,188 @@
|
||||||
|
use axum::{
|
||||||
|
body::Body,
|
||||||
|
extract::{ConnectInfo, Request, State},
|
||||||
|
http::{header, HeaderMap, StatusCode},
|
||||||
|
middleware::Next,
|
||||||
|
response::{IntoResponse, Response},
|
||||||
|
};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Instant;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
use tracing::{info, warn};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::{config::RateLimitConfig, state::AppState};
|
||||||
|
|
||||||
|
/// Rate limiter implementation
|
||||||
|
pub struct RateLimiter {
|
||||||
|
config: RateLimitConfig,
|
||||||
|
buckets: Arc<Mutex<HashMap<String, TokenBucket>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TokenBucket {
|
||||||
|
tokens: f64,
|
||||||
|
last_update: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RateLimiter {
|
||||||
|
pub fn new(config: RateLimitConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
config,
|
||||||
|
buckets: Arc::new(Mutex::new(HashMap::new())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn check_rate_limit(&self, key: &str) -> bool {
|
||||||
|
let mut buckets = self.buckets.lock().await;
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
let bucket = buckets
|
||||||
|
.entry(key.to_string())
|
||||||
|
.or_insert_with(|| TokenBucket {
|
||||||
|
tokens: self.config.burst_size as f64,
|
||||||
|
last_update: now,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Calculate tokens to add based on time elapsed
|
||||||
|
let elapsed = now.duration_since(bucket.last_update).as_secs_f64();
|
||||||
|
let tokens_to_add = elapsed * (self.config.requests_per_second as f64);
|
||||||
|
|
||||||
|
bucket.tokens = (bucket.tokens + tokens_to_add).min(self.config.burst_size as f64);
|
||||||
|
bucket.last_update = now;
|
||||||
|
|
||||||
|
// Check if we have tokens available
|
||||||
|
if bucket.tokens >= 1.0 {
|
||||||
|
bucket.tokens -= 1.0;
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Security middleware that adds various security headers and checks
|
||||||
|
pub async fn security_middleware(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||||
|
request: Request<Body>,
|
||||||
|
next: Next,
|
||||||
|
) -> Result<Response, StatusCode> {
|
||||||
|
// Generate request ID for tracing
|
||||||
|
let request_id = Uuid::new_v4();
|
||||||
|
|
||||||
|
// Add security headers to the request context
|
||||||
|
let mut request = request;
|
||||||
|
request
|
||||||
|
.headers_mut()
|
||||||
|
.insert("x-request-id", request_id.to_string().parse().unwrap());
|
||||||
|
|
||||||
|
let is_https = matches!(request.uri().scheme_str(), Some("https"));
|
||||||
|
|
||||||
|
// Apply rate limiting if configured
|
||||||
|
if let Some(rate_limit_config) = &state.config.rate_limit {
|
||||||
|
let rate_limiter = RateLimiter::new(rate_limit_config.clone());
|
||||||
|
|
||||||
|
let rate_limit_key = if rate_limit_config.by_ip {
|
||||||
|
addr.ip().to_string()
|
||||||
|
} else {
|
||||||
|
"global".to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
if !rate_limiter.check_rate_limit(&rate_limit_key).await {
|
||||||
|
warn!(
|
||||||
|
"Rate limit exceeded for key: {}, request_id: {}",
|
||||||
|
rate_limit_key, request_id
|
||||||
|
);
|
||||||
|
|
||||||
|
return Ok((
|
||||||
|
StatusCode::TOO_MANY_REQUESTS,
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"X-RateLimit-Limit",
|
||||||
|
rate_limit_config.requests_per_second.to_string(),
|
||||||
|
),
|
||||||
|
("X-RateLimit-Remaining", "0".to_string()),
|
||||||
|
("Retry-After", "1".to_string()),
|
||||||
|
],
|
||||||
|
"Rate limit exceeded",
|
||||||
|
)
|
||||||
|
.into_response());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the request
|
||||||
|
let mut response = next.run(request).await;
|
||||||
|
|
||||||
|
// Add security headers to the response
|
||||||
|
let headers = response.headers_mut();
|
||||||
|
|
||||||
|
// Security headers
|
||||||
|
headers.insert("X-Content-Type-Options", "nosniff".parse().unwrap());
|
||||||
|
headers.insert("X-Frame-Options", "DENY".parse().unwrap());
|
||||||
|
headers.insert("X-XSS-Protection", "1; mode=block".parse().unwrap());
|
||||||
|
headers.insert("Referrer-Policy", "no-referrer".parse().unwrap());
|
||||||
|
headers.insert("X-Request-ID", request_id.to_string().parse().unwrap());
|
||||||
|
|
||||||
|
// HSTS header for HTTPS connections
|
||||||
|
if is_https {
|
||||||
|
headers.insert(
|
||||||
|
"Strict-Transport-Security",
|
||||||
|
"max-age=31536000; includeSubDomains".parse().unwrap(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Content Security Policy
|
||||||
|
headers.insert(
|
||||||
|
"Content-Security-Policy",
|
||||||
|
"default-src 'none'; frame-ancestors 'none';"
|
||||||
|
.parse()
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Remove sensitive headers
|
||||||
|
headers.remove("Server");
|
||||||
|
headers.remove("X-Powered-By");
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Middleware to validate and sanitize incoming requests
|
||||||
|
pub async fn request_validation_middleware(
|
||||||
|
headers: HeaderMap,
|
||||||
|
request: Request<Body>,
|
||||||
|
next: Next,
|
||||||
|
) -> Result<Response, StatusCode> {
|
||||||
|
// Check for required headers only on requests with bodies
|
||||||
|
if matches!(
|
||||||
|
request.method(),
|
||||||
|
&axum::http::Method::POST | &axum::http::Method::PUT | &axum::http::Method::PATCH
|
||||||
|
) && !headers.contains_key(header::CONTENT_TYPE)
|
||||||
|
{
|
||||||
|
return Err(StatusCode::BAD_REQUEST);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate User-Agent
|
||||||
|
if let Some(user_agent) = headers.get(header::USER_AGENT) {
|
||||||
|
if let Ok(ua_str) = user_agent.to_str() {
|
||||||
|
// Block known bad user agents
|
||||||
|
if ua_str.is_empty() || ua_str.contains("bot") || ua_str.contains("crawler") {
|
||||||
|
info!("Blocked suspicious user agent: {}", ua_str);
|
||||||
|
return Err(StatusCode::FORBIDDEN);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for suspicious headers that might indicate attacks
|
||||||
|
const SUSPICIOUS_HEADERS: &[&str] = &["x-forwarded-host", "x-original-url", "x-rewrite-url"];
|
||||||
|
|
||||||
|
for header_name in SUSPICIOUS_HEADERS {
|
||||||
|
if headers.contains_key(*header_name) {
|
||||||
|
warn!("Request contains suspicious header: {}", header_name);
|
||||||
|
return Err(StatusCode::BAD_REQUEST);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(next.run(request).await)
|
||||||
|
}
|
||||||
94
src/state.rs
Normal file
94
src/state.rs
Normal file
|
|
@ -0,0 +1,94 @@
|
||||||
|
use crate::{
|
||||||
|
config::AppConfig,
|
||||||
|
key_manager::{CipherSuiteConfig, KeyManager, KeyManagerConfig},
|
||||||
|
metrics::AppMetrics,
|
||||||
|
};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct AppState {
|
||||||
|
pub key_manager: Arc<KeyManager>,
|
||||||
|
pub http_client: reqwest::Client,
|
||||||
|
pub config: AppConfig,
|
||||||
|
pub metrics: AppMetrics,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AppState {
|
||||||
|
pub async fn new(config: AppConfig) -> Result<Self, Box<dyn std::error::Error>> {
|
||||||
|
// Configure key manager based on app config
|
||||||
|
let key_manager_config = KeyManagerConfig {
|
||||||
|
rotation_interval: config.key_rotation_interval,
|
||||||
|
key_retention_period: config.key_retention_period,
|
||||||
|
auto_rotation_enabled: config.key_rotation_enabled,
|
||||||
|
cipher_suites: get_cipher_suites(&config),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Initialize key manager with or without seed
|
||||||
|
let key_manager = if let Some(seed_hex) = &config.seed_secret_key {
|
||||||
|
let seed = hex::decode(seed_hex)?;
|
||||||
|
Arc::new(KeyManager::new_with_seed(key_manager_config, seed).await?)
|
||||||
|
} else {
|
||||||
|
Arc::new(KeyManager::new(key_manager_config).await?)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create optimized HTTP client for backend requests
|
||||||
|
let http_client = create_http_client(&config)?;
|
||||||
|
|
||||||
|
let metrics = AppMetrics::default();
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
key_manager,
|
||||||
|
http_client,
|
||||||
|
config,
|
||||||
|
metrics,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_cipher_suites(config: &AppConfig) -> Vec<CipherSuiteConfig> {
|
||||||
|
// Default cipher suites matching the Go implementation
|
||||||
|
let mut suites = vec![
|
||||||
|
CipherSuiteConfig {
|
||||||
|
kem: "X25519_SHA256".to_string(),
|
||||||
|
kdf: "HKDF_SHA256".to_string(),
|
||||||
|
aead: "AES_128_GCM".to_string(),
|
||||||
|
},
|
||||||
|
CipherSuiteConfig {
|
||||||
|
kem: "X25519_SHA256".to_string(),
|
||||||
|
kdf: "HKDF_SHA256".to_string(),
|
||||||
|
aead: "CHACHA20_POLY1305".to_string(),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
// Add high-security suite if in production mode
|
||||||
|
if !config.debug_mode {
|
||||||
|
suites.push(CipherSuiteConfig {
|
||||||
|
kem: "P256_SHA256".to_string(),
|
||||||
|
kdf: "HKDF_SHA256".to_string(),
|
||||||
|
aead: "AES_256_GCM".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
suites
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_http_client(config: &AppConfig) -> Result<reqwest::Client, Box<dyn std::error::Error>> {
|
||||||
|
let mut client_builder = reqwest::Client::builder()
|
||||||
|
.timeout(config.request_timeout)
|
||||||
|
.pool_max_idle_per_host(100)
|
||||||
|
.pool_idle_timeout(std::time::Duration::from_secs(30))
|
||||||
|
.tcp_keepalive(std::time::Duration::from_secs(60))
|
||||||
|
.tcp_nodelay(true)
|
||||||
|
.user_agent("ohttp-gateway/1.0")
|
||||||
|
.danger_accept_invalid_certs(config.debug_mode); // Only in debug mode
|
||||||
|
|
||||||
|
// Configure proxy if needed
|
||||||
|
if let Ok(proxy_url) = std::env::var("HTTP_PROXY") {
|
||||||
|
client_builder = client_builder.proxy(reqwest::Proxy::http(proxy_url)?);
|
||||||
|
}
|
||||||
|
if let Ok(proxy_url) = std::env::var("HTTPS_PROXY") {
|
||||||
|
client_builder = client_builder.proxy(reqwest::Proxy::https(proxy_url)?);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(client_builder.build()?)
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue