diff --git a/eventstore-macros/src/lib.rs b/eventstore-macros/src/lib.rs index b85a3906..f3afd49c 100644 --- a/eventstore-macros/src/lib.rs +++ b/eventstore-macros/src/lib.rs @@ -92,9 +92,14 @@ pub fn options(input: TokenStream) -> TokenStream { } impl #name { - /// Performs the command with the given credentials. - pub fn authenticated(mut self, credentials: crate::types::Credentials) -> Self { - self.common_operation_options.credentials = Some(credentials); + /// Performs the command with the given authentication. + /// + /// Accepts `Credentials` or `Authentication` (the latter for Bearer tokens). + pub fn authenticated(mut self, authentication: A) -> Self + where + A: Into, + { + self.common_operation_options.authentication = Some(authentication.into()); self } diff --git a/kurrentdb/Cargo.toml b/kurrentdb/Cargo.toml index 24476b40..b8add70b 100755 --- a/kurrentdb/Cargo.toml +++ b/kurrentdb/Cargo.toml @@ -2,7 +2,7 @@ authors = ["Yorick Laupa "] edition = "2024" name = "kurrentdb" -version = "1.0.0" +version = "1.1.0" # Uncomment if you want to update messages.rs code-gen. # We disabled codegen.rs because it requires having `protoc` installed on your machine diff --git a/kurrentdb/src/event_store/generated.rs b/kurrentdb/src/event_store/generated.rs index 6e4dfdb8..b3142d6d 100644 --- a/kurrentdb/src/event_store/generated.rs +++ b/kurrentdb/src/event_store/generated.rs @@ -8,6 +8,7 @@ use chrono::{DateTime, Utc}; use std::ops::Add; use std::time::{Duration, SystemTime}; +#[allow(dead_code)] pub mod common; pub mod google_rpc; pub mod gossip; diff --git a/kurrentdb/src/grpc.rs b/kurrentdb/src/grpc.rs index c6ddb02b..1016b503 100644 --- a/kurrentdb/src/grpc.rs +++ b/kurrentdb/src/grpc.rs @@ -949,10 +949,10 @@ impl NodeConnection { loop { if let Some(request) = request.take() { - if self.id != request.correlation { - if let Some(handle) = self.handle.clone() { - return Ok(handle); - } + if self.id != request.correlation + && let Some(handle) = self.handle.clone() + { + return Ok(handle); } failed_endpoint = self.handle.take().map(|h| h.endpoint); @@ -1439,7 +1439,7 @@ fn determine_best_node( let member_opt = members.min_by(|a, b| { if let NodePreference::Random = preference { - if rng.next_u32() % 2 == 0 { + if rng.next_u32().is_multiple_of(2) { return Ordering::Greater; } @@ -1447,7 +1447,7 @@ fn determine_best_node( } if preference.match_preference(&a.state) && preference.match_preference(&b.state) { - if rng.next_u32() % 2 == 0 { + if rng.next_u32().is_multiple_of(2) { return Ordering::Less; } else { return Ordering::Greater; diff --git a/kurrentdb/src/http/mod.rs b/kurrentdb/src/http/mod.rs index c50bd307..f06bfc54 100644 --- a/kurrentdb/src/http/mod.rs +++ b/kurrentdb/src/http/mod.rs @@ -1,17 +1,31 @@ use tracing::error; pub mod persistent_subscriptions; +pub(crate) fn resolve_authentication( + options: &crate::options::CommonOperationOptions, + settings: &crate::ClientSettings, +) -> Option { + options.authentication.clone().or_else(|| { + settings + .default_authenticated_user() + .as_ref() + .map(|c| crate::Authentication::Basic(c.clone())) + }) +} + pub fn http_configure_auth( builder: reqwest::RequestBuilder, - creds_opt: Option<&crate::Credentials>, + auth_opt: Option<&crate::Authentication>, ) -> reqwest::RequestBuilder { - if let Some(creds) = creds_opt { - builder.basic_auth( - unsafe { std::str::from_utf8_unchecked(creds.login.as_ref()) }, - unsafe { Some(std::str::from_utf8_unchecked(creds.password.as_ref())) }, - ) - } else { - builder + match auth_opt { + Some(crate::Authentication::Basic(creds)) => builder.basic_auth( + String::from_utf8_lossy(creds.login.as_ref()), + Some(String::from_utf8_lossy(creds.password.as_ref())), + ), + Some(crate::Authentication::Bearer(token)) => { + builder.bearer_auth(String::from_utf8_lossy(token.as_ref())) + } + None => builder, } } @@ -66,3 +80,81 @@ pub async fn http_execute_request( } } } + +#[cfg(test)] +mod auth_tests { + use super::*; + use crate::options::CommonOperationOptions; + use crate::{Authentication, ClientSettings, Credentials}; + + fn settings_from(connection_string: &str) -> ClientSettings { + connection_string + .parse::() + .expect("valid connection string") + } + + fn authorization_header(builder: reqwest::RequestBuilder) -> Option { + let request = builder.build().expect("buildable request"); + request + .headers() + .get(reqwest::header::AUTHORIZATION) + .map(|v| v.to_str().expect("ASCII header").to_owned()) + } + + fn fresh_builder() -> reqwest::RequestBuilder { + reqwest::Client::new().get("http://localhost/") + } + + #[test] + fn http_configure_auth_with_basic_sets_basic_authorization_header() { + let auth = Authentication::basic("admin", "changeit"); + let header = authorization_header(http_configure_auth(fresh_builder(), Some(&auth))) + .expect("authorization header present"); + assert_eq!(header, "Basic YWRtaW46Y2hhbmdlaXQ="); + } + + #[test] + fn http_configure_auth_with_bearer_sets_bearer_authorization_header() { + let auth = Authentication::bearer("abc.def.ghi"); + let header = authorization_header(http_configure_auth(fresh_builder(), Some(&auth))) + .expect("authorization header present"); + assert_eq!(header, "Bearer abc.def.ghi"); + } + + #[test] + fn http_configure_auth_with_none_leaves_authorization_unset() { + assert!(authorization_header(http_configure_auth(fresh_builder(), None)).is_none()); + } + + #[test] + fn resolve_authentication_prefers_per_call_over_default_user() { + let settings = settings_from("esdb://admin:changeit@localhost:2113?tls=false"); + let common = CommonOperationOptions { + authentication: Some(Authentication::bearer("call-token")), + ..Default::default() + }; + + let resolved = resolve_authentication(&common, &settings).expect("present"); + assert_eq!(resolved, Authentication::bearer("call-token")); + } + + #[test] + fn resolve_authentication_falls_back_to_default_user_as_basic() { + let settings = settings_from("esdb://admin:changeit@localhost:2113?tls=false"); + let common = CommonOperationOptions::default(); + + let resolved = resolve_authentication(&common, &settings).expect("present"); + assert_eq!( + resolved, + Authentication::Basic(Credentials::new("admin", "changeit")) + ); + } + + #[test] + fn resolve_authentication_returns_none_when_neither_configured() { + let settings = settings_from("esdb://localhost:2113?tls=false"); + let common = CommonOperationOptions::default(); + + assert!(resolve_authentication(&common, &settings).is_none()); + } +} diff --git a/kurrentdb/src/http/persistent_subscriptions.rs b/kurrentdb/src/http/persistent_subscriptions.rs index 6254e234..e0951317 100644 --- a/kurrentdb/src/http/persistent_subscriptions.rs +++ b/kurrentdb/src/http/persistent_subscriptions.rs @@ -33,14 +33,8 @@ pub(crate) async fn replay_parked_messages( builder = builder.query(&[("stopAt", stop_at.to_string().as_str())]) } - builder = super::http_configure_auth( - builder, - options - .common_operation_options - .credentials - .as_ref() - .or_else(|| settings.default_authenticated_user().as_ref()), - ); + let auth = super::resolve_authentication(&options.common_operation_options, settings); + builder = super::http_configure_auth(builder, auth.as_ref()); super::http_execute_request(builder).await?; @@ -132,14 +126,8 @@ pub(crate) async fn list_all_persistent_subscriptions( .get(format!("{}/subscriptions", handle.url())) .header("content-type", "application/json"); - builder = super::http_configure_auth( - builder, - options - .common_operation_options - .credentials - .as_ref() - .or_else(|| settings.default_authenticated_user().as_ref()), - ); + let auth = super::resolve_authentication(&options.common_operation_options, settings); + builder = super::http_configure_auth(builder, auth.as_ref()); let resp = super::http_execute_request(builder).await?; @@ -179,14 +167,8 @@ where )) .header("content-type", "application/json"); - builder = super::http_configure_auth( - builder, - options - .common_operation_options - .credentials - .as_ref() - .or_else(|| settings.default_authenticated_user().as_ref()), - ); + let auth = super::resolve_authentication(&options.common_operation_options, settings); + builder = super::http_configure_auth(builder, auth.as_ref()); let resp = super::http_execute_request(builder).await?; @@ -228,14 +210,8 @@ where )) .header("content-type", "application/json"); - builder = super::http_configure_auth( - builder, - options - .common_operation_options - .credentials - .as_ref() - .or_else(|| settings.default_authenticated_user().as_ref()), - ); + let auth = super::resolve_authentication(&options.common_operation_options, settings); + builder = super::http_configure_auth(builder, auth.as_ref()); let resp = super::http_execute_request(builder).await?; @@ -258,14 +234,8 @@ pub(crate) async fn restart_persistent_subscription_subsystem( .header("content-type", "application/json") .header("content-length", "0"); - builder = super::http_configure_auth( - builder, - options - .common_operation_options - .credentials - .as_ref() - .or_else(|| settings.default_authenticated_user().as_ref()), - ); + let auth = super::resolve_authentication(&options.common_operation_options, settings); + builder = super::http_configure_auth(builder, auth.as_ref()); super::http_execute_request(builder).await?; diff --git a/kurrentdb/src/operations/gossip.rs b/kurrentdb/src/operations/gossip.rs index 799b4b90..aeae0f37 100644 --- a/kurrentdb/src/operations/gossip.rs +++ b/kurrentdb/src/operations/gossip.rs @@ -78,9 +78,14 @@ pub(crate) async fn http_read( .danger_accept_invalid_certs(!setts.tls_verify_cert) .build()?; + let default_auth = setts + .default_user_name + .as_ref() + .map(|c| crate::Authentication::Basic(c.clone())); + let resp = http_configure_auth( client.get(format!("{}/gossip", handle.url())), - setts.default_user_name.as_ref(), + default_auth.as_ref(), ) .send() .await?; diff --git a/kurrentdb/src/options/mod.rs b/kurrentdb/src/options/mod.rs index 57693d88..6670cdcd 100644 --- a/kurrentdb/src/options/mod.rs +++ b/kurrentdb/src/options/mod.rs @@ -1,6 +1,6 @@ use std::time::Duration; -use crate::Credentials; +use crate::Authentication; pub mod append_to_stream; pub mod batch_append; @@ -21,7 +21,7 @@ pub(crate) trait Options { #[derive(Clone, Default)] pub(crate) struct CommonOperationOptions { - pub(crate) credentials: Option, + pub(crate) authentication: Option, pub(crate) requires_leader: bool, pub(crate) deadline: Option, } diff --git a/kurrentdb/src/request.rs b/kurrentdb/src/request.rs index a02a2f67..4bc51af0 100644 --- a/kurrentdb/src/request.rs +++ b/kurrentdb/src/request.rs @@ -1,6 +1,7 @@ use crate::options::CommonOperationOptions; -use crate::{ClientSettings, NodePreference}; +use crate::{Authentication, ClientSettings, Credentials, NodePreference}; use base64::Engine; +use std::borrow::Cow; pub(crate) fn build_request_metadata( settings: &ClientSettings, @@ -11,21 +12,21 @@ where use tonic::metadata::MetadataValue; let mut metadata = tonic::metadata::MetadataMap::new(); - let credentials = options - .credentials + let authentication: Option> = options + .authentication .as_ref() - .or_else(|| settings.default_authenticated_user().as_ref()); - - if let Some(creds) = credentials { - let login = String::from_utf8_lossy(&creds.login).into_owned(); - let password = String::from_utf8_lossy(&creds.password).into_owned(); - - let basic_auth_string = - base64::engine::general_purpose::STANDARD.encode(format!("{}:{}", login, password)); - let basic_auth = format!("Basic {}", basic_auth_string); - let header_value = MetadataValue::try_from(basic_auth.as_str()) - .expect("Auth header value should be valid metadata header value"); + .map(Cow::Borrowed) + .or_else(|| { + settings + .default_authenticated_user() + .as_ref() + .map(|c| Cow::Owned(Authentication::Basic(c.clone()))) + }); + if let Some(header_value) = authentication + .as_deref() + .and_then(build_authorization_header) + { metadata.insert("authorization", header_value); } @@ -42,3 +43,127 @@ where metadata } + +fn build_authorization_header( + auth: &Authentication, +) -> Option> { + use tonic::metadata::MetadataValue; + + let header = match auth { + Authentication::Basic(Credentials { login, password }) => { + let login = String::from_utf8_lossy(login); + let password = String::from_utf8_lossy(password); + let encoded = + base64::engine::general_purpose::STANDARD.encode(format!("{}:{}", login, password)); + format!("Basic {}", encoded) + } + Authentication::Bearer(token) => { + let token = String::from_utf8_lossy(token); + format!("Bearer {}", token) + } + }; + + match MetadataValue::try_from(header.as_str()) { + Ok(value) => Some(value), + Err(_) => { + tracing::warn!( + auth_kind = auth.kind(), + "authentication value contains characters that are not valid in a gRPC metadata header; the Authorization header will be omitted" + ); + None + } + } +} + +#[cfg(test)] +mod auth_tests { + use super::*; + use crate::AppendToStreamOptions; + use crate::options::Options; + + fn settings_from(connection_string: &str) -> ClientSettings { + connection_string + .parse::() + .expect("valid connection string") + } + + #[test] + fn basic_authentication_produces_base64_basic_header() { + let auth = Authentication::basic("admin", "changeit"); + let header = build_authorization_header(&auth).expect("ASCII header"); + assert_eq!(header.to_str().unwrap(), "Basic YWRtaW46Y2hhbmdlaXQ="); + } + + #[test] + fn bearer_authentication_produces_bearer_header_verbatim() { + let auth = Authentication::bearer("abc.def.ghi"); + let header = build_authorization_header(&auth).expect("ASCII header"); + assert_eq!(header.to_str().unwrap(), "Bearer abc.def.ghi"); + } + + #[test] + fn basic_authentication_with_special_chars_encodes_correctly() { + let auth = Authentication::basic("user@example.com", "p@ss:word"); + let header = build_authorization_header(&auth).expect("ASCII header"); + assert_eq!( + header.to_str().unwrap(), + "Basic dXNlckBleGFtcGxlLmNvbTpwQHNzOndvcmQ=" + ); + } + + #[test] + fn build_request_metadata_skips_bearer_token_with_invalid_chars() { + let settings = settings_from("esdb://localhost:2113?tls=false"); + let options = + AppendToStreamOptions::default().authenticated(Authentication::bearer("token\nleak")); + let metadata = build_request_metadata(&settings, options.common_operation_options()); + assert!(metadata.get("authorization").is_none()); + } + + #[test] + fn no_auth_anywhere_produces_no_authorization_header() { + let settings = settings_from("esdb://localhost:2113?tls=false"); + let options = AppendToStreamOptions::default(); + let metadata = build_request_metadata(&settings, options.common_operation_options()); + + assert!(metadata.get("authorization").is_none()); + } + + #[test] + fn default_user_from_connection_string_falls_through_as_basic() { + let settings = settings_from("esdb://admin:changeit@localhost:2113?tls=false"); + let options = AppendToStreamOptions::default(); + let metadata = build_request_metadata(&settings, options.common_operation_options()); + + assert_eq!( + metadata.get("authorization").unwrap().to_str().unwrap(), + "Basic YWRtaW46Y2hhbmdlaXQ=" + ); + } + + #[test] + fn per_call_bearer_overrides_default_user() { + let settings = settings_from("esdb://admin:changeit@localhost:2113?tls=false"); + let options = + AppendToStreamOptions::default().authenticated(Authentication::bearer("call-token")); + let metadata = build_request_metadata(&settings, options.common_operation_options()); + + assert_eq!( + metadata.get("authorization").unwrap().to_str().unwrap(), + "Bearer call-token" + ); + } + + #[test] + fn authenticated_builder_accepts_credentials_directly() { + let settings = settings_from("esdb://localhost:2113?tls=false"); + let options = + AppendToStreamOptions::default().authenticated(Credentials::new("alice", "secret")); + let metadata = build_request_metadata(&settings, options.common_operation_options()); + + assert_eq!( + metadata.get("authorization").unwrap().to_str().unwrap(), + "Basic YWxpY2U6c2VjcmV0" + ); + } +} diff --git a/kurrentdb/src/types.rs b/kurrentdb/src/types.rs index 097359cb..9bd54b73 100755 --- a/kurrentdb/src/types.rs +++ b/kurrentdb/src/types.rs @@ -53,6 +53,49 @@ impl Credentials { } } +/// Authentication mode used when sending a request to KurrentDB. +/// +/// Supports HTTP Basic auth (login + password) and Bearer token auth (e.g. an +/// OAuth/OIDC access token). `Authentication` implements `From`, +/// so any API that accepts `impl Into` continues to accept a +/// plain `Credentials` value unchanged. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Authentication { + Basic(Credentials), + + /// Sent verbatim in the `Authorization: Bearer ` header. + Bearer(Bytes), +} + +impl Authentication { + pub fn basic(login: S, password: S) -> Self + where + S: Into, + { + Authentication::Basic(Credentials::new(login, password)) + } + + pub fn bearer(token: S) -> Self + where + S: Into, + { + Authentication::Bearer(token.into()) + } + + pub(crate) fn kind(&self) -> &'static str { + match self { + Authentication::Basic(_) => "basic", + Authentication::Bearer(_) => "bearer", + } + } +} + +impl From for Authentication { + fn from(credentials: Credentials) -> Self { + Authentication::Basic(credentials) + } +} + struct CredsVisitor; impl<'de> Visitor<'de> for CredsVisitor {