diff --git a/cmd/root/debug.go b/cmd/root/debug.go index 415c6f126..a32ffcff3 100644 --- a/cmd/root/debug.go +++ b/cmd/root/debug.go @@ -10,6 +10,7 @@ import ( "github.com/docker/docker-agent/pkg/cli" "github.com/docker/docker-agent/pkg/config" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/sessiontitle" "github.com/docker/docker-agent/pkg/team" "github.com/docker/docker-agent/pkg/teamloader" @@ -158,7 +159,7 @@ func (f *debugFlags) runDebugTitleCommand(cmd *cobra.Command, args []string) (co } model := agent.Model(ctx) - if model == nil { + if reflectx.IsNil(model) { return fmt.Errorf("agent %q has no model configured", agent.Name()) } diff --git a/cmd/root/run.go b/cmd/root/run.go index 74f5ff784..fd6f027c2 100644 --- a/cmd/root/run.go +++ b/cmd/root/run.go @@ -24,6 +24,7 @@ import ( "github.com/docker/docker-agent/pkg/paths" "github.com/docker/docker-agent/pkg/permissions" "github.com/docker/docker-agent/pkg/profiling" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/runtime" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/teamloader" @@ -514,7 +515,7 @@ func (f *runExecFlags) buildAppOpts(args []string) ([]app.Opt, error) { if f.exitAfterResponse { opts = append(opts, app.WithExitAfterFirstResponse()) } - if f.snapshotController != nil { + if !reflectx.IsNil(f.snapshotController) { opts = append(opts, app.WithSnapshotController(f.snapshotController)) } return opts, nil @@ -586,7 +587,7 @@ func (f *runExecFlags) createSessionSpawner(agentSource config.Source, sessStore if gen := localRt.TitleGenerator(); gen != nil { appOpts = append(appOpts, app.WithTitleGenerator(gen)) } - if ctrl != nil { + if !reflectx.IsNil(ctrl) { appOpts = append(appOpts, app.WithSnapshotController(ctrl)) } diff --git a/go.mod b/go.mod index 86a57c447..16092aca4 100644 --- a/go.mod +++ b/go.mod @@ -71,6 +71,7 @@ require ( golang.org/x/sync v0.20.0 golang.org/x/sys v0.44.0 golang.org/x/term v0.43.0 + golang.org/x/tools v0.45.0 google.golang.org/adk v1.2.0 google.golang.org/genai v1.57.0 gopkg.in/dnaeon/go-vcr.v4 v4.0.6 @@ -99,7 +100,6 @@ require ( github.com/standard-webhooks/standard-webhooks/libraries v0.0.1 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 // indirect golang.org/x/mod v0.36.0 // indirect - golang.org/x/tools v0.45.0 // indirect google.golang.org/api v0.272.0 // indirect ) diff --git a/lint/interface_nil_comparison.go b/lint/interface_nil_comparison.go new file mode 100644 index 000000000..b770680ee --- /dev/null +++ b/lint/interface_nil_comparison.go @@ -0,0 +1,261 @@ +package main + +import ( + "go/ast" + "go/token" + "go/types" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/dgageot/rubocop-go/cop" + "golang.org/x/tools/go/packages" +) + +// InterfaceNilComparison rejects nil comparisons against Docker-owned +// non-empty interface types. Interface values can hold typed nil pointers, so +// `x == nil` and `x != nil` only test the interface header and can miss a nil +// implementation value. Use reflectx.IsNil instead. +// +// The cop intentionally ignores error, any/interface{}, and interfaces owned by +// external packages to keep the rule focused on project abstractions we control. +var InterfaceNilComparison = &cop.Func{ + Meta: cop.Meta{ + Name: "Lint/InterfaceNilComparison", + Description: "use reflectx.IsNil instead of comparing project interface values to nil", + Severity: cop.Warning, + }, + Run: func(p *cop.Pass) { + if typed, ok := typedFileForInterfaceNilComparison(p); ok { + checkInterfaceNilComparisons(p, typed.file, typed.fset, typed.info, typed.pkg, typed.modulePath) + return + } + if p.Info != nil { + checkInterfaceNilComparisons(p, p.File, p.FileSet, p.Info, p.Package, "") + } + }, +} + +type interfaceNilTypedFile struct { + file *ast.File + fset *token.FileSet + info *types.Info + pkg *types.Package + modulePath string +} + +type interfaceNilCache struct { + files map[string][]interfaceNilTypedFile +} + +var ( + interfaceNilMu sync.Mutex + interfaceNilByRoot = map[string]*interfaceNilCache{} +) + +func typedFileForInterfaceNilComparison(p *cop.Pass) (interfaceNilTypedFile, bool) { + filename, err := filepath.Abs(p.Filename()) + if err != nil { + return interfaceNilTypedFile{}, false + } + root, ok := moduleRoot(filepath.Dir(filename)) + if !ok { + return interfaceNilTypedFile{}, false + } + + cache := loadInterfaceNilCache(root) + matches := cache.files[filepath.Clean(filename)] + for _, match := range matches { + if match.file.Name != nil && match.file.Name.Name == p.PackageName() { + return match, true + } + } + if len(matches) > 0 { + return matches[0], true + } + return interfaceNilTypedFile{}, false +} + +func loadInterfaceNilCache(root string) *interfaceNilCache { + root = filepath.Clean(root) + + interfaceNilMu.Lock() + cached := interfaceNilByRoot[root] + interfaceNilMu.Unlock() + if cached != nil { + return cached + } + + modulePath := readModulePath(root) + cache := &interfaceNilCache{files: map[string][]interfaceNilTypedFile{}} + cfg := &packages.Config{ + Dir: root, + Tests: true, + Mode: packages.NeedName | + packages.NeedFiles | + packages.NeedCompiledGoFiles | + packages.NeedSyntax | + packages.NeedTypes | + packages.NeedTypesInfo, + } + pkgs, err := packages.Load(cfg, "./...") + if err != nil && len(pkgs) == 0 { + return storeInterfaceNilCache(root, cache) + } + for _, pkg := range pkgs { + if pkg.Fset == nil || pkg.TypesInfo == nil || pkg.Types == nil { + continue + } + for _, file := range pkg.Syntax { + pos := pkg.Fset.Position(file.Package) + filename, absErr := filepath.Abs(pos.Filename) + if absErr != nil { + continue + } + filename = filepath.Clean(filename) + cache.files[filename] = append(cache.files[filename], interfaceNilTypedFile{ + file: file, + fset: pkg.Fset, + info: pkg.TypesInfo, + pkg: pkg.Types, + modulePath: modulePath, + }) + } + } + + return storeInterfaceNilCache(root, cache) +} + +func storeInterfaceNilCache(root string, cache *interfaceNilCache) *interfaceNilCache { + interfaceNilMu.Lock() + defer interfaceNilMu.Unlock() + if existing := interfaceNilByRoot[root]; existing != nil { + return existing + } + interfaceNilByRoot[root] = cache + return cache +} + +func checkInterfaceNilComparisons( + p *cop.Pass, + file *ast.File, + fset *token.FileSet, + info *types.Info, + pkg *types.Package, + modulePath string, +) { + ast.Inspect(file, func(n ast.Node) bool { + binary, ok := n.(*ast.BinaryExpr) + if !ok || (binary.Op != token.EQL && binary.Op != token.NEQ) { + return true + } + + expr, ok := nilComparisonOperand(binary) + if !ok { + return true + } + typ := info.TypeOf(expr) + if !forbiddenInterfaceNilType(typ, pkg, modulePath) { + return true + } + + start, end, ok := equivalentSpan(p, fset, binary.Pos(), binary.End()) + if !ok { + return true + } + p.ReportAtf(start, end, + "do not compare interface value of type %s to nil; use reflectx.IsNil so typed nil implementations are treated as nil", + typ.String()) + return true + }) +} + +func nilComparisonOperand(binary *ast.BinaryExpr) (ast.Expr, bool) { + if isNilIdent(binary.X) { + return binary.Y, true + } + if isNilIdent(binary.Y) { + return binary.X, true + } + return nil, false +} + +func isNilIdent(expr ast.Expr) bool { + id, ok := expr.(*ast.Ident) + return ok && id.Name == "nil" +} + +func forbiddenInterfaceNilType(typ types.Type, pkg *types.Package, modulePath string) bool { + if typ == nil || isErrorType(typ) { + return false + } + iface, ok := types.Unalias(typ).Underlying().(*types.Interface) + if !ok || iface.NumMethods() == 0 { + return false + } + if modulePath == "" { + return true + } + if path := namedTypePackagePath(typ); path != "" { + return path == modulePath || strings.HasPrefix(path, modulePath+"/") + } + return pkg != nil && (pkg.Path() == modulePath || strings.HasPrefix(pkg.Path(), modulePath+"/")) +} + +func isErrorType(typ types.Type) bool { + errType := types.Universe.Lookup("error").Type() + return types.Identical(typ, errType) || types.Identical(types.Unalias(typ).Underlying(), errType.Underlying()) +} + +func namedTypePackagePath(typ types.Type) string { + named, ok := types.Unalias(typ).(*types.Named) + if !ok || named.Obj() == nil || named.Obj().Pkg() == nil { + return "" + } + return named.Obj().Pkg().Path() +} + +func equivalentSpan(p *cop.Pass, sourceFSet *token.FileSet, pos, end token.Pos) (token.Pos, token.Pos, bool) { + if sourceFSet == p.FileSet { + return pos, end, true + } + sourceFile := sourceFSet.File(pos) + destFile := p.FileSet.File(p.File.Package) + if sourceFile == nil || destFile == nil { + return token.NoPos, token.NoPos, false + } + startOffset := sourceFile.Offset(pos) + endOffset := sourceFile.Offset(end) + if startOffset < 0 || endOffset < startOffset || endOffset > destFile.Size() { + return token.NoPos, token.NoPos, false + } + return destFile.Pos(startOffset), destFile.Pos(endOffset), true +} + +func moduleRoot(dir string) (string, bool) { + for { + if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { + return dir, true + } + parent := filepath.Dir(dir) + if parent == dir { + return "", false + } + dir = parent + } +} + +func readModulePath(root string) string { + data, err := os.ReadFile(filepath.Join(root, "go.mod")) + if err != nil { + return "" + } + for line := range strings.SplitSeq(string(data), "\n") { + fields := strings.Fields(line) + if len(fields) == 2 && fields[0] == "module" { + return fields[1] + } + } + return "" +} diff --git a/lint/interface_nil_comparison_test.go b/lint/interface_nil_comparison_test.go new file mode 100644 index 000000000..bbfcfadbe --- /dev/null +++ b/lint/interface_nil_comparison_test.go @@ -0,0 +1,43 @@ +package main + +import ( + "testing" + + "github.com/dgageot/rubocop-go/coptest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInterfaceNilComparisonReportsInterfaceNilChecks(t *testing.T) { + offenses := coptest.RunTyped(t, InterfaceNilComparison, `package sample + +type Worker interface { Work() } + +func f(w Worker) { + if w == nil { + } + if nil != w { + } +} +`) + + require.Len(t, offenses, 2) + assert.Equal(t, "Lint/InterfaceNilComparison", offenses[0].CopName) + assert.Contains(t, offenses[0].Message, "reflectx.IsNil") +} + +func TestInterfaceNilComparisonIgnoresErrorAnyAndConcreteTypes(t *testing.T) { + offenses := coptest.RunTyped(t, InterfaceNilComparison, `package sample + +func f(err error, value any, ptr *int) { + if err == nil { + } + if value == nil { + } + if ptr == nil { + } +} +`) + + assert.Empty(t, offenses) +} diff --git a/lint/main.go b/lint/main.go index 270bfa3f6..d42176ea0 100644 --- a/lint/main.go +++ b/lint/main.go @@ -29,6 +29,7 @@ var cops = []cop.Cop{ HookConfigSync, HookBuiltinsRegistered, SlogContextual, + InterfaceNilComparison, } func main() { diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index c98102e1d..92f624128 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -13,6 +13,7 @@ import ( "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/config/types" "github.com/docker/docker-agent/pkg/model/provider" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/tools" ) @@ -185,7 +186,7 @@ func (a *Agent) SetModelOverride(models ...provider.Provider) ModelOverrideSnaps // Filter out nil providers var validModels []provider.Provider for _, m := range models { - if m != nil { + if !reflectx.IsNil(m) { validModels = append(validModels, m) } } diff --git a/pkg/app/app.go b/pkg/app/app.go index 35e1a89e5..90b0f9b47 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -22,6 +22,7 @@ import ( "github.com/docker/docker-agent/pkg/cli" "github.com/docker/docker-agent/pkg/config/types" "github.com/docker/docker-agent/pkg/hooks/builtins" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/runtime" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/sessiontitle" @@ -861,7 +862,7 @@ func (a *App) SetCurrentAgentModel(ctx context.Context, modelRef string) error { } // Persist the session - if store := a.runtime.SessionStore(); store != nil { + if store := a.runtime.SessionStore(); !reflectx.IsNil(store) { if err := store.UpdateSession(ctx, a.session); err != nil { return fmt.Errorf("failed to persist model override: %w", err) } diff --git a/pkg/app/undo.go b/pkg/app/undo.go index ca1023a34..36a6f34a1 100644 --- a/pkg/app/undo.go +++ b/pkg/app/undo.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "os" + + "github.com/docker/docker-agent/pkg/reflectx" ) var ErrNothingToUndo = errors.New("nothing to undo") @@ -17,13 +19,13 @@ type UndoSnapshotResult struct { // active. The answer is a controller-level capability check and does // not depend on having an active session attached. func (a *App) SnapshotsEnabled() bool { - return a.snapshotController != nil && a.snapshotController.Enabled() + return !reflectx.IsNil(a.snapshotController) && a.snapshotController.Enabled() } // UndoLastSnapshot restores the files captured in the most recent // snapshot checkpoint for the current session. func (a *App) UndoLastSnapshot(ctx context.Context) (UndoSnapshotResult, error) { - if a.snapshotController == nil || a.session == nil { + if reflectx.IsNil(a.snapshotController) || a.session == nil { return UndoSnapshotResult{}, ErrNothingToUndo } return snapshotResult(a.snapshotController.UndoLast(ctx, a.session.ID, a.snapshotCwd())) @@ -33,7 +35,7 @@ func (a *App) UndoLastSnapshot(ctx context.Context) (UndoSnapshotResult, error) // the current session, oldest first. Returns nil when no snapshots exist // or when no controller is configured. func (a *App) ListSnapshots() []int { - if a.snapshotController == nil || a.session == nil { + if reflectx.IsNil(a.snapshotController) || a.session == nil { return nil } infos := a.snapshotController.List(a.session.ID) @@ -48,7 +50,7 @@ func (a *App) ListSnapshots() []int { // returns to the state captured at that snapshot. keep == 0 resets to // the original pre-agent state. func (a *App) ResetSnapshot(ctx context.Context, keep int) (UndoSnapshotResult, error) { - if a.snapshotController == nil || a.session == nil { + if reflectx.IsNil(a.snapshotController) || a.session == nil { return UndoSnapshotResult{}, ErrNothingToUndo } return snapshotResult(a.snapshotController.Reset(ctx, a.session.ID, a.snapshotCwd(), keep)) diff --git a/pkg/chatserver/runtime_pool.go b/pkg/chatserver/runtime_pool.go index d79f03448..3dbcd7173 100644 --- a/pkg/chatserver/runtime_pool.go +++ b/pkg/chatserver/runtime_pool.go @@ -4,6 +4,7 @@ import ( "errors" "sync" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/runtime" "github.com/docker/docker-agent/pkg/team" ) @@ -53,7 +54,7 @@ func (p *runtimePool) Get(agent string) (runtime.Runtime, error) { if p == nil { return nil, errInvalidRuntime } - if rt := p.takeIdle(agent); rt != nil { + if rt := p.takeIdle(agent); !reflectx.IsNil(rt) { return rt, nil } rt, err := runtime.New(p.team, runtime.WithCurrentAgent(agent)) @@ -68,7 +69,7 @@ func (p *runtimePool) Get(agent string) (runtime.Runtime, error) { // underlying toolsets). The runtime must not be used by the caller // after Put returns. func (p *runtimePool) Put(agent string, rt runtime.Runtime) { - if p == nil || rt == nil || p.maxIdle == 0 { + if p == nil || reflectx.IsNil(rt) || p.maxIdle == 0 { return } ch := p.channelFor(agent) diff --git a/pkg/config/runtime.go b/pkg/config/runtime.go index 455baa0ae..2674d2462 100644 --- a/pkg/config/runtime.go +++ b/pkg/config/runtime.go @@ -10,6 +10,7 @@ import ( "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/environment" "github.com/docker/docker-agent/pkg/modelsdev" + "github.com/docker/docker-agent/pkg/reflectx" ) type RuntimeConfig struct { @@ -86,7 +87,7 @@ func (runConfig *RuntimeConfig) ModelsDevStore() (*modelsdev.Store, error) { } func (runConfig *RuntimeConfig) EnvProvider() environment.Provider { - if runConfig.EnvProviderForTests != nil { + if !reflectx.IsNil(runConfig.EnvProviderForTests) { return runConfig.EnvProviderForTests } diff --git a/pkg/config/sources.go b/pkg/config/sources.go index 62256ff50..de374bb6b 100644 --- a/pkg/config/sources.go +++ b/pkg/config/sources.go @@ -20,6 +20,7 @@ import ( "github.com/docker/docker-agent/pkg/environment" "github.com/docker/docker-agent/pkg/httpclient" "github.com/docker/docker-agent/pkg/paths" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/remote" ) @@ -338,7 +339,7 @@ func isGitHubURL(urlStr string) bool { // - An environment provider is configured // - GITHUB_TOKEN is available in the environment func (a urlSource) addGitHubAuth(ctx context.Context, req *http.Request) { - if a.envProvider == nil { + if reflectx.IsNil(a.envProvider) { return } diff --git a/pkg/evaluation/eval.go b/pkg/evaluation/eval.go index f067fae4d..47afb16b6 100644 --- a/pkg/evaluation/eval.go +++ b/pkg/evaluation/eval.go @@ -26,6 +26,7 @@ import ( "github.com/docker/docker-agent/pkg/environment" "github.com/docker/docker-agent/pkg/model/provider" "github.com/docker/docker-agent/pkg/model/provider/options" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/session" ) @@ -48,7 +49,7 @@ type Runner struct { // newRunner creates a new evaluation runner. func newRunner(agentSource config.Source, runConfig *config.RuntimeConfig, judgeModel provider.Provider, cfg Config) *Runner { var judge *Judge - if judgeModel != nil { + if !reflectx.IsNil(judgeModel) { judge = NewJudge(judgeModel, cfg.Concurrency) } return &Runner{ diff --git a/pkg/hooks/model_handler.go b/pkg/hooks/model_handler.go index c7388cfc4..ecf1a943b 100644 --- a/pkg/hooks/model_handler.go +++ b/pkg/hooks/model_handler.go @@ -11,6 +11,7 @@ import ( "github.com/docker/docker-agent/pkg/concurrent" "github.com/docker/docker-agent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/reflectx" ) // ModelClient is the runtime-provided seam between [HookTypeModel] @@ -117,7 +118,7 @@ func defaultShape(raw string, in *Input) (*Output, error) { // invocation. The shape/schema lookup happens at handler-construction // time so adding new shapes after factory registration still works. func NewModelFactory(client ModelClient) HandlerFactory { - if client == nil { + if reflectx.IsNil(client) { // The empty client always errors; this lets a runtime register // the factory without a credentialed client and fail at the // first use rather than at construction. diff --git a/pkg/js/expand.go b/pkg/js/expand.go index cda94aa1e..48f0d7635 100644 --- a/pkg/js/expand.go +++ b/pkg/js/expand.go @@ -12,6 +12,7 @@ import ( "github.com/docker/docker-agent/pkg/config/types" "github.com/docker/docker-agent/pkg/environment" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/tools" ) @@ -52,7 +53,7 @@ func (*dynamicLookup) Keys() []string { return nil } func (exp *Expander) newVMWithBindings(ctx context.Context) *goja.Runtime { vm := newVM() - if exp.env != nil { + if !reflectx.IsNil(exp.env) { _ = vm.Set("env", vm.NewDynamicObject(&dynamicLookup{ vm: vm, lookup: func(k string) string { v, _ := exp.env.Get(ctx, k); return v }, diff --git a/pkg/model/provider/anthropic/client.go b/pkg/model/provider/anthropic/client.go index 14a85b938..c89a07f65 100644 --- a/pkg/model/provider/anthropic/client.go +++ b/pkg/model/provider/anthropic/client.go @@ -23,6 +23,7 @@ import ( "github.com/docker/docker-agent/pkg/model/provider/base" "github.com/docker/docker-agent/pkg/model/provider/options" "github.com/docker/docker-agent/pkg/model/provider/providerutil" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/tools" ) @@ -47,7 +48,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro return nil, errors.New("model type must be 'anthropic'") } - if env == nil { + if reflectx.IsNil(env) { slog.ErrorContext(ctx, "Anthropic client creation failed", "error", "environment provider is required") return nil, errors.New("environment provider is required") } diff --git a/pkg/model/provider/anthropic/vertex.go b/pkg/model/provider/anthropic/vertex.go index 7ed11d264..7501d99db 100644 --- a/pkg/model/provider/anthropic/vertex.go +++ b/pkg/model/provider/anthropic/vertex.go @@ -15,6 +15,7 @@ import ( "github.com/docker/docker-agent/pkg/environment" "github.com/docker/docker-agent/pkg/model/provider/base" "github.com/docker/docker-agent/pkg/model/provider/options" + "github.com/docker/docker-agent/pkg/reflectx" ) // vertexCloudPlatformScope is the OAuth2 scope required for Vertex AI API access. @@ -34,7 +35,7 @@ func NewVertexClient(ctx context.Context, cfg *latest.ModelConfig, env environme if cfg == nil { return nil, errors.New("model configuration is required") } - if env == nil { + if reflectx.IsNil(env) { return nil, errors.New("environment provider is required") } if project == "" { diff --git a/pkg/model/provider/openai/headers_test.go b/pkg/model/provider/openai/headers_test.go index e94c220bc..ac504c07c 100644 --- a/pkg/model/provider/openai/headers_test.go +++ b/pkg/model/provider/openai/headers_test.go @@ -12,6 +12,7 @@ import ( "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/environment" + "github.com/docker/docker-agent/pkg/reflectx" ) func TestUserHeaders(t *testing.T) { @@ -196,7 +197,7 @@ func captureHeaders(t *testing.T, cfg *latest.ModelConfig, envVars map[string]st []chat.Message{{Role: chat.MessageRoleUser, Content: "hi"}}, nil, ) - if err == nil && stream != nil { + if err == nil && !reflectx.IsNil(stream) { // Drain the stream so the HTTP request is actually sent. for { if _, err := stream.Recv(); err != nil { diff --git a/pkg/model/provider/rulebased/client.go b/pkg/model/provider/rulebased/client.go index 9f28a6b91..5eee40c5e 100644 --- a/pkg/model/provider/rulebased/client.go +++ b/pkg/model/provider/rulebased/client.go @@ -23,6 +23,7 @@ import ( "github.com/docker/docker-agent/pkg/model/provider/base" "github.com/docker/docker-agent/pkg/model/provider/options" "github.com/docker/docker-agent/pkg/modelsdev" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/tools" ) @@ -165,7 +166,7 @@ func (c *Client) CreateChatCompletionStream( availableTools []tools.Tool, ) (chat.MessageStream, error) { provider := c.selectProvider(messages) - if provider == nil { + if reflectx.IsNil(provider) { return nil, errors.New("no provider available for routing") } @@ -240,7 +241,7 @@ func parseRouteIndex(docID string) (int, bool) { } func (c *Client) defaultProvider() Provider { - if c.fallback != nil { + if !reflectx.IsNil(c.fallback) { return c.fallback } if len(c.routes) > 0 { diff --git a/pkg/rag/manager.go b/pkg/rag/manager.go index 3a68f6e8d..02806ce7a 100644 --- a/pkg/rag/manager.go +++ b/pkg/rag/manager.go @@ -16,6 +16,7 @@ import ( "github.com/docker/docker-agent/pkg/rag/rerank" "github.com/docker/docker-agent/pkg/rag/strategy" "github.com/docker/docker-agent/pkg/rag/types" + "github.com/docker/docker-agent/pkg/reflectx" ) // ToolConfig represents tool-specific configuration @@ -106,7 +107,7 @@ func New(_ context.Context, name string, config Config, strategyEvents <-chan ty } // Ensure fusion was actually created - if fusionStrategy == nil { + if reflectx.IsNil(fusionStrategy) { return nil, errors.New("fusion strategy is nil after creation (this is a bug)") } } @@ -243,7 +244,7 @@ func (m *Manager) Query(ctx context.Context, query string) ([]database.SearchRes "num_results", len(results)) // Apply reranking if configured - if m.reranker != nil { + if !reflectx.IsNil(m.reranker) { beforeCount := len(results) slog.DebugContext(ctx, "[RAG Manager] Applying reranking to single-strategy results", "rag_name", m.name, @@ -355,7 +356,7 @@ func (m *Manager) Query(ctx context.Context, query string) ([]database.SearchRes "num_strategies", len(strategyResults)) // Safety check: fusion should never be nil with multiple strategies - if m.fusion == nil { + if reflectx.IsNil(m.fusion) { return nil, errors.New("fusion strategy is nil but multiple strategies are configured (this is a bug)") } @@ -373,7 +374,7 @@ func (m *Manager) Query(ctx context.Context, query string) ([]database.SearchRes "result_limit", m.config.Results.Limit) // Apply reranking if configured (before limit and deduplication) - if m.reranker != nil { + if !reflectx.IsNil(m.reranker) { beforeCount := len(fusedResults) slog.DebugContext(ctx, "[RAG Manager] Applying reranking to fused results", "rag_name", m.name, diff --git a/pkg/rag/rerank/rerank.go b/pkg/rag/rerank/rerank.go index 065acc814..381822b88 100644 --- a/pkg/rag/rerank/rerank.go +++ b/pkg/rag/rerank/rerank.go @@ -12,6 +12,7 @@ import ( "github.com/docker/docker-agent/pkg/model/provider" "github.com/docker/docker-agent/pkg/rag/database" "github.com/docker/docker-agent/pkg/rag/types" + "github.com/docker/docker-agent/pkg/reflectx" ) // Reranker re-scores search results using a reranking model @@ -39,7 +40,7 @@ type LLMReranker struct { // NewLLMReranker creates a new LLM-based reranker func NewLLMReranker(config Config) (*LLMReranker, error) { - if config.Model == nil { + if reflectx.IsNil(config.Model) { return nil, errors.New("reranking model is required") } diff --git a/pkg/rag/strategy/semantic_embeddings.go b/pkg/rag/strategy/semantic_embeddings.go index c478dc529..8b4075025 100644 --- a/pkg/rag/strategy/semantic_embeddings.go +++ b/pkg/rag/strategy/semantic_embeddings.go @@ -19,6 +19,7 @@ import ( "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/rag/chunk" "github.com/docker/docker-agent/pkg/rag/types" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/tools" ) @@ -502,7 +503,7 @@ func humanizeMetadataKey(key string) string { // calculateSemanticUsageCost calculates cost for semantic LLM usage. func calculateSemanticUsageCost(modelsStore modelStore, id modelsdev.ID, usage *chat.Usage) float64 { - if usage == nil || modelsStore == nil || !id.IsValid() || id.Provider == "dmr" { + if usage == nil || reflectx.IsNil(modelsStore) || !id.IsValid() || id.Provider == "dmr" { return 0 } diff --git a/pkg/rag/strategy/vector_store.go b/pkg/rag/strategy/vector_store.go index 9d4935b92..c72584a91 100644 --- a/pkg/rag/strategy/vector_store.go +++ b/pkg/rag/strategy/vector_store.go @@ -20,6 +20,7 @@ import ( "github.com/docker/docker-agent/pkg/rag/embed" "github.com/docker/docker-agent/pkg/rag/treesitter" "github.com/docker/docker-agent/pkg/rag/types" + "github.com/docker/docker-agent/pkg/reflectx" ) // vectorStoreDB is the internal database interface used by VectorStore. @@ -162,7 +163,7 @@ func NewVectorStore(cfg VectorStoreConfig) *VectorStore { // before being sent to the embedding model. Passing nil resets to the default // behavior (raw chunk content). func (s *VectorStore) SetEmbeddingInputBuilder(builder EmbeddingInputBuilder) { - if builder == nil { + if reflectx.IsNil(builder) { s.embeddingInputBuilder = DefaultEmbeddingInputBuilder{} return } @@ -171,7 +172,7 @@ func (s *VectorStore) SetEmbeddingInputBuilder(builder EmbeddingInputBuilder) { // calculateCost calculates embedding cost using models.dev pricing func (s *VectorStore) calculateCost(tokens int64) float64 { - if s.modelsStore == nil || s.modelID.Provider == "dmr" { + if reflectx.IsNil(s.modelsStore) || s.modelID.Provider == "dmr" { return 0 } @@ -483,7 +484,7 @@ func (s *VectorStore) Close() error { } // Close database connection - if s.db != nil { + if !reflectx.IsNil(s.db) { if err := s.db.Close(); err != nil { slog.Error("Failed to close database", "strategy", s.name, "error", err) if firstErr == nil { diff --git a/pkg/reflectx/nil.go b/pkg/reflectx/nil.go new file mode 100644 index 000000000..dd91108fe --- /dev/null +++ b/pkg/reflectx/nil.go @@ -0,0 +1,19 @@ +// Package reflectx contains small reflection helpers. +package reflectx + +import "reflect" + +// IsNil reports whether v is nil or wraps a typed nil value. +func IsNil(v any) bool { + if v == nil { + return true + } + + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice, reflect.UnsafePointer: + return rv.IsNil() + default: + return false + } +} diff --git a/pkg/reflectx/nil_test.go b/pkg/reflectx/nil_test.go new file mode 100644 index 000000000..edd53d5e9 --- /dev/null +++ b/pkg/reflectx/nil_test.go @@ -0,0 +1,30 @@ +package reflectx + +import ( + "io" + "testing" + + "github.com/stretchr/testify/assert" +) + +type testReader struct{} + +func (*testReader) Read([]byte) (int, error) { return 0, io.EOF } + +func TestIsNil(t *testing.T) { + var ptr *testReader + var reader io.Reader = ptr + var values map[string]string + var items []string + var fn func() + + assert.True(t, IsNil(nil)) + assert.True(t, IsNil(ptr)) + assert.True(t, IsNil(reader)) + assert.True(t, IsNil(values)) + assert.True(t, IsNil(items)) + assert.True(t, IsNil(fn)) + assert.False(t, IsNil(&testReader{})) + assert.False(t, IsNil(42)) + assert.False(t, IsNil("")) +} diff --git a/pkg/runtime/compactor/compactor.go b/pkg/runtime/compactor/compactor.go index f810296ca..e9b7550a3 100644 --- a/pkg/runtime/compactor/compactor.go +++ b/pkg/runtime/compactor/compactor.go @@ -29,6 +29,7 @@ import ( "github.com/docker/docker-agent/pkg/compaction" "github.com/docker/docker-agent/pkg/model/provider" "github.com/docker/docker-agent/pkg/model/provider/options" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/session" ) @@ -114,7 +115,7 @@ func RunLLM(ctx context.Context, args LLMArgs) (*Result, error) { if args.ContextLimit <= 0 { return nil, errors.New("compactor: ContextLimit must be > 0") } - if args.Agent.Model(ctx) == nil { + if reflectx.IsNil(args.Agent.Model(ctx)) { return nil, errors.New("compactor: agent has no model") } diff --git a/pkg/runtime/hooks.go b/pkg/runtime/hooks.go index 6ca312f77..3a9134eca 100644 --- a/pkg/runtime/hooks.go +++ b/pkg/runtime/hooks.go @@ -9,6 +9,7 @@ import ( "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/hooks" "github.com/docker/docker-agent/pkg/hooks/builtins" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/runtime/toolexec" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/tools" @@ -61,7 +62,7 @@ func applyAutoInjectors(cfg *hooks.Config, injectors []builtins.AutoInjector) *h cfg = &hooks.Config{} } for _, inj := range injectors { - if inj != nil { + if !reflectx.IsNil(inj) { inj.AutoInject(cfg) } } @@ -105,11 +106,11 @@ func (r *LocalRuntime) dispatchHook( } started := time.Now() - if events != nil { + if !reflectx.IsNil(events) { events.Emit(HookStarted(event, input.SessionID, a.Name())) } result, err := exec.Dispatch(ctx, event, input) - if events != nil { + if !reflectx.IsNil(events) { events.Emit(HookFinished(event, input.SessionID, result, err, time.Since(started), a.Name())) } if err != nil { @@ -117,7 +118,7 @@ func (r *LocalRuntime) dispatchHook( return nil } - if events != nil && result.SystemMessage != "" { + if !reflectx.IsNil(events) && result.SystemMessage != "" { events.Emit(Warning(result.SystemMessage, a.Name())) } return result diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index aaa1ba6cd..a19c72546 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -18,6 +18,7 @@ import ( "github.com/docker/docker-agent/pkg/httpclient" "github.com/docker/docker-agent/pkg/model/provider" "github.com/docker/docker-agent/pkg/modelsdev" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/runtime/toolexec" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/tools" @@ -571,7 +572,7 @@ func (r *LocalRuntime) runTurn( // handlers can reuse the same parsing. r.executeAfterLLMCallHooks(ctx, sess, a, res.Content) - if usedModel != nil && usedModel.ID() != model.ID() { + if !reflectx.IsNil(usedModel) && usedModel.ID() != model.ID() { slog.InfoContext(ctx, "Used fallback model", "agent", a.Name(), "primary", model.ID().String(), "used", usedModel.ID().String()) events.Emit(AgentInfo(a.Name(), usedModel.ID().String(), a.Description(), a.WelcomeMessage())) } diff --git a/pkg/runtime/model_switcher.go b/pkg/runtime/model_switcher.go index 81b6caea6..151d8e403 100644 --- a/pkg/runtime/model_switcher.go +++ b/pkg/runtime/model_switcher.go @@ -14,6 +14,7 @@ import ( "github.com/docker/docker-agent/pkg/model/provider" "github.com/docker/docker-agent/pkg/model/provider/options" "github.com/docker/docker-agent/pkg/modelsdev" + "github.com/docker/docker-agent/pkg/reflectx" ) // ModelChoice represents a model available for selection in the model picker. @@ -489,7 +490,7 @@ func mapModelsDevProvider(providerID string) (string, bool) { // provider/model pair and copies it onto choice. It silently does // nothing when the lookup fails or when the runtime has no models store. func (r *LocalRuntime) populateCatalogMetadata(ctx context.Context, choice *ModelChoice, providerID, modelID string) { - if r.modelsStore == nil { + if reflectx.IsNil(r.modelsStore) { return } m, err := r.modelsStore.GetModel(ctx, modelsdev.NewID(providerID, modelID)) @@ -629,7 +630,7 @@ func (r *LocalRuntime) createProviderFromConfig(ctx context.Context, cfg *latest // Use max_tokens from config if specified, otherwise look up from models.dev if cfg.MaxTokens != nil { opts = append(opts, options.WithMaxTokens(*cfg.MaxTokens)) - } else if r.modelsStore != nil { + } else if !reflectx.IsNil(r.modelsStore) { m, err := r.modelsStore.GetModel(ctx, modelsdev.NewID(cfg.Provider, cfg.Model)) if err == nil && m != nil { opts = append(opts, options.WithMaxTokens(m.Limit.Output)) diff --git a/pkg/runtime/observer.go b/pkg/runtime/observer.go index d9ff008f8..26df2ff0f 100644 --- a/pkg/runtime/observer.go +++ b/pkg/runtime/observer.go @@ -3,6 +3,7 @@ package runtime import ( "context" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/session" ) @@ -50,7 +51,7 @@ type EventObserver interface { // alongside that one. func WithEventObserver(o EventObserver) Opt { return func(r *LocalRuntime) { - if o == nil { + if reflectx.IsNil(o) { return } r.observers = append(r.observers, o) diff --git a/pkg/runtime/persistence_observer.go b/pkg/runtime/persistence_observer.go index f6b0a5ed6..e7a35a392 100644 --- a/pkg/runtime/persistence_observer.go +++ b/pkg/runtime/persistence_observer.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/session" ) @@ -47,7 +48,7 @@ type streamingState struct { // nil when store is nil so the constructor can call [WithEventObserver] // unconditionally without a guard. func newPersistenceObserver(store session.Store) *PersistenceObserver { - if store == nil { + if reflectx.IsNil(store) { return nil } return &PersistenceObserver{store: store} diff --git a/pkg/runtime/remote_runtime.go b/pkg/runtime/remote_runtime.go index 38d2383db..dd185e439 100644 --- a/pkg/runtime/remote_runtime.go +++ b/pkg/runtime/remote_runtime.go @@ -16,6 +16,7 @@ import ( "github.com/docker/docker-agent/pkg/api" "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/sessiontitle" "github.com/docker/docker-agent/pkg/team" @@ -70,7 +71,7 @@ func WithRemoteAgentFilename(filename string) RemoteRuntimeOption { // NewRemoteRuntime creates a new remote runtime that implements the Runtime interface. // It accepts any client that implements the RemoteClient interface. func NewRemoteRuntime(client RemoteClient, opts ...RemoteRuntimeOption) (*RemoteRuntime, error) { - if client == nil { + if reflectx.IsNil(client) { return nil, errors.New("client cannot be nil") } diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 3ea8441fa..fff95aa7c 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -22,6 +22,7 @@ import ( "github.com/docker/docker-agent/pkg/hooks/builtins" "github.com/docker/docker-agent/pkg/httpclient" "github.com/docker/docker-agent/pkg/modelsdev" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/sessiontitle" "github.com/docker/docker-agent/pkg/team" @@ -374,7 +375,7 @@ func WithClock(now func() time.Time) Opt { // events without setting up an OTel client. func WithTelemetry(t Telemetry) Opt { return func(r *LocalRuntime) { - if t != nil { + if !reflectx.IsNil(t) { r.telemetry = t } } @@ -424,7 +425,7 @@ func WithRetryOnRateLimit() Opt { // Multiple calls accumulate; injectors run in registration order. func WithAutoInjector(inj builtins.AutoInjector) Opt { return func(r *LocalRuntime) { - if inj != nil { + if !reflectx.IsNil(inj) { r.autoInjectors = append(r.autoInjectors, inj) } } @@ -546,7 +547,7 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) { } } - if r.modelsStore == nil { + if reflectx.IsNil(r.modelsStore) { r.modelsStore = &lazyModelStore{} } @@ -557,7 +558,7 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) { return nil, err } - if defaultAgent.Model(context.TODO()) == nil && !defaultAgent.HasHarness() { + if reflectx.IsNil(defaultAgent.Model(context.TODO())) && !defaultAgent.HasHarness() { return nil, fmt.Errorf("agent %s has no valid model", defaultAgent.Name()) } @@ -853,7 +854,7 @@ func (r *LocalRuntime) TitleGenerator() *sessiontitle.Generator { // generator carries its own ctx when actually invoked. context.TODO is // the right marker here. model := a.Model(context.TODO()) - if model == nil { + if reflectx.IsNil(model) { return nil } return sessiontitle.New(model, a.FallbackModels()...) @@ -865,7 +866,7 @@ func getAgentModelID(a *agent.Agent) modelsdev.ID { if a == nil { return modelsdev.ID{} } - if model := a.Model(context.TODO()); model != nil { + if model := a.Model(context.TODO()); !reflectx.IsNil(model) { return model.ID() } return modelsdev.ID{} @@ -928,7 +929,7 @@ func (r *LocalRuntime) SessionStore() session.Store { // Close releases resources held by the runtime, including the session store. func (r *LocalRuntime) Close() error { r.bgAgents.StopAll() - if r.sessionStore != nil { + if !reflectx.IsNil(r.sessionStore) { return r.sessionStore.Close() } return nil @@ -937,7 +938,7 @@ func (r *LocalRuntime) Close() error { // UpdateSessionTitle persists the session title via the session store. func (r *LocalRuntime) UpdateSessionTitle(ctx context.Context, sess *session.Session, title string) error { sess.Title = title - if r.sessionStore != nil { + if !reflectx.IsNil(r.sessionStore) { return r.sessionStore.UpdateSession(ctx, sess) } return nil diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index d67459585..5f4b33a76 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -24,6 +24,7 @@ import ( "github.com/docker/docker-agent/pkg/modelerrors" "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/permissions" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/team" "github.com/docker/docker-agent/pkg/tools" @@ -246,7 +247,7 @@ func assertEventsEqual(t *testing.T, expected, actual []Event) { // clearTimestamps sets Timestamp fields to zero value in events for comparison. func clearTimestamps(event Event) { - if event == nil { + if reflectx.IsNil(event) { return } diff --git a/pkg/runtime/sampling.go b/pkg/runtime/sampling.go index c1ff91559..0c12f19e1 100644 --- a/pkg/runtime/sampling.go +++ b/pkg/runtime/sampling.go @@ -14,6 +14,7 @@ import ( "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/model/provider" "github.com/docker/docker-agent/pkg/model/provider/options" + "github.com/docker/docker-agent/pkg/reflectx" ) // Limits applied to inbound sampling requests to keep a misbehaving or @@ -64,7 +65,7 @@ func (r *LocalRuntime) samplingHandler(ctx context.Context, req *mcp.CreateMessa } baseModel := a.Model(ctx) - if baseModel == nil { + if reflectx.IsNil(baseModel) { return nil, errors.New("current agent has no model configured") } diff --git a/pkg/runtime/session_compaction.go b/pkg/runtime/session_compaction.go index 6132c2d5d..fa3f88285 100644 --- a/pkg/runtime/session_compaction.go +++ b/pkg/runtime/session_compaction.go @@ -13,6 +13,7 @@ import ( "github.com/docker/docker-agent/pkg/model/provider" "github.com/docker/docker-agent/pkg/model/provider/options" "github.com/docker/docker-agent/pkg/modelsdev" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/runtime/compactor" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/team" @@ -170,7 +171,7 @@ func summaryFromHook(sess *session.Session, a *agent.Agent, pre *hooks.Result) * // pass the cloned summary-call provider so its provider_opts (which // match the underlying model) are considered. func (r *LocalRuntime) compactionContextLimit(ctx context.Context, a *agent.Agent) int64 { - if a == nil || a.Model(ctx) == nil { + if a == nil || reflectx.IsNil(a.Model(ctx)) { return 0 } summaryModel := provider.CloneWithOptions(ctx, a.Model(ctx), @@ -213,7 +214,7 @@ func (r *LocalRuntime) resolveContextLimit(ctx context.Context, p provider.Provi // treated as "unset" so callers don't accidentally trigger // compaction with a degenerate limit. func providerContextLimit(p provider.Provider) int64 { - if p == nil { + if reflectx.IsNil(p) { return 0 } opts := p.BaseConfig().ModelConfig.ProviderOpts diff --git a/pkg/runtime/toolexec/dispatcher.go b/pkg/runtime/toolexec/dispatcher.go index bf901243b..6886f977f 100644 --- a/pkg/runtime/toolexec/dispatcher.go +++ b/pkg/runtime/toolexec/dispatcher.go @@ -17,6 +17,7 @@ import ( "github.com/docker/docker-agent/pkg/agent" "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/hooks" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/telemetry" "github.com/docker/docker-agent/pkg/tools" @@ -373,7 +374,7 @@ func (c *call) approveAndRun(ctx context.Context, runTool func() CallOutcome) Ca // arguments — this is the only place pre-call argument rewriting // happens. func (c *call) consultPreToolUseHook(ctx context.Context, runTool func() CallOutcome) (CallOutcome, bool) { - if c.d.Hooks == nil { + if reflectx.IsNil(c.d.Hooks) { return CallOutcome{}, false } @@ -425,7 +426,7 @@ func (c *call) applyHookModifiedInput(result *hooks.Result) { // HookDispatcher, when one is configured. Centralised so the nil-guard // stays in one place. func (c *call) notifyApproval(ctx context.Context, decision, source string) { - if c.d.Hooks == nil { + if reflectx.IsNil(c.d.Hooks) { return } c.d.Hooks.NotifyApprovalDecision(ctx, c.sess, c.a, c.tc, decision, source) @@ -490,7 +491,7 @@ func (c *call) askUser(ctx context.Context, runTool func() CallOutcome) CallOutc slog.DebugContext(ctx, "Tools not approved, waiting for resume", "tool", c.tc.Function.Name, "session_id", c.sess.ID) c.em.EmitToolCallConfirmation(c.tc, c.tool, c.a.Name()) - if c.d.Hooks != nil { + if !reflectx.IsNil(c.d.Hooks) { c.d.Hooks.NotifyUserInput(ctx, c.sess.ID, "tool confirmation") } @@ -512,7 +513,7 @@ func (c *call) askUser(ctx context.Context, runTool func() CallOutcome) CallOutc // "block" without permission_decision) is also honoured. Returning // nothing keeps the existing behaviour and asks the user. func (c *call) runPermissionRequestHook(ctx context.Context, runTool func() CallOutcome) (CallOutcome, bool) { - if c.d.Hooks == nil { + if reflectx.IsNil(c.d.Hooks) { return CallOutcome{}, false } @@ -666,7 +667,7 @@ func (c *call) invoke(ctx context.Context, spanName string, exec func(ctx contex // the persisted session file, the input the post_tool_use hook sees, // and the messages going to the next LLM call. func (c *call) applyToolResponseTransform(ctx context.Context, payload string, isError bool) string { - if c.d.Hooks == nil { + if reflectx.IsNil(c.d.Hooks) { return payload } in := NewPostToolHooksInput(c.sess, c.tc, &tools.ToolCallResult{Output: payload, IsError: isError}) @@ -741,7 +742,7 @@ func buildMultiContent(text string, images []tools.MediaContent) []chat.MessageP // (stop, message) return. The tool result is forwarded to the hook so // post_tool_use handlers can inspect ToolResponse / ToolError. func (c *call) postHook(ctx context.Context, res *tools.ToolCallResult) (stop bool, message string) { - if c.d.Hooks == nil { + if reflectx.IsNil(c.d.Hooks) { return false, "" } result := c.d.Hooks.Dispatch(ctx, c.a, hooks.EventPostToolUse, NewPostToolHooksInput(c.sess, c.tc, res)) diff --git a/pkg/sessiontitle/generator.go b/pkg/sessiontitle/generator.go index 5e59d4399..7b3709715 100644 --- a/pkg/sessiontitle/generator.go +++ b/pkg/sessiontitle/generator.go @@ -17,6 +17,7 @@ import ( "github.com/docker/docker-agent/pkg/httpclient" "github.com/docker/docker-agent/pkg/model/provider" "github.com/docker/docker-agent/pkg/model/provider/options" + "github.com/docker/docker-agent/pkg/reflectx" ) const ( @@ -46,7 +47,7 @@ type Generator struct { func New(model provider.Provider, fallbackModels ...provider.Provider) *Generator { models := slices.DeleteFunc( append([]provider.Provider{model}, fallbackModels...), - func(p provider.Provider) bool { return p == nil }, + func(p provider.Provider) bool { return reflectx.IsNil(p) }, ) return &Generator{models: models} } diff --git a/pkg/team/team.go b/pkg/team/team.go index dc1ad87a0..cb0b649cc 100644 --- a/pkg/team/team.go +++ b/pkg/team/team.go @@ -9,6 +9,7 @@ import ( "github.com/docker/docker-agent/pkg/agent" "github.com/docker/docker-agent/pkg/config/types" "github.com/docker/docker-agent/pkg/permissions" + "github.com/docker/docker-agent/pkg/reflectx" ) type Team struct { @@ -64,7 +65,7 @@ func (t *Team) AgentsInfo() []AgentInfo { Description: a.Description(), Commands: a.Commands(), } - if model := a.Model(context.TODO()); model != nil { + if model := a.Model(context.TODO()); !reflectx.IsNil(model) { id := model.ID() info.Provider = id.Provider info.Model = id.Model diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index a7fd35a03..a0f1d3182 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -22,6 +22,7 @@ import ( "github.com/docker/docker-agent/pkg/model/provider/options" "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/permissions" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/skills" "github.com/docker/docker-agent/pkg/team" "github.com/docker/docker-agent/pkg/tools" @@ -620,7 +621,7 @@ func loadExternalAgent(ctx context.Context, ref string, runConfig *config.Runtim } var opts []Opt - if loadOpts.toolsetRegistry != nil { + if !reflectx.IsNil(loadOpts.toolsetRegistry) { opts = append(opts, WithToolsetRegistry(loadOpts.toolsetRegistry)) } diff --git a/pkg/tools/builtin/mcpcatalog/mcpcatalog.go b/pkg/tools/builtin/mcpcatalog/mcpcatalog.go index ed0a7f1ea..c57a9c9d6 100644 --- a/pkg/tools/builtin/mcpcatalog/mcpcatalog.go +++ b/pkg/tools/builtin/mcpcatalog/mcpcatalog.go @@ -46,6 +46,7 @@ import ( "sync" "github.com/docker/docker-agent/pkg/environment" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/tools" "github.com/docker/docker-agent/pkg/tools/mcp" ) @@ -615,7 +616,7 @@ func (t *Toolset) expandHeaders(ctx context.Context, in map[string]string) map[s for k, v := range in { out[k] = unresolvedHeaderEnv.ReplaceAllStringFunc(v, func(match string) string { name := match[2 : len(match)-1] // strip ${ and } - if t.env == nil { + if reflectx.IsNil(t.env) { return match } if val, ok := t.env.Get(ctx, name); ok && val != "" { @@ -631,7 +632,7 @@ func (t *Toolset) expandHeaders(ctx context.Context, in map[string]string) map[s // available from the toolset's env provider. Empty result means "all good". // Returns nil for non api_key servers. func (t *Toolset) missingAPIKeyEnv(ctx context.Context, s Server) []string { - if s.Auth.Type != "api_key" || t.env == nil { + if s.Auth.Type != "api_key" || reflectx.IsNil(t.env) { return nil } var missing []string diff --git a/pkg/tools/builtin/todo/todo.go b/pkg/tools/builtin/todo/todo.go index ec9a87226..c6680e38b 100644 --- a/pkg/tools/builtin/todo/todo.go +++ b/pkg/tools/builtin/todo/todo.go @@ -10,6 +10,7 @@ import ( "github.com/docker/docker-agent/pkg/concurrent" "github.com/docker/docker-agent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/tools" ) @@ -140,7 +141,7 @@ type Option func(*ToolSet) // WithStorage sets a custom storage implementation for the Tool. // The provided storage must not be nil. func WithStorage(storage Storage) Option { - if storage == nil { + if reflectx.IsNil(storage) { panic("todo: storage must not be nil") } return func(t *ToolSet) { diff --git a/pkg/tools/lifecycle/supervisor.go b/pkg/tools/lifecycle/supervisor.go index ef3c2c4ae..a86c6a50f 100644 --- a/pkg/tools/lifecycle/supervisor.go +++ b/pkg/tools/lifecycle/supervisor.go @@ -8,6 +8,8 @@ import ( "math/rand/v2" "sync" "time" + + "github.com/docker/docker-agent/pkg/reflectx" ) // Connector creates new sessions for a Supervisor. Implementations are @@ -187,7 +189,7 @@ func (s *Supervisor) Start(ctx context.Context) error { defer s.startMu.Unlock() s.mu.Lock() - if s.session != nil { + if !reflectx.IsNil(s.session) { s.mu.Unlock() return nil } @@ -255,7 +257,7 @@ func (s *Supervisor) Stop(ctx context.Context) error { s.tracker.Set(StateStopped) s.signalDone() - if sess == nil { + if reflectx.IsNil(sess) { return nil } if err := sess.Close(context.WithoutCancel(ctx)); err != nil && ctx.Err() == nil { @@ -287,7 +289,7 @@ func (s *Supervisor) RestartAndWait(ctx context.Context, timeout time.Duration) // Only force-close if currently usable. If the watcher already detected // the disconnect, closing now would race with tryRestart. - if state.IsUsable() && sess != nil { + if state.IsUsable() && !reflectx.IsNil(sess) { _ = sess.Close(context.WithoutCancel(ctx)) } @@ -335,7 +337,7 @@ func (s *Supervisor) watch(ctx context.Context) { s.mu.Lock() sess := s.session s.mu.Unlock() - if sess == nil { + if reflectx.IsNil(sess) { return // defensive: shouldn't happen after a successful Start. } diff --git a/pkg/tools/mcp/remote.go b/pkg/tools/mcp/remote.go index 0e2600323..40693d482 100644 --- a/pkg/tools/mcp/remote.go +++ b/pkg/tools/mcp/remote.go @@ -9,6 +9,7 @@ import ( gomcp "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/docker/docker-agent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/upstream" ) @@ -38,7 +39,7 @@ func newRemoteClient( "allow_private_ips", allowPrivateIPs, ) - if tokenStore == nil { + if reflectx.IsNil(tokenStore) { tokenStore = NewInMemoryTokenStore() } diff --git a/pkg/tools/named.go b/pkg/tools/named.go index 5212127c0..e26a1be8f 100644 --- a/pkg/tools/named.go +++ b/pkg/tools/named.go @@ -1,5 +1,7 @@ package tools +import "github.com/docker/docker-agent/pkg/reflectx" + // Named is implemented by toolsets that carry a user-visible name. // // The convention is: @@ -31,7 +33,7 @@ func GetName(ts ToolSet) string { // The returned wrapper participates in As[T]: every capability of ts // remains reachable through the wrapper. func WithName(ts ToolSet, name string) ToolSet { - if ts == nil || name == "" { + if reflectx.IsNil(ts) || name == "" { return ts } if existing, ok := As[Named](ts); ok && existing.Name() != "" { diff --git a/pkg/tools/startable.go b/pkg/tools/startable.go index f550a4553..d25dc2812 100644 --- a/pkg/tools/startable.go +++ b/pkg/tools/startable.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "sync" + + "github.com/docker/docker-agent/pkg/reflectx" ) // Describer can be implemented by a ToolSet to provide a short, user-visible @@ -144,7 +146,7 @@ type Unwrapper interface { // prompts, _ := pp.ListPrompts(ctx) // } func As[T any](ts ToolSet) (T, bool) { - for ts != nil { + for !reflectx.IsNil(ts) { if result, ok := ts.(T); ok { return result, true } diff --git a/pkg/tui/components/editor/editor.go b/pkg/tui/components/editor/editor.go index a4e1fbcf5..08d8bd543 100644 --- a/pkg/tui/components/editor/editor.go +++ b/pkg/tui/components/editor/editor.go @@ -23,6 +23,7 @@ import ( "github.com/docker/docker-agent/pkg/history" "github.com/docker/docker-agent/pkg/paths" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/tui/components/completion" "github.com/docker/docker-agent/pkg/tui/components/editor/completions" "github.com/docker/docker-agent/pkg/tui/core" @@ -467,7 +468,7 @@ func deleteLastGraphemeCluster(s string) string { // textarea value and available history entries. func (e *editor) refreshSuggestion() { // Don't overwrite completion-managed suggestions with history suggestions. - if e.currentCompletion != nil { + if !reflectx.IsNil(e.currentCompletion) { return } @@ -655,7 +656,7 @@ func (e *editor) Update(msg tea.Msg) (layout.Model, tea.Cmd) { return e, cmd case completion.SelectedMsg: - if e.currentCompletion == nil { + if reflectx.IsNil(e.currentCompletion) { return e, nil } @@ -738,7 +739,7 @@ func (e *editor) Update(msg tea.Msg) (layout.Model, tea.Cmd) { case completion.SelectionChangedMsg: // Show the selected completion item as a suggestion in the editor. e.clearSuggestion() - if msg.Value != "" && e.currentCompletion != nil { + if msg.Value != "" && !reflectx.IsNil(e.currentCompletion) { currentText := e.textarea.Value() if strings.HasPrefix(msg.Value, currentText) { e.suggestion = msg.Value[len(currentText):] @@ -1003,7 +1004,7 @@ func (e *editor) handleGraphemeBackspace() (layout.Model, tea.Cmd) { func (e *editor) updateCompletionQuery() tea.Cmd { currentWord := e.textarea.Word() - if e.currentCompletion != nil && strings.HasPrefix(currentWord, e.currentCompletion.Trigger()) { + if !reflectx.IsNil(e.currentCompletion) && strings.HasPrefix(currentWord, e.currentCompletion.Trigger()) { e.completionWord = strings.TrimPrefix(currentWord, e.currentCompletion.Trigger()) // For @ completion, start full file loading when user starts typing (if not already started) @@ -1050,7 +1051,7 @@ func (e *editor) startFullFileLoad() tea.Cmd { } } - if asyncLoader == nil { + if reflectx.IsNil(asyncLoader) { return nil } @@ -1129,7 +1130,7 @@ func (e *editor) startInitialFileLoad() tea.Cmd { } } - if asyncLoader == nil { + if reflectx.IsNil(asyncLoader) { return nil } diff --git a/pkg/tui/components/statusbar/statusbar.go b/pkg/tui/components/statusbar/statusbar.go index 5c43094ef..18f954a4c 100644 --- a/pkg/tui/components/statusbar/statusbar.go +++ b/pkg/tui/components/statusbar/statusbar.go @@ -6,6 +6,7 @@ import ( "charm.land/lipgloss/v2" "github.com/charmbracelet/x/ansi" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/tui/core" "github.com/docker/docker-agent/pkg/tui/styles" ) @@ -116,7 +117,7 @@ func (s *StatusBar) rebuild() { var left string var leftW int - if s.help != nil { + if !reflectx.IsNil(s.help) { if help := s.help.Help(); help != nil { var parts []string for _, b := range help.ShortHelp() { diff --git a/pkg/tui/handlers.go b/pkg/tui/handlers.go index 90df59def..ed0a5c4fe 100644 --- a/pkg/tui/handlers.go +++ b/pkg/tui/handlers.go @@ -16,6 +16,7 @@ import ( "github.com/docker/docker-agent/pkg/app" "github.com/docker/docker-agent/pkg/browser" "github.com/docker/docker-agent/pkg/evaluation" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/shellpath" "github.com/docker/docker-agent/pkg/tools" @@ -34,7 +35,7 @@ import ( func (m *appModel) handleBranchFromEdit(msg messages.BranchFromEditMsg) (tea.Model, tea.Cmd) { store := m.application.SessionStore() - if store == nil { + if reflectx.IsNil(store) { return m, notification.ErrorCmd("No session store configured") } if msg.ParentSessionID == "" { @@ -104,7 +105,7 @@ func (m *appModel) handleForkSession() (tea.Model, tea.Cmd) { } store := m.application.SessionStore() - if store == nil { + if reflectx.IsNil(store) { return m, notification.ErrorCmd("No session store configured") } @@ -144,7 +145,7 @@ func (m *appModel) handleForkSession() (tea.Model, tea.Cmd) { func (m *appModel) handleToggleSessionStar(sessionID string) (tea.Model, tea.Cmd) { store := m.application.SessionStore() - if store == nil { + if reflectx.IsNil(store) { return m, notification.ErrorCmd("No session store configured") } @@ -197,7 +198,7 @@ func (m *appModel) handleRegenerateTitle() (tea.Model, tea.Cmd) { func (m *appModel) handleDeleteSession(sessionID string) (tea.Model, tea.Cmd) { store := m.application.SessionStore() - if store == nil { + if reflectx.IsNil(store) { return m, notification.ErrorCmd("No session store configured") } diff --git a/pkg/tui/tabs.go b/pkg/tui/tabs.go index 87ea2b286..c3f75bf6f 100644 --- a/pkg/tui/tabs.go +++ b/pkg/tui/tabs.go @@ -5,6 +5,7 @@ import ( "log/slog" "github.com/docker/docker-agent/pkg/app" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/tui/service/supervisor" "github.com/docker/docker-agent/pkg/tui/service/tuistate" "github.com/docker/docker-agent/pkg/userconfig" @@ -79,7 +80,7 @@ func (m *appModel) restoreTabs( for _, saved := range savedTabs { // Validate the saved session still exists. - if sessionStore != nil && saved.SessionID != "" { + if !reflectx.IsNil(sessionStore) && saved.SessionID != "" { if _, err := sessionStore.GetSession(ctx, saved.SessionID); err != nil { slog.WarnContext(ctx, "Saved session no longer exists, removing stale tab", "session_id", saved.SessionID, "error", err) @@ -119,7 +120,7 @@ func (m *appModel) restoreTabs( } // Peek at the session title so the tab bar shows a name before lazy load. - if sessionStore != nil && saved.SessionID != "" { + if !reflectx.IsNil(sessionStore) && saved.SessionID != "" { if oldSess, err := sessionStore.GetSession(ctx, saved.SessionID); err == nil && oldSess.Title != "" { sv.SetRunnerTitle(runtimeID, oldSess.Title) } diff --git a/pkg/tui/tui.go b/pkg/tui/tui.go index 8a4810567..dc5d150a7 100644 --- a/pkg/tui/tui.go +++ b/pkg/tui/tui.go @@ -20,6 +20,7 @@ import ( "github.com/docker/docker-agent/pkg/app" "github.com/docker/docker-agent/pkg/audio/transcribe" "github.com/docker/docker-agent/pkg/history" + "github.com/docker/docker-agent/pkg/reflectx" "github.com/docker/docker-agent/pkg/runtime" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/tui/animation" @@ -252,7 +253,7 @@ func WithCommandBuilder( // connecting to a real audio device or external API. func WithTranscriber(t Transcriber) Option { return func(m *appModel) { - if t != nil { + if !reflectx.IsNil(t) { m.transcriber = t } } @@ -466,7 +467,7 @@ func (m *appModel) Init() tea.Cmd { activeID := m.supervisor.ActiveID() if oldSessionID, ok := m.pendingRestores[activeID]; ok { delete(m.pendingRestores, activeID) - if store := m.application.SessionStore(); store != nil { + if store := m.application.SessionStore(); !reflectx.IsNil(store) { if sess, err := store.GetSession(context.Background(), oldSessionID); err == nil { _, cmd := m.replaceActiveSession(context.Background(), sess) @@ -1038,7 +1039,7 @@ func (m *appModel) handleWorkingStateChanged(msg messages.WorkingStateChangedMsg // handleOpenSessionBrowser opens the session browser dialog. func (m *appModel) handleOpenSessionBrowser() (tea.Model, tea.Cmd) { store := m.application.SessionStore() - if store == nil { + if reflectx.IsNil(store) { return m, notification.InfoCmd("No session store configured") } @@ -1058,7 +1059,7 @@ func (m *appModel) handleOpenSessionBrowser() (tea.Model, tea.Cmd) { // handleLoadSession loads a saved session into the current tab (if empty) or a new tab. func (m *appModel) handleLoadSession(sessionID string) (tea.Model, tea.Cmd) { store := m.application.SessionStore() - if store == nil { + if reflectx.IsNil(store) { return m, notification.ErrorCmd("No session store configured") } @@ -1312,7 +1313,7 @@ func (m *appModel) handleSwitchTab(sessionID string) (tea.Model, tea.Cmd) { var closeBackgroundDialogCmd tea.Cmd if backgroundEvent != nil && outgoingTabID != "" && outgoingTabID != sessionID { m.supervisor.SetPendingEvent(outgoingTabID, backgroundEvent) - if backgroundDialog != nil { + if !reflectx.IsNil(backgroundDialog) { m.stashedDialogs[outgoingTabID] = stashedDialog{ dialog: backgroundDialog, event: backgroundEvent, @@ -1329,7 +1330,7 @@ func (m *appModel) handleSwitchTab(sessionID string) (tea.Model, tea.Cmd) { if oldSessionID, ok := m.pendingRestores[sessionID]; ok { delete(m.pendingRestores, sessionID) m.application = runner.App - if store := runner.App.SessionStore(); store != nil { + if store := runner.App.SessionStore(); !reflectx.IsNil(store) { if sess, err := store.GetSession(context.Background(), oldSessionID); err == nil { m.persistActiveTab(sess.ID) model, cmd := m.replaceActiveSession(context.Background(), sess) @@ -1438,7 +1439,7 @@ func (m *appModel) replayPendingEvent(sessionID string) tea.Cmd { // in-progress input is preserved. if stash, ok := m.stashedDialogs[sessionID]; ok { delete(m.stashedDialogs, sessionID) - if stash.event == pendingEvent && stash.dialog != nil { + if stash.event == pendingEvent && !reflectx.IsNil(stash.dialog) { return core.CmdHandler(dialog.OpenDialogMsg{ Model: stash.dialog, OriginatingEvent: pendingEvent,