Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 88 additions & 23 deletions cmd/servers/ateapi/sessionidentity/sessionidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"os"
"path"
"strings"
"sync/atomic"
"time"

"github.com/agent-substrate/substrate/internal/k8sjwt"
Expand All @@ -46,22 +47,97 @@ type Server struct {
clientJWTIssuer string
clientJWTAudience string

// TODO: Cache the signing keys in memory, so we don't read from a file every time.
sessionIDJWTPoolFile string
sessionIDCAPoolFile string

workerCACerts string

sessionIDJWTPool *fileCache[*localjwtauthority.Pool]
sessionIDCAPool *fileCache[*localca.Pool]
}

// fileCache periodically refreshes a parsed file-backed value, so callers avoid
// a read+unmarshal on every request while still picking up rare key rotations.
type fileCache[T any] struct {
path string
parse func([]byte) (T, error)

state atomic.Pointer[fileCacheState[T]]
}

type fileCacheState[T any] struct {
value T
err error
}

func newFileCache[T any](path string, parse func([]byte) (T, error)) *fileCache[T] {
return newFileCacheWithTicker(path, time.NewTicker(5*time.Minute).C, parse)
}

func newFileCacheWithTicker[T any](path string, c <-chan time.Time, parse func([]byte) (T, error)) *fileCache[T] {
cache := &fileCache[T]{
path: path,
parse: parse,
}
if err := cache.updateValue(); err != nil {
slog.Error("Initial file cache load failed", slog.String("path", path), slog.Any("err", err))
}
go cache.run(c)
return cache
}

func (c *fileCache[T]) run(tickerChannel <-chan time.Time) {
for range tickerChannel {
if err := c.updateValue(); err != nil {
slog.Error("File cache refresh failed", slog.String("path", c.path), slog.Any("err", err))
} else {
slog.Info("File cache refreshed successfully", slog.String("path", c.path))
}
}
}

func (c *fileCache[T]) updateValue() error {
b, err := os.ReadFile(c.path)
if err != nil {
c.storeErr(fmt.Errorf("read %s: %w", c.path, err))
return err
}
v, err := c.parse(b)
if err != nil {
c.storeErr(err)
return err
}
c.state.Store(&fileCacheState[T]{value: v})
return nil
}

func (c *fileCache[T]) storeErr(err error) {
if state := c.state.Load(); state != nil && state.err == nil {
// Don't overwrite a good value with an error.
return
}
c.state.Store(&fileCacheState[T]{err: err})
}

func (c *fileCache[T]) get() (T, error) {
var zero T

state := c.state.Load()
if state == nil {
return zero, fmt.Errorf("value not available")
}
if state.err != nil {
return zero, state.err
}
return state.value, nil
}

var _ ateapipb.SessionIdentityServer = (*Server)(nil)

func New(clientJWTIssuer, clientJWTAudience, sessionIDJWTPoolFile, sessionIDCAPoolFile, workerCACerts string) *Server {
return &Server{
clientJWTIssuer: clientJWTIssuer,
clientJWTAudience: clientJWTAudience,
sessionIDJWTPoolFile: sessionIDJWTPoolFile,
sessionIDCAPoolFile: sessionIDCAPoolFile,
workerCACerts: workerCACerts,
clientJWTIssuer: clientJWTIssuer,
clientJWTAudience: clientJWTAudience,
sessionIDJWTPool: newFileCache(sessionIDJWTPoolFile, localjwtauthority.Unmarshal),
sessionIDCAPool: newFileCache(sessionIDCAPoolFile, localca.Unmarshal),
workerCACerts: workerCACerts,
}
}

