diff --git a/chaperone.go b/chaperone.go index 357ed44..1fdedaa 100644 --- a/chaperone.go +++ b/chaperone.go @@ -243,6 +243,7 @@ func newProxyServer(plugin sdk.Plugin, rc *runConfig, cfg *config.Config, tracin IdleTimeout: *cfg.Upstream.Timeouts.Idle, TracingEnabled: tracingEnabled, LogTargetAddrMode: cfg.Observability.LogTargetAddr, + ForwardTargets: cfg.ForwardTargets, }) if err != nil { return nil, fmt.Errorf("creating proxy server: %w", err) diff --git a/chaperone_test.go b/chaperone_test.go index efb2ea1..7ba91ee 100644 --- a/chaperone_test.go +++ b/chaperone_test.go @@ -16,6 +16,7 @@ import ( "testing" "time" + "github.com/cloudblue/chaperone/internal/config" "github.com/cloudblue/chaperone/internal/telemetry" "github.com/cloudblue/chaperone/sdk" ) @@ -590,3 +591,72 @@ func TestRun_ProxyPortInUse_CleansUpAdminServer(t *testing.T) { adminListener.Close() } } + +func TestNewProxyServer_PropagatesForwardTargetsFromConfig(t *testing.T) { + // Verify that ForwardTargets from config.Config are properly wired into + // proxy.Config during newProxyServer construction. This test validates that + // YAML forward_targets are propagated to the runtime proxy configuration. + if testing.Short() { + t.Skip("skipping test in short mode") + } + + serverAddr := freeAddr(t) + adminAddr := freeAddr(t) + configPath := t.TempDir() + "/config-with-forward.yaml" + + // Write config with one forward target + cfgContent := fmt.Sprintf(`server: + addr: "%s" + admin_addr: "%s" + tls: + enabled: false +upstream: + header_prefix: "X-Connect" + allow_list: + "example.com": + - "/api/**" +forward_targets: + my-target: + url: "https://my-target.example/api" + auth: + type: "bearer" + token: "secret-token" +`, serverAddr, adminAddr) + + if err := os.WriteFile(configPath, []byte(cfgContent), 0o600); err != nil { + t.Fatalf("failed to write test config: %v", err) + } + + // Load config using the same path as Run() would + cfg, err := config.Load(configPath) + if err != nil { + t.Fatalf("failed to load config: %v", err) + } + + // Verify the config loaded the forward target + if len(cfg.ForwardTargets) != 1 { + t.Fatalf("expected 1 forward target in loaded config, got %d", len(cfg.ForwardTargets)) + } + + // Call newProxyServer which should propagate ForwardTargets + runCfg := &runConfig{version: "test"} + proxySrv, err := newProxyServer(nil, runCfg, cfg, false) + if err != nil { + t.Fatalf("newProxyServer failed: %v", err) + } + + // Verify the forward target was propagated by checking the server's config + srvCfg := proxySrv.Config() + if len(srvCfg.ForwardTargets) != 1 { + t.Errorf("proxy.Config.ForwardTargets length = %d, want 1", len(srvCfg.ForwardTargets)) + } + + target, exists := srvCfg.ForwardTargets["my-target"] + if !exists { + t.Fatal("expected forward target 'my-target' in proxy.Config.ForwardTargets") + } + + if target.URL != "https://my-target.example/api" { + t.Errorf("forward target URL = %q, want %q", target.URL, "https://my-target.example/api") + } +} diff --git a/docs/guides/plugin-development.md b/docs/guides/plugin-development.md index 5321b35..a5b8a7b 100644 --- a/docs/guides/plugin-development.md +++ b/docs/guides/plugin-development.md @@ -837,6 +837,21 @@ func (p *MyPlugin) GetCredentials(ctx context.Context, tx sdk.TransactionContext > timeout and is cancelled if the client disconnects. This prevents your > plugin from leaking goroutines or holding connections to slow backends. +### Forwarding requests (optional) + +For some requests, the right answer is not to inject credentials at all but to forward the request as-is to another service — for example, a customer-side router that handles credential injection, response filtering, and policy enforcement on its own. Implement the optional [`sdk.RequestRouter`](../reference/sdk.md#requestrouter-optional) interface on your plugin to opt into this behavior. Returning a non-nil [`*sdk.RouteAction`](../reference/sdk.md#routeaction) with a `ForwardTo` that names a configured [`forward_target`](../reference/configuration.md#forward-targets) tells Chaperone to skip credential injection and `ModifyResponse` for that request; returning `nil` falls through to the normal credential-injection flow. + +```go +func (p *MyPlugin) RouteRequest(ctx context.Context, tx sdk.TransactionContext, req *http.Request) (*sdk.RouteAction, error) { + if v, ok, _ := tx.DataString("ResellerId"); ok && strings.HasPrefix(v, "migrated-") { + return &sdk.RouteAction{ForwardTo: "customer-router"}, nil + } + return nil, nil +} +``` + +Test routers with [`compliance.VerifyRouter`](../reference/sdk.md#verifyrouter). If you use the contrib [`Mux`](../reference/contrib-plugins.md#mux), prefer [`Mux.HandleForward`](../reference/contrib-plugins.md#handleforward) or the [`forward:`](../reference/contrib-plugins.md#muxconfig) field on a `MuxRouteConfig` — the mux implements `RequestRouter` for you. + --- ## Reference Plugin Walkthrough diff --git a/docs/reference/configuration.md b/docs/reference/configuration.md index caff8b8..a2b0afc 100644 --- a/docs/reference/configuration.md +++ b/docs/reference/configuration.md @@ -225,6 +225,46 @@ export OTEL_SDK_DISABLED=true | `OTEL_TRACES_SAMPLER_ARG` | Sampler argument (e.g., ratio) | | `OTEL_SDK_DISABLED` | Force-disable SDK (`true` always wins) | +### Forward Targets + +Named upstreams that Chaperone can forward requests to instead of calling the vendor directly. Targets are referenced by name from a [`sdk.RouteAction`](sdk.md#routeaction) returned by a [`RequestRouter`](sdk.md#requestrouter-optional). The contrib [`Mux`](contrib-plugins.md#handleforward) implements `RequestRouter` and exposes targets through the `forward:` field on route entries. + +When a router selects a forward target, the Core sends the request to that target's `url` with the configured authentication and timeout, and skips credential injection and `ModifyResponse`. + +```yaml +forward_targets: + customer-router: + url: "https://router.customer.example/v1/intake" + timeout: 15s + auth: + type: "bearer" + token: "${CUSTOMER_ROUTER_TOKEN}" + internal-relay: + url: "https://relay.internal.example/" + timeout: 10s + auth: + type: "none" +``` + +| Key | Type | Default | Description | +|-----|------|---------|-------------| +| `url` | string | — (required) | Absolute base URL of the forward target. Must be `https://` in production builds; `http://` is permitted only in dev builds. | +| `timeout` | duration | `0` (use upstream defaults) | Per-request timeout when calling the forward target. | +| `auth.type` | string | — (required) | `bearer` or `none`. Unknown values are rejected at startup. | +| `auth.token` | string | — | Bearer token used when `auth.type: bearer`. Required and must be non-empty for bearer auth. Supports `${VAR}` and `$VAR` environment variable interpolation. | + +#### Validation rules + +Forward targets are validated at startup. The proxy fails fast with a descriptive error when any of these rules is violated: + +- `url` must be present, parseable, and have a non-empty scheme and host. +- The scheme must be `https` in production builds. In dev builds, `http` is also accepted. +- `auth.type` must be set; the empty string is rejected. +- `auth.type` must be `bearer` or `none`; any other value is rejected. +- When `auth.type: bearer`, `auth.token` must be non-empty after environment variable interpolation. + +Routers that reference a `forward_target` name not defined here are also rejected at startup — see [`MuxConfig`](contrib-plugins.md#muxconfig) for how the contrib mux participates in this check. + ## Allow-List Syntax The allow-list enforces a **default-deny** policy. Only requests matching diff --git a/docs/reference/contrib-plugins.md b/docs/reference/contrib-plugins.md index 5c3c026..fe4d195 100644 --- a/docs/reference/contrib-plugins.md +++ b/docs/reference/contrib-plugins.md @@ -30,6 +30,9 @@ Sub-packages: [cred]: sdk.md#credential [cs]: sdk.md#certificatesigner [rm]: sdk.md#responsemodifier +[rr]: sdk.md#requestrouter-optional +[ra]: sdk.md#routeaction +[ft]: configuration.md#forward-targets --- @@ -39,9 +42,9 @@ Sub-packages: type Mux struct{ /* unexported */ } ``` -A request multiplexer that dispatches to the most specific matching [`CredentialProvider`][cp] based on transaction context fields. `Mux` implements [`Plugin`][plugin] and can be passed directly to `chaperone.Run()`. +A request multiplexer that dispatches to the most specific matching [`CredentialProvider`][cp] based on transaction context fields. `Mux` implements [`Plugin`][plugin] and the optional [`RequestRouter`][rr], and can be passed directly to `chaperone.Run()`. -Safe for concurrent use after construction. `Handle` and `Default` are not safe for concurrent calls — register all routes before serving traffic. +Safe for concurrent use after construction. `Handle`, `HandleForward`, and `Default` are not safe for concurrent calls — register all routes before serving traffic. ### `NewMux` @@ -73,6 +76,20 @@ func (m *Mux) Handle(route Route, provider sdk.CredentialProvider) Registers a route that dispatches matching requests to `provider`. Routes are evaluated by [specificity](#specificity) at dispatch time. Registration order breaks ties. +Mutually exclusive with [`HandleForward`](#handleforward) for the same route: every route in the mux dispatches to either a credential provider or a forward target, never both. + +### `HandleForward` + +```go +func (m *Mux) HandleForward(route Route, target string) +``` + +Registers a route that, when matched, forwards the request to the named [`forward_target`][ft] via the Core's forwarding path. Credential injection and [`ResponseModifier`][rm] are skipped for forwarded requests. + +`target` is the key of an entry in the proxy's `forward_targets` configuration. The Mux treats the name as opaque — cross-validation that every referenced target exists in the configuration happens at startup. Mutually exclusive with [`Handle`](#handle) for the same route. + +The Mux implements [`RequestRouter`][rr]: when a forward route matches, `RouteRequest` returns a [`*RouteAction`][ra] with `ForwardTo` set to `target`. When a [`Handle`](#handle) route matches (or nothing matches), `RouteRequest` returns `nil` and dispatch falls through to `GetCredentials`. + ### `Default` ```go @@ -129,6 +146,105 @@ Delegates to the configured modifier. Returns a nil action and nil error if no m --- +## MuxConfig + +YAML-friendly description of a request multiplexer. A `MuxConfig` can be unmarshalled directly from a YAML document — typically as a sub-section of the distributor's own configuration file — and passed to [`LoadMuxFromConfig`](#loadmuxfromconfig) to build a usable [`*Mux`](#mux). + +```go +type MuxConfig struct { + Routes []MuxRouteConfig `yaml:"routes"` + Fallback *MuxFallbackConfig `yaml:"fallback,omitempty"` +} + +type MuxRouteConfig struct { + Match MatchConfig `yaml:"match"` + Forward string `yaml:"forward,omitempty"` + Credentials *CredentialsConfig `yaml:"credentials,omitempty"` +} + +type MatchConfig struct { + VendorID string `yaml:"vendor_id,omitempty"` + MarketplaceID string `yaml:"marketplace_id,omitempty"` + ProductID string `yaml:"product_id,omitempty"` + EnvironmentID string `yaml:"environment_id,omitempty"` + TargetURL string `yaml:"target_url,omitempty"` + Data map[string]string `yaml:"data,omitempty"` +} + +type CredentialsConfig struct { + Type string `yaml:"type"` +} + +type MuxFallbackConfig struct { + Credentials *CredentialsConfig `yaml:"credentials,omitempty"` + Forward string `yaml:"forward,omitempty"` // rejected; see below +} +``` + +Each route must set **exactly one** of `forward` or `credentials`: + +- `forward` names a [`forward_target`][ft]. The matched request is sent there as-is by the Core, bypassing credential injection and `ModifyResponse`. +- `credentials.type` is a discriminator looked up in the providers map passed to [`LoadMuxFromConfig`](#loadmuxfromconfig). The distributor constructs the providers (OAuth, Microsoft SAM, etc.) and registers them in the map. + +`match.data` mirrors the [`Route.Data` matcher](#data-matcher) and follows the same semantics: missing keys, wrong-type values, and empty strings yield non-match. + +The optional `fallback` runs when no route matches. Only `fallback.credentials` is supported in v1 — a `fallback.forward` is rejected at load time. A silent fallback-forward would route every unmatched request, including misconfigured or unexpected traffic, to a customer-side service without credential injection; forward routes must be explicit per-match. + +### `LoadMuxFromConfig` + +```go +func LoadMuxFromConfig(cfg MuxConfig, providers map[string]sdk.CredentialProvider) (*Mux, error) +``` + +Builds a [`*Mux`](#mux) from a `MuxConfig` and a lookup of pre-built providers. Forward routes are registered via [`HandleForward`](#handleforward); credential routes via [`Handle`](#handle); the fallback (if any) via [`Default`](#default). + +Validation rules — the first violation returns an error that names the offending route by index (`routes[2]`) or `fallback`: + +- Every route must set exactly one of `forward` or `credentials`. +- A route's `credentials.type` must be non-empty and present in `providers`. +- The fallback, if present, must set `credentials` (not `forward`), and `credentials.type` must be non-empty and present in `providers`. + +#### Example + +```yaml +mux: + routes: + # Migrated tenants → customer-side handler (no credential injection). + - match: + vendor_id: "microsoft-*" + data: + ResellerId: "migrated-*" + forward: customer-router + + # Everyone else on Microsoft → SAM provider. + - match: + vendor_id: "microsoft-*" + credentials: + type: microsoft-sam + + # Acme → OAuth client credentials. + - match: + vendor_id: "acme" + credentials: + type: acme-oauth + + fallback: + credentials: + type: microsoft-sam +``` + +```go +providers := map[string]sdk.CredentialProvider{ + "microsoft-sam": msSource, + "acme-oauth": acmeOAuth, +} +mux, err := contrib.LoadMuxFromConfig(cfg.Mux, providers) +``` + +Pair this with the [`forward_targets`][ft] section in the proxy configuration so that every `forward:` name resolves to a real target. + +--- + ## Route ```go @@ -138,6 +254,7 @@ type Route struct { ProductID string TargetURL string EnvironmentID string + Data map[string]string } ``` @@ -150,6 +267,29 @@ Matching criteria for dispatching requests. Each non-empty field must match the | `ProductID` | `tx.ProductID` | `"MICROSOFT_*"` | | `TargetURL` | `tx.TargetURL` (scheme stripped) | `"*.graph.microsoft.com/**"` | | `EnvironmentID` | `tx.EnvironmentID` | `"prod-*"` | +| `Data` | `tx.Data[key]` per entry | `{"ResellerId": "migrated-*"}` | + +#### Data matcher + +`Data` is a map of `tx.Data` keys to glob patterns. The route matches only if **every** entry's pattern matches the corresponding `tx.DataString(key)` value. Each entry contributes 1 to [`Specificity()`](#specificity). + +Non-match cases (the route is skipped): + +- The key is absent from `tx.Data`. +- The value is present but has the wrong type (not a string). +- The value is present but is an empty string. + +Invalid or missing data must never silently dispatch to a provider, so these cases all yield a non-match rather than a partial match. The same semantics apply when the matcher is configured via [`MuxConfig`](#muxconfig) (the YAML `match.data` field). + +```go +mux.Handle( + contrib.Route{ + VendorID: "microsoft-*", + Data: map[string]string{"ResellerId": "migrated-*"}, + }, + legacyProvider, +) +``` ### `Matches` @@ -165,7 +305,7 @@ Reports whether every non-empty field in the route matches the corresponding `tx func (r Route) Specificity() int ``` -Returns the number of non-empty fields (0–5). The mux prefers routes with higher specificity when multiple routes match. +Returns the number of non-empty fields, where each `Data` entry counts as one. The mux prefers routes with higher specificity when multiple routes match. | Route | Specificity | |-------|-------------| @@ -173,6 +313,7 @@ Returns the number of non-empty fields (0–5). The mux prefers routes with high | `Route{VendorID: "acme"}` | 1 | | `Route{MarketplaceID: "MP-*", ProductID: "MICROSOFT_SAAS"}` | 2 | | `Route{EnvironmentID: "prod", VendorID: "acme", TargetURL: "api.acme.com/**"}` | 3 | +| `Route{VendorID: "acme", Data: map[string]string{"ResellerId": "migrated-*"}}` | 2 | ### Glob patterns @@ -860,6 +1001,8 @@ import "github.com/cloudblue/chaperone/plugins/contrib" | `ErrTokenExpiredOnArrival` | `"token expired on arrival"` | Token `expires_in` is less than or equal to the expiry margin. Token too short-lived to cache. | No | | `ErrSigningNotConfigured` | `"certificate signing not configured"` | `SignCSR` called on [`AsPlugin`](#asplugin) or [`Mux`](#mux) with no signer configured. | No | | `ErrTokenEndpointUnavailable` | `"token endpoint unavailable"` | Network error, HTTP 5xx, or HTTP 429 from the token endpoint. | Yes | +| `ErrUnexpectedForwardAction` | `"matched route is a forward action; GetCredentials should not have been called"` | The Core called `GetCredentials` for a route registered via [`HandleForward`](#handleforward). Indicates an integration bug — forwarding should have been handled by `RouteRequest`. | No | +| `ErrNilCredentialProvider` | `"credential action has nil provider"` | A credential route was registered with a nil provider. Programming error caught at dispatch time. | No | --- diff --git a/docs/reference/sdk.md b/docs/reference/sdk.md index 3b09804..7c350f6 100644 --- a/docs/reference/sdk.md +++ b/docs/reference/sdk.md @@ -144,6 +144,54 @@ or passing through ISV validation errors. --- +### RequestRouter (optional) + +```go +type RequestRouter interface { + RouteRequest(ctx context.Context, tx TransactionContext, req *http.Request) (*RouteAction, error) +} +``` + +`RequestRouter` is an **optional** plugin capability. Plugins that do not implement it retain the default behavior: every request flows through `GetCredentials` and the configured allow-list to the vendor. + +Implementations are invoked before `GetCredentials`. Returning a non-nil [`*RouteAction`](#routeaction) with a non-empty `ForwardTo` causes the Core to forward the request to the named [`forward_target`](configuration.md#forward-targets) and skip both credential injection and `ModifyResponse`. Returning `(nil, nil)` (or a non-nil action with an empty `ForwardTo`) is the fall-through signal: the Core continues with the normal credential-injection flow. + +Use `RequestRouter` when some requests should bypass credential injection entirely — for example, when a customer-side service handles authentication and response filtering on its own, and Chaperone's role is to forward the request as-is to that service. + +#### `RouteRequest` Parameters + +| Parameter | Type | Description | +|-----------|------|-------------| +| `ctx` | [`context.Context`][ctx] | Bounded by the Core with a request timeout. | +| `tx` | [`TransactionContext`](#transactioncontext) | Metadata extracted from inbound request headers. | +| `req` | [`*http.Request`][req] | The inbound request (not yet mutated by Core). | + +#### `RouteRequest` Return Values + +| Return | Type | Description | +|--------|------|-------------| +| action | [`*RouteAction`](#routeaction) | Non-nil with non-empty `ForwardTo` to forward the request; `nil` or empty `ForwardTo` to fall through to credential injection. | +| err | `error` | Any error during routing decision. Errors fail the request. | + +#### Example + +```go +type MyRouter struct{ /* ... */ } + +func (r *MyRouter) RouteRequest(ctx context.Context, tx sdk.TransactionContext, req *http.Request) (*sdk.RouteAction, error) { + // Forward migrated tenants to the new customer-side handler; + // everything else falls through to credential injection. + if v, ok, _ := tx.DataString("ResellerId"); ok && strings.HasPrefix(v, "migrated-") { + return &sdk.RouteAction{ForwardTo: "customer-router"}, nil + } + return nil, nil +} +``` + +Use [`compliance.VerifyRouter`](#verifyrouter) to test routers against the minimal contract. + +--- + ## Types ### TransactionContext @@ -281,6 +329,24 @@ type ResponseAction struct { --- +### RouteAction + +```go +type RouteAction struct { + ForwardTo string +} +``` + +Returned by a [`RequestRouter`](#requestrouter-optional) to tell the Core how to handle a request. + +#### Fields + +| Field | Type | Description | +|-------|------|-------------| +| `ForwardTo` | `string` | Name of a [`forward_target`](configuration.md#forward-targets) entry in the proxy configuration. When non-empty, the Core forwards the request to that target and skips credential injection and `ModifyResponse`. When empty, the action is equivalent to returning `nil` from `RouteRequest` — the request falls through to credential injection. | + +--- + ## Errors ### `ErrInvalidContextData` @@ -486,6 +552,22 @@ panicking, and that returned credentials have a valid `ExpiresAt`. See the [Plugin Development Guide](../guides/plugin-development.md) for usage in your test suite. +### `VerifyRouter` + +```go +func VerifyRouter(t *testing.T, router sdk.RequestRouter) +``` + +Exercises a [`RequestRouter`](#requestrouter-optional) implementation against the minimal contract: it must accept a cancelled context without panicking, and either return `(nil, nil)` (fall-through) or a non-nil [`*RouteAction`](#routeaction). + +`VerifyRouter` is opt-in: only plugins that implement `RequestRouter` need to call it. Plugins that do not implement it remain valid under `VerifyContract`. + +```go +func TestRouter(t *testing.T) { + compliance.VerifyRouter(t, NewMyRouter()) +} +``` + --- ## Module Versioning diff --git a/internal/config/config.go b/internal/config/config.go index dfc90d4..6ea0862 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -17,6 +17,37 @@ type Config struct { Upstream UpstreamConfig `yaml:"upstream"` // Observability holds logging and profiling configuration. Observability ObservabilityConfig `yaml:"observability"` + // ForwardTargets describes named upstreams that Chaperone can forward + // requests to instead of calling the vendor directly. Targets are + // referenced by name from a sdk.RouteAction returned by a RequestRouter. + // Keys are stable identifiers used by routers; values describe how to + // reach and authenticate to the target. + ForwardTargets map[string]ForwardTargetConfig `yaml:"forward_targets"` +} + +// ForwardTargetConfig describes a named upstream that Chaperone can forward +// requests to instead of calling the vendor. Targets are referenced by name +// from a sdk.RouteAction. +type ForwardTargetConfig struct { + // URL is the absolute base URL of the forward target. Must use https + // in production builds. http is permitted only in dev builds. + URL string `yaml:"url"` + // Timeout is the per-request timeout when calling the forward target. + // Zero means "use the default upstream timeouts". + Timeout time.Duration `yaml:"timeout"` + // Auth describes how Chaperone authenticates to the forward target. + Auth ForwardTargetAuthConfig `yaml:"auth"` +} + +// ForwardTargetAuthConfig describes how Chaperone authenticates to a forward +// target. Only "bearer" and "none" are supported in v1. +type ForwardTargetAuthConfig struct { + // Type selects the auth mechanism: "bearer" or "none". + Type string `yaml:"type"` + // Token is the bearer token used when Type == "bearer". The value + // supports ${VAR} and $VAR environment variable interpolation so that + // secrets can live outside of the config file. + Token string `yaml:"token"` } // ServerConfig holds the server binding and TLS configuration. diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 1dcbcd8..e4890c6 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -7,6 +7,7 @@ import ( "errors" "os" "path/filepath" + "strings" "testing" "time" @@ -1347,6 +1348,278 @@ func TestValidate_InvalidAdminAddr_ReturnsError(t *testing.T) { } } +// ----------------------------------------------------------------------------- +// forward_targets tests +// ----------------------------------------------------------------------------- + +func TestConfig_ForwardTargets_HTTPSAndBearer_Parses(t *testing.T) { + t.Setenv("COMPANY_B_TOKEN", "secret-token-abc") + + yaml := ` +forward_targets: + company-b: + url: "https://company-b.internal/ingress" + timeout: 30s + auth: + type: bearer + token: "${COMPANY_B_TOKEN}" +` + cfg, err := LoadFromBytes([]byte(yaml)) + if err != nil { + t.Fatalf("LoadFromBytes: %v", err) + } + target, ok := cfg.ForwardTargets["company-b"] + if !ok { + t.Fatal("missing forward_targets[company-b]") + } + if target.URL != "https://company-b.internal/ingress" { + t.Errorf("URL = %q", target.URL) + } + if target.Timeout != 30*time.Second { + t.Errorf("Timeout = %v, want 30s", target.Timeout) + } + if target.Auth.Type != "bearer" { + t.Errorf("Auth.Type = %q", target.Auth.Type) + } + if target.Auth.Token != "secret-token-abc" { + t.Errorf("Auth.Token = %q (expected env-interpolated value)", target.Auth.Token) + } +} + +func TestConfig_ForwardTargets_BearerMissingToken_Fails(t *testing.T) { + yaml := ` +forward_targets: + company-b: + url: "https://company-b.internal/ingress" + auth: + type: bearer +` + _, err := LoadFromBytes([]byte(yaml)) + if err == nil { + t.Fatal("expected error for bearer auth without token, got nil") + } + if !errors.Is(err, ErrForwardTargetBearerTokenMissing) { + t.Errorf("error = %v, want ErrForwardTargetBearerTokenMissing", err) + } +} + +func TestConfig_ForwardTargets_HTTPRejected_InProductionBuild(t *testing.T) { + // Default build behaviour: http forward targets are rejected. + // (allowInsecureForwardTargets defaults to "false") + yaml := ` +forward_targets: + company-b: + url: "http://company-b.internal/ingress" + auth: { type: none } +` + _, err := LoadFromBytes([]byte(yaml)) + if err == nil { + t.Fatal("expected error for http:// forward target in production build") + } + if !errors.Is(err, ErrForwardTargetInsecureURL) { + t.Errorf("error = %v, want ErrForwardTargetInsecureURL", err) + } +} + +// TestConfig_ForwardTargets_Matrix exercises the validation matrix for both +// the URL and the auth subsection of forward targets. +func TestConfig_ForwardTargets_Matrix(t *testing.T) { + // Reserve a guaranteed-unset env var name for one of the cases. + const unsetVar = "CHAPERONE_TEST_DEFINITELY_UNSET_VAR_XYZ" + if err := os.Unsetenv(unsetVar); err != nil { + t.Fatalf("unsetenv: %v", err) + } + + tests := []struct { + name string + yaml string + wantErr bool + wantErrIs error // optional: errors.Is target + }{ + { + name: "auth_none_no_token_passes", + yaml: ` +forward_targets: + x: + url: "https://x.example.com" + auth: { type: none } +`, + wantErr: false, + }, + { + name: "auth_none_with_token_passes_token_ignored", + yaml: ` +forward_targets: + x: + url: "https://x.example.com" + auth: { type: none, token: "ignored" } +`, + wantErr: false, + }, + { + name: "auth_bearer_empty_token_fails", + yaml: ` +forward_targets: + x: + url: "https://x.example.com" + auth: { type: bearer, token: "" } +`, + wantErr: true, + wantErrIs: ErrForwardTargetBearerTokenMissing, + }, + { + name: "auth_bearer_unset_env_var_resolves_to_empty_fails", + yaml: ` +forward_targets: + x: + url: "https://x.example.com" + auth: { type: bearer, token: "${` + unsetVar + `}" } +`, + wantErr: true, + wantErrIs: ErrForwardTargetBearerTokenMissing, + }, + { + name: "auth_type_missing_fails", + yaml: ` +forward_targets: + x: + url: "https://x.example.com" + auth: {} +`, + wantErr: true, + wantErrIs: ErrForwardTargetAuthTypeMissing, + }, + { + name: "auth_type_unsupported_fails", + yaml: ` +forward_targets: + x: + url: "https://x.example.com" + auth: { type: oauth2 } +`, + wantErr: true, + wantErrIs: ErrForwardTargetAuthTypeUnsupported, + }, + { + name: "url_empty_fails", + yaml: ` +forward_targets: + x: + url: "" + auth: { type: none } +`, + wantErr: true, + wantErrIs: ErrForwardTargetMissingURL, + }, + { + name: "url_not_a_url_fails", + yaml: ` +forward_targets: + x: + url: "not a url" + auth: { type: none } +`, + wantErr: true, + wantErrIs: ErrForwardTargetInvalidURL, + }, + { + name: "url_ftp_scheme_fails", + yaml: ` +forward_targets: + x: + url: "ftp://x.example.com/path" + auth: { type: none } +`, + wantErr: true, + wantErrIs: ErrForwardTargetInsecureURL, + }, + { + name: "url_http_in_prod_fails", + yaml: ` +forward_targets: + x: + url: "http://x.example.com" + auth: { type: none } +`, + wantErr: true, + wantErrIs: ErrForwardTargetInsecureURL, + }, + { + name: "two_targets_both_parse", + yaml: ` +forward_targets: + a: + url: "https://a.example.com" + auth: { type: none } + b: + url: "https://b.example.com" + auth: { type: bearer, token: "tok" } +`, + wantErr: false, + }, + { + name: "two_targets_one_invalid_surfaces_name", + yaml: ` +forward_targets: + good: + url: "https://good.example.com" + auth: { type: none } + bad: + url: "https://bad.example.com" + auth: { type: bearer } +`, + wantErr: true, + wantErrIs: ErrForwardTargetBearerTokenMissing, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := LoadFromBytes([]byte(tc.yaml)) + if tc.wantErr { + if err == nil { + t.Fatalf("expected error, got nil") + } + if tc.wantErrIs != nil && !errors.Is(err, tc.wantErrIs) { + t.Errorf("error = %v, want errors.Is(%v) = true", err, tc.wantErrIs) + } + // For the "surfaces name" case, ensure the offending name appears. + if tc.name == "two_targets_one_invalid_surfaces_name" { + if !strings.Contains(err.Error(), `"bad"`) { + t.Errorf("expected error to mention bad target name, got %q", err.Error()) + } + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +// TestConfig_ForwardTargets_HTTPAllowed_InDevBuild verifies that the +// dev-build toggle permits http forward targets. Uses +// SetAllowInsecureForwardTargetsForTesting to simulate a dev build. +func TestConfig_ForwardTargets_HTTPAllowed_InDevBuild(t *testing.T) { + cleanup := SetAllowInsecureForwardTargetsForTesting(true) + defer cleanup() + + yaml := ` +forward_targets: + x: + url: "http://x.example.com" + auth: { type: none } +` + cfg, err := LoadFromBytes([]byte(yaml)) + if err != nil { + t.Fatalf("LoadFromBytes: %v", err) + } + if _, ok := cfg.ForwardTargets["x"]; !ok { + t.Fatal("missing forward_targets[x]") + } +} + func TestValidate_AllNegativeTimeouts_ReturnsErrors(t *testing.T) { // Arrange - test all timeout validations tlsDisabled := false diff --git a/internal/config/defaults.go b/internal/config/defaults.go index d037284..7c4cb03 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -9,6 +9,8 @@ package config import ( "strings" "time" + + "github.com/cloudblue/chaperone/internal/security" ) // Default server configuration values. @@ -78,15 +80,12 @@ const ( // defaultSensitiveHeaders returns the list of headers that MUST be redacted // in logs. This is a security-critical default per Design Spec Section 5.3. // Returns a new copy each time to prevent accidental mutation. +// +// The list itself is owned by internal/security so that the forward proxy +// path (which has no access to the user-merged config) and the vendor proxy +// path share a single source of truth. func defaultSensitiveHeaders() []string { - return []string{ - "Authorization", - "Proxy-Authorization", - "Cookie", - "Set-Cookie", - "X-API-Key", - "X-Auth-Token", - } + return security.DefaultSensitiveHeaders() } // durationPtr returns a pointer to the given duration. diff --git a/internal/config/forward_targets.go b/internal/config/forward_targets.go new file mode 100644 index 0000000..6754881 --- /dev/null +++ b/internal/config/forward_targets.go @@ -0,0 +1,146 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package config + +import ( + "errors" + "fmt" + "net/url" + "os" +) + +// allowInsecureForwardTargets controls whether HTTP (non-HTTPS) forward +// target URLs are permitted. This is set at compile time via ldflags; +// the default is "false" (secure). +// +// SECURITY: In production builds, this MUST be "false" to prevent the +// bearer token (or any other credential) from being sent over an +// unencrypted connection. +// +// Set via: -ldflags "-X 'github.com/cloudblue/chaperone/internal/config.allowInsecureForwardTargets=true'" +// +// The matching dev-build variable for vendor targets lives in +// internal/proxy.allowInsecureTargets; we mirror the same pattern here +// rather than reach across packages because internal/proxy already +// imports internal/config (so we cannot import the other direction). +var allowInsecureForwardTargets = "false" + +// testOverrideInsecureForwardTargets is used by tests to temporarily +// allow http forward targets. It is always nil unless set by test code. +var testOverrideInsecureForwardTargets *bool + +// AllowInsecureForwardTargets reports whether http forward target URLs +// are permitted. This is true only in dev builds or under an explicit +// test override. +func AllowInsecureForwardTargets() bool { + if testOverrideInsecureForwardTargets != nil { + return *testOverrideInsecureForwardTargets + } + return allowInsecureForwardTargets == "true" +} + +// SetAllowInsecureForwardTargetsForTesting temporarily enables http +// forward target URLs. It returns a cleanup function that restores the +// previous value. Intended for tests only. +func SetAllowInsecureForwardTargetsForTesting(allow bool) func() { + old := testOverrideInsecureForwardTargets + testOverrideInsecureForwardTargets = &allow + return func() { + testOverrideInsecureForwardTargets = old + } +} + +// Forward target auth types. +const ( + // ForwardAuthNone disables authentication on the forward target. + ForwardAuthNone = "none" + // ForwardAuthBearer attaches a static bearer token on the forward target. + ForwardAuthBearer = "bearer" +) + +// Forward target validation errors. They are exported so that callers +// (and tests) can match on them with errors.Is. +var ( + // ErrForwardTargetMissingURL is returned when a forward target has no url. + ErrForwardTargetMissingURL = errors.New("forward target: url is required") + // ErrForwardTargetInvalidURL is returned when a forward target url cannot be parsed. + ErrForwardTargetInvalidURL = errors.New("forward target: invalid url") + // ErrForwardTargetInsecureURL is returned when a forward target url is not https + // (and http is not allowed in this build). + ErrForwardTargetInsecureURL = errors.New("forward target: url must be https") + // ErrForwardTargetAuthTypeMissing is returned when auth.type is unset. + ErrForwardTargetAuthTypeMissing = errors.New("forward target: auth.type is required") + // ErrForwardTargetAuthTypeUnsupported is returned when auth.type is not one of the supported values. + ErrForwardTargetAuthTypeUnsupported = errors.New("forward target: unsupported auth.type") + // ErrForwardTargetBearerTokenMissing is returned when bearer auth is configured without a token. + ErrForwardTargetBearerTokenMissing = errors.New("forward target: bearer auth requires a non-empty token") +) + +// interpolateForwardTargetEnv expands environment variable references in +// the credential-bearing fields of every forward target. Only fields +// listed here are interpolated, so the rest of the config behaves +// exactly like before. Uses os.ExpandEnv semantics (${VAR} and $VAR). +func interpolateForwardTargetEnv(cfg *Config) { + if cfg == nil { + return + } + for name, t := range cfg.ForwardTargets { + t.Auth.Token = os.ExpandEnv(t.Auth.Token) + cfg.ForwardTargets[name] = t + } +} + +// validateForwardTargets validates every entry in cfg.ForwardTargets. +// When allowHTTP is true (dev builds), http urls are permitted; otherwise +// only https is accepted. Errors include the offending target name so +// operators can locate the misconfiguration quickly. +func validateForwardTargets(cfg *Config, allowHTTP bool) error { + var errs []error + for name, t := range cfg.ForwardTargets { + if err := validateForwardTarget(name, t, allowHTTP); err != nil { + errs = append(errs, err) + } + } + if len(errs) > 0 { + return errors.Join(errs...) + } + return nil +} + +// validateForwardTarget validates a single forward target entry. Split +// out from validateForwardTargets to keep cognitive complexity in check. +func validateForwardTarget(name string, t ForwardTargetConfig, allowHTTP bool) error { + if t.URL == "" { + return fmt.Errorf("forward_targets[%q]: %w", name, ErrForwardTargetMissingURL) + } + u, err := url.Parse(t.URL) + if err != nil { + return fmt.Errorf("forward_targets[%q]: %w: %w", name, ErrForwardTargetInvalidURL, err) + } + if u.Scheme == "" || u.Host == "" { + return fmt.Errorf("forward_targets[%q]: %w: %q", name, ErrForwardTargetInvalidURL, t.URL) + } + if u.Scheme != "https" && (!allowHTTP || u.Scheme != "http") { + return fmt.Errorf("forward_targets[%q]: %w (got %q)", name, ErrForwardTargetInsecureURL, u.Scheme) + } + return validateForwardTargetAuth(name, t.Auth) +} + +// validateForwardTargetAuth validates the auth subsection of a forward target. +func validateForwardTargetAuth(name string, auth ForwardTargetAuthConfig) error { + switch auth.Type { + case ForwardAuthNone: + return nil + case ForwardAuthBearer: + if auth.Token == "" { + return fmt.Errorf("forward_targets[%q]: %w", name, ErrForwardTargetBearerTokenMissing) + } + return nil + case "": + return fmt.Errorf("forward_targets[%q]: %w (expected %q or %q)", + name, ErrForwardTargetAuthTypeMissing, ForwardAuthBearer, ForwardAuthNone) + default: + return fmt.Errorf("forward_targets[%q]: %w: %q", name, ErrForwardTargetAuthTypeUnsupported, auth.Type) + } +} diff --git a/internal/config/loader.go b/internal/config/loader.go index 84c7b55..a8750fb 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -24,6 +24,34 @@ const ConfigEnvVar = "CHAPERONE_CONFIG" // DefaultConfigPath is the default configuration file path. const DefaultConfigPath = "./config.yaml" +// LoadFromBytes parses a YAML configuration document from memory, applies +// defaults, runs environment variable interpolation for credential-bearing +// fields, and validates fields that are entirely structural (currently +// forward_targets). Unlike Load, it does NOT enforce that the global +// configuration is complete — it deliberately skips checks that depend +// on the filesystem (TLS file existence) or on full deployment context +// (allow_list, addresses). Use this helper from tests and from any caller +// that needs to parse a config fragment without a file backing it. +func LoadFromBytes(data []byte) (*Config, error) { + var cfg Config + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("parsing config: %w", err) + } + + applyDefaults(&cfg) + + // Env interpolation runs BEFORE validation so that an unset variable + // resolving to an empty string is caught by validateForwardTargets + // (e.g. bearer token must be non-empty). + interpolateForwardTargetEnv(&cfg) + + if err := validateForwardTargets(&cfg, AllowInsecureForwardTargets()); err != nil { + return nil, fmt.Errorf("configuration validation failed: %w", err) + } + + return &cfg, nil +} + // Load loads configuration from a YAML file with environment variable overrides. // If configPath is empty, it checks CHAPERONE_CONFIG env var, then falls back to DefaultConfigPath. // Environment variables take precedence over YAML values (12-Factor App methodology). @@ -45,6 +73,11 @@ func Load(configPath string) (*Config, error) { return nil, fmt.Errorf("environment variable override failed: %w", err) } + // Apply env interpolation for credential-bearing fields (e.g. bearer + // tokens in forward_targets). Done before validation so unset vars + // resolve to an empty string and fail the non-empty token check. + interpolateForwardTargetEnv(cfg) + // Validate the final configuration if err := Validate(cfg); err != nil { return nil, fmt.Errorf("configuration validation failed: %w", err) diff --git a/internal/config/validate.go b/internal/config/validate.go index c1d6104..64e406b 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -61,6 +61,11 @@ func Validate(cfg *Config) error { errs = append(errs, err) } + // Validate forward targets (optional section; no error if empty) + if err := validateForwardTargets(cfg, AllowInsecureForwardTargets()); err != nil { + errs = append(errs, err) + } + if len(errs) > 0 { return errors.Join(errs...) } diff --git a/internal/proxy/errors.go b/internal/proxy/errors.go new file mode 100644 index 0000000..817644d --- /dev/null +++ b/internal/proxy/errors.go @@ -0,0 +1,28 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package proxy + +import ( + "encoding/json" + "net/http" +) + +// errorResponse is the JSON structure for error responses returned from +// the proxy. It mirrors the shape used by internal/router so that clients +// observe a consistent error envelope across the request lifecycle. +type errorResponse struct { + Error string `json:"error"` +} + +// respondError writes a JSON error response with the given status code and +// message. It is intentionally a small duplicate of the helper in +// internal/router/middleware.go — the helper is four lines and the duplication +// keeps the proxy package free of an extra dependency on router internals. +func respondError(w http.ResponseWriter, statusCode int, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + + resp := errorResponse{Error: message} + _ = json.NewEncoder(w).Encode(resp) // Error ignored: client may have disconnected +} diff --git a/internal/proxy/export_test.go b/internal/proxy/export_test.go index 46a8ac9..bec411d 100644 --- a/internal/proxy/export_test.go +++ b/internal/proxy/export_test.go @@ -11,3 +11,32 @@ import "net/http" func (s *Server) WithMiddlewareForTesting(handler http.Handler) http.Handler { return s.withMiddleware(handler) } + +// ForwardProxyForTesting returns the *ForwardProxy registered under the given +// name, or nil if no such target was configured. Exposed for external tests +// that verify the per-target forward registry built at startup. +func (s *Server) ForwardProxyForTesting(name string) *ForwardProxy { + return s.forwardProxies[name] +} + +// ForwardProxyCountForTesting returns the number of forward proxies built at +// startup. Exposed for external tests that verify the registry is non-nil +// even when no forward targets are configured. +func (s *Server) ForwardProxyCountForTesting() int { + return len(s.forwardProxies) +} + +// ForwardProxiesNilForTesting reports whether the forward proxy map is nil. +// Exposed so external tests can assert the registry is non-nil (the spec +// requires an empty map, not nil) even with zero configured targets. +func (s *Server) ForwardProxiesNilForTesting() bool { + return s.forwardProxies == nil +} + +// RouterForTesting returns the RequestRouter detected on the plugin at startup, +// or nil if the plugin does not implement RequestRouter. Exposed for external +// tests that verify RequestRouter type assertion and capability detection. +// This must remain unexported for proxy package but tests access via this method. +func (s *Server) RouterForTesting() interface{} { + return s.router +} diff --git a/internal/proxy/forward_integration_test.go b/internal/proxy/forward_integration_test.go new file mode 100644 index 0000000..2f6447e --- /dev/null +++ b/internal/proxy/forward_integration_test.go @@ -0,0 +1,411 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package proxy_test + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + + "github.com/cloudblue/chaperone/internal/config" + "github.com/cloudblue/chaperone/internal/proxy" + "github.com/cloudblue/chaperone/sdk" +) + +// ============================================================================= +// Task 14: End-to-end forwarding with bearer auth. +// +// These tests drive a full request through the proxy.Server handler stack +// (AllowList middleware → handleProxy → router branch → ForwardProxy) and +// assert the forwarding contract: +// +// - The fake "Company B" target receives the request with X-Connect-* +// context headers intact, the configured bearer token, and NO trace of +// the inbound Authorization. +// - The fake target's response body, status, and (non-sensitive) headers +// reach the client recorder. +// - Sensitive headers reflected by the target are stripped before reaching +// the client (defense-in-depth). +// - No credentials are ever injected via the plugin: the forward path is +// mutually exclusive with the credential-injection path. +// +// The test uses an inline sdk.RequestRouter implementation (forwardRouter) +// rather than contrib.Mux because internal/proxy lives in the Core module +// and contrib lives in a separate Go module (importing it from here would +// add a stale require to the Core go.mod). The Mux's RouteRequest → +// RouteAction translation is exercised by contrib's own tests; this test +// targets the Core forwarding pipeline. +// ============================================================================= + +// forwardRouter is a minimal sdk.Plugin + sdk.RequestRouter that routes +// requests with the configured VendorID to a named forward target. Any +// invocation of GetCredentials is recorded so the test can assert the +// credential path was never taken. +type forwardRouter struct { + matchVendorID string + forwardTo string + credCallCount atomic.Int32 +} + +func (r *forwardRouter) GetCredentials(_ context.Context, _ sdk.TransactionContext, _ *http.Request) (*sdk.Credential, error) { + r.credCallCount.Add(1) + return nil, nil +} + +func (r *forwardRouter) SignCSR(_ context.Context, _ []byte) ([]byte, error) { + return nil, errors.New("not implemented") +} + +func (r *forwardRouter) ModifyResponse(_ context.Context, _ sdk.TransactionContext, _ *http.Response) (*sdk.ResponseAction, error) { + return nil, nil +} + +func (r *forwardRouter) RouteRequest(_ context.Context, tx sdk.TransactionContext, _ *http.Request) (*sdk.RouteAction, error) { + if tx.VendorID == r.matchVendorID { + return &sdk.RouteAction{ForwardTo: r.forwardTo}, nil + } + return nil, nil +} + +var _ sdk.Plugin = (*forwardRouter)(nil) +var _ sdk.RequestRouter = (*forwardRouter)(nil) + +// forwardIntegrationConfig builds the proxy.Config used by these tests with +// an allow-list that permits the X-Connect-Target-URL host (api.vendor.com) +// AND the actual fake-target host. The X-Connect-Target-URL value is +// validated by the allow-list middleware even though the router short- +// circuits the request to the forward target. +func forwardIntegrationConfig(t *testing.T, plugin sdk.Plugin, forwardURL string) proxy.Config { + t.Helper() + cfg := testConfig() + cfg.Plugin = plugin + cfg.AllowList = map[string][]string{ + "api.vendor.com": {"/**"}, + mustTargetHostPort(t, forwardURL): {"/**"}, + } + cfg.ForwardTargets = map[string]config.ForwardTargetConfig{ + "company-b": { + URL: forwardURL, + Auth: config.ForwardTargetAuthConfig{ + Type: config.ForwardAuthBearer, + Token: "expected-token", + }, + }, + } + return cfg +} + +// makeProxyRequest constructs a /proxy request configured for the forward +// scenario: VendorID=vendor-a triggers the router, X-Connect-Target-URL is +// the vendor URL (never dialed — the router forwards before that point), +// and an inbound Authorization header is set so we can verify it does NOT +// leak to the forward target. +func makeProxyRequest(body string) *http.Request { + var bodyReader io.Reader + if body != "" { + bodyReader = strings.NewReader(body) + } + req := httptest.NewRequest(http.MethodPost, "/proxy", bodyReader) + req.Header.Set("X-Connect-Target-URL", "https://api.vendor.com/v1/foo") + req.Header.Set("X-Connect-Vendor-ID", "vendor-a") + req.Header.Set("X-Connect-Marketplace-ID", "marketplace-1") + req.Header.Set("Authorization", "Bearer connect-original") + req.Header.Set("Content-Type", "application/json") + return req +} + +// TestForwardIntegration_BearerAuth_FullFlow exercises the end-to-end +// forwarding path with bearer auth across a matrix of behavioral assertions. +// Each subtest spins up its own fake target with the response semantics +// relevant to that scenario. +func TestForwardIntegration_BearerAuth_FullFlow(t *testing.T) { + t.Run("headers, body, status, and bearer all propagate", func(t *testing.T) { + var ( + callCount atomic.Int32 + seenHeaders http.Header + seenBody string + seenTargetURL string + seenVendorID string + seenAuth string + seenContentType string + seenTraceID string + ) + + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount.Add(1) + seenHeaders = r.Header.Clone() + seenTargetURL = r.Header.Get("X-Connect-Target-URL") + seenVendorID = r.Header.Get("X-Connect-Vendor-ID") + seenAuth = r.Header.Get("Authorization") + seenContentType = r.Header.Get("Content-Type") + seenTraceID = r.Header.Get("Connect-Request-ID") + body, _ := io.ReadAll(r.Body) + seenBody = string(body) + + w.Header().Set("X-Custom-Reply", "vendor-says-hi") + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, `{"vendor":"reply"}`) + })) + t.Cleanup(target.Close) + + plugin := &forwardRouter{matchVendorID: "vendor-a", forwardTo: "company-b"} + cfg := forwardIntegrationConfig(t, plugin, target.URL) + srv := mustNewServer(t, cfg) + + requestBody := `{"action":"create","name":"test"}` + const traceID = "test-trace-1234" + rec := httptest.NewRecorder() + req := makeProxyRequest(requestBody) + req.Header.Set("Connect-Request-ID", traceID) + srv.Handler().ServeHTTP(rec, req) + + if got := callCount.Load(); got != 1 { + t.Fatalf("fake target call count = %d, want 1", got) + } + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want 200. body=%s", rec.Code, rec.Body.String()) + } + + // Inbound context headers propagated. + if seenTargetURL != "https://api.vendor.com/v1/foo" { + t.Errorf("forwarded X-Connect-Target-URL = %q, want %q", seenTargetURL, "https://api.vendor.com/v1/foo") + } + if seenVendorID != "vendor-a" { + t.Errorf("forwarded X-Connect-Vendor-ID = %q, want %q", seenVendorID, "vendor-a") + } + if seenContentType != "application/json" { + t.Errorf("forwarded Content-Type = %q, want %q", seenContentType, "application/json") + } + // Connect-Request-ID propagates to the forward target for trace + // continuity (Design Spec §8.3). TraceIDMiddleware preserves valid + // inbound IDs verbatim, so the target sees the exact ID we set. + if seenTraceID != traceID { + t.Errorf("forwarded Connect-Request-ID = %q, want %q", seenTraceID, traceID) + } + + // Outbound Authorization is the configured bearer; inbound is stripped. + if seenAuth != "Bearer expected-token" { + t.Errorf("forwarded Authorization = %q, want %q", seenAuth, "Bearer expected-token") + } + if strings.Contains(seenAuth, "connect-original") { + t.Errorf("inbound Authorization leaked to forward target: %q", seenAuth) + } + + // Inbound Authorization MUST NOT have been propagated under any name. + for _, v := range seenHeaders.Values("Authorization") { + if strings.Contains(v, "connect-original") { + t.Errorf("inbound Authorization leaked (full header list): %q", v) + } + } + + // Request body propagates verbatim. + if seenBody != requestBody { + t.Errorf("forwarded body = %q, want %q", seenBody, requestBody) + } + + // Response from target reaches the client. + if got := rec.Header().Get("X-Custom-Reply"); got != "vendor-says-hi" { + t.Errorf("client X-Custom-Reply = %q, want %q", got, "vendor-says-hi") + } + if got := rec.Body.String(); got != `{"vendor":"reply"}` { + t.Errorf("client body = %q, want %q", got, `{"vendor":"reply"}`) + } + + // The plugin's credential path must NOT have been taken. + if got := plugin.credCallCount.Load(); got != 0 { + t.Errorf("GetCredentials was called %d time(s) on the forward path; want 0", got) + } + }) + + t.Run("target status 418 propagates to client", func(t *testing.T) { + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusTeapot) + })) + t.Cleanup(target.Close) + + plugin := &forwardRouter{matchVendorID: "vendor-a", forwardTo: "company-b"} + cfg := forwardIntegrationConfig(t, plugin, target.URL) + srv := mustNewServer(t, cfg) + + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, makeProxyRequest("")) + + if rec.Code != http.StatusTeapot { + t.Errorf("status = %d, want %d", rec.Code, http.StatusTeapot) + } + if got := plugin.credCallCount.Load(); got != 0 { + t.Errorf("GetCredentials was called on the forward path; want 0") + } + }) + + t.Run("response body propagates verbatim", func(t *testing.T) { + const wantBody = "hello world" + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, wantBody) + })) + t.Cleanup(target.Close) + + plugin := &forwardRouter{matchVendorID: "vendor-a", forwardTo: "company-b"} + cfg := forwardIntegrationConfig(t, plugin, target.URL) + srv := mustNewServer(t, cfg) + + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, makeProxyRequest("")) + + if got := rec.Body.String(); got != wantBody { + t.Errorf("client body = %q, want %q", got, wantBody) + } + }) + + t.Run("custom response header reaches client", func(t *testing.T) { + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("X-Request-Trace", "from-target-42") + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(target.Close) + + plugin := &forwardRouter{matchVendorID: "vendor-a", forwardTo: "company-b"} + cfg := forwardIntegrationConfig(t, plugin, target.URL) + srv := mustNewServer(t, cfg) + + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, makeProxyRequest("")) + + if got := rec.Header().Get("X-Request-Trace"); got != "from-target-42" { + t.Errorf("client X-Request-Trace = %q, want %q", got, "from-target-42") + } + }) + + t.Run("bearer token equals the configured value exactly", func(t *testing.T) { + var seenAuth string + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seenAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(target.Close) + + plugin := &forwardRouter{matchVendorID: "vendor-a", forwardTo: "company-b"} + cfg := forwardIntegrationConfig(t, plugin, target.URL) + srv := mustNewServer(t, cfg) + + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, makeProxyRequest("")) + + if seenAuth != "Bearer expected-token" { + t.Errorf("Authorization at target = %q, want %q", seenAuth, "Bearer expected-token") + } + }) + + t.Run("inbound Authorization not reflected at target", func(t *testing.T) { + var seenAuth string + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seenAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(target.Close) + + plugin := &forwardRouter{matchVendorID: "vendor-a", forwardTo: "company-b"} + cfg := forwardIntegrationConfig(t, plugin, target.URL) + srv := mustNewServer(t, cfg) + + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, makeProxyRequest("")) + + if strings.Contains(seenAuth, "connect-original") { + t.Errorf("inbound Authorization leaked to target: %q", seenAuth) + } + }) + + t.Run("sensitive response headers reflected by target are stripped", func(t *testing.T) { + // Defense-in-depth: even if the target reflects Authorization or + // Set-Cookie back, ForwardProxy.modifyResponse strips them before + // they reach the client. + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Authorization", "Bearer leaked") + w.Header().Set("Set-Cookie", "session=should-not-leak") + w.Header().Set("X-Safe", "ok-to-see") + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(target.Close) + + plugin := &forwardRouter{matchVendorID: "vendor-a", forwardTo: "company-b"} + cfg := forwardIntegrationConfig(t, plugin, target.URL) + srv := mustNewServer(t, cfg) + + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, makeProxyRequest("")) + + if got := rec.Header().Get("Authorization"); got != "" { + t.Errorf("reflected Authorization not stripped from client response: %q", got) + } + if got := rec.Header().Get("Set-Cookie"); got != "" { + t.Errorf("reflected Set-Cookie not stripped from client response: %q", got) + } + // Non-sensitive header should still propagate so the strip is targeted. + if got := rec.Header().Get("X-Safe"); got != "ok-to-see" { + t.Errorf("non-sensitive header X-Safe = %q, want %q", got, "ok-to-see") + } + }) + + t.Run("target unreachable returns 502 with sanitized JSON body", func(t *testing.T) { + // Start and immediately close a server to obtain a guaranteed-closed + // port. The returned URL is for a now-defunct listener; dials will + // fail at the transport layer and hit ForwardProxy.errorHandler. + closed := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) + closedURL := closed.URL + closed.Close() + + plugin := &forwardRouter{matchVendorID: "vendor-a", forwardTo: "company-b"} + cfg := forwardIntegrationConfig(t, plugin, closedURL) + srv := mustNewServer(t, cfg) + + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, makeProxyRequest("")) + + if rec.Code != http.StatusBadGateway { + t.Errorf("status = %d, want %d. body=%s", rec.Code, http.StatusBadGateway, rec.Body.String()) + } + const wantBody = `{"error":"forward target unavailable"}` + if got := strings.TrimSpace(rec.Body.String()); got != wantBody { + t.Errorf("body = %q, want %q", got, wantBody) + } + if got := rec.Header().Get("Content-Type"); got != "application/json" { + t.Errorf("Content-Type = %q, want %q", got, "application/json") + } + // Still no credentials path even on transport failure. + if got := plugin.credCallCount.Load(); got != 0 { + t.Errorf("GetCredentials was called %d time(s) on the forward path; want 0", got) + } + }) + + t.Run("forward path bypasses credential injection entirely", func(t *testing.T) { + // Sanity check across two requests: the forward route consistently + // short-circuits before reaching the credential provider. + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(target.Close) + + plugin := &forwardRouter{matchVendorID: "vendor-a", forwardTo: "company-b"} + cfg := forwardIntegrationConfig(t, plugin, target.URL) + srv := mustNewServer(t, cfg) + + for i := 0; i < 3; i++ { + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, makeProxyRequest("")) + if rec.Code != http.StatusOK { + t.Fatalf("iteration %d: status = %d, want 200. body=%s", i, rec.Code, rec.Body.String()) + } + } + if got := plugin.credCallCount.Load(); got != 0 { + t.Errorf("GetCredentials was called %d time(s) across 3 forwarded requests; want 0", got) + } + }) +} diff --git a/internal/proxy/forward_proxy.go b/internal/proxy/forward_proxy.go new file mode 100644 index 0000000..926b1b6 --- /dev/null +++ b/internal/proxy/forward_proxy.go @@ -0,0 +1,249 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package proxy + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "net/http" + "net/http/httputil" + "net/url" + "strings" + "time" + + "github.com/cloudblue/chaperone/internal/config" + "github.com/cloudblue/chaperone/internal/security" + "github.com/cloudblue/chaperone/internal/telemetry" +) + +// defaultForwardTimeout is applied when ForwardTargetConfig.Timeout is zero +// or negative. It bounds the response-header wait so that a hung forward +// target cannot pin a Connect goroutine indefinitely. +const defaultForwardTimeout = 30 * time.Second + +// ForwardProxy wraps a httputil.ReverseProxy for a single forward target. +// One instance is built at startup per named target in config.ForwardTargets +// and reused across requests. +// +// Compared with the vendor proxy path (server.createReverseProxy), the +// forward path is intentionally stripped down: +// +// - Inbound Authorization is dropped so Connect's auth posture cannot +// leak to the forward target. +// - A static bearer token is injected when auth.type == "bearer". +// - X-Connect-* context headers are forwarded verbatim (the forward +// target — typically the customer's own system — needs them). +// - The plugin's ResponseModifier is NOT invoked, and Core error +// normalization is NOT applied; the forward target's status code and +// body pass through unmodified. +// - Sensitive response headers (the static default set from +// internal/security) are stripped as a defense-in-depth measure +// against credential reflection. +type ForwardProxy struct { + name string + target *url.URL + auth config.ForwardTargetAuthConfig + proxy *httputil.ReverseProxy +} + +// NewForwardProxy builds a forward proxy for the given target configuration. +// The returned handler is safe for concurrent use and is intended to be +// cached at startup and reused across requests. +func NewForwardProxy(name string, cfg config.ForwardTargetConfig) (*ForwardProxy, error) { + u, err := url.Parse(cfg.URL) + if err != nil { + return nil, fmt.Errorf("forward_target[%q]: parse url: %w", name, err) + } + if u.Scheme == "" || u.Host == "" { + return nil, fmt.Errorf("forward_target[%q]: invalid url %q", name, cfg.URL) + } + + fp := &ForwardProxy{name: name, target: u, auth: cfg.Auth} + fp.proxy = &httputil.ReverseProxy{ + Director: fp.director, + ModifyResponse: fp.modifyResponse, + ErrorHandler: fp.errorHandler, + Transport: newForwardTransport(cfg.Timeout), + } + return fp, nil +} + +// ServeHTTP forwards the request to the configured target. +// +// Observability: +// - chaperone_forward_target_duration_seconds{target} is observed for every +// request that enters ServeHTTP, including those that fail at the +// transport layer. The deferred observation captures end-to-end time +// (entry → response written / error handled) so dashboards reflect the +// real wall-clock cost of forwarding even when the target is unreachable. +// - chaperone_forward_target_errors_total{target,kind} is incremented by +// errorHandler for infrastructure failures. 5xx responses returned by the +// target are NOT counted here — they are target responses, not Chaperone +// errors. +func (fp *ForwardProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + start := time.Now() + defer func() { + telemetry.ForwardTargetDuration.WithLabelValues(fp.name).Observe(time.Since(start).Seconds()) + }() + fp.proxy.ServeHTTP(w, r) // #nosec G704 -- target URL is fixed at startup from forward_targets config; the inbound request cannot influence the destination +} + +// director rewrites the outbound request: target host/scheme, path joining, +// inbound-Authorization stripping, and (optional) bearer-token injection. +// +// SECURITY: The bearer token must not be logged anywhere in this function. +// The static sensitive_headers redaction in the request logger already +// covers Authorization; do not emit log lines that include req.Header here. +func (fp *ForwardProxy) director(req *http.Request) { + req.URL.Scheme = fp.target.Scheme + req.URL.Host = fp.target.Host + req.URL.Path = singleJoiningSlash(fp.target.Path, req.URL.Path) + if fp.target.RawQuery != "" && req.URL.RawQuery != "" { + req.URL.RawQuery = fp.target.RawQuery + "&" + req.URL.RawQuery + } else { + req.URL.RawQuery = fp.target.RawQuery + req.URL.RawQuery + } + req.Host = fp.target.Host + + // Strip inbound Authorization to avoid leaking Connect's auth posture + // to the forward target. This happens regardless of fp.auth.Type — the + // forward target should only ever see credentials we choose to inject. + req.Header.Del("Authorization") + + if fp.auth.Type == config.ForwardAuthBearer { + req.Header.Set("Authorization", "Bearer "+fp.auth.Token) + } + + // X-Connect-* headers are intentionally preserved — the forward target + // (typically the customer's own system) needs the routing/context. + // Connect-Request-ID is likewise preserved by default; no action needed. +} + +// modifyResponse strips the static set of sensitive headers from the forward +// target's response. This is defense-in-depth: even if the forward target +// reflects an Authorization header back, it never reaches Connect. +// +// NOTE: Unlike the vendor proxy path, we do NOT invoke the plugin's +// ResponseModifier and we do NOT apply Core error normalization. Forward +// targets are by definition outside the plugin contract; their responses +// pass through verbatim modulo the credential-reflection sanitization. +func (fp *ForwardProxy) modifyResponse(resp *http.Response) error { + security.StripSensitiveResponseHeaders(resp.Header) + return nil +} + +// errorHandler returns 502 Bad Gateway when the forward target is +// unreachable. The error itself is not surfaced to the caller to avoid +// leaking internal infrastructure details (host names, ports, etc.). +// +// SECURITY: Do not include the error string in the response body. Internal +// observability of the cause belongs in logs, not in the wire response. +func (fp *ForwardProxy) errorHandler(w http.ResponseWriter, _ *http.Request, err error) { + telemetry.ForwardTargetErrors.WithLabelValues(fp.name, classifyForwardError(err)).Inc() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadGateway) + _, _ = w.Write([]byte(`{"error":"forward target unavailable"}`)) +} + +// classifyForwardError maps a transport-level error from httputil.ReverseProxy +// into a small set of well-known kinds for the forward_target_errors_total +// metric. Go's net/http error surface is intentionally fuzzy, so we classify +// what we can confidently identify and fall back to "other" for the rest. +// +// Order matters: timeouts and TLS failures can both surface as *net.OpError +// with a wrapped underlying error, so we check the most specific signals +// first. +func classifyForwardError(err error) string { + if err == nil { + return "other" + } + + // Timeout: context deadline, response-header timeout, or any error that + // implements the net.Error timeout contract. + if errors.Is(err, context.DeadlineExceeded) { + return "timeout" + } + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return "timeout" + } + + // TLS: handshake / record errors. The TLS package's error types aren't all + // exported, so we fall back to a substring check against the well-known + // "tls:" prefix used by crypto/tls error messages. + var recordHeaderErr tls.RecordHeaderError + if errors.As(err, &recordHeaderErr) { + return "tls" + } + if msg := err.Error(); strings.Contains(msg, "tls:") || strings.Contains(msg, "x509:") { + return "tls" + } + + // Connection: DNS failure, refused, reset, EOF mid-handshake, etc. + var dnsErr *net.DNSError + if errors.As(err, &dnsErr) { + return "connection" + } + var opErr *net.OpError + if errors.As(err, &opErr) { + return "connection" + } + + return "other" +} + +// newForwardTransport returns the per-target transport. Timeouts apply to +// the response-header wait (i.e., how long we are willing to block before +// the target writes status); body streaming is not bounded here, which +// matches the streaming semantics of httputil.ReverseProxy. +// +// The transport is built by cloning http.DefaultTransport so we inherit +// HTTP/2 negotiation (ForceAttemptHTTP2), connection pooling (MaxIdleConns, +// IdleConnTimeout), the dialer defaults (DialContext) and the various +// stdlib-tuned timeouts (TLSHandshakeTimeout, ExpectContinueTimeout). +// We override only ResponseHeaderTimeout and TLSClientConfig. +func newForwardTransport(timeout time.Duration) *http.Transport { + if timeout <= 0 { + timeout = defaultForwardTimeout + } + // http.DefaultTransport is documented as *http.Transport. The comma-ok + // form keeps the linter happy without losing the invariant. + base, ok := http.DefaultTransport.(*http.Transport) + if !ok { + // Stdlib invariant broken — fall back to a fresh transport rather + // than panicking. This branch is effectively unreachable. + base = &http.Transport{} + } + t := base.Clone() + t.ResponseHeaderTimeout = timeout + t.TLSClientConfig = &tls.Config{ + MinVersion: tls.VersionTLS13, + // Explicit: we always verify forward-target certificates. The + // linter flags this field because the zero value is also false, + // but being explicit guards against future refactors silently + // flipping the default. + InsecureSkipVerify: false, //nolint:gosec // explicit: always verify + } + return t +} + +// singleJoiningSlash mirrors httputil.singleJoiningSlash (unexported in +// net/http/httputil). Given a target-URL path and a request path, it joins +// them with exactly one separator slash. Used by director to rewrite +// req.URL.Path so that target paths with or without trailing slashes — and +// request paths with or without leading slashes — concatenate cleanly. +func singleJoiningSlash(a, b string) string { + aSlash := a != "" && a[len(a)-1] == '/' + bSlash := b != "" && b[0] == '/' + switch { + case aSlash && bSlash: + return a + b[1:] + case !aSlash && !bSlash: + return a + "/" + b + } + return a + b +} diff --git a/internal/proxy/forward_proxy_test.go b/internal/proxy/forward_proxy_test.go new file mode 100644 index 0000000..69f173d --- /dev/null +++ b/internal/proxy/forward_proxy_test.go @@ -0,0 +1,612 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package proxy + +import ( + "context" + "crypto/tls" + "io" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus/testutil" + + "github.com/cloudblue/chaperone/internal/config" + "github.com/cloudblue/chaperone/internal/telemetry" +) + +// errCtxDeadlineExceeded is captured once so test tables can reference it +// without re-importing context in every helper. +var errCtxDeadlineExceeded = context.DeadlineExceeded + +func newTestTarget(t *testing.T, handler http.HandlerFunc) *httptest.Server { + t.Helper() + return httptest.NewServer(handler) +} + +func TestForwardProxy_PassesXConnectHeaders(t *testing.T) { + var seen http.Header + target := newTestTarget(t, func(_ http.ResponseWriter, r *http.Request) { + seen = r.Header.Clone() + }) + defer target.Close() + + h, err := NewForwardProxy("company-b", config.ForwardTargetConfig{ + URL: target.URL, + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }) + if err != nil { + t.Fatalf("NewForwardProxy: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "/proxy", nil) + req.Header.Set("X-Connect-Target-URL", "https://api.vendor.com/v1/foo") + req.Header.Set("X-Connect-Vendor-ID", "vendor-a") + h.ServeHTTP(httptest.NewRecorder(), req) + + if got := seen.Get("X-Connect-Target-URL"); got != "https://api.vendor.com/v1/foo" { + t.Errorf("X-Connect-Target-URL forwarded = %q", got) + } + if got := seen.Get("X-Connect-Vendor-ID"); got != "vendor-a" { + t.Errorf("X-Connect-Vendor-ID forwarded = %q", got) + } +} + +func TestForwardProxy_StripsInboundAuthorization_AddsBearer(t *testing.T) { + var seen http.Header + target := newTestTarget(t, func(_ http.ResponseWriter, r *http.Request) { seen = r.Header.Clone() }) + defer target.Close() + + h, err := NewForwardProxy("company-b", config.ForwardTargetConfig{ + URL: target.URL, + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthBearer, Token: "secret-xyz"}, + }) + if err != nil { + t.Fatalf("NewForwardProxy: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "/proxy", nil) + req.Header.Set("Authorization", "Bearer connect-original") + h.ServeHTTP(httptest.NewRecorder(), req) + + auth := seen.Get("Authorization") + if auth != "Bearer secret-xyz" { + t.Errorf("forwarded Authorization = %q, want %q", auth, "Bearer secret-xyz") + } + if strings.Contains(auth, "connect-original") { + t.Errorf("inbound Authorization leaked: %q", auth) + } +} + +func TestForwardProxy_SanitizesReflectedSensitiveResponseHeaders(t *testing.T) { + target := newTestTarget(t, func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Authorization", "Bearer reflected-secret") + w.Header().Set("Set-Cookie", "session=abc") + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, "ok") + }) + defer target.Close() + + h, _ := NewForwardProxy("company-b", config.ForwardTargetConfig{ + URL: target.URL, + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }) + + rec := httptest.NewRecorder() + h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/proxy", nil)) + + if rec.Result().Header.Get("Authorization") != "" { + t.Errorf("reflected Authorization not stripped") + } + if rec.Result().Header.Get("Set-Cookie") != "" { + t.Errorf("reflected Set-Cookie not stripped") + } +} + +func TestForwardProxy_BearerToken_NotInLogOutput(t *testing.T) { + target := newTestTarget(t, func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + defer target.Close() + + getLogs := captureLogs(t) + + h, _ := NewForwardProxy("company-b", config.ForwardTargetConfig{ + URL: target.URL, + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthBearer, Token: "secret-not-in-logs"}, + }) + h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/proxy", nil)) + + if out := getLogs(); strings.Contains(out, "secret-not-in-logs") { + t.Errorf("bearer token leaked into log output: %s", out) + } +} + +func TestForwardProxy_HonorsTimeout(t *testing.T) { + target := newTestTarget(t, func(w http.ResponseWriter, r *http.Request) { + select { + case <-r.Context().Done(): + case <-time.After(2 * time.Second): + } + w.WriteHeader(http.StatusOK) + }) + defer target.Close() + + h, _ := NewForwardProxy("company-b", config.ForwardTargetConfig{ + URL: target.URL, + Timeout: 50 * time.Millisecond, + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }) + + rec := httptest.NewRecorder() + h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/proxy", nil)) + + if rec.Code == http.StatusOK { + t.Errorf("expected non-200 due to timeout, got 200") + } +} + +// TestForwardProxy_PreservesAllXConnectHeaders confirms multiple X-Connect-* +// headers are forwarded verbatim — they are part of the customer's context +// (the forward target's system needs them) and must not be stripped. +func TestForwardProxy_PreservesAllXConnectHeaders(t *testing.T) { + var seen http.Header + target := newTestTarget(t, func(_ http.ResponseWriter, r *http.Request) { + seen = r.Header.Clone() + }) + defer target.Close() + + h, _ := NewForwardProxy("company-b", config.ForwardTargetConfig{ + URL: target.URL, + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }) + + req := httptest.NewRequest(http.MethodGet, "/proxy", nil) + headers := map[string]string{ + "X-Connect-Target-URL": "https://api.vendor.com/v1/foo", + "X-Connect-Vendor-ID": "vendor-a", + "X-Connect-Marketplace-ID": "marketplace-1", + "X-Connect-Product-ID": "PRD-001", + "X-Connect-Subscription-ID": "AS-1234", + "X-Connect-Context-Data": "eyJrIjoidiJ9", + } + for k, v := range headers { + req.Header.Set(k, v) + } + h.ServeHTTP(httptest.NewRecorder(), req) + + for k, want := range headers { + if got := seen.Get(k); got != want { + t.Errorf("header %s = %q, want %q", k, got, want) + } + } +} + +// TestForwardProxy_AuthNoneStripsInboundAuthorization verifies that even +// when no bearer is configured (auth.type=none), any inbound Authorization +// header is still stripped before forwarding. This prevents Connect's auth +// posture from leaking to the forward target. +func TestForwardProxy_AuthNoneStripsInboundAuthorization(t *testing.T) { + var seen http.Header + target := newTestTarget(t, func(_ http.ResponseWriter, r *http.Request) { + seen = r.Header.Clone() + }) + defer target.Close() + + h, _ := NewForwardProxy("company-b", config.ForwardTargetConfig{ + URL: target.URL, + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }) + + req := httptest.NewRequest(http.MethodGet, "/proxy", nil) + req.Header.Set("Authorization", "Bearer connect-original") + h.ServeHTTP(httptest.NewRecorder(), req) + + if got := seen.Get("Authorization"); got != "" { + t.Errorf("auth.type=none: inbound Authorization not stripped, got %q", got) + } +} + +// TestForwardProxy_AuthNoneNoAuthorizationAdded confirms no Authorization +// header is added when auth.type=none and the inbound request has none. +func TestForwardProxy_AuthNoneNoAuthorizationAdded(t *testing.T) { + var seen http.Header + target := newTestTarget(t, func(_ http.ResponseWriter, r *http.Request) { + seen = r.Header.Clone() + }) + defer target.Close() + + h, _ := NewForwardProxy("company-b", config.ForwardTargetConfig{ + URL: target.URL, + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }) + + req := httptest.NewRequest(http.MethodGet, "/proxy", nil) + h.ServeHTTP(httptest.NewRecorder(), req) + + if got := seen.Get("Authorization"); got != "" { + t.Errorf("auth.type=none: Authorization should not be set, got %q", got) + } +} + +// TestForwardProxy_BearerTokenWithWhitespace verifies the token is +// forwarded verbatim (we do NOT trim/normalize whitespace — operators may +// have intentional content there, and any malformedness is their concern). +func TestForwardProxy_BearerTokenWithWhitespace(t *testing.T) { + var seen http.Header + target := newTestTarget(t, func(_ http.ResponseWriter, r *http.Request) { + seen = r.Header.Clone() + }) + defer target.Close() + + token := "abc def ghi" + h, _ := NewForwardProxy("company-b", config.ForwardTargetConfig{ + URL: target.URL, + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthBearer, Token: token}, + }) + + h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/proxy", nil)) + + want := "Bearer " + token + if got := seen.Get("Authorization"); got != want { + t.Errorf("Authorization = %q, want %q", got, want) + } +} + +// TestForwardProxy_PathJoining covers singleJoiningSlash correctness across +// the four corner cases of trailing/leading slashes between the target URL +// path and the inbound request path. +func TestForwardProxy_PathJoining(t *testing.T) { + tests := []struct { + name string + targetPath string + requestPath string + wantPath string + }{ + {"both empty", "", "", "/"}, + {"target trailing, req leading", "/api/", "/v1/foo", "/api/v1/foo"}, + {"target no trailing, req no leading", "/api", "v1/foo", "/api/v1/foo"}, + {"target trailing, req no leading", "/api/", "v1/foo", "/api/v1/foo"}, + {"target no trailing, req leading", "/api", "/v1/foo", "/api/v1/foo"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var seenPath string + target := newTestTarget(t, func(_ http.ResponseWriter, r *http.Request) { + seenPath = r.URL.Path + }) + defer target.Close() + + h, err := NewForwardProxy("t", config.ForwardTargetConfig{ + URL: target.URL + tt.targetPath, + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }) + if err != nil { + t.Fatalf("NewForwardProxy: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://example.test/", nil) + req.URL.Path = tt.requestPath + h.ServeHTTP(httptest.NewRecorder(), req) + + if seenPath != tt.wantPath { + t.Errorf("forwarded path = %q, want %q", seenPath, tt.wantPath) + } + }) + } +} + +// TestForwardProxy_Passes500Status verifies that a 5xx response from the +// forward target is passed through verbatim. The forward path explicitly +// does NOT perform error normalization (that is reserved for the vendor +// proxy path where ResponseModifier may opt out). +func TestForwardProxy_Passes500Status(t *testing.T) { + target := newTestTarget(t, func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = io.WriteString(w, "boom") + }) + defer target.Close() + + h, _ := NewForwardProxy("company-b", config.ForwardTargetConfig{ + URL: target.URL, + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }) + + rec := httptest.NewRecorder() + h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/proxy", nil)) + + if rec.Code != http.StatusInternalServerError { + t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError) + } + if body := rec.Body.String(); body != "boom" { + t.Errorf("body = %q, want %q", body, "boom") + } +} + +// TestForwardProxy_TargetUnreachable_Returns502 verifies the ErrorHandler +// returns 502 Bad Gateway when the target refuses the connection. +func TestForwardProxy_TargetUnreachable_Returns502(t *testing.T) { + // Start and immediately close a server to obtain a guaranteed-closed + // port. The returned URL is for a now-defunct listener; dials will fail. + srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) + url := srv.URL + srv.Close() + + h, _ := NewForwardProxy("company-b", config.ForwardTargetConfig{ + URL: url, + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }) + + rec := httptest.NewRecorder() + h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/proxy", nil)) + + if rec.Code != http.StatusBadGateway { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadGateway) + } +} + +// ============================================================================= +// Task 8: forward-target metrics +// ============================================================================= +// +// These tests assert on the global telemetry.Forward* metrics. They MUST NOT +// use t.Parallel() because the metrics are registered with the default +// Prometheus registry. Test isolation is via telemetry.ResetMetrics(). + +func TestMetrics_ForwardTarget_DurationHistogram_Records(t *testing.T) { + telemetry.ResetMetrics(t) + + target := newTestTarget(t, func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + defer target.Close() + + h, err := NewForwardProxy("company-b", config.ForwardTargetConfig{ + URL: target.URL, + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthBearer, Token: "secret"}, + }) + if err != nil { + t.Fatalf("NewForwardProxy: %v", err) + } + + h.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/proxy", nil)) + + // SampleCount: total observations under the {target=company-b} histogram. + if got := testutil.CollectAndCount(telemetry.ForwardTargetDuration); got == 0 { + t.Error("expected ForwardTargetDuration to have at least one observation") + } + // No infrastructure errors expected on a successful round-trip. + if got := testutil.ToFloat64(telemetry.ForwardTargetErrors.WithLabelValues("company-b", "connection")); got != 0 { + t.Errorf("connection errors counter = %v, want 0", got) + } +} + +func TestMetrics_ForwardTarget_Errors_IncrementsByKind_Connection(t *testing.T) { + telemetry.ResetMetrics(t) + + // Start and immediately close a server to obtain a guaranteed-closed port. + srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) + url := srv.URL + srv.Close() + + h, _ := NewForwardProxy("company-b", config.ForwardTargetConfig{ + URL: url, + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }) + + rec := httptest.NewRecorder() + h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/proxy", nil)) + + if rec.Code != http.StatusBadGateway { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadGateway) + } + got := testutil.ToFloat64(telemetry.ForwardTargetErrors.WithLabelValues("company-b", "connection")) + if got != 1 { + t.Errorf("forward_target_errors_total{target=company-b,kind=connection} = %v, want 1", got) + } + // Duration is still observed: the deferred Observe in ServeHTTP runs + // regardless of whether the request succeeded or hit errorHandler. This + // is intentional — operators want end-to-end latency including failures. + if got := testutil.CollectAndCount(telemetry.ForwardTargetDuration); got == 0 { + t.Error("expected ForwardTargetDuration to record even on error path") + } +} + +func TestMetrics_ForwardTarget_Errors_IncrementsByKind_Timeout(t *testing.T) { + telemetry.ResetMetrics(t) + + // Target that never responds within the test budget. + target := newTestTarget(t, func(w http.ResponseWriter, r *http.Request) { + select { + case <-r.Context().Done(): + case <-time.After(2 * time.Second): + } + w.WriteHeader(http.StatusOK) + }) + defer target.Close() + + h, _ := NewForwardProxy("company-b", config.ForwardTargetConfig{ + URL: target.URL, + Timeout: 25 * time.Millisecond, + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }) + + rec := httptest.NewRecorder() + h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/proxy", nil)) + + if rec.Code != http.StatusBadGateway { + t.Errorf("status = %d, want %d", rec.Code, http.StatusBadGateway) + } + got := testutil.ToFloat64(telemetry.ForwardTargetErrors.WithLabelValues("company-b", "timeout")) + if got != 1 { + t.Errorf("forward_target_errors_total{target=company-b,kind=timeout} = %v, want 1", got) + } +} + +// TestMetrics_ForwardTarget_500Response_NoErrorCounter verifies that a 5xx +// response from the target is treated as a target response (not a Chaperone +// infrastructure error) — the duration histogram observes but the errors +// counter does NOT increment. +func TestMetrics_ForwardTarget_500Response_NoErrorCounter(t *testing.T) { + telemetry.ResetMetrics(t) + + target := newTestTarget(t, func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = io.WriteString(w, "boom") + }) + defer target.Close() + + h, _ := NewForwardProxy("company-b", config.ForwardTargetConfig{ + URL: target.URL, + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }) + + rec := httptest.NewRecorder() + h.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/proxy", nil)) + + if rec.Code != http.StatusInternalServerError { + t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError) + } + // Duration: observed. + if got := testutil.CollectAndCount(telemetry.ForwardTargetDuration); got == 0 { + t.Error("expected duration histogram to record on 5xx response") + } + // Errors: NOT incremented (any kind). + for _, kind := range []string{"connection", "timeout", "tls", "other"} { + if got := testutil.ToFloat64(telemetry.ForwardTargetErrors.WithLabelValues("company-b", kind)); got != 0 { + t.Errorf("5xx response must not increment errors_total{kind=%s}, got %v", kind, got) + } + } +} + +// TestMetrics_ForwardTarget_MultipleTargets verifies each target gets its own +// histogram cell (no cross-aliasing). +func TestMetrics_ForwardTarget_MultipleTargets(t *testing.T) { + telemetry.ResetMetrics(t) + + targetA := newTestTarget(t, func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + defer targetA.Close() + targetB := newTestTarget(t, func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + defer targetB.Close() + + hA, _ := NewForwardProxy("a", config.ForwardTargetConfig{ + URL: targetA.URL, + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }) + hB, _ := NewForwardProxy("b", config.ForwardTargetConfig{ + URL: targetB.URL, + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }) + + hA.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/proxy", nil)) + hB.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/proxy", nil)) + hB.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/proxy", nil)) + + // Each named target should have its own histogram cell. + // CollectAndCount counts the number of distinct label sets — two here. + count := testutil.CollectAndCount(telemetry.ForwardTargetDuration) + if count < 2 { + t.Errorf("expected at least 2 distinct duration histogram cells, got %d", count) + } +} + +// TestClassifyForwardError_Matrix exercises the error classifier directly to +// pin the kind labels we surface in the metric. +func TestClassifyForwardError_Matrix(t *testing.T) { + tests := []struct { + name string + err error + want string + }{ + {"nil", nil, "other"}, + {"deadline exceeded", errCtxDeadlineExceeded, "timeout"}, + {"net timeout", testNetTimeoutError{}, "timeout"}, + {"dns failure", &net.DNSError{Err: "no such host", Name: "nope.example"}, "connection"}, + {"op error refused", &net.OpError{Op: "dial", Net: "tcp", Err: stringError("connection refused")}, "connection"}, + {"tls substring", stringError("tls: handshake failure"), "tls"}, + {"x509 substring", stringError("x509: certificate signed by unknown authority"), "tls"}, + {"plain", stringError("something weird"), "other"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := classifyForwardError(tt.err) + if got != tt.want { + t.Errorf("classifyForwardError(%v) = %q, want %q", tt.err, got, tt.want) + } + }) + } +} + +// TestNewForwardTransport_InheritsDefaultTransportSettings verifies that the +// per-target transport is built by cloning http.DefaultTransport so it keeps +// HTTP/2 negotiation, connection pooling, the dialer defaults, and the +// stdlib-tuned timeouts (TLSHandshakeTimeout, IdleConnTimeout, etc.). It also +// verifies our explicit overrides on TLSClientConfig stick. +func TestNewForwardTransport_InheritsDefaultTransportSettings(t *testing.T) { + fp, err := NewForwardProxy("x", config.ForwardTargetConfig{ + URL: "https://example.test/", + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }) + if err != nil { + t.Fatalf("NewForwardProxy: %v", err) + } + tr, ok := fp.proxy.Transport.(*http.Transport) + if !ok { + t.Fatalf("proxy.Transport = %T, want *http.Transport", fp.proxy.Transport) + } + if !tr.ForceAttemptHTTP2 { + t.Error("ForceAttemptHTTP2 false; expected DefaultTransport defaults to be inherited") + } + if tr.TLSHandshakeTimeout == 0 { + t.Error("TLSHandshakeTimeout 0; expected DefaultTransport defaults to be inherited") + } + if tr.IdleConnTimeout == 0 { + t.Error("IdleConnTimeout 0; expected DefaultTransport defaults to be inherited") + } + if tr.MaxIdleConns == 0 { + t.Error("MaxIdleConns 0; expected DefaultTransport defaults to be inherited") + } + if tr.TLSClientConfig == nil || tr.TLSClientConfig.MinVersion != tls.VersionTLS13 { + t.Error("TLSClientConfig not configured with TLS 1.3 minimum") + } + if tr.TLSClientConfig.InsecureSkipVerify { + t.Error("InsecureSkipVerify true; expected false") + } +} + +func TestForwardProxy_InvalidTargetURL_ReturnsError(t *testing.T) { + _, err := NewForwardProxy("bad", config.ForwardTargetConfig{ + URL: "://not-a-url", + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }) + if err == nil { + t.Fatal("NewForwardProxy with invalid URL: expected error, got nil") + } +} + +// ----------------------------------------------------------------------------- +// Test doubles for classifyForwardError matrix. +// ----------------------------------------------------------------------------- + +// stringError is a trivial error with a controllable message; used to +// exercise the substring-based TLS/x509 classification paths. +type stringError string + +func (e stringError) Error() string { return string(e) } + +// testNetTimeoutError satisfies net.Error with Timeout()=true so the classifier +// can identify it without depending on a real network call. +type testNetTimeoutError struct{} + +func (testNetTimeoutError) Error() string { return "i/o timeout" } +func (testNetTimeoutError) Timeout() bool { return true } +func (testNetTimeoutError) Temporary() bool { return true } diff --git a/internal/proxy/integration_test.go b/internal/proxy/integration_test.go index 460b371..ea2d238 100644 --- a/internal/proxy/integration_test.go +++ b/internal/proxy/integration_test.go @@ -4,7 +4,6 @@ package proxy_test import ( - "bytes" "context" "errors" "fmt" @@ -1208,11 +1207,7 @@ func TestHandlerStack_TraceID_ConsistentAcrossLogAndBackend(t *testing.T) { defer backend.Close() // Capture log output - var logBuf strings.Builder - logger := slog.New(slog.NewJSONHandler(&logBuf, nil)) - origLogger := slog.Default() - slog.SetDefault(logger) - defer slog.SetDefault(origLogger) + getLogs := captureLogs(t) srv := mustNewServerForTarget(t, testConfig(), backend.URL) handler := srv.Handler() @@ -1231,7 +1226,7 @@ func TestHandlerStack_TraceID_ConsistentAcrossLogAndBackend(t *testing.T) { t.Errorf("backend trace_id = %q, want %q", receivedTraceID, "consistency-check-123") } - logOutput := logBuf.String() + logOutput := getLogs() if !strings.Contains(logOutput, `"trace_id":"consistency-check-123"`) { t.Errorf("log output should contain trace_id, got:\n%s", logOutput) } @@ -1914,11 +1909,7 @@ func TestIntegration_NonContextHeaders_PreservedOnForwarding(t *testing.T) { func TestIntegration_FastPath_LogsCredentialInjection(t *testing.T) { // Arrange - capture DEBUG log output - var logBuffer bytes.Buffer - logger := slog.New(slog.NewJSONHandler(&logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) - originalLogger := slog.Default() - slog.SetDefault(logger) - defer slog.SetDefault(originalLogger) + getLogs := captureLogsAt(t, &slog.HandlerOptions{Level: slog.LevelDebug}) backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -1953,7 +1944,7 @@ func TestIntegration_FastPath_LogsCredentialInjection(t *testing.T) { } // Assert - "credentials injected" DEBUG log with Fast Path fields - logOutput := logBuffer.String() + logOutput := getLogs() if !strings.Contains(logOutput, `"msg":"credentials injected"`) { t.Errorf("expected credentials injected log, got: %s", logOutput) } @@ -1973,11 +1964,7 @@ func TestIntegration_FastPath_LogsCredentialInjection(t *testing.T) { func TestIntegration_SlowPath_LogsCredentialInjection(t *testing.T) { // Arrange - capture DEBUG log output - var logBuffer bytes.Buffer - logger := slog.New(slog.NewJSONHandler(&logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) - originalLogger := slog.Default() - slog.SetDefault(logger) - defer slog.SetDefault(originalLogger) + getLogs := captureLogsAt(t, &slog.HandlerOptions{Level: slog.LevelDebug}) backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -2010,7 +1997,7 @@ func TestIntegration_SlowPath_LogsCredentialInjection(t *testing.T) { } // Assert - "credentials injected" DEBUG log with Slow Path fields - logOutput := logBuffer.String() + logOutput := getLogs() if !strings.Contains(logOutput, `"msg":"credentials injected"`) { t.Errorf("expected credentials injected log, got: %s", logOutput) } @@ -2028,11 +2015,7 @@ func TestIntegration_SlowPath_LogsCredentialInjection(t *testing.T) { func TestProxy_ContextParsed_DebugLog_LogsHostOnly(t *testing.T) { // Arrange - capture DEBUG log output - var logBuffer bytes.Buffer - logger := slog.New(slog.NewJSONHandler(&logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) - originalLogger := slog.Default() - slog.SetDefault(logger) - defer slog.SetDefault(originalLogger) + getLogs := captureLogsAt(t, &slog.HandlerOptions{Level: slog.LevelDebug}) backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -2054,7 +2037,7 @@ func TestProxy_ContextParsed_DebugLog_LogsHostOnly(t *testing.T) { handler.ServeHTTP(rec, req) // Assert - only the host appears; path, query, and userinfo must not leak - logOutput := logBuffer.String() + logOutput := getLogs() if !strings.Contains(logOutput, `"msg":"transaction context parsed"`) { t.Errorf("expected 'transaction context parsed' debug log, got: %s", logOutput) } @@ -2081,11 +2064,7 @@ func TestProxy_ContextParsed_DebugLog_LogsHostOnly(t *testing.T) { func TestIntegration_ClientDisconnect_LogsStatus499(t *testing.T) { // Arrange - capture log output - var logBuffer bytes.Buffer - logger := slog.New(slog.NewJSONHandler(&logBuffer, nil)) - originalLogger := slog.Default() - slog.SetDefault(logger) - defer slog.SetDefault(originalLogger) + getLogs := captureLogs(t) plugin := &mockPlugin{ getCredentialsFn: func(_ context.Context, _ sdk.TransactionContext, _ *http.Request) (*sdk.Credential, error) { @@ -2111,7 +2090,7 @@ func TestIntegration_ClientDisconnect_LogsStatus499(t *testing.T) { if rec.Code != proxy.StatusClientClosedRequest { t.Errorf("response status = %d, want %d", rec.Code, proxy.StatusClientClosedRequest) } - logOutput := logBuffer.String() + logOutput := getLogs() if !strings.Contains(logOutput, `"status":499`) { t.Errorf("log should contain status 499, got: %s", logOutput) } diff --git a/internal/proxy/log_capture_internal_test.go b/internal/proxy/log_capture_internal_test.go new file mode 100644 index 0000000..ab8b395 --- /dev/null +++ b/internal/proxy/log_capture_internal_test.go @@ -0,0 +1,48 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package proxy + +import ( + "bytes" + "log/slog" + "sync" + "testing" +) + +// syncBuffer is a thread-safe bytes.Buffer for capturing log output in tests. +// +// slog.Default() is a process-global, so tests that swap it with a buffer-backed +// handler can race against production goroutines (or other tests) that continue +// to log through the same handler. Wrapping the buffer in a mutex eliminates +// the data race between concurrent Write and String calls. +type syncBuffer struct { + mu sync.Mutex + buf bytes.Buffer +} + +// Write appends p to the buffer under the mutex. Implements io.Writer. +func (s *syncBuffer) Write(p []byte) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + return s.buf.Write(p) +} + +// String returns the buffered contents under the mutex. +func (s *syncBuffer) String() string { + s.mu.Lock() + defer s.mu.Unlock() + return s.buf.String() +} + +// captureLogs swaps slog.Default() with a JSON handler that writes to a +// thread-safe buffer. It registers a t.Cleanup to restore the previous +// default handler, and returns a closure that yields the captured output. +func captureLogs(t *testing.T) func() string { + t.Helper() + prev := slog.Default() + sb := &syncBuffer{} + slog.SetDefault(slog.New(slog.NewJSONHandler(sb, nil))) + t.Cleanup(func() { slog.SetDefault(prev) }) + return sb.String +} diff --git a/internal/proxy/log_capture_test.go b/internal/proxy/log_capture_test.go new file mode 100644 index 0000000..7e1a5d4 --- /dev/null +++ b/internal/proxy/log_capture_test.go @@ -0,0 +1,59 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package proxy_test + +import ( + "bytes" + "log/slog" + "sync" + "testing" +) + +// syncBuffer is a thread-safe bytes.Buffer for capturing log output in tests. +// +// slog.Default() is a process-global, so tests that swap it with a buffer-backed +// handler can race against production goroutines (or other tests) that continue +// to log through the same handler. Wrapping the buffer in a mutex eliminates +// the data race between concurrent Write and String calls. +type syncBuffer struct { + mu sync.Mutex + buf bytes.Buffer +} + +// Write appends p to the buffer under the mutex. Implements io.Writer. +func (s *syncBuffer) Write(p []byte) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + return s.buf.Write(p) +} + +// String returns the buffered contents under the mutex. +func (s *syncBuffer) String() string { + s.mu.Lock() + defer s.mu.Unlock() + return s.buf.String() +} + +// captureLogs swaps slog.Default() with a JSON handler that writes to a +// thread-safe buffer. It registers a t.Cleanup to restore the previous +// default handler, and returns a closure that yields the captured output. +func captureLogs(t *testing.T) func() string { + t.Helper() + prev := slog.Default() + sb := &syncBuffer{} + slog.SetDefault(slog.New(slog.NewJSONHandler(sb, nil))) + t.Cleanup(func() { slog.SetDefault(prev) }) + return sb.String +} + +// captureLogsAt is like captureLogs but with explicit slog.HandlerOptions — +// used by tests that need to capture debug-level output. +func captureLogsAt(t *testing.T, opts *slog.HandlerOptions) func() string { + t.Helper() + prev := slog.Default() + sb := &syncBuffer{} + slog.SetDefault(slog.New(slog.NewJSONHandler(sb, opts))) + t.Cleanup(func() { slog.SetDefault(prev) }) + return sb.String +} diff --git a/internal/proxy/main_test.go b/internal/proxy/main_test.go index ebc78c8..a63e513 100644 --- a/internal/proxy/main_test.go +++ b/internal/proxy/main_test.go @@ -42,6 +42,7 @@ func testConfig() proxy.Config { PluginTimeout: 10 * time.Second, ConnectTimeout: 5 * time.Second, ShutdownTimeout: 30 * time.Second, + ForwardTargets: nil, // No forward targets by default } } diff --git a/internal/proxy/migration_integration_test.go b/internal/proxy/migration_integration_test.go new file mode 100644 index 0000000..2c56f0b --- /dev/null +++ b/internal/proxy/migration_integration_test.go @@ -0,0 +1,412 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package proxy_test + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus/testutil" + + "github.com/cloudblue/chaperone/internal/config" + chaperoneCtx "github.com/cloudblue/chaperone/internal/context" + "github.com/cloudblue/chaperone/internal/proxy" + "github.com/cloudblue/chaperone/internal/telemetry" + "github.com/cloudblue/chaperone/sdk" +) + +// ============================================================================= +// Task 15: End-to-end migration scenario. +// +// A single Chaperone server is configured with one plugin that ALSO implements +// sdk.RequestRouter. The router inspects tx.Data["ResellerId"] and forwards +// any reseller matching the glob pattern "migrated-*" to a fake "Company B" +// target. Other resellers (or requests missing ResellerId) fall through to +// credential injection, which adds a bearer token and forwards to the fake +// vendor target. +// +// This test exercises the full request flow across both branches in the same +// server, verifying: +// +// - Routing decisions are driven by tx.Data, which arrives via the +// X-Connect-Context-Data header (Base64-encoded JSON). +// - The forward path hits Company B and bypasses GetCredentials entirely. +// - The credentials path injects Authorization, strips X-Connect-* headers, +// and reaches the vendor target. +// - Per-decision metrics (chaperone_route_decisions_total) increment +// correctly under both branches. +// - Missing/empty Data falls through to the credentials path. +// ============================================================================= + +// migrationPlugin is an inline sdk.Plugin + sdk.RequestRouter used by the +// migration scenario. RouteRequest implements a simple glob match against +// tx.Data["ResellerId"]; the only supported pattern is a trailing "*" prefix +// match (sufficient for "migrated-*"). GetCredentials returns a fixed bearer +// credential and records its call count so the test can assert it never runs +// on the forward path. +type migrationPlugin struct { + pattern string // e.g. "migrated-*" + forwardTo string // forward target name + credToken string // bearer token returned from GetCredentials + getCredentialsCount atomic.Int32 +} + +func (p *migrationPlugin) GetCredentials(_ context.Context, _ sdk.TransactionContext, _ *http.Request) (*sdk.Credential, error) { + p.getCredentialsCount.Add(1) + return &sdk.Credential{ + Headers: map[string]string{ + "Authorization": "Bearer " + p.credToken, + }, + ExpiresAt: time.Now().Add(1 * time.Hour), + }, nil +} + +func (p *migrationPlugin) SignCSR(_ context.Context, _ []byte) ([]byte, error) { + return nil, errors.New("not implemented") +} + +func (p *migrationPlugin) ModifyResponse(_ context.Context, _ sdk.TransactionContext, _ *http.Response) (*sdk.ResponseAction, error) { + return nil, nil +} + +func (p *migrationPlugin) RouteRequest(_ context.Context, tx sdk.TransactionContext, _ *http.Request) (*sdk.RouteAction, error) { + rid, ok := tx.Data["ResellerId"].(string) + if !ok || rid == "" { + return nil, nil + } + if migrationGlobMatch(p.pattern, rid) { + return &sdk.RouteAction{ForwardTo: p.forwardTo}, nil + } + return nil, nil +} + +var _ sdk.Plugin = (*migrationPlugin)(nil) +var _ sdk.RequestRouter = (*migrationPlugin)(nil) + +// migrationGlobMatch implements the trailing-"*" wildcard semantics used in +// this test. "migrated-*" matches any input starting with "migrated-". +// Patterns without a trailing "*" require exact equality. This is a +// deliberately small implementation: the full glob semantics (multi-segment, +// "?" matching, etc.) belong in contrib/glob.go and are exercised there. +func migrationGlobMatch(pattern, input string) bool { + if strings.HasSuffix(pattern, "*") { + return strings.HasPrefix(input, strings.TrimSuffix(pattern, "*")) + } + return pattern == input +} + +// migrationProxyRequest builds a /proxy request for the migration scenario. +// resellerID is embedded in X-Connect-Context-Data as Base64-encoded JSON +// (matching the production wire format parsed by internal/context). When +// resellerID is empty, the Context-Data header is omitted entirely so the +// resulting tx.Data is nil (the "missing key" scenario). When emptyData is +// true, an empty JSON object is sent — tx.Data is non-nil but lacks +// ResellerId. +func migrationProxyRequest(t *testing.T, vendorTargetURL, resellerID string, emptyData bool) *http.Request { + t.Helper() + req := httptest.NewRequest(http.MethodPost, "/proxy", nil) + req.Header.Set("X-Connect-Target-URL", vendorTargetURL) + req.Header.Set("X-Connect-Vendor-ID", "vendor-a") + req.Header.Set("X-Connect-Marketplace-ID", "marketplace-1") + req.Header.Set("X-Connect-Product-ID", "product-1") + req.Header.Set("X-Connect-Subscription-ID", "sub-1") + + switch { + case emptyData: + req.Header.Set("X-Connect-Context-Data", base64.StdEncoding.EncodeToString([]byte(`{}`))) + case resellerID != "": + payload, err := json.Marshal(map[string]any{"ResellerId": resellerID}) + if err != nil { + t.Fatalf("encoding context data: %v", err) + } + req.Header.Set("X-Connect-Context-Data", base64.StdEncoding.EncodeToString(payload)) + } + return req +} + +// migrationServer wires up the proxy server, the migration plugin, both fake +// targets, and the per-target hit counters. It returns a teardown via +// t.Cleanup. +type migrationServer struct { + srv *proxy.Server + plugin *migrationPlugin + companyBHits *atomic.Int32 + vendorHits *atomic.Int32 + companyBBody string + vendorBody string + vendorURL string // real httptest URL of the vendor target; used as X-Connect-Target-URL +} + +// migrationVendorRecord records what the vendor target observed for a single +// request. The mutex-free assignment is safe because tests issue requests +// sequentially. +type migrationVendorRecord struct { + auth string + contextHeaders map[string]string +} + +func newMigrationServer(t *testing.T, lastVendorRecord *migrationVendorRecord) *migrationServer { + t.Helper() + + const companyBBody = `{"reply":"from-company-b"}` + const vendorBody = `{"reply":"from-vendor"}` + + companyBHits := &atomic.Int32{} + vendorHits := &atomic.Int32{} + + companyB := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + companyBHits.Add(1) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, companyBBody) + })) + t.Cleanup(companyB.Close) + + vendor := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + vendorHits.Add(1) + if lastVendorRecord != nil { + lastVendorRecord.auth = r.Header.Get("Authorization") + lastVendorRecord.contextHeaders = make(map[string]string, len(chaperoneCtx.HeaderSuffixes())) + for _, suffix := range chaperoneCtx.HeaderSuffixes() { + lastVendorRecord.contextHeaders["X-Connect"+suffix] = r.Header.Get("X-Connect" + suffix) + } + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, vendorBody) + })) + t.Cleanup(vendor.Close) + + plugin := &migrationPlugin{ + pattern: "migrated-*", + forwardTo: "company-b", + credToken: "vendor-token", + } + + cfg := testConfig() + cfg.Plugin = plugin + cfg.AllowList = map[string][]string{ + mustTargetHostPort(t, vendor.URL): {"/**"}, + mustTargetHostPort(t, companyB.URL): {"/**"}, + } + cfg.ForwardTargets = map[string]config.ForwardTargetConfig{ + "company-b": { + URL: companyB.URL, + Auth: config.ForwardTargetAuthConfig{ + Type: config.ForwardAuthNone, + }, + }, + } + + srv := mustNewServer(t, cfg) + + return &migrationServer{ + srv: srv, + plugin: plugin, + companyBHits: companyBHits, + vendorHits: vendorHits, + companyBBody: companyBBody, + vendorBody: vendorBody, + vendorURL: vendor.URL, + } +} + +// TestMigrationIntegration exercises the full migration scenario described in +// the implementation plan: one server, one plugin (router + credential +// provider), two fake upstreams (Company B and the vendor), and three +// requests that exercise both routing branches. +func TestMigrationIntegration(t *testing.T) { + telemetry.ResetMetrics(t) + + var vendorRecord migrationVendorRecord + env := newMigrationServer(t, &vendorRecord) + + t.Run("request A: migrated-001 is forwarded to Company B", func(t *testing.T) { + rec := httptest.NewRecorder() + req := migrationProxyRequest(t, env.vendorURL+"/v1/foo", "migrated-001", false) + env.srv.Handler().ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200. body=%s", rec.Code, rec.Body.String()) + } + if got := env.companyBHits.Load(); got != 1 { + t.Errorf("companyB hits after request A = %d, want 1", got) + } + if got := env.vendorHits.Load(); got != 0 { + t.Errorf("vendor hits after request A = %d, want 0", got) + } + if got := env.plugin.getCredentialsCount.Load(); got != 0 { + t.Errorf("GetCredentials calls after request A = %d, want 0", got) + } + // Response body from Company B reaches the client. + if got := rec.Body.String(); got != env.companyBBody { + t.Errorf("client body after request A = %q, want %q", got, env.companyBBody) + } + }) + + t.Run("request B: legacy-99 falls through to credentials and hits the vendor", func(t *testing.T) { + rec := httptest.NewRecorder() + req := migrationProxyRequest(t, env.vendorURL+"/v1/foo", "legacy-99", false) + env.srv.Handler().ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200. body=%s", rec.Code, rec.Body.String()) + } + if got := env.vendorHits.Load(); got != 1 { + t.Errorf("vendor hits after request B = %d, want 1", got) + } + // Company B must not have been hit by request B (still 1 from request A). + if got := env.companyBHits.Load(); got != 1 { + t.Errorf("companyB hits after request B = %d, want 1 (unchanged from request A)", got) + } + if got := env.plugin.getCredentialsCount.Load(); got != 1 { + t.Errorf("GetCredentials calls after request B = %d, want 1", got) + } + // Bearer token from the plugin's Credential reached the vendor target. + if got := vendorRecord.auth; got != "Bearer vendor-token" { + t.Errorf("vendor Authorization = %q, want %q", got, "Bearer vendor-token") + } + // X-Connect-* context headers MUST be stripped before reaching the vendor. + for header, value := range vendorRecord.contextHeaders { + if value != "" { + t.Errorf("context header %q leaked to vendor with value %q", header, value) + } + } + // Response body from the vendor reaches the client. + if got := rec.Body.String(); got != env.vendorBody { + t.Errorf("client body after request B = %q, want %q", got, env.vendorBody) + } + }) + + t.Run("request C: migrated-042 (glob match) is forwarded to Company B", func(t *testing.T) { + rec := httptest.NewRecorder() + req := migrationProxyRequest(t, env.vendorURL+"/v1/foo", "migrated-042", false) + env.srv.Handler().ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200. body=%s", rec.Code, rec.Body.String()) + } + if got := env.companyBHits.Load(); got != 2 { + t.Errorf("companyB hits after request C = %d, want 2 (1 from request A + 1 from request C)", got) + } + if got := env.vendorHits.Load(); got != 1 { + t.Errorf("vendor hits after request C = %d, want 1 (only request B)", got) + } + if got := env.plugin.getCredentialsCount.Load(); got != 1 { + t.Errorf("GetCredentials calls after request C = %d, want 1 (only request B)", got) + } + }) + + t.Run("metrics: route_decisions_total tracks both paths across the 3 requests", func(t *testing.T) { + // 2 forwards (requests A and C) → action=forward,target=company-b + fwd := testutil.ToFloat64(telemetry.RouteDecisionsTotal.WithLabelValues("forward", "company-b")) + if fwd != 2 { + t.Errorf("route_decisions_total{action=forward,target=company-b} = %v, want 2", fwd) + } + // 1 credentials decision (request B) → action=credentials,target="" + cred := testutil.ToFloat64(telemetry.RouteDecisionsTotal.WithLabelValues("credentials", "")) + if cred != 1 { + t.Errorf("route_decisions_total{action=credentials,target=\"\"} = %v, want 1", cred) + } + // No cross-contamination: forward with empty target, or credentials + // labeled with a forward target, must remain zero. + if v := testutil.ToFloat64(telemetry.RouteDecisionsTotal.WithLabelValues("forward", "")); v != 0 { + t.Errorf("forward+empty-target leaked: %v", v) + } + if v := testutil.ToFloat64(telemetry.RouteDecisionsTotal.WithLabelValues("credentials", "company-b")); v != 0 { + t.Errorf("credentials+company-b leaked: %v", v) + } + }) +} + +// TestMigrationIntegration_MissingResellerID_FallsThrough verifies that a +// request whose Data map omits ResellerId falls through to the credentials +// path (the router cannot make a migration decision without the key). +func TestMigrationIntegration_MissingResellerID_FallsThrough(t *testing.T) { + telemetry.ResetMetrics(t) + + env := newMigrationServer(t, nil) + + rec := httptest.NewRecorder() + // No Context-Data header at all → tx.Data is nil. + req := migrationProxyRequest(t, env.vendorURL+"/v1/foo", "", false) + env.srv.Handler().ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200. body=%s", rec.Code, rec.Body.String()) + } + if got := env.companyBHits.Load(); got != 0 { + t.Errorf("companyB hits = %d, want 0", got) + } + if got := env.vendorHits.Load(); got != 1 { + t.Errorf("vendor hits = %d, want 1", got) + } + if got := env.plugin.getCredentialsCount.Load(); got != 1 { + t.Errorf("GetCredentials calls = %d, want 1", got) + } +} + +// TestMigrationIntegration_EmptyDataMap_FallsThrough verifies that a request +// carrying X-Connect-Context-Data: "e30=" (base64 of `{}`) — a non-nil but +// empty Data map — also falls through to credentials. +func TestMigrationIntegration_EmptyDataMap_FallsThrough(t *testing.T) { + telemetry.ResetMetrics(t) + + env := newMigrationServer(t, nil) + + rec := httptest.NewRecorder() + req := migrationProxyRequest(t, env.vendorURL+"/v1/foo", "", true) + env.srv.Handler().ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want 200. body=%s", rec.Code, rec.Body.String()) + } + if got := env.companyBHits.Load(); got != 0 { + t.Errorf("companyB hits = %d, want 0", got) + } + if got := env.vendorHits.Load(); got != 1 { + t.Errorf("vendor hits = %d, want 1", got) + } + if got := env.plugin.getCredentialsCount.Load(); got != 1 { + t.Errorf("GetCredentials calls = %d, want 1", got) + } +} + +// TestMigrationIntegration_GlobMatch_MultipleResellerIDs verifies that the +// "migrated-*" pattern matches a range of migrated reseller IDs (not just an +// exact equality match), all of which take the forward path. +func TestMigrationIntegration_GlobMatch_MultipleResellerIDs(t *testing.T) { + telemetry.ResetMetrics(t) + + env := newMigrationServer(t, nil) + + migrated := []string{"migrated-001", "migrated-042", "migrated-foo-bar", "migrated-"} + for _, rid := range migrated { + rec := httptest.NewRecorder() + req := migrationProxyRequest(t, env.vendorURL+"/v1/foo", rid, false) + env.srv.Handler().ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("ResellerId=%q: status = %d, want 200. body=%s", rid, rec.Code, rec.Body.String()) + } + } + + if got := int(env.companyBHits.Load()); got != len(migrated) { + t.Errorf("companyB hits = %d, want %d (one per migrated ID)", got, len(migrated)) + } + if got := env.vendorHits.Load(); got != 0 { + t.Errorf("vendor hits = %d, want 0 (all migrated IDs should be forwarded)", got) + } + if got := env.plugin.getCredentialsCount.Load(); got != 0 { + t.Errorf("GetCredentials calls = %d, want 0 (forward path must bypass credentials)", got) + } +} diff --git a/internal/proxy/recovery_test.go b/internal/proxy/recovery_test.go index cf37eba..810ff6a 100644 --- a/internal/proxy/recovery_test.go +++ b/internal/proxy/recovery_test.go @@ -4,7 +4,6 @@ package proxy_test import ( - "bytes" "encoding/json" "io" "log/slog" @@ -64,11 +63,7 @@ func TestPanicRecovery_LogsStackTrace(t *testing.T) { // NOT parallel: this test mutates slog.SetDefault() (a global). // Arrange - capture log output - var logBuffer bytes.Buffer - logger := slog.New(slog.NewJSONHandler(&logBuffer, nil)) - originalLogger := slog.Default() - slog.SetDefault(logger) - defer slog.SetDefault(originalLogger) + getLogs := captureLogs(t) panicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { panic("stack trace test") @@ -83,7 +78,7 @@ func TestPanicRecovery_LogsStackTrace(t *testing.T) { handler.ServeHTTP(rec, req) // Assert - log should contain panic info with stack trace - logOutput := logBuffer.String() + logOutput := getLogs() if !strings.Contains(logOutput, "panic recovered") { t.Errorf("log should contain 'panic recovered', got: %s", logOutput) } diff --git a/internal/proxy/route_metrics_test.go b/internal/proxy/route_metrics_test.go new file mode 100644 index 0000000..6c2529b --- /dev/null +++ b/internal/proxy/route_metrics_test.go @@ -0,0 +1,181 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package proxy_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/prometheus/client_golang/prometheus/testutil" + + "github.com/cloudblue/chaperone/internal/config" + "github.com/cloudblue/chaperone/internal/telemetry" + "github.com/cloudblue/chaperone/sdk" +) + +// NOTE: These tests must NOT use t.Parallel() — they share the global +// Prometheus registries (RouteDecisionsTotal, ForwardTargetDuration, +// ForwardTargetErrors). Test isolation is achieved via telemetry.ResetMetrics(). + +func TestMetrics_RouteDecision_Forward_IncrementsCounter(t *testing.T) { + telemetry.ResetMetrics(t) + + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer target.Close() + + plugin := &routerPlugin{ + action: &sdk.RouteAction{ForwardTo: "company-b"}, + actionSet: true, + } + cfg := testConfig() + cfg.Plugin = plugin + cfg.ForwardTargets = map[string]config.ForwardTargetConfig{ + "company-b": {URL: target.URL, Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}}, + } + srv := mustNewServerForTarget(t, cfg, target.URL) + + srv.Handler().ServeHTTP(httptest.NewRecorder(), newProxyRequest(t, target.URL+"/v1/foo")) + + got := testutil.ToFloat64(telemetry.RouteDecisionsTotal.WithLabelValues("forward", "company-b")) + if got != 1 { + t.Errorf("route_decisions_total{action=forward,target=company-b} = %v, want 1", got) + } + // No credentials decision should fire for a forwarded request. + credCount := testutil.ToFloat64(telemetry.RouteDecisionsTotal.WithLabelValues("credentials", "")) + if credCount != 0 { + t.Errorf("credentials counter must be 0 on forward path, got %v", credCount) + } +} + +func TestMetrics_RouteDecision_Credentials_IncrementsCounter(t *testing.T) { + telemetry.ResetMetrics(t) + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + // Plain plugin (no RequestRouter) — straight credentials path. + cfg := testConfig() + cfg.Plugin = &plainPlugin{} + srv := mustNewServerForTarget(t, cfg, backend.URL) + + srv.Handler().ServeHTTP(httptest.NewRecorder(), newProxyRequest(t, backend.URL)) + + got := testutil.ToFloat64(telemetry.RouteDecisionsTotal.WithLabelValues("credentials", "")) + if got != 1 { + t.Errorf("route_decisions_total{action=credentials,target=\"\"} = %v, want 1", got) + } +} + +// TestMetrics_RouteDecisions_ForwardAndCredentials_NoCrossContamination drives +// both flows in the same test and verifies the counters track independently +// without label aliasing. +func TestMetrics_RouteDecisions_ForwardAndCredentials_NoCrossContamination(t *testing.T) { + telemetry.ResetMetrics(t) + + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer target.Close() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + // First request: forward path. + { + plugin := &routerPlugin{ + action: &sdk.RouteAction{ForwardTo: "company-b"}, + actionSet: true, + } + cfg := testConfig() + cfg.Plugin = plugin + cfg.ForwardTargets = map[string]config.ForwardTargetConfig{ + "company-b": {URL: target.URL, Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}}, + } + srv := mustNewServerForTarget(t, cfg, target.URL) + srv.Handler().ServeHTTP(httptest.NewRecorder(), newProxyRequest(t, target.URL+"/v1/foo")) + } + + // Second request: credentials path (nil action falls through). + { + plugin := &routerPlugin{ + action: nil, + actionSet: true, + } + cfg := testConfig() + cfg.Plugin = plugin + srv := mustNewServerForTarget(t, cfg, backend.URL) + srv.Handler().ServeHTTP(httptest.NewRecorder(), newProxyRequest(t, backend.URL)) + } + + fwd := testutil.ToFloat64(telemetry.RouteDecisionsTotal.WithLabelValues("forward", "company-b")) + if fwd != 1 { + t.Errorf("forward counter = %v, want 1", fwd) + } + cred := testutil.ToFloat64(telemetry.RouteDecisionsTotal.WithLabelValues("credentials", "")) + if cred != 1 { + t.Errorf("credentials counter = %v, want 1", cred) + } + // Cross-contamination sanity checks: these label combinations must not exist. + if v := testutil.ToFloat64(telemetry.RouteDecisionsTotal.WithLabelValues("forward", "")); v != 0 { + t.Errorf("forward+empty-target leaked: %v", v) + } + if v := testutil.ToFloat64(telemetry.RouteDecisionsTotal.WithLabelValues("credentials", "company-b")); v != 0 { + t.Errorf("credentials+company-b leaked: %v", v) + } +} + +// TestMetrics_RouteDecision_MultipleForwardTargets verifies each named target +// gets its own counter cell (no aliasing across targets). +func TestMetrics_RouteDecision_MultipleForwardTargets(t *testing.T) { + telemetry.ResetMetrics(t) + + a := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer a.Close() + b := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer b.Close() + + // Drive a request to "a" + { + plugin := &routerPlugin{action: &sdk.RouteAction{ForwardTo: "a"}, actionSet: true} + cfg := testConfig() + cfg.Plugin = plugin + cfg.ForwardTargets = map[string]config.ForwardTargetConfig{ + "a": {URL: a.URL, Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}}, + "b": {URL: b.URL, Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}}, + } + srv := mustNewServerForTarget(t, cfg, a.URL) + srv.Handler().ServeHTTP(httptest.NewRecorder(), newProxyRequest(t, a.URL+"/x")) + } + // Drive two requests to "b" + { + plugin := &routerPlugin{action: &sdk.RouteAction{ForwardTo: "b"}, actionSet: true} + cfg := testConfig() + cfg.Plugin = plugin + cfg.ForwardTargets = map[string]config.ForwardTargetConfig{ + "a": {URL: a.URL, Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}}, + "b": {URL: b.URL, Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}}, + } + srv := mustNewServerForTarget(t, cfg, b.URL) + srv.Handler().ServeHTTP(httptest.NewRecorder(), newProxyRequest(t, b.URL+"/x")) + srv.Handler().ServeHTTP(httptest.NewRecorder(), newProxyRequest(t, b.URL+"/x")) + } + + if v := testutil.ToFloat64(telemetry.RouteDecisionsTotal.WithLabelValues("forward", "a")); v != 1 { + t.Errorf("forward/a counter = %v, want 1", v) + } + if v := testutil.ToFloat64(telemetry.RouteDecisionsTotal.WithLabelValues("forward", "b")); v != 2 { + t.Errorf("forward/b counter = %v, want 2", v) + } +} diff --git a/internal/proxy/security_test.go b/internal/proxy/security_test.go index c445356..43b2b54 100644 --- a/internal/proxy/security_test.go +++ b/internal/proxy/security_test.go @@ -4,9 +4,7 @@ package proxy import ( - "bytes" "errors" - "log/slog" "net/url" "strings" "testing" @@ -80,11 +78,7 @@ func TestLogStartup_InsecureTargetsEnabled_EmitsWarning(t *testing.T) { cleanup := SetAllowInsecureTargetsForTesting(true) defer cleanup() - var logBuffer bytes.Buffer - logger := slog.New(slog.NewJSONHandler(&logBuffer, nil)) - originalLogger := slog.Default() - slog.SetDefault(logger) - defer slog.SetDefault(originalLogger) + getLogs := captureLogs(t) srv := &Server{ config: Config{ @@ -97,7 +91,7 @@ func TestLogStartup_InsecureTargetsEnabled_EmitsWarning(t *testing.T) { srv.logStartup() // Assert - logOutput := logBuffer.String() + logOutput := getLogs() if !strings.Contains(logOutput, "INSECURE") { t.Errorf("expected INSECURE warning in log output, got: %s", logOutput) } @@ -111,11 +105,7 @@ func TestLogStartup_InsecureTargetsDisabled_NoWarning(t *testing.T) { cleanup := SetAllowInsecureTargetsForTesting(false) defer cleanup() - var logBuffer bytes.Buffer - logger := slog.New(slog.NewJSONHandler(&logBuffer, nil)) - originalLogger := slog.Default() - slog.SetDefault(logger) - defer slog.SetDefault(originalLogger) + getLogs := captureLogs(t) srv := &Server{ config: Config{ @@ -128,7 +118,7 @@ func TestLogStartup_InsecureTargetsDisabled_NoWarning(t *testing.T) { srv.logStartup() // Assert - logOutput := logBuffer.String() + logOutput := getLogs() if strings.Contains(logOutput, "INSECURE") { t.Errorf("unexpected INSECURE warning in log output: %s", logOutput) } diff --git a/internal/proxy/server.go b/internal/proxy/server.go index 53c0208..1d13243 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -88,6 +88,12 @@ type Config struct { // Empty defaults to host-only, the safest behavior. LogTargetAddrMode observability.TargetAddrMode + // ForwardTargets describes named forward upstreams. One ForwardProxy is + // built per entry at startup and cached on the Server, keyed by name. + // Routers reference these targets by name via sdk.RouteAction. May be + // nil or empty when no router is registered. + ForwardTargets map[string]config.ForwardTargetConfig + // Timeouts ReadTimeout time.Duration WriteTimeout time.Duration @@ -105,6 +111,17 @@ type Server struct { httpSrv *http.Server transport *http.Transport + // forwardProxies holds one ForwardProxy per named entry in + // config.ForwardTargets, built once at startup and reused across requests. + // The map is always non-nil after NewServer (possibly empty) so callers + // can perform direct map lookups without a nil guard. + forwardProxies map[string]*ForwardProxy + + // router is set if the plugin implements sdk.RequestRouter. When non-nil, + // it is consulted before credential injection to decide whether to route + // the request to a forward target instead of the main vendor flow. + router sdk.RequestRouter + // started guards against calling Start() more than once, which would // panic on double-close of the ready channel. started atomic.Bool @@ -144,14 +161,55 @@ func NewServer(cfg Config) (*Server, error) { t.ResponseHeaderTimeout = cfg.ReadTimeout t.IdleConnTimeout = cfg.IdleTimeout + forwardProxies, err := buildForwardProxies(cfg.ForwardTargets) + if err != nil { + return nil, err + } + + // Cross-validate forward references: if the plugin implements ForwardReferences(), + // ensure all referenced targets exist in the config. + if err := validateForwardReferences(cfg.Plugin, cfg.ForwardTargets); err != nil { + return nil, err + } + + // Log warning for any forward_target entries that are never referenced + // (only if the plugin can report its references). + warnUnusedForwardTargets(cfg.Plugin, cfg.ForwardTargets) + + // Detect if the plugin implements RequestRouter capability + var requestRouter sdk.RequestRouter + if cfg.Plugin != nil { + if r, ok := cfg.Plugin.(sdk.RequestRouter); ok { + requestRouter = r + } + } + return &Server{ - config: cfg, - reflector: security.NewReflector(sensitiveHeaders), - transport: t, - ready: make(chan struct{}), + config: cfg, + reflector: security.NewReflector(sensitiveHeaders), + transport: t, + forwardProxies: forwardProxies, + router: requestRouter, + ready: make(chan struct{}), }, nil } +// buildForwardProxies constructs one *ForwardProxy per configured forward +// target, keyed by the target's name. The returned map is always non-nil so +// callers can perform direct lookups without a nil guard. Any failure to +// build a target is surfaced as an error mentioning the offending name. +func buildForwardProxies(targets map[string]config.ForwardTargetConfig) (map[string]*ForwardProxy, error) { + fps := make(map[string]*ForwardProxy, len(targets)) + for name, t := range targets { + fp, err := NewForwardProxy(name, t) + if err != nil { + return nil, fmt.Errorf("build forward proxy %q: %w", name, err) + } + fps[name] = fp + } + return fps, nil +} + // validateProxyConfig validates that all required proxy configuration fields // are set. This is defense-in-depth: the config loader validates too, but // NewServer may be called directly by tests or Distributor code. @@ -516,11 +574,49 @@ func (s *Server) handleProxy(w http.ResponseWriter, r *http.Request) { return } - // Parse the target URL once. If parsing fails, target_addr defaults to "" - // (consistent with FormatTargetAddr's behavior for malformed input) so - // the DEBUG breadcrumb still fires before the bad-request response. + targetURL, targetAddr, ok := s.parseAndValidateTarget(w, traceID, txCtx) + if !ok { + return + } + + // Router branch: consult the plugin's RequestRouter (if any) BEFORE + // credential injection. If routeAndMaybeForward dispatches the request + // to a forward target (or fails), handled is true and we return. + if handled := s.routeAndMaybeForward(w, r, traceID, txCtx, targetAddr); handled { + return + } + + // Fall-through: credential injection + vendor call. Log BEFORE + // injectCredentials so the routed-action breadcrumb is recorded even when + // credential injection fails. + slog.Info("request routed", + "trace_id", traceID, + "vendor_id", txCtx.VendorID, + "marketplace_id", txCtx.MarketplaceID, + "action", "credentials", + ) + // Use empty-string target for credentials path: there is no named forward + // target. Empty string is the natural "absence" value and avoids reserving + // a sentinel label like "vendor". + telemetry.RouteDecisionsTotal.WithLabelValues("credentials", "").Inc() + + r, err = s.injectCredentials(r, txCtx, targetAddr) + if err != nil { + s.handlePluginError(w, traceID, txCtx, targetAddr, err) + return + } + + //nolint:contextcheck // ModifyResponse uses resp.Request.Context() internally + s.forwardRequest(w, r, targetURL, traceID, txCtx, targetAddr) +} + +// parseAndValidateTarget parses the request target URL from the transaction +// context and validates its scheme. On any failure it writes the appropriate +// error response and returns ok=false; callers MUST return immediately when +// ok is false. The DEBUG "transaction context parsed" breadcrumb is emitted +// unconditionally so the trace is observable even when validation fails. +func (s *Server) parseAndValidateTarget(w http.ResponseWriter, traceID string, txCtx *sdk.TransactionContext) (target *url.URL, targetAddr string, ok bool) { targetURL, parseErr := url.Parse(txCtx.TargetURL) - var targetAddr string if parseErr == nil { targetAddr = observability.FormatTargetAddrFromURL(targetURL, s.config.LogTargetAddrMode) } @@ -535,19 +631,18 @@ func (s *Server) handleProxy(w http.ResponseWriter, r *http.Request) { if parseErr != nil { s.respondBadRequest(w, traceID, "invalid target URL", parseErr) - return + return nil, "", false } // SECURITY: Validate target URL scheme (HTTPS required in production) - err = ValidateTargetScheme(targetURL) - if err != nil { + if err := ValidateTargetScheme(targetURL); err != nil { slog.Warn("insecure target URL rejected", "trace_id", traceID, "target_scheme", targetURL.Scheme, "target_addr", targetAddr, ) http.Error(w, "Bad Request: "+err.Error(), http.StatusBadRequest) - return + return nil, "", false } // Warn if using HTTP in development mode @@ -558,14 +653,49 @@ func (s *Server) handleProxy(w http.ResponseWriter, r *http.Request) { ) } - r, err = s.injectCredentials(r, txCtx, targetAddr) + return targetURL, targetAddr, true +} + +// routeAndMaybeForward consults the plugin's RequestRouter (when present) and +// dispatches the request to a configured ForwardProxy when the router returns +// a non-empty ForwardTo. Returns true when the response has been fully +// written (success, router error, or unknown-target error), in which case the +// caller MUST NOT continue with credential injection / vendor forwarding. +// Returns false when the caller should fall through to the credential flow. +func (s *Server) routeAndMaybeForward(w http.ResponseWriter, r *http.Request, traceID string, txCtx *sdk.TransactionContext, targetAddr string) bool { + if s.router == nil { + return false + } + + action, err := s.router.RouteRequest(r.Context(), *txCtx, r) if err != nil { s.handlePluginError(w, traceID, txCtx, targetAddr, err) - return + return true + } + if action == nil || action.ForwardTo == "" { + return false } - //nolint:contextcheck // ModifyResponse uses resp.Request.Context() internally - s.forwardRequest(w, r, targetURL, traceID, txCtx, targetAddr) + fp, ok := s.forwardProxies[action.ForwardTo] + if !ok { + slog.Error("forward_target not found", + "trace_id", traceID, + "target", action.ForwardTo, + ) + respondError(w, http.StatusInternalServerError, "internal configuration error") + return true + } + + slog.Info("request routed", + "trace_id", traceID, + "vendor_id", txCtx.VendorID, + "marketplace_id", txCtx.MarketplaceID, + "action", "forward", + "target", action.ForwardTo, + ) + telemetry.RouteDecisionsTotal.WithLabelValues("forward", action.ForwardTo).Inc() + fp.ServeHTTP(w, r) + return true } // respondBadRequest logs and responds with a 400 Bad Request. @@ -960,3 +1090,69 @@ func (s *Server) applyErrorNormalization(traceID string, txCtx *sdk.TransactionC // Continue even if normalization fails - response will be sent as-is } } + +// validateForwardReferences checks that all forward target references in the +// plugin (if it implements the ForwardReferences() method) are defined in +// the forward targets configuration. +// +// If the plugin implements the interface { ForwardReferences() []string }, +// each name returned is verified to exist in cfg.ForwardTargets. If any +// reference is missing, an error is returned. If the plugin does not +// implement the interface, validation is skipped (we can't know what targets +// it uses). +func validateForwardReferences(plugin sdk.Plugin, targets map[string]config.ForwardTargetConfig) error { + if plugin == nil { + return nil + } + + // Anonymous interface type assertion: check if plugin has ForwardReferences() + lister, ok := plugin.(interface{ ForwardReferences() []string }) + if !ok { + // Plugin doesn't expose forward references — nothing to validate + return nil + } + + refs := lister.ForwardReferences() + for _, ref := range refs { + if _, ok := targets[ref]; !ok { + return fmt.Errorf("plugin references forward_target %q which is not defined in config", ref) + } + } + + return nil +} + +// warnUnusedForwardTargets logs a warning for any forward_target entry that +// is not referenced by the plugin. Only runs if the plugin implements the +// ForwardReferences() method. +// +// This helps catch configuration errors where targets are defined but never +// actually used by any route. +func warnUnusedForwardTargets(plugin sdk.Plugin, targets map[string]config.ForwardTargetConfig) { + if plugin == nil { + return + } + + lister, ok := plugin.(interface{ ForwardReferences() []string }) + if !ok { + // Plugin doesn't expose forward references — can't determine which + // targets are unused, so skip the warning. + return + } + + refs := lister.ForwardReferences() + // Build a set of referenced targets (deduplicated) + referenced := make(map[string]bool) + for _, ref := range refs { + referenced[ref] = true + } + + // Warn about any configured targets that are not referenced + for name := range targets { + if !referenced[name] { + slog.Warn("forward_target defined but never referenced", + "target", name, + ) + } + } +} diff --git a/internal/proxy/server_test.go b/internal/proxy/server_test.go index 34f011b..7d37cf0 100644 --- a/internal/proxy/server_test.go +++ b/internal/proxy/server_test.go @@ -4,8 +4,9 @@ package proxy_test import ( - "bytes" + "context" "encoding/json" + "errors" "io" "log/slog" "net/http" @@ -14,8 +15,10 @@ import ( "strings" "testing" + "github.com/cloudblue/chaperone/internal/config" "github.com/cloudblue/chaperone/internal/observability" "github.com/cloudblue/chaperone/internal/proxy" + "github.com/cloudblue/chaperone/sdk" ) func TestHealth_ReturnsAlive(t *testing.T) { @@ -362,11 +365,7 @@ func TestProxy_MethodPassthrough_ForwardsOriginalMethod(t *testing.T) { // 2. RequestLoggerMiddleware's defer logs with the correct status (500, not 200) func TestMiddlewareStack_PanicLogsCorrectStatus(t *testing.T) { // Arrange - capture log output - var logBuffer bytes.Buffer - logger := slog.New(slog.NewJSONHandler(&logBuffer, nil)) - originalLogger := slog.Default() - slog.SetDefault(logger) - defer slog.SetDefault(originalLogger) + getLogs := captureLogs(t) // Handler that panics panicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -391,7 +390,7 @@ func TestMiddlewareStack_PanicLogsCorrectStatus(t *testing.T) { } // Assert - log should contain status 500 - logOutput := logBuffer.String() + logOutput := getLogs() if !strings.Contains(logOutput, `"status":500`) { t.Errorf("log should contain status 500, got: %s", logOutput) } @@ -407,11 +406,7 @@ func TestMiddlewareStack_PanicLogsCorrectStatus(t *testing.T) { // (no panic) log the correct status code through the real middleware stack. func TestMiddlewareStack_NormalRequestLogsCorrectStatus(t *testing.T) { // Arrange - capture log output - var logBuffer bytes.Buffer - logger := slog.New(slog.NewJSONHandler(&logBuffer, nil)) - originalLogger := slog.Default() - slog.SetDefault(logger) - defer slog.SetDefault(originalLogger) + getLogs := captureLogs(t) // Handler that returns 201 Created handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -435,7 +430,7 @@ func TestMiddlewareStack_NormalRequestLogsCorrectStatus(t *testing.T) { } // Assert - log should contain status 201 - logOutput := logBuffer.String() + logOutput := getLogs() if !strings.Contains(logOutput, `"status":201`) { t.Errorf("log should contain status 201, got: %s", logOutput) } @@ -445,11 +440,7 @@ func TestMiddlewareStack_NormalRequestLogsCorrectStatus(t *testing.T) { // TraceIDMiddleware, the panic log includes the trace ID from context. func TestPanicRecovery_LogsTraceID(t *testing.T) { // Arrange - capture log output - var logBuffer bytes.Buffer - logger := slog.New(slog.NewJSONHandler(&logBuffer, nil)) - originalLogger := slog.Default() - slog.SetDefault(logger) - defer slog.SetDefault(originalLogger) + getLogs := captureLogs(t) panicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { panic("trace-id panic test") @@ -467,7 +458,7 @@ func TestPanicRecovery_LogsTraceID(t *testing.T) { handler.ServeHTTP(rec, req) // Assert - panic log should contain trace_id - logOutput := logBuffer.String() + logOutput := getLogs() if !strings.Contains(logOutput, `"trace_id":"panic-with-trace-789"`) { t.Errorf("panic log should contain trace_id, got: %s", logOutput) } @@ -786,6 +777,570 @@ func TestProxy_InvalidTargetURL_Returns400(t *testing.T) { } } +// ============================================================================= +// Forward proxy registry tests +// ============================================================================= + +// TestServer_BuildsForwardProxies_AtStartup is the spec-required test. It +// verifies that the named target "company-b" is built into the registry when +// NewServer returns. +func TestServer_BuildsForwardProxies_AtStartup(t *testing.T) { + t.Parallel() + + cfg := testConfig() + cfg.ForwardTargets = map[string]config.ForwardTargetConfig{ + "company-b": { + URL: "https://company-b.example/ingress", + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }, + } + + srv, err := proxy.NewServer(cfg) + if err != nil { + t.Fatalf("NewServer: %v", err) + } + + if srv.ForwardProxyForTesting("company-b") == nil { + t.Errorf("forwardProxies[company-b] not built") + } +} + +// TestServer_BuildsForwardProxies_Matrix covers the full behavior matrix: +// - zero targets → registry is non-nil and empty +// - one target → built and accessible +// - multiple targets → all built +// - invalid URL → NewServer returns error mentioning the offending name +// - bearer auth → built and accessible +func TestServer_BuildsForwardProxies_Matrix(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + targets map[string]config.ForwardTargetConfig + wantBuilt []string // names that must be present after NewServer + wantErrContain string // if non-empty, NewServer must fail with this substring + }{ + { + name: "zero targets — registry is non-nil and empty", + targets: nil, + wantBuilt: nil, + }, + { + name: "one target — built and accessible by name", + targets: map[string]config.ForwardTargetConfig{ + "company-b": { + URL: "https://company-b.example/ingress", + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }, + }, + wantBuilt: []string{"company-b"}, + }, + { + name: "multiple targets — all built and accessible", + targets: map[string]config.ForwardTargetConfig{ + "company-b": { + URL: "https://company-b.example/ingress", + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }, + "company-c": { + URL: "https://company-c.example/api", + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthBearer, Token: "tok"}, + }, + "company-d": { + URL: "https://company-d.example", + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }, + }, + wantBuilt: []string{"company-b", "company-c", "company-d"}, + }, + { + name: "invalid URL — NewServer returns error mentioning the offending name", + targets: map[string]config.ForwardTargetConfig{ + "broken": { + URL: ":::not a url", + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }, + }, + wantErrContain: "broken", + }, + { + name: "bearer auth — built and accessible", + targets: map[string]config.ForwardTargetConfig{ + "company-x": { + URL: "https://company-x.example/ingress", + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthBearer, Token: "secret-token"}, + }, + }, + wantBuilt: []string{"company-x"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + cfg := testConfig() + cfg.ForwardTargets = tt.targets + + srv, err := proxy.NewServer(cfg) + + if tt.wantErrContain != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.wantErrContain) + } + if !strings.Contains(err.Error(), tt.wantErrContain) { + t.Errorf("error = %q, want to contain %q", err.Error(), tt.wantErrContain) + } + return + } + + if err != nil { + t.Fatalf("NewServer: %v", err) + } + + // Registry must be non-nil even when no targets are configured. + if srv.ForwardProxiesNilForTesting() { + t.Error("forwardProxies map must be non-nil") + } + + if got, want := srv.ForwardProxyCountForTesting(), len(tt.wantBuilt); got != want { + t.Errorf("forward proxy count = %d, want %d", got, want) + } + + for _, name := range tt.wantBuilt { + if srv.ForwardProxyForTesting(name) == nil { + t.Errorf("forwardProxies[%q] not built", name) + } + } + }) + } +} + +// ============================================================================= +// RequestRouter Detection Tests +// ============================================================================= + +func TestServer_DetectsRequestRouter(t *testing.T) { + t.Parallel() + + // Arrange + plugin := &routerPlugin{} // implements sdk.Plugin + sdk.RequestRouter + cfg := testConfig() + cfg.Plugin = plugin + + // Act + srv, err := proxy.NewServer(cfg) + + // Assert + if err != nil { + t.Fatalf("NewServer: %v", err) + } + + if srv.RouterForTesting() == nil { + t.Fatal("router not detected on plugin implementing sdk.RequestRouter") + } + + // Verify the router is the same as the plugin + if srv.RouterForTesting() != plugin { + t.Error("router should be the same instance as the plugin") + } +} + +func TestServer_NoRouter_WhenPluginDoesNotImplement(t *testing.T) { + t.Parallel() + + // Arrange + plugin := &plainPlugin{} // implements sdk.Plugin only + cfg := testConfig() + cfg.Plugin = plugin + + // Act + srv, err := proxy.NewServer(cfg) + + // Assert + if err != nil { + t.Fatalf("NewServer: %v", err) + } + + if srv.RouterForTesting() != nil { + t.Fatal("router should be nil for plugin without RequestRouter") + } +} + +func TestServer_NoRouter_WhenNoPluginConfigured(t *testing.T) { + t.Parallel() + + // Arrange + cfg := testConfig() + cfg.Plugin = nil + + // Act + srv, err := proxy.NewServer(cfg) + + // Assert + if err != nil { + t.Fatalf("NewServer: %v", err) + } + + if srv.RouterForTesting() != nil { + t.Fatal("router should be nil when no plugin is configured") + } +} + +func TestServer_RouterIsAccessible_WhenImplemented(t *testing.T) { + t.Parallel() + + // Arrange + plugin := &routerPlugin{} // implements both sdk.Plugin and sdk.RequestRouter + cfg := testConfig() + cfg.Plugin = plugin + + // Act + srv, err := proxy.NewServer(cfg) + + // Assert + if err != nil { + t.Fatalf("NewServer: %v", err) + } + + routerIface := srv.RouterForTesting() + if routerIface == nil { + t.Fatal("router should not be nil") + } + + // Type assert to sdk.RequestRouter + router, ok := routerIface.(sdk.RequestRouter) + if !ok { + t.Fatal("router should be an sdk.RequestRouter") + } + + // Create a test request to verify the router is callable + req := httptest.NewRequest(http.MethodPost, "https://vendor.example/api", nil) + ctx := req.Context() + + // Act - call RouteRequest on the retrieved router + tx := testTransactionContext() + action, err := router.RouteRequest(ctx, tx, req) + + // Assert - for the test plugin, should return a RouteAction with ForwardTo="test-target" + if err != nil { + t.Fatalf("RouteRequest: %v", err) + } + if action == nil || action.ForwardTo != "test-target" { + t.Errorf("action.ForwardTo = %q, want %q", action.ForwardTo, "test-target") + } +} + +// ============================================================================= +// handleProxy router-branch tests (Task 7) +// ============================================================================= + +// newProxyRequest builds a /proxy request with the minimum X-Connect-* headers +// needed for handleProxy to reach the router branch. +func newProxyRequest(t *testing.T, targetURL string) *http.Request { + t.Helper() + req := httptest.NewRequest(http.MethodPost, "/proxy", nil) + req.Header.Set("X-Connect-Target-URL", targetURL) + req.Header.Set("X-Connect-Vendor-ID", "test-vendor") + req.Header.Set("X-Connect-Marketplace-ID", "test-marketplace") + return req +} + +func TestHandleProxy_ForwardAction_DispatchesToForwardProxy_AndSkipsCredentials(t *testing.T) { + var hitTarget bool + target := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + hitTarget = true + })) + defer target.Close() + + var injectedCreds bool + plugin := &routerPlugin{ + action: &sdk.RouteAction{ForwardTo: "company-b"}, + actionSet: true, + onGetCredentials: func() { injectedCreds = true }, + } + + cfg := testConfig() + cfg.Plugin = plugin + cfg.ForwardTargets = map[string]config.ForwardTargetConfig{ + "company-b": {URL: target.URL, Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}}, + } + srv := mustNewServerForTarget(t, cfg, target.URL) + + req := newProxyRequest(t, target.URL+"/v1/foo") + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, req) + + if !hitTarget { + t.Error("forward target was not called") + } + if injectedCreds { + t.Error("GetCredentials was called for a forwarded request") + } +} + +func TestHandleProxy_NilRouteAction_FallsThroughToCredentialFlow(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + var injectedCreds bool + plugin := &routerPlugin{ + action: nil, // fall through + actionSet: true, + onGetCredentials: func() { injectedCreds = true }, + } + + cfg := testConfig() + cfg.Plugin = plugin + srv := mustNewServerForTarget(t, cfg, backend.URL) + + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, newProxyRequest(t, backend.URL)) + + if !injectedCreds { + t.Error("GetCredentials should have been called for fall-through") + } +} + +func TestHandleProxy_UnknownForwardTarget_Returns500(t *testing.T) { + plugin := &routerPlugin{ + action: &sdk.RouteAction{ForwardTo: "missing"}, + actionSet: true, + } + cfg := testConfig() + cfg.Plugin = plugin + srv := mustNewServer(t, cfg) + + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, newProxyRequest(t, "https://api.vendor.example/v1/foo")) + + if rec.Code != http.StatusInternalServerError { + t.Errorf("status = %d, want 500", rec.Code) + } +} + +func TestHandleProxy_RouterError_Returns500(t *testing.T) { + plugin := &routerPlugin{ + actionErr: errors.New("router blew up"), + } + cfg := testConfig() + cfg.Plugin = plugin + srv := mustNewServer(t, cfg) + + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, newProxyRequest(t, "https://api.vendor.example/v1/foo")) + + // handlePluginError maps a generic plugin error to 500. + if rec.Code != http.StatusInternalServerError { + t.Errorf("status = %d, want %d", rec.Code, http.StatusInternalServerError) + } + // Defense: the wire response must not leak the router's internal + // error message verbatim. handlePluginError writes a generic body. + if strings.Contains(rec.Body.String(), "router blew up") { + t.Errorf("response body leaked router error: %s", rec.Body.String()) + } +} + +func TestHandleProxy_RouteActionEmptyForwardTo_FallsThroughToCredentialFlow(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + var injectedCreds bool + plugin := &routerPlugin{ + action: &sdk.RouteAction{ForwardTo: ""}, // empty == fall-through + actionSet: true, + onGetCredentials: func() { injectedCreds = true }, + } + + cfg := testConfig() + cfg.Plugin = plugin + srv := mustNewServerForTarget(t, cfg, backend.URL) + + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, newProxyRequest(t, backend.URL)) + + if !injectedCreds { + t.Error("empty ForwardTo should fall through to credential flow") + } +} + +func TestHandleProxy_PluginWithoutRequestRouter_GoesDirectlyToCredentials(t *testing.T) { + getLogs := captureLogs(t) + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + plugin := &plainPlugin{} // no RequestRouter + cfg := testConfig() + cfg.Plugin = plugin + srv := mustNewServerForTarget(t, cfg, backend.URL) + + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, newProxyRequest(t, backend.URL)) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want 200", rec.Code) + } + + out := getLogs() + if strings.Contains(out, `"action":"forward"`) { + t.Errorf("plain plugin must not log action=forward, got: %s", out) + } + if !strings.Contains(out, `"action":"credentials"`) { + t.Errorf("expected action=credentials log line, got: %s", out) + } +} + +func TestHandleProxy_ForwardedResponse_PropagatedToClient(t *testing.T) { + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("X-Forward-Echo", "from-target") + w.WriteHeader(http.StatusTeapot) + _, _ = io.WriteString(w, `{"forwarded":true}`) + })) + defer target.Close() + + plugin := &routerPlugin{ + action: &sdk.RouteAction{ForwardTo: "company-b"}, + actionSet: true, + } + + cfg := testConfig() + cfg.Plugin = plugin + cfg.ForwardTargets = map[string]config.ForwardTargetConfig{ + "company-b": {URL: target.URL, Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}}, + } + srv := mustNewServerForTarget(t, cfg, target.URL) + + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, newProxyRequest(t, target.URL+"/v1/foo")) + + if rec.Code != http.StatusTeapot { + t.Errorf("status = %d, want %d", rec.Code, http.StatusTeapot) + } + if got := rec.Header().Get("X-Forward-Echo"); got != "from-target" { + t.Errorf("X-Forward-Echo = %q, want %q", got, "from-target") + } + if body := rec.Body.String(); !strings.Contains(body, `"forwarded":true`) { + t.Errorf("body = %q, want to contain forwarded payload", body) + } +} + +func TestHandleProxy_ForwardPath_LogsActionForward(t *testing.T) { + getLogs := captureLogs(t) + + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer target.Close() + + plugin := &routerPlugin{ + action: &sdk.RouteAction{ForwardTo: "company-b"}, + actionSet: true, + } + + cfg := testConfig() + cfg.Plugin = plugin + cfg.ForwardTargets = map[string]config.ForwardTargetConfig{ + "company-b": {URL: target.URL, Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}}, + } + srv := mustNewServerForTarget(t, cfg, target.URL) + + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, newProxyRequest(t, target.URL+"/v1/foo")) + + out := getLogs() + if !strings.Contains(out, `"action":"forward"`) { + t.Errorf("expected action=forward log line, got: %s", out) + } + if !strings.Contains(out, `"target":"company-b"`) { + t.Errorf("expected target=company-b log line, got: %s", out) + } + if strings.Contains(out, `"action":"credentials"`) { + t.Errorf("forward path must not emit action=credentials, got: %s", out) + } +} + +// Test doubles for RequestRouter detection tests + +// plainPlugin implements sdk.Plugin but NOT sdk.RequestRouter +type plainPlugin struct{} + +func (p *plainPlugin) GetCredentials(ctx context.Context, tx sdk.TransactionContext, req *http.Request) (*sdk.Credential, error) { + return nil, nil +} + +func (p *plainPlugin) SignCSR(ctx context.Context, csrPEM []byte) ([]byte, error) { + return nil, nil +} + +func (p *plainPlugin) ModifyResponse(ctx context.Context, tx sdk.TransactionContext, resp *http.Response) (*sdk.ResponseAction, error) { + return nil, nil +} + +var _ sdk.Plugin = (*plainPlugin)(nil) + +// routerPlugin implements both sdk.Plugin and sdk.RequestRouter. +// +// Zero value (routerPlugin{}) preserves the legacy Task 6 behavior: +// RouteRequest returns &sdk.RouteAction{ForwardTo: "test-target"}. +// +// Tests may override behavior by setting any of: +// - action / actionErr → returned from RouteRequest +// - actionSet → when true, the (action, actionErr) pair is used +// verbatim even if action is nil (so callers can express a nil-action +// fall-through explicitly) +// - onGetCredentials → invoked when GetCredentials is called (records +// whether the credential path ran) +type routerPlugin struct { + action *sdk.RouteAction + actionErr error + actionSet bool + onGetCredentials func() +} + +func (r *routerPlugin) GetCredentials(ctx context.Context, tx sdk.TransactionContext, req *http.Request) (*sdk.Credential, error) { + if r.onGetCredentials != nil { + r.onGetCredentials() + } + return nil, nil +} + +func (r *routerPlugin) SignCSR(ctx context.Context, csrPEM []byte) ([]byte, error) { + return nil, nil +} + +func (r *routerPlugin) ModifyResponse(ctx context.Context, tx sdk.TransactionContext, resp *http.Response) (*sdk.ResponseAction, error) { + return nil, nil +} + +func (r *routerPlugin) RouteRequest(ctx context.Context, tx sdk.TransactionContext, req *http.Request) (*sdk.RouteAction, error) { + if r.actionSet || r.actionErr != nil { + return r.action, r.actionErr + } + // Legacy default for Task 6 tests. + return &sdk.RouteAction{ForwardTo: "test-target"}, nil +} + +var _ sdk.Plugin = (*routerPlugin)(nil) +var _ sdk.RequestRouter = (*routerPlugin)(nil) + +// testTransactionContext returns a minimal TransactionContext for testing +func testTransactionContext() sdk.TransactionContext { + return sdk.TransactionContext{ + TraceID: "test-trace-123", + VendorID: "test-vendor", + MarketplaceID: "test-marketplace", + ProductID: "test-product", + TargetURL: "https://vendor.example/api", + } +} + // Helper functions for creating test files func createDummyFile(path string) error { @@ -814,3 +1369,340 @@ func createFile(path string) (*file, error) { type file = os.File var osCreate = os.Create + +// ============================================================================= +// Forward Reference Validation Tests (Task 13) +// ============================================================================= + +func TestRun_ForwardActionReferencingUnknownTarget_FailsAtStartup(t *testing.T) { + t.Parallel() + + // Arrange - Mux with forward action referencing non-existent target + mux := newTestMux() + mux.HandleForward(newTestRoute("vendor-x"), "missing-target") + + cfg := testConfig() + cfg.Plugin = mux + cfg.ForwardTargets = map[string]config.ForwardTargetConfig{ + "company-b": { + URL: "https://company-b.example/ingress", + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }, + } + + // Act + _, err := proxy.NewServer(cfg) + + // Assert + if err == nil { + t.Fatal("expected startup error for unknown forward target") + } + if !strings.Contains(err.Error(), "missing-target") { + t.Errorf("error should mention missing-target: %v", err) + } +} + +func TestRun_AllForwardReferencesValid_SucceedsAtStartup(t *testing.T) { + t.Parallel() + + // Arrange - Mux with forward actions all referencing known targets + mux := newTestMux() + mux.HandleForward(newTestRoute("vendor-a"), "company-b") + mux.HandleForward(newTestRoute("vendor-c"), "company-d") + + cfg := testConfig() + cfg.Plugin = mux + cfg.ForwardTargets = map[string]config.ForwardTargetConfig{ + "company-b": { + URL: "https://company-b.example/ingress", + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }, + "company-d": { + URL: "https://company-d.example", + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }, + } + + // Act + srv, err := proxy.NewServer(cfg) + + // Assert + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + if srv == nil { + t.Fatal("server should be created") + } +} + +func TestRun_PluginWithoutForwardReferencesMethods_NoValidationError(t *testing.T) { + t.Parallel() + + // Arrange - custom plugin without ForwardReferences() method + plugin := &plainPlugin{} // does not implement ForwardReferences() + + cfg := testConfig() + cfg.Plugin = plugin + // Configured forward_targets that are never referenced + cfg.ForwardTargets = map[string]config.ForwardTargetConfig{ + "unused-target": { + URL: "https://unused.example", + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }, + } + + // Act + srv, err := proxy.NewServer(cfg) + + // Assert - should succeed; no validation against custom plugins + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + if srv == nil { + t.Fatal("server should be created") + } +} + +func TestRun_MultipleReferencesOneUnknown_ErrorMentionsUnknownOne(t *testing.T) { + t.Parallel() + + // Arrange - Mux with multiple forward references, one is unknown + mux := newTestMux() + mux.HandleForward(newTestRoute("vendor-a"), "valid-target") + mux.HandleForward(newTestRoute("vendor-b"), "unknown-target") + mux.HandleForward(newTestRoute("vendor-c"), "another-valid") + + cfg := testConfig() + cfg.Plugin = mux + cfg.ForwardTargets = map[string]config.ForwardTargetConfig{ + "valid-target": { + URL: "https://valid.example", + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }, + "another-valid": { + URL: "https://another.example", + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }, + } + + // Act + _, err := proxy.NewServer(cfg) + + // Assert + if err == nil { + t.Fatal("expected error for unknown target") + } + if !strings.Contains(err.Error(), "unknown-target") { + t.Errorf("error should mention unknown-target, got: %v", err) + } + // Error should NOT require the valid targets to also be mentioned +} + +func TestRun_EmptyForwardTargetsWithMuxNoForwardRoutes_Success(t *testing.T) { + t.Parallel() + + // Arrange - Mux with only credential routes, no forward targets configured + mux := newTestMux() + mux.Handle(newTestRoute("vendor-a"), &plainPlugin{}) + + cfg := testConfig() + cfg.Plugin = mux + cfg.ForwardTargets = nil // or empty map + + // Act + srv, err := proxy.NewServer(cfg) + + // Assert + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + if srv == nil { + t.Fatal("server should be created") + } +} + +func TestRun_EmptyForwardTargetsWithMuxForwardRoute_FailsAtStartup(t *testing.T) { + t.Parallel() + + // Arrange - Mux with forward route but no targets configured + mux := newTestMux() + mux.HandleForward(newTestRoute("vendor-a"), "missing-target") + + cfg := testConfig() + cfg.Plugin = mux + cfg.ForwardTargets = nil // empty + + // Act + _, err := proxy.NewServer(cfg) + + // Assert + if err == nil { + t.Fatal("expected error for forward route with no targets") + } + if !strings.Contains(err.Error(), "missing-target") { + t.Errorf("error should mention missing-target, got: %v", err) + } +} + +func TestRun_UnusedForwardTarget_WarnsAtStartup(t *testing.T) { + t.Parallel() + + // Arrange - capture log output + getLogs := captureLogs(t) + + // Mux with one forward reference + mux := newTestMux() + mux.HandleForward(newTestRoute("vendor-a"), "used-target") + + // But config has two targets (one unused) + cfg := testConfig() + cfg.Plugin = mux + cfg.ForwardTargets = map[string]config.ForwardTargetConfig{ + "used-target": { + URL: "https://used.example", + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }, + "unused-target": { + URL: "https://unused.example", + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }, + } + + // Act + srv, err := proxy.NewServer(cfg) + + // Assert - must succeed + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + if srv == nil { + t.Fatal("server should be created") + } + + // Warning should be logged for unused target + logOut := getLogs() + if !strings.Contains(logOut, "unused-target") { + t.Errorf("expected warning about unused-target, got: %s", logOut) + } + if !strings.Contains(logOut, "forward_target") { + t.Errorf("expected 'forward_target' in warning, got: %s", logOut) + } +} + +func TestRun_AllForwardTargetsReferenced_NoWarning(t *testing.T) { + t.Parallel() + + // Arrange + getLogs := captureLogs(t) + + mux := newTestMux() + mux.HandleForward(newTestRoute("vendor-a"), "target-1") + mux.HandleForward(newTestRoute("vendor-b"), "target-2") + + cfg := testConfig() + cfg.Plugin = mux + cfg.ForwardTargets = map[string]config.ForwardTargetConfig{ + "target-1": { + URL: "https://target1.example", + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }, + "target-2": { + URL: "https://target2.example", + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }, + } + + // Act + _, err := proxy.NewServer(cfg) + + // Assert + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + + logOut := getLogs() + if strings.Contains(logOut, "forward_target defined but never referenced") { + t.Errorf("should not warn about unreferenced targets when all are referenced: %s", logOut) + } +} + +func TestRun_DuplicateForwardReferences_NotAnError(t *testing.T) { + t.Parallel() + + // Arrange - Mux with duplicate references to the same target + mux := newTestMux() + mux.HandleForward(newTestRoute("vendor-a"), "company-b") + mux.HandleForward(newTestRoute("vendor-b"), "company-b") // same target, different route + + cfg := testConfig() + cfg.Plugin = mux + cfg.ForwardTargets = map[string]config.ForwardTargetConfig{ + "company-b": { + URL: "https://company-b.example", + Auth: config.ForwardTargetAuthConfig{Type: config.ForwardAuthNone}, + }, + } + + // Act + srv, err := proxy.NewServer(cfg) + + // Assert - duplicates are fine + if err != nil { + t.Fatalf("expected no error for duplicate references, got: %v", err) + } + if srv == nil { + t.Fatal("server should be created") + } +} + +// Helper: creates a Mux compatible with the test contrib package. +// Since contrib is a separate module, we create a testMux that wraps it. +type testMux struct { + entries []testMuxEntry +} + +type testMuxEntry struct { + target string // for forward references + isRefs bool // true if this is a forward reference +} + +func newTestMux() *testMux { + return &testMux{} +} + +func (tm *testMux) Handle(route interface{}, provider interface{}) { + // Stub for testing +} + +func (tm *testMux) HandleForward(route interface{}, target string) { + tm.entries = append(tm.entries, testMuxEntry{target: target, isRefs: true}) +} + +// ForwardReferences returns the list of forward target references. +func (tm *testMux) ForwardReferences() []string { + refs := make([]string, 0, len(tm.entries)) + for _, e := range tm.entries { + if e.isRefs { + refs = append(refs, e.target) + } + } + return refs +} + +func (tm *testMux) GetCredentials(ctx context.Context, tx sdk.TransactionContext, req *http.Request) (*sdk.Credential, error) { + return nil, nil +} + +func (tm *testMux) SignCSR(ctx context.Context, csrPEM []byte) ([]byte, error) { + return nil, nil +} + +func (tm *testMux) ModifyResponse(ctx context.Context, tx sdk.TransactionContext, resp *http.Response) (*sdk.ResponseAction, error) { + return nil, nil +} + +var _ sdk.Plugin = (*testMux)(nil) + +func newTestRoute(vendorID string) interface{} { + return struct{ VendorID string }{VendorID: vendorID} +} diff --git a/internal/proxy/target_addr_integration_test.go b/internal/proxy/target_addr_integration_test.go index 31fcc31..ea70edb 100644 --- a/internal/proxy/target_addr_integration_test.go +++ b/internal/proxy/target_addr_integration_test.go @@ -32,11 +32,7 @@ import ( func targetAddrTestSetup(t *testing.T, mode observability.TargetAddrMode, target, requestURL string) []map[string]any { t.Helper() - var buf bytes.Buffer - logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug})) - original := slog.Default() - slog.SetDefault(logger) - t.Cleanup(func() { slog.SetDefault(original) }) + getLogs := captureLogsAt(t, &slog.HandlerOptions{Level: slog.LevelDebug}) backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) @@ -70,7 +66,7 @@ func targetAddrTestSetup(t *testing.T, mode observability.TargetAddrMode, target t.Fatalf("status = %d, want %d. body=%s", rec.Code, http.StatusOK, rec.Body.String()) } - return parseJSONLogLines(t, buf.Bytes()) + return parseJSONLogLines(t, []byte(getLogs())) } func parseJSONLogLines(t *testing.T, b []byte) []map[string]any { diff --git a/internal/security/reflector.go b/internal/security/reflector.go index 33e295e..258d7b1 100644 --- a/internal/security/reflector.go +++ b/internal/security/reflector.go @@ -9,6 +9,65 @@ import ( "strings" ) +// defaultSensitiveHeaders is the canonical list of headers that MUST be +// redacted in logs and stripped from responses. This is a security-critical +// default per Design Spec Section 5.3 and is the single source of truth used +// by both internal/config (for the merged sensitive_headers list applied to +// the vendor proxy path) and the forward proxy path (which uses these +// defaults verbatim). +// +// Do not duplicate this list elsewhere — consumers should call +// DefaultSensitiveHeaders or StripSensitiveResponseHeaders. +var defaultSensitiveHeaders = []string{ + "Authorization", + "Proxy-Authorization", + "Cookie", + "Set-Cookie", + "X-API-Key", + "X-Auth-Token", +} + +// defaultSensitiveHeadersSet is the lookup table built from +// defaultSensitiveHeaders, used by StripSensitiveResponseHeaders. Keys are +// stored lowercase for case-insensitive matching. +var defaultSensitiveHeadersSet = func() map[string]struct{} { + m := make(map[string]struct{}, len(defaultSensitiveHeaders)) + for _, h := range defaultSensitiveHeaders { + m[strings.ToLower(h)] = struct{}{} + } + return m +}() + +// DefaultSensitiveHeaders returns a fresh copy of the built-in static list +// of sensitive headers. The caller is free to mutate the returned slice. +func DefaultSensitiveHeaders() []string { + out := make([]string, len(defaultSensitiveHeaders)) + copy(out, defaultSensitiveHeaders) + return out +} + +// StripSensitiveResponseHeaders removes the built-in static set of sensitive +// headers (Authorization, Cookie, etc.) from headers in place. Matching is +// case-insensitive. +// +// This is the free-function counterpart to Reflector.StripResponseHeaders. +// Use it on code paths that do not have access to a configured Reflector +// (e.g., the forward proxy path), where the user-extended sensitive_headers +// list is intentionally not applied. +// +// Per Design Spec Section 5.3 "Credential Reflection Protection". +func StripSensitiveResponseHeaders(headers http.Header) { + var toDelete []string + for header := range headers { + if _, ok := defaultSensitiveHeadersSet[strings.ToLower(header)]; ok { + toDelete = append(toDelete, header) + } + } + for _, header := range toDelete { + headers.Del(header) + } +} + // Reflector handles stripping sensitive headers from HTTP responses. // Per Design Spec Section 5.3 "Credential Reflection Protection": // "The Proxy strips all Injection Headers (like Authorization) from the diff --git a/internal/security/reflector_test.go b/internal/security/reflector_test.go index 5c47422..5dcdcca 100644 --- a/internal/security/reflector_test.go +++ b/internal/security/reflector_test.go @@ -7,13 +7,11 @@ import ( "context" "net/http" "testing" - - "github.com/cloudblue/chaperone/internal/config" ) // testSensitiveHeaders returns the default sensitive headers for testing. func testSensitiveHeaders() []string { - return config.MergeSensitiveHeaders(nil) + return DefaultSensitiveHeaders() } func TestReflector_StripResponseHeaders_RemovesSensitiveHeaders(t *testing.T) { diff --git a/internal/telemetry/metrics.go b/internal/telemetry/metrics.go index 2da42dc..4da9437 100644 --- a/internal/telemetry/metrics.go +++ b/internal/telemetry/metrics.go @@ -89,6 +89,60 @@ var ( Help: "Total number of recovered panics", }, ) + + // RouteDecisionsTotal counts per-request routing decisions made by the + // RequestRouter (or the default credential flow when no router is registered). + // + // Labels: + // - action: "forward" or "credentials" + // - target: forward target name when action="forward"; empty string ("") + // when action="credentials" (no named forward target involved). + // Empty string is preferred over a sentinel like "vendor" so that + // dashboards can naturally aggregate the credentials path without + // introducing a reserved label value. + RouteDecisionsTotal = promauto.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "chaperone", + Name: "route_decisions_total", + Help: "Per-request routing decisions made by the RequestRouter (or default).", + }, + []string{"action", "target"}, + ) + + // ForwardTargetDuration measures end-to-end duration of requests forwarded + // to a named forward target. Bounded by the configured forward-target names, + // so cardinality is bounded by deployment config (tens of targets at most). + // + // Labels: target (forward target name) + ForwardTargetDuration = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "chaperone", + Name: "forward_target_duration_seconds", + Help: "End-to-end duration of requests forwarded to a named target.", + Buckets: APILatencyBuckets, + }, + []string{"target"}, + ) + + // ForwardTargetErrors counts infrastructure errors encountered while + // forwarding to a named target. Does NOT count 5xx responses returned by + // the target itself (those are target responses, not Chaperone errors). + // + // Labels: + // - target: forward target name + // - kind: error classification + // - "timeout" — context deadline / response-header timeout + // - "tls" — TLS handshake failure + // - "connection" — DNS failure, connection refused, reset, etc. + // - "other" — any other transport-level error + ForwardTargetErrors = promauto.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "chaperone", + Name: "forward_target_errors_total", + Help: "Errors encountered while forwarding to a named target.", + }, + []string{"target", "kind"}, + ) ) // DefaultVendorID is the label value used when the X-Connect-Vendor-ID header diff --git a/internal/telemetry/testing.go b/internal/telemetry/testing.go index 06f7e65..1add339 100644 --- a/internal/telemetry/testing.go +++ b/internal/telemetry/testing.go @@ -22,6 +22,9 @@ func ResetMetrics(t *testing.T) { RequestDuration.Reset() UpstreamDuration.Reset() ActiveConnections.Set(0) + RouteDecisionsTotal.Reset() + ForwardTargetDuration.Reset() + ForwardTargetErrors.Reset() // PanicsTotal is a Counter (not CounterVec), so it cannot be Reset(). // It will accumulate across tests, but tests should assert relative increments. @@ -30,5 +33,8 @@ func ResetMetrics(t *testing.T) { RequestDuration.Reset() UpstreamDuration.Reset() ActiveConnections.Set(0) + RouteDecisionsTotal.Reset() + ForwardTargetDuration.Reset() + ForwardTargetErrors.Reset() }) } diff --git a/plugins/contrib/action.go b/plugins/contrib/action.go new file mode 100644 index 0000000..effcf40 --- /dev/null +++ b/plugins/contrib/action.go @@ -0,0 +1,36 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package contrib + +import "github.com/cloudblue/chaperone/sdk" + +// Action is the sealed interface implemented by [CredentialAction] and +// [ForwardAction]. The unexported isAction method prevents implementations +// outside this package, so the mux can exhaustively reason about the two +// dispatch outcomes: credential injection vs. raw forwarding. +type Action interface { + isAction() +} + +// CredentialAction routes a matched request to a [sdk.CredentialProvider] +// for normal credential injection. This is the default action installed +// by [Mux.Handle]. +type CredentialAction struct { + Provider sdk.CredentialProvider +} + +func (CredentialAction) isAction() {} + +// ForwardAction routes a matched request to a named forward_target. The +// Mux returns a [sdk.RouteAction] with ForwardTo set to Target from +// RouteRequest, and the Core handles the actual forwarding (the request +// never reaches a credential provider). +// +// Target validation (non-empty, references an existing forward_target) +// happens at config-load / cross-validation time, not here. +type ForwardAction struct { + Target string +} + +func (ForwardAction) isAction() {} diff --git a/plugins/contrib/errors.go b/plugins/contrib/errors.go index 630147f..6109fbf 100644 --- a/plugins/contrib/errors.go +++ b/plugins/contrib/errors.go @@ -51,3 +51,16 @@ var ErrTokenEndpointUnavailable = errors.New("token endpoint unavailable") // This is returned by adapters (AsPlugin, Mux) when SignCSR is called // without a configured signer. var ErrSigningNotConfigured = errors.New("certificate signing not configured") + +// ErrUnexpectedForwardAction indicates the mux's GetCredentials reached a +// route whose action is a ForwardAction. This is a defensive sentinel: +// when a ForwardAction matches, RouteRequest should have short-circuited +// at the Core boundary and the request should never reach the credential +// path. Receiving this error means an integration bug — the caller wired +// the mux into the credential path without consulting RouteRequest first. +var ErrUnexpectedForwardAction = errors.New("matched route is a forward action; GetCredentials should not have been called") + +// ErrNilCredentialProvider indicates a CredentialAction was registered +// with a nil Provider. The mux refuses to call a nil provider so we get a +// clear error instead of a nil-pointer panic. +var ErrNilCredentialProvider = errors.New("credential action has nil provider") diff --git a/plugins/contrib/go.mod b/plugins/contrib/go.mod index f56f38b..ba18929 100644 --- a/plugins/contrib/go.mod +++ b/plugins/contrib/go.mod @@ -5,3 +5,5 @@ go 1.26.3 require github.com/cloudblue/chaperone/sdk v0.1.0 require golang.org/x/sync v0.20.0 + +require gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/plugins/contrib/go.sum b/plugins/contrib/go.sum index a1c36be..ea7245e 100644 --- a/plugins/contrib/go.sum +++ b/plugins/contrib/go.sum @@ -2,3 +2,6 @@ github.com/cloudblue/chaperone/sdk v0.1.0 h1:OsrqjLfcaP35eSRLzsmL1r6wYf4IkmE/WG1 github.com/cloudblue/chaperone/sdk v0.1.0/go.mod h1:p6JOMXPqVfm8EqvnyDAozgrmkvhfbs1O32am/dthnFc= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/plugins/contrib/mux.go b/plugins/contrib/mux.go index 42f444a..c1ba43d 100644 --- a/plugins/contrib/mux.go +++ b/plugins/contrib/mux.go @@ -7,20 +7,22 @@ import ( "context" "log/slog" "net/http" + "sort" "strings" "github.com/cloudblue/chaperone/sdk" ) -// Compile-time check that Mux implements sdk.Plugin. +// Compile-time checks that Mux implements its expected interfaces. var _ sdk.Plugin = (*Mux)(nil) +var _ sdk.RequestRouter = (*Mux)(nil) -// routeEntry binds a route pattern to its credential provider, +// routeEntry binds a route pattern to its dispatch action, // preserving registration order for tie-breaking. type routeEntry struct { - route Route - provider sdk.CredentialProvider - index int + route Route + action Action + index int } // Mux is a request multiplexer that dispatches incoming requests to @@ -86,6 +88,35 @@ func (m *Mux) log() *slog.Logger { // (e.g., VendorID "acme" vs "globex") are recognized as non-overlapping // and do not trigger a warning. func (m *Mux) Handle(route Route, provider sdk.CredentialProvider) { + m.add(route, CredentialAction{Provider: provider}) +} + +// HandleForward registers a route that, when matched, forwards the request +// to the named forward_target via the Core's forwarding path. The target +// name is opaque to the Mux beyond a non-empty check — validation that the +// name references an existing forward_target happens at config-load / +// cross-validation time. +// +// An empty target name is a programmer error (it cannot reference any +// forward_target) and panics immediately so the misconfiguration surfaces +// at registration rather than producing silent dead routes. Config-driven +// users get a clean error before reaching here via the loader's own +// non-empty check on the `forward:` field. +// +// Overlap warnings work the same way as for Handle: equal-specificity +// overlaps with any other entry (CredentialAction or ForwardAction) are +// logged. +func (m *Mux) HandleForward(route Route, target string) { + if target == "" { + panic("contrib.Mux.HandleForward: empty target name") + } + m.add(route, ForwardAction{Target: target}) +} + +// add appends a routeEntry, logging an overlap warning when an equal- +// specificity overlap with an existing entry is detected. Shared by +// Handle and HandleForward so the warning fires regardless of action type. +func (m *Mux) add(route Route, action Action) { newSpec := route.Specificity() for _, e := range m.entries { if e.route.Specificity() == newSpec && routesMayOverlap(e.route, route) { @@ -98,9 +129,9 @@ func (m *Mux) Handle(route Route, provider sdk.CredentialProvider) { } m.entries = append(m.entries, routeEntry{ - route: route, - provider: provider, - index: len(m.entries), + route: route, + action: action, + index: len(m.entries), }) } @@ -122,14 +153,31 @@ func (m *Mux) SetResponseModifier(modifier sdk.ResponseModifier) { } // GetCredentials dispatches the request to the most specific matching -// route's provider. If no route matches, it falls back to the default -// provider. Returns [ErrNoRouteMatch] if nothing matches and no default -// is configured. +// route's CredentialAction provider. If no route matches, it falls back +// to the default provider. Returns [ErrNoRouteMatch] if nothing matches +// and no default is configured. +// +// If the matched entry is a [ForwardAction], the caller should have +// short-circuited via RouteRequest at the Core boundary; receiving a +// ForwardAction here indicates an integration bug and returns +// [ErrUnexpectedForwardAction] defensively rather than silently +// returning nil credentials. func (m *Mux) GetCredentials(ctx context.Context, tx sdk.TransactionContext, req *http.Request) (*sdk.Credential, error) { best := m.match(tx) if best != nil { - return best.provider.GetCredentials(ctx, tx, req) + switch a := best.action.(type) { + case CredentialAction: + if a.Provider == nil { + return nil, ErrNilCredentialProvider + } + return a.Provider.GetCredentials(ctx, tx, req) + case ForwardAction: + return nil, ErrUnexpectedForwardAction + default: + // Sealed interface — unreachable in practice. + return nil, ErrUnexpectedForwardAction + } } if m.fallback != nil { @@ -157,6 +205,37 @@ func (m *Mux) ModifyResponse(ctx context.Context, tx sdk.TransactionContext, res return nil, nil } +// RouteRequest implements sdk.RequestRouter. It returns a non-nil RouteAction +// only when the matched route is a ForwardAction. Credential matches fall +// through (return nil) so that the normal Mux.GetCredentials path handles them. +func (m *Mux) RouteRequest(_ context.Context, tx sdk.TransactionContext, _ *http.Request) (*sdk.RouteAction, error) { + best := m.match(tx) + if best == nil { + return nil, nil + } + if fa, ok := best.action.(ForwardAction); ok { + return &sdk.RouteAction{ForwardTo: fa.Target}, nil + } + return nil, nil +} + +// ForwardReferences returns the set of forward_target names referenced by all +// registered routes. Each entry in the returned slice corresponds to a route +// with a ForwardAction. Deduplication is not applied — if the same target is +// referenced by multiple routes, it may appear multiple times. +// +// This is used for startup validation to ensure all referenced targets are +// defined in the configuration. +func (m *Mux) ForwardReferences() []string { + var refs []string + for _, e := range m.entries { + if fa, ok := e.action.(ForwardAction); ok { + refs = append(refs, fa.Target) + } + } + return refs +} + // match finds the best matching route entry for the given transaction. // When multiple routes match at the same specificity, the first registered wins. func (m *Mux) match(tx sdk.TransactionContext) *routeEntry { @@ -205,6 +284,14 @@ func routesMayOverlap(a, b Route) bool { if a.EnvironmentID != "" && b.EnvironmentID != "" && !fieldsMayOverlap(a.EnvironmentID, b.EnvironmentID) { return false } + // For Data, only shared keys constitute shared dimensions. Keys present + // in only one route are wildcards in the other (same model as the + // top-level fields above), so they cannot prove disjointness. + for key, av := range a.Data { + if bv, ok := b.Data[key]; ok && !fieldsMayOverlap(av, bv) { + return false + } + } return true } @@ -247,6 +334,16 @@ func routeString(r Route) string { if r.EnvironmentID != "" { parts = append(parts, "EnvironmentID="+r.EnvironmentID) } + if len(r.Data) > 0 { + keys := make([]string, 0, len(r.Data)) + for k := range r.Data { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + parts = append(parts, "Data["+k+"]="+r.Data[k]) + } + } if len(parts) == 0 { return "{}" } diff --git a/plugins/contrib/mux_config.go b/plugins/contrib/mux_config.go new file mode 100644 index 0000000..83b9582 --- /dev/null +++ b/plugins/contrib/mux_config.go @@ -0,0 +1,169 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package contrib + +import ( + "fmt" + + "github.com/cloudblue/chaperone/sdk" +) + +// MuxConfig is the YAML-friendly description of a request multiplexer. +// It can be parsed directly from a YAML document (or constructed in code) +// and passed to [LoadMuxFromConfig] to build a usable [*Mux]. +// +// Mutual exclusion: every route — and the fallback, if present — must set +// exactly one of `forward` or `credentials`. See [LoadMuxFromConfig] for the +// validation rules. +type MuxConfig struct { + // Routes are evaluated by specificity at dispatch time; registration + // order is preserved and breaks ties. + Routes []MuxRouteConfig `yaml:"routes"` + // Fallback is the catch-all used when no route matches. Optional. + Fallback *MuxFallbackConfig `yaml:"fallback,omitempty"` +} + +// MuxRouteConfig is a single route entry in a YAML mux configuration. +// Exactly one of Forward or Credentials must be set. +type MuxRouteConfig struct { + // Match contains the route's matching criteria. Empty fields are + // wildcards. See [Route] for semantics. + Match MatchConfig `yaml:"match"` + // Forward names a forward_target. When set, the matched request is + // forwarded as-is to that target by the Core. Mutually exclusive with + // Credentials. + Forward string `yaml:"forward,omitempty"` + // Credentials selects a credential provider by type. Mutually exclusive + // with Forward. The Type must be a key in the providers map passed to + // [LoadMuxFromConfig]. + Credentials *CredentialsConfig `yaml:"credentials,omitempty"` +} + +// MatchConfig mirrors the [Route] fields in a YAML-friendly shape so the +// match criteria can be expressed as a nested YAML object. +type MatchConfig struct { + VendorID string `yaml:"vendor_id,omitempty"` + MarketplaceID string `yaml:"marketplace_id,omitempty"` + ProductID string `yaml:"product_id,omitempty"` + EnvironmentID string `yaml:"environment_id,omitempty"` + TargetURL string `yaml:"target_url,omitempty"` + Data map[string]string `yaml:"data,omitempty"` +} + +// CredentialsConfig identifies which pre-built credential provider should +// handle a route. Only Type is interpreted by [LoadMuxFromConfig]: it's a +// discriminator used to look up an [sdk.CredentialProvider] in the providers +// map. +// +// Provider-specific configuration (OAuth endpoints, scopes, etc.) is the +// caller's responsibility — they construct the providers and register them +// in the lookup map before calling [LoadMuxFromConfig]. This keeps the mux +// loader decoupled from provider internals, which vary widely. +type CredentialsConfig struct { + // Type is the discriminator used to look up the provider. + Type string `yaml:"type"` +} + +// MuxFallbackConfig is the catch-all entry used when no route matches. +// +// Only Credentials is supported in v1. Setting Forward returns a +// configuration error: a silent fallback-forward would route any +// unmatched request — including misconfigured or unexpected traffic — to +// an upstream without credential injection, which is ambiguous and unsafe. +// Forward routes must be explicit per-match. +type MuxFallbackConfig struct { + Credentials *CredentialsConfig `yaml:"credentials,omitempty"` + // Forward is rejected by [LoadMuxFromConfig]. Documented as a field + // so a misconfigured YAML can be diagnosed with a clear error rather + // than silently ignored. + Forward string `yaml:"forward,omitempty"` +} + +// LoadMuxFromConfig builds a [*Mux] from a [MuxConfig] and a lookup of +// pre-built credential providers keyed by [CredentialsConfig.Type]. +// +// Validation rules: +// - Every route must set exactly one of forward or credentials. +// - A route's credentials.type must be non-empty and present in providers. +// - The fallback, if present, must set credentials (not forward) and the +// credentials.type must be non-empty and present in providers. +// +// On the first validation failure, an error is returned that names the +// offending route by index (e.g. "routes[2]") or "fallback". No partial +// mux is returned on error. +func LoadMuxFromConfig(cfg MuxConfig, providers map[string]sdk.CredentialProvider) (*Mux, error) { + m := NewMux() + + for i, rc := range cfg.Routes { + if err := applyRoute(m, i, rc, providers); err != nil { + return nil, err + } + } + + if cfg.Fallback != nil { + if err := applyFallback(m, cfg.Fallback, providers); err != nil { + return nil, err + } + } + + return m, nil +} + +// applyRoute validates and registers a single route entry. +func applyRoute(m *Mux, i int, rc MuxRouteConfig, providers map[string]sdk.CredentialProvider) error { + hasForward := rc.Forward != "" + hasCreds := rc.Credentials != nil + if hasForward == hasCreds { + return fmt.Errorf("routes[%d]: exactly one of forward or credentials must be set", i) + } + + route := routeFromMatch(rc.Match) + + if hasForward { + m.HandleForward(route, rc.Forward) + return nil + } + + if rc.Credentials.Type == "" { + return fmt.Errorf("routes[%d]: credentials.type must be non-empty", i) + } + p, ok := providers[rc.Credentials.Type] + if !ok { + return fmt.Errorf("routes[%d]: unknown credentials provider type %q", i, rc.Credentials.Type) + } + m.Handle(route, p) + return nil +} + +// applyFallback validates and installs the catch-all provider. +// Forward is rejected — see [MuxFallbackConfig] for rationale. +func applyFallback(m *Mux, fc *MuxFallbackConfig, providers map[string]sdk.CredentialProvider) error { + if fc.Forward != "" { + return fmt.Errorf("fallback: forward is not supported; fallback must use credentials") + } + if fc.Credentials == nil { + return fmt.Errorf("fallback: credentials must be set") + } + if fc.Credentials.Type == "" { + return fmt.Errorf("fallback: credentials.type must be non-empty") + } + p, ok := providers[fc.Credentials.Type] + if !ok { + return fmt.Errorf("fallback: unknown credentials provider type %q", fc.Credentials.Type) + } + m.Default(p) + return nil +} + +// routeFromMatch projects a MatchConfig into a Route. +func routeFromMatch(mc MatchConfig) Route { + return Route{ + VendorID: mc.VendorID, + MarketplaceID: mc.MarketplaceID, + ProductID: mc.ProductID, + EnvironmentID: mc.EnvironmentID, + TargetURL: mc.TargetURL, + Data: mc.Data, + } +} diff --git a/plugins/contrib/mux_config_test.go b/plugins/contrib/mux_config_test.go new file mode 100644 index 0000000..b9a4b30 --- /dev/null +++ b/plugins/contrib/mux_config_test.go @@ -0,0 +1,454 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package contrib + +import ( + "context" + "net/http" + "strings" + "testing" + + "gopkg.in/yaml.v3" + + "github.com/cloudblue/chaperone/sdk" +) + +// namedStubProvider is a minimal sdk.CredentialProvider used to assert which +// provider the mux registered for a given route. +type namedStubProvider struct { + name string +} + +func (p *namedStubProvider) GetCredentials(_ context.Context, _ sdk.TransactionContext, _ *http.Request) (*sdk.Credential, error) { + return &sdk.Credential{Headers: map[string]string{"X-Provider": p.name}}, nil +} + +// --- spec-mandated tests --- + +func TestLoadMuxFromConfig_ForwardAndCredentials_Exclusive(t *testing.T) { + _, err := LoadMuxFromConfig(MuxConfig{ + Routes: []MuxRouteConfig{ + { + Match: MatchConfig{VendorID: "x"}, + Forward: "company-b", + Credentials: &CredentialsConfig{Type: "oauth2"}, + }, + }, + }, nil) + if err == nil { + t.Fatal("expected mutual-exclusion error, got nil") + } + if !strings.Contains(err.Error(), "routes[0]") { + t.Errorf("error message %q must reference routes[0]", err.Error()) + } +} + +func TestLoadMuxFromConfig_ForwardOnly_RegistersForwardAction(t *testing.T) { + m, err := LoadMuxFromConfig(MuxConfig{ + Routes: []MuxRouteConfig{ + {Match: MatchConfig{VendorID: "x"}, Forward: "company-b"}, + }, + }, nil) + if err != nil { + t.Fatalf("LoadMuxFromConfig: %v", err) + } + if len(m.entries) != 1 { + t.Fatalf("entries = %d, want 1", len(m.entries)) + } + fa, ok := m.entries[0].action.(ForwardAction) + if !ok { + t.Errorf("action = %T, want ForwardAction", m.entries[0].action) + } + if fa.Target != "company-b" { + t.Errorf("Target = %q, want %q", fa.Target, "company-b") + } +} + +// --- table-driven validation error matrix --- + +func TestLoadMuxFromConfig_ValidationErrors(t *testing.T) { + providers := map[string]sdk.CredentialProvider{ + "oauth2": &namedStubProvider{name: "oauth2"}, + } + + tests := []struct { + name string + cfg MuxConfig + wantErrSubs []string // all substrings must appear in the error message + }{ + { + name: "neither forward nor credentials", + cfg: MuxConfig{Routes: []MuxRouteConfig{ + {Match: MatchConfig{VendorID: "x"}}, + }}, + wantErrSubs: []string{"routes[0]", "forward", "credentials"}, + }, + { + name: "both forward and credentials", + cfg: MuxConfig{Routes: []MuxRouteConfig{ + {Match: MatchConfig{VendorID: "x"}, Forward: "t", Credentials: &CredentialsConfig{Type: "oauth2"}}, + }}, + wantErrSubs: []string{"routes[0]", "forward", "credentials"}, + }, + { + name: "unknown credentials provider type", + cfg: MuxConfig{Routes: []MuxRouteConfig{ + {Match: MatchConfig{VendorID: "x"}, Credentials: &CredentialsConfig{Type: "saml"}}, + }}, + wantErrSubs: []string{"routes[0]", `"saml"`}, + }, + { + name: "empty credentials.type", + cfg: MuxConfig{Routes: []MuxRouteConfig{ + {Match: MatchConfig{VendorID: "x"}, Credentials: &CredentialsConfig{Type: ""}}, + }}, + wantErrSubs: []string{"routes[0]", "credentials.type"}, + }, + { + name: "first invalid route reported when multiple invalid", + cfg: MuxConfig{Routes: []MuxRouteConfig{ + {Match: MatchConfig{VendorID: "x"}}, // neither -> error at index 0 + {Match: MatchConfig{VendorID: "y"}, Forward: "t", Credentials: &CredentialsConfig{Type: "oauth2"}}, + }}, + wantErrSubs: []string{"routes[0]"}, + }, + { + name: "fallback with forward is disallowed", + cfg: MuxConfig{ + Fallback: &MuxFallbackConfig{Forward: "t"}, + }, + wantErrSubs: []string{"fallback", "forward"}, + }, + { + name: "fallback with both forward and credentials", + cfg: MuxConfig{ + Fallback: &MuxFallbackConfig{ + Forward: "t", + Credentials: &CredentialsConfig{Type: "oauth2"}, + }, + }, + wantErrSubs: []string{"fallback"}, + }, + { + name: "fallback credentials unknown type", + cfg: MuxConfig{ + Fallback: &MuxFallbackConfig{Credentials: &CredentialsConfig{Type: "saml"}}, + }, + wantErrSubs: []string{"fallback", `"saml"`}, + }, + { + name: "fallback credentials empty type", + cfg: MuxConfig{ + Fallback: &MuxFallbackConfig{Credentials: &CredentialsConfig{Type: ""}}, + }, + wantErrSubs: []string{"fallback", "credentials.type"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := LoadMuxFromConfig(tt.cfg, providers) + if err == nil { + t.Fatalf("expected error, got nil") + } + msg := err.Error() + for _, sub := range tt.wantErrSubs { + if !strings.Contains(msg, sub) { + t.Errorf("error %q missing substring %q", msg, sub) + } + } + }) + } +} + +// --- table-driven successful construction matrix --- + +func TestLoadMuxFromConfig_SuccessfulConstruction(t *testing.T) { + oauthProv := &namedStubProvider{name: "oauth2"} + bearerProv := &namedStubProvider{name: "bearer"} + fallbackProv := &namedStubProvider{name: "fallback"} + providers := map[string]sdk.CredentialProvider{ + "oauth2": oauthProv, + "bearer": bearerProv, + "fallback": fallbackProv, + } + + t.Run("single forward route", func(t *testing.T) { + m, err := LoadMuxFromConfig(MuxConfig{ + Routes: []MuxRouteConfig{ + {Match: MatchConfig{VendorID: "v"}, Forward: "company-b"}, + }, + }, providers) + if err != nil { + t.Fatalf("LoadMuxFromConfig: %v", err) + } + if len(m.entries) != 1 { + t.Fatalf("entries = %d, want 1", len(m.entries)) + } + fa, ok := m.entries[0].action.(ForwardAction) + if !ok { + t.Fatalf("action = %T, want ForwardAction", m.entries[0].action) + } + if fa.Target != "company-b" { + t.Errorf("Target = %q, want %q", fa.Target, "company-b") + } + }) + + t.Run("single credential route", func(t *testing.T) { + m, err := LoadMuxFromConfig(MuxConfig{ + Routes: []MuxRouteConfig{ + {Match: MatchConfig{VendorID: "v"}, Credentials: &CredentialsConfig{Type: "oauth2"}}, + }, + }, providers) + if err != nil { + t.Fatalf("LoadMuxFromConfig: %v", err) + } + if len(m.entries) != 1 { + t.Fatalf("entries = %d, want 1", len(m.entries)) + } + ca, ok := m.entries[0].action.(CredentialAction) + if !ok { + t.Fatalf("action = %T, want CredentialAction", m.entries[0].action) + } + if ca.Provider != oauthProv { + t.Errorf("Provider = %v, want oauthProv", ca.Provider) + } + }) + + t.Run("mixed forward and credential routes preserve order", func(t *testing.T) { + m, err := LoadMuxFromConfig(MuxConfig{ + Routes: []MuxRouteConfig{ + {Match: MatchConfig{VendorID: "v1"}, Forward: "f1"}, + {Match: MatchConfig{VendorID: "v2"}, Credentials: &CredentialsConfig{Type: "oauth2"}}, + {Match: MatchConfig{VendorID: "v3"}, Forward: "f3"}, + {Match: MatchConfig{VendorID: "v4"}, Credentials: &CredentialsConfig{Type: "bearer"}}, + }, + }, providers) + if err != nil { + t.Fatalf("LoadMuxFromConfig: %v", err) + } + if len(m.entries) != 4 { + t.Fatalf("entries = %d, want 4", len(m.entries)) + } + // Check order preserved by index and action type + if fa, ok := m.entries[0].action.(ForwardAction); !ok || fa.Target != "f1" { + t.Errorf("entries[0] = %v, want ForwardAction{Target:f1}", m.entries[0].action) + } + if ca, ok := m.entries[1].action.(CredentialAction); !ok || ca.Provider != oauthProv { + t.Errorf("entries[1] = %v, want CredentialAction{oauth2}", m.entries[1].action) + } + if fa, ok := m.entries[2].action.(ForwardAction); !ok || fa.Target != "f3" { + t.Errorf("entries[2] = %v, want ForwardAction{Target:f3}", m.entries[2].action) + } + if ca, ok := m.entries[3].action.(CredentialAction); !ok || ca.Provider != bearerProv { + t.Errorf("entries[3] = %v, want CredentialAction{bearer}", m.entries[3].action) + } + // Indexes should be 0..3 + for i, e := range m.entries { + if e.index != i { + t.Errorf("entries[%d].index = %d, want %d", i, e.index, i) + } + } + }) + + t.Run("full match dimensions propagated", func(t *testing.T) { + m, err := LoadMuxFromConfig(MuxConfig{ + Routes: []MuxRouteConfig{ + { + Match: MatchConfig{ + VendorID: "vendor", + MarketplaceID: "mp", + ProductID: "prod", + EnvironmentID: "env", + TargetURL: "*.example.com/**", + Data: map[string]string{"ResellerId": "r-*", "Region": "us"}, + }, + Forward: "target", + }, + }, + }, providers) + if err != nil { + t.Fatalf("LoadMuxFromConfig: %v", err) + } + r := m.entries[0].route + if r.VendorID != "vendor" || r.MarketplaceID != "mp" || r.ProductID != "prod" || + r.EnvironmentID != "env" || r.TargetURL != "*.example.com/**" { + t.Errorf("route fields not propagated: %+v", r) + } + if got, want := r.Data["ResellerId"], "r-*"; got != want { + t.Errorf("Data[ResellerId] = %q, want %q", got, want) + } + if got, want := r.Data["Region"], "us"; got != want { + t.Errorf("Data[Region] = %q, want %q", got, want) + } + }) + + t.Run("fallback with credentials sets default", func(t *testing.T) { + m, err := LoadMuxFromConfig(MuxConfig{ + Fallback: &MuxFallbackConfig{Credentials: &CredentialsConfig{Type: "fallback"}}, + }, providers) + if err != nil { + t.Fatalf("LoadMuxFromConfig: %v", err) + } + if m.fallback != fallbackProv { + t.Errorf("fallback = %v, want fallbackProv", m.fallback) + } + if len(m.entries) != 0 { + t.Errorf("entries = %d, want 0", len(m.entries)) + } + }) + + t.Run("empty routes plus fallback", func(t *testing.T) { + m, err := LoadMuxFromConfig(MuxConfig{ + Routes: nil, + Fallback: &MuxFallbackConfig{Credentials: &CredentialsConfig{Type: "fallback"}}, + }, providers) + if err != nil { + t.Fatalf("LoadMuxFromConfig: %v", err) + } + if len(m.entries) != 0 { + t.Errorf("entries = %d, want 0", len(m.entries)) + } + if m.fallback != fallbackProv { + t.Errorf("fallback = %v, want fallbackProv", m.fallback) + } + }) + + t.Run("empty everything yields valid empty mux", func(t *testing.T) { + m, err := LoadMuxFromConfig(MuxConfig{}, providers) + if err != nil { + t.Fatalf("LoadMuxFromConfig: %v", err) + } + if len(m.entries) != 0 { + t.Errorf("entries = %d, want 0", len(m.entries)) + } + if m.fallback != nil { + t.Errorf("fallback = %v, want nil", m.fallback) + } + }) + + t.Run("nil providers map with only forward routes is fine", func(t *testing.T) { + // Forwards don't need providers — nil map should be acceptable. + m, err := LoadMuxFromConfig(MuxConfig{ + Routes: []MuxRouteConfig{ + {Match: MatchConfig{VendorID: "v"}, Forward: "f"}, + }, + }, nil) + if err != nil { + t.Fatalf("LoadMuxFromConfig: %v", err) + } + if len(m.entries) != 1 { + t.Fatalf("entries = %d, want 1", len(m.entries)) + } + }) +} + +// --- YAML roundtrip end-to-end test --- + +func TestLoadMuxFromConfig_YAMLRoundtrip(t *testing.T) { + const doc = ` +routes: + - match: + vendor_id: "acme" + product_id: "WIDGET" + credentials: + type: oauth2 + - match: + vendor_id: "globex" + forward: company-b + - match: + vendor_id: "migrated" + data: + ResellerId: "legacy-*" + forward: company-b +fallback: + credentials: + type: fallback +` + var cfg MuxConfig + if err := yaml.Unmarshal([]byte(doc), &cfg); err != nil { + t.Fatalf("yaml.Unmarshal: %v", err) + } + + oauthProv := &namedStubProvider{name: "oauth2"} + fallbackProv := &namedStubProvider{name: "fallback"} + providers := map[string]sdk.CredentialProvider{ + "oauth2": oauthProv, + "fallback": fallbackProv, + } + + m, err := LoadMuxFromConfig(cfg, providers) + if err != nil { + t.Fatalf("LoadMuxFromConfig: %v", err) + } + + if len(m.entries) != 3 { + t.Fatalf("entries = %d, want 3", len(m.entries)) + } + + // Entry 0: acme + WIDGET → credentials oauth2 + { + e := m.entries[0] + if e.route.VendorID != "acme" || e.route.ProductID != "WIDGET" { + t.Errorf("entry[0] route = %+v", e.route) + } + ca, ok := e.action.(CredentialAction) + if !ok || ca.Provider != oauthProv { + t.Errorf("entry[0] action = %T %v, want CredentialAction{oauth2}", e.action, e.action) + } + } + // Entry 1: globex forward + { + e := m.entries[1] + if e.route.VendorID != "globex" || e.route.TargetURL != "" { + t.Errorf("entry[1] route = %+v", e.route) + } + fa, ok := e.action.(ForwardAction) + if !ok || fa.Target != "company-b" { + t.Errorf("entry[1] action = %T %v, want ForwardAction{company-b}", e.action, e.action) + } + } + // Entry 2: migrated forward with Data + { + e := m.entries[2] + if e.route.VendorID != "migrated" { + t.Errorf("entry[2] route.VendorID = %q, want %q", e.route.VendorID, "migrated") + } + if got, want := e.route.Data["ResellerId"], "legacy-*"; got != want { + t.Errorf("entry[2] Data[ResellerId] = %q, want %q", got, want) + } + fa, ok := e.action.(ForwardAction) + if !ok || fa.Target != "company-b" { + t.Errorf("entry[2] action = %T %v, want ForwardAction{company-b}", e.action, e.action) + } + } + if m.fallback != fallbackProv { + t.Errorf("fallback = %v, want fallbackProv", m.fallback) + } + + // Behavioral check: send a tx that matches entry 1 and verify RouteRequest + // returns a ForwardAction to "company-b". + tx := sdk.TransactionContext{ + VendorID: "globex", + TargetURL: "https://api.globex.com/v1/things", + } + ra, err := m.RouteRequest(context.Background(), tx, nil) + if err != nil { + t.Fatalf("RouteRequest: %v", err) + } + if ra == nil || ra.ForwardTo != "company-b" { + t.Errorf("RouteRequest = %+v, want ForwardTo=company-b", ra) + } + + // And a tx that matches the credential route (entry 0) — RouteRequest + // must return nil so the credential path runs. + txCred := sdk.TransactionContext{VendorID: "acme", ProductID: "WIDGET"} + ra2, err := m.RouteRequest(context.Background(), txCred, nil) + if err != nil { + t.Fatalf("RouteRequest: %v", err) + } + if ra2 != nil { + t.Errorf("RouteRequest = %+v, want nil for credential match", ra2) + } +} diff --git a/plugins/contrib/mux_test.go b/plugins/contrib/mux_test.go index 28301fa..4e1f93a 100644 --- a/plugins/contrib/mux_test.go +++ b/plugins/contrib/mux_test.go @@ -9,6 +9,8 @@ import ( "io" "log/slog" "net/http" + "net/http/httptest" + "strings" "sync" "testing" "time" @@ -604,6 +606,172 @@ func TestMux_Compliance(t *testing.T) { compliance.VerifyContract(t, mux) } +// --- RouteRequest tests (RequestRouter implementation) --- + +func TestMux_RouteRequest_ReturnsForward_ForForwardAction(t *testing.T) { + m := NewMux() + m.HandleForward(Route{VendorID: "microsoft-*"}, "company-b") + + action, err := m.RouteRequest(context.Background(), + sdk.TransactionContext{VendorID: "microsoft-azure"}, + httptest.NewRequest("GET", "https://example.com/x", nil)) + if err != nil { + t.Fatalf("RouteRequest: %v", err) + } + if action == nil || action.ForwardTo != "company-b" { + t.Errorf("action = %#v, want ForwardTo=company-b", action) + } +} + +func TestMux_RouteRequest_ReturnsNil_ForCredentialAction(t *testing.T) { + m := NewMux() + m.Handle(Route{VendorID: "microsoft-*"}, &namedProvider{name: "test"}) + + action, err := m.RouteRequest(context.Background(), + sdk.TransactionContext{VendorID: "microsoft-azure"}, + httptest.NewRequest("GET", "https://example.com/x", nil)) + if err != nil { + t.Fatalf("RouteRequest: %v", err) + } + if action != nil { + t.Errorf("action = %#v, want nil for CredentialAction match", action) + } +} + +func TestMux_RouteRequest_ReturnsNil_NoMatch(t *testing.T) { + m := NewMux() + m.Default(&namedProvider{name: "fallback"}) + + action, err := m.RouteRequest(context.Background(), + sdk.TransactionContext{VendorID: "globex"}, + httptest.NewRequest("GET", "https://example.com/x", nil)) + if err != nil { + t.Fatalf("RouteRequest: %v", err) + } + if action != nil { + t.Errorf("action = %#v, want nil when no forward route matches", action) + } +} + +// --- RouteRequest mandatory test matrix --- + +func TestMux_RouteRequest_Matrix(t *testing.T) { + tests := []struct { + name string + setup func(*Mux) + tx sdk.TransactionContext + wantAction *sdk.RouteAction + wantErr bool + description string + }{ + { + name: "ForwardAction_MatchedWithData_ReturnsCorrectTarget", + setup: func(m *Mux) { + m.HandleForward(Route{VendorID: "acme", Data: map[string]string{"region": "us-east"}}, "acme-us-east") + }, + tx: sdk.TransactionContext{VendorID: "acme", Data: map[string]any{"region": "us-east"}}, + wantAction: &sdk.RouteAction{ForwardTo: "acme-us-east"}, + description: "ForwardAction matched at specific Data dimension returns correct ForwardTo", + }, + { + name: "HigherSpecificityForwardBeatsLowerSpecificityCredential", + setup: func(m *Mux) { + m.Handle(Route{VendorID: "acme"}, &namedProvider{name: "general"}) + m.HandleForward(Route{VendorID: "acme", EnvironmentID: "prod"}, "acme-prod") + }, + tx: sdk.TransactionContext{VendorID: "acme", EnvironmentID: "prod"}, + wantAction: &sdk.RouteAction{ForwardTo: "acme-prod"}, + description: "More specific ForwardAction wins over less specific CredentialAction", + }, + { + name: "HigherSpecificityCredentialBeatsLowerSpecificityForward", + setup: func(m *Mux) { + m.HandleForward(Route{VendorID: "acme"}, "general-acme") + m.Handle(Route{VendorID: "acme", EnvironmentID: "prod"}, &namedProvider{name: "specific"}) + }, + tx: sdk.TransactionContext{VendorID: "acme", EnvironmentID: "prod"}, + wantAction: nil, + description: "More specific CredentialAction wins over less specific ForwardAction (returns nil)", + }, + { + name: "TwoForwardActionsAtSameSpecificity_FirstRegisteredWins", + setup: func(m *Mux) { + m.HandleForward(Route{VendorID: "microsoft-*"}, "target-first") + m.HandleForward(Route{VendorID: "microsoft-*"}, "target-second") + }, + tx: sdk.TransactionContext{VendorID: "microsoft-azure"}, + wantAction: &sdk.RouteAction{ForwardTo: "target-first"}, + description: "Two ForwardActions matching at same specificity returns first registered", + }, + { + name: "NilHTTPRequest_DoesNotPanic", + setup: func(m *Mux) { + m.HandleForward(Route{VendorID: "acme"}, "target-acme") + }, + tx: sdk.TransactionContext{VendorID: "acme"}, + wantAction: &sdk.RouteAction{ForwardTo: "target-acme"}, + description: "nil http.Request argument does not panic", + }, + { + name: "NilTXData_WithDataDimensionRoute_NoMatch", + setup: func(m *Mux) { + m.HandleForward(Route{VendorID: "acme", Data: map[string]string{"region": "us"}}, "target-us") + }, + tx: sdk.TransactionContext{VendorID: "acme", Data: nil}, + wantAction: nil, + description: "nil tx.Data with a Data-dimension route does not match", + }, + { + name: "PreCancelledContext_StillReturnsAction", + setup: func(m *Mux) { + m.HandleForward(Route{VendorID: "acme"}, "target-acme") + }, + tx: sdk.TransactionContext{VendorID: "acme"}, + wantAction: &sdk.RouteAction{ForwardTo: "target-acme"}, + description: "Pre-cancelled context still returns the same action", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := NewMux() + tt.setup(m) + + ctx := context.Background() + // For the "pre-cancelled context" case, cancel it before calling. + if tt.name == "PreCancelledContext_StillReturnsAction" { + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(context.Background()) + cancel() + } + + action, err := m.RouteRequest(ctx, tt.tx, nil) // nil req is acceptable + if (err != nil) != tt.wantErr { + t.Errorf("error = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantAction == nil { + if action != nil { + t.Errorf("action = %#v, want nil (%s)", action, tt.description) + } + } else { + if action == nil { + t.Errorf("action = nil, want %#v (%s)", tt.wantAction, tt.description) + } else if action.ForwardTo != tt.wantAction.ForwardTo { + t.Errorf("action.ForwardTo = %q, want %q (%s)", action.ForwardTo, tt.wantAction.ForwardTo, tt.description) + } + } + }) + } +} + +// --- RouteRequest RequestRouter compliance test --- + +func TestMux_RouteRequest_Compliance(t *testing.T) { + m := NewMux() + m.HandleForward(Route{VendorID: "test"}, "target-test") + compliance.VerifyRouter(t, m) +} + func TestNewMux_NilLogger_LazyResolution(t *testing.T) { m := NewMux() @@ -624,3 +792,390 @@ func TestNewMux_WithLogger_UsesExplicitLogger(t *testing.T) { t.Error("log() should return the explicitly provided logger") } } + +// --- Action registration tests --- + +// countOverlapWarnings returns how many overlap warnings the capture saw. +func countOverlapWarnings(entries []logEntry) int { + const want = "routes registered with equal specificity may overlap, first registered wins on tie" + n := 0 + for _, e := range entries { + if e.level == slog.LevelWarn && e.message == want { + n++ + } + } + return n +} + +func TestMux_HandleForward_RegistersForwardAction(t *testing.T) { + m := NewMux() + m.HandleForward(Route{VendorID: "x"}, "company-b") + + if len(m.entries) != 1 { + t.Fatalf("entries = %d, want 1", len(m.entries)) + } + fa, ok := m.entries[0].action.(ForwardAction) + if !ok { + t.Fatalf("action = %T, want ForwardAction", m.entries[0].action) + } + if fa.Target != "company-b" { + t.Errorf("Target = %q, want %q", fa.Target, "company-b") + } + if m.entries[0].index != 0 { + t.Errorf("index = %d, want 0", m.entries[0].index) + } +} + +func TestMux_HandleForward_MultipleRegistrations_PreserveOrder(t *testing.T) { + m := NewMux() + m.HandleForward(Route{VendorID: "a"}, "target-a") + m.HandleForward(Route{VendorID: "b"}, "target-b") + m.HandleForward(Route{VendorID: "c"}, "target-c") + + if len(m.entries) != 3 { + t.Fatalf("entries = %d, want 3", len(m.entries)) + } + wantTargets := []string{"target-a", "target-b", "target-c"} + for i, want := range wantTargets { + fa, ok := m.entries[i].action.(ForwardAction) + if !ok { + t.Fatalf("entry[%d].action = %T, want ForwardAction", i, m.entries[i].action) + } + if fa.Target != want { + t.Errorf("entry[%d].Target = %q, want %q", i, fa.Target, want) + } + if m.entries[i].index != i { + t.Errorf("entry[%d].index = %d, want %d", i, m.entries[i].index, i) + } + } +} + +func TestMux_Handle_RegistersCredentialAction(t *testing.T) { + m := NewMux() + provider := &namedProvider{name: "acme"} + m.Handle(Route{VendorID: "acme"}, provider) + + if len(m.entries) != 1 { + t.Fatalf("entries = %d, want 1", len(m.entries)) + } + ca, ok := m.entries[0].action.(CredentialAction) + if !ok { + t.Fatalf("action = %T, want CredentialAction", m.entries[0].action) + } + if ca.Provider != provider { + t.Errorf("Provider = %v, want %v", ca.Provider, provider) + } +} + +func TestMux_MixedHandleAndHandleForward_AllRegisteredWithCorrectTypes(t *testing.T) { + m := NewMux() + prov := &namedProvider{name: "acme"} + m.Handle(Route{VendorID: "acme"}, prov) + m.HandleForward(Route{VendorID: "globex"}, "target-globex") + m.Handle(Route{VendorID: "initech"}, &namedProvider{name: "initech"}) + m.HandleForward(Route{VendorID: "umbrella"}, "target-umbrella") + + if len(m.entries) != 4 { + t.Fatalf("entries = %d, want 4", len(m.entries)) + } + + cases := []struct { + idx int + wantType string + }{ + {0, "credential"}, + {1, "forward"}, + {2, "credential"}, + {3, "forward"}, + } + for _, tc := range cases { + switch tc.wantType { + case "credential": + if _, ok := m.entries[tc.idx].action.(CredentialAction); !ok { + t.Errorf("entry[%d].action = %T, want CredentialAction", tc.idx, m.entries[tc.idx].action) + } + case "forward": + if _, ok := m.entries[tc.idx].action.(ForwardAction); !ok { + t.Errorf("entry[%d].action = %T, want ForwardAction", tc.idx, m.entries[tc.idx].action) + } + } + } +} + +// TestMux_HandleForward_EmptyTarget_Panics verifies that registering a route +// with an empty target name is treated as a programmer error. An empty target +// cannot reference any forward_target, so silently registering a dead route +// would only delay the failure to dispatch time. Config-driven users hit the +// loader's own non-empty check before ever reaching HandleForward. +func TestMux_HandleForward_EmptyTarget_Panics(t *testing.T) { + defer func() { + r := recover() + if r == nil { + t.Fatal("expected panic on empty target name") + } + msg, ok := r.(string) + if !ok || !strings.Contains(msg, "empty target name") { + t.Errorf("panic message = %v, want substring 'empty target name'", r) + } + }() + NewMux().HandleForward(Route{VendorID: "x"}, "") +} + +// --- Overlap warning tests across action types --- + +func TestMux_OverlapWarning_AcrossActionTypes(t *testing.T) { + tests := []struct { + name string + register func(m *Mux) + wantWarn int + }{ + { + name: "two HandleForward with overlapping routes at same specificity", + register: func(m *Mux) { + m.HandleForward(Route{VendorID: "microsoft-*"}, "target-1") + m.HandleForward(Route{VendorID: "microsoft-azure"}, "target-2") + }, + wantWarn: 1, + }, + { + name: "Handle then HandleForward with overlapping routes at same specificity", + register: func(m *Mux) { + m.Handle(Route{VendorID: "microsoft-*"}, &namedProvider{name: "first"}) + m.HandleForward(Route{VendorID: "microsoft-azure"}, "target-2") + }, + wantWarn: 1, + }, + { + name: "HandleForward then Handle with overlapping routes at same specificity", + register: func(m *Mux) { + m.HandleForward(Route{VendorID: "microsoft-*"}, "target-1") + m.Handle(Route{VendorID: "microsoft-azure"}, &namedProvider{name: "second"}) + }, + wantWarn: 1, + }, + { + name: "two non-overlapping HandleForward (disjoint literals)", + register: func(m *Mux) { + m.HandleForward(Route{VendorID: "acme"}, "target-acme") + m.HandleForward(Route{VendorID: "globex"}, "target-globex") + }, + wantWarn: 0, + }, + { + name: "non-overlapping mixed actions (disjoint literals)", + register: func(m *Mux) { + m.Handle(Route{VendorID: "acme"}, &namedProvider{name: "acme"}) + m.HandleForward(Route{VendorID: "globex"}, "target-globex") + }, + wantWarn: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + capture := &logCapture{} + m := NewMux(WithLogger(slog.New(capture))) + tt.register(m) + + if got := countOverlapWarnings(capture.getEntries()); got != tt.wantWarn { + t.Errorf("overlap warnings = %d, want %d", got, tt.wantWarn) + } + }) + } +} + +// --- GetCredentials fall-through tests for the new action model --- + +func TestMux_GetCredentials_CredentialAction_DelegatesToProvider(t *testing.T) { + m := NewMux() + m.Handle(Route{VendorID: "acme"}, &namedProvider{name: "acme"}) + + ctx := context.Background() + tx := sdk.TransactionContext{VendorID: "acme"} + cred, err := m.GetCredentials(ctx, tx, makeTestReq(ctx)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got := cred.Headers["Authorization"]; got != "Bearer acme" { + t.Errorf("Authorization = %q, want %q", got, "Bearer acme") + } +} + +func TestMux_GetCredentials_ForwardAction_ReturnsErrUnexpectedForwardAction(t *testing.T) { + m := NewMux() + m.HandleForward(Route{VendorID: "acme"}, "company-b") + + ctx := context.Background() + tx := sdk.TransactionContext{VendorID: "acme"} + cred, err := m.GetCredentials(ctx, tx, makeTestReq(ctx)) + if !errors.Is(err, ErrUnexpectedForwardAction) { + t.Errorf("error = %v, want ErrUnexpectedForwardAction", err) + } + if cred != nil { + t.Errorf("cred = %v, want nil", cred) + } +} + +func TestMux_GetCredentials_NilProviderInCredentialAction_ReturnsErrNilCredentialProvider(t *testing.T) { + m := NewMux() + // Direct registration via Handle with nil provider. + m.Handle(Route{VendorID: "acme"}, nil) + + ctx := context.Background() + tx := sdk.TransactionContext{VendorID: "acme"} + cred, err := m.GetCredentials(ctx, tx, makeTestReq(ctx)) + if !errors.Is(err, ErrNilCredentialProvider) { + t.Errorf("error = %v, want ErrNilCredentialProvider", err) + } + if cred != nil { + t.Errorf("cred = %v, want nil", cred) + } +} + +func TestMux_GetCredentials_ForwardActionMatched_DoesNotConsultFallback(t *testing.T) { + // A ForwardAction match must NOT silently fall through to the default + // provider — that would be a security regression. + m := NewMux() + m.HandleForward(Route{VendorID: "acme"}, "company-b") + m.Default(&namedProvider{name: "fallback"}) + + ctx := context.Background() + tx := sdk.TransactionContext{VendorID: "acme"} + _, err := m.GetCredentials(ctx, tx, makeTestReq(ctx)) + if !errors.Is(err, ErrUnexpectedForwardAction) { + t.Errorf("error = %v, want ErrUnexpectedForwardAction (must not fall through to default)", err) + } +} + +func TestMux_GetCredentials_NoMatch_FallbackStillWorks(t *testing.T) { + // Sanity: confirm fallback behavior still works after the refactor. + m := NewMux() + m.HandleForward(Route{VendorID: "acme"}, "company-b") + m.Default(&namedProvider{name: "fallback"}) + + ctx := context.Background() + tx := sdk.TransactionContext{VendorID: "unknown-vendor"} + cred, err := m.GetCredentials(ctx, tx, makeTestReq(ctx)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got := cred.Headers["Authorization"]; got != "Bearer fallback" { + t.Errorf("Authorization = %q, want %q", got, "Bearer fallback") + } +} + +// --- Specificity interaction with mixed action types --- + +func TestMux_GetCredentials_MoreSpecificForwardBeatsLessSpecificCredential(t *testing.T) { + // Selection is by specificity, independent of action type. A more + // specific ForwardAction wins over a less specific CredentialAction, + // and the mux returns ErrUnexpectedForwardAction. + m := NewMux() + m.Handle(Route{VendorID: "acme"}, &namedProvider{name: "general"}) + m.HandleForward( + Route{VendorID: "acme", EnvironmentID: "prod"}, + "target-acme-prod", + ) + + ctx := context.Background() + tx := sdk.TransactionContext{VendorID: "acme", EnvironmentID: "prod"} + _, err := m.GetCredentials(ctx, tx, makeTestReq(ctx)) + if !errors.Is(err, ErrUnexpectedForwardAction) { + t.Errorf("error = %v, want ErrUnexpectedForwardAction", err) + } +} + +func TestMux_GetCredentials_MoreSpecificCredentialBeatsLessSpecificForward(t *testing.T) { + m := NewMux() + m.HandleForward(Route{VendorID: "acme"}, "target-general") + m.Handle( + Route{VendorID: "acme", EnvironmentID: "prod"}, + &namedProvider{name: "specific"}, + ) + + ctx := context.Background() + tx := sdk.TransactionContext{VendorID: "acme", EnvironmentID: "prod"} + cred, err := m.GetCredentials(ctx, tx, makeTestReq(ctx)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got := cred.Headers["Authorization"]; got != "Bearer specific" { + t.Errorf("Authorization = %q, want %q", got, "Bearer specific") + } +} + +// --- ForwardReferences tests --- + +func TestMux_ForwardReferences_Empty(t *testing.T) { + m := NewMux() + + refs := m.ForwardReferences() + if len(refs) != 0 { + t.Errorf("refs = %v, want empty", refs) + } +} + +func TestMux_ForwardReferences_SingleForwardRoute(t *testing.T) { + m := NewMux() + m.HandleForward(Route{VendorID: "acme"}, "target-acme") + + refs := m.ForwardReferences() + if len(refs) != 1 || refs[0] != "target-acme" { + t.Errorf("refs = %v, want [target-acme]", refs) + } +} + +func TestMux_ForwardReferences_MultipleForwardRoutes(t *testing.T) { + m := NewMux() + m.HandleForward(Route{VendorID: "acme"}, "target-acme") + m.HandleForward(Route{VendorID: "globex"}, "target-globex") + m.HandleForward(Route{VendorID: "contoso"}, "target-contoso") + + refs := m.ForwardReferences() + if len(refs) != 3 { + t.Errorf("refs length = %d, want 3", len(refs)) + } + + // Order should match registration order + expected := []string{"target-acme", "target-globex", "target-contoso"} + for i, exp := range expected { + if i >= len(refs) || refs[i] != exp { + t.Errorf("refs[%d] = %q, want %q", i, refs[i], exp) + } + } +} + +func TestMux_ForwardReferences_MixedActions_OnlyIncludesForwards(t *testing.T) { + m := NewMux() + m.Handle(Route{VendorID: "cred-vendor"}, &namedProvider{name: "cred-provider"}) + m.HandleForward(Route{VendorID: "forward-vendor"}, "target-forward") + m.Handle(Route{VendorID: "another-cred"}, &namedProvider{name: "another-provider"}) + + refs := m.ForwardReferences() + if len(refs) != 1 || refs[0] != "target-forward" { + t.Errorf("refs = %v, want [target-forward]", refs) + } +} + +func TestMux_ForwardReferences_DuplicateTargets(t *testing.T) { + m := NewMux() + m.HandleForward(Route{VendorID: "vendor-a"}, "same-target") + m.HandleForward(Route{VendorID: "vendor-b"}, "same-target") + m.HandleForward(Route{VendorID: "vendor-c"}, "different-target") + + refs := m.ForwardReferences() + if len(refs) != 3 { + t.Errorf("refs length = %d, want 3 (including duplicates)", len(refs)) + } + + // Verify that duplicates are preserved (not deduplicated by ForwardReferences itself) + count := 0 + for _, ref := range refs { + if ref == "same-target" { + count++ + } + } + if count != 2 { + t.Errorf("same-target appears %d times, want 2", count) + } +} diff --git a/plugins/contrib/route.go b/plugins/contrib/route.go index 0c94461..8e3a198 100644 --- a/plugins/contrib/route.go +++ b/plugins/contrib/route.go @@ -20,6 +20,7 @@ import ( // Route{EnvironmentID: "prod", VendorID: "microsoft-*"} // 2-field route, higher specificity // Route{TargetURL: "*.graph.microsoft.com/**"} // matches any Graph API path // Route{MarketplaceID: "MP-*", ProductID: "MICROSOFT_SAAS"} // matches by marketplace and product +// Route{Data: map[string]string{"ResellerId": "migrated-*"}} // matches by tx.Data entry type Route struct { // VendorID matches against TransactionContext.VendorID. // Supports glob patterns (e.g., "microsoft-*"). @@ -41,6 +42,20 @@ type Route struct { // EnvironmentID matches against TransactionContext.EnvironmentID. // Supports glob patterns (e.g., "prod-*"). EnvironmentID string + + // Data matches against TransactionContext.Data entries. Each entry is + // : ; the route matches only if every entry's + // pattern matches the corresponding tx.DataString(key) value. + // + // Behavior: + // - Missing keys do not match. + // - Keys present but with a non-string value or an empty string + // do not match (sdk.TransactionContext.DataString returns an + // error for those cases, which we treat as a non-match for + // routing safety — invalid data must never silently route to + // a provider). + // - Each entry contributes 1 to Specificity(). + Data map[string]string } // Specificity returns the number of non-empty fields in the route. @@ -63,6 +78,7 @@ func (r Route) Specificity() int { if r.EnvironmentID != "" { n++ } + n += len(r.Data) return n } @@ -84,6 +100,23 @@ func (r Route) Matches(tx sdk.TransactionContext) bool { if r.TargetURL != "" && !matchTargetURL(r.TargetURL, tx.TargetURL) { return false } + return matchesData(r.Data, tx) +} + +// matchesData reports whether every entry in data matches the +// corresponding tx.Data[key] string value. Missing keys, wrong-type values, +// and empty strings are all treated as non-matches: a route MUST NOT silently +// dispatch when its declared data dimension is unusable. +func matchesData(data map[string]string, tx sdk.TransactionContext) bool { + for key, pattern := range data { + v, ok, err := tx.DataString(key) + if !ok || err != nil { + return false + } + if !GlobMatch(pattern, v, '/') { + return false + } + } return true } diff --git a/plugins/contrib/route_test.go b/plugins/contrib/route_test.go index d9801a2..b49cb3a 100644 --- a/plugins/contrib/route_test.go +++ b/plugins/contrib/route_test.go @@ -4,6 +4,7 @@ package contrib import ( + "strings" "testing" "github.com/cloudblue/chaperone/sdk" @@ -388,6 +389,381 @@ func TestRoute_Matches_MultipleFields(t *testing.T) { } } +// --- Spec tests (verbatim from Task 9 plan) --- + +func TestRoute_Matches_DataField_ExactKey(t *testing.T) { + r := Route{Data: map[string]string{"ResellerId": "migrated-*"}} + tx := sdk.TransactionContext{Data: map[string]any{"ResellerId": "migrated-42"}} + if !r.Matches(tx) { + t.Error("expected match for ResellerId=migrated-42 against migrated-*") + } +} + +func TestRoute_Matches_DataField_Mismatch(t *testing.T) { + r := Route{Data: map[string]string{"ResellerId": "migrated-*"}} + tx := sdk.TransactionContext{Data: map[string]any{"ResellerId": "legacy-42"}} + if r.Matches(tx) { + t.Error("expected no match for legacy-42") + } +} + +func TestRoute_Matches_DataField_MissingKey_DoesNotMatch(t *testing.T) { + r := Route{Data: map[string]string{"ResellerId": "migrated-*"}} + tx := sdk.TransactionContext{Data: map[string]any{}} + if r.Matches(tx) { + t.Error("expected no match when key is absent") + } +} + +func TestRoute_Specificity_IncludesDataEntries(t *testing.T) { + r := Route{VendorID: "microsoft-*", Data: map[string]string{"ResellerId": "x", "TenantId": "y"}} + if got, want := r.Specificity(), 3; got != want { + t.Errorf("Specificity = %d, want %d", got, want) + } +} + +// --- Matches table for the Data dimension --- + +func TestRoute_Matches_Data_Table(t *testing.T) { + tests := []struct { + name string + route Route + tx sdk.TransactionContext + want bool + }{ + { + name: "single Data key with literal exact value matches", + route: Route{Data: map[string]string{"ResellerId": "migrated-42"}}, + tx: sdk.TransactionContext{Data: map[string]any{"ResellerId": "migrated-42"}}, + want: true, + }, + { + name: "single Data key with glob pattern matches", + route: Route{Data: map[string]string{"ResellerId": "migrated-*"}}, + tx: sdk.TransactionContext{Data: map[string]any{"ResellerId": "migrated-42"}}, + want: true, + }, + { + name: "single Data key with wrong value does not match", + route: Route{Data: map[string]string{"ResellerId": "migrated-42"}}, + tx: sdk.TransactionContext{Data: map[string]any{"ResellerId": "legacy-42"}}, + want: false, + }, + { + name: "Data key empty string in tx is invalid and does not match (DataString returns err)", + route: Route{Data: map[string]string{"ResellerId": "*"}}, + tx: sdk.TransactionContext{Data: map[string]any{"ResellerId": ""}}, + want: false, + }, + { + name: "multiple Data keys all match", + route: Route{Data: map[string]string{"ResellerId": "migrated-*", "TenantId": "abc-*"}}, + tx: sdk.TransactionContext{Data: map[string]any{ + "ResellerId": "migrated-42", + "TenantId": "abc-001", + }}, + want: true, + }, + { + name: "multiple Data keys one mismatch", + route: Route{Data: map[string]string{"ResellerId": "migrated-*", "TenantId": "abc-*"}}, + tx: sdk.TransactionContext{Data: map[string]any{ + "ResellerId": "migrated-42", + "TenantId": "xyz-001", + }}, + want: false, + }, + { + name: "multiple Data keys one missing", + route: Route{Data: map[string]string{"ResellerId": "migrated-*", "TenantId": "abc-*"}}, + tx: sdk.TransactionContext{Data: map[string]any{ + "ResellerId": "migrated-42", + }}, + want: false, + }, + { + name: "Data combined with top-level VendorID, both match", + route: Route{ + VendorID: "microsoft-*", + Data: map[string]string{"ResellerId": "migrated-*"}, + }, + tx: sdk.TransactionContext{ + VendorID: "microsoft-azure", + Data: map[string]any{"ResellerId": "migrated-42"}, + }, + want: true, + }, + { + name: "Data matches but VendorID does not", + route: Route{ + VendorID: "microsoft-*", + Data: map[string]string{"ResellerId": "migrated-*"}, + }, + tx: sdk.TransactionContext{ + VendorID: "google-cloud", + Data: map[string]any{"ResellerId": "migrated-42"}, + }, + want: false, + }, + { + name: "VendorID matches but Data does not", + route: Route{ + VendorID: "microsoft-*", + Data: map[string]string{"ResellerId": "migrated-*"}, + }, + tx: sdk.TransactionContext{ + VendorID: "microsoft-azure", + Data: map[string]any{"ResellerId": "legacy-1"}, + }, + want: false, + }, + { + name: "Data with recursive ** glob matches multi-segment value", + route: Route{Data: map[string]string{"Scope": "tenant/**"}}, + tx: sdk.TransactionContext{Data: map[string]any{"Scope": "tenant/abc/sub/123"}}, + want: true, + }, + { + name: "empty Data map plus top-level VendorID match returns true", + route: Route{VendorID: "microsoft-*", Data: map[string]string{}}, + tx: sdk.TransactionContext{VendorID: "microsoft-azure"}, + want: true, + }, + { + name: "nil Data plus top-level VendorID match returns true", + route: Route{VendorID: "microsoft-*"}, + tx: sdk.TransactionContext{VendorID: "microsoft-azure"}, + want: true, + }, + { + name: "tx Data has key but value is wrong type (int) does not match", + route: Route{Data: map[string]string{"ResellerId": "migrated-*"}}, + tx: sdk.TransactionContext{Data: map[string]any{"ResellerId": 42}}, + want: false, + }, + { + name: "tx Data has key but value is wrong type (bool) does not match", + route: Route{Data: map[string]string{"Active": "*"}}, + tx: sdk.TransactionContext{Data: map[string]any{"Active": true}}, + want: false, + }, + { + name: "tx Data nil with non-empty route Data does not match", + route: Route{Data: map[string]string{"ResellerId": "migrated-*"}}, + tx: sdk.TransactionContext{Data: nil}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.route.Matches(tt.tx) + if got != tt.want { + t.Errorf("Route.Matches() = %v, want %v", got, tt.want) + } + }) + } +} + +// --- Specificity table for the Data dimension --- + +func TestRoute_Specificity_Data_Table(t *testing.T) { + tests := []struct { + name string + route Route + want int + }{ + { + name: "zero non-empty fields, nil Data", + route: Route{}, + want: 0, + }, + { + name: "zero non-empty fields, empty Data map", + route: Route{Data: map[string]string{}}, + want: 0, + }, + { + name: "only Data with one entry", + route: Route{Data: map[string]string{"ResellerId": "x"}}, + want: 1, + }, + { + name: "only Data with three entries", + route: Route{Data: map[string]string{"a": "1", "b": "2", "c": "3"}}, + want: 3, + }, + { + name: "VendorID plus two Data entries", + route: Route{VendorID: "microsoft-*", Data: map[string]string{"a": "1", "b": "2"}}, + want: 3, + }, + { + name: "all five top-level fields plus two Data entries", + route: Route{ + VendorID: "v", + MarketplaceID: "m", + ProductID: "p", + TargetURL: "t", + EnvironmentID: "e", + Data: map[string]string{"a": "1", "b": "2"}, + }, + want: 7, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.route.Specificity() + if got != tt.want { + t.Errorf("Route.Specificity() = %d, want %d", got, tt.want) + } + }) + } +} + +// --- routesMayOverlap regression checks for the Data dimension --- + +func TestRoutesMayOverlap_DataDimension(t *testing.T) { + tests := []struct { + name string + a Route + b Route + want bool + }{ + { + name: "both routes same Data key with same literal may overlap", + a: Route{Data: map[string]string{"ResellerId": "migrated-1"}}, + b: Route{Data: map[string]string{"ResellerId": "migrated-1"}}, + want: true, + }, + { + name: "both routes same Data key with disjoint literals do not overlap", + a: Route{Data: map[string]string{"ResellerId": "migrated-1"}}, + b: Route{Data: map[string]string{"ResellerId": "migrated-2"}}, + want: false, + }, + { + name: "routes with different Data keys may overlap (no shared dimension)", + a: Route{Data: map[string]string{"ResellerId": "migrated-1"}}, + b: Route{Data: map[string]string{"TenantId": "abc"}}, + want: true, + }, + { + name: "same Data key, one glob one literal: conservatively may overlap", + a: Route{Data: map[string]string{"ResellerId": "migrated-*"}}, + b: Route{Data: map[string]string{"ResellerId": "legacy-1"}}, + want: true, + }, + { + name: "same Data key, one glob one matching literal: may overlap", + a: Route{Data: map[string]string{"ResellerId": "migrated-*"}}, + b: Route{Data: map[string]string{"ResellerId": "migrated-1"}}, + want: true, + }, + { + name: "one route with Data, other with nil Data: no shared dimension, may overlap", + a: Route{Data: map[string]string{"ResellerId": "migrated-1"}}, + b: Route{VendorID: "microsoft-*"}, + want: true, + }, + { + name: "one route with Data, other with empty Data map: no shared dimension, may overlap", + a: Route{Data: map[string]string{"ResellerId": "migrated-1"}}, + b: Route{Data: map[string]string{}}, + want: true, + }, + { + name: "multi-key Data with one shared key disjoint literal: do not overlap", + a: Route{Data: map[string]string{"ResellerId": "migrated-1", "TenantId": "abc"}}, + b: Route{Data: map[string]string{"ResellerId": "migrated-2", "TenantId": "abc"}}, + want: false, + }, + { + name: "multi-key Data with all shared keys literal and equal: may overlap", + a: Route{Data: map[string]string{"ResellerId": "migrated-1", "TenantId": "abc"}}, + b: Route{Data: map[string]string{"ResellerId": "migrated-1", "TenantId": "abc"}}, + want: true, + }, + { + name: "Data shared key matches, top-level VendorID disjoint literals: do not overlap", + a: Route{ + VendorID: "microsoft", + Data: map[string]string{"ResellerId": "migrated-1"}, + }, + b: Route{ + VendorID: "google", + Data: map[string]string{"ResellerId": "migrated-1"}, + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := routesMayOverlap(tt.a, tt.b); got != tt.want { + t.Errorf("routesMayOverlap(%+v, %+v) = %v, want %v", tt.a, tt.b, got, tt.want) + } + }) + } +} + +// --- routeString rendering for Data --- + +func TestRouteString_RendersDataEntries(t *testing.T) { + tests := []struct { + name string + route Route + // substrings that must appear in the rendered string + mustContain []string + }{ + { + name: "Data-only single entry", + route: Route{Data: map[string]string{"ResellerId": "migrated-*"}}, + mustContain: []string{"Data[ResellerId]=migrated-*"}, + }, + { + name: "Data combined with VendorID", + route: Route{ + VendorID: "microsoft-*", + Data: map[string]string{"ResellerId": "migrated-*"}, + }, + mustContain: []string{"VendorID=microsoft-*", "Data[ResellerId]=migrated-*"}, + }, + { + name: "Data with multiple entries is deterministically sorted", + route: Route{ + Data: map[string]string{"TenantId": "abc", "ResellerId": "migrated-*"}, + }, + // Sorted alphabetically: ResellerId before TenantId. + mustContain: []string{"Data[ResellerId]=migrated-*", "Data[TenantId]=abc"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := routeString(tt.route) + for _, s := range tt.mustContain { + if !strings.Contains(got, s) { + t.Errorf("routeString = %q, want it to contain %q", got, s) + } + } + }) + } + + // Additionally verify deterministic ordering: ResellerId appears + // before TenantId in the sorted multi-entry case. + got := routeString(Route{Data: map[string]string{"TenantId": "abc", "ResellerId": "migrated-*"}}) + iReseller := strings.Index(got, "Data[ResellerId]=") + iTenant := strings.Index(got, "Data[TenantId]=") + if iReseller == -1 || iTenant == -1 { + t.Fatalf("routeString = %q, missing expected entries", got) + } + if iReseller > iTenant { + t.Errorf("routeString = %q, expected Data[ResellerId] before Data[TenantId]", got) + } +} + func TestStripScheme(t *testing.T) { tests := []struct { name string diff --git a/sdk/compliance/router.go b/sdk/compliance/router.go new file mode 100644 index 0000000..190b7fc --- /dev/null +++ b/sdk/compliance/router.go @@ -0,0 +1,44 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package compliance + +import ( + "context" + "net/http" + "testing" + + "github.com/cloudblue/chaperone/sdk" +) + +// VerifyRouter exercises a sdk.RequestRouter implementation against the +// minimal contract: it must accept a cancelled context without panicking +// and return either (nil, nil) (fall-through) or a non-nil RouteAction. +// +// VerifyRouter is opt-in: only plugins that implement RequestRouter need +// to call it. Plugins that do not implement RequestRouter remain valid +// under VerifyContract. +func VerifyRouter(t *testing.T, router sdk.RequestRouter) { + t.Helper() + + t.Run("returns without panicking on cancelled context", func(t *testing.T) { + t.Helper() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "https://example.test", http.NoBody) + _, _ = router.RouteRequest(ctx, sdk.TransactionContext{}, req) // no panic = pass + }) + + t.Run("nil RouteAction is a valid fall-through signal", func(t *testing.T) { + t.Helper() + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://example.test", http.NoBody) + action, err := router.RouteRequest(context.Background(), sdk.TransactionContext{}, req) + if err != nil { + return // returning an error is also valid + } + if action != nil && action.ForwardTo == "" { + t.Fatalf("router returned non-nil RouteAction with empty ForwardTo; use nil instead") + } + }) +} diff --git a/sdk/compliance/router_test.go b/sdk/compliance/router_test.go new file mode 100644 index 0000000..6ac268e --- /dev/null +++ b/sdk/compliance/router_test.go @@ -0,0 +1,31 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package compliance_test + +import ( + "context" + "net/http" + "testing" + + "github.com/cloudblue/chaperone/sdk" + "github.com/cloudblue/chaperone/sdk/compliance" +) + +// stubRouter implements sdk.RequestRouter for the compliance test. +type stubRouter struct { + action *sdk.RouteAction + err error +} + +func (s *stubRouter) RouteRequest(_ context.Context, _ sdk.TransactionContext, _ *http.Request) (*sdk.RouteAction, error) { + return s.action, s.err +} + +func TestVerifyRouter_NilActionAndNilError_Passes(t *testing.T) { + compliance.VerifyRouter(t, &stubRouter{}) +} + +func TestVerifyRouter_NonEmptyForwardTo_Passes(t *testing.T) { + compliance.VerifyRouter(t, &stubRouter{action: &sdk.RouteAction{ForwardTo: "x"}}) +} diff --git a/sdk/router.go b/sdk/router.go new file mode 100644 index 0000000..4523415 --- /dev/null +++ b/sdk/router.go @@ -0,0 +1,33 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package sdk + +import ( + "context" + "net/http" +) + +// RequestRouter is an optional plugin capability that decides, per request, +// whether to forward the request to a different upstream instead of +// proceeding with credential injection and the vendor call. +// +// Plugins that do not implement RequestRouter retain today's behavior. +type RequestRouter interface { + // RouteRequest is invoked before GetCredentials. Returning a non-nil + // RouteAction with a non-empty ForwardTo causes the Core to forward + // the request to the named forward_target and skip both credential + // injection and ModifyResponse. + // + // Returning nil (or an empty ForwardTo) is the fall-through signal: + // the Core continues with the normal credential-injection flow. + RouteRequest(ctx context.Context, tx TransactionContext, req *http.Request) (*RouteAction, error) +} + +// RouteAction signals how the Core should handle this request. +type RouteAction struct { + // ForwardTo names a forward_target defined in the proxy configuration. + // When non-empty, the Core forwards the request to that target's URL + // and skips credential injection and ModifyResponse. + ForwardTo string +} diff --git a/sdk/router_test.go b/sdk/router_test.go new file mode 100644 index 0000000..1248c0d --- /dev/null +++ b/sdk/router_test.go @@ -0,0 +1,20 @@ +// Copyright 2026 CloudBlue LLC +// SPDX-License-Identifier: Apache-2.0 + +package sdk + +import "testing" + +func TestRouteAction_ZeroValue_IsNonForwarding(t *testing.T) { + var a RouteAction + if a.ForwardTo != "" { + t.Fatalf("zero RouteAction.ForwardTo = %q, want empty", a.ForwardTo) + } +} + +func TestRouteAction_WithForwardTo_PreservesName(t *testing.T) { + a := RouteAction{ForwardTo: "company-b"} + if a.ForwardTo != "company-b" { + t.Fatalf("RouteAction.ForwardTo = %q, want %q", a.ForwardTo, "company-b") + } +}