Skip to content

Commit 1a1ca51

Browse files
committed
add support for fetching keys from a JWKS endpoint
This requires changing a few function signatures and plumbing some things together. Notably, I don't want to have a second service discovery client and send duplicate calls off, so I shared the service discovery client from the CyberArk client and added caching of responses to the service discovery client. Signed-off-by: Ashley Davis <ashley.davis@cyberark.com>
1 parent 6810e77 commit 1a1ca51

18 files changed

Lines changed: 726 additions & 162 deletions

File tree

internal/cyberark/client_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ func TestCyberArkClient_PutSnapshot_MockAPI(t *testing.T) {
3232
Secret: "somepassword",
3333
}
3434

35-
discoveryClient := servicediscovery.New(httpClient)
35+
discoveryClient := servicediscovery.New(httpClient, cfg.Subdomain)
3636

37-
serviceMap, tenantUUID, err := discoveryClient.DiscoverServices(t.Context(), cfg.Subdomain)
37+
serviceMap, tenantUUID, err := discoveryClient.DiscoverServices(t.Context())
3838
if err != nil {
3939
t.Fatalf("failed to discover mock services: %v", err)
4040
}
@@ -76,9 +76,9 @@ func TestCyberArkClient_PutSnapshot_RealAPI(t *testing.T) {
7676
cfg, err := cyberark.LoadClientConfigFromEnvironment()
7777
require.NoError(t, err)
7878

79-
discoveryClient := servicediscovery.New(httpClient)
79+
discoveryClient := servicediscovery.New(httpClient, cfg.Subdomain)
8080

81-
serviceMap, tenantUUID, err := discoveryClient.DiscoverServices(t.Context(), cfg.Subdomain)
81+
serviceMap, tenantUUID, err := discoveryClient.DiscoverServices(t.Context())
8282
if err != nil {
8383
t.Fatalf("failed to discover services: %v", err)
8484
}

internal/cyberark/identity/cmd/testidentity/main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ func run(ctx context.Context) error {
5050
var rootCAs *x509.CertPool
5151
httpClient := http_client.NewDefaultClient(version.UserAgent(), rootCAs)
5252

53-
sdClient := servicediscovery.New(httpClient)
54-
services, _, err := sdClient.DiscoverServices(ctx, subdomain)
53+
sdClient := servicediscovery.New(httpClient, subdomain)
54+
services, _, err := sdClient.DiscoverServices(ctx)
5555
if err != nil {
5656
return fmt.Errorf("while performing service discovery: %s", err)
5757
}

internal/cyberark/identity/identity_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func TestLoginUsernamePassword_RealAPI(t *testing.T) {
5353
arktesting.SkipIfNoEnv(t)
5454
subdomain := os.Getenv("ARK_SUBDOMAIN")
5555
httpClient := http.DefaultClient
56-
services, _, err := servicediscovery.New(httpClient).DiscoverServices(t.Context(), subdomain)
56+
services, _, err := servicediscovery.New(httpClient, subdomain).DiscoverServices(t.Context())
5757
require.NoError(t, err)
5858

5959
loginUsernamePasswordTests(t, func(t testing.TB) inputs {

internal/cyberark/servicediscovery/discovery.go

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
"net/url"
1010
"os"
1111
"path"
12+
"sync"
13+
"time"
1214

1315
arkapi "github.com/jetstack/preflight/internal/cyberark/api"
1416
"github.com/jetstack/preflight/pkg/version"
@@ -35,21 +37,34 @@ const (
3537
// users to fetch URLs for various APIs available in CyberArk. This client is specialised to
3638
// fetch only API endpoints, since only API endpoints are required by the Venafi Kubernetes Agent currently.
3739
type Client struct {
38-
client *http.Client
39-
baseURL string
40+
client *http.Client
41+
baseURL string
42+
subdomain string
43+
44+
cachedResponse *Services
45+
cachedTenantID string
46+
cachedResponseTime time.Time
47+
cachedResponseMutex sync.Mutex
4048
}
4149

4250
// New creates a new CyberArk Service Discovery client. If the ARK_DISCOVERY_API
4351
// environment variable is set, it is used as the base URL for the service
4452
// discovery API. Otherwise, the production URL is used.
45-
func New(httpClient *http.Client) *Client {
53+
func New(httpClient *http.Client, subdomain string) *Client {
4654
baseURL := os.Getenv("ARK_DISCOVERY_API")
4755
if baseURL == "" {
4856
baseURL = ProdDiscoveryAPIBaseURL
4957
}
58+
5059
client := &Client{
51-
client: httpClient,
52-
baseURL: baseURL,
60+
client: httpClient,
61+
baseURL: baseURL,
62+
subdomain: subdomain,
63+
64+
cachedResponse: nil,
65+
cachedTenantID: "",
66+
cachedResponseTime: time.Time{},
67+
cachedResponseMutex: sync.Mutex{},
5368
}
5469

5570
return client
@@ -93,17 +108,24 @@ type Services struct {
93108
DiscoveryContext ServiceEndpoint
94109
}
95110

96-
// DiscoverServices fetches from the service discovery service for a given subdomain
111+
// DiscoverServices fetches from the service discovery service for the configured subdomain
97112
// and parses the CyberArk Identity API URL and Inventory API URL.
98113
// It also returns the Tenant ID UUID corresponding to the subdomain.
99-
func (c *Client) DiscoverServices(ctx context.Context, subdomain string) (*Services, string, error) {
114+
func (c *Client) DiscoverServices(ctx context.Context) (*Services, string, error) {
115+
c.cachedResponseMutex.Lock()
116+
defer c.cachedResponseMutex.Unlock()
117+
118+
if c.cachedResponse != nil && time.Since(c.cachedResponseTime) < 1*time.Hour {
119+
return c.cachedResponse, c.cachedTenantID, nil
120+
}
121+
100122
u, err := url.Parse(c.baseURL)
101123
if err != nil {
102124
return nil, "", fmt.Errorf("invalid base URL for service discovery: %w", err)
103125
}
104126

105127
u.Path = path.Join(u.Path, "api/public/tenant-discovery")
106-
u.RawQuery = url.Values{"bySubdomain": []string{subdomain}}.Encode()
128+
u.RawQuery = url.Values{"bySubdomain": []string{c.subdomain}}.Encode()
107129

108130
endpoint := u.String()
109131

@@ -127,7 +149,7 @@ func (c *Client) DiscoverServices(ctx context.Context, subdomain string) (*Servi
127149
// a 404 error is returned with an empty JSON body "{}" if the subdomain is unknown; at the time of writing, we haven't observed
128150
// any other errors and so we can't special case them
129151
if resp.StatusCode == http.StatusNotFound {
130-
return nil, "", fmt.Errorf("got an HTTP 404 response from service discovery; maybe the subdomain %q is incorrect or does not exist?", subdomain)
152+
return nil, "", fmt.Errorf("got an HTTP 404 response from service discovery; maybe the subdomain %q is incorrect or does not exist?", c.subdomain)
131153
}
132154

133155
return nil, "", fmt.Errorf("got unexpected status code %s from request to service discovery API", resp.Status)
@@ -167,8 +189,14 @@ func (c *Client) DiscoverServices(ctx context.Context, subdomain string) (*Servi
167189
}
168190
//TODO: Should add a check for discoveryContextAPI too?
169191

170-
return &Services{
192+
services := &Services{
171193
Identity: ServiceEndpoint{API: identityAPI},
172194
DiscoveryContext: ServiceEndpoint{API: discoveryContextAPI},
173-
}, discoveryResp.TenantID, nil
195+
}
196+
197+
c.cachedResponse = services
198+
c.cachedTenantID = discoveryResp.TenantID
199+
c.cachedResponseTime = time.Now()
200+
201+
return services, discoveryResp.TenantID, nil
174202
}

internal/cyberark/servicediscovery/discovery_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ func Test_DiscoverIdentityAPIURL(t *testing.T) {
6464
},
6565
})
6666

67-
client := New(httpClient)
67+
client := New(httpClient, testSpec.subdomain)
6868

69-
services, _, err := client.DiscoverServices(ctx, testSpec.subdomain)
69+
services, _, err := client.DiscoverServices(ctx)
7070
if testSpec.expectedError != nil {
7171
assert.EqualError(t, err, testSpec.expectedError.Error())
7272
assert.Nil(t, services)
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
package keyfetch
2+
3+
import (
4+
"context"
5+
"crypto/rsa"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"net/url"
10+
"time"
11+
12+
"github.com/lestrrat-go/jwx/v3/jwk"
13+
14+
"github.com/jetstack/preflight/internal/cyberark/servicediscovery"
15+
)
16+
17+
const (
18+
// minRSAKeySize is the minimum RSA key size in bits; we'd expect that keys will be larger but 2048 is a sane floor
19+
// to enforce to ensure that a weak key can't accidentally be used
20+
minRSAKeySize = 2048
21+
)
22+
23+
// KeyFetcher is an interface for fetching public keys.
24+
type KeyFetcher interface {
25+
// FetchKey retrieves a public key from the key source.
26+
FetchKey(ctx context.Context) (PublicKey, error)
27+
}
28+
29+
// Compile-time check that Client implements KeyFetcher
30+
var _ KeyFetcher = (*Client)(nil)
31+
32+
// PublicKey represents an RSA public key retrieved from the key server.
33+
type PublicKey struct {
34+
// KeyID is the unique identifier for this key
35+
KeyID string
36+
37+
// Key is the actual RSA public key
38+
Key *rsa.PublicKey
39+
}
40+
41+
// Client fetches public keys from a CyberArk HTTP endpoint that provides keys in JWKS format.
42+
// It can be expanded in future to support other key types and formats, but for now it only supports RSA keys
43+
// and ignored other types.
44+
type Client struct {
45+
discoveryClient *servicediscovery.Client
46+
47+
// httpClient is the HTTP client used for requests
48+
httpClient *http.Client
49+
}
50+
51+
// NewClient creates a new key fetching client.
52+
// Uses CyberArk service discovery to derive the JWKS endpoint
53+
func NewClient(discoveryClient *servicediscovery.Client) *Client {
54+
return &Client{
55+
discoveryClient: discoveryClient,
56+
httpClient: &http.Client{
57+
Timeout: 10 * time.Second,
58+
},
59+
}
60+
}
61+
62+
// FetchKey retrieves the public keys from the configured endpoint.
63+
// It returns a slice of PublicKey structs containing the key material and metadata.
64+
func (c *Client) FetchKey(ctx context.Context) (PublicKey, error) {
65+
services, _, err := c.discoveryClient.DiscoverServices(ctx)
66+
if err != nil {
67+
return PublicKey{}, fmt.Errorf("failed to get services from discovery client: %w", err)
68+
}
69+
70+
endpoint, err := url.JoinPath(services.DiscoveryContext.API, "foo")
71+
if err != nil {
72+
return PublicKey{}, fmt.Errorf("failed to construct endpoint URL: %w", err)
73+
}
74+
75+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
76+
if err != nil {
77+
return PublicKey{}, fmt.Errorf("failed to create request: %w", err)
78+
}
79+
80+
req.Header.Set("Accept", "application/json")
81+
82+
resp, err := c.httpClient.Do(req)
83+
if err != nil {
84+
return PublicKey{}, fmt.Errorf("failed to fetch keys from %s: %w", endpoint, err)
85+
}
86+
defer resp.Body.Close()
87+
88+
if resp.StatusCode != http.StatusOK {
89+
body, _ := io.ReadAll(resp.Body)
90+
return PublicKey{}, fmt.Errorf("unexpected status code %d from %s: %s", resp.StatusCode, endpoint, string(body))
91+
}
92+
93+
body, err := io.ReadAll(resp.Body)
94+
if err != nil {
95+
return PublicKey{}, fmt.Errorf("failed to read response body: %w", err)
96+
}
97+
98+
keySet, err := jwk.Parse(body)
99+
if err != nil {
100+
return PublicKey{}, fmt.Errorf("failed to parse JWKs response: %w", err)
101+
}
102+
103+
for i := range keySet.Len() {
104+
key, ok := keySet.Key(i)
105+
if !ok {
106+
continue
107+
}
108+
109+
// Only process RSA keys
110+
if key.KeyType().String() != "RSA" {
111+
continue
112+
}
113+
114+
var rawKey any
115+
if err := jwk.Export(key, &rawKey); err != nil {
116+
// skip unparseable keys
117+
continue
118+
}
119+
120+
rsaKey, ok := rawKey.(*rsa.PublicKey)
121+
if !ok {
122+
// only process RSA keys (for now)
123+
continue
124+
}
125+
126+
if rsaKey.N.BitLen() < minRSAKeySize {
127+
// skip keys that are too small to be secure
128+
continue
129+
}
130+
131+
kid, ok := key.KeyID()
132+
if !ok {
133+
// skip any keys which don't have an ID
134+
continue
135+
}
136+
137+
alg, ok := key.Algorithm()
138+
if !ok {
139+
// skip any keys which don't have an algorithm specified
140+
continue
141+
}
142+
143+
if alg.String() != "RSA-OAEP-256" {
144+
// we only use RSA keys for RSA-OAEP-256
145+
continue
146+
}
147+
148+
// return the first valid key we find
149+
return PublicKey{
150+
KeyID: kid,
151+
Key: rsaKey,
152+
}, nil
153+
}
154+
155+
return PublicKey{}, fmt.Errorf("no valid RSA keys found at %s", endpoint)
156+
}

0 commit comments

Comments
 (0)