Expand Down Expand Up @@ -90,15 +166,9 @@ func (s *Server) MintJWT(ctx context.Context, req *ateapipb.MintJWTRequest) (*at

// TODO: Cross-check requested session and user claims against the session database.

// TODO: Cache signing keys in memory, so we don't read from disk every time.
signingPoolBytes, err := os.ReadFile(s.sessionIDJWTPoolFile)
signingPool, err := s.sessionIDJWTPool.get()
if err != nil {
return nil, fmt.Errorf("while reading signing pool bytes: %w", err)
}

signingPool, err := localjwtauthority.Unmarshal(signingPoolBytes)
if err != nil {
return nil, fmt.Errorf("while unmarshaling signing pool: %w", err)
return nil, fmt.Errorf("while loading signing pool: %w", err)
}

// We only issue tokens with audience bindings.
Expand Down Expand Up @@ -163,12 +233,7 @@ func (s *Server) MintCert(ctx context.Context, req *ateapipb.MintCertRequest) (*
}

// Load the CA pool for signing
poolBytes, err := os.ReadFile(s.sessionIDCAPoolFile)
if err != nil {
slog.ErrorContext(ctx, "Failed to read session CA pool file", slog.Any("err", err))
return nil, status.Errorf(codes.Internal, "Failed to load session CA")
}
caPool, err := localca.Unmarshal(poolBytes)
caPool, err := s.sessionIDCAPool.get()
if err != nil || len(caPool.CAs) == 0 {
slog.ErrorContext(ctx, "Failed to load session CA", slog.Any("err", err))
return nil, status.Errorf(codes.Internal, "Failed to load session CA")
Expand Down
187 changes: 187 additions & 0 deletions cmd/servers/ateapi/sessionidentity/sessionidentity_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sessionidentity

import (
"errors"
"os"
"path/filepath"
"testing"
"time"
)

func writeFile(t *testing.T, path, contents string) {
t.Helper()
if err := os.WriteFile(path, []byte(contents), 0o600); err != nil {
t.Fatalf("write %s: %v", path, err)
}
}

func newTestFileCache[T any](t *testing.T, path string, parse func([]byte) (T, error)) (*fileCache[T], chan time.Time) {
t.Helper()
ticks := make(chan time.Time)
t.Cleanup(func() {
close(ticks)
})
return newFileCacheWithTicker(path, ticks, parse), ticks
}

func waitForValue(t *testing.T, c *fileCache[string], want string) {
t.Helper()

deadline := time.Now().Add(time.Second)
for {
got, err := c.get()
if err == nil && got == want {
return
}
if time.Now().After(deadline) {
t.Fatalf("value = %q, err = %v; want value %q", got, err, want)
}
time.Sleep(time.Millisecond)
}
}

func TestFileCache_LoadsOnCreation(t *testing.T) {
dir := t.TempDir()
p := filepath.Join(dir, "f")
writeFile(t, p, "hello")

calls := 0
c, _ := newTestFileCache(t, p, func(b []byte) (string, error) {
calls++
return string(b), nil
})

got, err := c.get()
if err != nil {
t.Fatalf("get: %v", err)
}
if got != "hello" {
t.Fatalf("value = %q, want %q", got, "hello")
}
if calls != 1 {
t.Fatalf("parse calls = %d, want 1", calls)
}
}

func TestFileCache_DoesNotReloadOnGet(t *testing.T) {
dir := t.TempDir()
p := filepath.Join(dir, "f")
writeFile(t, p, "v1")

calls := 0
c, _ := newTestFileCache(t, p, func(b []byte) (string, error) {
calls++
return string(b), nil
})

writeFile(t, p, "v2")

for range 5 {
v, err := c.get()
if err != nil {
t.Fatalf("get: %v", err)
}
if v != "v1" {
t.Fatalf("value = %q, want v1", v)
}
}
if calls != 1 {
t.Fatalf("parse calls = %d, want 1", calls)
}
}

func TestFileCache_ReloadsOnTick(t *testing.T) {
dir := t.TempDir()
p := filepath.Join(dir, "f")
writeFile(t, p, "v1")

calls := 0
c, ticks := newTestFileCache(t, p, func(b []byte) (string, error) {
calls++
return string(b), nil
})

if _, err := c.get(); err != nil {
t.Fatalf("get v1: %v", err)
}

writeFile(t, p, "v2")
ticks <- time.Now()

waitForValue(t, c, "v2")
if calls != 2 {
t.Fatalf("parse calls = %d, want 2", calls)
}
}

func TestFileCache_KeepsLastValueWhenRefreshFails(t *testing.T) {
dir := t.TempDir()
p := filepath.Join(dir, "f")
writeFile(t, p, "v1")

want := errors.New("boom")
fail := false
refreshAttempted := make(chan struct{}, 1)
c, ticks := newTestFileCache(t, p, func(b []byte) (string, error) {
if fail {
refreshAttempted <- struct{}{}
return "", want
}
return string(b), nil
})

writeFile(t, p, "v2")
fail = true
ticks <- time.Now()

select {
case <-refreshAttempted:
case <-time.After(time.Second):
t.Fatal("timed out waiting for refresh attempt")
}

got, err := c.get()
if err != nil {
t.Fatalf("get: %v", err)
}
if got != "v1" {
t.Fatalf("value = %q, want v1", got)
}
}

func TestFileCache_PropagatesParseError(t *testing.T) {
dir := t.TempDir()
p := filepath.Join(dir, "f")
writeFile(t, p, "bad")

want := errors.New("boom")
c, _ := newTestFileCache(t, p, func(b []byte) (string, error) { return "", want })

_, err := c.get()
if !errors.Is(err, want) {
t.Fatalf("err = %v, want %v", err, want)
}
}

func TestFileCache_MissingFile(t *testing.T) {
c, _ := newTestFileCache(t, filepath.Join(t.TempDir(), "nope"), func(b []byte) (string, error) {
return string(b), nil
})
if _, err := c.get(); err == nil {
t.Fatalf("expected error for missing file")
}
}
Loading