diff --git a/cmd/servers/ateapi/sessionidentity/sessionidentity.go b/cmd/servers/ateapi/sessionidentity/sessionidentity.go index 55f3654..115b675 100644 --- a/cmd/servers/ateapi/sessionidentity/sessionidentity.go +++ b/cmd/servers/ateapi/sessionidentity/sessionidentity.go @@ -25,6 +25,7 @@ import ( "os" "path" "strings" + "sync/atomic" "time" "github.com/agent-substrate/substrate/internal/k8sjwt" @@ -46,22 +47,97 @@ type Server struct { clientJWTIssuer string clientJWTAudience string - // TODO: Cache the signing keys in memory, so we don't read from a file every time. - sessionIDJWTPoolFile string - sessionIDCAPoolFile string - workerCACerts string + + sessionIDJWTPool *fileCache[*localjwtauthority.Pool] + sessionIDCAPool *fileCache[*localca.Pool] +} + +// fileCache periodically refreshes a parsed file-backed value, so callers avoid +// a read+unmarshal on every request while still picking up rare key rotations. +type fileCache[T any] struct { + path string + parse func([]byte) (T, error) + + state atomic.Pointer[fileCacheState[T]] +} + +type fileCacheState[T any] struct { + value T + err error +} + +func newFileCache[T any](path string, parse func([]byte) (T, error)) *fileCache[T] { + return newFileCacheWithTicker(path, time.NewTicker(5*time.Minute).C, parse) +} + +func newFileCacheWithTicker[T any](path string, c <-chan time.Time, parse func([]byte) (T, error)) *fileCache[T] { + cache := &fileCache[T]{ + path: path, + parse: parse, + } + if err := cache.updateValue(); err != nil { + slog.Error("Initial file cache load failed", slog.String("path", path), slog.Any("err", err)) + } + go cache.run(c) + return cache +} + +func (c *fileCache[T]) run(tickerChannel <-chan time.Time) { + for range tickerChannel { + if err := c.updateValue(); err != nil { + slog.Error("File cache refresh failed", slog.String("path", c.path), slog.Any("err", err)) + } else { + slog.Info("File cache refreshed successfully", slog.String("path", c.path)) + } + } +} + +func (c *fileCache[T]) updateValue() error { + b, err := os.ReadFile(c.path) + if err != nil { + c.storeErr(fmt.Errorf("read %s: %w", c.path, err)) + return err + } + v, err := c.parse(b) + if err != nil { + c.storeErr(err) + return err + } + c.state.Store(&fileCacheState[T]{value: v}) + return nil +} + +func (c *fileCache[T]) storeErr(err error) { + if state := c.state.Load(); state != nil && state.err == nil { + // Don't overwrite a good value with an error. + return + } + c.state.Store(&fileCacheState[T]{err: err}) +} + +func (c *fileCache[T]) get() (T, error) { + var zero T + + state := c.state.Load() + if state == nil { + return zero, fmt.Errorf("value not available") + } + if state.err != nil { + return zero, state.err + } + return state.value, nil } var _ ateapipb.SessionIdentityServer = (*Server)(nil) func New(clientJWTIssuer, clientJWTAudience, sessionIDJWTPoolFile, sessionIDCAPoolFile, workerCACerts string) *Server { return &Server{ - clientJWTIssuer: clientJWTIssuer, - clientJWTAudience: clientJWTAudience, - sessionIDJWTPoolFile: sessionIDJWTPoolFile, - sessionIDCAPoolFile: sessionIDCAPoolFile, - workerCACerts: workerCACerts, + clientJWTIssuer: clientJWTIssuer, + clientJWTAudience: clientJWTAudience, + sessionIDJWTPool: newFileCache(sessionIDJWTPoolFile, localjwtauthority.Unmarshal), + sessionIDCAPool: newFileCache(sessionIDCAPoolFile, localca.Unmarshal), + workerCACerts: workerCACerts, } } @@ -90,15 +166,9 @@ func (s *Server) MintJWT(ctx context.Context, req *ateapipb.MintJWTRequest) (*at // TODO: Cross-check requested session and user claims against the session database. - // TODO: Cache signing keys in memory, so we don't read from disk every time. - signingPoolBytes, err := os.ReadFile(s.sessionIDJWTPoolFile) + signingPool, err := s.sessionIDJWTPool.get() if err != nil { - return nil, fmt.Errorf("while reading signing pool bytes: %w", err) - } - - signingPool, err := localjwtauthority.Unmarshal(signingPoolBytes) - if err != nil { - return nil, fmt.Errorf("while unmarshaling signing pool: %w", err) + return nil, fmt.Errorf("while loading signing pool: %w", err) } // We only issue tokens with audience bindings. @@ -163,12 +233,7 @@ func (s *Server) MintCert(ctx context.Context, req *ateapipb.MintCertRequest) (* } // Load the CA pool for signing - poolBytes, err := os.ReadFile(s.sessionIDCAPoolFile) - if err != nil { - slog.ErrorContext(ctx, "Failed to read session CA pool file", slog.Any("err", err)) - return nil, status.Errorf(codes.Internal, "Failed to load session CA") - } - caPool, err := localca.Unmarshal(poolBytes) + caPool, err := s.sessionIDCAPool.get() if err != nil || len(caPool.CAs) == 0 { slog.ErrorContext(ctx, "Failed to load session CA", slog.Any("err", err)) return nil, status.Errorf(codes.Internal, "Failed to load session CA") diff --git a/cmd/servers/ateapi/sessionidentity/sessionidentity_test.go b/cmd/servers/ateapi/sessionidentity/sessionidentity_test.go new file mode 100644 index 0000000..44b29eb --- /dev/null +++ b/cmd/servers/ateapi/sessionidentity/sessionidentity_test.go @@ -0,0 +1,187 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sessionidentity + +import ( + "errors" + "os" + "path/filepath" + "testing" + "time" +) + +func writeFile(t *testing.T, path, contents string) { + t.Helper() + if err := os.WriteFile(path, []byte(contents), 0o600); err != nil { + t.Fatalf("write %s: %v", path, err) + } +} + +func newTestFileCache[T any](t *testing.T, path string, parse func([]byte) (T, error)) (*fileCache[T], chan time.Time) { + t.Helper() + ticks := make(chan time.Time) + t.Cleanup(func() { + close(ticks) + }) + return newFileCacheWithTicker(path, ticks, parse), ticks +} + +func waitForValue(t *testing.T, c *fileCache[string], want string) { + t.Helper() + + deadline := time.Now().Add(time.Second) + for { + got, err := c.get() + if err == nil && got == want { + return + } + if time.Now().After(deadline) { + t.Fatalf("value = %q, err = %v; want value %q", got, err, want) + } + time.Sleep(time.Millisecond) + } +} + +func TestFileCache_LoadsOnCreation(t *testing.T) { + dir := t.TempDir() + p := filepath.Join(dir, "f") + writeFile(t, p, "hello") + + calls := 0 + c, _ := newTestFileCache(t, p, func(b []byte) (string, error) { + calls++ + return string(b), nil + }) + + got, err := c.get() + if err != nil { + t.Fatalf("get: %v", err) + } + if got != "hello" { + t.Fatalf("value = %q, want %q", got, "hello") + } + if calls != 1 { + t.Fatalf("parse calls = %d, want 1", calls) + } +} + +func TestFileCache_DoesNotReloadOnGet(t *testing.T) { + dir := t.TempDir() + p := filepath.Join(dir, "f") + writeFile(t, p, "v1") + + calls := 0 + c, _ := newTestFileCache(t, p, func(b []byte) (string, error) { + calls++ + return string(b), nil + }) + + writeFile(t, p, "v2") + + for range 5 { + v, err := c.get() + if err != nil { + t.Fatalf("get: %v", err) + } + if v != "v1" { + t.Fatalf("value = %q, want v1", v) + } + } + if calls != 1 { + t.Fatalf("parse calls = %d, want 1", calls) + } +} + +func TestFileCache_ReloadsOnTick(t *testing.T) { + dir := t.TempDir() + p := filepath.Join(dir, "f") + writeFile(t, p, "v1") + + calls := 0 + c, ticks := newTestFileCache(t, p, func(b []byte) (string, error) { + calls++ + return string(b), nil + }) + + if _, err := c.get(); err != nil { + t.Fatalf("get v1: %v", err) + } + + writeFile(t, p, "v2") + ticks <- time.Now() + + waitForValue(t, c, "v2") + if calls != 2 { + t.Fatalf("parse calls = %d, want 2", calls) + } +} + +func TestFileCache_KeepsLastValueWhenRefreshFails(t *testing.T) { + dir := t.TempDir() + p := filepath.Join(dir, "f") + writeFile(t, p, "v1") + + want := errors.New("boom") + fail := false + refreshAttempted := make(chan struct{}, 1) + c, ticks := newTestFileCache(t, p, func(b []byte) (string, error) { + if fail { + refreshAttempted <- struct{}{} + return "", want + } + return string(b), nil + }) + + writeFile(t, p, "v2") + fail = true + ticks <- time.Now() + + select { + case <-refreshAttempted: + case <-time.After(time.Second): + t.Fatal("timed out waiting for refresh attempt") + } + + got, err := c.get() + if err != nil { + t.Fatalf("get: %v", err) + } + if got != "v1" { + t.Fatalf("value = %q, want v1", got) + } +} + +func TestFileCache_PropagatesParseError(t *testing.T) { + dir := t.TempDir() + p := filepath.Join(dir, "f") + writeFile(t, p, "bad") + + want := errors.New("boom") + c, _ := newTestFileCache(t, p, func(b []byte) (string, error) { return "", want }) + + _, err := c.get() + if !errors.Is(err, want) { + t.Fatalf("err = %v, want %v", err, want) + } +} + +func TestFileCache_MissingFile(t *testing.T) { + c, _ := newTestFileCache(t, filepath.Join(t.TempDir(), "nope"), func(b []byte) (string, error) { + return string(b), nil + }) + if _, err := c.get(); err == nil { + t.Fatalf("expected error for missing file") + } +}