Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hack/ark/test-e2e.sh
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ kubectl apply -f "${root_dir}/hack/ark/cluster-external-secret.yaml"

# We use a non-existent tag and omit the `--version` flag, to work around a Helm
# v4 bug. See: https://github.com/helm/helm/issues/31600
# TODO: shouldn't need to set config.sendSecretValues because it will default to true in future
helm upgrade agent "oci://${ARK_CHART}:NON_EXISTENT_TAG@${ARK_CHART_DIGEST}" \
--install \
--wait \
Expand All @@ -113,6 +114,7 @@ helm upgrade agent "oci://${ARK_CHART}:NON_EXISTENT_TAG@${ARK_CHART_DIGEST}" \
--set config.clusterName="e2e-test-cluster" \
--set config.clusterDescription="A temporary cluster for E2E testing. Contact @wallrj-cyberark." \
--set config.period=60s \
--set config.sendSecretValues=true \
--set-json "podLabels={\"disco-agent.cyberark.cloud/test-id\": \"${RANDOM}\"}"

kubectl rollout status deployments/disco-agent --namespace "${NAMESPACE}"
Expand Down
8 changes: 4 additions & 4 deletions internal/cyberark/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ func TestCyberArkClient_PutSnapshot_MockAPI(t *testing.T) {
Secret: "somepassword",
}

discoveryClient := servicediscovery.New(httpClient)
discoveryClient := servicediscovery.New(httpClient, cfg.Subdomain)

