From d5d5c9734debd4fe06e666c826f6d929fe040f8f Mon Sep 17 00:00:00 2001 From: MollyAI Date: Sun, 17 May 2026 16:26:41 -0400 Subject: [PATCH] Add MCP progress notifications --- CHANGELOG.md | 1 + README.md | 2 +- docs/ARCHITECTURE.md | 2 + docs/mcp-tools.md | 11 + .../recallforge-memory-mcp-roadmap.md | 8 +- src/recallforge/search.py | 12 +- src/recallforge/server.py | 330 ++++++++++++++++-- tests/test_mcp_progress.py | 206 +++++++++++ 8 files changed, 533 insertions(+), 39 deletions(-) create mode 100644 tests/test_mcp_progress.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 39a7bd3..7e55bd7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ All notable changes to RecallForge will be documented in this file. - Added staged background reindex promotion so document, video, audio, and conversation replacements stay hidden until their parent/child memory batches are complete. - Added index-version-aware query caching for repeated text/media embeddings and generated expansion branches. +- Added MCP progress notifications for long-running search, ingest, batch, memory write, and FTS rebuild tool calls when clients provide a progress token. - Added deterministic memory graph enrichment with entity/relation side tables and new `memory_graph_entities` / `memory_graph_related` MCP tools. - Replaced the tiny UAT video clips with compact episodic-memory fixtures, richer transcript sidecars, related artifact metadata, and regression coverage for the video corpus. - Added `memory_add_conversation` so conversation threads ingest as canonical parent memories with turn-level child memories and standard memory rollups. diff --git a/README.md b/README.md index db52195..b969bd9 100644 --- a/README.md +++ b/README.md @@ -146,7 +146,7 @@ Run over HTTP/SSE: recallforge serve --http --host 127.0.0.1 --port 7433 --mode embed ``` -RecallForge now exposes **26 MCP tools** across search, ingest, memory graph navigation, collection admin, and runtime config. HTTP/SSE mode also exposes `/health`, `/sse`, and `/messages/`. +RecallForge now exposes **26 MCP tools** across search, ingest, memory graph navigation, collection admin, and runtime config. HTTP/SSE mode also exposes `/health`, `/sse`, and `/messages/`. Long-running tools emit MCP `notifications/progress` when the client supplies a request `_meta.progressToken`, so compatible HTTP/SSE clients can show live progress for ingest, search, batch, memory writes, and FTS rebuilds. See [docs/mcp-tools.md](docs/mcp-tools.md) for the full tool reference. diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index 6b209e9..d55f59c 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -216,6 +216,7 @@ Tools: 26 MCP tools across search, ingest, memory, memory graph, collection admi Transport: stdio (default) or HTTP/SSE (`/health`, `/sse`, `/messages/`) Startup: backend.warm_up() for predictable latency Signals: SIGTERM/SIGINT graceful shutdown +Progress: request `_meta.progressToken` enables `notifications/progress` during long-running tool calls ``` Key runtime details: @@ -223,6 +224,7 @@ Key runtime details: - Blocking tool work is routed through a bounded async semaphore to avoid overloading local model/runtime resources - HTTP mode requires the optional `server` extra (`starlette` + `uvicorn`) - Runtime-safe config changes (`mode`, `collection`, `rerank_top_k`, `caption_media`, model IDs) are exposed through `get_config` / `set_config` +- Progress notifications are best-effort and preserve stable final response JSON. `search_batch` reports per-query completion before returning the final merged results; `batch` reports per-operation completion. ## Storage Layout diff --git a/docs/mcp-tools.md b/docs/mcp-tools.md index ad915a5..766e2a5 100644 --- a/docs/mcp-tools.md +++ b/docs/mcp-tools.md @@ -21,6 +21,17 @@ HTTP mode also exposes: - `/sse` - `/messages/` +## Progress Notifications + +RecallForge supports MCP progress notifications for long-running tool calls. When a client includes `_meta.progressToken` in a request, compatible transports receive `notifications/progress` events with numeric progress, optional total, and a human-readable status message. + +Progress is best-effort and does not change the final tool response shape. It currently covers: + +- search and explain phases +- vector and full-text search phases +- `search_batch` per-query completion updates before the final merged result +- `ingest`, individual index/memory writes, `batch`, and `rebuild_fts` + Example MCP client config (Claude Desktop): ```json diff --git a/docs/research/recallforge-memory-mcp-roadmap.md b/docs/research/recallforge-memory-mcp-roadmap.md index 8a5e44f..ad2bfdb 100644 --- a/docs/research/recallforge-memory-mcp-roadmap.md +++ b/docs/research/recallforge-memory-mcp-roadmap.md @@ -135,9 +135,10 @@ Goal: - Prove RecallForge as a memory MCP, not just a benchmark pipeline. Current Linear fit: -- `REC-160` -- `REC-153` - `REC-33` + +Shipped Linear work: +- `REC-153` - `REC-61` What this phase delivers: @@ -145,6 +146,7 @@ What this phase delivers: - explanation quality checks - latency and RSS budget enforcement - real episodic corpora coverage +- MCP progress notifications for long-running search, ingest, batch, and rebuild workflows - alpha and beta validation with real workflows Why this comes last: @@ -156,7 +158,7 @@ Why this comes last: - Keep `Retrieval and Ranking` for cheap broad retrieval work like `REC-169`, `REC-148`, `REC-72`, `REC-71`, `REC-146` - Add a milestone such as `Memory Policy and Enrichment` for `REC-84`, `REC-83`, `REC-75`, `REC-76`, `REC-78` - Keep `Research Queue` for gated expensive-stage work like `REC-130`, `REC-115`, `REC-147`, `REC-168` -- Keep `Benchmark Integrity` and `Launch and Distribution` for `REC-160`, `REC-153`, `REC-33`, `REC-61` +- Keep `Benchmark Integrity` and `Launch and Distribution` for `REC-33` and any future public validation work ## Architecture Principle diff --git a/src/recallforge/search.py b/src/recallforge/search.py index 37b8697..906ab28 100644 --- a/src/recallforge/search.py +++ b/src/recallforge/search.py @@ -23,7 +23,7 @@ import time from dataclasses import dataclass, field, replace from hashlib import sha256 -from typing import List, Dict, Any, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from .backends.base import ModelBackend from .cache import EmbeddingCache @@ -1769,6 +1769,7 @@ def search_batch( profile: Optional[str] = None, max_workers: int = 4, rrf_k: int = 60, + progress_callback: Optional[Callable[[int, int, int], None]] = None, ) -> List[BatchSearchResult]: """ Run multiple search queries in parallel and merge results using RRF. @@ -1789,6 +1790,8 @@ def search_batch( profile: Optional profile namespace filter max_workers: Maximum parallel threads rrf_k: RRF fusion constant + progress_callback: Optional callback invoked as each query branch + completes with (completed_count, total_count, branch_result_count) Returns: List of BatchSearchResult objects, sorted by best merged score @@ -1845,6 +1848,7 @@ def run_single_query(q: BatchQuery) -> List[tuple]: # Run all queries in parallel all_results: List[List[tuple]] = [[] for _ in batch_queries] + completed_queries = 0 with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: future_to_idx = { executor.submit(run_single_query, q): i @@ -1857,6 +1861,12 @@ def run_single_query(q: BatchQuery) -> List[tuple]: except Exception as e: logger.error("Batch query %d failed: %s", idx, e) all_results[idx] = [] + completed_queries += 1 + if progress_callback is not None: + try: + progress_callback(completed_queries, len(batch_queries), len(all_results[idx])) + except Exception as exc: + logger.debug("search_batch progress callback failed: %s", exc) # Merge results using RRF with best-score-wins merged: Dict[str, Dict[str, Any]] = {} diff --git a/src/recallforge/server.py b/src/recallforge/server.py index 3b9ec0c..1c85b44 100644 --- a/src/recallforge/server.py +++ b/src/recallforge/server.py @@ -22,7 +22,7 @@ import time from pathlib import Path from urllib.parse import unquote -from typing import Optional, Callable, TypeVar +from typing import Awaitable, Callable, Optional, TypeVar from mcp.server import Server from mcp.server.stdio import stdio_server @@ -66,6 +66,83 @@ def trace_log(operation: str, **kwargs) -> None: ) _T = TypeVar("_T") +_ProgressSend = Callable[[float, Optional[float], Optional[str]], Awaitable[None]] + + +class _ToolProgressReporter: + """Best-effort MCP progress notification helper for long-running tools.""" + + def __init__(self, send: Optional[_ProgressSend] = None): + self._send = send + + @property + def enabled(self) -> bool: + return self._send is not None + + async def report( + self, + progress: float, + total: Optional[float] = None, + message: Optional[str] = None, + ) -> None: + if self._send is None: + return + try: + await self._send(float(progress), None if total is None else float(total), message) + except Exception as exc: + logger.debug("Failed to send MCP progress notification: %s", exc) + + +def _progress_reporter_for_request(server: Server) -> _ToolProgressReporter: + """Create a progress reporter for the active MCP request, if requested.""" + try: + request_context = server.request_context + except LookupError: + return _ToolProgressReporter() + + meta = getattr(request_context, "meta", None) + progress_token = getattr(meta, "progressToken", None) + session = getattr(request_context, "session", None) + if progress_token is None or session is None: + return _ToolProgressReporter() + + send_progress = getattr(session, "send_progress_notification", None) + if not callable(send_progress): + return _ToolProgressReporter() + + request_id = str(getattr(request_context, "request_id", "")) or None + + async def _send(progress: float, total: Optional[float], message: Optional[str]) -> None: + await send_progress( + progress_token, + progress, + total=total, + message=message, + related_request_id=request_id, + ) + + return _ToolProgressReporter(_send) + + +def _schedule_progress_from_thread( + loop: asyncio.AbstractEventLoop, + progress: _ToolProgressReporter, + value: float, + total: Optional[float], + message: str, +) -> None: + """Schedule a progress notification from worker-thread callbacks.""" + if not progress.enabled: + return + future = asyncio.run_coroutine_threadsafe(progress.report(value, total, message), loop) + + def _log_failure(done_future): + try: + done_future.result() + except Exception as exc: + logger.debug("Scheduled MCP progress notification failed: %s", exc) + + future.add_done_callback(_log_failure) def _get_tool_semaphore() -> asyncio.Semaphore: @@ -826,10 +903,11 @@ async def list_tools() -> list[Tool]: @server.call_tool() async def call_tool(name: str, arguments: dict) -> list[TextContent | ImageContent | EmbeddedResource]: """Execute a tool.""" + progress = _progress_reporter_for_request(server) try: if name == "batch": - return await _handle_batch(arguments, backend, storage, _mutable_config) - return await _dispatch_tool(name, arguments, backend, storage, _mutable_config) + return await _handle_batch(arguments, backend, storage, _mutable_config, progress=progress) + return await _dispatch_tool(name, arguments, backend, storage, _mutable_config, progress=progress) except Exception as e: return _error_response("INTERNAL_ERROR", str(e), {"exception_type": type(e).__name__}) @@ -845,8 +923,10 @@ async def _dispatch_tool( backend, storage, mutable_config: Optional[dict] = None, + progress: Optional[_ToolProgressReporter] = None, ) -> list[TextContent | ImageContent | EmbeddedResource]: """Route a single tool call to the appropriate handler.""" + progress = progress or _ToolProgressReporter() # Apply mutable config defaults: if a handler expects collection/max_file_size_mb # and the caller didn't explicitly provide them, use the mutable config values. if mutable_config: @@ -859,29 +939,29 @@ async def _dispatch_tool( if "caption_media" not in arguments and "caption_media" in mutable_config: arguments.setdefault("caption_media", mutable_config["caption_media"]) if name == "search": - return await _handle_search(arguments, backend, storage) + return await _handle_search(arguments, backend, storage, progress=progress) elif name == "explain_results": - return await _handle_explain_results(arguments, backend, storage) + return await _handle_explain_results(arguments, backend, storage, progress=progress) elif name == "search_fts": - return await _handle_search_fts(arguments, storage) + return await _handle_search_fts(arguments, storage, progress=progress) elif name == "search_vec": - return await _handle_search_vec(arguments, backend, storage) + return await _handle_search_vec(arguments, backend, storage, progress=progress) elif name == "ingest": - return await _handle_ingest(arguments, backend, storage) + return await _handle_ingest(arguments, backend, storage, progress=progress) elif name == "index_document": - return await _handle_index_document(arguments, backend, storage) + return await _handle_index_document(arguments, backend, storage, progress=progress) elif name == "index_image": - return await _handle_index_image(arguments, backend, storage) + return await _handle_index_image(arguments, backend, storage, progress=progress) elif name == "index_audio": - return await _handle_index_audio(arguments, backend, storage) + return await _handle_index_audio(arguments, backend, storage, progress=progress) elif name == "memory_add": - return await _handle_memory_add(arguments, backend, storage) + return await _handle_memory_add(arguments, backend, storage, progress=progress) elif name == "memory_add_conversation": - return await _handle_memory_add_conversation(arguments, backend, storage) + return await _handle_memory_add_conversation(arguments, backend, storage, progress=progress) elif name == "memory_update": - return await _handle_memory_update(arguments, backend, storage) + return await _handle_memory_update(arguments, backend, storage, progress=progress) elif name == "memory_delete": - return await _handle_memory_delete(arguments, storage) + return await _handle_memory_delete(arguments, storage, progress=progress) elif name == "memory_get": return await _handle_memory_get(arguments, storage) elif name == "memory_graph_entities": @@ -893,7 +973,7 @@ async def _dispatch_tool( elif name == "status": return await _handle_status(backend, storage) elif name == "rebuild_fts": - return await _handle_rebuild_fts(storage) + return await _handle_rebuild_fts(storage, progress=progress) elif name == "list_collections": return await _handle_list_collections(arguments, storage) elif name == "list_namespaces": @@ -907,7 +987,7 @@ async def _dispatch_tool( elif name == "set_config": return await _handle_set_config(arguments, backend, storage, mutable_config if mutable_config is not None else {}) elif name == "search_batch": - return await _handle_search_batch(arguments, backend, storage) + return await _handle_search_batch(arguments, backend, storage, progress=progress) else: raise ValueError(f"Unknown tool: {name}") @@ -917,8 +997,10 @@ async def _handle_batch( backend, storage, mutable_config: Optional[dict] = None, + progress: Optional[_ToolProgressReporter] = None, ) -> list[TextContent]: """Execute multiple RecallForge operations in a single call.""" + progress = progress or _ToolProgressReporter() operations = arguments.get("operations") if not isinstance(operations, list): return [TextContent(type="text", text=json.dumps({"error": "operations must be a list"}))] @@ -936,10 +1018,19 @@ async def _handle_batch( batch_results = [] succeeded = 0 failed = 0 + total = len(operations) + progress_total = max(total, 1) + + await progress.report(0, progress_total, f"Starting batch with {total} operation(s)") for i, op in enumerate(operations): tool_name = op.get("tool", "") op_args = op.get("arguments", {}) + await progress.report( + i, + progress_total, + f"Starting batch operation {i + 1}/{total}: {tool_name or 'unknown'}", + ) # Reject nested batch calls if tool_name == "batch": @@ -950,10 +1041,21 @@ async def _handle_batch( "result": {"error": "Nested batch operations are not allowed"}, }) failed += 1 + await progress.report( + i + 1, + progress_total, + f"Finished batch operation {i + 1}/{total}: {tool_name} (error)", + ) continue try: - content_list = await _dispatch_tool(tool_name, op_args, backend, storage, mutable_config) + content_list = await _dispatch_tool( + tool_name, + op_args, + backend, + storage, + mutable_config, + ) # Unwrap the first TextContent result into a parsed dict when possible if content_list and hasattr(content_list[0], "text"): try: @@ -969,6 +1071,11 @@ async def _handle_batch( "result": result_payload, }) succeeded += 1 + await progress.report( + i + 1, + progress_total, + f"Finished batch operation {i + 1}/{total}: {tool_name} (success)", + ) except Exception as exc: batch_results.append({ "index": i, @@ -977,6 +1084,11 @@ async def _handle_batch( "result": {"error": str(exc)}, }) failed += 1 + await progress.report( + i + 1, + progress_total, + f"Finished batch operation {i + 1}/{total}: {tool_name} (error)", + ) output = { "batch_results": batch_results, @@ -984,11 +1096,18 @@ async def _handle_batch( "succeeded": succeeded, "failed": failed, } + await progress.report(progress_total, progress_total, f"Batch complete: {succeeded} succeeded, {failed} failed") return [TextContent(type="text", text=json.dumps(output, indent=2))] -async def _handle_search(arguments: dict, backend, storage) -> list[TextContent]: +async def _handle_search( + arguments: dict, + backend, + storage, + progress: Optional[_ToolProgressReporter] = None, +) -> list[TextContent]: """Handle hybrid search.""" + progress = progress or _ToolProgressReporter() query, image_path, video_path, file_path, input_error = _resolve_query_inputs(arguments) limit = arguments.get("limit", 10) collection = arguments.get("collection") @@ -1004,10 +1123,12 @@ async def _handle_search(arguments: dict, backend, storage) -> list[TextContent] if input_error: return _error_response("INVALID_INPUT", input_error) + await progress.report(0, 4, "Validating search input") if file_path: query, image_path, video_path, file_error = await _run_blocking(_resolve_file_query_input, file_path) if file_error: return _error_response("INVALID_INPUT", file_error, {"file_path": file_path}) + await progress.report(1, 4, "Resolved search input") trace_log("search_start", query=(query or image_path or video_path or file_path or "")[:50], limit=limit, collection=collection, content_type=content_type, user_id=user_id, session_id=session_id, project_id=project_id, profile=profile, intent=intent, rerank_top_k=rerank_top_k, expand=expand) @@ -1028,13 +1149,17 @@ async def _handle_search(arguments: dict, backend, storage) -> list[TextContent] ) if image_path: + await progress.report(2, 4, "Running image search") results = await _run_blocking(searcher.search_image, image_path) elif video_path: + await progress.report(2, 4, "Running video search") results = await _run_blocking(searcher.search_video, video_path) else: + await progress.report(2, 4, "Running hybrid search") results = await _run_blocking(searcher.search, query) trace_log("search_done", query=(query or image_path or video_path or file_path or "")[:50], count=len(results)) + await progress.report(3, 4, f"Search retrieved {len(results)} result(s)") output = { "query": query, @@ -1068,11 +1193,18 @@ async def _handle_search(arguments: dict, backend, storage) -> list[TextContent] ], } + await progress.report(4, 4, "Search response ready") return [TextContent(type="text", text=json.dumps(output, indent=2))] -async def _handle_explain_results(arguments: dict, backend, storage) -> list[TextContent]: +async def _handle_explain_results( + arguments: dict, + backend, + storage, + progress: Optional[_ToolProgressReporter] = None, +) -> list[TextContent]: """Handle explain_results - returns detailed scoring provenance for each result.""" + progress = progress or _ToolProgressReporter() query, image_path, video_path, file_path, input_error = _resolve_query_inputs(arguments) limit = arguments.get("limit", 10) collection = arguments.get("collection") @@ -1088,10 +1220,12 @@ async def _handle_explain_results(arguments: dict, backend, storage) -> list[Tex if input_error: return _error_response("INVALID_INPUT", input_error) + await progress.report(0, 4, "Validating explanation input") if file_path: query, image_path, video_path, file_error = await _run_blocking(_resolve_file_query_input, file_path) if file_error: return _error_response("INVALID_INPUT", file_error, {"file_path": file_path}) + await progress.report(1, 4, "Resolved explanation input") trace_log("explain_results_start", query=(query or image_path or video_path or file_path or "")[:50], limit=limit, collection=collection, content_type=content_type, user_id=user_id, session_id=session_id, project_id=project_id, profile=profile, intent=intent, rerank_top_k=rerank_top_k, expand=expand) @@ -1112,13 +1246,17 @@ async def _handle_explain_results(arguments: dict, backend, storage) -> list[Tex ) if image_path: + await progress.report(2, 4, "Running image search for explanation") results = await _run_blocking(searcher.search_image, image_path) elif video_path: + await progress.report(2, 4, "Running video search for explanation") results = await _run_blocking(searcher.search_video, video_path) else: + await progress.report(2, 4, "Running hybrid search for explanation") results = await _run_blocking(searcher.search, query) trace_log("explain_results_done", query=(query or image_path or video_path or file_path or "")[:50], count=len(results)) + await progress.report(3, 4, f"Building explanations for {len(results)} result(s)") # Build detailed explanation for each result explained_results = [] @@ -1182,11 +1320,17 @@ async def _handle_explain_results(arguments: dict, backend, storage) -> list[Tex "results": explained_results, } + await progress.report(4, 4, "Explanation response ready") return [TextContent(type="text", text=json.dumps(output, indent=2))] -async def _handle_search_fts(arguments: dict, storage) -> list[TextContent]: +async def _handle_search_fts( + arguments: dict, + storage, + progress: Optional[_ToolProgressReporter] = None, +) -> list[TextContent]: """Handle FTS search.""" + progress = progress or _ToolProgressReporter() query = arguments.get("query", "") limit = arguments.get("limit", 20) collection = arguments.get("collection") @@ -1202,6 +1346,8 @@ async def _handle_search_fts(arguments: dict, storage) -> list[TextContent]: if not query: return _error_response("INVALID_INPUT", "Query is required") + await progress.report(0, 3, "Validating full-text search input") + await progress.report(1, 3, "Running full-text search") results = await _run_blocking( storage.search_fts, query=query, @@ -1215,6 +1361,7 @@ async def _handle_search_fts(arguments: dict, storage) -> list[TextContent]: ) trace_log("search_fts_done", query=query[:50], count=len(results)) + await progress.report(2, 3, f"Full-text search retrieved {len(results)} result(s)") output = { "query": query, @@ -1235,11 +1382,18 @@ async def _handle_search_fts(arguments: dict, storage) -> list[TextContent]: ], } + await progress.report(3, 3, "Full-text search response ready") return [TextContent(type="text", text=json.dumps(output, indent=2))] -async def _handle_search_vec(arguments: dict, backend, storage) -> list[TextContent]: +async def _handle_search_vec( + arguments: dict, + backend, + storage, + progress: Optional[_ToolProgressReporter] = None, +) -> list[TextContent]: """Handle vector search.""" + progress = progress or _ToolProgressReporter() query, image_path, video_path, file_path, input_error = _resolve_query_inputs(arguments) limit = arguments.get("limit", 20) collection = arguments.get("collection") @@ -1252,10 +1406,12 @@ async def _handle_search_vec(arguments: dict, backend, storage) -> list[TextCont if input_error: return _error_response("INVALID_INPUT", input_error) + await progress.report(0, 4, "Validating vector search input") if file_path: query, image_path, video_path, file_error = await _run_blocking(_resolve_file_query_input, file_path) if file_error: return _error_response("INVALID_INPUT", file_error, {"file_path": file_path}) + await progress.report(1, 4, "Embedding vector search input") trace_log("search_vec_start", query=(query or image_path or video_path or file_path or "")[:50], limit=limit, collection=collection, content_type=content_type, user_id=user_id, session_id=session_id, project_id=project_id, profile=profile) @@ -1270,6 +1426,7 @@ async def _handle_search_vec(arguments: dict, backend, storage) -> list[TextCont else: vector = await _run_blocking(backend.embed_text, query) + await progress.report(2, 4, "Running vector search") results = await _run_blocking( storage.search_vec, vector=vector.tolist() if hasattr(vector, 'tolist') else list(vector), @@ -1283,6 +1440,7 @@ async def _handle_search_vec(arguments: dict, backend, storage) -> list[TextCont ) trace_log("search_vec_done", query=(query or image_path or video_path or file_path or "")[:50], count=len(results)) + await progress.report(3, 4, f"Vector search retrieved {len(results)} result(s)") output = { "query": query, @@ -1306,13 +1464,20 @@ async def _handle_search_vec(arguments: dict, backend, storage) -> list[TextCont ], } + await progress.report(4, 4, "Vector search response ready") return [TextContent(type="text", text=json.dumps(output, indent=2))] -async def _handle_search_batch(arguments: dict, backend, storage) -> list[TextContent]: +async def _handle_search_batch( + arguments: dict, + backend, + storage, + progress: Optional[_ToolProgressReporter] = None, +) -> list[TextContent]: """Handle parallel batch search with RRF merge.""" from .search import BatchQuery, search_batch + progress = progress or _ToolProgressReporter() queries_raw = arguments.get("queries", []) limit = arguments.get("limit", 10) collection = arguments.get("collection") @@ -1354,6 +1519,18 @@ async def _handle_search_batch(arguments: dict, backend, storage) -> list[TextCo else: return _error_response("INVALID_INPUT", f"queries[{i}] must be a string or object") + await progress.report(0, len(queries), f"Starting batch search with {len(queries)} query item(s)") + loop = asyncio.get_running_loop() + + def _query_progress(completed: int, total: int, result_count: int) -> None: + _schedule_progress_from_thread( + loop, + progress, + completed, + total, + f"Batch search completed query {completed}/{total}; last branch returned {result_count} candidate(s)", + ) + results = await _run_blocking( search_batch, queries=queries, @@ -1366,9 +1543,11 @@ async def _handle_search_batch(arguments: dict, backend, storage) -> list[TextCo session_id=session_id, project_id=project_id, profile=profile, + progress_callback=_query_progress if progress.enabled else None, ) trace_log("search_batch_done", query_count=len(queries), count=len(results)) + await progress.report(len(queries), len(queries), f"Batch search merged {len(results)} result(s)") output = { "query_count": len(queries), @@ -1398,11 +1577,18 @@ async def _handle_search_batch(arguments: dict, backend, storage) -> list[TextCo ], } + await progress.report(len(queries), len(queries), "Batch search response ready") return [TextContent(type="text", text=json.dumps(output, indent=2))] -async def _handle_ingest(arguments: dict, backend, storage) -> list[TextContent]: +async def _handle_ingest( + arguments: dict, + backend, + storage, + progress: Optional[_ToolProgressReporter] = None, +) -> list[TextContent]: """Handle unified ingest.""" + progress = progress or _ToolProgressReporter() text = arguments.get("text") path = arguments.get("path") file_path = arguments.get("file_path") @@ -1422,6 +1608,7 @@ async def _handle_ingest(arguments: dict, backend, storage) -> list[TextContent] trace_log("ingest_start", collection=collection, text=bool(text), file_path=file_path, folder_path=folder_path, user_id=user_id, session_id=session_id, project_id=project_id, profile=profile) + await progress.report(0, 2, "Starting ingest") output = await _run_blocking( storage.ingest, collection=collection, @@ -1446,11 +1633,31 @@ async def _handle_ingest(arguments: dict, backend, storage) -> list[TextContent] ) trace_log("ingest_done", collection=collection, indexed_text=output.get("indexed_text", 0), indexed_images=output.get("indexed_images", 0)) + indexed_total = sum( + int(output.get(key, 0) or 0) + for key in ( + "indexed_text", + "indexed_images", + "indexed_videos", + "indexed_audio", + "indexed_documents", + "indexed_document_sections", + "indexed_video_frames", + "indexed_video_transcripts", + ) + ) + await progress.report(2, 2, f"Ingest complete; indexed {indexed_total} item(s)") return [TextContent(type="text", text=json.dumps(output, indent=2))] -async def _handle_index_document(arguments: dict, backend, storage) -> list[TextContent]: +async def _handle_index_document( + arguments: dict, + backend, + storage, + progress: Optional[_ToolProgressReporter] = None, +) -> list[TextContent]: """Handle document indexing.""" + progress = progress or _ToolProgressReporter() path = arguments.get("path", "") text = arguments.get("text", "") collection = arguments.get("collection", "default") @@ -1459,7 +1666,8 @@ async def _handle_index_document(arguments: dict, backend, storage) -> list[Text if not path or not text: return _error_response("INVALID_INPUT", "path and text are required") - + + await progress.report(0, 1, f"Indexing document {path}") content_hash = await _run_blocking( storage.index_document, path=path, @@ -1470,6 +1678,7 @@ async def _handle_index_document(arguments: dict, backend, storage) -> list[Text ) trace_log("index_document_done", path=path, hash=content_hash[:8]) + await progress.report(1, 1, f"Indexed document {path}") output = { "success": True, @@ -1481,8 +1690,14 @@ async def _handle_index_document(arguments: dict, backend, storage) -> list[Text return [TextContent(type="text", text=json.dumps(output, indent=2))] -async def _handle_index_image(arguments: dict, backend, storage) -> list[TextContent]: +async def _handle_index_image( + arguments: dict, + backend, + storage, + progress: Optional[_ToolProgressReporter] = None, +) -> list[TextContent]: """Handle image indexing.""" + progress = progress or _ToolProgressReporter() path = arguments.get("path", "") collection = arguments.get("collection", "default") @@ -1493,7 +1708,8 @@ async def _handle_index_image(arguments: dict, backend, storage) -> list[TextCon if not os.path.exists(path): return _error_response("NOT_FOUND", f"File not found: {path}", {"path": path}) - + + await progress.report(0, 1, f"Indexing image {path}") content_hash = await _run_blocking( storage.index_image, path=path, @@ -1502,6 +1718,7 @@ async def _handle_index_image(arguments: dict, backend, storage) -> list[TextCon ) trace_log("index_image_done", path=path, hash=content_hash[:8]) + await progress.report(1, 1, f"Indexed image {path}") output = { "success": True, @@ -1513,8 +1730,14 @@ async def _handle_index_image(arguments: dict, backend, storage) -> list[TextCon return [TextContent(type="text", text=json.dumps(output, indent=2))] -async def _handle_index_audio(arguments: dict, backend, storage) -> list[TextContent]: +async def _handle_index_audio( + arguments: dict, + backend, + storage, + progress: Optional[_ToolProgressReporter] = None, +) -> list[TextContent]: """Handle transcript-first audio indexing.""" + progress = progress or _ToolProgressReporter() path = arguments.get("path", "") collection = arguments.get("collection", "default") @@ -1526,6 +1749,7 @@ async def _handle_index_audio(arguments: dict, backend, storage) -> list[TextCon if not os.path.exists(path): return _error_response("NOT_FOUND", f"File not found: {path}", {"path": path}) + await progress.report(0, 1, f"Indexing audio {path}") output = await _run_blocking( storage.index_audio, path=path, @@ -1534,11 +1758,18 @@ async def _handle_index_audio(arguments: dict, backend, storage) -> list[TextCon ) trace_log("index_audio_done", path=path, hash=str(output.get("hash", ""))[:8]) + await progress.report(1, 1, f"Indexed audio {path}") return [TextContent(type="text", text=json.dumps(output, indent=2))] -async def _handle_memory_add(arguments: dict, backend, storage) -> list[TextContent]: +async def _handle_memory_add( + arguments: dict, + backend, + storage, + progress: Optional[_ToolProgressReporter] = None, +) -> list[TextContent]: """Handle memory add.""" + progress = progress or _ToolProgressReporter() path = arguments.get("path", "") text = arguments.get("text", "") collection = arguments.get("collection", "default") @@ -1557,6 +1788,7 @@ async def _handle_memory_add(arguments: dict, backend, storage) -> list[TextCont if not path or not text: return _error_response("INVALID_INPUT", "path and text are required") + await progress.report(0, 1, f"Adding memory {path}") content_hash = await _run_blocking( storage.upsert_memory, path=path, @@ -1574,6 +1806,7 @@ async def _handle_memory_add(arguments: dict, backend, storage) -> list[TextCont ) trace_log("memory_add_done", path=path, hash=content_hash[:8]) + await progress.report(1, 1, f"Added memory {path}") output = { "success": True, @@ -1592,8 +1825,14 @@ async def _handle_memory_add(arguments: dict, backend, storage) -> list[TextCont return [TextContent(type="text", text=json.dumps(output, indent=2))] -async def _handle_memory_add_conversation(arguments: dict, backend, storage) -> list[TextContent]: +async def _handle_memory_add_conversation( + arguments: dict, + backend, + storage, + progress: Optional[_ToolProgressReporter] = None, +) -> list[TextContent]: """Handle conversation memory ingest.""" + progress = progress or _ToolProgressReporter() path = arguments.get("path", "") turns = arguments.get("turns") collection = arguments.get("collection", "default") @@ -1626,6 +1865,7 @@ async def _handle_memory_add_conversation(arguments: dict, backend, storage) -> return _error_response("BACKEND_ERROR", "Storage backend does not support conversation memories") try: + await progress.report(0, 1, f"Indexing conversation memory {path}") output = await _run_blocking( index_conversation, path=path, @@ -1652,14 +1892,21 @@ async def _handle_memory_add_conversation(arguments: dict, backend, storage) -> hash=str(output.get("hash", ""))[:8], indexed_turns=output.get("indexed_turns", 0), ) + await progress.report(1, 1, f"Indexed conversation memory {path}") output = dict(output) output["operation"] = "add_conversation" return [TextContent(type="text", text=json.dumps(output, indent=2))] -async def _handle_memory_update(arguments: dict, backend, storage) -> list[TextContent]: +async def _handle_memory_update( + arguments: dict, + backend, + storage, + progress: Optional[_ToolProgressReporter] = None, +) -> list[TextContent]: """Handle memory update.""" + progress = progress or _ToolProgressReporter() path = arguments.get("path", "") text = arguments.get("text", "") collection = arguments.get("collection", "default") @@ -1678,6 +1925,7 @@ async def _handle_memory_update(arguments: dict, backend, storage) -> list[TextC if not path or not text: return _error_response("INVALID_INPUT", "path and text are required") + await progress.report(0, 1, f"Updating memory {path}") content_hash = await _run_blocking( storage.upsert_memory, path=path, @@ -1695,6 +1943,7 @@ async def _handle_memory_update(arguments: dict, backend, storage) -> list[TextC ) trace_log("memory_update_done", path=path, hash=content_hash[:8]) + await progress.report(1, 1, f"Updated memory {path}") output = { "success": True, @@ -1713,8 +1962,13 @@ async def _handle_memory_update(arguments: dict, backend, storage) -> list[TextC return [TextContent(type="text", text=json.dumps(output, indent=2))] -async def _handle_memory_delete(arguments: dict, storage) -> list[TextContent]: +async def _handle_memory_delete( + arguments: dict, + storage, + progress: Optional[_ToolProgressReporter] = None, +) -> list[TextContent]: """Handle memory delete.""" + progress = progress or _ToolProgressReporter() path = arguments.get("path", "") collection = arguments.get("collection", "default") user_id = arguments.get("user_id") @@ -1728,6 +1982,7 @@ async def _handle_memory_delete(arguments: dict, storage) -> list[TextContent]: if not path: return _error_response("INVALID_INPUT", "path is required") + await progress.report(0, 1, f"Deleting memory {path}") output = await _run_blocking( storage.delete_memory, path=path, @@ -1737,6 +1992,7 @@ async def _handle_memory_delete(arguments: dict, storage) -> list[TextContent]: project_id=project_id, profile=profile, ) + await progress.report(1, 1, f"Deleted memory {path}") trace_log("memory_delete_done", path=path, removed_vectors=output.get("removed_vectors", 0)) @@ -2074,11 +2330,17 @@ async def _handle_status(backend, storage) -> list[TextContent]: return [TextContent(type="text", text=json.dumps(output, indent=2))] -async def _handle_rebuild_fts(storage) -> list[TextContent]: +async def _handle_rebuild_fts( + storage, + progress: Optional[_ToolProgressReporter] = None, +) -> list[TextContent]: """Handle FTS rebuild.""" + progress = progress or _ToolProgressReporter() try: + await progress.report(0, 1, "Rebuilding full-text search index") await _run_blocking(storage.rebuild_fts_index) output = {"success": True, "message": "FTS index rebuilt"} + await progress.report(1, 1, "Full-text search index rebuilt") return [TextContent(type="text", text=json.dumps(output, indent=2))] except Exception as e: return _error_response("BACKEND_ERROR", str(e), {"exception_type": type(e).__name__}) diff --git a/tests/test_mcp_progress.py b/tests/test_mcp_progress.py new file mode 100644 index 0000000..c6de472 --- /dev/null +++ b/tests/test_mcp_progress.py @@ -0,0 +1,206 @@ +""" +test_mcp_progress.py - Progress notification coverage for MCP tool handlers. + +These tests use mocked backends/storage and a fake progress sink. They verify +that handlers emit protocol-ready progress updates without needing a live MCP +HTTP/SSE client. +""" + +import asyncio +import json +import os +import sys +import unittest +from unittest.mock import MagicMock + +import numpy as np +from mcp.shared.memory import create_connected_server_and_client_session + +sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "src")) + +from recallforge.server import ( + _ToolProgressReporter, + create_server, + _handle_batch, + _handle_ingest, + _handle_search_batch, +) +from recallforge.storage.base import SearchResult + + +class _FakeInfo: + name = "stub-backend" + device = "cpu" + dtype = "float32" + embedder_loaded = True + reranker_loaded = False + memory_allocated_gb = 0.0 + quantization = None + + +class ProgressRecorder: + def __init__(self): + self.events = [] + + @property + def reporter(self): + return _ToolProgressReporter(self.send) + + async def send(self, progress, total, message): + self.events.append( + { + "progress": progress, + "total": total, + "message": message, + } + ) + + +def _make_backend(): + backend = MagicMock() + backend.get_mode.return_value = "embed" + backend.embed_text.return_value = np.ones(8, dtype=np.float32) + backend.embed_image.return_value = np.ones(8, dtype=np.float32) + backend.get_info.return_value = _FakeInfo() + backend.needs_reranker.return_value = False + return backend + + +def _result(path: str, score: float = 1.0) -> SearchResult: + return SearchResult( + filepath=path, + display_path=path, + title=path, + context=None, + hash=f"hash-{path}", + docid=f"doc-{path}", + collection="default", + modified_at="now", + body_length=20, + score=score, + source="fts", + body=f"Body for {path}", + ) + + +def _make_storage(): + storage = MagicMock() + storage.count_embeddings.return_value = 0 + storage.count_documents.return_value = 0 + storage.delete_memory.return_value = {"success": True, "removed_vectors": 1} + storage.ingest.return_value = { + "success": True, + "indexed_text": 1, + "indexed_images": 0, + "indexed_videos": 0, + "indexed_audio": 0, + "indexed_documents": 0, + } + storage.search_fts.side_effect = [ + [_result("alpha.md", 0.9)], + [_result("beta.md", 0.8)], + ] + return storage + + +class TestMcpProgress(unittest.IsolatedAsyncioTestCase): + async def test_batch_emits_operation_progress(self): + backend = _make_backend() + storage = _make_storage() + recorder = ProgressRecorder() + + result = await _handle_batch( + { + "operations": [ + {"tool": "status", "arguments": {}}, + {"tool": "memory_delete", "arguments": {"path": "notes/demo.md"}}, + ] + }, + backend, + storage, + progress=recorder.reporter, + ) + + data = json.loads(result[0].text) + self.assertEqual(data["succeeded"], 2) + messages = [event["message"] for event in recorder.events] + self.assertIn("Starting batch with 2 operation(s)", messages) + self.assertTrue(any("Finished batch operation 1/2: status (success)" == msg for msg in messages)) + self.assertTrue(any("Finished batch operation 2/2: memory_delete (success)" == msg for msg in messages)) + self.assertEqual(recorder.events[-1]["progress"], recorder.events[-1]["total"]) + + async def test_ingest_emits_start_and_completion_progress(self): + backend = _make_backend() + storage = _make_storage() + recorder = ProgressRecorder() + + result = await _handle_ingest( + {"text": "hello", "path": "notes/hello.md", "collection": "default"}, + backend, + storage, + progress=recorder.reporter, + ) + + data = json.loads(result[0].text) + self.assertTrue(data["success"]) + self.assertEqual([event["progress"] for event in recorder.events], [0.0, 2.0]) + self.assertIn("Ingest complete; indexed 1 item(s)", recorder.events[-1]["message"]) + + async def test_search_batch_emits_per_query_partial_progress(self): + backend = _make_backend() + storage = _make_storage() + recorder = ProgressRecorder() + + result = await _handle_search_batch( + { + "queries": [ + {"query": "alpha", "mode": "fts"}, + {"query": "beta", "mode": "fts"}, + ], + "limit": 5, + }, + backend, + storage, + progress=recorder.reporter, + ) + await asyncio.sleep(0.05) + + data = json.loads(result[0].text) + self.assertEqual(data["query_count"], 2) + partial_messages = [ + event["message"] + for event in recorder.events + if "Batch search completed query" in event["message"] + ] + self.assertEqual(len(partial_messages), 2) + self.assertTrue(any("last branch returned 1 candidate(s)" in msg for msg in partial_messages)) + self.assertEqual(recorder.events[-1]["progress"], 2.0) + self.assertEqual(recorder.events[-1]["total"], 2.0) + + async def test_client_session_receives_progress_notifications(self): + backend = _make_backend() + storage = _make_storage() + server = await create_server(backend=backend, storage=storage, mode="embed") + events = [] + + async def on_progress(progress, total, message): + events.append((progress, total, message)) + + async with create_connected_server_and_client_session(server) as session: + result = await session.call_tool( + "ingest", + arguments={"text": "hello", "path": "notes/hello.md", "collection": "default"}, + progress_callback=on_progress, + ) + + data = json.loads(result.content[0].text) + self.assertTrue(data["success"]) + self.assertGreaterEqual(len(events), 2) + self.assertEqual(events[0][0], 0.0) + self.assertIn("Starting ingest", events[0][2]) + self.assertEqual(events[-1][0], events[-1][1]) + self.assertIn("Ingest complete", events[-1][2]) + + +if __name__ == "__main__": + unittest.main()