From 2cd3e26e53c03d7bce3cb9142f080f360fbaaaca Mon Sep 17 00:00:00 2001 From: Yuval Kohavi Date: Wed, 20 May 2026 16:11:54 -0400 Subject: [PATCH 1/2] fix: use a cache instead of reading from disk every time For sessionIDJWTPoolFile, sessionIDCAPoolFile check modification time, and if not changed, use cached value. --- .../ateapi/sessionidentity/sessionidentity.go | 81 +++++--- .../sessionidentity/sessionidentity_test.go | 178 ++++++++++++++++++ 2 files changed, 236 insertions(+), 23 deletions(-) create mode 100644 cmd/servers/ateapi/sessionidentity/sessionidentity_test.go diff --git a/cmd/servers/ateapi/sessionidentity/sessionidentity.go b/cmd/servers/ateapi/sessionidentity/sessionidentity.go index 55f3654..f197794 100644 --- a/cmd/servers/ateapi/sessionidentity/sessionidentity.go +++ b/cmd/servers/ateapi/sessionidentity/sessionidentity.go @@ -25,6 +25,7 @@ import ( "os" "path" "strings" + "sync" "time" "github.com/agent-substrate/substrate/internal/k8sjwt" @@ -46,22 +47,67 @@ type Server struct { clientJWTIssuer string clientJWTAudience string - // TODO: Cache the signing keys in memory, so we don't read from a file every time. - sessionIDJWTPoolFile string - sessionIDCAPoolFile string - workerCACerts string + + sessionIDJWTPool mtimeCache[*localjwtauthority.Pool] + sessionIDCAPool mtimeCache[*localca.Pool] +} + +// mtimeCache holds a parsed value alongside the mod-time of the file it was +// parsed from. It re-parses only when the file mod-time changes on disk, so callers +// avoid a read+unmarshal on every request while still picking up rotations. +type mtimeCache[T any] struct { + path string + parse func([]byte) (T, error) + + mu sync.Mutex + mtime time.Time + value *T +} + +func newMtimeCache[T any](path string, parse func([]byte) (T, error)) mtimeCache[T] { + return mtimeCache[T]{ + path: path, + parse: parse, + } +} + +func (c *mtimeCache[T]) get() (T, error) { + c.mu.Lock() + defer c.mu.Unlock() + + info, err := os.Stat(c.path) + if err != nil { + var zero T + return zero, fmt.Errorf("stat %s: %w", c.path, err) + } + if c.value != nil && info.ModTime().Equal(c.mtime) { + return *c.value, nil + } + b, err := os.ReadFile(c.path) + if err != nil { + var zero T + return zero, fmt.Errorf("read %s: %w", c.path, err) + } + v, err := c.parse(b) + if err != nil { + var zero T + return zero, err + } + c.value = &v + c.mtime = info.ModTime() + return v, nil } var _ ateapipb.SessionIdentityServer = (*Server)(nil) func New(clientJWTIssuer, clientJWTAudience, sessionIDJWTPoolFile, sessionIDCAPoolFile, workerCACerts string) *Server { return &Server{ - clientJWTIssuer: clientJWTIssuer, - clientJWTAudience: clientJWTAudience, - sessionIDJWTPoolFile: sessionIDJWTPoolFile, - sessionIDCAPoolFile: sessionIDCAPoolFile, - workerCACerts: workerCACerts, + clientJWTIssuer: clientJWTIssuer, + clientJWTAudience: clientJWTAudience, + sessionIDJWTPool: newMtimeCache(sessionIDJWTPoolFile, localjwtauthority.Unmarshal), + sessionIDCAPool: newMtimeCache(sessionIDCAPoolFile, localca.Unmarshal), + workerCACerts: workerCACerts, } } @@ -90,15 +136,9 @@ func (s *Server) MintJWT(ctx context.Context, req *ateapipb.MintJWTRequest) (*at // TODO: Cross-check requested session and user claims against the session database. - // TODO: Cache signing keys in memory, so we don't read from disk every time. - signingPoolBytes, err := os.ReadFile(s.sessionIDJWTPoolFile) + signingPool, err := s.sessionIDJWTPool.get() if err != nil { - return nil, fmt.Errorf("while reading signing pool bytes: %w", err) - } - - signingPool, err := localjwtauthority.Unmarshal(signingPoolBytes) - if err != nil { - return nil, fmt.Errorf("while unmarshaling signing pool: %w", err) + return nil, fmt.Errorf("while loading signing pool: %w", err) } // We only issue tokens with audience bindings. @@ -163,12 +203,7 @@ func (s *Server) MintCert(ctx context.Context, req *ateapipb.MintCertRequest) (* } // Load the CA pool for signing - poolBytes, err := os.ReadFile(s.sessionIDCAPoolFile) - if err != nil { - slog.ErrorContext(ctx, "Failed to read session CA pool file", slog.Any("err", err)) - return nil, status.Errorf(codes.Internal, "Failed to load session CA") - } - caPool, err := localca.Unmarshal(poolBytes) + caPool, err := s.sessionIDCAPool.get() if err != nil || len(caPool.CAs) == 0 { slog.ErrorContext(ctx, "Failed to load session CA", slog.Any("err", err)) return nil, status.Errorf(codes.Internal, "Failed to load session CA") diff --git a/cmd/servers/ateapi/sessionidentity/sessionidentity_test.go b/cmd/servers/ateapi/sessionidentity/sessionidentity_test.go new file mode 100644 index 0000000..9a52056 --- /dev/null +++ b/cmd/servers/ateapi/sessionidentity/sessionidentity_test.go @@ -0,0 +1,178 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sessionidentity + +import ( + "errors" + "os" + "path/filepath" + "testing" + "time" +) + +func writeFile(t *testing.T, path, contents string, mtime time.Time) { + t.Helper() + if err := os.WriteFile(path, []byte(contents), 0o600); err != nil { + t.Fatalf("write %s: %v", path, err) + } + if err := os.Chtimes(path, mtime, mtime); err != nil { + t.Fatalf("chtimes %s: %v", path, err) + } +} + +func TestMtimeCache_LoadsOnFirstCall(t *testing.T) { + dir := t.TempDir() + p := filepath.Join(dir, "f") + writeFile(t, p, "hello", time.Now()) + + calls := 0 + c := newMtimeCache(p, func(b []byte) (string, error) { + calls++ + return string(b), nil + }) + + got, err := c.get() + if err != nil { + t.Fatalf("get: %v", err) + } + if got != "hello" { + t.Fatalf("value = %q, want %q", got, "hello") + } + if calls != 1 { + t.Fatalf("parse calls = %d, want 1", calls) + } +} + +func TestMtimeCache_CachesWhenMtimeUnchanged(t *testing.T) { + dir := t.TempDir() + p := filepath.Join(dir, "f") + mt := time.Now().Add(-time.Hour) + writeFile(t, p, "v1", mt) + + calls := 0 + c := newMtimeCache(p, func(b []byte) (string, error) { + calls++ + return string(b), nil + }) + + for range 5 { + v, err := c.get() + if err != nil { + t.Fatalf("get: %v", err) + } + if v != "v1" { + t.Fatalf("value = %q, want v1", v) + } + } + if calls != 1 { + t.Fatalf("parse calls = %d, want 1 (cache miss on unchanged file)", calls) + } +} + +func TestMtimeCache_ReloadsWhenMtimeChanges(t *testing.T) { + dir := t.TempDir() + p := filepath.Join(dir, "f") + mt := time.Now().Add(-time.Hour) + writeFile(t, p, "v1", mt) + + calls := 0 + c := newMtimeCache(p, func(b []byte) (string, error) { + calls++ + return string(b), nil + }) + + if _, err := c.get(); err != nil { + t.Fatalf("get v1: %v", err) + } + + // Simulate a kubelet AtomicWriter rotation: new contents, new mtime. + writeFile(t, p, "v2", mt.Add(time.Minute)) + + got, err := c.get() + if err != nil { + t.Fatalf("get v2: %v", err) + } + if got != "v2" { + t.Fatalf("value = %q, want v2", got) + } + if calls != 2 { + t.Fatalf("parse calls = %d, want 2", calls) + } +} + +func TestMtimeCache_ReloadsViaSymlinkSwap(t *testing.T) { + // Mimics ConfigMap/Secret mounts: stable file path is a symlink whose + // target is replaced atomically on rotation. + dir := t.TempDir() + v1 := filepath.Join(dir, "v1") + v2 := filepath.Join(dir, "v2") + link := filepath.Join(dir, "current") + + writeFile(t, v1, "v1", time.Now().Add(-time.Hour)) + writeFile(t, v2, "v2", time.Now()) + if err := os.Symlink(v1, link); err != nil { + t.Fatalf("symlink: %v", err) + } + + c := newMtimeCache(link, func(b []byte) (string, error) { return string(b), nil }) + + got, err := c.get() + if err != nil { + t.Fatalf("get v1: %v", err) + } + if got != "v1" { + t.Fatalf("value = %q, want v1", got) + } + + // Atomic-style swap: write new symlink, rename over the old one. + tmp := link + ".new" + if err := os.Symlink(v2, tmp); err != nil { + t.Fatalf("symlink new: %v", err) + } + if err := os.Rename(tmp, link); err != nil { + t.Fatalf("rename: %v", err) + } + + got, err = c.get() + if err != nil { + t.Fatalf("get v2: %v", err) + } + if got != "v2" { + t.Fatalf("value = %q, want v2 after rotation", got) + } +} + +func TestMtimeCache_PropagatesParseError(t *testing.T) { + dir := t.TempDir() + p := filepath.Join(dir, "f") + writeFile(t, p, "bad", time.Now()) + + want := errors.New("boom") + c := newMtimeCache(p, func(b []byte) (string, error) { return "", want }) + + _, err := c.get() + if !errors.Is(err, want) { + t.Fatalf("err = %v, want %v", err, want) + } +} + +func TestMtimeCache_MissingFile(t *testing.T) { + c := newMtimeCache(filepath.Join(t.TempDir(), "nope"), func(b []byte) (string, error) { + return string(b), nil + }) + if _, err := c.get(); err == nil { + t.Fatalf("expected error for missing file") + } +} From b48dd1e9e9cca95a87b3971327ce45c847f8e75a Mon Sep 17 00:00:00 2001 From: Yuval Kohavi Date: Fri, 22 May 2026 17:31:46 -0400 Subject: [PATCH 2/2] PR comments --- .../ateapi/sessionidentity/sessionidentity.go | 94 ++++++++----- .../sessionidentity/sessionidentity_test.go | 129 ++++++++++-------- 2 files changed, 131 insertions(+), 92 deletions(-) diff --git a/cmd/servers/ateapi/sessionidentity/sessionidentity.go b/cmd/servers/ateapi/sessionidentity/sessionidentity.go index f197794..115b675 100644 --- a/cmd/servers/ateapi/sessionidentity/sessionidentity.go +++ b/cmd/servers/ateapi/sessionidentity/sessionidentity.go @@ -25,7 +25,7 @@ import ( "os" "path" "strings" - "sync" + "sync/atomic" "time" "github.com/agent-substrate/substrate/internal/k8sjwt" @@ -49,54 +49,84 @@ type Server struct { workerCACerts string - sessionIDJWTPool mtimeCache[*localjwtauthority.Pool] - sessionIDCAPool mtimeCache[*localca.Pool] + sessionIDJWTPool *fileCache[*localjwtauthority.Pool] + sessionIDCAPool *fileCache[*localca.Pool] } -// mtimeCache holds a parsed value alongside the mod-time of the file it was -// parsed from. It re-parses only when the file mod-time changes on disk, so callers -// avoid a read+unmarshal on every request while still picking up rotations. -type mtimeCache[T any] struct { +// fileCache periodically refreshes a parsed file-backed value, so callers avoid +// a read+unmarshal on every request while still picking up rare key rotations. +type fileCache[T any] struct { path string parse func([]byte) (T, error) - mu sync.Mutex - mtime time.Time - value *T + state atomic.Pointer[fileCacheState[T]] } -func newMtimeCache[T any](path string, parse func([]byte) (T, error)) mtimeCache[T] { - return mtimeCache[T]{ +type fileCacheState[T any] struct { + value T + err error +} + +func newFileCache[T any](path string, parse func([]byte) (T, error)) *fileCache[T] { + return newFileCacheWithTicker(path, time.NewTicker(5*time.Minute).C, parse) +} + +func newFileCacheWithTicker[T any](path string, c <-chan time.Time, parse func([]byte) (T, error)) *fileCache[T] { + cache := &fileCache[T]{ path: path, parse: parse, } + if err := cache.updateValue(); err != nil { + slog.Error("Initial file cache load failed", slog.String("path", path), slog.Any("err", err)) + } + go cache.run(c) + return cache } -func (c *mtimeCache[T]) get() (T, error) { - c.mu.Lock() - defer c.mu.Unlock() - - info, err := os.Stat(c.path) - if err != nil { - var zero T - return zero, fmt.Errorf("stat %s: %w", c.path, err) - } - if c.value != nil && info.ModTime().Equal(c.mtime) { - return *c.value, nil +func (c *fileCache[T]) run(tickerChannel <-chan time.Time) { + for range tickerChannel { + if err := c.updateValue(); err != nil { + slog.Error("File cache refresh failed", slog.String("path", c.path), slog.Any("err", err)) + } else { + slog.Info("File cache refreshed successfully", slog.String("path", c.path)) + } } +} + +func (c *fileCache[T]) updateValue() error { b, err := os.ReadFile(c.path) if err != nil { - var zero T - return zero, fmt.Errorf("read %s: %w", c.path, err) + c.storeErr(fmt.Errorf("read %s: %w", c.path, err)) + return err } v, err := c.parse(b) if err != nil { - var zero T - return zero, err + c.storeErr(err) + return err + } + c.state.Store(&fileCacheState[T]{value: v}) + return nil +} + +func (c *fileCache[T]) storeErr(err error) { + if state := c.state.Load(); state != nil && state.err == nil { + // Don't overwrite a good value with an error. + return + } + c.state.Store(&fileCacheState[T]{err: err}) +} + +func (c *fileCache[T]) get() (T, error) { + var zero T + + state := c.state.Load() + if state == nil { + return zero, fmt.Errorf("value not available") + } + if state.err != nil { + return zero, state.err } - c.value = &v - c.mtime = info.ModTime() - return v, nil + return state.value, nil } var _ ateapipb.SessionIdentityServer = (*Server)(nil) @@ -105,8 +135,8 @@ func New(clientJWTIssuer, clientJWTAudience, sessionIDJWTPoolFile, sessionIDCAPo return &Server{ clientJWTIssuer: clientJWTIssuer, clientJWTAudience: clientJWTAudience, - sessionIDJWTPool: newMtimeCache(sessionIDJWTPoolFile, localjwtauthority.Unmarshal), - sessionIDCAPool: newMtimeCache(sessionIDCAPoolFile, localca.Unmarshal), + sessionIDJWTPool: newFileCache(sessionIDJWTPoolFile, localjwtauthority.Unmarshal), + sessionIDCAPool: newFileCache(sessionIDCAPoolFile, localca.Unmarshal), workerCACerts: workerCACerts, } } diff --git a/cmd/servers/ateapi/sessionidentity/sessionidentity_test.go b/cmd/servers/ateapi/sessionidentity/sessionidentity_test.go index 9a52056..44b29eb 100644 --- a/cmd/servers/ateapi/sessionidentity/sessionidentity_test.go +++ b/cmd/servers/ateapi/sessionidentity/sessionidentity_test.go @@ -22,23 +22,45 @@ import ( "time" ) -func writeFile(t *testing.T, path, contents string, mtime time.Time) { +func writeFile(t *testing.T, path, contents string) { t.Helper() if err := os.WriteFile(path, []byte(contents), 0o600); err != nil { t.Fatalf("write %s: %v", path, err) } - if err := os.Chtimes(path, mtime, mtime); err != nil { - t.Fatalf("chtimes %s: %v", path, err) +} + +func newTestFileCache[T any](t *testing.T, path string, parse func([]byte) (T, error)) (*fileCache[T], chan time.Time) { + t.Helper() + ticks := make(chan time.Time) + t.Cleanup(func() { + close(ticks) + }) + return newFileCacheWithTicker(path, ticks, parse), ticks +} + +func waitForValue(t *testing.T, c *fileCache[string], want string) { + t.Helper() + + deadline := time.Now().Add(time.Second) + for { + got, err := c.get() + if err == nil && got == want { + return + } + if time.Now().After(deadline) { + t.Fatalf("value = %q, err = %v; want value %q", got, err, want) + } + time.Sleep(time.Millisecond) } } -func TestMtimeCache_LoadsOnFirstCall(t *testing.T) { +func TestFileCache_LoadsOnCreation(t *testing.T) { dir := t.TempDir() p := filepath.Join(dir, "f") - writeFile(t, p, "hello", time.Now()) + writeFile(t, p, "hello") calls := 0 - c := newMtimeCache(p, func(b []byte) (string, error) { + c, _ := newTestFileCache(t, p, func(b []byte) (string, error) { calls++ return string(b), nil }) @@ -55,18 +77,19 @@ func TestMtimeCache_LoadsOnFirstCall(t *testing.T) { } } -func TestMtimeCache_CachesWhenMtimeUnchanged(t *testing.T) { +func TestFileCache_DoesNotReloadOnGet(t *testing.T) { dir := t.TempDir() p := filepath.Join(dir, "f") - mt := time.Now().Add(-time.Hour) - writeFile(t, p, "v1", mt) + writeFile(t, p, "v1") calls := 0 - c := newMtimeCache(p, func(b []byte) (string, error) { + c, _ := newTestFileCache(t, p, func(b []byte) (string, error) { calls++ return string(b), nil }) + writeFile(t, p, "v2") + for range 5 { v, err := c.get() if err != nil { @@ -77,18 +100,17 @@ func TestMtimeCache_CachesWhenMtimeUnchanged(t *testing.T) { } } if calls != 1 { - t.Fatalf("parse calls = %d, want 1 (cache miss on unchanged file)", calls) + t.Fatalf("parse calls = %d, want 1", calls) } } -func TestMtimeCache_ReloadsWhenMtimeChanges(t *testing.T) { +func TestFileCache_ReloadsOnTick(t *testing.T) { dir := t.TempDir() p := filepath.Join(dir, "f") - mt := time.Now().Add(-time.Hour) - writeFile(t, p, "v1", mt) + writeFile(t, p, "v1") calls := 0 - c := newMtimeCache(p, func(b []byte) (string, error) { + c, ticks := newTestFileCache(t, p, func(b []byte) (string, error) { calls++ return string(b), nil }) @@ -97,70 +119,57 @@ func TestMtimeCache_ReloadsWhenMtimeChanges(t *testing.T) { t.Fatalf("get v1: %v", err) } - // Simulate a kubelet AtomicWriter rotation: new contents, new mtime. - writeFile(t, p, "v2", mt.Add(time.Minute)) + writeFile(t, p, "v2") + ticks <- time.Now() - got, err := c.get() - if err != nil { - t.Fatalf("get v2: %v", err) - } - if got != "v2" { - t.Fatalf("value = %q, want v2", got) - } + waitForValue(t, c, "v2") if calls != 2 { t.Fatalf("parse calls = %d, want 2", calls) } } -func TestMtimeCache_ReloadsViaSymlinkSwap(t *testing.T) { - // Mimics ConfigMap/Secret mounts: stable file path is a symlink whose - // target is replaced atomically on rotation. +func TestFileCache_KeepsLastValueWhenRefreshFails(t *testing.T) { dir := t.TempDir() - v1 := filepath.Join(dir, "v1") - v2 := filepath.Join(dir, "v2") - link := filepath.Join(dir, "current") - - writeFile(t, v1, "v1", time.Now().Add(-time.Hour)) - writeFile(t, v2, "v2", time.Now()) - if err := os.Symlink(v1, link); err != nil { - t.Fatalf("symlink: %v", err) - } + p := filepath.Join(dir, "f") + writeFile(t, p, "v1") - c := newMtimeCache(link, func(b []byte) (string, error) { return string(b), nil }) + want := errors.New("boom") + fail := false + refreshAttempted := make(chan struct{}, 1) + c, ticks := newTestFileCache(t, p, func(b []byte) (string, error) { + if fail { + refreshAttempted <- struct{}{} + return "", want + } + return string(b), nil + }) - got, err := c.get() - if err != nil { - t.Fatalf("get v1: %v", err) - } - if got != "v1" { - t.Fatalf("value = %q, want v1", got) - } + writeFile(t, p, "v2") + fail = true + ticks <- time.Now() - // Atomic-style swap: write new symlink, rename over the old one. - tmp := link + ".new" - if err := os.Symlink(v2, tmp); err != nil { - t.Fatalf("symlink new: %v", err) - } - if err := os.Rename(tmp, link); err != nil { - t.Fatalf("rename: %v", err) + select { + case <-refreshAttempted: + case <-time.After(time.Second): + t.Fatal("timed out waiting for refresh attempt") } - got, err = c.get() + got, err := c.get() if err != nil { - t.Fatalf("get v2: %v", err) + t.Fatalf("get: %v", err) } - if got != "v2" { - t.Fatalf("value = %q, want v2 after rotation", got) + if got != "v1" { + t.Fatalf("value = %q, want v1", got) } } -func TestMtimeCache_PropagatesParseError(t *testing.T) { +func TestFileCache_PropagatesParseError(t *testing.T) { dir := t.TempDir() p := filepath.Join(dir, "f") - writeFile(t, p, "bad", time.Now()) + writeFile(t, p, "bad") want := errors.New("boom") - c := newMtimeCache(p, func(b []byte) (string, error) { return "", want }) + c, _ := newTestFileCache(t, p, func(b []byte) (string, error) { return "", want }) _, err := c.get() if !errors.Is(err, want) { @@ -168,8 +177,8 @@ func TestMtimeCache_PropagatesParseError(t *testing.T) { } } -func TestMtimeCache_MissingFile(t *testing.T) { - c := newMtimeCache(filepath.Join(t.TempDir(), "nope"), func(b []byte) (string, error) { +func TestFileCache_MissingFile(t *testing.T) { + c, _ := newTestFileCache(t, filepath.Join(t.TempDir(), "nope"), func(b []byte) (string, error) { return string(b), nil }) if _, err := c.get(); err == nil {