serviceMap, tenantUUID, err := discoveryClient.DiscoverServices(t.Context(), cfg.Subdomain)
serviceMap, tenantUUID, err := discoveryClient.DiscoverServices(t.Context())
if err != nil {
t.Fatalf("failed to discover mock services: %v", err)
}
Expand Down Expand Up @@ -76,9 +76,9 @@ func TestCyberArkClient_PutSnapshot_RealAPI(t *testing.T) {
cfg, err := cyberark.LoadClientConfigFromEnvironment()
require.NoError(t, err)

discoveryClient := servicediscovery.New(httpClient)
discoveryClient := servicediscovery.New(httpClient, cfg.Subdomain)

serviceMap, tenantUUID, err := discoveryClient.DiscoverServices(t.Context(), cfg.Subdomain)
serviceMap, tenantUUID, err := discoveryClient.DiscoverServices(t.Context())
if err != nil {
t.Fatalf("failed to discover services: %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/cyberark/identity/cmd/testidentity/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ func run(ctx context.Context) error {
var rootCAs *x509.CertPool
httpClient := http_client.NewDefaultClient(version.UserAgent(), rootCAs)

sdClient := servicediscovery.New(httpClient)
services, _, err := sdClient.DiscoverServices(ctx, subdomain)
sdClient := servicediscovery.New(httpClient, subdomain)
services, _, err := sdClient.DiscoverServices(ctx)
if err != nil {
return fmt.Errorf("while performing service discovery: %s", err)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/cyberark/identity/identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestLoginUsernamePassword_RealAPI(t *testing.T) {
arktesting.SkipIfNoEnv(t)
subdomain := os.Getenv("ARK_SUBDOMAIN")
httpClient := http.DefaultClient
services, _, err := servicediscovery.New(httpClient).DiscoverServices(t.Context(), subdomain)
services, _, err := servicediscovery.New(httpClient, subdomain).DiscoverServices(t.Context())
require.NoError(t, err)

loginUsernamePasswordTests(t, func(t testing.TB) inputs {
Expand Down
50 changes: 39 additions & 11 deletions internal/cyberark/servicediscovery/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"net/url"
"os"
"path"
"sync"
"time"

arkapi "github.com/jetstack/preflight/internal/cyberark/api"
"github.com/jetstack/preflight/pkg/version"
Expand All @@ -35,21 +37,34 @@ const (
// users to fetch URLs for various APIs available in CyberArk. This client is specialised to
// fetch only API endpoints, since only API endpoints are required by the Venafi Kubernetes Agent currently.
type Client struct {
client *http.Client
baseURL string
client *http.Client
baseURL string
subdomain string

cachedResponse *Services
cachedTenantID string
cachedResponseTime time.Time
cachedResponseMutex sync.Mutex
}

// New creates a new CyberArk Service Discovery client. If the ARK_DISCOVERY_API
// environment variable is set, it is used as the base URL for the service
// discovery API. Otherwise, the production URL is used.
func New(httpClient *http.Client) *Client {
func New(httpClient *http.Client, subdomain string) *Client {
baseURL := os.Getenv("ARK_DISCOVERY_API")
if baseURL == "" {
baseURL = ProdDiscoveryAPIBaseURL
}

client := &Client{
client: httpClient,
baseURL: baseURL,
client: httpClient,
baseURL: baseURL,
subdomain: subdomain,

cachedResponse: nil,
cachedTenantID: "",
cachedResponseTime: time.Time{},
cachedResponseMutex: sync.Mutex{},
}

return client
Expand Down Expand Up @@ -93,17 +108,24 @@ type Services struct {
DiscoveryContext ServiceEndpoint
}

// DiscoverServices fetches from the service discovery service for a given subdomain
// DiscoverServices fetches from the service discovery service for the configured subdomain
// and parses the CyberArk Identity API URL and Inventory API URL.
// It also returns the Tenant ID UUID corresponding to the subdomain.
func (c *Client) DiscoverServices(ctx context.Context, subdomain string) (*Services, string, error) {
func (c *Client) DiscoverServices(ctx context.Context) (*Services, string, error) {
c.cachedResponseMutex.Lock()
defer c.cachedResponseMutex.Unlock()

if c.cachedResponse != nil && time.Since(c.cachedResponseTime) < 1*time.Hour {
return c.cachedResponse, c.cachedTenantID, nil
}

u, err := url.Parse(c.baseURL)
if err != nil {
return nil, "", fmt.Errorf("invalid base URL for service discovery: %w", err)
}

u.Path = path.Join(u.Path, "api/public/tenant-discovery")
u.RawQuery = url.Values{"bySubdomain": []string{subdomain}}.Encode()
u.RawQuery = url.Values{"bySubdomain": []string{c.subdomain}}.Encode()

endpoint := u.String()

Expand All @@ -127,7 +149,7 @@ func (c *Client) DiscoverServices(ctx context.Context, subdomain string) (*Servi
// a 404 error is returned with an empty JSON body "{}" if the subdomain is unknown; at the time of writing, we haven't observed
// any other errors and so we can't special case them
if resp.StatusCode == http.StatusNotFound {
return nil, "", fmt.Errorf("got an HTTP 404 response from service discovery; maybe the subdomain %q is incorrect or does not exist?", subdomain)
return nil, "", fmt.Errorf("got an HTTP 404 response from service discovery; maybe the subdomain %q is incorrect or does not exist?", c.subdomain)
}

return nil, "", fmt.Errorf("got unexpected status code %s from request to service discovery API", resp.Status)
Expand Down Expand Up @@ -167,8 +189,14 @@ func (c *Client) DiscoverServices(ctx context.Context, subdomain string) (*Servi
}
//TODO: Should add a check for discoveryContextAPI too?

return &Services{
services := &Services{
Identity: ServiceEndpoint{API: identityAPI},
DiscoveryContext: ServiceEndpoint{API: discoveryContextAPI},
}, discoveryResp.TenantID, nil
}

c.cachedResponse = services
c.cachedTenantID = discoveryResp.TenantID
c.cachedResponseTime = time.Now()

return services, discoveryResp.TenantID, nil
}
4 changes: 2 additions & 2 deletions internal/cyberark/servicediscovery/discovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ func Test_DiscoverIdentityAPIURL(t *testing.T) {
},
})

client := New(httpClient)
client := New(httpClient, testSpec.subdomain)

services, _, err := client.DiscoverServices(ctx, testSpec.subdomain)
services, _, err := client.DiscoverServices(ctx)
if testSpec.expectedError != nil {
assert.EqualError(t, err, testSpec.expectedError.Error())
assert.Nil(t, services)
Expand Down
175 changes: 175 additions & 0 deletions internal/envelope/keyfetch/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
package keyfetch

import (
"context"
"crypto/rsa"
"crypto/x509"
"fmt"
"io"
"net/http"
"net/url"

"github.com/jetstack/venafi-connection-lib/http_client"
"github.com/lestrrat-go/jwx/v3/jwk"

"github.com/jetstack/preflight/internal/cyberark"
"github.com/jetstack/preflight/internal/cyberark/identity"
"github.com/jetstack/preflight/internal/cyberark/servicediscovery"
"github.com/jetstack/preflight/pkg/version"
)

const (
// minRSAKeySize is the minimum RSA key size in bits; we'd expect that keys will be larger but 2048 is a sane floor
// to enforce to ensure that a weak key can't accidentally be used
minRSAKeySize = 2048
)

// KeyFetcher is an interface for fetching public keys.
type KeyFetcher interface {
// FetchKey retrieves a public key from the key source.
FetchKey(ctx context.Context) (PublicKey, error)
}

// Compile-time check that Client implements KeyFetcher
var _ KeyFetcher = (*Client)(nil)

// PublicKey represents an RSA public key retrieved from the key server.
type PublicKey struct {
// KeyID is the unique identifier for this key
KeyID string

// Key is the actual RSA public key
Key *rsa.PublicKey
}

// Client fetches public keys from a CyberArk HTTP endpoint that provides keys in JWKS format.
// It can be expanded in future to support other key types and formats, but for now it only supports RSA keys
// and ignored other types.
type Client struct {
discoveryClient *servicediscovery.Client
cfg cyberark.ClientConfig

// httpClient is the HTTP client used for requests
httpClient *http.Client
}

// NewClient creates a new key fetching client.
// Uses CyberArk service discovery to derive the JWKS endpoint
func NewClient(discoveryClient *servicediscovery.Client, cfg cyberark.ClientConfig) *Client {
var rootCAs *x509.CertPool

return &Client{
discoveryClient: discoveryClient,
cfg: cfg,
httpClient: http_client.NewDefaultClient(version.UserAgent(), rootCAs),
}
}

// FetchKey retrieves the public keys from the configured endpoint.
// It returns a slice of PublicKey structs containing the key material and metadata.
func (c *Client) FetchKey(ctx context.Context) (PublicKey, error) {
services, _, err := c.discoveryClient.DiscoverServices(ctx)
if err != nil {
return PublicKey{}, fmt.Errorf("failed to get services from discovery client: %w", err)
}

identityClient := identity.New(c.httpClient, services.Identity.API, c.cfg.Subdomain)

err = identityClient.LoginUsernamePassword(ctx, c.cfg.Username, []byte(c.cfg.Secret))
if err != nil {
return PublicKey{}, fmt.Errorf("failed to authenticate for fetching JWKs: %w", err)
}

endpoint, err := url.JoinPath(services.DiscoveryContext.API, "discovery-context/jwks")
if err != nil {
return PublicKey{}, fmt.Errorf("failed to construct endpoint URL: %w", err)
}

req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return PublicKey{}, fmt.Errorf("failed to create request: %w", err)
}

_, err = identityClient.AuthenticateRequest(req)
if err != nil {
return PublicKey{}, fmt.Errorf("failed to authenticate request: %s", err)
}

req.Header.Set("Accept", "application/json")
version.SetUserAgent(req)

resp, err := c.httpClient.Do(req)
if err != nil {
return PublicKey{}, fmt.Errorf("failed to fetch keys from %s: %w", endpoint, err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return PublicKey{}, fmt.Errorf("unexpected status code %d from %s: %s", resp.StatusCode, endpoint, string(body))
}

body, err := io.ReadAll(resp.Body)
if err != nil {
return PublicKey{}, fmt.Errorf("failed to read response body: %w", err)
}

keySet, err := jwk.Parse(body)
if err != nil {
return PublicKey{}, fmt.Errorf("failed to parse JWKs response: %w", err)
}

for i := range keySet.Len() {
key, ok := keySet.Key(i)
if !ok {
continue
}

// Only process RSA keys
if key.KeyType().String() != "RSA" {
continue
}

var rawKey any
if err := jwk.Export(key, &rawKey); err != nil {
// skip unparseable keys
continue
}

rsaKey, ok := rawKey.(*rsa.PublicKey)
if !ok {
// only process RSA keys (for now)
continue
}

if rsaKey.N.BitLen() < minRSAKeySize {
// skip keys that are too small to be secure
continue
}

kid, ok := key.KeyID()
if !ok {
// skip any keys which don't have an ID
continue
}

alg, ok := key.Algorithm()
if !ok {
// skip any keys which don't have an algorithm specified
continue
}

if alg.String() != "RSA-OAEP-256" {
// we only use RSA keys for RSA-OAEP-256
continue
}

// return the first valid key we find
return PublicKey{
KeyID: kid,
Key: rsaKey,
}, nil
}

return PublicKey{}, fmt.Errorf("no valid RSA keys found at %s", endpoint)
}
Loading