diff --git a/README.md b/README.md index 48412a3..bd736c2 100644 --- a/README.md +++ b/README.md @@ -115,28 +115,70 @@ mimir analyze --hint # Hint for faster patch planning ## MCP Tools -AI agents get access to 7 tools via MCP: +AI agents get access to 12 tools via MCP: | Tool | Description | Example | |---|---|---| | `query` | Hybrid search (BM25 + vector) | "Find all auth-related processes" | | `context` | 360-degree symbol view | "Show handleRequest definition and callers" | +| `find_referencing` | Who directly calls/imports/extends a symbol (1-hop) | "What calls UserService.GetUser?" | +| `symbol_coordinates` | Exact file path + line range for a symbol | "Where is ProcessOrder defined?" | +| `get_symbols_overview` | All top-level symbols in a file, sorted by line | "What's exported from store.go?" | | `impact` | Blast radius analysis | "What breaks if I change UserService?" | | `detect_changes` | Analyze uncommitted git changes | "What processes did my commit affect?" | | `rename` | Plan coordinated multi-file rename | "Rename AuthController to SessionController" | | `cypher` | Raw graph queries | "Find unused exported functions" | -| `list_repos` | List indexed repositories | — | +| `list_repos` | List all registered repositories | — | +| `query_repo` | Run a read-only tool against a different repo | "Query symbols in repo B from repo A" | ### Recommended Workflow 1. **Discovery**: Use `query()` for semantic search 2. **Deep Dive**: Use `context()` to understand a symbol -3. **Before Editing**: Always run `impact()` first +3. **Find callers**: Use `find_referencing()` for a lightweight 1-hop caller list +4. **Understand a file**: Use `get_symbols_overview()` to see what's defined in a file +5. **Before Editing**: Run `symbol_coordinates()` to get the exact location, then `impact()` for blast radius +6. **Cross-repo**: Use `list_repos()` to discover repos, then `query_repo()` to query them For detailed usage, see [docs/guide.md](docs/guide.md). --- +## Testing MCP Tools Locally + +Use [MCP Inspector](https://github.com/modelcontextprotocol/inspector) — the official browser-based UI for exploring and calling MCP tools interactively: + +```bash +# Build first +make build + +# Index the current repo +./bin/mimir analyze . + +# Launch MCP Inspector (requires Node.js) +npx @modelcontextprotocol/inspector ./bin/mimir mcp +``` + +Then open **http://localhost:5173** in your browser. You can browse all 12 tools, fill in arguments via a form, and see the JSON responses in real time. + +Alternatively, test via raw stdin: + +```bash +# List all tools +echo '{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}' | ./bin/mimir mcp + +# Call a tool +echo '{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"list_repos","arguments":{}}}' | ./bin/mimir mcp + +# Query a symbol in the current repo +echo '{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"query","arguments":{"q":"handleRequest"}}}' | ./bin/mimir mcp + +# Cross-repo query: search for a symbol in a different repo +echo '{"jsonrpc":"2.0","id":4,"method":"tools/call","params":{"name":"query_repo","arguments":{"tool_name":"query","arguments":{"query":"auth"},"target_repo":"other-project","current_repo":"git-mimir"}}}' | ./bin/mimir mcp +``` + +--- + ## Auto-Analyze Features Running `mimir analyze` automatically sets up: diff --git a/cmd/mimir/analyze.go b/cmd/mimir/analyze.go index 13f2032..0f68a43 100644 --- a/cmd/mimir/analyze.go +++ b/cmd/mimir/analyze.go @@ -10,6 +10,8 @@ import ( "strings" "time" + "github.com/schollz/progressbar/v3" + "github.com/spf13/cobra" "github.com/thuongh2/git-mimir/internal/cluster" "github.com/thuongh2/git-mimir/internal/daemon" @@ -26,17 +28,17 @@ import ( ) var ( - analyzeForce bool - analyzeSkipEmbeds bool - analyzeResolution float64 - analyzeName string - analyzeIncremental bool - analyzeHint string - analyzeRepo string - analyzeSkipHooks bool - analyzeSkipSkills bool - analyzeSkipDaemon bool - analyzeQuiet bool + analyzeForce bool + analyzeSkipEmbeds bool + analyzeResolution float64 + analyzeName string + analyzeIncremental bool + analyzeHint string + analyzeRepo string + analyzeSkipHooks bool + analyzeSkipSkills bool + analyzeSkipDaemon bool + analyzeQuiet bool ) func init() { @@ -342,29 +344,57 @@ func fullIndex(ctx context.Context, repoPath string, s *store.Store) ([]graph.Fi concurrency := runtime.GOMAXPROCS(0) start := time.Now() - fileCh := walker.WalkRepo(repoPath, concurrency) + // Walk synchronously so we know the total before parsing starts. + allFiles, err := walker.CollectFiles(repoPath) + if err != nil { + return nil, 0, fmt.Errorf("walk: %w", err) + } + // Pre-filter to supported extensions so total matches actual parser output. + files := allFiles[:0] + for _, f := range allFiles { + if parser.LangForExt(f.Ext) != "" { + files = append(files, f) + } + } + total := len(files) - // Count files while forwarding to parser pool - counted := make(chan walker.FileInfo, concurrency*16) - fileCount := 0 + var bar *progressbar.ProgressBar + if !analyzeQuiet { + bar = progressbar.NewOptions(total, + progressbar.OptionSetWriter(os.Stderr), + progressbar.OptionSetDescription("Parsing"), + progressbar.OptionSetWidth(30), + progressbar.OptionShowCount(), + progressbar.OptionShowIts(), + progressbar.OptionSetItsString("files"), + progressbar.OptionOnCompletion(func() { fmt.Fprint(os.Stderr, "\n") }), + ) + } + + // Fan collected files into a channel for the parser pool. + fileCh := make(chan walker.FileInfo, concurrency*16) go func() { - for f := range fileCh { - counted <- f - fileCount++ + for _, f := range files { + fileCh <- f } - close(counted) + close(fileCh) }() pool := parser.NewParserPool(concurrency) - symsCh := pool.Run(ctx, counted) + symsCh := pool.Run(ctx, fileCh) var allSymbols []graph.FileSymbols for fs := range symsCh { allSymbols = append(allSymbols, fs) + if bar != nil { + _ = bar.Add(1) + } } - fmt.Printf("Walked %s in %s\n", repoPath, time.Since(start).Round(time.Millisecond)) - return allSymbols, fileCount, nil + if !analyzeQuiet { + fmt.Fprintf(os.Stderr, "Parsed %d files (%s)\n", len(allSymbols), time.Since(start).Round(time.Millisecond)) + } + return allSymbols, len(allSymbols), nil } func isIncrementalCandidate(s *store.Store) bool { diff --git a/docs/features.md b/docs/features.md index 63eb1ab..915b0dd 100644 --- a/docs/features.md +++ b/docs/features.md @@ -6,13 +6,17 @@ Mimir provides a suite of features designed specifically for AI agents through t | Tool | Description | |---|---| -| `query` | Hybrid search (Keyword + Vector) that returns results grouped by logical processes. | -| `context` | Returns a "360-degree" view of a specific symbol, including its definition, callers, called functions, and related documentation. | -| `impact` | Performs blast-radius analysis to show what might break if a specific function or class is modified. | +| `query` | Hybrid search (BM25 + vector + centrality) that returns results grouped by logical processes. | +| `context` | Returns a "360-degree" view of a specific symbol: definition, outgoing calls, all incoming edges, and cluster membership. | +| `find_referencing` | Lightweight 1-hop lookup of all symbols that directly call, import, extend, or implement a given symbol. Accepts an optional `edge_types` filter (`CALLS`, `IMPORTS`, `EXTENDS`, `IMPLEMENTS`, `MEMBER_OF`) and `min_confidence` threshold. Faster than `context` when you only need the caller list. | +| `symbol_coordinates` | Returns the exact `file_path`, `start_line`, and `end_line` for every definition of a symbol. Always run this before editing a symbol body so the agent knows the precise location to modify. | +| `get_symbols_overview` | Returns all top-level symbols (functions, classes, interfaces, variables, constants) defined in a given file, sorted by line number. Excludes nested methods and members that belong to a class. Accepts `include_private` (default: `true`) to filter to exported symbols only. Use this to understand the structure of a file before editing. | +| `impact` | Performs blast-radius analysis (BFS up to configurable depth) to show what might break if a specific function or class is modified. Supports `direction` (upstream/downstream) and `min_confidence`. | | `detect_changes` | Analyzes uncommitted git changes and identifies which high-level processes are affected. | -| `rename` | Plans a coordinated rename across the graph and the filesystem. | -| `cypher` | Allows advanced users to run graph queries using a subset of the Cypher query language. | +| `rename` | Plans a coordinated rename across the graph and the filesystem. Use `dry_run=true` first. | +| `cypher` | Allows advanced users to run graph queries using a subset of the Cypher query language (`MATCH` only). | | `list_repos` | Lists all indexed repositories available for querying. | +| `query_repo` | Executes a read-only tool against a different indexed repository without switching context. Accepts `tool_name` (one of: `query`, `context`, `find_referencing`, `symbol_coordinates`, `get_symbols_overview`, `impact`), `arguments`, `target_repo`, and optional `current_repo` for correlation. The response includes a `meta` field with `queried_repo`, `current_repo`, and `tool_used`. | ## Key Features diff --git a/docs/guide.md b/docs/guide.md index 81a703c..a59ee26 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -28,9 +28,23 @@ Before writing code, run `mimir serve` and open `http://localhost:7842`. When chatting with Claude, you can prompt it to use Mimir effectively: * **"Query the graph"**: Instead of letting Claude search files with `grep`, tell it: *"Use Mimir to find all processes related to 'user authentication'."* Mimir will return logical flows, not just lines of code. +* **"Find callers (lightweight)"**: Before touching a function, ask: *"Use `find_referencing` to show me every symbol that calls `ProcessOrder`."* This is faster than `context` when you only need the caller list — it does a single 1-hop edge lookup. You can also filter by edge type: *"find all symbols that **import** the logger package."* +* **"See what's in a file"**: Before editing a file, ask: *"Run `get_symbols_overview` on `internal/store/store.go`.*" This returns every top-level symbol sorted by line number — instantly orient yourself without reading the whole file. Use `include_private: false` to focus on the public API. +* **"Get exact location before editing"**: Before asking Claude to modify a function body, say: *"Run `symbol_coordinates` on `ProcessOrder` first."* Claude will know the exact file and lines to replace without reading the whole file. * **"Check Impact"**: Before asking Claude to refactor, say: *"Run a Mimir impact analysis on the `AuthService` interface."* This prevents Claude from making breaking changes in distant parts of the repo. * **"Trace the process"**: If a bug happens in a specific flow, say: *"Mimir, show me the full process trace starting from the `login` endpoint."* +### Recommended Tool Order for Editing a Symbol + +``` +1. query() → find candidate symbols +2. find_referencing() → who calls it? (decide scope of change) +3. get_symbols_overview() → understand the file's public surface +4. symbol_coordinates() → get file_path + line range +5. impact() → blast radius if it's a public API +6. make the edit → agent replaces lines start_line–end_line +``` + ## 5. Advanced: Raw Graph Queries (Cypher) For power users, Mimir supports a subset of Cypher. You can ask Claude: *"Run a Cypher query to find all exported functions in the 'internal/store' package that have no incoming call edges."* diff --git a/docs/usage.md b/docs/usage.md index ef269a8..b18494e 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -16,12 +16,21 @@ When using Claude Code to develop or query Mimir, use the following guidelines: - **Always start with `query`**: If you don't know where to look, use `mcp__mimir__query` to find the relevant code paths. - **Use `context` for deep dives**: Once you find a symbol, use `mcp__mimir__context` to see its full relationship tree. +- **Use `find_referencing` for lightweight caller lookup**: When you only need to know who calls or imports a symbol (not the full 360° view), `mcp__mimir__find_referencing` is faster. Supports `edge_types` filter — e.g., `["CALLS"]` or `["IMPORTS"]`. +- **Always run `symbol_coordinates` before editing**: Before replacing a function or class body, call `mcp__mimir__symbol_coordinates` to get the exact `file_path`, `start_line`, and `end_line`. This avoids reading entire files just to locate the target. - **Check `impact` before refactoring**: If you are about to change a core interface, run `mcp__mimir__impact` to see the blast radius. - **Analyze on save**: You can configure a git hook to run `mimir analyze` automatically on commit to keep the knowledge graph fresh. -## 3. Example Workflow +## 3. Example Workflow — Understanding a Symbol 1. **User**: "How does the resolver handle interfaces?" 2. **Claude**: Calls `mcp__mimir__query(query="interface resolution")`. 3. **Mimir**: Returns relevant symbols and the "Resolve" process. 4. **Claude**: Calls `mcp__mimir__context(name="Resolve")` to see the logic. 5. **Claude**: Explains the logic to the user using the high-fidelity graph data. + +## 4. Example Workflow — Editing a Symbol Safely +1. **User**: "Refactor `ProcessOrder` to accept a context parameter." +2. **Claude**: Calls `mcp__mimir__find_referencing(name="ProcessOrder", edge_types=["CALLS"])` → sees 4 direct callers. +3. **Claude**: Calls `mcp__mimir__symbol_coordinates(name="ProcessOrder")` → gets `{file_path: "internal/order/service.go", start_line: 42, end_line: 67}`. +4. **Claude**: Calls `mcp__mimir__impact(target="ProcessOrder")` → confirms blast radius before changing the signature. +5. **Claude**: Edits lines 42–67 in `internal/order/service.go` and updates the 4 call sites. diff --git a/go.mod b/go.mod index 32e7bad..d6e08b6 100644 --- a/go.mod +++ b/go.mod @@ -34,9 +34,12 @@ require ( github.com/kevinburke/ssh_config v1.2.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-pointer v0.0.1 // indirect + github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect github.com/ncruces/go-strftime v1.0.0 // indirect github.com/pjbgf/sha1cd v0.3.2 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/rivo/uniseg v0.4.7 // indirect + github.com/schollz/progressbar/v3 v3.19.0 // indirect github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect github.com/skeema/knownhosts v1.3.1 // indirect github.com/spf13/pflag v1.0.9 // indirect @@ -44,6 +47,7 @@ require ( golang.org/x/crypto v0.45.0 // indirect golang.org/x/net v0.47.0 // indirect golang.org/x/sys v0.42.0 // indirect + golang.org/x/term v0.41.0 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect modernc.org/libc v1.70.0 // indirect modernc.org/mathutil v1.7.1 // indirect diff --git a/go.sum b/go.sum index 8057b1c..4bb4c03 100644 --- a/go.sum +++ b/go.sum @@ -60,6 +60,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o0= github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc= +github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ= +github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= @@ -72,11 +74,15 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 h1:OkMGxebDjyw0ULyrTYWeN0UNCCkmCWfjPnIA2W6oviI= github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06/go.mod h1:+ePHsJ1keEjQtpvf9HHw0f4ZeJ0TLRsxhunSI2hYJSs= +github.com/schollz/progressbar/v3 v3.19.0 h1:Ea18xuIRQXLAUidVDox3AbwfUhD0/1IvohyTutOIFoc= +github.com/schollz/progressbar/v3 v3.19.0/go.mod h1:IsO3lpbaGuzh8zIMzgY3+J8l4C8GjO0Y9S69eFvNsec= github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 h1:n661drycOFuPLCN3Uc8sB6B/s6Z4t2xvBgU1htSHuq8= github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= @@ -147,6 +153,8 @@ golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= +golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= +golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= diff --git a/internal/store/query.go b/internal/store/query.go index 1d11536..3153c90 100644 --- a/internal/store/query.go +++ b/internal/store/query.go @@ -86,6 +86,40 @@ func (s *Store) QueryByFile(filePath string) ([]graph.Node, error) { return nodes, err } +// QueryTopLevelByFile returns all top-level symbols in a file — excludes nested +// methods and any symbol that is a target of a MEMBER_OF edge. +// Matches by exact path OR by suffix (handles relative vs absolute path mismatches). +func (s *Store) QueryTopLevelByFile(filePath string) ([]graph.Node, error) { + var nodes []graph.Node + err := s.Read(func(db *sql.DB) error { + rows, err := db.Query(` + SELECT uid, name, kind, file_path, start_line, end_line, exported, package_path, COALESCE(cluster_id,'') + FROM nodes + WHERE (file_path = ? OR file_path LIKE ('%/' || ?)) + AND kind != 'Method' + AND NOT EXISTS ( + SELECT 1 FROM edges WHERE from_uid = uid AND type = 'MEMBER_OF' + ) + ORDER BY start_line`, filePath, filePath) + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + var n graph.Node + var exp int + if err := rows.Scan(&n.UID, &n.Name, &n.Kind, &n.FilePath, + &n.StartLine, &n.EndLine, &exp, &n.PackagePath, &n.ClusterID); err != nil { + return err + } + n.Exported = exp == 1 + nodes = append(nodes, n) + } + return rows.Err() + }) + return nodes, err +} + // QueryNodeByUID fetches a single node by UID. func (s *Store) QueryNodeByUID(uid string) (*graph.Node, error) { var n graph.Node diff --git a/internal/walker/walker.go b/internal/walker/walker.go index ff050b5..9185f01 100644 --- a/internal/walker/walker.go +++ b/internal/walker/walker.go @@ -123,3 +123,52 @@ func WalkRepo(root string, concurrency int) <-chan FileInfo { return out } + +// CollectFiles walks the repository synchronously and returns all discovered +// files as a slice. The total count is known upfront, making it ideal for +// driving progress bars. Uses the same gitignore and skip rules as WalkRepo. +func CollectFiles(root string) ([]FileInfo, error) { + // Load .gitignore if present. + var gi *ignore.GitIgnore + gitignorePath := filepath.Join(root, ".gitignore") + if _, err := os.Stat(gitignorePath); err == nil { + gi, _ = ignore.CompileIgnoreFile(gitignorePath) + } + + var files []FileInfo + err := filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return nil // skip unreadable entries + } + name := d.Name() + if d.IsDir() { + if skipDirs[name] { + return filepath.SkipDir + } + return nil + } + // Apply gitignore. + rel, _ := filepath.Rel(root, path) + if gi != nil && gi.MatchesPath(rel) { + return nil + } + // Skip unwanted suffixes. + for _, suf := range skipSuffixes { + if strings.HasSuffix(name, suf) { + return nil + } + } + info, err := d.Info() + if err != nil { + return nil + } + files = append(files, FileInfo{ + Path: path, + Ext: filepath.Ext(name), + Size: info.Size(), + ModTime: info.ModTime(), + }) + return nil + }) + return files, err +} diff --git a/mcp/tools.go b/mcp/tools.go index 1611caf..07f1895 100644 --- a/mcp/tools.go +++ b/mcp/tools.go @@ -1,9 +1,12 @@ package mcp import ( + "bufio" "context" "encoding/json" "fmt" + "os" + "path/filepath" "strings" "github.com/thuongh2/git-mimir/internal/incremental" @@ -66,6 +69,31 @@ func (t *Tools) ListTools() map[string]interface{} { Description: "List all indexed repositories.", InputSchema: schema(`{"type":"object","properties":{}}`), }, + { + Name: "find_referencing", + Description: "Find all symbols that directly reference (call, import, extend, implement) a given symbol. Lighter than impact — returns 1-hop inbound edges only.", + InputSchema: schema(`{"type":"object","properties":{"name":{"type":"string"},"edge_types":{"type":"array","items":{"type":"string"},"description":"Filter by edge type. One of: CALLS, IMPORTS, EXTENDS, IMPLEMENTS, MEMBER_OF. Defaults to all."},"min_confidence":{"type":"number"},"repo":{"type":"string"}},"required":["name"]}`), + }, + { + Name: "symbol_coordinates", + Description: "Return the exact file path and line range for a symbol. Use before editing — gives the precise location to replace.", + InputSchema: schema(`{"type":"object","properties":{"name":{"type":"string"},"repo":{"type":"string"}},"required":["name"]}`), + }, + { + Name: "get_symbols_overview", + Description: "Gets an overview of all top-level symbols defined in a given file, sorted by line number. Excludes nested methods and members. Use to understand file structure before editing.", + InputSchema: schema(`{"type":"object","properties":{"file_path":{"type":"string"},"include_private":{"type":"boolean","description":"Include non-exported symbols. Defaults to true."},"repo":{"type":"string","description":"Name of the indexed repository to query. Required."}},"required":["file_path","repo"]}`), + }, + { + Name: "find_symbol_body", + Description: "Returns the exact source code body of a symbol (function, method, class) by name, including file_path, start_line, end_line, and the full implementation text. Use when you see a function name in logs or stack traces — fetches only the relevant lines instead of reading the whole file.", + InputSchema: schema(`{"type":"object","properties":{"name":{"type":"string","description":"Function, method, or class name to look up"},"repo":{"type":"string"}},"required":["name"]}`), + }, + { + Name: "query_repo", + Description: "Execute a whitelisted read-only tool against a different indexed repository. Pass tool_name, arguments, target_repo, and optional current_repo. Allowed tools: query, context, find_referencing, symbol_coordinates, get_symbols_overview, impact.", + InputSchema: schema(`{"type":"object","properties":{"tool_name":{"type":"string","description":"Tool to invoke on the target repo. One of: query, context, find_referencing, symbol_coordinates, get_symbols_overview, impact."},"arguments":{"type":"object","description":"Arguments to pass to the tool."},"target_repo":{"type":"string","description":"Name of the target repository to query."},"current_repo":{"type":"string","description":"Name of the current repository (optional, for context)."}},"required":["tool_name","arguments","target_repo"]}`), + }, }, } } @@ -97,6 +125,16 @@ func (t *Tools) Call(ctx context.Context, params json.RawMessage) Response { return t.rename(ctx, p.Arguments) case "cypher": return t.cypher(ctx, p.Arguments) + case "find_referencing": + return t.findReferencing(ctx, p.Arguments) + case "symbol_coordinates": + return t.symbolCoordinates(ctx, p.Arguments) + case "get_symbols_overview": + return t.getSymbolsOverview(ctx, p.Arguments) + case "query_repo": + return t.queryRepo(ctx, p.Arguments) + case "find_symbol_body": + return t.findSymbolBody(ctx, p.Arguments) default: return errResp(ErrMethodNotFound, "unknown tool: "+p.Name) } @@ -375,6 +413,364 @@ func (t *Tools) cypher(ctx context.Context, args json.RawMessage) Response { }) } +func (t *Tools) findReferencing(ctx context.Context, args json.RawMessage) Response { + var input struct { + Name string `json:"name"` + EdgeTypes []string `json:"edge_types"` + MinConfidence *float64 `json:"min_confidence"` + Repo *string `json:"repo"` + } + if err := json.Unmarshal(args, &input); err != nil { + return errResp(ErrInvalidParams, err.Error()) + } + + logDebug("findReferencing: name=%q edgeTypes=%v minConf=%v", input.Name, input.EdgeTypes, input.MinConfidence) + + s, err := t.openStore(input.Repo) + if err != nil { + logError("findReferencing.openStore", err) + return errResp(ErrInternal, err.Error()) + } + defer s.Close() + + nodes, err := s.QuerySymbol(input.Name) + if err != nil { + logError("findReferencing.QuerySymbol", err) + return errResp(ErrInternal, err.Error()) + } + if len(nodes) == 0 { + logDebug("findReferencing: symbol not found: %s", input.Name) + return toolResult(map[string]interface{}{"symbol": nil, "message": "symbol not found"}) + } + + target := nodes[0] + + minConf := 0.0 + if input.MinConfidence != nil { + minConf = *input.MinConfidence + } + + // Build edge-type filter set. + edgeTypeSet := map[string]bool{} + for _, et := range input.EdgeTypes { + edgeTypeSet[strings.ToUpper(et)] = true + } + + inEdges, err := s.QueryEdgesTo(target.UID) + if err != nil { + logError("findReferencing.QueryEdgesTo", err) + return errResp(ErrInternal, err.Error()) + } + + type refEntry struct { + Symbol interface{} `json:"symbol"` + EdgeType string `json:"edge_type"` + Confidence float64 `json:"confidence"` + } + + refs := make([]refEntry, 0, len(inEdges)) + for _, e := range inEdges { + if len(edgeTypeSet) > 0 && !edgeTypeSet[e.Type] { + continue + } + if e.Confidence < minConf { + continue + } + caller, err := s.QueryNodeByUID(e.FromUID) + if err != nil || caller == nil { + continue + } + refs = append(refs, refEntry{ + Symbol: caller, + EdgeType: e.Type, + Confidence: e.Confidence, + }) + } + + logDebug("findReferencing: %d references found for %s", len(refs), input.Name) + return toolResult(map[string]interface{}{ + "target": target.Name, + "total": len(refs), + "references": refs, + }) +} + +func (t *Tools) symbolCoordinates(ctx context.Context, args json.RawMessage) Response { + var input struct { + Name string `json:"name"` + Repo *string `json:"repo"` + } + if err := json.Unmarshal(args, &input); err != nil { + return errResp(ErrInvalidParams, err.Error()) + } + + logDebug("symbolCoordinates: name=%q repo=%v", input.Name, input.Repo) + + s, err := t.openStore(input.Repo) + if err != nil { + logError("symbolCoordinates.openStore", err) + return errResp(ErrInternal, err.Error()) + } + defer s.Close() + + nodes, err := s.QuerySymbol(input.Name) + if err != nil { + logError("symbolCoordinates.QuerySymbol", err) + return errResp(ErrInternal, err.Error()) + } + if len(nodes) == 0 { + logDebug("symbolCoordinates: symbol not found: %s", input.Name) + return toolResult(map[string]interface{}{"symbol": nil, "message": "symbol not found"}) + } + + results := make([]map[string]interface{}, 0, len(nodes)) + for _, n := range nodes { + results = append(results, map[string]interface{}{ + "name": n.Name, + "kind": n.Kind, + "file_path": n.FilePath, + "start_line": n.StartLine, + "end_line": n.EndLine, + "package_path": n.PackagePath, + "exported": n.Exported, + }) + } + + logDebug("symbolCoordinates: found %d locations for %s", len(results), input.Name) + return toolResult(map[string]interface{}{ + "name": input.Name, + "matches": results, + }) +} + +func (t *Tools) getSymbolsOverview(ctx context.Context, args json.RawMessage) Response { + var input struct { + FilePath string `json:"file_path"` + IncludePrivate *bool `json:"include_private"` + Repo *string `json:"repo"` + } + if err := json.Unmarshal(args, &input); err != nil { + return errResp(ErrInvalidParams, err.Error()) + } + + logDebug("getSymbolsOverview: file_path=%q repo=%v", input.FilePath, input.Repo) + + s, err := t.openStore(input.Repo) + if err != nil { + logError("getSymbolsOverview.openStore", err) + return errResp(ErrInternal, err.Error()) + } + defer s.Close() + + // Nodes are stored with absolute paths. Normalize a relative input path + // using the repo root stored in index_meta so agents can pass relative paths. + filePath := input.FilePath + if !filepath.IsAbs(filePath) { + if repoRoot, err := s.GetMeta("repo_path"); err == nil && repoRoot != "" { + filePath = filepath.Join(repoRoot, filePath) + } + } + + nodes, err := s.QueryTopLevelByFile(filePath) + if err != nil { + logError("getSymbolsOverview.QueryTopLevelByFile", err) + return errResp(ErrInternal, err.Error()) + } + + // By default include all symbols; filter private only when explicitly disabled. + includePrivate := true + if input.IncludePrivate != nil { + includePrivate = *input.IncludePrivate + } + + results := make([]map[string]interface{}, 0, len(nodes)) + for _, n := range nodes { + if !includePrivate && !n.Exported { + continue + } + results = append(results, map[string]interface{}{ + "name": n.Name, + "kind": n.Kind, + "start_line": n.StartLine, + "end_line": n.EndLine, + "exported": n.Exported, + "package_path": n.PackagePath, + }) + } + + logDebug("getSymbolsOverview: %d top-level symbols in %s", len(results), filePath) + return toolResult(map[string]interface{}{ + "file_path": input.FilePath, + "total": len(results), + "symbols": results, + }) +} + +// queryRepoWhitelist contains tools safe to execute cross-repo. +var queryRepoWhitelist = map[string]bool{ + "query": true, + "context": true, + "find_referencing": true, + "symbol_coordinates": true, + "get_symbols_overview": true, + "impact": true, +} + +func (t *Tools) queryRepo(ctx context.Context, args json.RawMessage) Response { + var input struct { + ToolName string `json:"tool_name"` + Arguments json.RawMessage `json:"arguments"` + TargetRepo string `json:"target_repo"` + CurrentRepo *string `json:"current_repo"` + } + if err := json.Unmarshal(args, &input); err != nil { + return errResp(ErrInvalidParams, err.Error()) + } + if !queryRepoWhitelist[input.ToolName] { + return errResp(ErrInvalidParams, fmt.Sprintf("tool %q is not allowed in query_repo; allowed: query, context, find_referencing, symbol_coordinates, get_symbols_overview, impact", input.ToolName)) + } + + // Inject target_repo into arguments. + var argMap map[string]interface{} + if len(input.Arguments) > 0 { + if err := json.Unmarshal(input.Arguments, &argMap); err != nil { + return errResp(ErrInvalidParams, "invalid arguments: "+err.Error()) + } + } else { + argMap = map[string]interface{}{} + } + argMap["repo"] = input.TargetRepo + injectedArgs, _ := json.Marshal(argMap) + + // Delegate to existing dispatcher via Call(). + delegateParams, _ := json.Marshal(map[string]interface{}{ + "name": input.ToolName, + "arguments": json.RawMessage(injectedArgs), + }) + resp := t.Call(ctx, delegateParams) + if resp.Error != nil { + return resp + } + + // Attach meta inside the content text so the agent can read it. + currentRepo := "" + if input.CurrentRepo != nil { + currentRepo = *input.CurrentRepo + } + meta := map[string]interface{}{ + "queried_repo": input.TargetRepo, + "current_repo": currentRepo, + "tool_used": input.ToolName, + } + if resultMap, ok := resp.Result.(map[string]interface{}); ok { + if content, ok := resultMap["content"].([]map[string]interface{}); ok && len(content) > 0 { + if text, ok := content[0]["text"].(string); ok { + var innerData map[string]interface{} + if err := json.Unmarshal([]byte(text), &innerData); err == nil { + innerData["meta"] = meta + if b, err := json.Marshal(innerData); err == nil { + content[0]["text"] = string(b) + } + } + } + } + } + logDebug("queryRepo: tool=%s target=%s", input.ToolName, input.TargetRepo) + return resp +} + +func (t *Tools) findSymbolBody(ctx context.Context, args json.RawMessage) Response { + var input struct { + Name string `json:"name"` + Repo *string `json:"repo"` + } + if err := json.Unmarshal(args, &input); err != nil { + return errResp(ErrInvalidParams, err.Error()) + } + + logDebug("findSymbolBody: name=%q repo=%v", input.Name, input.Repo) + + s, err := t.openStore(input.Repo) + if err != nil { + logError("findSymbolBody.openStore", err) + return errResp(ErrInternal, err.Error()) + } + defer s.Close() + + nodes, err := s.QuerySymbol(input.Name) + if err != nil { + logError("findSymbolBody.QuerySymbol", err) + return errResp(ErrInternal, err.Error()) + } + if len(nodes) == 0 { + logDebug("findSymbolBody: symbol not found: %s", input.Name) + return toolResult(map[string]interface{}{"symbol": nil, "message": "symbol not found"}) + } + + type bodyResult struct { + Name string `json:"name"` + Kind string `json:"kind"` + FilePath string `json:"file_path"` + StartLine uint `json:"start_line"` + EndLine uint `json:"end_line"` + PackagePath string `json:"package_path"` + Exported bool `json:"exported"` + Body string `json:"body"` + } + + results := make([]bodyResult, 0, len(nodes)) + for _, n := range nodes { + body, readErr := readLines(n.FilePath, n.StartLine, n.EndLine) + if readErr != nil { + logDebug("findSymbolBody: read %s: %v", n.FilePath, readErr) + body = "" + } + results = append(results, bodyResult{ + Name: n.Name, + Kind: n.Kind, + FilePath: n.FilePath, + StartLine: n.StartLine, + EndLine: n.EndLine, + PackagePath: n.PackagePath, + Exported: n.Exported, + Body: body, + }) + } + + logDebug("findSymbolBody: %d matches for %s", len(results), input.Name) + return toolResult(map[string]interface{}{ + "name": input.Name, + "matches": results, + }) +} + +// readLines reads lines [start, end] (1-based, inclusive) from a file. +func readLines(filePath string, start, end uint) (string, error) { + f, err := os.Open(filePath) + if err != nil { + return "", fmt.Errorf("open %s: %w", filePath, err) + } + defer f.Close() + + var sb strings.Builder + scanner := bufio.NewScanner(f) + var lineNum uint = 1 + for scanner.Scan() { + if lineNum >= start && lineNum <= end { + sb.WriteString(scanner.Text()) + sb.WriteByte('\n') + } + if lineNum > end { + break + } + lineNum++ + } + if err := scanner.Err(); err != nil { + return "", fmt.Errorf("scan %s: %w", filePath, err) + } + return sb.String(), nil +} + // openStore opens the store for the given repo name. func (t *Tools) openStore(repoName *string) (*store.Store, error) { name := t.resolveRepoName(repoName) @@ -560,8 +956,8 @@ func translateEdgeWhereClause(where string) string { func translatePropRefs(s string) string { replacements := map[string]string{ - "filePath": "file_path", - "fileName": "file_path", + "filePath": "file_path", + "fileName": "file_path", "startLine": "start_line", "endLine": "end_line", } diff --git a/mcp/tools_test.go b/mcp/tools_test.go index 7e476d0..7934a5b 100644 --- a/mcp/tools_test.go +++ b/mcp/tools_test.go @@ -3,12 +3,47 @@ package mcp_test import ( "context" "encoding/json" + "os" + "strings" "testing" + "github.com/thuongh2/git-mimir/internal/graph" "github.com/thuongh2/git-mimir/internal/registry" + "github.com/thuongh2/git-mimir/internal/store" "github.com/thuongh2/git-mimir/mcp" ) +// seedTestRepo creates an isolated store in a temp HOME, seeds it with data, +// and returns a Tools instance pointing at that repo. +func seedTestRepo(t *testing.T, nodes []graph.Node, edges []graph.Edge) *mcp.Tools { + t.Helper() + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + + dbPath, err := registry.DBPath("testrepo") + if err != nil { + t.Fatalf("DBPath: %v", err) + } + s, err := store.OpenStore(dbPath) + if err != nil { + t.Fatalf("OpenStore: %v", err) + } + if len(nodes) > 0 { + if err := s.BatchUpsertNodes(nodes); err != nil { + t.Fatalf("BatchUpsertNodes: %v", err) + } + } + if len(edges) > 0 { + if err := s.BatchUpsertEdges(edges); err != nil { + t.Fatalf("BatchUpsertEdges: %v", err) + } + } + s.Close() + + reg := ®istry.Registry{Repos: []registry.RepoInfo{{Name: "testrepo", Path: tmpHome}}} + return mcp.NewTools(reg) +} + func TestTools_ListTools(t *testing.T) { reg := ®istry.Registry{} tools := mcp.NewTools(reg) @@ -24,8 +59,8 @@ func TestTools_ListTools(t *testing.T) { t.Fatal("tools is not []ToolDefinition") } - if len(defs) != 7 { - t.Errorf("expected 7 tools, got %d", len(defs)) + if len(defs) != 12 { + t.Errorf("expected 12 tools, got %d", len(defs)) } names := map[string]bool{} @@ -34,7 +69,7 @@ func TestTools_ListTools(t *testing.T) { t.Logf(" tool: %s", d.Name) } - for _, want := range []string{"query", "context", "impact", "detect_changes", "rename", "cypher", "list_repos"} { + for _, want := range []string{"query", "context", "impact", "detect_changes", "rename", "cypher", "list_repos", "find_referencing", "symbol_coordinates", "get_symbols_overview", "query_repo", "find_symbol_body"} { if !names[want] { t.Errorf("missing tool: %s", want) } @@ -60,3 +95,535 @@ func TestTools_ListRepos(t *testing.T) { t.Fatal("list_repos returned nil result") } } + +func TestTools_FindReferencing(t *testing.T) { + caller := graph.Node{UID: "n1", Name: "callMe", Kind: "Function", FilePath: "a.go", StartLine: 1, EndLine: 5} + callee := graph.Node{UID: "n2", Name: "target", Kind: "Function", FilePath: "b.go", StartLine: 10, EndLine: 20} + edge := graph.Edge{FromUID: "n1", ToUID: "n2", Type: graph.EdgeCalls, Confidence: 0.9} + + tools := seedTestRepo(t, []graph.Node{caller, callee}, []graph.Edge{edge}) + + cases := []struct { + name string + symbolName string + edgeTypes []string + minConf *float64 + wantTotal int + wantNotFound bool + }{ + { + name: "finds direct caller", + symbolName: "target", + wantTotal: 1, + }, + { + name: "filters by edge type match", + symbolName: "target", + edgeTypes: []string{"CALLS"}, + wantTotal: 1, + }, + { + name: "filters by edge type no match", + symbolName: "target", + edgeTypes: []string{"IMPORTS"}, + wantTotal: 0, + }, + { + name: "symbol not found", + symbolName: "nonexistent", + wantNotFound: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + input := map[string]any{"name": tc.symbolName} + if len(tc.edgeTypes) > 0 { + input["edge_types"] = tc.edgeTypes + } + if tc.minConf != nil { + input["min_confidence"] = *tc.minConf + } + args, _ := json.Marshal(map[string]any{ + "name": "find_referencing", + "arguments": input, + }) + resp := tools.Call(context.Background(), args) + if resp.Error != nil { + t.Fatalf("find_referencing error: %s", resp.Error.Message) + } + + content := extractText(t, resp) + t.Logf("find_referencing response: %s", content) + + var result map[string]any + if err := json.Unmarshal([]byte(content), &result); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + + if tc.wantNotFound { + if result["symbol"] != nil { + t.Errorf("expected symbol=nil, got %v", result["symbol"]) + } + return + } + + total, ok := result["total"].(float64) + if !ok { + t.Fatalf("total missing or wrong type: %v", result["total"]) + } + if int(total) != tc.wantTotal { + t.Errorf("total = %d, want %d", int(total), tc.wantTotal) + } + }) + } +} + +func TestTools_SymbolCoordinates(t *testing.T) { + fn := graph.Node{UID: "n1", Name: "MyFunc", Kind: "Function", FilePath: "pkg/foo.go", StartLine: 42, EndLine: 67, Exported: true, PackagePath: "pkg"} + tools := seedTestRepo(t, []graph.Node{fn}, nil) + + cases := []struct { + name string + symbolName string + wantNotFound bool + wantStartLine uint + wantEndLine uint + }{ + { + name: "returns coordinates for known symbol", + symbolName: "MyFunc", + wantStartLine: 42, + wantEndLine: 67, + }, + { + name: "symbol not found", + symbolName: "UnknownFunc", + wantNotFound: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + args, _ := json.Marshal(map[string]any{ + "name": "symbol_coordinates", + "arguments": map[string]any{"name": tc.symbolName}, + }) + resp := tools.Call(context.Background(), args) + if resp.Error != nil { + t.Fatalf("symbol_coordinates error: %s", resp.Error.Message) + } + + content := extractText(t, resp) + t.Logf("symbol_coordinates response: %s", content) + + var result map[string]any + if err := json.Unmarshal([]byte(content), &result); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + + if tc.wantNotFound { + if result["symbol"] != nil { + t.Errorf("expected symbol=nil, got %v", result["symbol"]) + } + return + } + + matches, ok := result["matches"].([]any) + if !ok || len(matches) == 0 { + t.Fatalf("matches missing or empty: %v", result["matches"]) + } + m := matches[0].(map[string]any) + + if got := m["file_path"]; got != fn.FilePath { + t.Errorf("file_path = %v, want %v", got, fn.FilePath) + } + if got := m["start_line"].(float64); uint(got) != tc.wantStartLine { + t.Errorf("start_line = %v, want %v", got, tc.wantStartLine) + } + if got := m["end_line"].(float64); uint(got) != tc.wantEndLine { + t.Errorf("end_line = %v, want %v", got, tc.wantEndLine) + } + }) + } +} + +func TestTools_GetSymbolsOverview(t *testing.T) { + class := graph.Node{UID: "c1", Name: "MyClass", Kind: "Class", FilePath: "src/foo.go", StartLine: 1, EndLine: 50, Exported: true, PackagePath: "src"} + method := graph.Node{UID: "m1", Name: "doWork", Kind: "Method", FilePath: "src/foo.go", StartLine: 10, EndLine: 20, Exported: false, PackagePath: "src"} + topFn := graph.Node{UID: "f1", Name: "HelperFunc", Kind: "Function", FilePath: "src/foo.go", StartLine: 55, EndLine: 60, Exported: true, PackagePath: "src"} + privateFn := graph.Node{UID: "f2", Name: "internalFn", Kind: "Function", FilePath: "src/foo.go", StartLine: 62, EndLine: 70, Exported: false, PackagePath: "src"} + memberEdge := graph.Edge{FromUID: "m1", ToUID: "c1", Type: graph.EdgeMemberOf, Confidence: 1.0} + + tools := seedTestRepo(t, + []graph.Node{class, method, topFn, privateFn}, + []graph.Edge{memberEdge}, + ) + + cases := []struct { + name string + filePath string + includePrivate *bool + wantTotal int + wantNames []string + wantAbsent []string + }{ + { + name: "all top-level symbols (default includes private)", + filePath: "src/foo.go", + wantTotal: 3, // class + topFn + privateFn (method excluded via MEMBER_OF) + wantNames: []string{"MyClass", "HelperFunc", "internalFn"}, + wantAbsent: []string{"doWork"}, + }, + { + name: "only exported symbols", + filePath: "src/foo.go", + includePrivate: boolPtr(false), + wantTotal: 2, + wantNames: []string{"MyClass", "HelperFunc"}, + wantAbsent: []string{"internalFn", "doWork"}, + }, + { + name: "empty file returns zero symbols", + filePath: "nonexistent/file.go", + wantTotal: 0, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + input := map[string]any{"file_path": tc.filePath} + if tc.includePrivate != nil { + input["include_private"] = *tc.includePrivate + } + args, _ := json.Marshal(map[string]any{ + "name": "get_symbols_overview", + "arguments": input, + }) + resp := tools.Call(context.Background(), args) + if resp.Error != nil { + t.Fatalf("get_symbols_overview error: %s", resp.Error.Message) + } + + content := extractText(t, resp) + t.Logf("get_symbols_overview response: %s", content) + + var result map[string]any + if err := json.Unmarshal([]byte(content), &result); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + + total, ok := result["total"].(float64) + if !ok { + t.Fatalf("total missing or wrong type: %v", result["total"]) + } + if int(total) != tc.wantTotal { + t.Errorf("total = %d, want %d", int(total), tc.wantTotal) + } + + syms, _ := result["symbols"].([]any) + gotNames := map[string]bool{} + for _, sym := range syms { + if m, ok := sym.(map[string]any); ok { + gotNames[m["name"].(string)] = true + } + } + for _, want := range tc.wantNames { + if !gotNames[want] { + t.Errorf("expected symbol %q in results, got %v", want, gotNames) + } + } + for _, absent := range tc.wantAbsent { + if gotNames[absent] { + t.Errorf("symbol %q should not appear in results", absent) + } + } + }) + } +} + +func boolPtr(b bool) *bool { return &b } + +// TestTools_GetSymbolsOverview_LikeSuffix verifies that a relative file_path +// matches nodes stored with absolute paths via the LIKE suffix fallback in SQL, +// even when there is NO repo_path stored in index_meta (handler normalization +// cannot help). This is the primary regression test for the empty-results bug. +func TestTools_GetSymbolsOverview_LikeSuffix(t *testing.T) { + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + + dbPath, err := registry.DBPath("testrepo") + if err != nil { + t.Fatalf("DBPath: %v", err) + } + s, err := store.OpenStore(dbPath) + if err != nil { + t.Fatalf("OpenStore: %v", err) + } + // Store nodes with absolute paths, no repo_path meta set. + nodes := []graph.Node{ + {UID: "b1", Name: "FlashcardHandler", Kind: "Class", FilePath: "/abs/repo/internal/handler/flashcard_handler.go", StartLine: 1, EndLine: 40, Exported: true, PackagePath: "handler"}, + {UID: "b2", Name: "NewFlashcardHandler", Kind: "Function", FilePath: "/abs/repo/internal/handler/flashcard_handler.go", StartLine: 42, EndLine: 50, Exported: true, PackagePath: "handler"}, + } + if err := s.BatchUpsertNodes(nodes); err != nil { + t.Fatalf("BatchUpsertNodes: %v", err) + } + // Intentionally no SetMeta("repo_path", ...) — handler normalization won't fire. + s.Close() + + reg := ®istry.Registry{Repos: []registry.RepoInfo{{Name: "testrepo", Path: tmpHome}}} + tools := mcp.NewTools(reg) + + args, _ := json.Marshal(map[string]any{ + "name": "get_symbols_overview", + "arguments": map[string]any{"file_path": "internal/handler/flashcard_handler.go"}, + }) + resp := tools.Call(context.Background(), args) + if resp.Error != nil { + t.Fatalf("get_symbols_overview error: %s", resp.Error.Message) + } + + content := extractText(t, resp) + t.Logf("response: %s", content) + + var result map[string]any + if err := json.Unmarshal([]byte(content), &result); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + total, _ := result["total"].(float64) + if int(total) != 2 { + t.Errorf("total = %d, want 2 (LIKE suffix should match absolute stored paths)", int(total)) + } + syms, _ := result["symbols"].([]any) + found := map[string]bool{} + for _, sym := range syms { + if m, ok := sym.(map[string]any); ok { + found[m["name"].(string)] = true + } + } + if !found["FlashcardHandler"] || !found["NewFlashcardHandler"] { + t.Errorf("expected FlashcardHandler and NewFlashcardHandler, got %v", found) + } +} + +// TestTools_GetSymbolsOverview_RelativePath verifies that the handler correctly +// resolves a relative file_path against the repo_path stored in index_meta. +// Nodes are stored with absolute paths (as the indexer does) and the agent +// supplies a relative path — this should still return results. +func TestTools_GetSymbolsOverview_RelativePath(t *testing.T) { + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + + const repoRoot = "/abs/fake/root" + + dbPath, err := registry.DBPath("testrepo") + if err != nil { + t.Fatalf("DBPath: %v", err) + } + s, err := store.OpenStore(dbPath) + if err != nil { + t.Fatalf("OpenStore: %v", err) + } + nodes := []graph.Node{ + {UID: "a1", Name: "Handler", Kind: "Function", FilePath: repoRoot + "/internal/handler/file.go", StartLine: 1, EndLine: 20, Exported: true, PackagePath: "handler"}, + {UID: "a2", Name: "helper", Kind: "Function", FilePath: repoRoot + "/internal/handler/file.go", StartLine: 22, EndLine: 30, Exported: false, PackagePath: "handler"}, + } + if err := s.BatchUpsertNodes(nodes); err != nil { + t.Fatalf("BatchUpsertNodes: %v", err) + } + if err := s.SetMeta("repo_path", repoRoot); err != nil { + t.Fatalf("SetMeta repo_path: %v", err) + } + s.Close() + + reg := ®istry.Registry{Repos: []registry.RepoInfo{{Name: "testrepo", Path: tmpHome}}} + tools := mcp.NewTools(reg) + + args, _ := json.Marshal(map[string]any{ + "name": "get_symbols_overview", + "arguments": map[string]any{"file_path": "internal/handler/file.go"}, + }) + resp := tools.Call(context.Background(), args) + if resp.Error != nil { + t.Fatalf("get_symbols_overview error: %s", resp.Error.Message) + } + + content := extractText(t, resp) + t.Logf("response: %s", content) + + var result map[string]any + if err := json.Unmarshal([]byte(content), &result); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + total, _ := result["total"].(float64) + if int(total) != 2 { + t.Errorf("total = %d, want 2 (relative path should resolve against repo_path)", int(total)) + } + + syms, _ := result["symbols"].([]any) + found := map[string]bool{} + for _, sym := range syms { + if m, ok := sym.(map[string]any); ok { + found[m["name"].(string)] = true + } + } + if !found["Handler"] || !found["helper"] { + t.Errorf("expected Handler and helper in results, got %v", found) + } +} + +func TestTools_QueryRepo(t *testing.T) { + fn := graph.Node{UID: "n1", Name: "Greet", Kind: "Function", FilePath: "src/a.go", StartLine: 1, EndLine: 5, Exported: true} + tools := seedTestRepo(t, []graph.Node{fn}, nil) + + t.Run("delegates query tool with meta", func(t *testing.T) { + args, _ := json.Marshal(map[string]any{ + "name": "query_repo", + "arguments": map[string]any{ + "tool_name": "query", + "arguments": map[string]any{"query": "Greet"}, + "target_repo": "testrepo", + "current_repo": "otherrepo", + }, + }) + resp := tools.Call(context.Background(), args) + if resp.Error != nil { + t.Fatalf("query_repo error: %s", resp.Error.Message) + } + content := extractText(t, resp) + t.Logf("query_repo response: %s", content) + + var result map[string]any + if err := json.Unmarshal([]byte(content), &result); err != nil { + t.Fatalf("unmarshal: %v", err) + } + meta, ok := result["meta"].(map[string]any) + if !ok { + t.Fatalf("meta missing from response: %v", result) + } + if meta["queried_repo"] != "testrepo" { + t.Errorf("meta.queried_repo = %v, want testrepo", meta["queried_repo"]) + } + if meta["current_repo"] != "otherrepo" { + t.Errorf("meta.current_repo = %v, want otherrepo", meta["current_repo"]) + } + if meta["tool_used"] != "query" { + t.Errorf("meta.tool_used = %v, want query", meta["tool_used"]) + } + }) + + t.Run("rejects non-whitelisted tool", func(t *testing.T) { + args, _ := json.Marshal(map[string]any{ + "name": "query_repo", + "arguments": map[string]any{ + "tool_name": "rename", + "arguments": map[string]any{}, + "target_repo": "testrepo", + }, + }) + resp := tools.Call(context.Background(), args) + if resp.Error == nil { + t.Fatal("expected error for non-whitelisted tool, got nil") + } + t.Logf("error (expected): %s", resp.Error.Message) + }) + + t.Run("rejects query_repo itself (recursion protection)", func(t *testing.T) { + args, _ := json.Marshal(map[string]any{ + "name": "query_repo", + "arguments": map[string]any{ + "tool_name": "query_repo", + "arguments": map[string]any{}, + "target_repo": "testrepo", + }, + }) + resp := tools.Call(context.Background(), args) + if resp.Error == nil { + t.Fatal("expected error for query_repo recursion, got nil") + } + t.Logf("error (expected): %s", resp.Error.Message) + }) +} + +func TestTools_FindSymbolBody(t *testing.T) { + fn := graph.Node{ + UID: "n1", Name: "ProcessOrder", Kind: "Function", + FilePath: "src/order.go", StartLine: 3, EndLine: 5, + Exported: true, PackagePath: "src", + } + tools := seedTestRepo(t, []graph.Node{fn}, nil) + + // Write a real file matching the stored coordinates. + dir := t.TempDir() + srcDir := dir + "/src" + _ = os.MkdirAll(srcDir, 0755) + filePath := srcDir + "/order.go" + content := "package src\n\n// ProcessOrder handles orders\nfunc ProcessOrder() {}\n// end\n" + if err := os.WriteFile(filePath, []byte(content), 0600); err != nil { + t.Fatalf("WriteFile: %v", err) + } + + // Update the node's FilePath to the real temp file. + fn.FilePath = filePath + dbPath, err := registry.DBPath("testrepo") + if err != nil { + t.Fatalf("DBPath: %v", err) + } + s2, err := store.OpenStore(dbPath) + if err != nil { + t.Fatalf("OpenStore: %v", err) + } + if err := s2.BatchUpsertNodes([]graph.Node{fn}); err != nil { + t.Fatalf("BatchUpsertNodes: %v", err) + } + s2.Close() + + args, _ := json.Marshal(map[string]any{ + "name": "find_symbol_body", + "arguments": map[string]any{"name": "ProcessOrder"}, + }) + resp := tools.Call(context.Background(), args) + if resp.Error != nil { + t.Fatalf("find_symbol_body error: %s", resp.Error.Message) + } + + text := extractText(t, resp) + t.Logf("find_symbol_body response: %s", text) + + var result map[string]any + if err := json.Unmarshal([]byte(text), &result); err != nil { + t.Fatalf("unmarshal: %v", err) + } + matches, _ := result["matches"].([]any) + if len(matches) == 0 { + t.Fatal("expected at least one match") + } + m := matches[0].(map[string]any) + if m["name"] != "ProcessOrder" { + t.Errorf("name = %v, want ProcessOrder", m["name"]) + } + body, _ := m["body"].(string) + if body == "" { + t.Error("body is empty") + } + if !strings.Contains(body, "ProcessOrder") { + t.Errorf("body does not contain 'ProcessOrder': %q", body) + } +} + +// extractText pulls the text content from a tool response. +func extractText(t *testing.T, resp mcp.Response) string { + t.Helper() + resultMap, ok := resp.Result.(map[string]interface{}) + if !ok { + t.Fatalf("resp.Result wrong type: %T", resp.Result) + } + content, _ := resultMap["content"].([]map[string]interface{}) + if len(content) == 0 { + t.Fatal("content is empty") + } + text, _ := content[0]["text"].(string) + return text +}