From 88dcf42286213a6fff4f97b99fbd292b63745a17 Mon Sep 17 00:00:00 2001 From: LeNei Date: Fri, 12 May 2023 20:43:55 +0200 Subject: [PATCH] add base auth implementation --- Cargo.lock | 352 ++++++++++++++++++++++++++++++++++++++++++++- Cargo.toml | 3 + src/auth/claims.rs | 112 +++++++++++++++ src/auth/mod.rs | 25 ++++ src/auth/token.rs | 61 ++++++++ src/config/jwks.rs | 158 ++++++++++++++++++++ src/config/mod.rs | 1 + src/lib.rs | 1 + src/main.rs | 2 +- src/routes/mod.rs | 19 ++- src/startup.rs | 3 + 11 files changed, 731 insertions(+), 6 deletions(-) create mode 100644 src/auth/claims.rs create mode 100644 src/auth/mod.rs create mode 100644 src/auth/token.rs create mode 100644 src/config/jwks.rs diff --git a/Cargo.lock b/Cargo.lock index 8bd2656..397efde 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -127,12 +127,15 @@ dependencies = [ "config", "http", "hyper", + "jsonwebtoken", "log", + "reqwest", "secrecy", "serde", "serde-aux", "serde_json", "sqlx", + "thiserror", "tokio", "tower-http", "tracing", @@ -238,6 +241,16 @@ dependencies = [ "yaml-rust", ] +[[package]] +name = "core-foundation" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.4" @@ -387,18 +400,72 @@ dependencies = [ "serde", ] +[[package]] +name = "encoding_rs" +version = "0.8.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071a31f4ee85403370b58aca746f01041ede6f0da2730960ad001edc2b71b394" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "errno" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a" +dependencies = [ + "errno-dragonfly", + "libc", + "windows-sys 0.48.0", +] + +[[package]] +name = "errno-dragonfly" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" +dependencies = [ + "cc", + "libc", +] + [[package]] name = "event-listener" version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" +[[package]] +name = "fastrand" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" +dependencies = [ + "instant", +] + [[package]] name = "fnv" version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.1.0" @@ -546,6 +613,12 @@ dependencies = [ "libc", ] +[[package]] +name = "hermit-abi" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" + [[package]] name = "hex" version = "0.4.3" @@ -634,6 +707,19 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-tls" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" +dependencies = [ + "bytes", + "hyper", + "native-tls", + "tokio", + "tokio-native-tls", +] + [[package]] name = "iana-time-zone" version = "0.1.56" @@ -687,6 +773,23 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "io-lifetimes" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c66c74d2ae7e79a5a8f7ac924adbe38ee42a859c6539ad869eb51f0b52dc220" +dependencies = [ + "hermit-abi 0.3.1", + "libc", + "windows-sys 0.48.0", +] + +[[package]] +name = "ipnet" +version = "2.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12b6ee2129af8d4fb011108c73d99a1b83a85977f23b82460c0ae2e25bb4b57f" + [[package]] name = "itertools" version = "0.10.5" @@ -711,6 +814,18 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "jsonwebtoken" +version = "8.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6971da4d9c3aa03c3d8f3ff0f4155b534aad021292003895a469716b2a230378" +dependencies = [ + "base64 0.21.0", + "ring", + "serde", + "serde_json", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -738,6 +853,12 @@ version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" +[[package]] +name = "linux-raw-sys" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ece97ea872ece730aed82664c424eb4c8291e1ff2480247ccf7409044bc6479f" + [[package]] name = "lock_api" version = "0.4.9" @@ -811,6 +932,24 @@ dependencies = [ "windows-sys 0.45.0", ] +[[package]] +name = "native-tls" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e" +dependencies = [ + "lazy_static", + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "nom" version = "7.1.3" @@ -856,7 +995,7 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fac9e2da13b5eb447a6ce3d392f23a29d8694bff781bf03a16cd9ac8697593b" dependencies = [ - "hermit-abi", + "hermit-abi 0.2.6", "libc", ] @@ -866,6 +1005,50 @@ version = "1.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" +[[package]] +name = "openssl" +version = "0.10.52" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01b8574602df80f7b85fdfc5392fa884a4e3b3f4f35402c070ab34c3d3f78d56" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.15", +] + +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + +[[package]] +name = "openssl-sys" +version = "0.9.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e17f59264b2809d77ae94f0e1ebabc434773f370d6ca667bd223ea10e06cc7e" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "overload" version = "0.1.1" @@ -892,7 +1075,7 @@ dependencies = [ "cfg-if", "instant", "libc", - "redox_syscall", + "redox_syscall 0.2.16", "smallvec", "winapi", ] @@ -947,6 +1130,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkg-config" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -1010,6 +1199,15 @@ dependencies = [ "bitflags", ] +[[package]] +name = "redox_syscall" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +dependencies = [ + "bitflags", +] + [[package]] name = "redox_users" version = "0.4.3" @@ -1017,7 +1215,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" dependencies = [ "getrandom", - "redox_syscall", + "redox_syscall 0.2.16", "thiserror", ] @@ -1051,6 +1249,43 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a5996294f19bd3aae0453a862ad728f60e6600695733dd5df01da90c54363a3c" +[[package]] +name = "reqwest" +version = "0.11.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13293b639a097af28fc8a90f22add145a9c954e49d77da06263d58cf44d5fb91" +dependencies = [ + "base64 0.21.0", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "hyper", + "hyper-tls", + "ipnet", + "js-sys", + "log", + "mime", + "native-tls", + "once_cell", + "percent-encoding", + "pin-project-lite", + "serde", + "serde_json", + "serde_urlencoded", + "tokio", + "tokio-native-tls", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "winreg", +] + [[package]] name = "ring" version = "0.16.20" @@ -1066,6 +1301,20 @@ dependencies = [ "winapi", ] +[[package]] +name = "rustix" +version = "0.37.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acf8729d8542766f1b2cf77eb034d52f40d375bb8b615d0b147089946e16613d" +dependencies = [ + "bitflags", + "errno", + "io-lifetimes", + "libc", + "linux-raw-sys", + "windows-sys 0.48.0", +] + [[package]] name = "rustls" version = "0.20.8" @@ -1099,6 +1348,15 @@ version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" +[[package]] +name = "schannel" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "713cfb06c7059f3588fb8044c0fad1d09e3c01d225e25b9220dbfdcf16dbb1b3" +dependencies = [ + "windows-sys 0.42.0", +] + [[package]] name = "scopeguard" version = "1.1.0" @@ -1131,6 +1389,29 @@ dependencies = [ "zeroize", ] +[[package]] +name = "security-framework" +version = "2.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a332be01508d814fed64bf28f798a146d73792121129962fdf335bb3c49a4254" +dependencies = [ + "bitflags", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31c9bb296072e961fcbd8853511dd39c2d8be2deb1e17c6860b1d30732b323b4" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "serde" version = "1.0.162" @@ -1409,6 +1690,19 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" +[[package]] +name = "tempfile" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9fbec84f381d5795b08656e4912bec604d162bff9291d6189a78f4c8ab87998" +dependencies = [ + "cfg-if", + "fastrand", + "redox_syscall 0.3.5", + "rustix", + "windows-sys 0.45.0", +] + [[package]] name = "termcolor" version = "1.2.0" @@ -1529,6 +1823,16 @@ dependencies = [ "syn 2.0.15", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.23.4" @@ -1773,6 +2077,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.4" @@ -1826,6 +2136,18 @@ dependencies = [ "wasm-bindgen-shared", ] +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f219e0d211ba40266969f6dbdd90636da12f75bee4fc9d6c23d1260dadb51454" +dependencies = [ + "cfg-if", + "js-sys", + "wasm-bindgen", + "web-sys", +] + [[package]] name = "wasm-bindgen-macro" version = "0.2.84" @@ -1934,6 +2256,21 @@ dependencies = [ "windows-targets 0.48.0", ] +[[package]] +name = "windows-sys" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + [[package]] name = "windows-sys" version = "0.45.0" @@ -2066,6 +2403,15 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" +[[package]] +name = "winreg" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" +dependencies = [ + "winapi", +] + [[package]] name = "yaml-rust" version = "0.4.5" diff --git a/Cargo.toml b/Cargo.toml index 8999e4d..c789b04 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,3 +31,6 @@ tower-http = { version = "0.4.0", features = ["trace", "cors"] } http = "0.2" hyper = { version = "0.14", features = ["full"] } anyhow = "1.0" +jsonwebtoken = {version = "8", default-features = false } +thiserror = "1.0" +reqwest = { version = "0.11", features = ["json"] } diff --git a/src/auth/claims.rs b/src/auth/claims.rs new file mode 100644 index 0000000..8724161 --- /dev/null +++ b/src/auth/claims.rs @@ -0,0 +1,112 @@ +use crate::routes::ApiContext; + +use super::token::{Token, TokenError}; +use axum::{ + async_trait, + extract::{FromRef, FromRequestParts}, + http::request::Parts, + response::IntoResponse, +}; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; + +pub(crate) struct Claims(pub C); + +/// Trait indicating that the type can be parsed from a request. +/// +/// Implementing this trait for your claims data means it can be used as an +/// [extractor][axum::extract] in your request handlers. Assuming you have a +/// struct `TokenClaims` with the attributes you want to parse from a JWT, +/// implementing [`ParseTokenClaims`] will allow you to write a request handler +/// like: +/// ```ignore +/// async fn my_request_handler(Claims(claims): Claims) -> Response { +/// todo!() +/// } +/// ``` +/// +/// The alternative to implementing this trait is to implement +/// [`FromRequestParts`][axum::extract::FromRequestParts] directly for your +/// token claims. +/// +/// # Example +/// ``` +/// use axum::{ +/// http::status::StatusCode, +/// response::{IntoResponse, Response}, +/// Json, +/// }; +/// use axum_jwks::{ParseTokenClaims, TokenError}; +/// use serde::Deserialize; +/// use serde_json::json; +/// +/// #[derive(Deserialize)] +/// struct TokenClaims { +/// sub: String, +/// } +/// +/// impl ParseTokenClaims for TokenClaims { +/// type Rejection = TokenClaimsError; +/// } +/// +/// enum TokenClaimsError { +/// Invalid, +/// Missing, +/// } +/// +/// impl From for TokenClaimsError { +/// fn from(error: TokenError) -> Self { +/// match error { +/// TokenError::Missing => Self::Missing, +/// other => Self::Invalid, +/// } +/// } +/// } +/// +/// impl IntoResponse for TokenClaimsError { +/// fn into_response(self) -> Response { +/// let body = match self { +/// Self::Invalid => json!({ "message": "Invalid token." }), +/// Self::Missing => json!({ "message": "No token provided." }), +/// }; +/// +/// (StatusCode::UNAUTHORIZED, Json(body)).into_response() +/// } +/// } +/// ``` +pub(crate) trait ParseTokenClaims { + /// The type of error returned if the token claims cannot be parsed and + /// validated from the request. + type Rejection: IntoResponse + From; +} + +#[async_trait] +impl FromRequestParts for Claims +where + C: DeserializeOwned + ParseTokenClaims, + ApiContext: FromRef, + S: Send + Sync, +{ + type Rejection = C::Rejection; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let jwks = ApiContext::from_ref(state).jwks; + let token = Token::from_request_parts(parts)?; + + let token_data = jwks.validate_claims(token.value())?; + + Ok(Claims(token_data.claims)) + } +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub(crate) struct KeycloakClaims { + pub sub: String, + exp: i64, + iat: i64, + email: String, +} + +impl ParseTokenClaims for KeycloakClaims { + type Rejection = TokenError; +} diff --git a/src/auth/mod.rs b/src/auth/mod.rs new file mode 100644 index 0000000..5a06084 --- /dev/null +++ b/src/auth/mod.rs @@ -0,0 +1,25 @@ +pub mod claims; +pub mod token; + +use crate::routes::ApiContext; +use axum::extract::State; +use axum::http::Request; +use axum::middleware::Next; +use axum::response::Response; +use claims::{Claims, KeycloakClaims}; +use hyper::StatusCode; + +pub(crate) async fn auth_user( + State(ctx): State, + Claims(claims): Claims, + mut req: Request, + next: Next, +) -> Result { + /* + let user = AuthUser::get_from_db(&claims.sub, &ctx.db, None) + .await + .map_err(|_| StatusCode::UNAUTHORIZED)?; + req.extensions_mut().insert(user); + */ + Ok(next.run(req).await) +} diff --git a/src/auth/token.rs b/src/auth/token.rs new file mode 100644 index 0000000..f9ad77b --- /dev/null +++ b/src/auth/token.rs @@ -0,0 +1,61 @@ +use axum::http::request::Parts; +use axum::response::IntoResponse; +use hyper::header::AUTHORIZATION; +use hyper::StatusCode; +use thiserror::Error; + +/// A JWT provided as a bearer token in an `Authorization` header. +#[derive(PartialEq)] +pub(crate) struct Token(String); + +impl Token { + pub fn value(&self) -> &str { + &self.0 + } + + pub fn from_request_parts(parts: &mut Parts) -> Result { + let auth_header = parts + .headers + .get(AUTHORIZATION) + .ok_or(TokenError::Missing)? + .to_str() + .map_err(|_| TokenError::Missing)?; + + let token = auth_header + .strip_prefix("Bearer ") + .ok_or(TokenError::Missing)?; + + Ok(Token(token.to_string())) + } +} + +/// An error with a JWT. +#[derive(Debug, Error, PartialEq)] +pub(crate) enum TokenError { + /// The token is either malformed or did not pass validation. + #[error("the token is invalid or malformed: {0:?}")] + Invalid(jsonwebtoken::errors::Error), + + /// The token header could not be decoded because it was malformed. + #[error("the token header is malformed: {0:?}")] + InvalidHeader(jsonwebtoken::errors::Error), + + /// No bearer token found in the `Authorization` header. + #[error("no bearer token found")] + Missing, + + /// The token's header does not contain the `kid` attribute used to identify + /// which decoding key should be used. + #[error("the token header does not specify a `kid`")] + MissingKeyId, + + /// The token's `kid` attribute specifies a key that is unknown. + #[error("token uses the unknown key {0:?}")] + UnknownKeyId(String), +} + +impl IntoResponse for TokenError { + fn into_response(self) -> axum::response::Response { + StatusCode::UNAUTHORIZED.into_response() + } +} diff --git a/src/config/jwks.rs b/src/config/jwks.rs new file mode 100644 index 0000000..17caa81 --- /dev/null +++ b/src/config/jwks.rs @@ -0,0 +1,158 @@ +use crate::auth::token::TokenError; +use jsonwebtoken::{ + decode, decode_header, + jwk::{self, AlgorithmParameters}, + DecodingKey, TokenData, Validation, +}; +use serde::de::DeserializeOwned; +use std::collections::HashMap; +use thiserror::Error; +use tracing::{debug, info}; + +/// A container for a set of JWT decoding keys. +/// +/// The container can be used to validate any JWT that identifies a known key +/// through the `kid` attribute in the token's header. +#[derive(Clone)] +pub(crate) struct Jwks { + keys: HashMap, +} + +impl Jwks { + /// Pull a JSON Web Key Set from a specific authority. + pub async fn from_authority(authority: &str, audience: String) -> Result { + Self::from_authority_with_client(&reqwest::Client::default(), authority, audience).await + } + + /// A version of [`from_authority`][Self::from_authority] that allows for + /// passing in a custom [`Client`][reqwest::Client]. + pub async fn from_authority_with_client( + client: &reqwest::Client, + authority: &str, + audience: String, + ) -> Result { + let jwks_url = format!("{}/protocol/openid-connect/certs", authority); + debug!(%authority, %jwks_url, "Fetching JSON Web Key Set."); + let jwks: jwk::JwkSet = client.get(jwks_url).send().await?.json().await?; + + info!( + %authority, + count = jwks.keys.len(), + "Successfully pulled JSON Web Key Set." + ); + + let mut keys = HashMap::new(); + for jwk in jwks.keys { + let kid = jwk.common.key_id.ok_or(JwkError::MissingKeyId)?; + + match &jwk.algorithm { + jwk::AlgorithmParameters::RSA(rsa) => { + let decoding_key = + DecodingKey::from_rsa_components(&rsa.n, &rsa.e).map_err(|err| { + JwkError::DecodingError { + key_id: kid.clone(), + error: err, + } + })?; + let mut validation = Validation::new(jwk.common.algorithm.ok_or( + JwkError::MissingAlgorithm { + key_id: kid.clone(), + }, + )?); + validation.set_audience(&[audience.clone()]); + + keys.insert( + kid, + Jwk { + decoding: decoding_key, + validation, + }, + ); + } + _ => { + info!(%kid, "Ignoring unsupported key.") + } + } + } + + Ok(Self { keys }) + } + + pub fn validate_claims(&self, token: &str) -> Result, TokenError> + where + T: DeserializeOwned, + { + let header = decode_header(token).map_err(|error| { + debug!(?error, "Received token with invalid header."); + + TokenError::InvalidHeader(error) + })?; + let kid = header.kid.as_ref().ok_or_else(|| { + debug!(?header, "Header is missing the `kid` attribute."); + + TokenError::MissingKeyId + })?; + + let key = self.keys.get(kid).ok_or_else(|| { + debug!(%kid, "Token refers to an unknown key."); + + TokenError::UnknownKeyId(kid.to_owned()) + })?; + + let decoded_token: TokenData = + decode(token, &key.decoding, &key.validation).map_err(|error| { + println!("error: {:?}", error); + debug!(?error, "Token is malformed or does not pass validation."); + + TokenError::Invalid(error) + })?; + + Ok(decoded_token) + } +} + +#[derive(Clone)] +struct Jwk { + decoding: DecodingKey, + validation: Validation, +} + +/// An error with the overall set of JSON Web Keys. +#[derive(Debug, Error)] +pub(crate) enum JwksError { + /// There was an error fetching the JWKS from the specified authority. + #[error("could not fetch JWKS from authority: {0}")] + FetchError(#[from] reqwest::Error), + + /// An error with an individual key caused the processing of the JWKS to + /// fail. + #[error("there was an error with an individual key: {0}")] + KeyError(#[from] JwkError), +} + +/// An error with a specific key from a JWKS. +#[derive(Debug, Error)] +pub(crate) enum JwkError { + /// There was an error constructing the decoding key from the RSA components + /// provided by the key. + #[error("could not construct a decoding key for {key_id:?}: {error:?}")] + DecodingError { + key_id: String, + error: jsonwebtoken::errors::Error, + }, + + /// The key does not specify an algorithm to use. + #[error("the key {key_id:?} does not specify an algorithm")] + MissingAlgorithm { key_id: String }, + + /// The key is missing the `kid` attribute. + #[error("the key is missing the `kid` attribute")] + MissingKeyId, + + /// The key uses an unexpected algorithm type. + #[error("the key {key_id:?} uses a non-RSA algorithm {algorithm:?}")] + UnexpectedAlgorithm { + algorithm: AlgorithmParameters, + key_id: String, + }, +} diff --git a/src/config/mod.rs b/src/config/mod.rs index f709818..a32589a 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,5 +1,6 @@ pub mod app; pub mod database; +pub mod jwks; pub mod logging; use app::ApplicationSettings; diff --git a/src/lib.rs b/src/lib.rs index f083a6a..6d8dae7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +pub mod auth; pub mod config; pub mod routes; pub mod startup; diff --git a/src/main.rs b/src/main.rs index 0eca77b..dcdaa9c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,7 +4,7 @@ use axum_sqlx_template::startup::build; #[tokio::main] async fn main() -> anyhow::Result<()> { - let subscriber = get_subscriber("api".into(), "info".into(), std::io::stdout); + let subscriber = get_subscriber("api".into(), "debug".into(), std::io::stdout); init_subscriber(subscriber); let configuration = get_configuration().expect("Failed to read configuration."); diff --git a/src/routes/mod.rs b/src/routes/mod.rs index 2c644c5..e536acc 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -1,3 +1,4 @@ +use axum::middleware; use axum::{routing::get, Router}; use http::Method; use hyper::StatusCode; @@ -5,17 +6,21 @@ use sqlx::PgPool; use tower_http::cors::{Any, CorsLayer}; use tower_http::trace::TraceLayer; +use crate::auth::auth_user; +use crate::config::jwks::Jwks; + #[tracing::instrument(name = "Ping")] async fn ping() -> StatusCode { StatusCode::OK } #[derive(Clone)] -pub struct ApiContext { +pub(crate) struct ApiContext { pub db: PgPool, + pub jwks: Jwks, } -pub fn build_routes(api_context: ApiContext) -> Router { +pub(crate) fn build_routes(api_context: ApiContext) -> Router { let cors = CorsLayer::new() // allow `GET` and `POST` when accessing the resource .allow_methods([Method::GET, Method::POST]) @@ -23,7 +28,17 @@ pub fn build_routes(api_context: ApiContext) -> Router { .allow_origin(Any); Router::new() .route("/ping", get(ping)) + .nest("/auth", build_auth_routes(&api_context)) .layer(TraceLayer::new_for_http()) .layer(cors) .with_state(api_context) } + +fn build_auth_routes(api_context: &ApiContext) -> Router { + Router::new() + .route("/ping", get(ping)) + .route_layer(middleware::from_fn_with_state( + api_context.clone(), + auth_user, + )) +} diff --git a/src/startup.rs b/src/startup.rs index 7640893..eb9c5e9 100644 --- a/src/startup.rs +++ b/src/startup.rs @@ -1,3 +1,4 @@ +use crate::config::jwks::Jwks; use crate::config::Settings; use crate::routes::{build_routes, ApiContext}; use anyhow::Context; @@ -7,6 +8,8 @@ use std::net::TcpListener; pub async fn build(settings: Settings) -> anyhow::Result<()> { let api_context = ApiContext { db: settings.database.get_connection_pool(), + jwks: Jwks::from_authority("http://localhost:8088/realms/test", "account".to_string()) + .await?, }; let api_router = build_routes(api_context); let address = format!(