From 43f30bafd665d9995174cebce874885d91810e08 Mon Sep 17 00:00:00 2001 From: MollyAI Date: Sun, 17 May 2026 15:26:24 -0400 Subject: [PATCH] Add memory graph enrichment --- CHANGELOG.md | 1 + README.md | 8 +- docs/ARCHITECTURE.md | 16 +- docs/MEMORY_POLICY.md | 6 + docs/mcp-tools.md | 127 ++++- src/recallforge/entities.py | 157 ++++++ src/recallforge/server.py | 153 +++++- src/recallforge/storage/base.py | 34 ++ src/recallforge/storage/indexing_ops.py | 34 ++ src/recallforge/storage/lancedb_backend.py | 575 ++++++++++++++++++++- tests/test_batch_tool.py | 1 + tests/test_config_tools.py | 69 ++- tests/test_entities.py | 49 ++ tests/test_json_compliance.py | 16 + tests/test_storage.py | 54 ++ tests/uat/test_mcp_server.sh | 1 + tests/uat/test_uat_comprehensive.py | 1 + 17 files changed, 1290 insertions(+), 12 deletions(-) create mode 100644 src/recallforge/entities.py create mode 100644 tests/test_entities.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d3aaf9..c0c78eb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ All notable changes to RecallForge will be documented in this file. ## [Unreleased] +- Added deterministic memory graph enrichment with entity/relation side tables and new `memory_graph_entities` / `memory_graph_related` MCP tools. - Replaced the tiny UAT video clips with compact episodic-memory fixtures, richer transcript sidecars, related artifact metadata, and regression coverage for the video corpus. - Added `memory_add_conversation` so conversation threads ingest as canonical parent memories with turn-level child memories and standard memory rollups. diff --git a/README.md b/README.md index 82821ff..18320a5 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ One query. Any modality. All local. | Audio transcript ingest | ✅ | ❌ | ❌ | ❌ | ❌ | | Document ingest (PDF/DOCX/PPTX) | ✅ | ❌ | ❌ | ❌ | ❌ | | Built-in reranking | ✅ Multimodal | ❌ | ❌ | ✅ ColBERT | ✅ Modules | -| MCP-native | ✅ 24 tools | ❌ | ❌ | ❌ | ❌ | +| MCP-native | ✅ 26 tools | ❌ | ❌ | ❌ | ❌ | | 100% local | ✅ | ✅ | ⚠️ Cloud default | ✅ | ✅ Docker | | Apple Silicon optimized | ✅ MLX 4-bit | ❌ | ❌ | ❌ | ❌ | | Cloud option | ❌ | ✅ | ✅ | ✅ | ✅ | @@ -146,7 +146,7 @@ Run over HTTP/SSE: recallforge serve --http --host 127.0.0.1 --port 7433 --mode embed ``` -RecallForge now exposes **24 MCP tools** across search, ingest, memory, collection admin, and runtime config. HTTP/SSE mode also exposes `/health`, `/sse`, and `/messages/`. +RecallForge now exposes **26 MCP tools** across search, ingest, memory graph navigation, collection admin, and runtime config. HTTP/SSE mode also exposes `/health`, `/sse`, and `/messages/`. See [docs/mcp-tools.md](docs/mcp-tools.md) for the full tool reference. @@ -161,7 +161,7 @@ See [docs/mcp-tools.md](docs/mcp-tools.md) for the full tool reference. ## How it works -RecallForge encodes text, images, video frames, documents, conversation turns, and audio transcripts into the same 2048-dimensional vector space using Qwen3-VL. This means "find notes about this diagram" works whether the diagram is text, an image, a conversation thread, or a frame from a video. A 3-stage pipeline handles the rest: +RecallForge encodes text, images, video frames, documents, conversation turns, and audio transcripts into the same 2048-dimensional vector space using Qwen3-VL. It also extracts lightweight entity and relation metadata so agents can navigate from one memory to other memories that mention the same people, projects, tickets, URLs, and organizations. This means "find notes about this diagram" works whether the diagram is text, an image, a conversation thread, or a frame from a video. A 3-stage pipeline handles the rest: ```mermaid graph TD @@ -276,7 +276,7 @@ src/recallforge/ │ └── lancedb_backend.py # LanceDB + Tantivy FTS ├── cache.py # LRU embedding cache ├── search.py # Hybrid search pipeline (BM25 + vector + RRF) -├── server.py # MCP server (24 tools, stdio + HTTP/SSE) +├── server.py # MCP server (26 tools, stdio + HTTP/SSE) ├── documents.py # PDF/DOCX/PPTX extraction ├── video.py # Frame/transcript extraction ├── audio.py # Transcript-first audio ingest diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index d7cfa50..8452c39 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -114,6 +114,18 @@ id | collection | file_path | title | content_hash | content_type | active created_at | updated_at ``` +**entities** (memory graph mentions) +``` +id | collection | entity_key | name | entity_type | memory_id | memory_root_path +file_path | content_hash | hash_seq | seq | evidence | namespace fields | created_at +``` + +**relations** (lightweight graph edges) +``` +id | collection | subject_key | subject_name | object_key | object_name | relation_type +memory_id | memory_root_path | file_path | content_hash | hash_seq | evidence | namespace fields +``` + Conversation memories use the same parent/child layout as media-derived memories: ``` @@ -122,7 +134,7 @@ conversation root path └──> turn child memories (`path::turn:0001`, `path::turn:0002`, ...) ``` -All turn children share the root `memory_id` and `memory_root_path`, so matching turns strengthen the parent conversation result through the standard memory rollup path. +All turn children share the root `memory_id` and `memory_root_path`, so matching turns strengthen the parent conversation result through the standard memory rollup path. Entity and relation rows keep the same IDs and evidence paths, which lets MCP clients navigate same-entity memories without relying only on lexical overlap. **content** (bodies, content-addressed) ``` @@ -193,7 +205,7 @@ backend = recallforge.get_backend() ### 5. MCP Server (`src/recallforge/server.py`) ``` -Tools: 24 MCP tools across search, ingest, memory, collection admin, batch, and runtime config +Tools: 26 MCP tools across search, ingest, memory, memory graph, collection admin, batch, and runtime config Transport: stdio (default) or HTTP/SSE (`/health`, `/sse`, `/messages/`) Startup: backend.warm_up() for predictable latency Signals: SIGTERM/SIGINT graceful shutdown diff --git a/docs/MEMORY_POLICY.md b/docs/MEMORY_POLICY.md index 58b690a..5a45855 100644 --- a/docs/MEMORY_POLICY.md +++ b/docs/MEMORY_POLICY.md @@ -47,3 +47,9 @@ Raw audio transcription and dedicated audio encoders are not part of the shipped ## Conversations Use `memory_add_conversation` when an agent or app wants to persist a thread. RecallForge stores the parent at the supplied `path` and stores each turn at `path::turn:0001`, `path::turn:0002`, and so on. All turns share the parent `memory_id`, so if several turns match a query, search and explanation output roll them up into the parent conversation with evidence paths. + +## Memory Graph + +Every indexed text-bearing evidence unit can also produce lightweight entity mentions and co-mention relation edges. That includes normal text memories, OCR/document sections, transcripts, captions, and conversation turns. Graph rows store `memory_id`, `memory_root_path`, `file_path`, and a short evidence snippet, so same-entity navigation remains traceable to the source memory. + +Use `memory_graph_entities` to inspect entities for a memory, path, or entity key. Use `memory_graph_related` to find other memories that share extracted entities with a seed memory or entity. This graph enrichment is intentionally local and deterministic: it adds navigation and grouping without introducing an external NLP service. diff --git a/docs/mcp-tools.md b/docs/mcp-tools.md index 64de4a6..2f5e4cf 100644 --- a/docs/mcp-tools.md +++ b/docs/mcp-tools.md @@ -55,6 +55,8 @@ Example MCP client config (Claude Desktop): - `memory_update` - `memory_delete` - `memory_get` +- `memory_graph_entities` +- `memory_graph_related` - `list_memories` ### Admin / Introspection @@ -853,6 +855,124 @@ Turn objects accept: --- +## memory_graph_entities + +**Description:** List entity mentions extracted from indexed memory text, OCR text, transcripts, captions, and conversation turns. Each mention includes the source memory/path and evidence snippet that produced it. + +**Parameters:** +| Name | Type | Required | Default | Description | +|------|------|----------|---------|-------------| +| memory_id | string | Conditionally* | — | Stable memory identifier | +| path | string | Conditionally* | — | Root or child memory path | +| entity | string | Conditionally* | — | Entity name or normalized entity key | +| collection | string | No | server default collection | Collection filter | +| limit | integer | No | 100 | Max entity mentions | +| user_id | string | No | — | User namespace | +| session_id | string | No | — | Session namespace | +| project_id | string | No | — | Project namespace | +| profile | string | No | — | Profile namespace | + +\* Provide at least one of `memory_id`, `path`, or `entity`. + +**Example Request:** +```json +{ + "path": "threads/customer-renewal", + "collection": "default" +} +``` + +**Example Response:** +```json +{ + "success": true, + "count": 2, + "entities": [ + { + "entity_key": "acme_robotics", + "name": "Acme Robotics", + "entity_type": "proper_noun", + "memory_id": "b4f7...", + "memory_root_path": "threads/customer-renewal", + "file_path": "threads/customer-renewal::turn:0002", + "evidence": "Acme Robotics asked to review the renewal timeline..." + } + ] +} +``` + +**Errors:** +- `INVALID_INPUT`: when no seed (`memory_id`, `path`, or `entity`) is provided. +- `BACKEND_ERROR`: when the storage backend does not support graph entities. +- `INTERNAL_ERROR`: uncaught exceptions. + +**Notes:** Entity keys are normalized for lookup. Evidence snippets are stored beside the graph rows so related-memory navigation can be traced back to source memories. + +--- + +## memory_graph_related + +**Description:** Find memories related by shared extracted entities. This is useful when two memories mention the same person, project, organization, ticket, URL, or topic even if ordinary lexical search would not rank them together. + +**Parameters:** +| Name | Type | Required | Default | Description | +|------|------|----------|---------|-------------| +| memory_id | string | Conditionally* | — | Stable memory identifier to use as the seed | +| path | string | Conditionally* | — | Root or child memory path to use as the seed | +| entity | string | Conditionally* | — | Entity name or normalized entity key to use as the seed | +| collection | string | No | server default collection | Collection filter | +| limit | integer | No | 20 | Max related memories | +| user_id | string | No | — | User namespace | +| session_id | string | No | — | Session namespace | +| project_id | string | No | — | Project namespace | +| profile | string | No | — | Profile namespace | + +\* Provide at least one of `memory_id`, `path`, or `entity`. + +**Example Request:** +```json +{ + "entity": "Acme Robotics", + "collection": "default", + "limit": 5 +} +``` + +**Example Response:** +```json +{ + "success": true, + "count": 1, + "related_memories": [ + { + "memory_id": "af31...", + "collection": "default", + "path": "notes/acme-budget", + "score": 11, + "shared_entities": [ + { "entity_key": "acme_robotics", "name": "Acme Robotics", "entity_type": "proper_noun" } + ], + "evidence": [ + { + "path": "notes/acme-budget", + "entity": "Acme Robotics", + "text": "The budget memo says Acme Robotics approved new sensors." + } + ] + } + ] +} +``` + +**Errors:** +- `INVALID_INPUT`: when no seed (`memory_id`, `path`, or `entity`) is provided. +- `BACKEND_ERROR`: when the storage backend does not support related memory graph lookup. +- `INTERNAL_ERROR`: uncaught exceptions. + +**Notes:** Relatedness is based on shared graph entity keys and includes evidence from the related memories. The score is intentionally simple: more distinct shared entities and mentions rank higher. + +--- + ## list_memories **Description:** List canonical root memories for a collection or namespace. @@ -1185,7 +1305,12 @@ Operation object schema: 2. Query with `search` or `explain_results`; matching turns roll up to the parent `memory_id`. 3. Inspect the full memory with `memory_get` when the agent needs turn evidence. -### 4) Configure mode, then ingest +### 4) Navigate the memory graph +1. Call `memory_graph_entities` for a `memory_id`, `path`, or entity name to inspect extracted entities and evidence. +2. Call `memory_graph_related` with the same seed to find memories that share those entities. +3. Use the returned evidence paths with `memory_get` or `search` when a client needs the full memory context. + +### 5) Configure mode, then ingest 1. Inspect config using `get_config`. 2. Set desired runtime defaults with `set_config` (for mode, default collection, max file size). 3. Run `ingest` without repeating shared defaults. diff --git a/src/recallforge/entities.py b/src/recallforge/entities.py new file mode 100644 index 0000000..5c31967 --- /dev/null +++ b/src/recallforge/entities.py @@ -0,0 +1,157 @@ +"""Lightweight entity and relation extraction for memory graph enrichment.""" + +from __future__ import annotations + +import hashlib +import re +from dataclasses import dataclass +from itertools import combinations +from typing import Iterable, Optional + + +_MAX_ENTITIES_PER_TEXT = 24 +_MAX_ENTITY_LEN = 80 +_MAX_EVIDENCE_CHARS = 360 + +_STOP_ENTITIES = { + "A", "An", "And", "Are", "As", "At", "By", "For", "From", "In", "Into", "Is", + "It", "Of", "On", "Or", "The", "This", "To", "With", + "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday", + "January", "February", "March", "April", "May", "June", "July", "August", + "September", "October", "November", "December", +} +_STOP_ENTITY_KEYS = {_entity.lower() for _entity in _STOP_ENTITIES} + +_PROPER_NOUN_RE = re.compile( + r"\b[A-Z][A-Za-z0-9&._-]{1,}(?:\s+[A-Z][A-Za-z0-9&._-]{1,}){0,4}\b" +) +_ACRONYM_RE = re.compile(r"\b[A-Z][A-Z0-9]{1,}(?:-[A-Z0-9]+)*\b") +_HANDLE_RE = re.compile(r"(?\]]+") +_DOMAIN_RE = re.compile(r"\b(?:[A-Za-z0-9-]+\.)+[A-Za-z]{2,}\b") + + +@dataclass(frozen=True) +class ExtractedEntity: + """One normalized entity mention with source evidence.""" + + name: str + entity_key: str + entity_type: str + evidence: str + + +@dataclass(frozen=True) +class ExtractedRelation: + """A lightweight relation edge between two entity mentions.""" + + subject_key: str + subject_name: str + object_key: str + object_name: str + relation_type: str + evidence: str + + +def _clean_entity(raw: str) -> str: + text = re.sub(r"\s+", " ", str(raw or "").strip()) + return text.strip(".,;:!?()[]{}\"'`") + + +def normalize_entity_key(name: str) -> str: + """Normalize an entity mention to a stable lookup key.""" + lowered = name.lower() + lowered = re.sub(r"^@", "", lowered) + lowered = re.sub(r"[^a-z0-9]+", "_", lowered).strip("_") + return lowered + + +def _classify_entity(name: str) -> str: + if name.startswith("@"): + return "person" + if _ISSUE_RE.fullmatch(name): + return "ticket" + if _URL_RE.fullmatch(name) or _DOMAIN_RE.fullmatch(name): + return "url" + if _ACRONYM_RE.fullmatch(name): + return "acronym" + if any(token in name.lower().split() for token in ("project", "program", "initiative")): + return "project" + return "proper_noun" + + +def _evidence_for(text: str, start: int, end: int) -> str: + left = max(0, start - 120) + right = min(len(text), end + 180) + snippet = re.sub(r"\s+", " ", text[left:right]).strip() + if len(snippet) > _MAX_EVIDENCE_CHARS: + snippet = snippet[: _MAX_EVIDENCE_CHARS - 3].rsplit(" ", 1)[0].strip() + "..." + return snippet + + +def _iter_entity_matches(text: str): + for pattern in (_URL_RE, _HANDLE_RE, _ISSUE_RE, _ACRONYM_RE, _PROPER_NOUN_RE, _DOMAIN_RE): + yield from pattern.finditer(text) + + +def extract_entities(text: str, *, max_entities: int = _MAX_ENTITIES_PER_TEXT) -> list[ExtractedEntity]: + """Extract deterministic entity mentions from text without external NLP deps.""" + if not isinstance(text, str) or not text.strip(): + return [] + + found: dict[str, ExtractedEntity] = {} + occupied_spans: list[tuple[int, int]] = [] + for match in sorted(_iter_entity_matches(text), key=lambda item: (item.start(), -(item.end() - item.start()))): + if any(match.start() < end and match.end() > start for start, end in occupied_spans): + continue + name = _clean_entity(match.group(0)) + if not name or len(name) > _MAX_ENTITY_LEN or name in _STOP_ENTITIES: + continue + key = normalize_entity_key(name) + if len(key) < 2 or key in _STOP_ENTITY_KEYS or key in found: + continue + found[key] = ExtractedEntity( + name=name, + entity_key=key, + entity_type=_classify_entity(name), + evidence=_evidence_for(text, match.start(), match.end()), + ) + occupied_spans.append((match.start(), match.end())) + if len(found) >= max_entities: + break + return list(found.values()) + + +def extract_relations( + entities: Iterable[ExtractedEntity], + *, + max_pairs: int = 48, +) -> list[ExtractedRelation]: + """Create co-mention relation edges for entities found in the same evidence unit.""" + unique: dict[str, ExtractedEntity] = {} + for entity in entities: + unique.setdefault(entity.entity_key, entity) + + relations: list[ExtractedRelation] = [] + for left, right in combinations(list(unique.values())[:12], 2): + evidence = left.evidence if len(left.evidence) >= len(right.evidence) else right.evidence + relations.append( + ExtractedRelation( + subject_key=left.entity_key, + subject_name=left.name, + object_key=right.entity_key, + object_name=right.name, + relation_type="co_mentions", + evidence=evidence, + ) + ) + if len(relations) >= max_pairs: + break + return relations + + +def stable_graph_id(*parts: Optional[str]) -> str: + """Build a stable hash ID for graph rows.""" + seed = "\x1f".join(str(part or "") for part in parts) + return hashlib.sha256(seed.encode("utf-8")).hexdigest() diff --git a/src/recallforge/server.py b/src/recallforge/server.py index 649b1e0..3b9ec0c 100644 --- a/src/recallforge/server.py +++ b/src/recallforge/server.py @@ -4,7 +4,8 @@ MCP protocol server with stdio or HTTP/SSE transport. Tools: search, search_fts, search_vec, explain_results, search_batch, ingest, index_document, index_image, index_audio, memory_add, memory_update, memory_delete, -memory_add_conversation, memory_get, list_memories, status, rebuild_fts, list_collections, +memory_add_conversation, memory_get, list_memories, memory_graph_entities, +memory_graph_related, status, rebuild_fts, list_collections, list_namespaces, rename_collection, delete_collection, batch, get_config, set_config. Resources expose canonical memories via memory:// URIs. @@ -233,6 +234,20 @@ def _get_memory_from_storage(storage, memory_id: Optional[str] = None, **kwargs) return get_memory(memory_id, **kwargs) +def _list_memory_entities_from_storage(storage, **kwargs) -> list[dict]: + list_memory_entities = getattr(storage, "list_memory_entities", None) + if not callable(list_memory_entities): + return [] + return list_memory_entities(**kwargs) + + +def _find_related_memories_from_storage(storage, **kwargs) -> list[dict]: + find_related_memories = getattr(storage, "find_related_memories", None) + if not callable(find_related_memories): + return [] + return find_related_memories(**kwargs) + + def _signal_handler(signum, frame): """Handle shutdown signals gracefully.""" global _shutdown_requested @@ -558,6 +573,42 @@ async def list_tools() -> list[Tool]: }, }, ), + Tool( + name="memory_graph_entities", + description="List extracted entity mentions for a memory, path, or entity key with source evidence", + inputSchema={ + "type": "object", + "properties": { + "memory_id": {"type": "string", "description": "Stable memory identifier to inspect"}, + "path": {"type": "string", "description": "Memory root path or child path to inspect"}, + "entity": {"type": "string", "description": "Optional entity name/key to navigate across memories"}, + "collection": {"type": "string", "description": "Optional collection filter"}, + "limit": {"type": "integer", "description": "Maximum entity mentions to return", "default": 100}, + "user_id": {"type": "string", "description": "Optional user namespace filter"}, + "session_id": {"type": "string", "description": "Optional session namespace filter"}, + "project_id": {"type": "string", "description": "Optional project namespace filter"}, + "profile": {"type": "string", "description": "Optional profile namespace filter"}, + }, + }, + ), + Tool( + name="memory_graph_related", + description="Find memories related by shared extracted entities, with supporting evidence", + inputSchema={ + "type": "object", + "properties": { + "memory_id": {"type": "string", "description": "Stable memory identifier to use as the seed"}, + "path": {"type": "string", "description": "Memory root path or child path to use as the seed"}, + "entity": {"type": "string", "description": "Entity name/key to use as the seed"}, + "collection": {"type": "string", "description": "Optional collection filter"}, + "limit": {"type": "integer", "description": "Maximum related memories to return", "default": 20}, + "user_id": {"type": "string", "description": "Optional user namespace filter"}, + "session_id": {"type": "string", "description": "Optional session namespace filter"}, + "project_id": {"type": "string", "description": "Optional project namespace filter"}, + "profile": {"type": "string", "description": "Optional profile namespace filter"}, + }, + }, + ), Tool( name="list_memories", description="List canonical root memories for a collection or namespace", @@ -833,6 +884,10 @@ async def _dispatch_tool( return await _handle_memory_delete(arguments, storage) elif name == "memory_get": return await _handle_memory_get(arguments, storage) + elif name == "memory_graph_entities": + return await _handle_memory_graph_entities(arguments, storage) + elif name == "memory_graph_related": + return await _handle_memory_graph_related(arguments, storage) elif name == "list_memories": return await _handle_list_memories(arguments, storage) elif name == "status": @@ -1747,6 +1802,102 @@ async def _handle_memory_get(arguments: dict, storage) -> list[TextContent]: return [TextContent(type="text", text=json.dumps(memory, indent=2))] +async def _handle_memory_graph_entities(arguments: dict, storage) -> list[TextContent]: + """Handle entity mention lookup for the memory graph.""" + memory_id = arguments.get("memory_id") + path = arguments.get("path") + entity = arguments.get("entity") + collection = arguments.get("collection") + limit = arguments.get("limit", 100) + user_id = arguments.get("user_id") + session_id = arguments.get("session_id") + project_id = arguments.get("project_id") + profile = arguments.get("profile") + + if not memory_id and not path and not entity: + return _error_response("INVALID_INPUT", "memory_id, path, or entity is required") + try: + limit = int(limit) + except (TypeError, ValueError): + return _error_response("INVALID_INPUT", "limit must be an integer") + + if not callable(getattr(storage, "list_memory_entities", None)): + return _error_response("BACKEND_ERROR", "Storage backend does not support memory graph entities") + + entities = await _run_blocking( + _list_memory_entities_from_storage, + storage, + memory_id=memory_id, + path=path, + entity=entity, + collection=collection, + user_id=user_id, + session_id=session_id, + project_id=project_id, + profile=profile, + limit=limit, + ) + + output = { + "success": True, + "count": len(entities), + "entities": entities, + "memory_id": memory_id, + "path": path, + "entity": entity, + "collection": collection, + } + return [TextContent(type="text", text=json.dumps(output, indent=2))] + + +async def _handle_memory_graph_related(arguments: dict, storage) -> list[TextContent]: + """Handle related-memory lookup through shared graph entities.""" + memory_id = arguments.get("memory_id") + path = arguments.get("path") + entity = arguments.get("entity") + collection = arguments.get("collection") + limit = arguments.get("limit", 20) + user_id = arguments.get("user_id") + session_id = arguments.get("session_id") + project_id = arguments.get("project_id") + profile = arguments.get("profile") + + if not memory_id and not path and not entity: + return _error_response("INVALID_INPUT", "memory_id, path, or entity is required") + try: + limit = int(limit) + except (TypeError, ValueError): + return _error_response("INVALID_INPUT", "limit must be an integer") + + if not callable(getattr(storage, "find_related_memories", None)): + return _error_response("BACKEND_ERROR", "Storage backend does not support related memory graph lookup") + + related = await _run_blocking( + _find_related_memories_from_storage, + storage, + memory_id=memory_id, + path=path, + entity=entity, + collection=collection, + user_id=user_id, + session_id=session_id, + project_id=project_id, + profile=profile, + limit=limit, + ) + + output = { + "success": True, + "count": len(related), + "related_memories": related, + "memory_id": memory_id, + "path": path, + "entity": entity, + "collection": collection, + } + return [TextContent(type="text", text=json.dumps(output, indent=2))] + + async def _handle_get_config(backend, storage, mutable_config: dict) -> list[TextContent]: """Return current server configuration.""" info = await _run_blocking(backend.get_info) diff --git a/src/recallforge/storage/base.py b/src/recallforge/storage/base.py index d1cc29c..a2a9bf2 100644 --- a/src/recallforge/storage/base.py +++ b/src/recallforge/storage/base.py @@ -298,6 +298,40 @@ def index_conversation( """Index a conversation root plus turn-level child memories.""" pass + @abstractmethod + def list_memory_entities( + self, + *, + memory_id: Optional[str] = None, + path: Optional[str] = None, + entity: Optional[str] = None, + collection: Optional[str] = None, + user_id: Optional[str] = None, + session_id: Optional[str] = None, + project_id: Optional[str] = None, + profile: Optional[str] = None, + limit: int = 100, + ) -> List[Dict[str, Any]]: + """List extracted entity mentions with source evidence.""" + pass + + @abstractmethod + def find_related_memories( + self, + *, + memory_id: Optional[str] = None, + path: Optional[str] = None, + entity: Optional[str] = None, + collection: Optional[str] = None, + user_id: Optional[str] = None, + session_id: Optional[str] = None, + project_id: Optional[str] = None, + profile: Optional[str] = None, + limit: int = 20, + ) -> List[Dict[str, Any]]: + """Find memories related by shared extracted entities.""" + pass + @abstractmethod def delete_memory( self, diff --git a/src/recallforge/storage/indexing_ops.py b/src/recallforge/storage/indexing_ops.py index 8cd244c..0628770 100644 --- a/src/recallforge/storage/indexing_ops.py +++ b/src/recallforge/storage/indexing_ops.py @@ -362,6 +362,17 @@ def upsert_memory( self._backend._embeddings_table.delete(del_filter) except Exception as e: logger.warning(f"upsert_memory: failed to delete old vectors for {collection}/{normalized_path}: {e}") + try: + self._backend.delete_graph_entries( + collection=collection, + logical_path=normalized_path, + user_id=user_id, + session_id=session_id, + project_id=project_id, + profile=profile, + ) + except Exception as e: + logger.warning(f"upsert_memory: failed to delete old graph rows for {collection}/{normalized_path}: {e}") self._backend.insert_content(content_hash, text, "text") self._backend.insert_document( @@ -599,6 +610,17 @@ def delete_memory( self._backend._embeddings_table.delete(del_filter) except Exception as e: logger.error(f"delete_memory: failed to delete embeddings for {collection}/{normalized_path}: {e}") + try: + self._backend.delete_graph_entries( + collection=collection, + logical_path=normalized_path, + user_id=user_id, + session_id=session_id, + project_id=project_id, + profile=profile, + ) + except Exception as e: + logger.error(f"delete_memory: failed to delete graph rows for {collection}/{normalized_path}: {e}") self._backend.deactivate_document(collection, normalized_path) @@ -801,6 +823,18 @@ def _delete_path_entries( self._backend._embeddings_table.delete(filter_clause) except Exception as e: logger.debug(f"_delete_path_entries: failed to delete embeddings for {logical_path}: {e}") + try: + self._backend.delete_graph_entries( + collection=collection, + logical_path=logical_path, + user_id=user_id, + session_id=session_id, + project_id=project_id, + profile=profile, + include_children=include_children, + ) + except Exception as e: + logger.debug(f"_delete_path_entries: failed to delete graph rows for {logical_path}: {e}") doc_filters = self._namespace_filters( collection=collection, diff --git a/src/recallforge/storage/lancedb_backend.py b/src/recallforge/storage/lancedb_backend.py index 2db6f82..78e962e 100644 --- a/src/recallforge/storage/lancedb_backend.py +++ b/src/recallforge/storage/lancedb_backend.py @@ -23,6 +23,7 @@ import pyarrow as pa import lancedb +from ..entities import extract_entities, extract_relations, normalize_entity_key, stable_graph_id from .base import StorageBackend, Document from .chunking import ( BREAK_PATTERNS, @@ -102,6 +103,8 @@ def __init__(self, store_path: Optional[str] = None): self._documents_table = None self._content_table = None self._cache_table = None + self._entities_table = None + self._relations_table = None # FTS rebuild debouncing state self._fts_rebuild_pending = 0 @@ -195,6 +198,34 @@ def initialize(self, store_path: Optional[str] = None) -> None: self._cache_table = self._conn.open_table("cache") else: raise + + if "entities" in existing: + self._entities_table = self._conn.open_table("entities") + else: + try: + self._entities_table = self._conn.create_table( + "entities", + schema=self._build_entities_schema() + ) + except ValueError as e: + if "already exists" in str(e): + self._entities_table = self._conn.open_table("entities") + else: + raise + + if "relations" in existing: + self._relations_table = self._conn.open_table("relations") + else: + try: + self._relations_table = self._conn.create_table( + "relations", + schema=self._build_relations_schema() + ) + except ValueError as e: + if "already exists" in str(e): + self._relations_table = self._conn.open_table("relations") + else: + raise self._ensure_indices() self._migrate_schema() @@ -224,6 +255,8 @@ def _migrate_schema(self) -> None: tables_and_schemas = [ (self._embeddings_table, "embeddings", self._build_embeddings_schema()), (self._documents_table, "documents", self._build_documents_schema()), + (self._entities_table, "entities", self._build_entities_schema()), + (self._relations_table, "relations", self._build_relations_schema()), ] for table, table_name, target_schema in tables_and_schemas: @@ -280,6 +313,8 @@ def close(self) -> None: self._documents_table = None self._content_table = None self._cache_table = None + self._entities_table = None + self._relations_table = None # ========================================================================= # Schema Definitions @@ -353,6 +388,52 @@ def _build_cache_schema(self) -> pa.Schema: pa.field("value", pa.string(), nullable=False), pa.field("created_at", pa.int64(), nullable=False), ]) + + def _build_entities_schema(self) -> pa.Schema: + """Schema for extracted entity mentions.""" + return pa.schema([ + pa.field("id", pa.string(), nullable=False), + pa.field("collection", pa.string(), nullable=False), + pa.field("entity_key", pa.string(), nullable=False), + pa.field("name", pa.string(), nullable=False), + pa.field("entity_type", pa.string(), nullable=False), + pa.field("memory_id", pa.string(), nullable=True), + pa.field("memory_root_path", pa.string(), nullable=True), + pa.field("file_path", pa.string(), nullable=False), + pa.field("content_hash", pa.string(), nullable=False), + pa.field("hash_seq", pa.string(), nullable=False), + pa.field("seq", pa.int32(), nullable=False), + pa.field("evidence", pa.string(), nullable=True), + pa.field("user_id", pa.string(), nullable=True), + pa.field("session_id", pa.string(), nullable=True), + pa.field("project_id", pa.string(), nullable=True), + pa.field("profile", pa.string(), nullable=True), + pa.field("created_at", pa.int64(), nullable=False), + ]) + + def _build_relations_schema(self) -> pa.Schema: + """Schema for lightweight relation edges between entity mentions.""" + return pa.schema([ + pa.field("id", pa.string(), nullable=False), + pa.field("collection", pa.string(), nullable=False), + pa.field("subject_key", pa.string(), nullable=False), + pa.field("subject_name", pa.string(), nullable=False), + pa.field("object_key", pa.string(), nullable=False), + pa.field("object_name", pa.string(), nullable=False), + pa.field("relation_type", pa.string(), nullable=False), + pa.field("memory_id", pa.string(), nullable=True), + pa.field("memory_root_path", pa.string(), nullable=True), + pa.field("file_path", pa.string(), nullable=False), + pa.field("content_hash", pa.string(), nullable=False), + pa.field("hash_seq", pa.string(), nullable=False), + pa.field("seq", pa.int32(), nullable=False), + pa.field("evidence", pa.string(), nullable=True), + pa.field("user_id", pa.string(), nullable=True), + pa.field("session_id", pa.string(), nullable=True), + pa.field("project_id", pa.string(), nullable=True), + pa.field("profile", pa.string(), nullable=True), + pa.field("created_at", pa.int64(), nullable=False), + ]) def _has_scalar_index(self, table, column: str) -> bool: """Return True if a scalar index already exists for a column.""" @@ -410,6 +491,32 @@ def _ensure_indices(self) -> None: ) except Exception as e: logger.warning(f"_ensure_indices: failed to create cache index: {e}") + + try: + self._create_scalar_index_safe( + self._entities_table, "entity_key", "entity_key_scalar", "entities" + ) + self._create_scalar_index_safe( + self._entities_table, "memory_id", "memory_id_scalar", "entities" + ) + self._create_scalar_index_safe( + self._entities_table, "memory_root_path", "memory_root_path_scalar", "entities" + ) + except Exception as e: + logger.warning(f"_ensure_indices: failed to create entity indices: {e}") + + try: + self._create_scalar_index_safe( + self._relations_table, "subject_key", "subject_key_scalar", "relations" + ) + self._create_scalar_index_safe( + self._relations_table, "object_key", "object_key_scalar", "relations" + ) + self._create_scalar_index_safe( + self._relations_table, "memory_id", "memory_id_scalar", "relations" + ) + except Exception as e: + logger.warning(f"_ensure_indices: failed to create relation indices: {e}") def _ensure_bulk_buffers(self) -> None: """Initialize bulk-write buffers for tests that construct via __new__.""" @@ -730,6 +837,7 @@ def insert_embedding( self._embeddings_table.delete(_safe_filter("hash_seq", hash_seq)) except Exception as e: logger.debug(f"insert_embedding: no existing embedding to delete for {hash_seq}: {e}") + self.delete_graph_entries(hash_seq=hash_seq) trace_log("insert_embedding", hash_seq=hash_seq, collection=collection, file_path=file_path, seq=seq, user_id=user_id, session_id=session_id, project_id=project_id, profile=profile, @@ -766,6 +874,120 @@ def insert_embedding( self._flush_pending_writes(force=False) else: self._embeddings_table.add(pa.Table.from_pylist([row], schema=self._build_embeddings_schema())) + + self._index_graph_rows_for_embedding( + collection=collection, + file_path=file_path, + content_hash=content_hash, + hash_seq=hash_seq, + seq=seq, + text_body=text_body, + user_id=user_id, + session_id=session_id, + project_id=project_id, + profile=profile, + memory_id=normalized_memory_id, + memory_root_path=normalized_memory_root_path, + created_at=now, + ) + + def _index_graph_rows_for_embedding( + self, + *, + collection: str, + file_path: str, + content_hash: str, + hash_seq: str, + seq: int, + text_body: str, + user_id: Optional[str], + session_id: Optional[str], + project_id: Optional[str], + profile: Optional[str], + memory_id: Optional[str], + memory_root_path: Optional[str], + created_at: int, + ) -> None: + """Persist deterministic entity/relation rows for one indexed evidence unit.""" + if getattr(self, "_entities_table", None) is None or not isinstance(text_body, str) or not text_body.strip(): + return + + entities = extract_entities(text_body) + if not entities: + return + + root_path = memory_root_path or file_path + entity_rows = [ + { + "id": stable_graph_id("entity", hash_seq, entity.entity_key), + "collection": collection, + "entity_key": entity.entity_key, + "name": entity.name, + "entity_type": entity.entity_type, + "memory_id": memory_id, + "memory_root_path": root_path, + "file_path": file_path, + "content_hash": content_hash, + "hash_seq": hash_seq, + "seq": seq, + "evidence": entity.evidence, + "user_id": user_id, + "session_id": session_id, + "project_id": project_id, + "profile": profile, + "created_at": created_at, + } + for entity in entities + ] + + try: + self._entities_table.add(pa.Table.from_pylist(entity_rows, schema=self._build_entities_schema())) + except Exception as exc: + logger.warning("insert_embedding: failed to index entity graph rows for %s: %s", hash_seq, exc) + return + + if getattr(self, "_relations_table", None) is None: + return + + relations = extract_relations(entities) + if not relations: + return + + relation_rows = [ + { + "id": stable_graph_id( + "relation", + hash_seq, + relation.subject_key, + relation.object_key, + relation.relation_type, + ), + "collection": collection, + "subject_key": relation.subject_key, + "subject_name": relation.subject_name, + "object_key": relation.object_key, + "object_name": relation.object_name, + "relation_type": relation.relation_type, + "memory_id": memory_id, + "memory_root_path": root_path, + "file_path": file_path, + "content_hash": content_hash, + "hash_seq": hash_seq, + "seq": seq, + "evidence": relation.evidence, + "user_id": user_id, + "session_id": session_id, + "project_id": project_id, + "profile": profile, + "created_at": created_at, + } + for relation in relations + ] + + try: + self._relations_table.add(pa.Table.from_pylist(relation_rows, schema=self._build_relations_schema())) + except Exception as exc: + logger.warning("insert_embedding: failed to index relation graph rows for %s: %s", hash_seq, exc) def has_vectors(self) -> bool: """Check if index has any vectors.""" @@ -862,7 +1084,15 @@ def rename_collection( if not old_name or not new_name: raise ValueError("old_name and new_name are required") if old_name == new_name: - return {"success": True, "old_name": old_name, "new_name": new_name, "embeddings_updated": 0, "documents_updated": 0} + return { + "success": True, + "old_name": old_name, + "new_name": new_name, + "embeddings_updated": 0, + "documents_updated": 0, + "entities_updated": 0, + "relations_updated": 0, + } # Check if target collection already exists existing_collections = self.list_collections() @@ -930,6 +1160,44 @@ def rename_collection( logger.error(f"rename_collection: failed to update documents: {e}") raise + entities_updated = 0 + if self._entities_table is not None: + try: + all_entity_rows = list( + self._entities_table.search() + .where(_safe_filter("collection", old_name)) + .limit(10_000_000) + .to_list() + ) + entities_updated = len(all_entity_rows) + if all_entity_rows: + self._entities_table.delete(_safe_filter("collection", old_name)) + for row in all_entity_rows: + row["collection"] = new_name + self._entities_table.add(pa.Table.from_pylist(all_entity_rows, schema=self._build_entities_schema())) + except Exception as e: + logger.error(f"rename_collection: failed to update entities: {e}") + raise + + relations_updated = 0 + if self._relations_table is not None: + try: + all_relation_rows = list( + self._relations_table.search() + .where(_safe_filter("collection", old_name)) + .limit(10_000_000) + .to_list() + ) + relations_updated = len(all_relation_rows) + if all_relation_rows: + self._relations_table.delete(_safe_filter("collection", old_name)) + for row in all_relation_rows: + row["collection"] = new_name + self._relations_table.add(pa.Table.from_pylist(all_relation_rows, schema=self._build_relations_schema())) + except Exception as e: + logger.error(f"rename_collection: failed to update relations: {e}") + raise + # Schedule FTS rebuild since we modified embeddings self._schedule_fts_rebuild() @@ -939,6 +1207,8 @@ def rename_collection( "new_name": new_name, "embeddings_updated": embeddings_updated, "documents_updated": documents_updated, + "entities_updated": entities_updated, + "relations_updated": relations_updated, } def delete_collection( @@ -1000,6 +1270,38 @@ def delete_collection( logger.error(f"delete_collection: failed to delete documents: {e}") raise + entities_deleted = 0 + if self._entities_table is not None: + try: + entities_filter = _safe_filter("collection", name) + entities_deleted = len( + self._entities_table.search() + .where(entities_filter) + .select(["id"]) + .limit(10_000_000) + .to_list() + ) + self._entities_table.delete(entities_filter) + except Exception as e: + logger.error(f"delete_collection: failed to delete entities: {e}") + raise + + relations_deleted = 0 + if self._relations_table is not None: + try: + relations_filter = _safe_filter("collection", name) + relations_deleted = len( + self._relations_table.search() + .where(relations_filter) + .select(["id"]) + .limit(10_000_000) + .to_list() + ) + self._relations_table.delete(relations_filter) + except Exception as e: + logger.error(f"delete_collection: failed to delete relations: {e}") + raise + # Clean up orphaned content entries orphans_cleaned = 0 if content_hashes_to_check: @@ -1030,6 +1332,8 @@ def delete_collection( "name": name, "embeddings_deleted": embeddings_deleted, "documents_deleted": documents_deleted, + "entities_deleted": entities_deleted, + "relations_deleted": relations_deleted, "orphans_cleaned": orphans_cleaned, } @@ -1078,6 +1382,270 @@ def _memory_path_clause(self, path: str) -> str: f"OR file_path LIKE '{escaped_path}::%')" ) + def _graph_namespace_filters( + self, + *, + collection: Optional[str] = None, + user_id: Optional[str] = None, + session_id: Optional[str] = None, + project_id: Optional[str] = None, + profile: Optional[str] = None, + ) -> List[str]: + filters: List[str] = [] + if collection: + filters.append(_safe_filter("collection", collection)) + if user_id is not None: + filters.append(_safe_filter("user_id", user_id)) + if session_id is not None: + filters.append(_safe_filter("session_id", session_id)) + if project_id is not None: + filters.append(_safe_filter("project_id", project_id)) + if profile is not None: + filters.append(_safe_filter("profile", profile)) + return filters + + def delete_graph_entries( + self, + *, + collection: Optional[str] = None, + logical_path: Optional[str] = None, + user_id: Optional[str] = None, + session_id: Optional[str] = None, + project_id: Optional[str] = None, + profile: Optional[str] = None, + include_children: bool = False, + hash_seq: Optional[str] = None, + ) -> int: + """Delete entity/relation rows for an indexed chunk or logical memory path.""" + filters = self._graph_namespace_filters( + collection=collection, + user_id=user_id, + session_id=session_id, + project_id=project_id, + profile=profile, + ) + if hash_seq: + filters.append(_safe_filter("hash_seq", hash_seq)) + if logical_path: + if include_children: + filters.append(self._memory_path_clause(logical_path)) + else: + validated_path = _validate_identifier(logical_path, "logical_path") + escaped_path = validated_path.replace("'", "''") + filters.append(f"file_path = '{escaped_path}'") + if not filters: + return 0 + + filter_clause = " AND ".join(filters) + removed = 0 + for table, label in ( + (getattr(self, "_entities_table", None), "entities"), + (getattr(self, "_relations_table", None), "relations"), + ): + if table is None: + continue + try: + removed += len(table.search().where(filter_clause).select(["id"]).limit(10_000_000).to_list()) + except Exception as exc: + logger.debug("delete_graph_entries: failed to count %s rows: %s", label, exc) + try: + table.delete(filter_clause) + except Exception as exc: + logger.debug("delete_graph_entries: failed to delete %s rows: %s", label, exc) + return removed + + def list_memory_entities( + self, + *, + memory_id: Optional[str] = None, + path: Optional[str] = None, + entity: Optional[str] = None, + collection: Optional[str] = None, + user_id: Optional[str] = None, + session_id: Optional[str] = None, + project_id: Optional[str] = None, + profile: Optional[str] = None, + limit: int = 100, + ) -> List[Dict[str, Any]]: + """List entity mentions with evidence for a memory, path, or entity key.""" + if getattr(self, "_entities_table", None) is None: + return [] + + filters = self._graph_namespace_filters( + collection=collection, + user_id=user_id, + session_id=session_id, + project_id=project_id, + profile=profile, + ) + if memory_id: + filters.append(_safe_filter("memory_id", memory_id)) + if path: + filters.append(self._memory_path_clause(path)) + if entity: + normalized_entity = normalize_entity_key(entity) + if normalized_entity: + filters.append(_safe_filter("entity_key", normalized_entity)) + + query = self._entities_table.search() + if filters: + query = query.where(" AND ".join(filters)) + try: + rows = list( + query.select([ + "entity_key", "name", "entity_type", "memory_id", "memory_root_path", + "file_path", "content_hash", "hash_seq", "seq", "evidence", + "collection", "user_id", "session_id", "project_id", "profile", + "created_at", + ]) + .limit(max(1, min(int(limit or 100), 5000))) + .to_list() + ) + except Exception as exc: + logger.warning("list_memory_entities: graph lookup failed: %s", exc) + return [] + + rows.sort( + key=lambda row: ( + str(row.get("name", "")).lower(), + row.get("memory_root_path") or row.get("file_path") or "", + row.get("seq", 0) or 0, + ) + ) + return rows + + def find_related_memories( + self, + *, + memory_id: Optional[str] = None, + path: Optional[str] = None, + entity: Optional[str] = None, + collection: Optional[str] = None, + user_id: Optional[str] = None, + session_id: Optional[str] = None, + project_id: Optional[str] = None, + profile: Optional[str] = None, + limit: int = 20, + ) -> List[Dict[str, Any]]: + """Find memories that share graph entities with a seed memory/path/entity.""" + if getattr(self, "_entities_table", None) is None: + return [] + + if entity: + seed_rows = [] + seed_keys = {normalize_entity_key(entity)} + else: + seed_rows = self.list_memory_entities( + memory_id=memory_id, + path=path, + collection=collection, + user_id=user_id, + session_id=session_id, + project_id=project_id, + profile=profile, + limit=250, + ) + seed_keys = {row.get("entity_key") for row in seed_rows if row.get("entity_key")} + + seed_keys = {key for key in seed_keys if key} + if not seed_keys: + return [] + + seed_memory_ids = {row.get("memory_id") for row in seed_rows if row.get("memory_id")} + seed_paths = { + row.get("memory_root_path") or row.get("file_path") + for row in seed_rows + if row.get("memory_root_path") or row.get("file_path") + } + if memory_id: + seed_memory_ids.add(memory_id) + if path: + seed_paths.add(path) + + filters = self._graph_namespace_filters( + collection=collection, + user_id=user_id, + session_id=session_id, + project_id=project_id, + profile=profile, + ) + filters.append("(" + " OR ".join(_safe_filter("entity_key", key) for key in sorted(seed_keys)) + ")") + try: + rows = list( + self._entities_table.search() + .where(" AND ".join(filters)) + .select([ + "entity_key", "name", "entity_type", "memory_id", "memory_root_path", + "file_path", "evidence", "collection", "seq", + "user_id", "session_id", "project_id", "profile", + ]) + .limit(max(500, min(10_000, int(limit or 20) * 250))) + .to_list() + ) + except Exception as exc: + logger.warning("find_related_memories: graph lookup failed: %s", exc) + return [] + + grouped: Dict[str, Dict[str, Any]] = {} + for row in rows: + root_path = row.get("memory_root_path") or row.get("file_path") + row_memory_id = row.get("memory_id") + if (row_memory_id and row_memory_id in seed_memory_ids) or (root_path and root_path in seed_paths): + continue + if not root_path: + continue + related_id = row_memory_id or build_memory_id( + row.get("collection", collection or ""), + root_path, + user_id=row.get("user_id"), + session_id=row.get("session_id"), + project_id=row.get("project_id"), + profile=row.get("profile"), + ) + bucket = grouped.setdefault( + related_id, + { + "memory_id": related_id, + "collection": row.get("collection"), + "path": root_path, + "score": 0, + "_entities": {}, + "evidence": [], + "_mentions": 0, + }, + ) + entity_key = row.get("entity_key") + if entity_key and entity_key not in bucket["_entities"]: + bucket["_entities"][entity_key] = { + "entity_key": entity_key, + "name": row.get("name"), + "entity_type": row.get("entity_type"), + } + bucket["_mentions"] += 1 + if len(bucket["evidence"]) < 5: + bucket["evidence"].append( + { + "path": row.get("file_path"), + "entity_key": entity_key, + "entity": row.get("name"), + "text": row.get("evidence"), + } + ) + + related = [] + for bucket in grouped.values(): + shared_entities = list(bucket.pop("_entities").values()) + mentions = bucket.pop("_mentions") + bucket["shared_entities"] = sorted( + shared_entities, + key=lambda item: str(item.get("name") or item.get("entity_key") or "").lower(), + ) + bucket["score"] = len(shared_entities) * 10 + mentions + related.append(bucket) + + related.sort(key=lambda item: (-item["score"], item.get("path") or "")) + return related[: max(1, min(int(limit or 20), 200))] + def _memory_root_key(self, row: Dict[str, Any]) -> Optional[str]: """Return the canonical root path for a memory or derived asset row.""" root_path = row.get("memory_root_path") @@ -1643,7 +2211,10 @@ def index_conversation( ttl_seconds=ttl_seconds, tags=tags, ) - + + # list_memory_entities and find_related_memories are implemented directly + # above because they query graph side tables rather than the indexing service. + def delete_path( self, path: str, diff --git a/tests/test_batch_tool.py b/tests/test_batch_tool.py index ee4124f..b25095b 100644 --- a/tests/test_batch_tool.py +++ b/tests/test_batch_tool.py @@ -279,6 +279,7 @@ async def test_all_existing_tools_still_present(self): "search", "search_fts", "search_vec", "ingest", "index_document", "index_image", "memory_add", "memory_add_conversation", "memory_update", "memory_delete", + "memory_graph_entities", "memory_graph_related", "status", "rebuild_fts", "batch", } self.assertTrue(expected.issubset(set(names)), f"Missing tools: {expected - set(names)}") diff --git a/tests/test_config_tools.py b/tests/test_config_tools.py index ccf7016..626f9b9 100644 --- a/tests/test_config_tools.py +++ b/tests/test_config_tools.py @@ -27,6 +27,8 @@ _handle_memory_get, _handle_search, _handle_memory_add_conversation, + _handle_memory_graph_entities, + _handle_memory_graph_related, _resolve_file_query_input, _handle_set_config, create_server, @@ -107,6 +109,27 @@ def _make_storage(store_path="/tmp/test-store"): "indexed_turns": 2, "tags": ["conversation"], } + s.list_memory_entities.return_value = [ + { + "entity_key": "acme_robotics", + "name": "Acme Robotics", + "entity_type": "proper_noun", + "memory_id": "mem-123", + "memory_root_path": "notes/demo.md", + "file_path": "notes/demo.md", + "evidence": "Acme Robotics launch note", + } + ] + s.find_related_memories.return_value = [ + { + "memory_id": "mem-456", + "collection": "default", + "path": "notes/related.md", + "score": 11, + "shared_entities": [{"entity_key": "acme_robotics", "name": "Acme Robotics"}], + "evidence": [{"path": "notes/related.md", "entity": "Acme Robotics", "text": "related"}], + } + ] return s @@ -443,6 +466,33 @@ async def test_memory_add_conversation_dispatched(self): self.assertEqual(data["operation"], "add_conversation") self.storage.index_conversation.assert_called_once() + async def test_memory_graph_entities_dispatched(self): + result = await _dispatch_tool( + "memory_graph_entities", + {"path": "notes/demo.md"}, + self.backend, + self.storage, + {}, + ) + data = json.loads(result[0].text) + self.assertTrue(data["success"]) + self.assertEqual(data["count"], 1) + self.assertEqual(data["entities"][0]["entity_key"], "acme_robotics") + self.storage.list_memory_entities.assert_called_once() + + async def test_memory_graph_related_dispatched(self): + result = await _dispatch_tool( + "memory_graph_related", + {"entity": "Acme Robotics"}, + self.backend, + self.storage, + {}, + ) + data = json.loads(result[0].text) + self.assertTrue(data["success"]) + self.assertEqual(data["count"], 1) + self.storage.find_related_memories.assert_called_once() + # --------------------------------------------------------------------------- # create_server: new tools appear in list_tools @@ -475,6 +525,8 @@ async def test_memory_tools_registered(self): self.assertIn("memory_add_conversation", names) self.assertIn("memory_get", names) self.assertIn("list_memories", names) + self.assertIn("memory_graph_entities", names) + self.assertIn("memory_graph_related", names) async def test_all_original_tools_still_present(self): names = await self._get_tool_names() @@ -482,6 +534,7 @@ async def test_all_original_tools_still_present(self): "search", "search_fts", "search_vec", "ingest", "index_document", "index_image", "index_audio", "memory_add", "memory_add_conversation", "memory_update", "memory_delete", + "memory_graph_entities", "memory_graph_related", "status", "rebuild_fts", "batch", "list_collections", "list_namespaces", "rename_collection", "delete_collection", @@ -529,8 +582,20 @@ async def test_list_memories_returns_json(self): data = json.loads(result[0].text) self.assertTrue(data["success"]) self.assertEqual(data["count"], 1) - self.assertEqual(data["memories"][0]["memory_id"], "mem-123") - self.assertEqual(data["memories"][0]["summary"], "Demo summary") + + async def test_memory_graph_entities_returns_json(self): + result = await _handle_memory_graph_entities({"path": "notes/demo.md"}, _make_storage()) + data = json.loads(result[0].text) + self.assertTrue(data["success"]) + self.assertEqual(data["count"], 1) + self.assertEqual(data["entities"][0]["entity_key"], "acme_robotics") + + async def test_memory_graph_related_returns_json(self): + result = await _handle_memory_graph_related({"entity": "Acme Robotics"}, _make_storage()) + data = json.loads(result[0].text) + self.assertTrue(data["success"]) + self.assertEqual(data["count"], 1) + self.assertEqual(data["related_memories"][0]["path"], "notes/related.md") async def test_memory_get_by_id_returns_json(self): result = await _handle_memory_get({"memory_id": "mem-123"}, _make_storage()) diff --git a/tests/test_entities.py b/tests/test_entities.py new file mode 100644 index 0000000..fe6a296 --- /dev/null +++ b/tests/test_entities.py @@ -0,0 +1,49 @@ +"""Tests for lightweight entity and relation extraction.""" + +from recallforge.entities import ( + extract_entities, + extract_relations, + normalize_entity_key, + stable_graph_id, +) + + +def test_extract_entities_deduplicates_and_classifies_common_memory_entities(): + text = ( + "Alice from Acme Robotics discussed REC-76 with @brian. " + "The notes live at https://recallforge.dev/roadmap." + ) + + entities = extract_entities(text) + by_key = {entity.entity_key: entity for entity in entities} + + assert "alice" in by_key + assert by_key["alice"].entity_type == "proper_noun" + assert "acme_robotics" in by_key + assert "rec_76" in by_key + assert by_key["rec_76"].entity_type == "ticket" + assert "brian" in by_key + assert by_key["brian"].entity_type == "person" + assert "https_recallforge_dev_roadmap" in by_key + assert by_key["https_recallforge_dev_roadmap"].entity_type == "url" + assert "Acme Robotics" in by_key["acme_robotics"].evidence + + +def test_extract_relations_creates_traceable_co_mentions(): + entities = extract_entities("Alice and Bob met at Acme Robotics for Project Atlas.") + + relations = extract_relations(entities) + + assert relations + assert any( + relation.relation_type == "co_mentions" + and {relation.subject_key, relation.object_key} >= {"alice", "bob"} + for relation in relations + ) + assert all(relation.evidence for relation in relations) + + +def test_normalize_entity_key_and_stable_graph_id_are_deterministic(): + assert normalize_entity_key("@Brian Meyer") == "brian_meyer" + assert stable_graph_id("entity", "hash_0", "acme") == stable_graph_id("entity", "hash_0", "acme") + assert stable_graph_id("entity", "hash_0", "acme") != stable_graph_id("entity", "hash_1", "acme") diff --git a/tests/test_json_compliance.py b/tests/test_json_compliance.py index 1f69dfb..ed06989 100644 --- a/tests/test_json_compliance.py +++ b/tests/test_json_compliance.py @@ -23,6 +23,8 @@ _handle_memory_add_conversation, _handle_memory_delete, _handle_memory_get, + _handle_memory_graph_entities, + _handle_memory_graph_related, _handle_memory_update, _handle_rebuild_fts, _handle_search, @@ -133,6 +135,12 @@ def get_memory(self, memory_id=None, path=None, **_kwargs): "snippets": [], } + def list_memory_entities(self, **_kwargs): + return [{"entity_key": "acme_robotics", "name": "Acme Robotics", "evidence": "Acme Robotics note"}] + + def find_related_memories(self, **_kwargs): + return [{"memory_id": "mem-2", "path": "mem/2", "shared_entities": [{"entity_key": "acme_robotics"}]}] + class _FakeSearchResult: def __init__(self, filepath="/tmp/a.txt"): @@ -277,6 +285,14 @@ async def test_all_tool_handlers_valid_and_invalid_calls_return_json(self): lambda: _handle_memory_get({"memory_id": "mem-1"}, self.storage), lambda: _handle_memory_get({"path": "missing"}, self.storage), ), + ( + lambda: _handle_memory_graph_entities({"path": "mem/1"}, self.storage), + lambda: _handle_memory_graph_entities({}, self.storage), + ), + ( + lambda: _handle_memory_graph_related({"entity": "Acme Robotics"}, self.storage), + lambda: _handle_memory_graph_related({}, self.storage), + ), ] for valid_call, invalid_call in cases: diff --git a/tests/test_storage.py b/tests/test_storage.py index 93ec706..7c28e6f 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -399,6 +399,60 @@ def test_upsert_memory_skip_delete_still_prevents_stale_duplicates(self): self.assertEqual(len(hashes), 1) self.assertGreater(len(rows), 0) + def test_memory_graph_entities_and_related_memories_track_evidence(self): + self.backend.upsert_memory( + path="notes/acme-review.md", + text="Mira from Acme Robotics tracks the launch review for Project Atlas.", + collection="test", + embed_func=mock_embed, + model="mock-embedder", + ) + self.backend.upsert_memory( + path="notes/acme-budget.md", + text="The budget memo says Acme Robotics approved new sensors.", + collection="test", + embed_func=mock_embed, + model="mock-embedder", + ) + + entities = self.backend.list_memory_entities(path="notes/acme-review.md", collection="test") + entity_keys = {row["entity_key"] for row in entities} + self.assertIn("acme_robotics", entity_keys) + self.assertTrue(any(row["evidence"] and "Acme Robotics" in row["evidence"] for row in entities)) + + related = self.backend.find_related_memories(path="notes/acme-review.md", collection="test") + self.assertTrue(related) + self.assertEqual(related[0]["path"], "notes/acme-budget.md") + self.assertIn("acme_robotics", {item["entity_key"] for item in related[0]["shared_entities"]}) + self.assertTrue(related[0]["evidence"]) + + def test_memory_graph_rows_are_replaced_on_memory_update(self): + self.backend.upsert_memory( + path="notes/company.md", + text="Mira from Acme Robotics owns the launch checklist.", + collection="test", + embed_func=mock_embed, + model="mock-embedder", + ) + self.assertTrue( + self.backend.list_memory_entities(path="notes/company.md", entity="Acme Robotics", collection="test") + ) + + self.backend.upsert_memory( + path="notes/company.md", + text="Mira from Globex Labs owns the launch checklist.", + collection="test", + embed_func=mock_embed, + model="mock-embedder", + ) + + self.assertFalse( + self.backend.list_memory_entities(path="notes/company.md", entity="Acme Robotics", collection="test") + ) + self.assertTrue( + self.backend.list_memory_entities(path="notes/company.md", entity="Globex Labs", collection="test") + ) + def test_delete_memory_deactivates_doc_and_removes_embeddings(self): self.backend.upsert_memory( path="notes/delete-me.md", diff --git a/tests/uat/test_mcp_server.sh b/tests/uat/test_mcp_server.sh index 602b536..739f89a 100755 --- a/tests/uat/test_mcp_server.sh +++ b/tests/uat/test_mcp_server.sh @@ -259,6 +259,7 @@ async def test_server(): required_tools = [ "search", "search_fts", "search_vec", "ingest", "index_document", "index_image", "memory_add", "memory_add_conversation", "memory_update", "memory_delete", + "memory_graph_entities", "memory_graph_related", "rename_collection", "delete_collection", "list_collections", "status", "rebuild_fts", "get_config", "set_config" ] diff --git a/tests/uat/test_uat_comprehensive.py b/tests/uat/test_uat_comprehensive.py index 8ecb099..b2edc40 100644 --- a/tests/uat/test_uat_comprehensive.py +++ b/tests/uat/test_uat_comprehensive.py @@ -460,6 +460,7 @@ async def test_server_exposes_required_tools(self, mock_backend, mock_storage): "search", "search_fts", "search_vec", "ingest", "index_document", "index_image", "memory_add", "memory_add_conversation", "memory_update", "memory_delete", + "memory_graph_entities", "memory_graph_related", "status", "rebuild_fts", "get_config", "set_config", "list_collections", "list_namespaces", "rename_collection", "delete_collection",