From 9cf104d2aec96abb94c208cc8cb1d96c0e1c0bbe Mon Sep 17 00:00:00 2001 From: Ashley Davis Date: Tue, 27 Jan 2026 11:39:32 +0000 Subject: [PATCH] add support for fetching keys from a JWKS endpoint This requires changing a few function signatures and plumbing some things together. Notably, I don't want to have a second service discovery client and send duplicate calls off, so I shared the service discovery client from the CyberArk client and added caching of responses to the service discovery client. I also had to share credentials for auth. Also removes encrypted-secrets example The machinehub mode is required for key fetching, but doesn't play nicely with one shot mode and the example hangs. Secret encryption is covered in the e2e tests, so just remove the example for simplicity Signed-off-by: Ashley Davis --- examples/encrypted-secrets/README.md | 47 -- examples/encrypted-secrets/config.yaml | 41 -- examples/encrypted-secrets/test.sh | 65 --- hack/ark/test-e2e.sh | 2 + internal/cyberark/client_test.go | 8 +- .../identity/cmd/testidentity/main.go | 4 +- internal/cyberark/identity/identity_test.go | 2 +- .../cyberark/servicediscovery/discovery.go | 50 ++- .../servicediscovery/discovery_test.go | 4 +- internal/envelope/keyfetch/client.go | 188 ++++++++ internal/envelope/keyfetch/client_test.go | 425 ++++++++++++++++++ internal/envelope/keyfetch/doc.go | 9 + internal/envelope/keyfetch/fake.go | 85 ++++ internal/envelope/keyfetch/fake_test.go | 89 ++++ internal/envelope/rsa/encryptor.go | 43 +- internal/envelope/rsa/encryptor_test.go | 109 +---- internal/envelope/rsa/keys_test.go | 6 +- internal/envelope/types.go | 7 +- pkg/agent/run.go | 48 +- pkg/client/client_cyberark.go | 21 +- pkg/datagatherer/k8sdynamic/dynamic.go | 10 +- pkg/datagatherer/k8sdynamic/dynamic_test.go | 7 +- 22 files changed, 952 insertions(+), 318 deletions(-) delete mode 100644 examples/encrypted-secrets/README.md delete mode 100644 examples/encrypted-secrets/config.yaml delete mode 100755 examples/encrypted-secrets/test.sh create mode 100644 internal/envelope/keyfetch/client.go create mode 100644 internal/envelope/keyfetch/client_test.go create mode 100644 internal/envelope/keyfetch/doc.go create mode 100644 internal/envelope/keyfetch/fake.go create mode 100644 internal/envelope/keyfetch/fake_test.go diff --git a/examples/encrypted-secrets/README.md b/examples/encrypted-secrets/README.md deleted file mode 100644 index 3276b7fe..00000000 --- a/examples/encrypted-secrets/README.md +++ /dev/null @@ -1,47 +0,0 @@ -# Encrypted Secrets Example - -This example demonstrates how to use the disco agent to gather Kubernetes secrets and encrypt their data fields. - -## Overview - -When the `ARK_SEND_SECRETS` environment variable is set to `"true"`, the disco agent will: - -0. Fetch an encryption key from the configured endpoint (if running in production) or use a local key for testing -1. Discover Kubernetes secrets in your cluster (excluding common system secret types) -2. Encrypt each secret's data fields using RSA envelope encryption with JWE (JSON Web Encryption) format -3. If running in production, send the encrypted secrets to the configured endpoint; otherwise, write them to `output.json` for testing - -The encryption uses: - -- **Key Algorithm**: RSA-OAEP-256 (for encrypting the content encryption key) -- **Content Encryption**: AES-256-GCM (for encrypting the actual secret data) -- **Format**: JWE Compact Serialization - -Metadata (names, namespaces, labels, annotations) remains in plaintext for discovery purposes, while the sensitive secret data is encrypted. Some keys in Secret data fields are also preserved in the `data` section, for backwards compatibility. - -## Prerequisites - -1. A running Kubernetes cluster with secrets to discover -3. Go installed - -## Configuration File - -The `config.yaml` file configures: - -- The data gatherer to collect Kubernetes secrets -- Field selectors to exclude system secrets (service account tokens, docker configs, etc.) -- The cluster ID and organization ID for grouping data - -## Running the Example - -Test the agent locally by running this script: - -```bash -./test.sh -``` - -This will: - -- Connect to your current Kubernetes context -- Gather all non-system secrets -- Write the raw data to `output.json` diff --git a/examples/encrypted-secrets/config.yaml b/examples/encrypted-secrets/config.yaml deleted file mode 100644 index b966e4d1..00000000 --- a/examples/encrypted-secrets/config.yaml +++ /dev/null @@ -1,41 +0,0 @@ -# encrypted-secrets config.yaml -# -# An example configuration file demonstrating how to use the disco agent -# to send encrypted secrets to CyberArk Discovery & Context. -# -# The agent will: -# 1. Discover Kubernetes secrets in the cluster -# 2. Encrypt the secret data fields using RSA envelope encryption (JWE format) -# 3. Upload the encrypted secrets to CyberArk Discovery & Context -# -# Example usage: -# -# export ARK_SUBDOMAIN="your-subdomain" -# export ARK_USERNAME="your-username" -# export ARK_SECRET="your-secret" -# export ARK_SEND_SECRETS="true" -# -# go run . agent \ -# --agent-config-file examples/encrypted-secrets/config.yaml \ -# --one-shot \ -# --output-path output.json -# -organization_id: "my-organization" -cluster_id: "my_cluster" -period: 1m -data-gatherers: -- kind: "k8s-dynamic" - name: "k8s/secrets" - config: - resource-type: - version: v1 - resource: secrets - # Filter out common system secret types to focus on application secrets - field-selectors: - - type!=kubernetes.io/service-account-token - - type!=kubernetes.io/dockercfg - - type!=kubernetes.io/dockerconfigjson - - type!=kubernetes.io/basic-auth - - type!=kubernetes.io/ssh-auth - - type!=bootstrap.kubernetes.io/token - - type!=helm.sh/release.v1 diff --git a/examples/encrypted-secrets/test.sh b/examples/encrypted-secrets/test.sh deleted file mode 100755 index 506001ea..00000000 --- a/examples/encrypted-secrets/test.sh +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/env bash -# test.sh - Test script for the encrypted secrets example -# -# This script demonstrates running the disco agent with encrypted secrets enabled. -# It will run in one-shot mode and output to a local file for inspection. - -set -euo pipefail - -# Colors for output -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -NC='\033[0m' # No Color - -echo -e "${GREEN}=== Encrypted Secrets Example Test ===${NC}\n" - -echo -e "${GREEN}Testing agent with Kubernetes secrets${NC}" -echo "" - -# Enable encrypted secrets -export ARK_SEND_SECRETS="true" - -# Check Kubernetes connectivity -if ! kubectl cluster-info &> /dev/null; then - echo -e "${RED}Error: Unable to connect to Kubernetes cluster${NC}" - echo "Please ensure your kubeconfig is configured correctly." - exit 1 -fi - -echo -e "${GREEN}✓ Connected to Kubernetes cluster${NC}" -CONTEXT=$(kubectl config current-context) -echo " Context: ${CONTEXT}" -echo "" - -# Check for secrets -SECRET_COUNT=$(kubectl get secrets --all-namespaces --no-headers 2>/dev/null | wc -l | tr -d ' ') -echo "Found ${SECRET_COUNT} secrets in cluster" -echo "" - -# Run the agent in one-shot mode with output to file -OUTPUT_FILE="output.json" -echo -e "${GREEN}Running disco agent with encrypted secrets enabled...${NC}" -echo "Command: go run ../.. agent --agent-config-file config.yaml --one-shot --output-path ${OUTPUT_FILE}" -echo "" - -if go run ../.. agent \ - --agent-config-file config.yaml \ - --one-shot \ - --output-path "${OUTPUT_FILE}"; then - - echo "" - echo -e "${GREEN}✓ Agent completed successfully${NC}" - - # Check if output file was created - if [ -f "${OUTPUT_FILE}" ]; then - echo -e "${GREEN}✓ Output file created: ${OUTPUT_FILE}${NC}" - else - echo -e "${RED}✗ Output file was not created${NC}" - exit 1 - fi -else - echo "" - echo -e "${RED}✗ Agent failed${NC}" - exit 1 -fi diff --git a/hack/ark/test-e2e.sh b/hack/ark/test-e2e.sh index 3c5eeb67..0bd86baa 100755 --- a/hack/ark/test-e2e.sh +++ b/hack/ark/test-e2e.sh @@ -101,6 +101,7 @@ kubectl apply -f "${root_dir}/hack/ark/cluster-external-secret.yaml" # We use a non-existent tag and omit the `--version` flag, to work around a Helm # v4 bug. See: https://github.com/helm/helm/issues/31600 +# TODO: shouldn't need to set config.sendSecretValues because it will default to true in future helm upgrade agent "oci://${ARK_CHART}:NON_EXISTENT_TAG@${ARK_CHART_DIGEST}" \ --install \ --wait \ @@ -113,6 +114,7 @@ helm upgrade agent "oci://${ARK_CHART}:NON_EXISTENT_TAG@${ARK_CHART_DIGEST}" \ --set config.clusterName="e2e-test-cluster" \ --set config.clusterDescription="A temporary cluster for E2E testing. Contact @wallrj-cyberark." \ --set config.period=60s \ + --set config.sendSecretValues=true \ --set-json "podLabels={\"disco-agent.cyberark.cloud/test-id\": \"${RANDOM}\"}" kubectl rollout status deployments/disco-agent --namespace "${NAMESPACE}" diff --git a/internal/cyberark/client_test.go b/internal/cyberark/client_test.go index 1c220d2d..b8efeb93 100644 --- a/internal/cyberark/client_test.go +++ b/internal/cyberark/client_test.go @@ -32,9 +32,9 @@ func TestCyberArkClient_PutSnapshot_MockAPI(t *testing.T) { Secret: "somepassword", } - discoveryClient := servicediscovery.New(httpClient) + discoveryClient := servicediscovery.New(httpClient, cfg.Subdomain) - serviceMap, tenantUUID, err := discoveryClient.DiscoverServices(t.Context(), cfg.Subdomain) + serviceMap, tenantUUID, err := discoveryClient.DiscoverServices(t.Context()) if err != nil { t.Fatalf("failed to discover mock services: %v", err) } @@ -76,9 +76,9 @@ func TestCyberArkClient_PutSnapshot_RealAPI(t *testing.T) { cfg, err := cyberark.LoadClientConfigFromEnvironment() require.NoError(t, err) - discoveryClient := servicediscovery.New(httpClient) + discoveryClient := servicediscovery.New(httpClient, cfg.Subdomain) - serviceMap, tenantUUID, err := discoveryClient.DiscoverServices(t.Context(), cfg.Subdomain) + serviceMap, tenantUUID, err := discoveryClient.DiscoverServices(t.Context()) if err != nil { t.Fatalf("failed to discover services: %v", err) } diff --git a/internal/cyberark/identity/cmd/testidentity/main.go b/internal/cyberark/identity/cmd/testidentity/main.go index 916c81ea..0a8df80b 100644 --- a/internal/cyberark/identity/cmd/testidentity/main.go +++ b/internal/cyberark/identity/cmd/testidentity/main.go @@ -50,8 +50,8 @@ func run(ctx context.Context) error { var rootCAs *x509.CertPool httpClient := http_client.NewDefaultClient(version.UserAgent(), rootCAs) - sdClient := servicediscovery.New(httpClient) - services, _, err := sdClient.DiscoverServices(ctx, subdomain) + sdClient := servicediscovery.New(httpClient, subdomain) + services, _, err := sdClient.DiscoverServices(ctx) if err != nil { return fmt.Errorf("while performing service discovery: %s", err) } diff --git a/internal/cyberark/identity/identity_test.go b/internal/cyberark/identity/identity_test.go index 917ba15d..0915f46c 100644 --- a/internal/cyberark/identity/identity_test.go +++ b/internal/cyberark/identity/identity_test.go @@ -53,7 +53,7 @@ func TestLoginUsernamePassword_RealAPI(t *testing.T) { arktesting.SkipIfNoEnv(t) subdomain := os.Getenv("ARK_SUBDOMAIN") httpClient := http.DefaultClient - services, _, err := servicediscovery.New(httpClient).DiscoverServices(t.Context(), subdomain) + services, _, err := servicediscovery.New(httpClient, subdomain).DiscoverServices(t.Context()) require.NoError(t, err) loginUsernamePasswordTests(t, func(t testing.TB) inputs { diff --git a/internal/cyberark/servicediscovery/discovery.go b/internal/cyberark/servicediscovery/discovery.go index 82394ab3..93598d5c 100644 --- a/internal/cyberark/servicediscovery/discovery.go +++ b/internal/cyberark/servicediscovery/discovery.go @@ -9,6 +9,8 @@ import ( "net/url" "os" "path" + "sync" + "time" arkapi "github.com/jetstack/preflight/internal/cyberark/api" "github.com/jetstack/preflight/pkg/version" @@ -35,21 +37,34 @@ const ( // users to fetch URLs for various APIs available in CyberArk. This client is specialised to // fetch only API endpoints, since only API endpoints are required by the Venafi Kubernetes Agent currently. type Client struct { - client *http.Client - baseURL string + client *http.Client + baseURL string + subdomain string + + cachedResponse *Services + cachedTenantID string + cachedResponseTime time.Time + cachedResponseMutex sync.Mutex } // New creates a new CyberArk Service Discovery client. If the ARK_DISCOVERY_API // environment variable is set, it is used as the base URL for the service // discovery API. Otherwise, the production URL is used. -func New(httpClient *http.Client) *Client { +func New(httpClient *http.Client, subdomain string) *Client { baseURL := os.Getenv("ARK_DISCOVERY_API") if baseURL == "" { baseURL = ProdDiscoveryAPIBaseURL } + client := &Client{ - client: httpClient, - baseURL: baseURL, + client: httpClient, + baseURL: baseURL, + subdomain: subdomain, + + cachedResponse: nil, + cachedTenantID: "", + cachedResponseTime: time.Time{}, + cachedResponseMutex: sync.Mutex{}, } return client @@ -93,17 +108,24 @@ type Services struct { DiscoveryContext ServiceEndpoint } -// DiscoverServices fetches from the service discovery service for a given subdomain +// DiscoverServices fetches from the service discovery service for the configured subdomain // and parses the CyberArk Identity API URL and Inventory API URL. // It also returns the Tenant ID UUID corresponding to the subdomain. -func (c *Client) DiscoverServices(ctx context.Context, subdomain string) (*Services, string, error) { +func (c *Client) DiscoverServices(ctx context.Context) (*Services, string, error) { + c.cachedResponseMutex.Lock() + defer c.cachedResponseMutex.Unlock() + + if c.cachedResponse != nil && time.Since(c.cachedResponseTime) < 1*time.Hour { + return c.cachedResponse, c.cachedTenantID, nil + } + u, err := url.Parse(c.baseURL) if err != nil { return nil, "", fmt.Errorf("invalid base URL for service discovery: %w", err) } u.Path = path.Join(u.Path, "api/public/tenant-discovery") - u.RawQuery = url.Values{"bySubdomain": []string{subdomain}}.Encode() + u.RawQuery = url.Values{"bySubdomain": []string{c.subdomain}}.Encode() endpoint := u.String() @@ -127,7 +149,7 @@ func (c *Client) DiscoverServices(ctx context.Context, subdomain string) (*Servi // a 404 error is returned with an empty JSON body "{}" if the subdomain is unknown; at the time of writing, we haven't observed // any other errors and so we can't special case them if resp.StatusCode == http.StatusNotFound { - return nil, "", fmt.Errorf("got an HTTP 404 response from service discovery; maybe the subdomain %q is incorrect or does not exist?", subdomain) + return nil, "", fmt.Errorf("got an HTTP 404 response from service discovery; maybe the subdomain %q is incorrect or does not exist?", c.subdomain) } return nil, "", fmt.Errorf("got unexpected status code %s from request to service discovery API", resp.Status) @@ -167,8 +189,14 @@ func (c *Client) DiscoverServices(ctx context.Context, subdomain string) (*Servi } //TODO: Should add a check for discoveryContextAPI too? - return &Services{ + services := &Services{ Identity: ServiceEndpoint{API: identityAPI}, DiscoveryContext: ServiceEndpoint{API: discoveryContextAPI}, - }, discoveryResp.TenantID, nil + } + + c.cachedResponse = services + c.cachedTenantID = discoveryResp.TenantID + c.cachedResponseTime = time.Now() + + return services, discoveryResp.TenantID, nil } diff --git a/internal/cyberark/servicediscovery/discovery_test.go b/internal/cyberark/servicediscovery/discovery_test.go index 00d0fd58..618e63f9 100644 --- a/internal/cyberark/servicediscovery/discovery_test.go +++ b/internal/cyberark/servicediscovery/discovery_test.go @@ -64,9 +64,9 @@ func Test_DiscoverIdentityAPIURL(t *testing.T) { }, }) - client := New(httpClient) + client := New(httpClient, testSpec.subdomain) - services, _, err := client.DiscoverServices(ctx, testSpec.subdomain) + services, _, err := client.DiscoverServices(ctx) if testSpec.expectedError != nil { assert.EqualError(t, err, testSpec.expectedError.Error()) assert.Nil(t, services) diff --git a/internal/envelope/keyfetch/client.go b/internal/envelope/keyfetch/client.go new file mode 100644 index 00000000..3fbd67d3 --- /dev/null +++ b/internal/envelope/keyfetch/client.go @@ -0,0 +1,188 @@ +package keyfetch + +import ( + "context" + "crypto/rsa" + "crypto/x509" + "fmt" + "io" + "net/http" + "net/url" + + "github.com/jetstack/venafi-connection-lib/http_client" + "github.com/lestrrat-go/jwx/v3/jwk" + "k8s.io/klog/v2" + + "github.com/jetstack/preflight/internal/cyberark" + "github.com/jetstack/preflight/internal/cyberark/identity" + "github.com/jetstack/preflight/internal/cyberark/servicediscovery" + "github.com/jetstack/preflight/pkg/version" +) + +const ( + // minRSAKeySize is the minimum RSA key size in bits; we'd expect that keys will be larger but 2048 is a sane floor + // to enforce to ensure that a weak key can't accidentally be used + minRSAKeySize = 2048 +) + +// KeyFetcher is an interface for fetching public keys. +type KeyFetcher interface { + // FetchKey retrieves a public key from the key source. + FetchKey(ctx context.Context) (PublicKey, error) +} + +// Compile-time check that Client implements KeyFetcher +var _ KeyFetcher = (*Client)(nil) + +// PublicKey represents an RSA public key retrieved from the key server. +type PublicKey struct { + // KeyID is the unique identifier for this key + KeyID string + + // Key is the actual RSA public key + Key *rsa.PublicKey +} + +// Client fetches public keys from a CyberArk HTTP endpoint that provides keys in JWKS format. +// It can be expanded in future to support other key types and formats, but for now it only supports RSA keys +// and ignored other types. +type Client struct { + discoveryClient *servicediscovery.Client + identityClient *identity.Client + cfg cyberark.ClientConfig + + // httpClient is the HTTP client used for requests + httpClient *http.Client +} + +// NewClient creates a new key fetching client. +// Uses CyberArk service discovery to derive the JWKS endpoint and CyberArk identity client for authentication. +// Constructing the client involves a service discovery call to initialise the identity client, +// so this may return an error if the discovery client is not able to connect to the service discovery endpoint. +// If httpClient is nil, a default HTTP client will be created. +func NewClient(ctx context.Context, discoveryClient *servicediscovery.Client, cfg cyberark.ClientConfig, httpClient *http.Client) (*Client, error) { + if httpClient == nil { + var rootCAs *x509.CertPool + httpClient = http_client.NewDefaultClient(version.UserAgent(), rootCAs) + } + + services, _, err := discoveryClient.DiscoverServices(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get services from discovery client for initialising identity client: %w", err) + } + + return &Client{ + discoveryClient: discoveryClient, + identityClient: identity.New(httpClient, services.Identity.API, cfg.Subdomain), + cfg: cfg, + httpClient: httpClient, + }, nil +} + +// FetchKey retrieves the public keys from the configured endpoint. +// It returns a slice of PublicKey structs containing the key material and metadata. +func (c *Client) FetchKey(ctx context.Context) (PublicKey, error) { + services, _, err := c.discoveryClient.DiscoverServices(ctx) + if err != nil { + return PublicKey{}, fmt.Errorf("failed to get services from discovery client: %w", err) + } + + err = c.identityClient.LoginUsernamePassword(ctx, c.cfg.Username, []byte(c.cfg.Secret)) + if err != nil { + return PublicKey{}, fmt.Errorf("failed to authenticate for fetching JWKs: %w", err) + } + + endpoint, err := url.JoinPath(services.DiscoveryContext.API, "discovery-context/jwks") + if err != nil { + return PublicKey{}, fmt.Errorf("failed to construct endpoint URL: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return PublicKey{}, fmt.Errorf("failed to create request: %w", err) + } + + _, err = c.identityClient.AuthenticateRequest(req) + if err != nil { + return PublicKey{}, fmt.Errorf("failed to authenticate request: %s", err) + } + + req.Header.Set("Accept", "application/json") + version.SetUserAgent(req) + + resp, err := c.httpClient.Do(req) + if err != nil { + return PublicKey{}, fmt.Errorf("failed to fetch keys from %s: %w", endpoint, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return PublicKey{}, fmt.Errorf("unexpected status code %d from %s: %s", resp.StatusCode, endpoint, string(body)) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return PublicKey{}, fmt.Errorf("failed to read response body: %w", err) + } + + keySet, err := jwk.Parse(body) + if err != nil { + return PublicKey{}, fmt.Errorf("failed to parse JWKs response: %w", err) + } + + for i := range keySet.Len() { + key, ok := keySet.Key(i) + if !ok { + continue + } + + // Only process RSA keys + if key.KeyType().String() != "RSA" { + continue + } + + var rawKey any + if err := jwk.Export(key, &rawKey); err != nil { + // skip unparseable keys + continue + } + + rsaKey, ok := rawKey.(*rsa.PublicKey) + if !ok { + // only process RSA keys (for now) + continue + } + + if rsaKey.N.BitLen() < minRSAKeySize { + // skip keys that are too small to be secure + continue + } + + kid, ok := key.KeyID() + if !ok { + // skip any keys which don't have an ID + continue + } + + alg, ok := key.Algorithm() + if !ok { + // skip any keys which don't have an algorithm specified + continue + } + + if alg.String() != "RSA-OAEP-256" { + // we only use RSA keys for RSA-OAEP-256 + continue + } + + klog.FromContext(ctx).WithName("keyfetch").Info("found valid RSA key", "kid", kid) + // return the first valid key we find + return PublicKey{ + KeyID: kid, + Key: rsaKey, + }, nil + } + + return PublicKey{}, fmt.Errorf("no valid RSA keys found at %s", endpoint) +} diff --git a/internal/envelope/keyfetch/client_test.go b/internal/envelope/keyfetch/client_test.go new file mode 100644 index 00000000..1dff3d53 --- /dev/null +++ b/internal/envelope/keyfetch/client_test.go @@ -0,0 +1,425 @@ +package keyfetch + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/jetstack/preflight/internal/cyberark" + "github.com/jetstack/preflight/internal/cyberark/identity" + "github.com/jetstack/preflight/internal/cyberark/servicediscovery" +) + +// testClientSetup sets up a complete test environment with mock identity and discovery servers +// and returns a configured client along with the test ClientConfig +func testClientSetup(t *testing.T, jwksServerURL string) (*Client, cyberark.ClientConfig) { + t.Helper() + + // Create mock identity server + identityURL, httpClient := identity.MockIdentityServer(t) + + // Set up services for mock discovery server + services := servicediscovery.Services{ + Identity: servicediscovery.ServiceEndpoint{ + IsActive: true, + Type: "main", + API: identityURL, + }, + DiscoveryContext: servicediscovery.ServiceEndpoint{ + IsActive: true, + Type: "main", + API: jwksServerURL, + }, + } + + // Create mock discovery server + _ = servicediscovery.MockDiscoveryServer(t, services) + + // Create discovery client + discoveryClient := servicediscovery.New(httpClient, servicediscovery.MockDiscoverySubdomain) + + // Create test config with credentials that match the mock identity server + cfg := cyberark.ClientConfig{ + Subdomain: servicediscovery.MockDiscoverySubdomain, + Username: "test@example.com", // matches successUser in mock identity server + Secret: "somepassword", // matches successPassword in mock identity server + } + + // Create the keyfetch client with the properly configured httpClient + client, err := NewClient(t.Context(), discoveryClient, cfg, httpClient) + require.NoError(t, err) + + return client, cfg +} + +func mockJWKSServer(t *testing.T, statusCode int, jwksResponse string) *httptest.Server { + t.Helper() + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check if this is the JWKS endpoint + if r.URL.Path == "/discovery-context/jwks" { + assert.Equal(t, http.MethodGet, r.Method) + assert.Equal(t, "application/json", r.Header.Get("Accept")) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + _, err := w.Write([]byte(jwksResponse)) + require.NoError(t, err) + } + })) + + t.Cleanup(server.Close) + + return server +} + +func TestClient_FetchKey(t *testing.T) { + // Sample JWKs response with a valid RSA key + // This is a minimal example with the required fields, used in multiple tests + jwksResponse := `{ +"keys": [ + { + "kty": "RSA", + "use": "enc", + "kid": "test-key-1", + "alg": "RSA-OAEP-256", + "n": "vDdioGpDuAEQDd4WRXyWa4sZ5EeS9OPsRrU_jU3PbZdDcANxfh_WSeSvSBKGfGXGC3fIzu0Ernk9VjXcs3LeFdRq2N4nNRZvCzsd_MjBtn7CWgjM_Sk9DXEGn3cHHilcJUJQ4i2YgX9bHu0odNgE6cSVIUEMIC2EGuGk_I7lwroinAAwXpNLLQkV_25kv_QQof2i5f7AocY6QTd0SAo8ZUqFBzanupkeFpl3-Bsz6_zdt_N0x9k5XHQn42Q2oTupTwvXFbE1x8XtCpiaP3_fsQ9dN7t4z6HtwlNUJB2tFfF6PgdKZ9LuJpYjFPYzJQ6Rv28fuc8YHcF7Jittjyzmew", + "e": "AQAB" + } + ] + }` + + t.Run("successful fetch", func(t *testing.T) { + + server := mockJWKSServer(t, http.StatusOK, jwksResponse) + + client, _ := testClientSetup(t, server.URL) + key, err := client.FetchKey(t.Context()) + + require.NoError(t, err) + + assert.Equal(t, "test-key-1", key.KeyID) + assert.NotNil(t, key.Key) + assert.NotNil(t, key.Key.N) + assert.Greater(t, key.Key.E, 0) + }) + + t.Run("multiple keys", func(t *testing.T) { + // want to check that FetchKey returns the first valid RSA key, even if there are multiple keys in the response + multiKeyResponse := `{ + "keys": [ + { + "kty": "RSA", + "kid": "key-1", + "alg": "RSA-OAEP-256", + "n": "vDdioGpDuAEQDd4WRXyWa4sZ5EeS9OPsRrU_jU3PbZdDcANxfh_WSeSvSBKGfGXGC3fIzu0Ernk9VjXcs3LeFdRq2N4nNRZvCzsd_MjBtn7CWgjM_Sk9DXEGn3cHHilcJUJQ4i2YgX9bHu0odNgE6cSVIUEMIC2EGuGk_I7lwroinAAwXpNLLQkV_25kv_QQof2i5f7AocY6QTd0SAo8ZUqFBzanupkeFpl3-Bsz6_zdt_N0x9k5XHQn42Q2oTupTwvXFbE1x8XtCpiaP3_fsQ9dN7t4z6HtwlNUJB2tFfF6PgdKZ9LuJpYjFPYzJQ6Rv28fuc8YHcF7Jittjyzmew", + "e": "AQAB" + }, + { + "kty": "RSA", + "kid": "key-2", + "alg": "RSA-OAEP-256", + "n": "4J0VE8FK1rSQUBGiLpk4MkPyFApCyCugOfkuH0hiHclxZay96JgyZylH97eqs-ZmWXtv42ynYctIj2ZleaoqVDfMOqZ1GsbccyNAYReDtUYgeUtJEajpfUo1vitoh6OEB6nB0Hau07ELLqcUoxH_zkH5Kwoi_BgxByJDQ1HOut6nyEPTXLTMrAYK_pqL_kzsU0OtrCgSBh6j-11ToqUfxsLupbadRC0t5zrq4-3mZKqxBUz4XB2g3b9d2lH7mOTl5J_E8jcD4tK9DePzjdbkRWonBEJetWl9f2mh_VD1sxJbie1kzM5cdQylXzV_AvhSr58w00qy6XR_QXI10UU16Q", + "e": "AQAB" + } + ] + }` + + server := mockJWKSServer(t, http.StatusOK, multiKeyResponse) + + client, _ := testClientSetup(t, server.URL) + key, err := client.FetchKey(t.Context()) + + require.NoError(t, err) + + assert.Equal(t, "key-1", key.KeyID) + }) + + t.Run("filters non-RSA keys", func(t *testing.T) { + // check that the client correctly filters out non-RSA keys and returns the first valid RSA key + mixedKeyResponse := `{ + "keys": [ + { + "kty": "EC", + "kid": "ec-key-1", + "alg": "ES256", + "crv": "P-256", + "x": "WKn-ZIGevcwGIyyrzFoZNBdaq9_TsqzGl96oc0CWuis", + "y": "y77t-RvAHRKTsSGdIYUfweuOvwrvDD-Q3Hv5J0fSKbE" + }, + { + "kty": "RSA", + "kid": "rsa-key-1", + "alg": "RSA-OAEP-256", + "n": "vDdioGpDuAEQDd4WRXyWa4sZ5EeS9OPsRrU_jU3PbZdDcANxfh_WSeSvSBKGfGXGC3fIzu0Ernk9VjXcs3LeFdRq2N4nNRZvCzsd_MjBtn7CWgjM_Sk9DXEGn3cHHilcJUJQ4i2YgX9bHu0odNgE6cSVIUEMIC2EGuGk_I7lwroinAAwXpNLLQkV_25kv_QQof2i5f7AocY6QTd0SAo8ZUqFBzanupkeFpl3-Bsz6_zdt_N0x9k5XHQn42Q2oTupTwvXFbE1x8XtCpiaP3_fsQ9dN7t4z6HtwlNUJB2tFfF6PgdKZ9LuJpYjFPYzJQ6Rv28fuc8YHcF7Jittjyzmew", + "e": "AQAB" + } + ] + }` + + server := mockJWKSServer(t, http.StatusOK, mixedKeyResponse) + + client, _ := testClientSetup(t, server.URL) + key, err := client.FetchKey(t.Context()) + + require.NoError(t, err) + assert.Equal(t, "rsa-key-1", key.KeyID) + }) + + t.Run("error on non-200 status", func(t *testing.T) { + server := mockJWKSServer(t, http.StatusInternalServerError, "") // Response body won't be used since we return 500 + + client, _ := testClientSetup(t, server.URL) + _, err := client.FetchKey(t.Context()) + + require.Error(t, err) + assert.Contains(t, err.Error(), "unexpected status code 500") + }) + + t.Run("error on invalid JSON", func(t *testing.T) { + server := mockJWKSServer(t, http.StatusOK, "invalid json") + + client, _ := testClientSetup(t, server.URL) + _, err := client.FetchKey(t.Context()) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse JWKs response") + }) + + t.Run("error on no RSA keys", func(t *testing.T) { + ecOnlyResponse := `{ + "keys": [ + { + "kty": "EC", + "kid": "ec-key-1", + "alg": "ES256", + "crv": "P-256", + "x": "WKn-ZIGevcwGIyyrzFoZNBdaq9_TsqzGl96oc0CWuis", + "y": "y77t-RvAHRKTsSGdIYUfweuOvwrvDD-Q3Hv5J0fSKbE" + } + ] + }` + + server := mockJWKSServer(t, http.StatusOK, ecOnlyResponse) + + client, _ := testClientSetup(t, server.URL) + _, err := client.FetchKey(t.Context()) + + require.Error(t, err) + assert.Contains(t, err.Error(), "no valid RSA keys found") + }) + + t.Run("context cancellation", func(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // This handler will never respond + <-r.Context().Done() + })) + defer server.Close() + + client, _ := testClientSetup(t, server.URL) + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + _, err := client.FetchKey(ctx) + + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") + }) + + t.Run("authentication failure", func(t *testing.T) { + server := mockJWKSServer(t, http.StatusOK, jwksResponse) + + // Create mock identity server + identityURL, httpClient := identity.MockIdentityServer(t) + + // Set up services for mock discovery server + services := servicediscovery.Services{ + Identity: servicediscovery.ServiceEndpoint{ + IsActive: true, + Type: "main", + API: identityURL, + }, + DiscoveryContext: servicediscovery.ServiceEndpoint{ + IsActive: true, + Type: "main", + API: server.URL, + }, + } + + // Create mock discovery server + _ = servicediscovery.MockDiscoveryServer(t, services) + + // Create discovery client + discoveryClient := servicediscovery.New(httpClient, servicediscovery.MockDiscoverySubdomain) + + // Create test config with WRONG credentials + // Use the failureUser from the mock identity server + cfg := cyberark.ClientConfig{ + Subdomain: servicediscovery.MockDiscoverySubdomain, + Username: "test-fail@example.com", // This user is configured to fail in the mock server // TODO: export these constants from the identity package to avoid hardcoding them here + Secret: "somepassword", + } + + // Create the keyfetch client + client, err := NewClient(t.Context(), discoveryClient, cfg, httpClient) + require.NoError(t, err) + + _, err = client.FetchKey(t.Context()) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to authenticate") + }) + + t.Run("service discovery fails", func(t *testing.T) { + // Create mock identity server (won't be used but needed for setup) + identityURL, httpClient := identity.MockIdentityServer(t) + + // Set up services for mock discovery server + services := servicediscovery.Services{ + Identity: servicediscovery.ServiceEndpoint{ + IsActive: true, + Type: "main", + API: identityURL, + }, + } + + // Create mock discovery server + _ = servicediscovery.MockDiscoveryServer(t, services) + + // Create discovery client with a subdomain that triggers failure + discoveryClient := servicediscovery.New(httpClient, "bad-request") + + cfg := cyberark.ClientConfig{ + Subdomain: "bad-request", + Username: "test@example.com", + Secret: "somepassword", + } + + _, err := NewClient(t.Context(), discoveryClient, cfg, httpClient) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to get services from discovery client") + }) + + t.Run("ignores small RSA keys", func(t *testing.T) { + // This is a 1024-bit RSA key (half the minimum size) + // Generated with: openssl genrsa 1024 | openssl rsa -pubin -outform der | base64url + smallKeyResponse := `{ + "keys": [ + { + "kty": "RSA", + "kid": "small-key-1", + "alg": "RSA-OAEP-256", + "n": "wKhJSKlx9aO_TmT4qAqN5EZ8FeXCXmh5F_hGHWL6c4lKvdKc_jBq1YI0H8pCIWZ6WhPKmBZ8JQ4Q2q0TjvdKLYQ8jqzMZxz4J_z4ySbN7yBn7N7xKqL5JN7KqVr7N8KQ", + "e": "AQAB" + }, + { + "kty": "RSA", + "kid": "valid-key", + "alg": "RSA-OAEP-256", + "n": "vDdioGpDuAEQDd4WRXyWa4sZ5EeS9OPsRrU_jU3PbZdDcANxfh_WSeSvSBKGfGXGC3fIzu0Ernk9VjXcs3LeFdRq2N4nNRZvCzsd_MjBtn7CWgjM_Sk9DXEGn3cHHilcJUJQ4i2YgX9bHu0odNgE6cSVIUEMIC2EGuGk_I7lwroinAAwXpNLLQkV_25kv_QQof2i5f7AocY6QTd0SAo8ZUqFBzanupkeFpl3-Bsz6_zdt_N0x9k5XHQn42Q2oTupTwvXFbE1x8XtCpiaP3_fsQ9dN7t4z6HtwlNUJB2tFfF6PgdKZ9LuJpYjFPYzJQ6Rv28fuc8YHcF7Jittjyzmew", + "e": "AQAB" + } + ] + }` + + server := mockJWKSServer(t, http.StatusOK, smallKeyResponse) + + client, _ := testClientSetup(t, server.URL) + key, err := client.FetchKey(t.Context()) + + require.NoError(t, err) + // Should skip the small key and return the valid one + assert.Equal(t, "valid-key", key.KeyID) + }) + + t.Run("skips keys without kid", func(t *testing.T) { + noKidResponse := `{ + "keys": [ + { + "kty": "RSA", + "alg": "RSA-OAEP-256", + "n": "vDdioGpDuAEQDd4WRXyWa4sZ5EeS9OPsRrU_jU3PbZdDcANxfh_WSeSvSBKGfGXGC3fIzu0Ernk9VjXcs3LeFdRq2N4nNRZvCzsd_MjBtn7CWgjM_Sk9DXEGn3cHHilcJUJQ4i2YgX9bHu0odNgE6cSVIUEMIC2EGuGk_I7lwroinAAwXpNLLQkV_25kv_QQof2i5f7AocY6QTd0SAo8ZUqFBzanupkeFpl3-Bsz6_zdt_N0x9k5XHQn42Q2oTupTwvXFbE1x8XtCpiaP3_fsQ9dN7t4z6HtwlNUJB2tFfF6PgdKZ9LuJpYjFPYzJQ6Rv28fuc8YHcF7Jittjyzmew", + "e": "AQAB" + } + ] + }` + + server := mockJWKSServer(t, http.StatusOK, noKidResponse) + + client, _ := testClientSetup(t, server.URL) + _, err := client.FetchKey(t.Context()) + + require.Error(t, err) + assert.Contains(t, err.Error(), "no valid RSA keys found") + }) + + t.Run("filters keys with wrong algorithm", func(t *testing.T) { + wrongAlgResponse := `{ + "keys": [ + { + "kty": "RSA", + "kid": "wrong-alg-key", + "alg": "RS256", + "n": "vDdioGpDuAEQDd4WRXyWa4sZ5EeS9OPsRrU_jU3PbZdDcANxfh_WSeSvSBKGfGXGC3fIzu0Ernk9VjXcs3LeFdRq2N4nNRZvCzsd_MjBtn7CWgjM_Sk9DXEGn3cHHilcJUJQ4i2YgX9bHu0odNgE6cSVIUEMIC2EGuGk_I7lwroinAAwXpNLLQkV_25kv_QQof2i5f7AocY6QTd0SAo8ZUqFBzanupkeFpl3-Bsz6_zdt_N0x9k5XHQn42Q2oTupTwvXFbE1x8XtCpiaP3_fsQ9dN7t4z6HtwlNUJB2tFfF6PgdKZ9LuJpYjFPYzJQ6Rv28fuc8YHcF7Jittjyzmew", + "e": "AQAB" + }, + { + "kty": "RSA", + "kid": "correct-alg-key", + "alg": "RSA-OAEP-256", + "n": "4J0VE8FK1rSQUBGiLpk4MkPyFApCyCugOfkuH0hiHclxZay96JgyZylH97eqs-ZmWXtv42ynYctIj2ZleaoqVDfMOqZ1GsbccyNAYReDtUYgeUtJEajpfUo1vitoh6OEB6nB0Hau07ELLqcUoxH_zkH5Kwoi_BgxByJDQ1HOut6nyEPTXLTMrAYK_pqL_kzsU0OtrCgSBh6j-11ToqUfxsLupbadRC0t5zrq4-3mZKqxBUz4XB2g3b9d2lH7mOTl5J_E8jcD4tK9DePzjdbkRWonBEJetWl9f2mh_VD1sxJbie1kzM5cdQylXzV_AvhSr58w00qy6XR_QXI10UU16Q", + "e": "AQAB" + } + ] + }` + + server := mockJWKSServer(t, http.StatusOK, wrongAlgResponse) + + client, _ := testClientSetup(t, server.URL) + key, err := client.FetchKey(t.Context()) + + require.NoError(t, err) + // Should skip the RS256 key and return the RSA-OAEP-256 key + assert.Equal(t, "correct-alg-key", key.KeyID) + }) + + t.Run("skips keys without algorithm", func(t *testing.T) { + noAlgResponse := `{ + "keys": [ + { + "kty": "RSA", + "kid": "no-alg-key", + "n": "vDdioGpDuAEQDd4WRXyWa4sZ5EeS9OPsRrU_jU3PbZdDcANxfh_WSeSvSBKGfGXGC3fIzu0Ernk9VjXcs3LeFdRq2N4nNRZvCzsd_MjBtn7CWgjM_Sk9DXEGn3cHHilcJUJQ4i2YgX9bHu0odNgE6cSVIUEMIC2EGuGk_I7lwroinAAwXpNLLQkV_25kv_QQof2i5f7AocY6QTd0SAo8ZUqFBzanupkeFpl3-Bsz6_zdt_N0x9k5XHQn42Q2oTupTwvXFbE1x8XtCpiaP3_fsQ9dN7t4z6HtwlNUJB2tFfF6PgdKZ9LuJpYjFPYzJQ6Rv28fuc8YHcF7Jittjyzmew", + "e": "AQAB" + } + ] + }` + + server := mockJWKSServer(t, http.StatusOK, noAlgResponse) + + client, _ := testClientSetup(t, server.URL) + _, err := client.FetchKey(t.Context()) + + require.Error(t, err) + assert.Contains(t, err.Error(), "no valid RSA keys found") + }) + + t.Run("handles empty key set", func(t *testing.T) { + emptyKeysResponse := `{ + "keys": [] + }` + + server := mockJWKSServer(t, http.StatusOK, emptyKeysResponse) + + client, _ := testClientSetup(t, server.URL) + _, err := client.FetchKey(t.Context()) + + require.Error(t, err) + assert.Contains(t, err.Error(), "no valid RSA keys found") + }) +} diff --git a/internal/envelope/keyfetch/doc.go b/internal/envelope/keyfetch/doc.go new file mode 100644 index 00000000..d3cb09da --- /dev/null +++ b/internal/envelope/keyfetch/doc.go @@ -0,0 +1,9 @@ +// Package keyfetch provides a client for fetching encryption keys from an HTTP endpoint. +// +// The client retrieves public keys in JSON Web Key Set (JWKs) format from a remote +// server and converts them into usable cryptographic keys for envelope encryption. +// +// This package uses github.com/lestrrat-go/jwx/v3/jwk for JWK parsing and handling. +// +// Currently, keyfetch only supports RSA keys for envelope encryption. +package keyfetch diff --git a/internal/envelope/keyfetch/fake.go b/internal/envelope/keyfetch/fake.go new file mode 100644 index 00000000..d7226b2b --- /dev/null +++ b/internal/envelope/keyfetch/fake.go @@ -0,0 +1,85 @@ +package keyfetch + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "fmt" +) + +// Compile-time check that FakeClient implements KeyFetcher +var _ KeyFetcher = (*FakeClient)(nil) + +// FakeClient is a fake implementation of the key fetcher for testing. +// It can be configured to return specific keys or errors for testing different scenarios. +type FakeClient struct { + // Key is the public key that will be returned by FetchKey. + // If nil, a random key will be generated on the first call. + Key *PublicKey + + // Err is the error that will be returned by FetchKey. + // If both Key and Err are set, Err takes precedence. + Err error + + // FetchKeyCalls tracks how many times FetchKey was called + FetchKeyCalls int +} + +// NewFakeClient creates a new fake client for testing. +func NewFakeClient() *FakeClient { + return &FakeClient{} +} + +// NewFakeClientWithKey creates a new fake client that returns the specified key. +func NewFakeClientWithKey(keyID string, key *rsa.PublicKey) *FakeClient { + return &FakeClient{ + Key: &PublicKey{ + KeyID: keyID, + Key: key, + }, + } +} + +// NewFakeClientWithError creates a new fake client that returns the specified error. +func NewFakeClientWithError(err error) *FakeClient { + return &FakeClient{ + Err: err, + } +} + +// FetchKey implements the key fetching interface for testing. +// It returns the configured key or error, or generates a random key if none is configured. +func (f *FakeClient) FetchKey(ctx context.Context) (PublicKey, error) { + f.FetchKeyCalls++ + + // Check if context is canceled + if ctx.Err() != nil { + return PublicKey{}, ctx.Err() + } + + // If an error is configured, return it + if f.Err != nil { + return PublicKey{}, f.Err + } + + // If a key is configured, return it + if f.Key != nil { + return *f.Key, nil + } + + // Generate a random key for testing + privateKey, err := rsa.GenerateKey(rand.Reader, minRSAKeySize) + if err != nil { + return PublicKey{}, fmt.Errorf("failed to generate test key: %w", err) + } + + generatedKey := PublicKey{ + KeyID: "test-key", + Key: &privateKey.PublicKey, + } + + // Cache the generated key for subsequent calls + f.Key = &generatedKey + + return generatedKey, nil +} diff --git a/internal/envelope/keyfetch/fake_test.go b/internal/envelope/keyfetch/fake_test.go new file mode 100644 index 00000000..c2bc255d --- /dev/null +++ b/internal/envelope/keyfetch/fake_test.go @@ -0,0 +1,89 @@ +package keyfetch + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFakeClient(t *testing.T) { + t.Run("returns generated key by default", func(t *testing.T) { + fake := NewFakeClient() + + key, err := fake.FetchKey(t.Context()) + require.NoError(t, err) + + assert.Equal(t, "test-key", key.KeyID) + assert.NotNil(t, key.Key) + assert.Equal(t, 1, fake.FetchKeyCalls) + + // Subsequent calls return the same key + key2, err := fake.FetchKey(t.Context()) + require.NoError(t, err) + assert.Equal(t, key.KeyID, key2.KeyID) + assert.Equal(t, key.Key, key2.Key) + assert.Equal(t, 2, fake.FetchKeyCalls) + }) + + t.Run("returns configured key", func(t *testing.T) { + privateKey, err := rsa.GenerateKey(rand.Reader, minRSAKeySize) + require.NoError(t, err) + + fake := NewFakeClientWithKey("custom-key", &privateKey.PublicKey) + + key, err := fake.FetchKey(t.Context()) + require.NoError(t, err) + + assert.Equal(t, "custom-key", key.KeyID) + assert.Equal(t, &privateKey.PublicKey, key.Key) + assert.Equal(t, 1, fake.FetchKeyCalls) + }) + + t.Run("returns configured error", func(t *testing.T) { + expectedErr := errors.New("test error") + fake := NewFakeClientWithError(expectedErr) + + _, err := fake.FetchKey(t.Context()) + require.Error(t, err) + + assert.Equal(t, expectedErr, err) + assert.Equal(t, 1, fake.FetchKeyCalls) + }) + + t.Run("respects context cancellation", func(t *testing.T) { + fake := NewFakeClient() + + ctx, cancel := context.WithCancel(t.Context()) + cancel() + + _, err := fake.FetchKey(ctx) + require.Error(t, err) + + assert.Equal(t, context.Canceled, err) + assert.Equal(t, 1, fake.FetchKeyCalls) + }) + + t.Run("error takes precedence over key", func(t *testing.T) { + privateKey, err := rsa.GenerateKey(rand.Reader, minRSAKeySize) + require.NoError(t, err) + + expectedErr := errors.New("test error") + fake := &FakeClient{ + Key: &PublicKey{ + KeyID: "custom-key", + Key: &privateKey.PublicKey, + }, + Err: expectedErr, + } + + _, err = fake.FetchKey(t.Context()) + require.Error(t, err) + + assert.Equal(t, expectedErr, err) + }) +} diff --git a/internal/envelope/rsa/encryptor.go b/internal/envelope/rsa/encryptor.go index 8cc0e17a..80d87de3 100644 --- a/internal/envelope/rsa/encryptor.go +++ b/internal/envelope/rsa/encryptor.go @@ -1,20 +1,17 @@ package rsa import ( - "crypto/rsa" + "context" "fmt" "github.com/lestrrat-go/jwx/v3/jwa" "github.com/lestrrat-go/jwx/v3/jwe" "github.com/jetstack/preflight/internal/envelope" + "github.com/jetstack/preflight/internal/envelope/keyfetch" ) const ( - // minRSAKeySize is the minimum RSA key size in bits; we'd expect that keys will be larger but 2048 is a sane floor - // to enforce to ensure that a weak key can't accidentally be used - minRSAKeySize = 2048 - // EncryptionType is the type identifier for RSA JWE encryption EncryptionType = "JWE-RSA" ) @@ -25,45 +22,33 @@ var _ envelope.Encryptor = (*Encryptor)(nil) // Encryptor provides envelope encryption using RSA-OAEP-256 for key wrapping // and AES-256-GCM for data encryption, outputting JWE Compact Serialization format. type Encryptor struct { - keyID string - publicKey *rsa.PublicKey + fetcher keyfetch.KeyFetcher } -// NewEncryptor creates a new Encryptor with the provided RSA public key. -// The RSA key must be at least minRSAKeySize bits. +// NewEncryptor creates a new Encryptor with the provided key fetcher. // The encryptor will use RSA-OAEP-256 for key encryption and A256GCM for content encryption. -func NewEncryptor(keyID string, publicKey *rsa.PublicKey) (*Encryptor, error) { - if publicKey == nil { - return nil, fmt.Errorf("RSA public key cannot be nil") - } - - // Validate key size - keySize := publicKey.N.BitLen() - if keySize < minRSAKeySize { - return nil, fmt.Errorf("RSA key size must be at least %d bits, got %d bits", minRSAKeySize, keySize) - } - - if len(keyID) == 0 { - return nil, fmt.Errorf("keyID cannot be empty") - } - +func NewEncryptor(fetcher keyfetch.KeyFetcher) (*Encryptor, error) { return &Encryptor{ - keyID: keyID, - publicKey: publicKey, + fetcher: fetcher, }, nil } // Encrypt performs envelope encryption on the provided data. // It returns an EncryptedData struct containing JWE Compact Serialization format and type metadata. // The JWE uses RSA-OAEP-256 for key encryption and A256GCM for content encryption. -func (e *Encryptor) Encrypt(data []byte) (*envelope.EncryptedData, error) { +func (e *Encryptor) Encrypt(ctx context.Context, data []byte) (*envelope.EncryptedData, error) { if len(data) == 0 { return nil, fmt.Errorf("data to encrypt cannot be empty") } + key, err := e.fetcher.FetchKey(ctx) + if err != nil { + return nil, fmt.Errorf("failed to fetch encryption key: %w", err) + } + // Create headers with the key ID headers := jwe.NewHeaders() - if err := headers.Set("kid", e.keyID); err != nil { + if err := headers.Set("kid", key.KeyID); err != nil { return nil, fmt.Errorf("failed to set key ID header: %w", err) } @@ -71,7 +56,7 @@ func (e *Encryptor) Encrypt(data []byte) (*envelope.EncryptedData, error) { // TODO: in go1.26+, consider using secret.Do to wrap this call, since it will generate an AES key encrypted, err := jwe.Encrypt( data, - jwe.WithKey(jwa.RSA_OAEP_256(), e.publicKey, jwe.WithPerRecipientHeaders(headers)), + jwe.WithKey(jwa.RSA_OAEP_256(), key.Key, jwe.WithPerRecipientHeaders(headers)), jwe.WithContentEncryption(jwa.A256GCM()), jwe.WithCompact(), ) diff --git a/internal/envelope/rsa/encryptor_test.go b/internal/envelope/rsa/encryptor_test.go index 6763c1e5..50534fed 100644 --- a/internal/envelope/rsa/encryptor_test.go +++ b/internal/envelope/rsa/encryptor_test.go @@ -3,9 +3,7 @@ package rsa import ( "crypto/rand" "crypto/rsa" - "crypto/x509" "encoding/base64" - "encoding/pem" "strings" "sync" "testing" @@ -13,21 +11,15 @@ import ( "github.com/lestrrat-go/jwx/v3/jwa" "github.com/lestrrat-go/jwx/v3/jwe" "github.com/stretchr/testify/require" -) -const testKeyID = "test-key-id" + "github.com/jetstack/preflight/internal/envelope/keyfetch" +) -// smallRSAKey1024 is a hardcoded 1024-bit RSA public key in PEM format (PKIX) -// used for testing key size validation. This key is intentionally weak and should -// only be used for testing purposes. -// This is hardcoded rather than generated in order to save compute, and also on the -// assumption that future Go releases might restrict the ability to generate such small keys. -const smallRSAKey1024 = `-----BEGIN PUBLIC KEY----- -MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDCNDoCM0OBt4HFxFxyU50FYsuZ -gK+lgel/Jlzb+ghkWpCL1Vk3Au7aet4KxNxQh5dFRxtMU7pe6fC5eZtdL3+0TCUu -XAUVgMhTRn3ZXlEmJXosuiFQ2y4+3nbWL51OxXRf3jsieSVqr4fbceakuOKXp4vX -wgiguV3/XqaysHs1uwIDAQAB ------END PUBLIC KEY-----` +const ( + testKeyID = "test-key-id" + // minRSAKeySize is the minimum RSA key size used for test key generation + minRSAKeySize = 2048 +) var ( testKeyOnce sync.Once @@ -49,67 +41,10 @@ func testKey() *rsa.PrivateKey { return internalTestKey } -func TestNewEncryptor_ValidKeys(t *testing.T) { - tests := []struct { - name string - keySize int - }{ - {"2048 bits", 2048}, - {"3072 bits", 3072}, - {"4096 bits", 4096}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - key, err := rsa.GenerateKey(rand.Reader, tt.keySize) - require.NoError(t, err) - - enc, err := NewEncryptor(testKeyID, &key.PublicKey) - require.NoError(t, err) - require.NotNil(t, enc) - }) - } -} - -func TestNewEncryptor_RejectsSmallKeys(t *testing.T) { - // Parse the hardcoded 1024-bit RSA public key from PEM format - block, _ := pem.Decode([]byte(smallRSAKey1024)) - require.NotNil(t, block, "failed to decode PEM block") - - // NB: a future Go update might restrict the ability to parse small keys; - // if that happens, this test will need to be removed or changed. - pubKey, err := x509.ParsePKIXPublicKey(block.Bytes) - require.NoError(t, err, "failed to parse RSA public key") - - rsaPubKey, ok := pubKey.(*rsa.PublicKey) - require.True(t, ok, "key should be an RSA public key") - - enc, err := NewEncryptor(testKeyID, rsaPubKey) - require.Error(t, err) - require.Nil(t, enc) - require.Contains(t, err.Error(), "must be at least 2048 bits") -} - -func TestNewEncryptor_NilKey(t *testing.T) { - enc, err := NewEncryptor(testKeyID, nil) - require.Error(t, err) - require.Nil(t, enc) - require.Contains(t, err.Error(), "cannot be nil") -} - -func TestNewEncryptor_EmptyKeyID(t *testing.T) { - key := testKey() - - enc, err := NewEncryptor("", &key.PublicKey) - require.Error(t, err) - require.Nil(t, enc) - require.Contains(t, err.Error(), "keyID cannot be empty") -} - func TestEncrypt_VariousDataSizes(t *testing.T) { - key := testKey() + fetcher := keyfetch.NewFakeClient() - enc, err := NewEncryptor(testKeyID, &key.PublicKey) + enc, err := NewEncryptor(fetcher) require.NoError(t, err) tests := []struct { @@ -127,7 +62,7 @@ func TestEncrypt_VariousDataSizes(t *testing.T) { _, err := rand.Read(data) require.NoError(t, err) - result, err := enc.Encrypt(data) + result, err := enc.Encrypt(t.Context(), data) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, EncryptionType, result.Type, "Type should be JWE-RSA") @@ -152,31 +87,31 @@ func TestEncrypt_VariousDataSizes(t *testing.T) { } func TestEncrypt_EmptyData(t *testing.T) { - key := testKey() + fetcher := keyfetch.NewFakeClient() - enc, err := NewEncryptor(testKeyID, &key.PublicKey) + enc, err := NewEncryptor(fetcher) require.NoError(t, err) - result, err := enc.Encrypt([]byte{}) + result, err := enc.Encrypt(t.Context(), []byte{}) require.Error(t, err) require.Nil(t, result) require.Contains(t, err.Error(), "cannot be empty") } func TestEncrypt_NonDeterministic(t *testing.T) { - key := testKey() + fetcher := keyfetch.NewFakeClient() - enc, err := NewEncryptor(testKeyID, &key.PublicKey) + enc, err := NewEncryptor(fetcher) require.NoError(t, err) data := []byte("test data for encryption") // Encrypt the same data twice - result1, err := enc.Encrypt(data) + result1, err := enc.Encrypt(t.Context(), data) require.NoError(t, err) require.Equal(t, EncryptionType, result1.Type, "Type should be JWE-RSA") - result2, err := enc.Encrypt(data) + result2, err := enc.Encrypt(t.Context(), data) require.NoError(t, err) require.Equal(t, EncryptionType, result2.Type, "Type should be JWE-RSA") @@ -186,12 +121,13 @@ func TestEncrypt_NonDeterministic(t *testing.T) { func TestEncrypt_JWEFormat(t *testing.T) { key := testKey() + fetcher := keyfetch.NewFakeClientWithKey(testKeyID, &key.PublicKey) - enc, err := NewEncryptor(testKeyID, &key.PublicKey) + enc, err := NewEncryptor(fetcher) require.NoError(t, err) data := []byte("test data") - result, err := enc.Encrypt(data) + result, err := enc.Encrypt(t.Context(), data) require.NoError(t, err) require.Equal(t, EncryptionType, result.Type, "Type should be JWE-RSA") @@ -203,14 +139,15 @@ func TestEncrypt_JWEFormat(t *testing.T) { func TestEncrypt_DecryptRoundtrip(t *testing.T) { key := testKey() + fetcher := keyfetch.NewFakeClientWithKey(testKeyID, &key.PublicKey) - enc, err := NewEncryptor(testKeyID, &key.PublicKey) + enc, err := NewEncryptor(fetcher) require.NoError(t, err) originalData := []byte("test data for roundtrip encryption and decryption") // Encrypt the data - encrypted, err := enc.Encrypt(originalData) + encrypted, err := enc.Encrypt(t.Context(), originalData) require.NoError(t, err) require.Equal(t, EncryptionType, encrypted.Type, "Type should be JWE-RSA") diff --git a/internal/envelope/rsa/keys_test.go b/internal/envelope/rsa/keys_test.go index 83f86e19..1a138a35 100644 --- a/internal/envelope/rsa/keys_test.go +++ b/internal/envelope/rsa/keys_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/jetstack/preflight/internal/envelope/keyfetch" internalrsa "github.com/jetstack/preflight/internal/envelope/rsa" ) @@ -151,13 +152,14 @@ func TestLoadHardcodedPublicKey_CanBeUsedWithEncryptor(t *testing.T) { require.NotNil(t, key) require.NotEmpty(t, uid) - encryptor, err := internalrsa.NewEncryptor(uid, key) + fetcher := keyfetch.NewFakeClientWithKey(uid, key) + encryptor, err := internalrsa.NewEncryptor(fetcher) require.NoError(t, err) require.NotNil(t, encryptor) // Test that the encryptor can encrypt data testData := []byte("test data for encryption") - encryptedData, err := encryptor.Encrypt(testData) + encryptedData, err := encryptor.Encrypt(t.Context(), testData) require.NoError(t, err) require.NotNil(t, encryptedData) require.NotEmpty(t, encryptedData.Data) diff --git a/internal/envelope/types.go b/internal/envelope/types.go index b458f35d..6618ce6c 100644 --- a/internal/envelope/types.go +++ b/internal/envelope/types.go @@ -1,6 +1,9 @@ package envelope -import "encoding/json" +import ( + "context" + "encoding/json" +) // EncryptedData represents encrypted data along with metadata about the encryption type. type EncryptedData struct { @@ -34,5 +37,5 @@ func (ed *EncryptedData) ToMap() map[string]any { type Encryptor interface { // Encrypt encrypts data using envelope encryption, returning an EncryptedData struct // containing the encrypted payload and encryption type metadata. - Encrypt(data []byte) (*EncryptedData, error) + Encrypt(ctx context.Context, data []byte) (*EncryptedData, error) } diff --git a/pkg/agent/run.go b/pkg/agent/run.go index cba3c0a0..e8f1c234 100644 --- a/pkg/agent/run.go +++ b/pkg/agent/run.go @@ -32,6 +32,7 @@ import ( "github.com/jetstack/preflight/api" "github.com/jetstack/preflight/internal/envelope" + "github.com/jetstack/preflight/internal/envelope/keyfetch" "github.com/jetstack/preflight/internal/envelope/rsa" "github.com/jetstack/preflight/pkg/client" "github.com/jetstack/preflight/pkg/datagatherer" @@ -164,6 +165,20 @@ func Run(cmd *cobra.Command, args []string) (returnErr error) { return fmt.Errorf("failed to create event recorder: %v", err) } + // Check if secret encryption is enabled via environment variable + // When enabled, secret data will be kept for encryption instead of being redacted + encryptSecrets := strings.ToLower(os.Getenv("ARK_SEND_SECRET_VALUES")) == "true" + + var encryptor envelope.Encryptor + + if encryptSecrets { + encryptor, err = loadEncryptor(gctx, preflightClient) + if err != nil { + log.Error(err, "Failed to set up encryptor for secrets, secret data will not be sent") + encryptSecrets = false + } + } + dataGatherers := map[string]datagatherer.DataGatherer{} // load datagatherer config and boot each one @@ -184,17 +199,11 @@ func Run(cmd *cobra.Command, args []string) (returnErr error) { dynDg.ExcludeAnnotKeys = config.ExcludeAnnotationKeysRegex dynDg.ExcludeLabelKeys = config.ExcludeLabelKeysRegex - // Check if secret encryption is enabled via environment variable - // When enabled, secret data will be kept for encryption instead of being redacted - encryptSecrets := strings.ToLower(os.Getenv("ARK_SEND_SECRET_VALUES")) + gvr := dynDg.GVR() - if encryptSecrets == "true" { - var err error - - dynDg.Encryptor, err = loadEncryptor() - if err != nil { - log.Error(err, "Failed to set up encryptor for secrets, secret data will not be sent") - } + if encryptSecrets && gvr.Resource == "secrets" && gvr.Group == "" { + log.Info("Secret encryption enabled for datagatherer") + dynDg.Encryptor = encryptor } } @@ -273,14 +282,23 @@ func Run(cmd *cobra.Command, args []string) (returnErr error) { } // loadEncryptor sets up an encryptor for encrypting secrets. For now, it just loads a hardcoded public key -func loadEncryptor() (envelope.Encryptor, error) { - // TODO(@SgtCoDFish): this will eventually fetch a key from JWKS endpoint when that endpoint is available - key, keyID, err := rsa.LoadHardcodedPublicKey() +func loadEncryptor(ctx context.Context, preflightClient client.Client) (envelope.Encryptor, error) { + cyberarkClient, ok := preflightClient.(*client.CyberArkClient) + if !ok { + return nil, fmt.Errorf("secret encryption is only supported for CyberArk clients") + } + + cfg, err := cyberarkClient.Config() + if err != nil { + return nil, fmt.Errorf("failed to get CyberArk client config: %w", err) + } + + fetcher, err := keyfetch.NewClient(ctx, cyberarkClient.DiscoveryClient(), cfg, nil) if err != nil { - return nil, fmt.Errorf("failed to load public key for secret encryption: %w", err) + return nil, fmt.Errorf("failed to create key fetcher for secret encryption: %w", err) } - encryptor, err := rsa.NewEncryptor(keyID, key) + encryptor, err := rsa.NewEncryptor(fetcher) if err != nil { return nil, fmt.Errorf("failed to create encryptor for secret encryption: %w", err) } diff --git a/pkg/client/client_cyberark.go b/pkg/client/client_cyberark.go index 826ab3a5..3232e7c6 100644 --- a/pkg/client/client_cyberark.go +++ b/pkg/client/client_cyberark.go @@ -28,6 +28,8 @@ import ( type CyberArkClient struct { configLoader cyberark.ClientConfigLoader httpClient *http.Client + + discoveryClient *servicediscovery.Client } var _ Client = &CyberArkClient{} @@ -41,14 +43,15 @@ var _ Client = &CyberArkClient{} func NewCyberArk(httpClient *http.Client) (*CyberArkClient, error) { configLoader := cyberark.LoadClientConfigFromEnvironment - _, err := configLoader() + cfg, err := configLoader() if err != nil { return nil, err } return &CyberArkClient{ - configLoader: configLoader, - httpClient: httpClient, + configLoader: configLoader, + httpClient: httpClient, + discoveryClient: servicediscovery.New(httpClient, cfg.Subdomain), }, nil } @@ -67,9 +70,7 @@ func (o *CyberArkClient) PostDataReadingsWithOptions(ctx context.Context, readin return fmt.Errorf("failed to load config: %w", err) } - discoveryClient := servicediscovery.New(o.httpClient) - - serviceMap, tenantUUID, err := discoveryClient.DiscoverServices(ctx, cfg.Subdomain) + serviceMap, tenantUUID, err := o.discoveryClient.DiscoverServices(ctx) if err != nil { return err } @@ -95,6 +96,14 @@ func (o *CyberArkClient) PostDataReadingsWithOptions(ctx context.Context, readin return nil } +func (o *CyberArkClient) DiscoveryClient() *servicediscovery.Client { + return o.discoveryClient +} + +func (o *CyberArkClient) Config() (cyberark.ClientConfig, error) { + return o.configLoader() +} + // baseSnapshotFromOptions creates a base snapshot with common fields from the provided options. // This includes the cluster name, description, and agent version. // Other fields like ClusterID and K8SVersion need to be populated separately. diff --git a/pkg/datagatherer/k8sdynamic/dynamic.go b/pkg/datagatherer/k8sdynamic/dynamic.go index df490db4..da805b07 100644 --- a/pkg/datagatherer/k8sdynamic/dynamic.go +++ b/pkg/datagatherer/k8sdynamic/dynamic.go @@ -348,6 +348,10 @@ type DataGathererDynamic struct { Encryptor envelope.Encryptor } +func (g *DataGathererDynamic) GVR() schema.GroupVersionResource { + return g.groupVersionResource +} + // Run starts the dynamic data gatherer's informers for resource collection. // Returns error if the data gatherer informer wasn't initialized, Run blocks // until the stopCh is closed. @@ -469,7 +473,7 @@ func (g *DataGathererDynamic) redactList(ctx context.Context, list []*api.Gather // If encryption is enabled, we encrypt the data and preserve it, but we still need to redact later. // If encryption is enabled and _fails_, we MUST still redact the data field to avoid leaking sensitive information. if g.Encryptor != nil { - err := g.encryptDataField(resource) + err := g.encryptDataField(ctx, resource) if err != nil { // WARNING: We CAN NOT return an error here, as that would leak the secret data log := klog.FromContext(ctx).WithName("encryptDataField") @@ -544,7 +548,7 @@ var encryptedDataField = FieldPath{encryptedDataFieldName} // in a new field with the name of [encryptedDataFieldName]. The original `data` field is left unchanged, on the // assumption that it will be redacted after the encryption step. // This function does not check that the given resource is actually a Secret; that is the caller's responsibility. -func (g *DataGathererDynamic) encryptDataField(secret *unstructured.Unstructured) error { +func (g *DataGathererDynamic) encryptDataField(ctx context.Context, secret *unstructured.Unstructured) error { if g.Encryptor == nil { return nil } @@ -569,7 +573,7 @@ func (g *DataGathererDynamic) encryptDataField(secret *unstructured.Unstructured return fmt.Errorf("failed to marshal secret data field for encryption: %w", err) } - encryptedData, err := g.Encryptor.Encrypt(plaintextData) + encryptedData, err := g.Encryptor.Encrypt(ctx, plaintextData) if err != nil { return fmt.Errorf("failed to encrypt secret data during redaction: %w", err) } diff --git a/pkg/datagatherer/k8sdynamic/dynamic_test.go b/pkg/datagatherer/k8sdynamic/dynamic_test.go index 335d6571..9b4651eb 100644 --- a/pkg/datagatherer/k8sdynamic/dynamic_test.go +++ b/pkg/datagatherer/k8sdynamic/dynamic_test.go @@ -1,6 +1,7 @@ package k8sdynamic import ( + "context" "crypto/rand" stdrsa "crypto/rsa" "encoding/base64" @@ -32,6 +33,7 @@ import ( "github.com/jetstack/preflight/api" "github.com/jetstack/preflight/internal/envelope" + "github.com/jetstack/preflight/internal/envelope/keyfetch" "github.com/jetstack/preflight/internal/envelope/rsa" ) @@ -405,7 +407,7 @@ func init() { type failEncryptor struct{} -func (fe *failEncryptor) Encrypt(plaintext []byte) (*envelope.EncryptedData, error) { +func (fe *failEncryptor) Encrypt(_ context.Context, plaintext []byte) (*envelope.EncryptedData, error) { return nil, fmt.Errorf("encryption failed") } @@ -415,7 +417,8 @@ func TestDynamicGatherer_Fetch(t *testing.T) { keyID := "test-key-id" - encryptor, err := rsa.NewEncryptor(keyID, privKey.Public().(*stdrsa.PublicKey)) + fetcher := keyfetch.NewFakeClientWithKey(keyID, privKey.Public().(*stdrsa.PublicKey)) + encryptor, err := rsa.NewEncryptor(fetcher) if err != nil { t.Fatalf("failed to create encryptor: %v", err) }