diff --git a/architecture.png b/architecture.png deleted file mode 100644 index f9279b2..0000000 Binary files a/architecture.png and /dev/null differ diff --git a/benchmark.py b/benchmark.py index 17f664c..f2d2d7c 100644 --- a/benchmark.py +++ b/benchmark.py @@ -17,239 +17,237 @@ from __future__ import annotations import asyncio -import hashlib +import json import logging -import math -import struct import time from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any, Optional from unittest.mock import patch logging.basicConfig(level=logging.WARNING) from langchain_core.language_models import BaseChatModel -from src.config import settings from src.models import get_model -from src.agents.classifier import ClassifierAgent -from src.agents.profiler import ProfilerAgent -from src.agents.temporal import TemporalAgent -from src.agents.summarizer import SummarizerAgent -from src.agents.judge import JudgeAgent from src.pipelines.ingest import IngestPipeline from src.pipelines.retrieval import RetrievalPipeline -from src.storage.base import SearchResult TEST_QUERY = ( - "My name is Alice and I work at Google as a senior software engineer. " - "My birthday is April 5th. I love sushi and hiking on weekends." + "My name is Bob, and I started a new job at Vercel as a frontend developer today!" ) -TEST_RESPONSE = "Nice to meet you Alice! That sounds like a great lifestyle." -SESSION_DT = "4:04 pm on 20 January, 2025" +TEST_RESPONSE = "Congratulations on your new role Bob! That's wonderful news." +SESSION_DT = "4:00 pm on 20 May, 2026" BENCH_USER_ID = "bench-user" -# ── In-memory stores (mirrors Go benchmark — no Pinecone/Neo4j required) ── - - -class InMemoryVectorStore: - def __init__(self) -> None: - self.records: dict[str, dict[str, Any]] = {} - self.next_id = 1 - - def add(self, texts, embeddings, ids=None, metadata=None): - created = [] - for idx, text in enumerate(texts): - record_id = ids[idx] if ids else f"vec-{self.next_id}" - self.next_id += 1 - self.records[record_id] = { - "content": text, - "embedding": embeddings[idx], - "metadata": (metadata or [{}])[idx], - "score": 1.0, - } - created.append(record_id) - return created - - def update(self, id, text=None, embedding=None, metadata=None): - if id not in self.records: - return False - current = self.records[id] - if text is not None: - current["content"] = text - if embedding is not None: - current["embedding"] = embedding - if metadata is not None: - current["metadata"] = metadata - return True - - def delete(self, ids): - for record_id in ids: - self.records.pop(record_id, None) - return True - - def get(self, ids): - return [ - {"id": record_id, **self.records[record_id]} - for record_id in ids - if record_id in self.records - ] - - def search_by_metadata(self, filters, top_k=10): - matches = [] - for record_id, record in self.records.items(): - metadata = record["metadata"] - if all(metadata.get(key) == value for key, value in filters.items()): - matches.append( - SearchResult( - id=record_id, - content=record["content"], - score=record.get("score", 1.0), - metadata=metadata, - ) - ) - return matches[:top_k] - - async def search_by_text(self, query_text, top_k=10, filters=None): - filters = filters or {} - query = query_text.lower() - matches = [] - for record_id, record in self.records.items(): - metadata = record["metadata"] - if filters and not all(metadata.get(k) == v for k, v in filters.items()): - continue - content = record["content"].lower() - score = 1.0 if query and query in content else record.get("score", 0.5) - if filters or score > 0: - matches.append( - SearchResult( - id=record_id, - content=record["content"], - score=score, - metadata=metadata, - ) - ) - matches.sort(key=lambda r: r.score, reverse=True) - return matches[:top_k] - - def search(self, query_embedding, top_k=5, filters=None): - return self.search_by_metadata(filters or {}, top_k=top_k) - - def health_check(self): - return True - - -class FakeNeo4jClient: - def __init__(self) -> None: - self.events: list[dict[str, Any]] = [] - self.connected = False - - def connect(self): - self.connected = True - - def close(self): - pass - - def search_events_by_name(self, event_name: str, user_id: str, top_k: int = 1): - return [ - event for event in self.events - if event.get("user_id", user_id) == user_id - and event_name.lower() in event.get("event_name", "").lower() - ][:top_k] - - def search_events_by_embedding(self, user_id: str, query_text: str, top_k: int = 3, similarity_threshold: float = 0.0): - query = query_text.lower() - matches = [] - for event in self.events: - if event.get("user_id", user_id) != user_id: - continue - text = " ".join( - str(event.get(k, "")) for k in ("event_name", "desc", "date_expression") - ).lower() - score = 1.0 if query and any(w in text for w in query.split()) else 0.1 - matches.append({**event, "similarity_score": score}) - matches.sort(key=lambda e: e.get("similarity_score", 0), reverse=True) - return matches[:top_k] - - def create_event(self, user_id: str, date_str: str, event_data: dict[str, Any]): - self.events.append({"user_id": user_id, "date": date_str, **event_data}) - - def update_event(self, user_id: str, date_str: str, event_data: dict[str, Any]): - for event in self.events: - if event.get("user_id") == user_id and event.get("date") == date_str: - event.update(event_data) - return True - self.create_event(user_id, date_str, event_data) - return True - - def delete_event(self, user_id: str, date_str: str, event_name: str | None = None): - self.events = [ - event for event in self.events - if not ( - event.get("user_id") == user_id - and event.get("date") == date_str - and (event_name is None or event.get("event_name") == event_name) - ) - ] - return True - - -class FakeCodeGraphClient: - def connect(self): - pass - - def close(self): - pass - - def setup(self): - pass - - def create_annotation(self, **kwargs): - return "ann-1" - - -def hash_embed(text: str) -> list[float]: - """Local hash embedder — matches Go benchmark (no embedding API calls).""" - dim = int(settings.pinecone_dimension or 384) - vec = [0.0] * dim - words = text.lower().split() or [text.lower()] - for word in words: - digest = hashlib.sha256(word.encode()).digest() - idx = struct.unpack(">Q", digest[:8])[0] % dim - vec[idx] += 1.0 - norm = math.sqrt(sum(v * v for v in vec)) - if norm: - vec = [v / norm for v in vec] - return vec +def indent_lines(text: str, spaces: int = 4) -> str: + indent = " " * spaces + return "\n".join(indent + line for line in str(text).splitlines()) + + +def estimate_tokens(text: str) -> int: + text = str(text or "").strip() + if not text: + return 0 + return (len(text) + 3) // 4 + + +def message_role(message: Any) -> str: + if isinstance(message, dict): + return str(message.get("role", "user")) + role = getattr(message, "type", None) or getattr(message, "role", None) + if role == "human": + return "user" + if role == "ai": + return "assistant" + return str(role or "user") + + +def message_content(message: Any) -> str: + if isinstance(message, dict): + content = message.get("content", "") + else: + content = getattr(message, "content", "") + if isinstance(content, list): + parts = [] + for item in content: + if isinstance(item, dict): + if item.get("type") == "text": + parts.append(str(item.get("text", ""))) + elif item.get("type") == "image_url": + parts.append("") + else: + parts.append(json.dumps(item, default=str)) + else: + parts.append(str(item)) + return "\n".join(parts) + return str(content or "") + + +def response_content(response: Any) -> str: + content = getattr(response, "content", response) + if isinstance(content, list): + parts = [] + for item in content: + if isinstance(item, dict) and "text" in item: + parts.append(str(item["text"])) + else: + parts.append(str(item)) + return "\n".join(parts) + return str(content or "") + + +def message_tokens(messages: Any) -> int: + if isinstance(messages, (str, bytes)): + return estimate_tokens(str(messages)) + if not isinstance(messages, list): + return estimate_tokens(message_content(messages)) + return sum(estimate_tokens(message_role(m)) + estimate_tokens(message_content(m)) for m in messages) + + +def response_tokens(response: Any) -> tuple[int, int, bool]: + usage = getattr(response, "usage_metadata", None) or {} + if isinstance(usage, dict): + input_tokens = int(usage.get("input_tokens") or usage.get("prompt_tokens") or 0) + output_tokens = int(usage.get("output_tokens") or usage.get("completion_tokens") or 0) + if input_tokens or output_tokens: + return input_tokens, output_tokens, False + + metadata = getattr(response, "response_metadata", None) or {} + if isinstance(metadata, dict): + token_usage = metadata.get("token_usage") or metadata.get("usage") or {} + if isinstance(token_usage, dict): + input_tokens = int(token_usage.get("prompt_tokens") or token_usage.get("input_tokens") or 0) + output_tokens = int(token_usage.get("completion_tokens") or token_usage.get("output_tokens") or 0) + if input_tokens or output_tokens: + return input_tokens, output_tokens, False + return 0, estimate_tokens(response_content(response)), True + + +def infer_agent(messages: Any, default_agent: str) -> str: + texts = [] + if isinstance(messages, list): + texts = [message_content(m) for m in messages] + else: + texts = [str(messages)] + combined = "\n".join(texts).lower() + if "analyze this user input" in combined or "intent router" in combined: + return "classifier" + if "extract all temporal events" in combined or "event extraction assistant" in combined: + return "temporal" + if "summarize this conversation" in combined: + return "summarizer" + if "profile facts" in combined or "extracts structured user facts" in combined or "extract important user profiles" in combined or "build a complete picture of the user" in combined: + return "profiler" + if "## domain:" in combined or "judge agent" in combined: + return "judge" + if "image analysis" in combined or "analyse this image" in combined: + return "image" + if "extract code snippets" in combined: + return "snippet" + if "extract code annotations" in combined: + return "code" + return default_agent class TimedModel: - """Proxy that wraps a real LangChain model and tracks LLM call time.""" - - def __init__(self, inner: BaseChatModel, tracker: Optional[dict[str, int]] = None): + """Proxy that wraps a real LangChain model and traces/times LLM calls.""" + + def __init__( + self, + inner: BaseChatModel, + tracker: Optional[dict[str, Any]] = None, + agent: str = "model", + trace: bool = True, + ): self._inner = inner - self._tracker = tracker or {"llm_time_ns": 0, "call_count": 0} + self._tracker = tracker or { + "llm_time_ns": 0, + "call_count": 0, + "input_tokens": 0, + "output_tokens": 0, + "agents": {}, + } + self._agent = agent + self._trace = trace def __getattr__(self, name: str) -> Any: return getattr(self._inner, name) + def _count_tokens(self, messages: Any, result: Any) -> tuple[int, int]: + input_tokens, output_tokens, _ = response_tokens(result) + if input_tokens == 0: + input_tokens = message_tokens(messages) + return input_tokens, output_tokens + + def _record_metrics(self, agent: str, elapsed_ns: int, input_tokens: int, output_tokens: int) -> None: + self._tracker["llm_time_ns"] += elapsed_ns + self._tracker["call_count"] += 1 + self._tracker["input_tokens"] += input_tokens + self._tracker["output_tokens"] += output_tokens + + agents = self._tracker.setdefault("agents", {}) + metrics = agents.setdefault( + agent, + {"llm_time_ns": 0, "call_count": 0, "input_tokens": 0, "output_tokens": 0}, + ) + metrics["llm_time_ns"] += elapsed_ns + metrics["call_count"] += 1 + metrics["input_tokens"] += input_tokens + metrics["output_tokens"] += output_tokens + async def ainvoke(self, messages: Any, **kwargs) -> Any: start = time.perf_counter_ns() result = await self._inner.ainvoke(messages, **kwargs) - self._tracker["llm_time_ns"] += time.perf_counter_ns() - start - self._tracker["call_count"] += 1 + elapsed_ns = time.perf_counter_ns() - start + input_tokens, output_tokens = self._count_tokens(messages, result) + agent = infer_agent(messages, self._agent) + self._record_metrics(agent, elapsed_ns, input_tokens, output_tokens) + if self._trace: + self._print_trace("ainvoke", messages, result, elapsed_ns) return result def invoke(self, messages: Any, **kwargs) -> Any: start = time.perf_counter_ns() result = self._inner.invoke(messages, **kwargs) - self._tracker["llm_time_ns"] += time.perf_counter_ns() - start - self._tracker["call_count"] += 1 + elapsed_ns = time.perf_counter_ns() - start + input_tokens, output_tokens = self._count_tokens(messages, result) + agent = infer_agent(messages, self._agent) + self._record_metrics(agent, elapsed_ns, input_tokens, output_tokens) + if self._trace: + self._print_trace("invoke", messages, result, elapsed_ns) return result def bind_tools(self, tools, **kwargs): - return TimedModel(self._inner.bind_tools(tools, **kwargs), tracker=self._tracker) + return TimedModel( + self._inner.bind_tools(tools, **kwargs), + tracker=self._tracker, + agent=self._agent, + trace=self._trace, + ) + + def for_agent(self, agent: str) -> "TimedModel": + return TimedModel(self._inner, tracker=self._tracker, agent=agent, trace=self._trace) + + def _print_trace(self, call_type: str, messages: Any, result: Any, elapsed_ns: int) -> None: + agent = infer_agent(messages, self._agent) + + print() + print(f" \033[35m┌─── [LLM Call: {agent} / {call_type}] ───────────────────────────────────┐\033[0m") + if isinstance(messages, list): + for message in messages: + role = message_role(message) + if role == "system": + continue + print(f" \033[35m│\033[0m \033[1;33m{role.upper()}:\033[0m") + print(indent_lines(message_content(message))) + else: + print(" \033[35m│\033[0m \033[1;33mPROMPT:\033[0m") + print(indent_lines(str(messages))) + print(" \033[35m├─────────────────────────────────────────────────────────────────────────────────┤\033[0m") + print(" \033[35m│\033[0m \033[1;32mResponse:\033[0m") + print(indent_lines(response_content(result))) + print(" \033[35m└─────────────────────────────────────────────────────────────────────────────────┘\033[0m") @property def llm_duration_ms(self) -> float: @@ -259,6 +257,18 @@ def llm_duration_ms(self) -> float: def call_count(self) -> int: return self._tracker["call_count"] + @property + def input_tokens(self) -> int: + return self._tracker["input_tokens"] + + @property + def output_tokens(self) -> int: + return self._tracker["output_tokens"] + + @property + def agent_metrics(self) -> dict[str, dict[str, int]]: + return self._tracker.setdefault("agents", {}) + @dataclass class Timing: @@ -267,6 +277,8 @@ class Timing: llm_ms: float calls: int concurrent: bool = False + input_tokens: int = 0 + output_tokens: int = 0 @property def overhead_ms(self) -> float: @@ -279,124 +291,36 @@ def _truncate(text: str, max_len: int = 80) -> str: def _ingest_stats(state: dict[str, Any]) -> str: - parts = [] + lines = [] cls = state.get("classification_result") if cls and cls.classifications: - parts.append(f"classifications={len(cls.classifications)}") + lines.append(f"classifications={len(cls.classifications)}") + for c in cls.classifications: + lines.append(f" {c['source']}: {c['query']}") for domain in ("profile", "temporal", "summary"): judge = state.get(f"{domain}_judge") if judge and judge.operations: - parts.append(f"{domain}_ops={len(judge.operations)}") - return " ".join(parts) - - -async def bench_classifier(real_model: BaseChatModel) -> Timing: - tm = TimedModel(real_model) - agent = ClassifierAgent(model=tm) - start = time.perf_counter_ns() - result = await agent.arun({"user_query": TEST_QUERY}) - total_ms = (time.perf_counter_ns() - start) / 1_000_000 - n = len(result.classifications) if result.classifications else 0 - print(f" Classifier Agent results={n}") - return Timing("Classifier Agent", total_ms, tm.llm_duration_ms, tm.call_count) - + lines.append(f"Judge ({domain}): {len(judge.operations)} op(s)") + for op in judge.operations: + preview = op.content[:70] + "..." if len(op.content) > 70 else op.content + lines.append(f" {op.type.value}: {preview}") + return "\n ".join(lines) -async def bench_profiler(real_model: BaseChatModel) -> Timing: - tm = TimedModel(real_model) - agent = ProfilerAgent(model=tm) - start = time.perf_counter_ns() - result = await agent.arun({"classifier_output": TEST_QUERY}) - total_ms = (time.perf_counter_ns() - start) / 1_000_000 - n = len(result.facts) if result.facts else 0 - print(f" Profiler Agent facts={n}") - return Timing("Profiler Agent", total_ms, tm.llm_duration_ms, tm.call_count) - -async def bench_temporal(real_model: BaseChatModel) -> Timing: - tm = TimedModel(real_model) - agent = TemporalAgent(model=tm) - start = time.perf_counter_ns() - result = await agent.arun({ - "classifier_output": TEST_QUERY, - "session_datetime": SESSION_DT, - }) - total_ms = (time.perf_counter_ns() - start) / 1_000_000 - n = len(result.events) if result.events else 0 - print(f" Temporal Agent events={n}") - return Timing("Temporal Agent", total_ms, tm.llm_duration_ms, tm.call_count) - - -async def bench_summarizer(real_model: BaseChatModel) -> Timing: - tm = TimedModel(real_model) - agent = SummarizerAgent(model=tm) - start = time.perf_counter_ns() - result = await agent.arun({ - "user_query": TEST_QUERY, - "agent_response": TEST_RESPONSE, - }) - total_ms = (time.perf_counter_ns() - start) / 1_000_000 - summary = result.summary if result.summary else "" - n = len([line for line in summary.splitlines() if line.strip()]) - print(f" Summarizer Agent bullets={n}") - return Timing("Summarizer Agent", total_ms, tm.llm_duration_ms, tm.call_count) - - -async def bench_judge_deterministic(real_model: BaseChatModel) -> Timing: - tm = TimedModel(real_model) - agent = JudgeAgent(model=tm, vector_store=None, graph_event_search=None, top_k=3) - items = [ - {"topic": "basic_info", "sub_topic": "name", "memo": "Alice"}, - {"topic": "work", "sub_topic": "company", "memo": "Google"}, - {"topic": "work", "sub_topic": "title", "memo": "Senior Software Engineer"}, - ] - start = time.perf_counter_ns() - result = await agent.arun_deterministic({ - "domain": "profile", - "new_items": items, - "user_id": BENCH_USER_ID, - }) - total_ms = (time.perf_counter_ns() - start) / 1_000_000 - n = len(result.operations) if result.operations else 0 - print(f" Judge (deterministic) ops={n}") - return Timing("Judge (deterministic)", total_ms, tm.llm_duration_ms, tm.call_count) - - -async def bench_judge_llm(real_model: BaseChatModel) -> Timing: - tm = TimedModel(real_model) - agent = JudgeAgent(model=tm, vector_store=None, graph_event_search=None, top_k=3) - items = [ - "User's name is Alice and works at Google as a senior software engineer", - "User's birthday is April 5th", - "User loves sushi and hiking on weekends", - ] - start = time.perf_counter_ns() - result = await agent.arun({ - "domain": "summary", - "new_items": items, - "user_id": BENCH_USER_ID, - }) - total_ms = (time.perf_counter_ns() - start) / 1_000_000 - n = len(result.operations) if result.operations else 0 - print(f" Judge (LLM) ops={n}") - return Timing("Judge (LLM)", total_ms, tm.llm_duration_ms, tm.call_count) - - -async def bench_full_ingest(real_model: BaseChatModel, vector_store: InMemoryVectorStore, neo4j: FakeNeo4jClient) -> Timing: - tm = TimedModel(real_model) +async def bench_full_ingest(real_model: BaseChatModel, tracker: dict[str, Any]) -> tuple[Timing, IngestPipeline]: + tm = TimedModel(real_model, tracker=tracker) def _timed_get_model(*_args, **_kwargs): return tm with patch("src.pipelines.ingest.get_model", _timed_get_model), \ - patch("src.pipelines.ingest.get_vision_model", _timed_get_model), \ - patch("src.storage.factory.get_vector_store", lambda *args, **kwargs: vector_store): - pipeline = IngestPipeline( - vector_store=vector_store, - neo4j_client=neo4j, - code_graph_client=FakeCodeGraphClient(), - embed_fn=hash_embed, - ) + patch("src.pipelines.ingest.get_vision_model", _timed_get_model): + pipeline = IngestPipeline() + start_calls = tm.call_count + start_input_tokens = tm.input_tokens + start_output_tokens = tm.output_tokens + start_llm_ms = tm.llm_duration_ms start = time.perf_counter_ns() state = await pipeline.run( user_query=TEST_QUERY, @@ -405,76 +329,103 @@ def _timed_get_model(*_args, **_kwargs): session_datetime=SESSION_DT, ) total_ms = (time.perf_counter_ns() - start) / 1_000_000 - print(f" Full Ingest Pipeline calls={tm.call_count} {_ingest_stats(state)} (parallel — LLM sum > wall clock)") - return Timing("Full Ingest Pipeline", total_ms, tm.llm_duration_ms, tm.call_count, concurrent=True) - - -async def bench_full_retrieval(real_model: BaseChatModel, vector_store: InMemoryVectorStore, neo4j: FakeNeo4jClient) -> Timing: - tm = TimedModel(real_model) + calls = tm.call_count - start_calls + input_tokens = tm.input_tokens - start_input_tokens + output_tokens = tm.output_tokens - start_output_tokens + llm_ms = tm.llm_duration_ms - start_llm_ms + stats = _ingest_stats(state) + print(f" Full Ingest Pipeline calls={calls} (parallel — LLM sum > wall clock)") + if stats: + print(f" {indent_lines(stats, 2).lstrip()}") + timing = Timing("Full Ingest Pipeline", total_ms, llm_ms, calls, + concurrent=True, input_tokens=input_tokens, output_tokens=output_tokens) + return timing, pipeline + + +async def bench_full_retrieval(real_model: BaseChatModel, vector_store: Any, neo4j: Any, tracker: dict[str, Any]) -> Timing: + tm = TimedModel(real_model, tracker=tracker, agent="retrieval") pipeline = RetrievalPipeline( model=tm, vector_store=vector_store, neo4j_client=neo4j, ) + start_calls = tm.call_count + start_input_tokens = tm.input_tokens + start_output_tokens = tm.output_tokens + start_llm_ms = tm.llm_duration_ms start = time.perf_counter_ns() result = await pipeline.run( query="What is my name and where do I work?", user_id=BENCH_USER_ID, ) total_ms = (time.perf_counter_ns() - start) / 1_000_000 + calls = tm.call_count - start_calls + input_tokens = tm.input_tokens - start_input_tokens + output_tokens = tm.output_tokens - start_output_tokens + llm_ms = tm.llm_duration_ms - start_llm_ms print( - f" Full Retrieval Pipeline calls={tm.call_count} " + f" Full Retrieval Pipeline calls={calls} " f"answer={result.answer!r} sources={result.source_count} confidence={result.confidence:.2f}" ) - return Timing("Full Retrieval Pipeline", total_ms, tm.llm_duration_ms, tm.call_count) + return Timing("Full Retrieval Pipeline", total_ms, llm_ms, calls, + input_tokens=input_tokens, output_tokens=output_tokens) -async def bench_concurrent_agents(real_model: BaseChatModel) -> Timing: - tm = TimedModel(real_model) - start = time.perf_counter_ns() - - classifier = ClassifierAgent(model=tm) - await classifier.arun({"user_query": TEST_QUERY}) - - profiler = ProfilerAgent(model=tm) - temporal = TemporalAgent(model=tm) - summarizer = SummarizerAgent(model=tm) - - await asyncio.gather( - profiler.arun({"classifier_output": TEST_QUERY}), - temporal.arun({"classifier_output": TEST_QUERY, "session_datetime": SESSION_DT}), - summarizer.arun({"user_query": TEST_QUERY, "agent_response": TEST_RESPONSE}), - ) - - total_ms = (time.perf_counter_ns() - start) / 1_000_000 - print(f" Concurrent Pipeline Sim calls={tm.call_count} (parallel — LLM sum > wall clock)") - return Timing("Concurrent Pipeline Sim", total_ms, tm.llm_duration_ms, tm.call_count, concurrent=True) - +def _print_summary(tracker: dict[str, Any], model_name: str, pipeline_timings: list[Timing] | None = None) -> None: + metrics_by_agent = tracker.get("agents", {}) -def _print_summary(timings: List[Timing], model_name: str) -> None: print() - print("╔════════════════════════════════════════════════════════════════════════════════════════╗") - print("║ XMem-Python Benchmark Summary ║") - print(f"║ Model: {str(model_name):<77}║") - print("╠════════════════════════════════════════════════════════════════════════════════════════╣") - print(f"║ {'Component':<30} {'Total':>10} {'LLM Time':>12} {'Overhead':>10} {'Calls':>6} {'':>8} ║") - print("╠════════════════════════════════════════════════════════════════════════════════════════╣") - for t in timings: - if t.concurrent: - saved = max(0.0, t.llm_ms - t.total_ms) - print( - f"║ {t.name:<30} {t.total_ms:>9.0f}ms {t.llm_ms:>10.0f}ms† {saved:>9.0f}ms {t.calls:>6} {'parallel':>8} ║" - ) - else: + print("╔══════════════════════════════════════════════════════════════════════════════════════════════════╗") + print("║ XMem-Python Pipeline Metrics Summary ║") + print(f"║ Model: {str(model_name):<85}║") + print("╠══════════════════════════════════════════════════════════════════════════════════════════════════╣") + print(f"║ {'Agent':<20} {'Calls':>6} {'LLM Time':>12} {'Overhead':>12} {'In Tokens':>12} {'Out Tokens':>12} ║") + print("╠══════════════════════════════════════════════════════════════════════════════════════════════════╣") + for agent in sorted(metrics_by_agent): + metrics = metrics_by_agent[agent] + llm_ms = metrics.get("llm_time_ns", 0) / 1_000_000 + print( + f"║ {agent:<20} {metrics.get('call_count', 0):>6} " + f"{llm_ms:>11.0f}ms {'—':>12} {metrics.get('input_tokens', 0):>12} " + f"{metrics.get('output_tokens', 0):>12} ║" + ) + if pipeline_timings: + print("╠══════════════════════════════════════════════════════════════════════════════════════════════════╣") + print(f"║ {'Pipeline':<20} {'Calls':>6} {'LLM Time':>12} {'Overhead':>12} {'Wall Clock':>12} {'':>12} ║") + print("╠══════════════════════════════════════════════════════════════════════════════════════════════════╣") + for pt in pipeline_timings: + overhead = max(0, pt.total_ms - pt.llm_ms) print( - f"║ {t.name:<30} {t.total_ms:>9.0f}ms {t.llm_ms:>11.0f}ms {t.overhead_ms:>9.1f}ms {t.calls:>6} {'':>8} ║" + f"║ {pt.name:<20} {pt.calls:>6} " + f"{pt.llm_ms:>11.0f}ms {overhead:>11.0f}ms {pt.total_ms:>11.0f}ms {'':>12} ║" ) - print("╚════════════════════════════════════════════════════════════════════════════════════════╝") - print() - print("Sequential agents: Overhead = Total - LLM Time (prompt building, parsing, etc.)") - print("Parallel agents: LLM Time† = cumulative across asyncio tasks; Overhead = time saved by concurrency") - print("Compare sequential 'Overhead' with Go benchmark output.") + print("╚══════════════════════════════════════════════════════════════════════════════════════════════════╝") + + if pipeline_timings: + total_in = tracker.get("input_tokens", 0) + total_out = tracker.get("output_tokens", 0) + query_tokens = estimate_tokens(TEST_QUERY) + + input_price = 0.15 / 1_000_000 + output_price = 0.60 / 1_000_000 + + ingest_pt = next((pt for pt in pipeline_timings if "Ingest" in pt.name), None) + retrieve_pt = next((pt for pt in pipeline_timings if "Retrieval" in pt.name), None) + + print() + print(f" User Query Tokens: ~{query_tokens}") + print(f" Total Input Tokens: {total_in}") + print(f" Total Output Tokens: {total_out}") + if ingest_pt: + cost = ingest_pt.input_tokens * input_price + ingest_pt.output_tokens * output_price + print(f" Cost to Ingest: ${cost:.6f} ({ingest_pt.input_tokens} in / {ingest_pt.output_tokens} out)") + if retrieve_pt: + cost = retrieve_pt.input_tokens * input_price + retrieve_pt.output_tokens * output_price + print(f" Cost to Retrieve: ${cost:.6f} ({retrieve_pt.input_tokens} in / {retrieve_pt.output_tokens} out)") + print() + print(" Cost estimate based on gpt-4o-mini pricing ($0.15/1M input, $0.60/1M output).") + print(" Actual cost varies by provider/model. Tokens may be estimated if provider didn't return usage.") async def main(): @@ -482,33 +433,25 @@ async def main(): model_name = getattr(real_model, "model_name", getattr(real_model, "model", "unknown")) print(f"Model: {model_name}\n") - timings: List[Timing] = [] - - print("Running individual agent benchmarks (real LLM calls)...") - print("─" * 70) - timings.append(await bench_classifier(real_model)) - timings.append(await bench_profiler(real_model)) - timings.append(await bench_temporal(real_model)) - timings.append(await bench_summarizer(real_model)) - timings.append(await bench_judge_deterministic(real_model)) - timings.append(await bench_judge_llm(real_model)) - print("─" * 70) + tracker: dict[str, Any] = { + "llm_time_ns": 0, + "call_count": 0, + "input_tokens": 0, + "output_tokens": 0, + "agents": {}, + } - # Shared stores: ingest writes here, retrieval reads the same data. - vector_store = InMemoryVectorStore() - neo4j = FakeNeo4jClient() - neo4j.connect() + pipeline_timings: list[Timing] = [] - print("\nRunning full ingest pipeline...") - timings.append(await bench_full_ingest(real_model, vector_store, neo4j)) + print("Running full ingest pipeline...") + ingest_timing, ingest_pipeline = await bench_full_ingest(real_model, tracker) + pipeline_timings.append(ingest_timing) print("\nRunning full retrieval pipeline (after ingest)...") - timings.append(await bench_full_retrieval(real_model, vector_store, neo4j)) - - print("\nRunning concurrent agent benchmark (classifier → profiler+temporal+summarizer in parallel)...") - timings.append(await bench_concurrent_agents(real_model)) + retrieval_timing = await bench_full_retrieval(real_model, ingest_pipeline.vector_store, ingest_pipeline.neo4j, tracker) + pipeline_timings.append(retrieval_timing) - _print_summary(timings, str(model_name)) + _print_summary(tracker, str(model_name), pipeline_timings) if __name__ == "__main__": diff --git a/demo.mp4 b/demo.mp4 deleted file mode 100644 index c5e637d..0000000 Binary files a/demo.mp4 and /dev/null differ diff --git a/scripts/xmem.js b/scripts/xmem.js index d9158de..93c1a0b 100644 --- a/scripts/xmem.js +++ b/scripts/xmem.js @@ -519,6 +519,33 @@ function ensurePrerequisites(skipPython = false) { } } +function pythonHasPip(pythonPath) { + return run(pythonPath, ["-m", "pip", "--version"], { capture: true, allowFailure: true }).status === 0; +} + +function ensureVirtualenv() { + const venvPython = venvPythonPath(); + if (!fs.existsSync(venvPython)) { + log("Creating XMem virtualenv"); + run(systemPythonCommand(), ["-m", "venv", path.join(root, ".venv")]); + } + + if (!pythonHasPip(venvPython)) { + log("Repairing XMem virtualenv pip"); + const result = run(venvPython, ["-m", "ensurepip", "--upgrade"], { + allowFailure: true, + }); + if (result.status !== 0 || !pythonHasPip(venvPython)) { + fail( + "XMem virtualenv was created, but pip is unavailable. Reinstall Python with venv/pip support, delete .venv, and rerun npm run setup.", + 2, + ); + } + } + + return venvPython; +} + function setupLooksComplete(reposDir) { return ( fs.existsSync(path.join(root, "pyproject.toml")) && @@ -586,11 +613,7 @@ function runSetup(args) { } if (!options.skipPythonInstall) { - const venvPython = venvPythonPath(); - if (!fs.existsSync(venvPython)) { - log("Creating XMem virtualenv"); - run(systemPythonCommand(), ["-m", "venv", path.join(root, ".venv")]); - } + const venvPython = ensureVirtualenv(); log("Installing XMem local dependencies"); run(venvPython, ["-m", "pip", "install", "--upgrade", "pip"]); run(venvPython, ["-m", "pip", "install", "-e", `${root}[local,dev]`]); diff --git a/src/prompts/examples/temporal.py b/src/prompts/examples/temporal.py index 889e863..032b027 100644 --- a/src/prompts/examples/temporal.py +++ b/src/prompts/examples/temporal.py @@ -63,6 +63,12 @@ "2:35 pm on 16 March, 2023", "DATE: 03-09\nEVENT_NAME: Started Gym\nYEAR: 2023\nDESC: Started going to the gym\nTIME: \nDATE_EXPRESSION: last week", ), + # Relative date — today + ( + "I started a new job at Vercel as a frontend developer today!", + "4:00 pm on 20 May, 2026", + "DATE: 05-20\nEVENT_NAME: Started New Job\nYEAR: 2026\nDESC: Started a new job at Vercel as a frontend developer\nTIME: \nDATE_EXPRESSION: today", + ), # Relative date — next month ( "I'm getting ready for a dance comp near me next month.", diff --git a/src/prompts/temporal.py b/src/prompts/temporal.py index fc3ebdb..79fc48d 100644 --- a/src/prompts/temporal.py +++ b/src/prompts/temporal.py @@ -33,6 +33,7 @@ You will be given a CONTEXT_DATE which is the date/time when the conversation occurred. Use this to resolve relative expressions: +- "today" → use CONTEXT_DATE - "yesterday" → subtract 1 day from CONTEXT_DATE - "tomorrow" → add 1 day to CONTEXT_DATE - "next Friday" → find the next Friday after CONTEXT_DATE diff --git a/xmem-go/cmd/benchmark/main.go b/xmem-go/cmd/smoke_test/main.go similarity index 85% rename from xmem-go/cmd/benchmark/main.go rename to xmem-go/cmd/smoke_test/main.go index 26386cc..c8c5961 100644 --- a/xmem-go/cmd/benchmark/main.go +++ b/xmem-go/cmd/smoke_test/main.go @@ -48,6 +48,14 @@ func (t *TimedModel) GenerateWithMessages(ctx context.Context, msgs []models.Mes return resp, err } +func (t *TimedModel) GenerateVision(ctx context.Context, systemPrompt string, userText string, imageURL string) (models.Response, error) { + start := time.Now() + resp, err := t.inner.GenerateVision(ctx, systemPrompt, userText, imageURL) + atomic.AddInt64(&t.llmTime, int64(time.Since(start))) + atomic.AddInt64(&t.calls, 1) + return resp, err +} + func (t *TimedModel) SelectTools(ctx context.Context, query string, catalog []map[string]string) (models.Response, error) { start := time.Now() resp, err := t.inner.SelectTools(ctx, query, catalog) @@ -85,7 +93,11 @@ func main() { os.Exit(1) } - realModel := models.NewRegistry(settings) + realModel, err := models.NewRegistry(settings) + if err != nil { + fmt.Fprintf(os.Stderr, "model registry error: %v\n", err) + os.Exit(1) + } fmt.Printf("Model: %s\n\n", realModel.Name()) ctx := context.Background() @@ -109,7 +121,7 @@ func main() { total := time.Since(start) t := timing{"Classifier Agent", total, tm.LLMDuration(), total - tm.LLMDuration(), tm.CallCount(), false} timings = append(timings, t) - fmt.Printf(" %-30s total=%-10s llm=%-10s overhead=%-10s calls=%d results=%d\n", t.name, t.total.Round(time.Millisecond), t.llm.Round(time.Millisecond), t.overhead.Round(time.Microsecond), t.calls, len(result)) + fmt.Printf(" %-30s results=%d\n", t.name, len(result)) } // Profiler @@ -121,7 +133,7 @@ func main() { total := time.Since(start) t := timing{"Profiler Agent", total, tm.LLMDuration(), total - tm.LLMDuration(), tm.CallCount(), false} timings = append(timings, t) - fmt.Printf(" %-30s total=%-10s llm=%-10s overhead=%-10s calls=%d facts=%d\n", t.name, t.total.Round(time.Millisecond), t.llm.Round(time.Millisecond), t.overhead.Round(time.Microsecond), t.calls, len(result)) + fmt.Printf(" %-30s facts=%d\n", t.name, len(result)) } // Temporal @@ -133,7 +145,7 @@ func main() { total := time.Since(start) t := timing{"Temporal Agent", total, tm.LLMDuration(), total - tm.LLMDuration(), tm.CallCount(), false} timings = append(timings, t) - fmt.Printf(" %-30s total=%-10s llm=%-10s overhead=%-10s calls=%d events=%d\n", t.name, t.total.Round(time.Millisecond), t.llm.Round(time.Millisecond), t.overhead.Round(time.Microsecond), t.calls, len(result)) + fmt.Printf(" %-30s events=%d\n", t.name, len(result)) } // Summarizer @@ -145,7 +157,7 @@ func main() { total := time.Since(start) t := timing{"Summarizer Agent", total, tm.LLMDuration(), total - tm.LLMDuration(), tm.CallCount(), false} timings = append(timings, t) - fmt.Printf(" %-30s total=%-10s llm=%-10s overhead=%-10s calls=%d bullets=%d\n", t.name, t.total.Round(time.Millisecond), t.llm.Round(time.Millisecond), t.overhead.Round(time.Microsecond), t.calls, len(result)) + fmt.Printf(" %-30s bullets=%d\n", t.name, len(result)) } // Judge (deterministic profile path — no LLM) @@ -158,11 +170,11 @@ func main() { {Topic: "work", SubTopic: "title", Memo: "Senior Software Engineer"}, } start := time.Now() - result := agent.JudgeProfile(ctx, facts) + result := agent.JudgeProfile(ctx, facts, "bench-user") total := time.Since(start) t := timing{"Judge (deterministic)", total, tm.LLMDuration(), total - tm.LLMDuration(), tm.CallCount(), false} timings = append(timings, t) - fmt.Printf(" %-30s total=%-10s llm=%-10s overhead=%-10s calls=%d ops=%d\n", t.name, t.total.Round(time.Millisecond), t.llm.Round(time.Millisecond), t.overhead.Round(time.Microsecond), t.calls, len(result.Operations)) + fmt.Printf(" %-30s ops=%d\n", t.name, len(result.Operations)) } // Judge (LLM path — summary domain) @@ -179,7 +191,7 @@ func main() { total := time.Since(start) t := timing{"Judge (LLM)", total, tm.LLMDuration(), total - tm.LLMDuration(), tm.CallCount(), false} timings = append(timings, t) - fmt.Printf(" %-30s total=%-10s llm=%-10s overhead=%-10s calls=%d ops=%d\n", t.name, t.total.Round(time.Millisecond), t.llm.Round(time.Millisecond), t.overhead.Round(time.Microsecond), t.calls, len(result.Operations)) + fmt.Printf(" %-30s ops=%d\n", t.name, len(result.Operations)) } fmt.Println(strings.Repeat("─", 70)) @@ -205,7 +217,7 @@ func main() { Summarizer: agents.SummarizerAgent{Model: tm}, Image: agents.ImageAgent{Model: tm}, Snippet: agents.SnippetAgent{Model: tm}, - Judge: agents.JudgeAgent{Model: tm, VectorStore: memStore, TopK: 3}, + Judge: agents.JudgeAgent{Model: tm, VectorStore: memStore, TemporalStore: tempStore, TopK: 3}, } req := contracts.IngestRequest{ @@ -222,7 +234,7 @@ func main() { } else { t := timing{"Full Ingest Pipeline", total, tm.LLMDuration(), total - tm.LLMDuration(), tm.CallCount(), true} timings = append(timings, t) - fmt.Printf(" %-30s total=%-10s llm_sum=%-10s calls=%d (parallel — LLM sum > wall clock)\n", t.name, t.total.Round(time.Millisecond), t.llm.Round(time.Millisecond), t.calls) + fmt.Printf(" %-30s calls=%d (parallel)\n", t.name, t.calls) fmt.Printf(" classifications=%d", len(resp.Classification)) if resp.Profile != nil { fmt.Printf(" profile_ops=%d", len(resp.Profile.Operations)) @@ -262,7 +274,7 @@ func main() { } else { t := timing{"Full Retrieval Pipeline", total, tm.LLMDuration(), total - tm.LLMDuration(), tm.CallCount(), false} timings = append(timings, t) - fmt.Printf(" %-30s total=%-10s llm=%-10s overhead=%-10s calls=%d\n", t.name, t.total.Round(time.Millisecond), t.llm.Round(time.Millisecond), t.overhead.Round(time.Microsecond), t.calls) + fmt.Printf(" %-30s calls=%d\n", t.name, t.calls) fmt.Printf(" answer=%q sources=%d confidence=%.2f\n", truncate(resp.Answer, 80), len(resp.Sources), resp.Confidence) } } @@ -294,7 +306,7 @@ func main() { total := time.Since(start) t := timing{"Concurrent Pipeline Sim", total, tm.LLMDuration(), total - tm.LLMDuration(), tm.CallCount(), true} timings = append(timings, t) - fmt.Printf(" %-30s total=%-10s llm_sum=%-10s calls=%d (parallel — LLM sum > wall clock)\n", t.name, t.total.Round(time.Millisecond), t.llm.Round(time.Millisecond), t.calls) + fmt.Printf(" %-30s calls=%d (parallel)\n", t.name, t.calls) } // --- Summary Table --- diff --git a/xmem-go/cmd/smoke_test/trace.go b/xmem-go/cmd/smoke_test/trace.go new file mode 100644 index 0000000..e959ca2 --- /dev/null +++ b/xmem-go/cmd/smoke_test/trace.go @@ -0,0 +1,558 @@ +//go:build ignore +// +build ignore + +package main + +import ( + "context" + json "github.com/goccy/go-json" + "fmt" + "log/slog" + "os" + "sort" + "strings" + "sync" + "time" + + "github.com/xortexai/xmem-go/internal/agents" + "github.com/xortexai/xmem-go/internal/config" + "github.com/xortexai/xmem-go/internal/contracts" + "github.com/xortexai/xmem-go/internal/graph" + "github.com/xortexai/xmem-go/internal/models" + "github.com/xortexai/xmem-go/internal/pipelines" + "github.com/xortexai/xmem-go/internal/storage" + "github.com/xortexai/xmem-go/internal/weaver" +) + +// ── Metrics Collector ───────────────────────────────────────────────────── + +type MetricsEntry struct { + Agent string + CallCount int + LLMTime time.Duration + TotalInputTokens int + TotalOutputTokens int +} + +type PipelineTiming struct { + Name string + WallClock time.Duration + LLMTime time.Duration + Calls int + InputTokens int + OutputTokens int +} + +type MetricsCollector struct { + mu sync.Mutex + entries map[string]*MetricsEntry + pipelineTimings []PipelineTiming + queryTokens int +} + +func NewMetricsCollector() *MetricsCollector { + return &MetricsCollector{entries: make(map[string]*MetricsEntry)} +} + +func (mc *MetricsCollector) Record(agent string, latency time.Duration, inputTokens, outputTokens int) { + mc.mu.Lock() + defer mc.mu.Unlock() + entry, ok := mc.entries[agent] + if !ok { + entry = &MetricsEntry{Agent: agent} + mc.entries[agent] = entry + } + entry.CallCount++ + entry.LLMTime += latency + entry.TotalInputTokens += inputTokens + entry.TotalOutputTokens += outputTokens +} + +func (mc *MetricsCollector) TotalLLMTime() time.Duration { + mc.mu.Lock() + defer mc.mu.Unlock() + var total time.Duration + for _, e := range mc.entries { + total += e.LLMTime + } + return total +} + +func (mc *MetricsCollector) TotalCalls() int { + mc.mu.Lock() + defer mc.mu.Unlock() + total := 0 + for _, e := range mc.entries { + total += e.CallCount + } + return total +} + +func (mc *MetricsCollector) RecordPipeline(name string, wallClock time.Duration, llmTime time.Duration, calls int, inputTokens int, outputTokens int) { + mc.mu.Lock() + defer mc.mu.Unlock() + mc.pipelineTimings = append(mc.pipelineTimings, PipelineTiming{Name: name, WallClock: wallClock, LLMTime: llmTime, Calls: calls, InputTokens: inputTokens, OutputTokens: outputTokens}) +} + +func (mc *MetricsCollector) TotalInputTokens() int { + mc.mu.Lock() + defer mc.mu.Unlock() + total := 0 + for _, e := range mc.entries { + total += e.TotalInputTokens + } + return total +} + +func (mc *MetricsCollector) TotalOutputTokens() int { + mc.mu.Lock() + defer mc.mu.Unlock() + total := 0 + for _, e := range mc.entries { + total += e.TotalOutputTokens + } + return total +} + +func (mc *MetricsCollector) PrintSummary() { + agents := make([]string, 0, len(mc.entries)) + for name := range mc.entries { + agents = append(agents, name) + } + sort.Strings(agents) + + fmt.Println() + fmt.Println("╔══════════════════════════════════════════════════════════════════════════════════════════════════╗") + fmt.Println("║ XMem-Go Pipeline Metrics Summary ║") + fmt.Println("╠══════════════════════════════════════════════════════════════════════════════════════════════════╣") + fmt.Printf("║ %-20s %6s %12s %12s %12s %12s ║\n", "Agent", "Calls", "LLM Time", "Overhead", "In Tokens", "Out Tokens") + fmt.Println("╠══════════════════════════════════════════════════════════════════════════════════════════════════╣") + for _, name := range agents { + entry := mc.entries[name] + fmt.Printf("║ %-20s %6d %12s %12s %12d %12d ║\n", + entry.Agent, + entry.CallCount, + entry.LLMTime.Round(time.Millisecond), + "—", + entry.TotalInputTokens, + entry.TotalOutputTokens, + ) + } + if len(mc.pipelineTimings) > 0 { + fmt.Println("╠══════════════════════════════════════════════════════════════════════════════════════════════════╣") + fmt.Printf("║ %-20s %6s %12s %12s %12s %12s ║\n", "Pipeline", "Calls", "LLM Time", "Overhead", "Wall Clock", "") + fmt.Println("╠══════════════════════════════════════════════════════════════════════════════════════════════════╣") + for _, pt := range mc.pipelineTimings { + overhead := pt.WallClock - pt.LLMTime + if overhead < 0 { + overhead = 0 + } + fmt.Printf("║ %-20s %6d %12s %12s %12s %12s ║\n", + pt.Name, + pt.Calls, + pt.LLMTime.Round(time.Millisecond), + overhead.Round(time.Millisecond), + pt.WallClock.Round(time.Millisecond), + "", + ) + } + } + fmt.Println("╚══════════════════════════════════════════════════════════════════════════════════════════════════╝") + + if len(mc.pipelineTimings) > 0 { + fmt.Println() + totalIn := mc.TotalInputTokens() + totalOut := mc.TotalOutputTokens() + + var ingestPT, retrievePT *PipelineTiming + for i := range mc.pipelineTimings { + if strings.Contains(mc.pipelineTimings[i].Name, "Ingest") { + ingestPT = &mc.pipelineTimings[i] + } else if strings.Contains(mc.pipelineTimings[i].Name, "Retrieval") { + retrievePT = &mc.pipelineTimings[i] + } + } + + inputPrice := 0.15 / 1_000_000.0 + outputPrice := 0.60 / 1_000_000.0 + + fmt.Printf(" User Query Tokens: ~%d\n", mc.queryTokens) + fmt.Printf(" Total Input Tokens: %d\n", totalIn) + fmt.Printf(" Total Output Tokens: %d\n", totalOut) + if ingestPT != nil { + cost := float64(ingestPT.InputTokens)*inputPrice + float64(ingestPT.OutputTokens)*outputPrice + fmt.Printf(" Cost to Ingest: $%.6f (%d in / %d out)\n", cost, ingestPT.InputTokens, ingestPT.OutputTokens) + } + if retrievePT != nil { + cost := float64(retrievePT.InputTokens)*inputPrice + float64(retrievePT.OutputTokens)*outputPrice + fmt.Printf(" Cost to Retrieve: $%.6f (%d in / %d out)\n", cost, retrievePT.InputTokens, retrievePT.OutputTokens) + } + fmt.Println() + fmt.Println(" Cost estimate based on gpt-4o-mini pricing ($0.15/1M input, $0.60/1M output).") + fmt.Println(" Actual cost varies by provider/model. Tokens may be estimated if provider didn't return usage.") + } +} + +// TracingModel intercepts ChatModel calls and prints the user message and response (system prompts are hidden). +// Metrics are recorded in a shared MetricsCollector for the final summary table. +type TracingModel struct { + inner models.ChatModel + agent string + mu *sync.Mutex + metrics *MetricsCollector +} + +func NewTracingModel(inner models.ChatModel) *TracingModel { + return &TracingModel{inner: inner, agent: "model", mu: &sync.Mutex{}, metrics: NewMetricsCollector()} +} + +func (t *TracingModel) ForAgent(agent string) *TracingModel { + return &TracingModel{inner: t.inner, agent: agent, mu: t.mu, metrics: t.metrics} +} + +func (t *TracingModel) Metrics() *MetricsCollector { + return t.metrics +} + +func (t *TracingModel) Name() string { + return t.inner.Name() +} + +func indentLines(str string, spaces int) string { + indent := strings.Repeat(" ", spaces) + lines := strings.Split(str, "\n") + for i, line := range lines { + lines[i] = indent + line + } + return strings.Join(lines, "\n") +} + +func estimateTokens(text string) int { + text = strings.TrimSpace(text) + if text == "" { + return 0 + } + // Good enough for trace readability without provider-specific tokenizers. + return (len([]rune(text)) + 3) / 4 +} + +func estimateMessageTokens(msgs []models.Message) int { + total := 0 + for _, msg := range msgs { + total += estimateTokens(msg.Role) + estimateTokens(msg.Content) + } + return total +} + +func tokenCounts(resp models.Response, estimatedInput int, estimatedOutput int) (int, int, bool) { + if resp.InputTokens > 0 || resp.OutputTokens > 0 { + return resp.InputTokens, resp.OutputTokens, false + } + return estimatedInput, estimatedOutput, true +} + +func (t *TracingModel) Generate(ctx context.Context, prompt string) (models.Response, error) { + start := time.Now() + resp, err := t.inner.Generate(ctx, prompt) + elapsed := time.Since(start) + + inputTokens, outputTokens, _ := tokenCounts(resp, estimateTokens(prompt), estimateTokens(resp.Content)) + t.metrics.Record(t.agent, elapsed, inputTokens, outputTokens) + + t.mu.Lock() + defer t.mu.Unlock() + + fmt.Println() + fmt.Printf(" \x1b[35m┌─── [LLM Call: %s / Generate] ────────────────────────────────────────────────┐\x1b[0m\n", t.agent) + fmt.Println(" \x1b[35m│\x1b[0m \x1b[1;33mPrompt:\x1b[0m") + fmt.Println(indentLines(prompt, 4)) + fmt.Println(" \x1b[35m├─────────────────────────────────────────────────────────────────────────────────┤\x1b[0m") + if err != nil { + fmt.Printf(" \x1b[35m│\x1b[0m \x1b[1;31mError:\x1b[0m %v\n", err) + } else { + fmt.Println(" \x1b[35m│\x1b[0m \x1b[1;32mResponse:\x1b[0m") + fmt.Println(indentLines(resp.Content, 4)) + } + fmt.Println(" \x1b[35m└─────────────────────────────────────────────────────────────────────────────────┘\x1b[0m") + + return resp, err +} + +func (t *TracingModel) GenerateWithMessages(ctx context.Context, msgs []models.Message) (models.Response, error) { + start := time.Now() + resp, err := t.inner.GenerateWithMessages(ctx, msgs) + elapsed := time.Since(start) + + inputTokens, outputTokens, _ := tokenCounts(resp, estimateMessageTokens(msgs), estimateTokens(resp.Content)) + t.metrics.Record(t.agent, elapsed, inputTokens, outputTokens) + + t.mu.Lock() + defer t.mu.Unlock() + + fmt.Println() + fmt.Printf(" \x1b[35m┌─── [LLM Call: %s / GenerateWithMessages] ───────────────────────────────────┐\x1b[0m\n", t.agent) + for _, m := range msgs { + if m.Role == "system" { + continue // Skip printing system prompts — they are extremely long + } + roleColor := "\x1b[1;33m" // user -> Yellow + if m.Role == "assistant" || m.Role == "model" { + roleColor = "\x1b[1;32m" // assistant -> Green + } + fmt.Printf(" \x1b[35m│\x1b[0m %s%s:\x1b[0m\n", roleColor, strings.ToUpper(m.Role)) + fmt.Println(indentLines(m.Content, 4)) + } + fmt.Println(" \x1b[35m├─────────────────────────────────────────────────────────────────────────────────┤\x1b[0m") + if err != nil { + fmt.Printf(" \x1b[35m│\x1b[0m \x1b[1;31mError:\x1b[0m %v\n", err) + } else { + fmt.Println(" \x1b[35m│\x1b[0m \x1b[1;32mResponse:\x1b[0m") + fmt.Println(indentLines(resp.Content, 4)) + } + fmt.Println(" \x1b[35m└─────────────────────────────────────────────────────────────────────────────────┘\x1b[0m") + + return resp, err +} + +func (t *TracingModel) GenerateVision(ctx context.Context, systemPrompt string, userText string, imageURL string) (models.Response, error) { + start := time.Now() + resp, err := t.inner.GenerateVision(ctx, systemPrompt, userText, imageURL) + elapsed := time.Since(start) + inputTokens, outputTokens, _ := tokenCounts(resp, estimateTokens(userText)+estimateTokens(systemPrompt), estimateTokens(resp.Content)) + t.metrics.Record(t.agent, elapsed, inputTokens, outputTokens) + + t.mu.Lock() + defer t.mu.Unlock() + fmt.Println() + fmt.Printf(" \x1b[35m┌─── [LLM Call: %s / GenerateVision] ───────────────────────────────────┐\x1b[0m\n", t.agent) + fmt.Printf(" \x1b[35m│\x1b[0m \x1b[1;33mUser Text:\x1b[0m %s\n", userText[:min(80, len(userText))]) + fmt.Printf(" \x1b[35m│\x1b[0m \x1b[1;34mImage URL:\x1b[0m %s\n", imageURL[:min(80, len(imageURL))]) + fmt.Println(" \x1b[35m├─────────────────────────────────────────────────────────────────────────────────┤\x1b[0m") + if err != nil { + fmt.Printf(" \x1b[35m│\x1b[0m \x1b[1;31mError:\x1b[0m %v\n", err) + } else { + fmt.Println(" \x1b[35m│\x1b[0m \x1b[1;32mResponse:\x1b[0m") + fmt.Println(indentLines(resp.Content, 4)) + } + fmt.Println(" \x1b[35m└─────────────────────────────────────────────────────────────────────────────────┘\x1b[0m") + return resp, err +} + +func (t *TracingModel) SelectTools(ctx context.Context, query string, catalog []map[string]string) (models.Response, error) { + start := time.Now() + resp, err := t.inner.SelectTools(ctx, query, catalog) + elapsed := time.Since(start) + + toolCallsJSON, _ := json.Marshal(resp.ToolCalls) + catalogJSON, _ := json.Marshal(catalog) + inputTokens, outputTokens, _ := tokenCounts(resp, estimateTokens(query)+estimateTokens(string(catalogJSON)), estimateTokens(string(toolCallsJSON))) + t.metrics.Record(t.agent, elapsed, inputTokens, outputTokens) + + t.mu.Lock() + defer t.mu.Unlock() + + fmt.Println() + fmt.Println(" \x1b[35m┌─── [LLM Call: SelectTools] ─────────────────────────────────────────────────────┐\x1b[0m") + fmt.Println(" \x1b[35m│\x1b[0m \x1b[1;33mQuery:\x1b[0m", query) + fmt.Println(" \x1b[35m│\x1b[0m \x1b[1;34mProfile Catalog:\x1b[0m", string(catalogJSON)) + fmt.Println(" \x1b[35m├─────────────────────────────────────────────────────────────────────────────────┤\x1b[0m") + if err != nil { + fmt.Printf(" \x1b[35m│\x1b[0m \x1b[1;31mError:\x1b[0m %v\n", err) + } else { + fmt.Println(" \x1b[35m│\x1b[0m \x1b[1;32mSelected Tools:\x1b[0m") + prettyToolCallsJSON, _ := json.MarshalIndent(resp.ToolCalls, " ", " ") + fmt.Println(indentLines(string(prettyToolCallsJSON), 2)) + } + fmt.Println(" \x1b[35m└─────────────────────────────────────────────────────────────────────────────────┘\x1b[0m") + + return resp, err +} + +func printBanner(title string) { + border := strings.Repeat("═", len(title)+8) + fmt.Printf("\n\x1b[1;36m╔%s╗\x1b[0m\n", border) + fmt.Printf("\x1b[1;36m║ %s ║\x1b[0m\n", title) + fmt.Printf("\x1b[1;36m╚%s╝\x1b[0m\n\n", border) +} + +func printStep(step string) { + fmt.Printf("\n\x1b[1;34m─── %s ──────────────────────────────────────────────────────────\x1b[0m\n", step) +} + +// buildRealStores uses the configured cloud stores. It intentionally fails hard +// instead of falling back to memory so benchmark runs cannot silently avoid cloud DBs. +func buildRealStores(ctx context.Context, settings config.Settings, logger *slog.Logger) (storage.Embedder, storage.VectorStore, storage.VectorStore, graph.TemporalStore, error) { + // --- Embedder --- + var embedder storage.Embedder + if settings.EmbeddingProvider == "openai" { + oai, err := storage.NewOpenAIEmbedder(settings) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("openai embedder unavailable: %w", err) + } + embedder = oai + logger.Info("using OpenAI embedder", "model", settings.OpenAIEmbeddingModel, "dimension", settings.PineconeDimension) + } else { + return nil, nil, nil, nil, fmt.Errorf("cloud benchmark requires EMBEDDING_PROVIDER=openai, got %q", settings.EmbeddingProvider) + } + + // --- Vector stores --- + if !strings.EqualFold(settings.VectorStoreProvider, "pinecone") { + return nil, nil, nil, nil, fmt.Errorf("cloud benchmark requires VECTOR_STORE_PROVIDER=pinecone, got %q", settings.VectorStoreProvider) + } + vectorStore, err := storage.NewPineconeVectorStore(ctx, settings, embedder, settings.PineconeNamespace) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("pinecone vector store unavailable: %w", err) + } + logger.Info("using Pinecone vector store", "namespace", settings.PineconeNamespace) + + snippetNS := settings.PineconeNamespace + "-snippets" + snippetStore, err := storage.NewPineconeVectorStore(ctx, settings, embedder, snippetNS) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("pinecone snippet store unavailable: %w", err) + } + logger.Info("using Pinecone snippet store", "namespace", snippetNS) + + // --- Temporal store --- + if settings.Neo4jPassword == "" { + return nil, nil, nil, nil, fmt.Errorf("cloud benchmark requires NEO4J_PASSWORD") + } + temporalStore, err := graph.NewNeo4jTemporalStore(ctx, settings, embedder) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("neo4j temporal store unavailable: %w", err) + } + logger.Info("using Neo4j temporal store") + + return embedder, vectorStore, snippetStore, temporalStore, nil +} + +func main() { + settings, err := config.Load() + if err != nil { + fmt.Fprintf(os.Stderr, "config error: %v\n", err) + os.Exit(1) + } + + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelInfo})) + + ctx := context.Background() + + realModel, err := models.NewRegistry(settings) + if err != nil { + fmt.Fprintf(os.Stderr, "model registry error: %v\n", err) + os.Exit(1) + } + tracingModel := NewTracingModel(realModel) + + // Build real stores from .env (Pinecone, OpenAI embeddings, Neo4j) + embedder, vectorStore, snippetStore, temporalStore, err := buildRealStores(ctx, settings, logger) + if err != nil { + fmt.Fprintf(os.Stderr, "cloud store error: %v\n", err) + os.Exit(1) + } + + printBanner(fmt.Sprintf("XMem-Go Pipeline Flow Trace (Model: %s)", realModel.Name())) + + userID := "trace-user" + + makePipeline := func() *pipelines.IngestPipeline { + return &pipelines.IngestPipeline{ + ModelName: tracingModel.Name(), + Weaver: &weaver.Weaver{VectorStore: vectorStore, SnippetVectorStore: snippetStore, Embedder: embedder, TemporalStore: temporalStore}, + Classifier: agents.ClassifierAgent{Model: tracingModel.ForAgent("classifier")}, + Profiler: agents.ProfilerAgent{Model: tracingModel.ForAgent("profiler")}, + Temporal: agents.TemporalAgent{Model: tracingModel.ForAgent("temporal")}, + Summarizer: agents.SummarizerAgent{Model: tracingModel.ForAgent("summarizer")}, + Image: agents.ImageAgent{Model: tracingModel.ForAgent("image")}, + Code: agents.CodeAgent{Model: tracingModel.ForAgent("code")}, + Snippet: agents.SnippetAgent{Model: tracingModel.ForAgent("snippet")}, + Judge: agents.JudgeAgent{Model: tracingModel.ForAgent("judge"), VectorStore: vectorStore, TemporalStore: temporalStore, TopK: 3}, + } + } + + // ────────────────────────────────────────────────────────────────────── + // 1. Full Ingest Pipeline Flow + // ────────────────────────────────────────────────────────────────────── + printStep("1. Full Ingest Pipeline Flow") + { + pipeline := makePipeline() + + req := contracts.IngestRequest{ + UserQuery: "My name is Bob, and I started a new job at Vercel as a frontend developer today!", + AgentResponse: "Congratulations on your new role Bob! That's wonderful news.", + SessionDatetime: "4:00 pm on 20 May, 2026", + } + + tracingModel.Metrics().queryTokens = estimateTokens(req.UserQuery) + + fmt.Printf("\x1b[1;33mIngest Request:\x1b[0m\n User Query: %s\n Response: %s\n", req.UserQuery, req.AgentResponse) + fmt.Println("\nExecuting Ingest Pipeline (data goes to Pinecone / Neo4j)...") + + llmBefore := tracingModel.Metrics().TotalLLMTime() + callsBefore := tracingModel.Metrics().TotalCalls() + inBefore := tracingModel.Metrics().TotalInputTokens() + outBefore := tracingModel.Metrics().TotalOutputTokens() + ingestStart := time.Now() + resp, err := pipeline.Run(ctx, req, userID) + ingestWall := time.Since(ingestStart) + ingestLLM := tracingModel.Metrics().TotalLLMTime() - llmBefore + ingestCalls := tracingModel.Metrics().TotalCalls() - callsBefore + ingestIn := tracingModel.Metrics().TotalInputTokens() - inBefore + ingestOut := tracingModel.Metrics().TotalOutputTokens() - outBefore + tracingModel.Metrics().RecordPipeline("Full Ingest", ingestWall, ingestLLM, ingestCalls, ingestIn, ingestOut) + + if err != nil { + fmt.Printf("\x1b[1;31mPipeline Error:\x1b[0m %v\n", err) + } else { + fmt.Println("\n\x1b[1;32mIngest Pipeline Completed. Response:\x1b[0m") + respBytes, _ := json.MarshalIndent(resp, "", " ") + fmt.Println(string(respBytes)) + } + } + + // ────────────────────────────────────────────────────────────────────── + // 2. Full Retrieval Pipeline Flow + // ────────────────────────────────────────────────────────────────────── + printStep("2. Full Retrieval Pipeline Flow") + { + pipeline := &pipelines.RetrievalPipeline{ + Model: tracingModel.ForAgent("retrieval"), + VectorStore: vectorStore, + SnippetStore: snippetStore, + TemporalStore: temporalStore, + } + + req := contracts.RetrieveRequest{ + Query: "What is my name and where do I work?", + } + + fmt.Printf("\x1b[1;33mRetrieval Request Query:\x1b[0m %s\n", req.Query) + fmt.Println("\nExecuting Retrieval Pipeline (querying Pinecone / Neo4j)...") + + llmBefore := tracingModel.Metrics().TotalLLMTime() + callsBefore := tracingModel.Metrics().TotalCalls() + inBefore := tracingModel.Metrics().TotalInputTokens() + outBefore := tracingModel.Metrics().TotalOutputTokens() + retrieveStart := time.Now() + resp, err := pipeline.Run(ctx, req, userID) + retrieveWall := time.Since(retrieveStart) + retrieveLLM := tracingModel.Metrics().TotalLLMTime() - llmBefore + retrieveCalls := tracingModel.Metrics().TotalCalls() - callsBefore + retrieveIn := tracingModel.Metrics().TotalInputTokens() - inBefore + retrieveOut := tracingModel.Metrics().TotalOutputTokens() - outBefore + tracingModel.Metrics().RecordPipeline("Full Retrieval", retrieveWall, retrieveLLM, retrieveCalls, retrieveIn, retrieveOut) + + if err != nil { + fmt.Printf("\x1b[1;31mPipeline Error:\x1b[0m %v\n", err) + } else { + fmt.Println("\n\x1b[1;32mRetrieval Response:\x1b[0m") + fmt.Printf(" Answer: %s\n", resp.Answer) + fmt.Printf(" Confidence: %.2f\n", resp.Confidence) + fmt.Printf(" Sources retrieved: %d\n", len(resp.Sources)) + for i, src := range resp.Sources { + fmt.Printf(" [%d] Domain: %s | Score: %.3f | Content: %s\n", i+1, src.Domain, src.Score, src.Content) + } + } + } + + tracingModel.Metrics().PrintSummary() +} diff --git a/xmem-go/cmd/xmem/bootstrap.go b/xmem-go/cmd/xmem/bootstrap.go index 60ae66c..9a7361d 100644 --- a/xmem-go/cmd/xmem/bootstrap.go +++ b/xmem-go/cmd/xmem/bootstrap.go @@ -34,12 +34,15 @@ func buildRuntime(ctx context.Context, settings config.Settings, logger *slog.Lo return runtimeDeps{}, err } - temporalStore, err := buildTemporalStore(ctx, settings, logger) + temporalStore, err := buildTemporalStore(ctx, settings, embedder, logger) if err != nil { return runtimeDeps{}, err } - model := models.NewRegistry(settings) + model, err := models.NewRegistry(settings) + if err != nil { + return runtimeDeps{}, fmt.Errorf("LLM model initialization failed: %w", err) + } ingest, retrieval := buildPipelines(model, vectorStore, snippetStore, temporalStore, embedder) keyStore, jobStore, err := buildAppStores(ctx, settings, logger) @@ -56,21 +59,23 @@ func buildRuntime(ctx context.Context, settings config.Settings, logger *slog.Lo } func buildEmbedder(settings config.Settings, logger *slog.Logger) (storage.Embedder, error) { - fallback := storage.HashEmbedder{Dimension: settings.PineconeDimension} - if settings.EmbeddingProvider != "openai" { - return fallback, nil - } - - openAIEmbedder, err := storage.NewOpenAIEmbedder(settings) - if err == nil { + if settings.EmbeddingProvider == "openai" || settings.EmbeddingProvider == "" { + openAIEmbedder, err := storage.NewOpenAIEmbedder(settings) + if err != nil { + if production(settings) { + return nil, fmt.Errorf("openai embedder initialization failed: %w", err) + } + logger.Warn("openai embedder unavailable, falling back to hash embedder (dev only)", "error", err) + return storage.HashEmbedder{Dimension: settings.PineconeDimension}, nil + } logger.Info("using OpenAI embedder", "model", settings.OpenAIEmbeddingModel, "dimension", settings.PineconeDimension) return openAIEmbedder, nil } if production(settings) { - return nil, fmt.Errorf("openai embedder initialization failed: %w", err) + return nil, fmt.Errorf("unsupported EMBEDDING_PROVIDER=%q; only 'openai' is supported in production", settings.EmbeddingProvider) } - logger.Warn("openai embedder unavailable, using hash embedder", "error", err) - return fallback, nil + logger.Warn("using hash embedder (dev only)", "provider", settings.EmbeddingProvider) + return storage.HashEmbedder{Dimension: settings.PineconeDimension}, nil } func buildVectorStores(ctx context.Context, settings config.Settings, embedder storage.Embedder, logger *slog.Logger) (storage.VectorStore, storage.VectorStore, error) { @@ -106,13 +111,16 @@ func buildVectorStores(ctx context.Context, settings config.Settings, embedder s return vectorStore, snippetStore, nil } -func buildTemporalStore(ctx context.Context, settings config.Settings, logger *slog.Logger) (graph.TemporalStore, error) { - fallback := graph.NewMemoryTemporalStore() +func buildTemporalStore(ctx context.Context, settings config.Settings, embedder storage.Embedder, logger *slog.Logger) (graph.TemporalStore, error) { if settings.Neo4jPassword == "" { - return fallback, nil + if production(settings) { + return nil, fmt.Errorf("NEO4J_PASSWORD is required in production") + } + logger.Warn("NEO4J_PASSWORD not set, using memory temporal store (dev only)") + return graph.NewMemoryTemporalStore(), nil } - neoStore, err := graph.NewNeo4jTemporalStore(ctx, settings) + neoStore, err := graph.NewNeo4jTemporalStore(ctx, settings, embedder) if err == nil { logger.Info("using Neo4j temporal store") return neoStore, nil @@ -121,7 +129,7 @@ func buildTemporalStore(ctx context.Context, settings config.Settings, logger *s return nil, fmt.Errorf("neo4j initialization failed: %w", err) } logger.Warn("neo4j unavailable, using memory temporal store", "error", err) - return fallback, nil + return graph.NewMemoryTemporalStore(), nil } func buildPipelines(model models.ChatModel, vectorStore storage.VectorStore, snippetStore storage.VectorStore, temporalStore graph.TemporalStore, embedder storage.Embedder) (*pipelines.IngestPipeline, *pipelines.RetrievalPipeline) { @@ -140,8 +148,9 @@ func buildPipelines(model models.ChatModel, vectorStore storage.VectorStore, sni Temporal: agents.TemporalAgent{Model: model}, Summarizer: agents.SummarizerAgent{Model: model}, Image: agents.ImageAgent{Model: model}, + Code: agents.CodeAgent{Model: model}, Snippet: agents.SnippetAgent{Model: model}, - Judge: agents.JudgeAgent{Model: model, VectorStore: vectorStore, TopK: 3}, + Judge: agents.JudgeAgent{Model: model, VectorStore: vectorStore, TemporalStore: temporalStore, TopK: 3}, } retrieval := &pipelines.RetrievalPipeline{ Model: model, diff --git a/xmem-go/go.mod b/xmem-go/go.mod index a2ca528..209f5a0 100644 --- a/xmem-go/go.mod +++ b/xmem-go/go.mod @@ -4,8 +4,10 @@ go 1.26 require ( github.com/go-chi/chi/v5 v5.2.5 + github.com/goccy/go-json v0.10.6 github.com/neo4j/neo4j-go-driver/v5 v5.28.4 go.mongodb.org/mongo-driver/v2 v2.6.0 + golang.org/x/net v0.55.0 ) require ( @@ -14,7 +16,7 @@ require ( github.com/xdg-go/scram v1.2.0 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect - golang.org/x/crypto v0.33.0 // indirect - golang.org/x/sync v0.11.0 // indirect - golang.org/x/text v0.22.0 // indirect + golang.org/x/crypto v0.51.0 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/text v0.37.0 // indirect ) diff --git a/xmem-go/go.sum b/xmem-go/go.sum index 2dca09b..ecbf875 100644 --- a/xmem-go/go.sum +++ b/xmem-go/go.sum @@ -2,6 +2,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= +github.com/goccy/go-json v0.10.6 h1:p8HrPJzOakx/mn/bQtjgNjdTcN+/S6FcG2CTtQOrHVU= +github.com/goccy/go-json v0.10.6/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI= @@ -21,16 +23,18 @@ go.mongodb.org/mongo-driver/v2 v2.6.0 h1:b9sJOYrkmt4l8bY43ZenFBcPlhYIjaOfYHLtbB/ go.mongodb.org/mongo-driver/v2 v2.6.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= -golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= +golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= +golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8= +golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= -golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -42,8 +46,8 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= -golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= -golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= +golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= +golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= diff --git a/xmem-go/internal/agents/agents.go b/xmem-go/internal/agents/agents.go deleted file mode 100644 index 0cffaa5..0000000 --- a/xmem-go/internal/agents/agents.go +++ /dev/null @@ -1,572 +0,0 @@ -package agents - -import ( - "context" - "encoding/json" - "fmt" - "strconv" - "strings" - - "github.com/xortexai/xmem-go/internal/models" - "github.com/xortexai/xmem-go/internal/prompts" - "github.com/xortexai/xmem-go/internal/storage" - "github.com/xortexai/xmem-go/internal/utils" - "github.com/xortexai/xmem-go/internal/weaver" -) - -const summaryJudgeSimilarityThreshold = 0.4 - -// ---------- shared LLM call helper ---------- - -func callModel(ctx context.Context, model models.ChatModel, systemPrompt, userMessage string) (string, error) { - messages := []models.Message{ - {Role: "system", Content: systemPrompt}, - {Role: "user", Content: userMessage}, - } - resp, err := model.GenerateWithMessages(ctx, messages) - if err != nil { - return "", fmt.Errorf("LLM call failed (%s): %w", model.Name(), err) - } - return resp.Content, nil -} - -// ---------- Classification ---------- - -type Classification struct { - Source string `json:"source"` - Query string `json:"query"` -} - -type ClassifierAgent struct { - Model models.ChatModel -} - -func (a ClassifierAgent) Run(ctx context.Context, userQuery string, imageURL string) []Classification { - systemPrompt := prompts.BuildClassifierSystemPrompt() - userMessage := prompts.PackClassificationQuery(userQuery) - - raw, err := callModel(ctx, a.Model, systemPrompt, userMessage) - if err != nil { - return nil - } - - parsed := utils.ParseRawResponseToClassifications(raw) - - out := make([]Classification, 0, len(parsed)) - for _, c := range parsed { - out = append(out, Classification{Source: c.Source, Query: c.Query}) - } - - if strings.TrimSpace(imageURL) != "" { - hasImage := false - for _, c := range out { - if c.Source == "image" { - hasImage = true - break - } - } - if !hasImage { - out = append(out, Classification{Source: "image", Query: "Analyze this image for memory-relevant details."}) - } - } - - return out -} - -// ---------- Profile ---------- - -type ProfileFact struct { - Topic string - SubTopic string - Memo string -} - -type ProfilerAgent struct { - Model models.ChatModel -} - -func (a ProfilerAgent) Run(ctx context.Context, text string) []ProfileFact { - systemPrompt := prompts.BuildProfilerSystemPrompt() - userMessage := prompts.PackProfilerQuery(text) - - raw, err := callModel(ctx, a.Model, systemPrompt, userMessage) - if err != nil { - return nil - } - - parsed := utils.ParseRawResponseToProfiles(raw) - - facts := make([]ProfileFact, 0, len(parsed)) - for _, f := range parsed { - facts = append(facts, ProfileFact{ - Topic: f.Topic, - SubTopic: f.SubTopic, - Memo: f.Memo, - }) - } - return facts -} - -// ---------- Temporal ---------- - -type Event struct { - Date string - EventName string - Desc string - Year string - Time string - DateExpression string -} - -type TemporalAgent struct { - Model models.ChatModel -} - -func (a TemporalAgent) Run(ctx context.Context, text string, sessionDatetime string) []Event { - systemPrompt := prompts.BuildTemporalSystemPrompt() - userMessage := prompts.PackTemporalQuery(text, sessionDatetime) - - raw, err := callModel(ctx, a.Model, systemPrompt, userMessage) - if err != nil { - return nil - } - - parsed := utils.ParseRawResponseToEvents(raw) - - events := make([]Event, 0, len(parsed)) - for _, e := range parsed { - if !isValidDate(e.Date) { - continue - } - events = append(events, Event{ - Date: e.Date, - EventName: e.EventName, - Desc: e.Desc, - Year: e.Year, - Time: e.Time, - DateExpression: e.DateExpression, - }) - } - return events -} - -func isValidDate(date string) bool { - parts := strings.SplitN(date, "-", 2) - if len(parts) != 2 { - return false - } - month, err := strconv.Atoi(parts[0]) - if err != nil || month < 1 || month > 12 { - return false - } - day, err := strconv.Atoi(parts[1]) - if err != nil || day < 1 || day > 31 { - return false - } - return true -} - -// ---------- Summarizer ---------- - -type SummarizerAgent struct { - Model models.ChatModel -} - -var emptySentinels = map[string]struct{}{ - `""`: {}, - `''`: {}, - "empty": {}, - "(empty)": {}, - "(empty string)": {}, -} - -func (a SummarizerAgent) Run(ctx context.Context, userQuery string, agentResponse string) []string { - systemPrompt := prompts.BuildSummarizerSystemPrompt() - userMessage := prompts.PackSummaryQuery(userQuery, agentResponse) - - raw, err := callModel(ctx, a.Model, systemPrompt, userMessage) - if err != nil { - return nil - } - - trimmed := strings.TrimSpace(raw) - if _, isEmpty := emptySentinels[strings.ToLower(trimmed)]; isEmpty || trimmed == "" { - return nil - } - - lines := strings.Split(trimmed, "\n") - bullets := make([]string, 0, len(lines)) - for _, line := range lines { - cleaned := strings.TrimSpace(line) - cleaned = strings.TrimLeft(cleaned, "-•*") - cleaned = strings.TrimSpace(cleaned) - if cleaned != "" { - bullets = append(bullets, cleaned) - } - } - return bullets -} - -// ---------- Image ---------- - -type ImageAgent struct { - Model models.ChatModel -} - -func (a ImageAgent) Run(ctx context.Context, imageURL string) []string { - if strings.TrimSpace(imageURL) == "" { - return nil - } - - systemPrompt := prompts.BuildImageSystemPrompt() - userMessage := prompts.PackImageQuery("", imageURL) - - raw, err := callModel(ctx, a.Model, systemPrompt, userMessage) - if err != nil { - return nil - } - - result := utils.ParseRawResponseToImage(raw) - - observations := make([]string, 0, len(result.Observations)) - for _, obs := range result.Observations { - entry := "[" + obs.Category + "] " + obs.Description - if obs.Confidence != "" { - entry += " (confidence: " + obs.Confidence + ")" - } - observations = append(observations, entry) - } - if len(observations) == 0 && result.Description != "" { - observations = append(observations, result.Description) - } - return observations -} - -// ---------- Snippet ---------- - -type SnippetAgent struct { - Model models.ChatModel -} - -type snippetJSON struct { - Content string `json:"content"` - CodeSnippet string `json:"code_snippet"` - Language string `json:"language"` - SnippetType string `json:"snippet_type"` - Tags string `json:"tags"` -} - -type snippetsResponse struct { - Snippets []snippetJSON `json:"snippets"` -} - -func (a SnippetAgent) Run(ctx context.Context, text string) []string { - systemPrompt := prompts.BuildSnippetSystemPrompt() - userMessage := prompts.PackSnippetQuery(text) - - raw, err := callModel(ctx, a.Model, systemPrompt, userMessage) - if err != nil { - return nil - } - - items := parseSnippetResponse(raw) - results := make([]string, 0, len(items)) - for _, s := range items { - line := joinPipe(s.Content, s.CodeSnippet, s.Language, s.SnippetType, s.Tags) - results = append(results, line) - } - return results -} - -func parseSnippetResponse(raw string) []snippetJSON { - jsonStr := extractJSONObject(raw) - if jsonStr == "" { - jsonStr = extractJSONArray(raw) - } - - var resp snippetsResponse - if err := json.Unmarshal([]byte(jsonStr), &resp); err == nil && len(resp.Snippets) > 0 { - return resp.Snippets - } - - var arr []snippetJSON - if err := json.Unmarshal([]byte(jsonStr), &arr); err == nil && len(arr) > 0 { - return arr - } - - var single snippetJSON - if err := json.Unmarshal([]byte(jsonStr), &single); err == nil && single.Content != "" { - return []snippetJSON{single} - } - - return nil -} - -// ---------- Judge ---------- - -type JudgeAgent struct { - Model models.ChatModel - VectorStore storage.VectorStore - TopK int -} - -func (a JudgeAgent) judgeTopK() int { - if a.TopK <= 0 { - return 3 - } - return a.TopK -} - -func (a JudgeAgent) JudgeItems(ctx context.Context, domain weaver.JudgeDomain, items []string, userID string, confidence float64) weaver.JudgeResult { - if len(items) == 0 { - return weaver.JudgeResult{} - } - - if domain == weaver.DomainSummary { - matches := a.fetchSimilarSummaries(ctx, items, userID) - if !hasSummaryJudgeCandidates(matches) { - return judgeDeterministicSummary(items, confidence) - } - similarBlock := formatSummarySimilarBlock(items, filterMatchesByThreshold(matches, summaryJudgeSimilarityThreshold)) - return a.judgeItemsWithLLM(ctx, domain, items, similarBlock, confidence) - } - - return a.judgeItemsWithLLM(ctx, domain, items, nil, confidence) -} - -func (a JudgeAgent) judgeItemsWithLLM(ctx context.Context, domain weaver.JudgeDomain, items []string, similarLines []string, confidence float64) weaver.JudgeResult { - systemPrompt := prompts.BuildJudgeSystemPrompt() - userMessage := prompts.PackJudgeQuery(items, similarLines, string(domain)) - - raw, err := callModel(ctx, a.Model, systemPrompt, userMessage) - if err != nil { - return judgeFallback(items, confidence) - } - - result, ok := parseJudgeResponse(raw) - if !ok || len(result.Operations) == 0 { - return judgeFallback(items, confidence) - } - - if confidence > 0 { - result.Confidence = confidence - } - return result -} - -func (a JudgeAgent) JudgeProfile(ctx context.Context, facts []ProfileFact) weaver.JudgeResult { - items := make([]string, 0, len(facts)) - for _, fact := range facts { - items = append(items, fact.Topic+" / "+fact.SubTopic+" = "+fact.Memo) - } - return judgeDeterministic(items, 1.0) -} - -func (a JudgeAgent) JudgeTemporal(ctx context.Context, events []Event) weaver.JudgeResult { - items := make([]string, 0, len(events)) - for _, event := range events { - items = append(items, strings.Join([]string{ - event.Date, event.EventName, event.Desc, event.Year, event.Time, event.DateExpression, - }, " | ")) - } - return judgeDeterministic(items, 1.0) -} - -func judgeDeterministic(items []string, confidence float64) weaver.JudgeResult { - ops := make([]weaver.Operation, 0, len(items)) - for _, item := range items { - item = strings.TrimSpace(item) - if item == "" { - continue - } - ops = append(ops, weaver.Operation{ - Type: weaver.OperationAdd, - Content: item, - Reason: "Deterministic extraction — no deduplication needed.", - }) - } - return weaver.JudgeResult{Operations: ops, Confidence: confidence} -} - -func judgeDeterministicSummary(items []string, confidence float64) weaver.JudgeResult { - ops := make([]weaver.Operation, 0, len(items)) - for _, item := range items { - item = strings.TrimSpace(item) - if item == "" { - continue - } - ops = append(ops, weaver.Operation{ - Type: weaver.OperationAdd, - Content: item, - Reason: "No similar summary at or above 0.4 — defaulting to ADD.", - }) - } - if confidence == 0 { - confidence = 0.8 - } - return weaver.JudgeResult{Operations: ops, Confidence: confidence} -} - -func (a JudgeAgent) fetchSimilarSummaries(ctx context.Context, items []string, userID string) map[string][]storage.SearchResult { - out := make(map[string][]storage.SearchResult, len(items)) - if a.VectorStore == nil { - return out - } - filters := map[string]any{"domain": string(weaver.DomainSummary)} - if strings.TrimSpace(userID) != "" { - filters["user_id"] = userID - } - for _, item := range items { - item = strings.TrimSpace(item) - if item == "" { - continue - } - results, err := a.VectorStore.SearchByText(ctx, item, a.judgeTopK(), filters) - if err != nil { - out[item] = nil - continue - } - out[item] = results - } - return out -} - -func hasSummaryJudgeCandidates(matches map[string][]storage.SearchResult) bool { - for _, results := range matches { - for _, result := range results { - if result.Score >= summaryJudgeSimilarityThreshold { - return true - } - } - } - return false -} - -func filterMatchesByThreshold(matches map[string][]storage.SearchResult, threshold float64) map[string][]storage.SearchResult { - out := make(map[string][]storage.SearchResult, len(matches)) - for item, results := range matches { - filtered := make([]storage.SearchResult, 0, len(results)) - for _, result := range results { - if result.Score >= threshold { - filtered = append(filtered, result) - } - } - out[item] = filtered - } - return out -} - -func formatSummarySimilarBlock(items []string, matches map[string][]storage.SearchResult) []string { - if len(matches) == 0 { - return nil - } - lines := make([]string, 0) - for _, item := range items { - item = strings.TrimSpace(item) - if item == "" { - continue - } - results := matches[item] - lines = append(lines, fmt.Sprintf("For item: %q", item)) - if len(results) == 0 { - lines = append(lines, " - (no similar records above threshold)") - continue - } - for _, result := range results { - lines = append(lines, fmt.Sprintf(" - ID: %s | Score: %.2f | %q", result.ID, result.Score, result.Content)) - } - } - return lines -} - -func judgeFallback(items []string, confidence float64) weaver.JudgeResult { - ops := make([]weaver.Operation, 0, len(items)) - for _, item := range items { - item = strings.TrimSpace(item) - if item == "" { - continue - } - ops = append(ops, weaver.Operation{ - Type: weaver.OperationAdd, - Content: item, - Reason: "LLM judge unavailable — defaulting to ADD.", - }) - } - if confidence == 0 { - confidence = 0.8 - } - return weaver.JudgeResult{Operations: ops, Confidence: confidence} -} - -type judgeResponse struct { - Operations []struct { - Type string `json:"type"` - Content string `json:"content"` - EmbeddingID string `json:"embedding_id"` - Reason string `json:"reason"` - } `json:"operations"` - Confidence float64 `json:"confidence"` -} - -func parseJudgeResponse(raw string) (weaver.JudgeResult, bool) { - jsonStr := extractJSONObject(raw) - if jsonStr == "" { - return weaver.JudgeResult{}, false - } - - var resp judgeResponse - if err := json.Unmarshal([]byte(jsonStr), &resp); err != nil { - return weaver.JudgeResult{}, false - } - - ops := make([]weaver.Operation, 0, len(resp.Operations)) - for _, o := range resp.Operations { - opType := weaver.OperationType(strings.ToUpper(strings.TrimSpace(o.Type))) - switch opType { - case weaver.OperationAdd, weaver.OperationUpdate, weaver.OperationDelete, weaver.OperationNoop: - default: - opType = weaver.OperationAdd - } - if strings.TrimSpace(o.Content) == "" && opType != weaver.OperationDelete { - continue - } - ops = append(ops, weaver.Operation{ - Type: opType, - Content: o.Content, - EmbeddingID: o.EmbeddingID, - Reason: o.Reason, - }) - } - - conf := resp.Confidence - if conf <= 0 || conf > 1 { - conf = 0.8 - } - return weaver.JudgeResult{Operations: ops, Confidence: conf}, len(ops) > 0 -} - -// ---------- helpers ---------- - -func joinPipe(parts ...string) string { - return strings.Join(parts, " | ") -} - -func extractJSONObject(text string) string { - text = strings.TrimSpace(text) - start := strings.Index(text, "{") - end := strings.LastIndex(text, "}") - if start >= 0 && end > start { - return text[start : end+1] - } - return "" -} - -func extractJSONArray(text string) string { - text = strings.TrimSpace(text) - start := strings.Index(text, "[") - end := strings.LastIndex(text, "]") - if start >= 0 && end > start { - return text[start : end+1] - } - return "" -} diff --git a/xmem-go/internal/agents/agents_test.go b/xmem-go/internal/agents/agents_test.go new file mode 100644 index 0000000..759d451 --- /dev/null +++ b/xmem-go/internal/agents/agents_test.go @@ -0,0 +1,86 @@ +package agents + +import ( + "context" + "strings" + "testing" + + "github.com/xortexai/xmem-go/internal/models" +) + +type stubModel struct { + content string +} + +func (m stubModel) Name() string { return "stub" } + +func (m stubModel) Generate(context.Context, string) (models.Response, error) { + return models.Response{Content: m.content}, nil +} + +func (m stubModel) GenerateWithMessages(context.Context, []models.Message) (models.Response, error) { + return models.Response{Content: m.content}, nil +} + +func (m stubModel) GenerateVision(_ context.Context, _ string, _ string, _ string) (models.Response, error) { + return models.Response{Content: m.content}, nil +} + +func (m stubModel) SelectTools(context.Context, string, []map[string]string) (models.Response, error) { + return models.Response{Content: m.content}, nil +} + +func TestTemporalDateValidationMatchesPythonAgent(t *testing.T) { + valid := []string{"01-31", "02-29", "04-30", "12-31"} + for _, date := range valid { + if !isValidDate(date) { + t.Fatalf("expected %s to be valid", date) + } + } + + invalid := []string{"", "1-01", "01-1", "00-10", "13-01", "02-30", "04-31"} + for _, date := range invalid { + if isValidDate(date) { + t.Fatalf("expected %s to be invalid", date) + } + } +} + +func TestSnippetAgentParsesArrayTags(t *testing.T) { + agent := SnippetAgent{Model: stubModel{content: `{"snippets":[{"content":"Binary search handles empty arrays","code_snippet":"return -1","language":"cpp","snippet_type":"algorithm","tags":["dsa","binary-search"]}]}`}} + + items := agent.Run(context.Background(), "save this snippet") + if len(items) != 1 { + t.Fatalf("items length = %d", len(items)) + } + if !strings.Contains(items[0], "dsa,binary-search") { + t.Fatalf("expected comma-joined tags, got %q", items[0]) + } +} + +func TestCodeAgentParsesAnnotations(t *testing.T) { + agent := CodeAgent{Model: stubModel{content: `{"annotations":[{"target_symbol":"PaymentProcessor.process","target_file":"src/payments.go","content":"Add an idempotency key before retrying charges.","annotation_type":"fix","severity":"high","repo":"payments"}]}`}} + + items := agent.Run(context.Background(), "remember this fix") + if len(items) != 1 { + t.Fatalf("items length = %d", len(items)) + } + if got := items[0]; !strings.Contains(got, "fix | PaymentProcessor.process | src/payments.go | payments | high | Add an idempotency key") { + t.Fatalf("unexpected annotation content: %q", got) + } +} + +func TestImageAgentFormatsPythonPipelineItems(t *testing.T) { + agent := ImageAgent{Model: stubModel{content: "DESCRIPTION: Whiteboard with API architecture\n\nOBSERVATIONS:\n- [technical] Shows service-to-database flow (confidence: high)"}} + + items := agent.Run(context.Background(), "analyze architecture", "https://example.com/image.png") + if len(items) != 2 { + t.Fatalf("items length = %d", len(items)) + } + if items[0] != "[Image] Whiteboard with API architecture" { + t.Fatalf("unexpected description item: %q", items[0]) + } + if items[1] != "[Image/technical] Shows service-to-database flow (high)" { + t.Fatalf("unexpected observation item: %q", items[1]) + } +} diff --git a/xmem-go/internal/agents/base.go b/xmem-go/internal/agents/base.go new file mode 100644 index 0000000..e116780 --- /dev/null +++ b/xmem-go/internal/agents/base.go @@ -0,0 +1,25 @@ +package agents + +import ( + "context" + "fmt" + "time" + + "github.com/xortexai/xmem-go/internal/models" +) + +const llmCallTimeout = 45 * time.Second + +func callModel(ctx context.Context, model models.ChatModel, systemPrompt, userMessage string) (string, error) { + ctx, cancel := context.WithTimeout(ctx, llmCallTimeout) + defer cancel() + messages := []models.Message{ + {Role: "system", Content: systemPrompt}, + {Role: "user", Content: userMessage}, + } + resp, err := model.GenerateWithMessages(ctx, messages) + if err != nil { + return "", fmt.Errorf("LLM call failed (%s): %w", model.Name(), err) + } + return resp.Content, nil +} diff --git a/xmem-go/internal/agents/classifier.go b/xmem-go/internal/agents/classifier.go new file mode 100644 index 0000000..289366c --- /dev/null +++ b/xmem-go/internal/agents/classifier.go @@ -0,0 +1,40 @@ +package agents + +import ( + "context" + + "github.com/xortexai/xmem-go/internal/models" + "github.com/xortexai/xmem-go/internal/prompts" + "github.com/xortexai/xmem-go/internal/utils" +) + +type Classification struct { + Source string `json:"source"` + Query string `json:"query"` +} + +type ClassifierAgent struct { + Model models.ChatModel +} + +func (a ClassifierAgent) Run(ctx context.Context, userQuery string, _ string) []Classification { + if userQuery == "" { + return nil + } + + systemPrompt := prompts.BuildClassifierSystemPrompt() + userMessage := prompts.PackClassificationQuery(userQuery) + + raw, err := callModel(ctx, a.Model, systemPrompt, userMessage) + if err != nil { + return nil + } + + parsed := utils.ParseRawResponseToClassifications(raw) + out := make([]Classification, 0, len(parsed)) + for _, c := range parsed { + out = append(out, Classification{Source: c.Source, Query: c.Query}) + } + + return out +} diff --git a/xmem-go/internal/agents/code.go b/xmem-go/internal/agents/code.go new file mode 100644 index 0000000..dc18ee7 --- /dev/null +++ b/xmem-go/internal/agents/code.go @@ -0,0 +1,74 @@ +package agents + +import ( + "context" + json "github.com/goccy/go-json" + "strings" + + "github.com/xortexai/xmem-go/internal/models" + "github.com/xortexai/xmem-go/internal/prompts" +) + +type CodeAgent struct { + Model models.ChatModel +} + +type codeAnnotationJSON struct { + TargetSymbol string `json:"target_symbol"` + TargetFile string `json:"target_file"` + Content string `json:"content"` + AnnotationType string `json:"annotation_type"` + Severity string `json:"severity"` + Repo string `json:"repo"` + AssignedToName string `json:"assigned_to_name"` +} + +type codeAnnotationsResponse struct { + Annotations []codeAnnotationJSON `json:"annotations"` +} + +func (a CodeAgent) Run(ctx context.Context, text string) []string { + if strings.TrimSpace(text) == "" { + return nil + } + + systemPrompt := prompts.BuildCodeSystemPrompt() + userMessage := prompts.PackCodeQuery(text) + + raw, err := callModel(ctx, a.Model, systemPrompt, userMessage) + if err != nil { + return nil + } + + items := parseCodeResponse(raw) + results := make([]string, 0, len(items)) + for _, ann := range items { + content := strings.TrimSpace(ann.Content) + if content == "" { + continue + } + line := joinPipe( + defaultString(ann.AnnotationType, "explanation"), + ann.TargetSymbol, + ann.TargetFile, + ann.Repo, + ann.Severity, + content, + ) + results = append(results, line) + } + return results +} + +func parseCodeResponse(raw string) []codeAnnotationJSON { + jsonStr := extractJSONObject(raw) + if jsonStr == "" { + jsonStr = extractJSONArray(raw) + } + + var resp codeAnnotationsResponse + if err := json.Unmarshal([]byte(jsonStr), &resp); err == nil && len(resp.Annotations) > 0 { + return resp.Annotations + } + return nil +} diff --git a/xmem-go/internal/agents/helpers.go b/xmem-go/internal/agents/helpers.go new file mode 100644 index 0000000..4b2052e --- /dev/null +++ b/xmem-go/internal/agents/helpers.go @@ -0,0 +1,166 @@ +package agents + +import ( + "fmt" + "strings" + + "github.com/xortexai/xmem-go/internal/graph" + "github.com/xortexai/xmem-go/internal/storage" +) + +func joinPipe(parts ...string) string { + return strings.Join(parts, " | ") +} + +func extractJSONObject(text string) string { + text = strings.TrimSpace(text) + start := strings.Index(text, "{") + end := strings.LastIndex(text, "}") + if start >= 0 && end > start { + return text[start : end+1] + } + return "" +} + +func extractJSONArray(text string) string { + text = strings.TrimSpace(text) + start := strings.Index(text, "[") + end := strings.LastIndex(text, "]") + if start >= 0 && end > start { + return text[start : end+1] + } + return "" +} + +func buildProfileMetadataKey(fact ProfileFact) string { + topic := strings.TrimSpace(fact.Topic) + subTopic := strings.TrimSpace(fact.SubTopic) + if topic == "" || subTopic == "" { + return "" + } + key := topic + "_" + subTopic + key = strings.ReplaceAll(key, " ", "_") + return strings.ToLower(key) +} + +func dedupeProfileItems(facts []ProfileFact) []ProfileFact { + latest := make(map[string]ProfileFact) + order := []string{} + passthrough := []ProfileFact{} + for _, fact := range facts { + key := buildProfileMetadataKey(fact) + if key != "" { + if _, exists := latest[key]; !exists { + order = append(order, key) + } + latest[key] = fact + } else { + passthrough = append(passthrough, fact) + } + } + out := make([]ProfileFact, 0, len(latest)+len(passthrough)) + for _, key := range order { + out = append(out, latest[key]) + } + out = append(out, passthrough...) + return out +} + +func dedupeTemporalItems(events []Event) []Event { + latest := make(map[string]Event) + order := []string{} + passthrough := []Event{} + for _, event := range events { + name := normText(event.EventName) + if name != "" { + if _, exists := latest[name]; !exists { + order = append(order, name) + } + latest[name] = event + } else { + passthrough = append(passthrough, event) + } + } + out := make([]Event, 0, len(latest)+len(passthrough)) + for _, key := range order { + out = append(out, latest[key]) + } + out = append(out, passthrough...) + return out +} + +func normText(val string) string { + fields := strings.Fields(strings.ToLower(strings.TrimSpace(val))) + return strings.Join(fields, " ") +} + +func profileMemoFromContent(content string) string { + if !strings.Contains(content, " = ") { + return content + } + parts := strings.SplitN(content, " = ", 2) + return strings.TrimSpace(parts[1]) +} + +func profileMemoFromMatch(match storage.SearchResult) string { + if match.Metadata != nil { + if sub, ok := match.Metadata["subcontent"]; ok { + return fmt.Sprintf("%v", sub) + } + } + return profileMemoFromContent(match.Content) +} + +func temporalFieldsFromContent(content string) Event { + parts := strings.Split(content, " | ") + for i, p := range parts { + parts[i] = strings.TrimSpace(p) + } + e := Event{} + if len(parts) > 0 { + e.Date = parts[0] + } + if len(parts) > 1 { + e.EventName = parts[1] + } + if len(parts) > 2 { + e.Desc = parts[2] + } + if len(parts) > 3 { + e.Year = parts[3] + } + if len(parts) > 4 { + e.Time = parts[4] + } + if len(parts) > 5 { + e.DateExpression = parts[5] + } + return e +} + +func temporalFieldsFromMatch(match graph.Event) Event { + return Event{ + Date: match.Date, + EventName: match.EventName, + Desc: match.Description, + Year: match.Year, + Time: match.Time, + DateExpression: match.DateExpression, + } +} + +func sameTemporalEvent(incoming, existing Event) bool { + return normText(incoming.Date) == normText(existing.Date) && + normText(incoming.EventName) == normText(existing.EventName) && + normText(incoming.Desc) == normText(existing.Desc) && + normText(incoming.Year) == normText(existing.Year) && + normText(incoming.Time) == normText(existing.Time) && + normText(incoming.DateExpression) == normText(existing.DateExpression) +} + +func defaultString(value, fallback string) string { + if strings.TrimSpace(value) == "" { + return fallback + } + return strings.TrimSpace(value) +} diff --git a/xmem-go/internal/agents/image.go b/xmem-go/internal/agents/image.go new file mode 100644 index 0000000..cbab2b6 --- /dev/null +++ b/xmem-go/internal/agents/image.go @@ -0,0 +1,54 @@ +package agents + +import ( + "context" + "strings" + + "github.com/xortexai/xmem-go/internal/models" + "github.com/xortexai/xmem-go/internal/prompts" + "github.com/xortexai/xmem-go/internal/utils" +) + +type ImageAgent struct { + Model models.ChatModel +} + +func (a ImageAgent) Run(ctx context.Context, query string, imageURL string) []string { + if strings.TrimSpace(query) == "" && strings.TrimSpace(imageURL) == "" { + return nil + } + + systemPrompt := prompts.BuildImageSystemPrompt() + userText := prompts.PackImageQuery(query, "") + + var raw string + var err error + if strings.TrimSpace(imageURL) != "" { + ctx, cancel := context.WithTimeout(ctx, llmCallTimeout) + defer cancel() + resp, visionErr := a.Model.GenerateVision(ctx, systemPrompt, userText, imageURL) + if visionErr != nil { + return nil + } + raw = resp.Content + } else { + raw, err = callModel(ctx, a.Model, systemPrompt, userText) + if err != nil { + return nil + } + } + + result := utils.ParseRawResponseToImage(raw) + items := make([]string, 0, len(result.Observations)+1) + if strings.TrimSpace(result.Description) != "" { + items = append(items, "[Image] "+result.Description) + } + for _, obs := range result.Observations { + entry := "[Image/" + obs.Category + "] " + obs.Description + if obs.Confidence != "" { + entry += " (" + obs.Confidence + ")" + } + items = append(items, entry) + } + return items +} diff --git a/xmem-go/internal/agents/judge.go b/xmem-go/internal/agents/judge.go new file mode 100644 index 0000000..5f4c691 --- /dev/null +++ b/xmem-go/internal/agents/judge.go @@ -0,0 +1,292 @@ +package agents + +import ( + "context" + json "github.com/goccy/go-json" + "fmt" + "strings" + + "github.com/xortexai/xmem-go/internal/graph" + "github.com/xortexai/xmem-go/internal/models" + "github.com/xortexai/xmem-go/internal/prompts" + "github.com/xortexai/xmem-go/internal/storage" + "github.com/xortexai/xmem-go/internal/weaver" +) + +const summaryJudgeSimilarityThreshold = 0.4 + +type JudgeAgent struct { + Model models.ChatModel + VectorStore storage.VectorStore + TemporalStore graph.TemporalStore + TopK int +} + +func (a JudgeAgent) judgeTopK() int { + if a.TopK <= 0 { + return 1 + } + return a.TopK +} + +func (a JudgeAgent) JudgeItems(ctx context.Context, domain weaver.JudgeDomain, items []string, userID string, confidence float64) weaver.JudgeResult { + if len(items) == 0 { + return weaver.JudgeResult{} + } + + if domain == weaver.DomainSummary { + matches := a.fetchSimilarSummaries(ctx, items, userID) + if !hasSummaryJudgeCandidates(matches) { + return judgeDeterministicSummary(items, confidence) + } + similarBlock := formatSummarySimilarBlock(items, filterMatchesByThreshold(matches, summaryJudgeSimilarityThreshold)) + return a.judgeItemsWithLLM(ctx, domain, items, similarBlock, confidence) + } + + return a.judgeItemsWithLLM(ctx, domain, items, nil, confidence) +} + +func (a JudgeAgent) judgeItemsWithLLM(ctx context.Context, domain weaver.JudgeDomain, items []string, similarLines []string, confidence float64) weaver.JudgeResult { + systemPrompt := prompts.BuildJudgeSystemPrompt() + userMessage := prompts.PackJudgeQuery(items, similarLines, string(domain)) + + raw, err := callModel(ctx, a.Model, systemPrompt, userMessage) + if err != nil { + return judgeFallback(items, confidence) + } + + result, ok := parseJudgeResponse(raw) + if !ok || len(result.Operations) == 0 { + return judgeParseFallback(items) + } + + if confidence > 0 { + result.Confidence = confidence + } + return result +} + +func (a JudgeAgent) JudgeProfile(ctx context.Context, facts []ProfileFact, userID string) weaver.JudgeResult { + dedupedFacts := dedupeProfileItems(facts) + ops := make([]weaver.Operation, 0, len(dedupedFacts)) + + for _, fact := range dedupedFacts { + itemStr := fact.Topic + " / " + fact.SubTopic + " = " + fact.Memo + key := buildProfileMetadataKey(fact) + if key == "" || a.VectorStore == nil { + ops = append(ops, weaver.Operation{Type: weaver.OperationAdd, Content: itemStr, Reason: "No vector store or invalid metadata key — defaulting to ADD."}) + continue + } + + filters := map[string]any{"domain": "profile", "main_content": key} + if userID != "" { + filters["user_id"] = userID + } + + results, err := a.VectorStore.SearchByMetadata(ctx, filters, a.judgeTopK()) + if err != nil || len(results) == 0 { + ops = append(ops, weaver.Operation{Type: weaver.OperationAdd, Content: itemStr, Reason: "No profile record with the same topic/sub_topic."}) + continue + } + + match := results[0] + incomingMemo := profileMemoFromContent(itemStr) + existingMemo := profileMemoFromMatch(match) + if normText(incomingMemo) == normText(existingMemo) { + ops = append(ops, weaver.Operation{Type: weaver.OperationNoop, Content: itemStr, EmbeddingID: match.ID, Reason: "Existing profile fact is unchanged."}) + } else { + ops = append(ops, weaver.Operation{Type: weaver.OperationUpdate, Content: itemStr, EmbeddingID: match.ID, Reason: "Existing profile fact has new content."}) + } + } + + return weaver.JudgeResult{Operations: ops, Confidence: 1.0} +} + +func (a JudgeAgent) JudgeTemporal(ctx context.Context, events []Event, userID string) weaver.JudgeResult { + dedupedEvents := dedupeTemporalItems(events) + ops := make([]weaver.Operation, 0, len(dedupedEvents)) + + for _, event := range dedupedEvents { + itemStr := strings.Join([]string{event.Date, event.EventName, event.Desc, event.Year, event.Time, event.DateExpression}, " | ") + if event.EventName == "" || a.TemporalStore == nil { + ops = append(ops, weaver.Operation{Type: weaver.OperationAdd, Content: itemStr, Reason: "No temporal store or invalid event name — defaulting to ADD."}) + continue + } + + results, err := a.TemporalStore.SearchEventsByName(ctx, event.EventName, userID, a.judgeTopK()) + if err != nil || len(results) == 0 { + ops = append(ops, weaver.Operation{Type: weaver.OperationAdd, Content: itemStr, Reason: "No temporal event with the same event_name."}) + continue + } + + match := results[0] + incoming := temporalFieldsFromContent(itemStr) + existing := temporalFieldsFromMatch(match) + if sameTemporalEvent(incoming, existing) { + ops = append(ops, weaver.Operation{Type: weaver.OperationNoop, Content: itemStr, EmbeddingID: match.EmbeddingID(), Reason: "Existing temporal event is unchanged."}) + } else if normText(incoming.Date) != normText(existing.Date) { + ops = append(ops, weaver.Operation{Type: weaver.OperationDelete, Content: strings.Join([]string{match.Date, match.EventName, match.Description, match.Year, match.Time, match.DateExpression}, " | "), EmbeddingID: match.EmbeddingID(), Reason: "Existing temporal event moved to a different date."}) + ops = append(ops, weaver.Operation{Type: weaver.OperationAdd, Content: itemStr, Reason: "Re-created temporal event on the updated date."}) + } else { + ops = append(ops, weaver.Operation{Type: weaver.OperationUpdate, Content: itemStr, EmbeddingID: match.EmbeddingID(), Reason: "Existing temporal event has new content."}) + } + } + + return weaver.JudgeResult{Operations: ops, Confidence: 1.0} +} + +func judgeDeterministicSummary(items []string, confidence float64) weaver.JudgeResult { + ops := make([]weaver.Operation, 0, len(items)) + for _, item := range items { + item = strings.TrimSpace(item) + if item == "" { + continue + } + ops = append(ops, weaver.Operation{Type: weaver.OperationAdd, Content: item, Reason: "No similar summary at or above 0.4 — defaulting to ADD."}) + } + if confidence == 0 { + confidence = 0.8 + } + return weaver.JudgeResult{Operations: ops, Confidence: confidence} +} + +func (a JudgeAgent) fetchSimilarSummaries(ctx context.Context, items []string, userID string) map[string][]storage.SearchResult { + out := make(map[string][]storage.SearchResult, len(items)) + if a.VectorStore == nil { + return out + } + filters := map[string]any{"domain": string(weaver.DomainSummary)} + if strings.TrimSpace(userID) != "" { + filters["user_id"] = userID + } + for _, item := range items { + item = strings.TrimSpace(item) + if item == "" { + continue + } + results, err := a.VectorStore.SearchByText(ctx, item, a.judgeTopK(), filters) + if err != nil { + out[item] = nil + continue + } + out[item] = results + } + return out +} + +func hasSummaryJudgeCandidates(matches map[string][]storage.SearchResult) bool { + for _, results := range matches { + for _, result := range results { + if result.Score >= summaryJudgeSimilarityThreshold { + return true + } + } + } + return false +} + +func filterMatchesByThreshold(matches map[string][]storage.SearchResult, threshold float64) map[string][]storage.SearchResult { + out := make(map[string][]storage.SearchResult, len(matches)) + for item, results := range matches { + filtered := make([]storage.SearchResult, 0, len(results)) + for _, result := range results { + if result.Score >= threshold { + filtered = append(filtered, result) + } + } + out[item] = filtered + } + return out +} + +func formatSummarySimilarBlock(items []string, matches map[string][]storage.SearchResult) []string { + if len(matches) == 0 { + return nil + } + lines := make([]string, 0) + for _, item := range items { + item = strings.TrimSpace(item) + if item == "" { + continue + } + results := matches[item] + lines = append(lines, fmt.Sprintf("For item: %q", item)) + if len(results) == 0 { + lines = append(lines, " - (no similar records)") + continue + } + for _, result := range results { + lines = append(lines, fmt.Sprintf(" - ID: %s | Score: %.2f | %q", result.ID, result.Score, result.Content)) + } + } + return lines +} + +func judgeParseFallback(items []string) weaver.JudgeResult { + ops := make([]weaver.Operation, 0, len(items)) + for _, item := range items { + item = strings.TrimSpace(item) + if item == "" { + continue + } + ops = append(ops, weaver.Operation{Type: weaver.OperationAdd, Content: item, Reason: "Fallback — JSON parse failed"}) + } + return weaver.JudgeResult{Operations: ops, Confidence: 0.5} +} + +func judgeFallback(items []string, confidence float64) weaver.JudgeResult { + ops := make([]weaver.Operation, 0, len(items)) + for _, item := range items { + item = strings.TrimSpace(item) + if item == "" { + continue + } + ops = append(ops, weaver.Operation{Type: weaver.OperationAdd, Content: item, Reason: "LLM judge unavailable — defaulting to ADD."}) + } + if confidence == 0 { + confidence = 0.8 + } + return weaver.JudgeResult{Operations: ops, Confidence: confidence} +} + +type judgeResponse struct { + Operations []struct { + Type string `json:"type"` + Content string `json:"content"` + EmbeddingID string `json:"embedding_id"` + Reason string `json:"reason"` + } `json:"operations"` + Confidence float64 `json:"confidence"` +} + +func parseJudgeResponse(raw string) (weaver.JudgeResult, bool) { + jsonStr := extractJSONObject(raw) + if jsonStr == "" { + return weaver.JudgeResult{}, false + } + + var resp judgeResponse + if err := json.Unmarshal([]byte(jsonStr), &resp); err != nil { + return weaver.JudgeResult{}, false + } + + ops := make([]weaver.Operation, 0, len(resp.Operations)) + for _, o := range resp.Operations { + opType := weaver.OperationType(strings.ToUpper(strings.TrimSpace(o.Type))) + switch opType { + case weaver.OperationAdd, weaver.OperationUpdate, weaver.OperationDelete, weaver.OperationNoop: + default: + opType = weaver.OperationAdd + } + if strings.TrimSpace(o.Content) == "" && opType != weaver.OperationDelete && opType != weaver.OperationNoop { + continue + } + ops = append(ops, weaver.Operation{Type: opType, Content: o.Content, EmbeddingID: o.EmbeddingID, Reason: o.Reason}) + } + + conf := resp.Confidence + if conf <= 0 || conf > 1 { + conf = 0.8 + } + return weaver.JudgeResult{Operations: ops, Confidence: conf}, len(ops) > 0 +} diff --git a/xmem-go/internal/agents/profiler.go b/xmem-go/internal/agents/profiler.go new file mode 100644 index 0000000..de7e402 --- /dev/null +++ b/xmem-go/internal/agents/profiler.go @@ -0,0 +1,45 @@ +package agents + +import ( + "context" + "strings" + + "github.com/xortexai/xmem-go/internal/models" + "github.com/xortexai/xmem-go/internal/prompts" + "github.com/xortexai/xmem-go/internal/utils" +) + +type ProfileFact struct { + Topic string + SubTopic string + Memo string +} + +type ProfilerAgent struct { + Model models.ChatModel +} + +func (a ProfilerAgent) Run(ctx context.Context, text string) []ProfileFact { + if strings.TrimSpace(text) == "" { + return nil + } + + systemPrompt := prompts.BuildProfilerSystemPrompt() + userMessage := prompts.PackProfilerQuery(text) + + raw, err := callModel(ctx, a.Model, systemPrompt, userMessage) + if err != nil { + return nil + } + + parsed := utils.ParseRawResponseToProfiles(raw) + facts := make([]ProfileFact, 0, len(parsed)) + for _, f := range parsed { + facts = append(facts, ProfileFact{ + Topic: f.Topic, + SubTopic: f.SubTopic, + Memo: f.Memo, + }) + } + return facts +} diff --git a/xmem-go/internal/agents/snippet.go b/xmem-go/internal/agents/snippet.go new file mode 100644 index 0000000..5be0d6b --- /dev/null +++ b/xmem-go/internal/agents/snippet.go @@ -0,0 +1,89 @@ +package agents + +import ( + "context" + json "github.com/goccy/go-json" + "fmt" + "strings" + + "github.com/xortexai/xmem-go/internal/models" + "github.com/xortexai/xmem-go/internal/prompts" +) + +type SnippetAgent struct { + Model models.ChatModel +} + +type snippetJSON struct { + Content string `json:"content"` + CodeSnippet string `json:"code_snippet"` + Language string `json:"language"` + SnippetType string `json:"snippet_type"` + Tags any `json:"tags"` +} + +type snippetsResponse struct { + Snippets []snippetJSON `json:"snippets"` +} + +func (a SnippetAgent) Run(ctx context.Context, text string) []string { + if strings.TrimSpace(text) == "" { + return nil + } + + systemPrompt := prompts.BuildSnippetSystemPrompt() + userMessage := prompts.PackSnippetQuery(text) + + raw, err := callModel(ctx, a.Model, systemPrompt, userMessage) + if err != nil { + return nil + } + + items := parseSnippetResponse(raw) + results := make([]string, 0, len(items)) + for _, s := range items { + content := strings.TrimSpace(s.Content) + if content == "" { + continue + } + line := joinPipe(content, s.CodeSnippet, s.Language, defaultString(s.SnippetType, "algorithm"), normalizeTags(s.Tags)) + results = append(results, line) + } + return results +} + +func parseSnippetResponse(raw string) []snippetJSON { + jsonStr := extractJSONObject(raw) + if jsonStr == "" { + jsonStr = extractJSONArray(raw) + } + + var resp snippetsResponse + if err := json.Unmarshal([]byte(jsonStr), &resp); err == nil && len(resp.Snippets) > 0 { + return resp.Snippets + } + return nil +} + +func normalizeTags(tags any) string { + switch v := tags.(type) { + case nil: + return "" + case string: + return v + case []any: + out := make([]string, 0, len(v)) + for _, tag := range v { + cleaned := strings.TrimSpace(fmt.Sprintf("%v", tag)) + if cleaned != "" { + out = append(out, cleaned) + } + } + if len(out) > 10 { + out = out[:10] + } + return strings.Join(out, ",") + default: + return strings.TrimSpace(fmt.Sprintf("%v", v)) + } +} diff --git a/xmem-go/internal/agents/summarizer.go b/xmem-go/internal/agents/summarizer.go new file mode 100644 index 0000000..470857b --- /dev/null +++ b/xmem-go/internal/agents/summarizer.go @@ -0,0 +1,52 @@ +package agents + +import ( + "context" + "strings" + + "github.com/xortexai/xmem-go/internal/models" + "github.com/xortexai/xmem-go/internal/prompts" +) + +type SummarizerAgent struct { + Model models.ChatModel +} + +var emptySentinels = map[string]struct{}{ + `""`: {}, + `''`: {}, + "empty": {}, + "(empty)": {}, + "(empty string)": {}, +} + +func (a SummarizerAgent) Run(ctx context.Context, userQuery string, agentResponse string) []string { + if strings.TrimSpace(userQuery) == "" && strings.TrimSpace(agentResponse) == "" { + return nil + } + + systemPrompt := prompts.BuildSummarizerSystemPrompt() + userMessage := prompts.PackSummaryQuery(userQuery, agentResponse) + + raw, err := callModel(ctx, a.Model, systemPrompt, userMessage) + if err != nil { + return nil + } + + trimmed := strings.TrimSpace(raw) + if _, isEmpty := emptySentinels[trimmed]; isEmpty || trimmed == "" { + return nil + } + + lines := strings.Split(trimmed, "\n") + bullets := make([]string, 0, len(lines)) + for _, line := range lines { + cleaned := strings.TrimSpace(line) + cleaned = strings.TrimLeft(cleaned, "-•*") + cleaned = strings.TrimSpace(cleaned) + if cleaned != "" { + bullets = append(bullets, cleaned) + } + } + return bullets +} diff --git a/xmem-go/internal/agents/temporal.go b/xmem-go/internal/agents/temporal.go new file mode 100644 index 0000000..1dcc36f --- /dev/null +++ b/xmem-go/internal/agents/temporal.go @@ -0,0 +1,79 @@ +package agents + +import ( + "context" + "strconv" + "strings" + + "github.com/xortexai/xmem-go/internal/models" + "github.com/xortexai/xmem-go/internal/prompts" + "github.com/xortexai/xmem-go/internal/utils" +) + +var daysInMonth = map[int]int{ + 1: 31, 2: 29, 3: 31, 4: 30, 5: 31, 6: 30, + 7: 31, 8: 31, 9: 30, 10: 31, 11: 30, 12: 31, +} + +type Event struct { + Date string + EventName string + Desc string + Year string + Time string + DateExpression string +} + +type TemporalAgent struct { + Model models.ChatModel +} + +func (a TemporalAgent) Run(ctx context.Context, text string, sessionDatetime string) []Event { + if strings.TrimSpace(text) == "" { + return nil + } + + systemPrompt := prompts.BuildTemporalSystemPrompt() + userMessage := prompts.PackTemporalQuery(text, sessionDatetime) + + raw, err := callModel(ctx, a.Model, systemPrompt, userMessage) + if err != nil { + return nil + } + + parsed := utils.ParseRawResponseToEvents(raw) + events := make([]Event, 0, len(parsed)) + for _, e := range parsed { + if !isValidDate(e.Date) { + continue + } + events = append(events, Event{ + Date: e.Date, + EventName: e.EventName, + Desc: e.Desc, + Year: e.Year, + Time: e.Time, + DateExpression: e.DateExpression, + }) + } + return events +} + +func isValidDate(date string) bool { + if len(date) != 5 { + return false + } + parts := strings.Split(date, "-") + if len(parts) != 2 { + return false + } + month, err := strconv.Atoi(parts[0]) + if err != nil || month < 1 || month > 12 { + return false + } + day, err := strconv.Atoi(parts[1]) + if err != nil || day < 1 || day > daysInMonth[month] { + return false + } + return true +} diff --git a/xmem-go/internal/api/auth.go b/xmem-go/internal/api/auth.go index de40d50..e6aa2af 100644 --- a/xmem-go/internal/api/auth.go +++ b/xmem-go/internal/api/auth.go @@ -4,7 +4,7 @@ import ( "crypto/hmac" "crypto/sha256" "encoding/base64" - "encoding/json" + json "github.com/goccy/go-json" "net/http" "strings" diff --git a/xmem-go/internal/api/response.go b/xmem-go/internal/api/response.go index 23b3eba..e2ee9ae 100644 --- a/xmem-go/internal/api/response.go +++ b/xmem-go/internal/api/response.go @@ -1,7 +1,7 @@ package api import ( - "encoding/json" + json "github.com/goccy/go-json" "fmt" "net/http" "time" diff --git a/xmem-go/internal/api/server_test.go b/xmem-go/internal/api/server_test.go index 685fb54..665b776 100644 --- a/xmem-go/internal/api/server_test.go +++ b/xmem-go/internal/api/server_test.go @@ -2,7 +2,8 @@ package api import ( "bytes" - "encoding/json" + "context" + json "github.com/goccy/go-json" "log/slog" "net/http" "net/http/httptest" @@ -21,6 +22,29 @@ import ( "github.com/xortexai/xmem-go/internal/weaver" ) +type testModel struct{} + +func (testModel) Name() string { return "test-model" } +func (testModel) Generate(_ context.Context, prompt string) (models.Response, error) { + return models.Response{Content: prompt, ModelName: "test-model"}, nil +} +func (testModel) GenerateWithMessages(_ context.Context, msgs []models.Message) (models.Response, error) { + content := "" + for _, m := range msgs { + content += m.Content + "\n" + } + return models.Response{Content: strings.TrimSpace(content), ModelName: "test-model"}, nil +} +func (testModel) GenerateVision(_ context.Context, _ string, userText string, _ string) (models.Response, error) { + return models.Response{Content: userText, ModelName: "test-model"}, nil +} +func (testModel) SelectTools(_ context.Context, query string, _ []map[string]string) (models.Response, error) { + return models.Response{ + ToolCalls: []models.ToolCall{{ID: "call-1", Name: "search_summary", Args: map[string]any{"query": query}}}, + ModelName: "test-model", + }, nil +} + func testHandler() http.Handler { settings := config.Settings{ APIHost: "127.0.0.1", @@ -35,7 +59,7 @@ func testHandler() http.Handler { vectorStore := storage.NewMemoryVectorStore() snippetStore := storage.NewMemoryVectorStore() temporalStore := graph.NewMemoryTemporalStore() - model := models.NewLocalModel("test-model") + model := testModel{} w := &weaver.Weaver{ VectorStore: vectorStore, SnippetVectorStore: snippetStore, @@ -51,7 +75,7 @@ func testHandler() http.Handler { Summarizer: agents.SummarizerAgent{Model: model}, Image: agents.ImageAgent{Model: model}, Snippet: agents.SnippetAgent{Model: model}, - Judge: agents.JudgeAgent{Model: model, VectorStore: vectorStore, TopK: 3}, + Judge: agents.JudgeAgent{Model: model, VectorStore: vectorStore, TemporalStore: temporalStore, TopK: 3}, } retrieval := &pipelines.RetrievalPipeline{ Model: model, diff --git a/xmem-go/internal/config/config.go b/xmem-go/internal/config/config.go index a37a35e..1865ae5 100644 --- a/xmem-go/internal/config/config.go +++ b/xmem-go/internal/config/config.go @@ -2,7 +2,7 @@ package config import ( "bufio" - "encoding/json" + json "github.com/goccy/go-json" "errors" "net" "os" @@ -89,14 +89,14 @@ func Load() (Settings, error) { PineconeMetric: "cosine", PineconeCloud: "aws", PineconeRegion: "us-east-1", - VectorStoreProvider: "memory", - EmbeddingProvider: "local", - EmbeddingModel: "local-hash", + VectorStoreProvider: "pinecone", + EmbeddingProvider: "openai", + EmbeddingModel: "text-embedding-3-small", MongoDBURI: "mongodb://localhost:27017", MongoDBDatabase: "xmem_go", Neo4jURI: "bolt://localhost:7687", Neo4jUsername: "neo4j", - AppStoreProvider: "memory", + AppStoreProvider: "mongo", APIHost: "0.0.0.0", APIPort: 8081, CORSOrigins: []string{"http://localhost:3000", "http://localhost:5173"}, @@ -175,6 +175,17 @@ func Load() (Settings, error) { if net.ParseIP(strings.TrimSpace(s.APIHost)) == nil && s.APIHost != "localhost" { return s, errors.New("API_HOST must be an IP address or localhost") } + if s.Environment != "development" && s.Environment != "test" { + if s.Neo4jPassword == "" { + return s, errors.New("NEO4J_PASSWORD is required in non-development environments") + } + if s.PineconeAPIKey == "" { + return s, errors.New("PINECONE_API_KEY is required in non-development environments") + } + if s.OpenAIAPIKey == "" { + return s, errors.New("OPENAI_API_KEY is required in non-development environments") + } + } return s, nil } diff --git a/xmem-go/internal/graph/neo4j.go b/xmem-go/internal/graph/neo4j.go index 24a181c..d8766a9 100644 --- a/xmem-go/internal/graph/neo4j.go +++ b/xmem-go/internal/graph/neo4j.go @@ -7,14 +7,16 @@ import ( "github.com/neo4j/neo4j-go-driver/v5/neo4j" "github.com/xortexai/xmem-go/internal/config" + "github.com/xortexai/xmem-go/internal/storage" + "github.com/xortexai/xmem-go/internal/utils" ) type Neo4jTemporalStore struct { - driver neo4j.DriverWithContext - db string + driver neo4j.DriverWithContext + embedder storage.Embedder } -func NewNeo4jTemporalStore(ctx context.Context, settings config.Settings) (*Neo4jTemporalStore, error) { +func NewNeo4jTemporalStore(ctx context.Context, settings config.Settings, embedder storage.Embedder) (*Neo4jTemporalStore, error) { driver, err := neo4j.NewDriverWithContext( settings.Neo4jURI, neo4j.BasicAuth(settings.Neo4jUsername, settings.Neo4jPassword, ""), @@ -28,31 +30,94 @@ func NewNeo4jTemporalStore(ctx context.Context, settings config.Settings) (*Neo4 _ = driver.Close(context.Background()) return nil, err } - store := &Neo4jTemporalStore{driver: driver} - _, _ = neo4j.ExecuteQuery(ctx, driver, - "CREATE CONSTRAINT xmem_go_event_key IF NOT EXISTS FOR (e:XMemGoEvent) REQUIRE (e.user_id, e.date, e.event_name) IS UNIQUE", - nil, neo4j.EagerResultTransformer, neo4j.ExecuteQueryWithWritersRouting()) + store := &Neo4jTemporalStore{driver: driver, embedder: embedder} + store.initSchema(ctx) return store, nil } +func (s *Neo4jTemporalStore) initSchema(ctx context.Context) { + constraints := []string{ + "CREATE CONSTRAINT user_id_unique IF NOT EXISTS FOR (u:User) REQUIRE u.user_id IS UNIQUE", + "CREATE CONSTRAINT date_val_unique IF NOT EXISTS FOR (d:Date) REQUIRE d.date IS UNIQUE", + } + for _, q := range constraints { + _, _ = neo4j.ExecuteQuery(ctx, s.driver, q, nil, + neo4j.EagerResultTransformer, neo4j.ExecuteQueryWithWritersRouting()) + } +} + func (s *Neo4jTemporalStore) Close(ctx context.Context) error { return s.driver.Close(ctx) } func (s *Neo4jTemporalStore) CreateEvent(ctx context.Context, userID string, date string, event Event) error { - return s.upsertEvent(ctx, userID, date, event) + return s.withRetry(ctx, func() error { + return s.createEventOnce(ctx, userID, date, event) + }) +} + +func (s *Neo4jTemporalStore) createEventOnce(ctx context.Context, userID string, date string, event Event) error { + props := s.buildEventProps(ctx, event) + + _, err := neo4j.ExecuteQuery(ctx, s.driver, ` + MERGE (u:User {user_id: $user_id}) + MERGE (d:Date {date: $date_str}) + CREATE (u)-[r:HAS_EVENT]->(d) + SET r += $props + `, map[string]any{ + "user_id": userID, + "date_str": date, + "props": props, + }, neo4j.EagerResultTransformer, neo4j.ExecuteQueryWithWritersRouting()) + return err } func (s *Neo4jTemporalStore) UpdateEvent(ctx context.Context, userID string, date string, event Event) error { - return s.upsertEvent(ctx, userID, date, event) + return s.withRetry(ctx, func() error { + return s.updateEventOnce(ctx, userID, date, event) + }) +} + +func (s *Neo4jTemporalStore) updateEventOnce(ctx context.Context, userID string, date string, event Event) error { + props := s.buildEventProps(ctx, event) + + _, err := neo4j.ExecuteQuery(ctx, s.driver, ` + MATCH (u:User {user_id: $user_id}) + -[r:HAS_EVENT]-> + (d:Date {date: $date_str}) + SET r += $props + `, map[string]any{ + "user_id": userID, + "date_str": date, + "props": props, + }, neo4j.EagerResultTransformer, neo4j.ExecuteQueryWithWritersRouting()) + return err } func (s *Neo4jTemporalStore) DeleteEvent(ctx context.Context, userID string, embeddingID string) error { + return s.withRetry(ctx, func() error { + return s.deleteEventOnce(ctx, userID, embeddingID) + }) +} + +func (s *Neo4jTemporalStore) deleteEventOnce(ctx context.Context, userID string, embeddingID string) error { date, name := splitEmbeddingID(embeddingID) + if name != "" { + _, err := neo4j.ExecuteQuery(ctx, s.driver, ` + MATCH (u:User {user_id: $user_id}) + -[r:HAS_EVENT {event_name: $event_name}]-> + (d:Date {date: $date_str}) + DELETE r + `, map[string]any{"user_id": userID, "date_str": date, "event_name": name}, + neo4j.EagerResultTransformer, neo4j.ExecuteQueryWithWritersRouting()) + return err + } _, err := neo4j.ExecuteQuery(ctx, s.driver, ` - MATCH (e:XMemGoEvent {user_id: $user_id, date: $date, event_name: $event_name}) - DELETE e - `, map[string]any{"user_id": userID, "date": date, "event_name": name}, + MATCH (u:User {user_id: $user_id}) + -[r:HAS_EVENT]-> + (d:Date {date: $date_str}) + DELETE r + `, map[string]any{"user_id": userID, "date_str": date}, neo4j.EagerResultTransformer, neo4j.ExecuteQueryWithWritersRouting()) return err } @@ -61,58 +126,129 @@ func (s *Neo4jTemporalStore) SearchEventsByName(ctx context.Context, eventName s if topK <= 0 { topK = 10 } - result, err := neo4j.ExecuteQuery(ctx, s.driver, ` - MATCH (e:XMemGoEvent {user_id: $user_id}) - WHERE toLower(e.event_name) CONTAINS toLower($event_name) - RETURN e.date AS date, e.event_name AS event_name, e.desc AS desc, e.year AS year, - e.time AS time, e.date_expression AS date_expression, 1.0 AS score - LIMIT $top_k - `, map[string]any{"user_id": userID, "event_name": eventName, "top_k": topK}, - neo4j.EagerResultTransformer, neo4j.ExecuteQueryWithReadersRouting()) - if err != nil { - return nil, err + var result *neo4j.EagerResult + var err error + retryErr := s.withRetry(ctx, func() error { + result, err = neo4j.ExecuteQuery(ctx, s.driver, ` + MATCH (u:User {user_id: $user_id}) + -[r:HAS_EVENT]-> + (d:Date) + WHERE toLower(r.event_name) = toLower($event_name) + RETURN r.event_name AS event_name, r.desc AS desc, r.year AS year, + r.time AS time, r.date_expression AS date_expression, + d.date AS date + LIMIT $top_k + `, map[string]any{"user_id": userID, "event_name": eventName, "top_k": topK}, + neo4j.EagerResultTransformer, neo4j.ExecuteQueryWithReadersRouting()) + return err + }) + if retryErr != nil { + return nil, retryErr } return recordsToEvents(result.Records), nil } -func (s *Neo4jTemporalStore) SearchEventsByEmbedding(ctx context.Context, userID string, queryText string, topK int, _ float64) ([]Event, error) { +func (s *Neo4jTemporalStore) SearchEventsByEmbedding(ctx context.Context, userID string, queryText string, topK int, threshold float64) ([]Event, error) { if topK <= 0 { topK = 10 } + + if s.embedder == nil { + return s.searchEventsByKeyword(ctx, userID, queryText, topK) + } + + queryEmbedding, err := s.embedder.Embed(ctx, queryText) + if err != nil || len(queryEmbedding) == 0 { + return s.searchEventsByKeyword(ctx, userID, queryText, topK) + } + + var result *neo4j.EagerResult + retryErr := s.withRetry(ctx, func() error { + result, err = neo4j.ExecuteQuery(ctx, s.driver, ` + MATCH (u:User {user_id: $user_id}) + -[r:HAS_EVENT]-> + (d:Date) + WHERE r.embedding IS NOT NULL + AND size(r.embedding) = size($query_embedding) + WITH r, d, + 2.0 * vector.similarity.cosine(r.embedding, $query_embedding) - 1.0 + AS similarity_score + WHERE similarity_score >= $similarity_threshold + RETURN r.event_name AS event_name, r.desc AS desc, r.year AS year, + r.time AS time, r.date_expression AS date_expression, + d.date AS date, similarity_score + ORDER BY similarity_score DESC + LIMIT $top_k + `, map[string]any{ + "user_id": userID, + "query_embedding": queryEmbedding, + "similarity_threshold": threshold, + "top_k": topK, + }, neo4j.EagerResultTransformer, neo4j.ExecuteQueryWithReadersRouting()) + return err + }) + if retryErr != nil { + return nil, retryErr + } + return recordsToEvents(result.Records), nil +} + +func (s *Neo4jTemporalStore) searchEventsByKeyword(ctx context.Context, userID string, queryText string, topK int) ([]Event, error) { terms := strings.ToLower(queryText) - result, err := neo4j.ExecuteQuery(ctx, s.driver, ` - MATCH (e:XMemGoEvent {user_id: $user_id}) - WITH e, - CASE - WHEN toLower(coalesce(e.event_name, '') + ' ' + coalesce(e.desc, '') + ' ' + coalesce(e.date_expression, '')) CONTAINS $query THEN 1.0 - ELSE 0.25 - END AS score - RETURN e.date AS date, e.event_name AS event_name, e.desc AS desc, e.year AS year, - e.time AS time, e.date_expression AS date_expression, score - ORDER BY score DESC - LIMIT $top_k - `, map[string]any{"user_id": userID, "query": terms, "top_k": topK}, - neo4j.EagerResultTransformer, neo4j.ExecuteQueryWithReadersRouting()) - if err != nil { - return nil, err + var result *neo4j.EagerResult + var err error + retryErr := s.withRetry(ctx, func() error { + result, err = neo4j.ExecuteQuery(ctx, s.driver, ` + MATCH (u:User {user_id: $user_id}) + -[r:HAS_EVENT]-> + (d:Date) + WITH r, d, + CASE + WHEN toLower(coalesce(r.event_name, '') + ' ' + coalesce(r.desc, '') + ' ' + coalesce(r.date_expression, '')) CONTAINS $query THEN 1.0 + ELSE 0.25 + END AS similarity_score + WHERE similarity_score > 0.25 + RETURN r.event_name AS event_name, r.desc AS desc, r.year AS year, + r.time AS time, r.date_expression AS date_expression, + d.date AS date, similarity_score + ORDER BY similarity_score DESC + LIMIT $top_k + `, map[string]any{"user_id": userID, "query": terms, "top_k": topK}, + neo4j.EagerResultTransformer, neo4j.ExecuteQueryWithReadersRouting()) + return err + }) + if retryErr != nil { + return nil, retryErr } return recordsToEvents(result.Records), nil } -func (s *Neo4jTemporalStore) upsertEvent(ctx context.Context, userID string, date string, event Event) error { - _, err := neo4j.ExecuteQuery(ctx, s.driver, ` - MERGE (e:XMemGoEvent {user_id: $user_id, date: $date, event_name: $event_name}) - SET e.desc = $desc, - e.year = $year, - e.time = $time, - e.date_expression = $date_expression, - e.updated_at = datetime() - `, map[string]any{ - "user_id": userID, "date": date, "event_name": event.EventName, - "desc": event.Description, "year": event.Year, "time": event.Time, +func (s *Neo4jTemporalStore) buildEventProps(ctx context.Context, event Event) map[string]any { + props := map[string]any{ + "event_name": event.EventName, + "desc": event.Description, + "year": event.Year, + "time": event.Time, "date_expression": event.DateExpression, - }, neo4j.EagerResultTransformer, neo4j.ExecuteQueryWithWritersRouting()) - return err + } + + if s.embedder != nil { + searchable := event.EventName + if event.Description != "" { + searchable = event.EventName + ": " + event.Description + } + if searchable != "" { + if embedding, err := s.embedder.Embed(ctx, searchable); err == nil { + props["embedding"] = embedding + } + } + } + + return props +} + +func (s *Neo4jTemporalStore) withRetry(ctx context.Context, fn func() error) error { + return utils.RetryWithBackoff(ctx, 3, time.Second, fn) } func recordsToEvents(records []*neo4j.Record) []Event { @@ -125,7 +261,7 @@ func recordsToEvents(records []*neo4j.Record) []Event { Year: asString(record, "year"), Time: asString(record, "time"), DateExpression: asString(record, "date_expression"), - SimilarityScore: asFloat(record, "score"), + SimilarityScore: asFloat(record, "similarity_score"), }) } return out diff --git a/xmem-go/internal/graph/types.go b/xmem-go/internal/graph/types.go index ac4a0ca..7d42288 100644 --- a/xmem-go/internal/graph/types.go +++ b/xmem-go/internal/graph/types.go @@ -10,6 +10,7 @@ type Event struct { Year string Time string DateExpression string + Embedding []float64 SimilarityScore float64 } diff --git a/xmem-go/internal/jobs/payload.go b/xmem-go/internal/jobs/payload.go index eafb38c..12c76e4 100644 --- a/xmem-go/internal/jobs/payload.go +++ b/xmem-go/internal/jobs/payload.go @@ -3,7 +3,7 @@ package jobs import ( "crypto/sha256" "encoding/hex" - "encoding/json" + json "github.com/goccy/go-json" "sort" "strings" ) @@ -43,11 +43,43 @@ func Redact(payload map[string]any) map[string]any { } func stableHash(value any) string { - encoded, _ := json.Marshal(value) + encoded := canonicalJSON(value) sum := sha256.Sum256(encoded) return hex.EncodeToString(sum[:]) } +func canonicalJSON(v any) []byte { + switch val := v.(type) { + case map[string]any: + keys := SortedKeys(val) + buf := []byte("{") + for i, k := range keys { + if i > 0 { + buf = append(buf, ',') + } + keyBytes, _ := json.Marshal(k) + buf = append(buf, keyBytes...) + buf = append(buf, ':') + buf = append(buf, canonicalJSON(val[k])...) + } + buf = append(buf, '}') + return buf + case []any: + buf := []byte("[") + for i, item := range val { + if i > 0 { + buf = append(buf, ',') + } + buf = append(buf, canonicalJSON(item)...) + } + buf = append(buf, ']') + return buf + default: + b, _ := json.Marshal(v) + return b + } +} + func toMap(value any) map[string]any { if value == nil { return map[string]any{} diff --git a/xmem-go/internal/models/models.go b/xmem-go/internal/models/models.go index 0927768..6f8a7af 100644 --- a/xmem-go/internal/models/models.go +++ b/xmem-go/internal/models/models.go @@ -3,7 +3,6 @@ package models import ( "bytes" "context" - "encoding/json" "errors" "fmt" "io" @@ -11,7 +10,9 @@ import ( "strings" "time" + json "github.com/goccy/go-json" "github.com/xortexai/xmem-go/internal/config" + "golang.org/x/net/http2" ) type Message struct { @@ -19,6 +20,17 @@ type Message struct { Content string `json:"content"` } +type ContentBlock struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ImageURL string `json:"image_url,omitempty"` +} + +type MultimodalMessage struct { + Role string + Content []ContentBlock +} + type ToolCall struct { ID string `json:"id"` Name string `json:"name"` @@ -26,152 +38,47 @@ type ToolCall struct { } type Response struct { - Content string - ToolCalls []ToolCall - ModelName string + Content string + ToolCalls []ToolCall + ModelName string + InputTokens int + OutputTokens int + TotalTokens int } type ChatModel interface { Name() string Generate(ctx context.Context, prompt string) (Response, error) GenerateWithMessages(ctx context.Context, messages []Message) (Response, error) + GenerateVision(ctx context.Context, systemPrompt string, userText string, imageURL string) (Response, error) SelectTools(ctx context.Context, query string, profileCatalog []map[string]string) (Response, error) } -type LocalModel struct { - name string -} - -func NewLocalModel(name string) LocalModel { - if name == "" { - name = "local-rule-model" - } - return LocalModel{name: name} -} - -func NewRegistry(settings config.Settings) ChatModel { +func NewRegistry(settings config.Settings) (ChatModel, error) { for _, provider := range settings.FallbackOrder { provider = strings.ToLower(provider) switch provider { case "ollama": - return NewOllamaModel(settings) + return NewOllamaModel(settings), nil case "gemini": if settings.GeminiAPIKey != "" { - return NewGeminiModel(settings) + return NewGeminiModel(settings), nil } case "claude": if settings.ClaudeAPIKey != "" { - return NewClaudeModel(settings) + return NewClaudeModel(settings), nil } case "openai": if settings.OpenAIAPIKey != "" { - return NewOpenAICompatibleModel("openai", settings.OpenAIModel, "https://api.openai.com/v1/chat/completions", settings.OpenAIAPIKey) + return NewOpenAICompatibleModel("openai", settings.OpenAIModel, "https://api.openai.com/v1/chat/completions", settings.OpenAIAPIKey), nil } case "openrouter": if settings.OpenRouterAPIKey != "" { - return NewOpenAICompatibleModel("openrouter", settings.OpenRouterModel, "https://openrouter.ai/api/v1/chat/completions", settings.OpenRouterAPIKey) - } - case "bedrock": - if settings.AWSAccessKeyID != "" { - return NewLocalModel(settings.BedrockModel) + return NewOpenAICompatibleModel("openrouter", settings.OpenRouterModel, "https://openrouter.ai/api/v1/chat/completions", settings.OpenRouterAPIKey), nil } } } - return NewLocalModel("local-rule-model") -} - -func (m LocalModel) Name() string { - return m.name -} - -func (m LocalModel) Generate(_ context.Context, prompt string) (Response, error) { - answer := strings.TrimSpace(prompt) - if strings.Contains(prompt, "CONTEXT:") { - answer = synthesizeFromPrompt(prompt) - } - return Response{Content: answer, ModelName: m.name}, nil -} - -func (m LocalModel) GenerateWithMessages(_ context.Context, messages []Message) (Response, error) { - combined := "" - for _, msg := range messages { - combined += msg.Content + "\n" - } - return Response{Content: strings.TrimSpace(combined), ModelName: m.name}, nil -} - -func (m LocalModel) SelectTools(_ context.Context, query string, catalog []map[string]string) (Response, error) { - lowered := strings.ToLower(query) - calls := []ToolCall{} - id := 1 - add := func(name string, args map[string]any) { - calls = append(calls, ToolCall{ID: "call-" + string(rune('0'+id)), Name: name, Args: args}) - id++ - } - - for _, item := range catalog { - topic := item["topic"] - if topic != "" && strings.Contains(lowered, strings.ToLower(topic)) { - add("search_profile", map[string]any{"topic": topic}) - break - } - } - if len(calls) == 0 && containsAny(lowered, "name", "work", "job", "company", "hobby", "food", "like", "prefer", "profile") { - topic := "personal" - if containsAny(lowered, "work", "job", "company") { - topic = "work" - } else if containsAny(lowered, "food", "eat", "prefer") { - topic = "food" - } else if containsAny(lowered, "hobby", "like", "enjoy") { - topic = "interest" - } - add("search_profile", map[string]any{"topic": topic}) - } - if containsAny(lowered, "when", "date", "schedule", "appointment", "birthday", "tomorrow", "today", "event") { - add("search_temporal", map[string]any{"query": query}) - } - if containsAny(lowered, "code", "script", "function", "snippet") { - add("search_snippet", map[string]any{"query": query}) - } - if len(calls) == 0 || containsAny(lowered, "remember", "conversation", "summary", "context", "what") { - add("search_summary", map[string]any{"query": query}) - } - return Response{ToolCalls: calls, ModelName: m.name}, nil -} - -func synthesizeFromPrompt(prompt string) string { - context := after(prompt, "CONTEXT:") - if idx := strings.Index(context, "QUERY:"); idx >= 0 { - context = context[:idx] - } - lines := []string{} - for _, line := range strings.Split(context, "\n") { - line = strings.TrimSpace(line) - if line != "" && line != "No results found." { - lines = append(lines, line) - } - } - if len(lines) == 0 { - return "I could not find any stored memories that answer that." - } - return "Based on stored memories, " + strings.Trim(strings.Join(lines, " "), ". ") + "." -} - -func after(text, marker string) string { - idx := strings.Index(text, marker) - if idx < 0 { - return "" - } - return strings.TrimSpace(text[idx+len(marker):]) -} - -func containsAny(text string, words ...string) bool { - for _, word := range words { - if strings.Contains(text, word) { - return true - } - } - return false + return nil, errors.New("no LLM provider configured: set at least one API key (OPENROUTER_API_KEY, GEMINI_API_KEY, CLAUDE_API_KEY, or OPENAI_API_KEY)") } func MarshalJSON(v any) string { @@ -185,24 +92,34 @@ type HTTPModel struct { url string apiKey string client *http.Client - local LocalModel +} + +func newHTTP2Client(timeout time.Duration) *http.Client { + transport := &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 100, + IdleConnTimeout: 90 * time.Second, + ForceAttemptHTTP2: true, + } + http2.ConfigureTransport(transport) + return &http.Client{Timeout: timeout, Transport: transport} } func NewOpenAICompatibleModel(provider, model, url, apiKey string) HTTPModel { - return HTTPModel{provider: provider, model: model, url: url, apiKey: apiKey, client: &http.Client{Timeout: 90 * time.Second}, local: NewLocalModel(model)} + return HTTPModel{provider: provider, model: model, url: url, apiKey: apiKey, client: newHTTP2Client(90 * time.Second)} } func NewGeminiModel(settings config.Settings) HTTPModel { url := "https://generativelanguage.googleapis.com/v1beta/models/" + settings.GeminiModel + ":generateContent?key=" + settings.GeminiAPIKey - return HTTPModel{provider: "gemini", model: settings.GeminiModel, url: url, apiKey: settings.GeminiAPIKey, client: &http.Client{Timeout: 90 * time.Second}, local: NewLocalModel(settings.GeminiModel)} + return HTTPModel{provider: "gemini", model: settings.GeminiModel, url: url, apiKey: settings.GeminiAPIKey, client: newHTTP2Client(90 * time.Second)} } func NewClaudeModel(settings config.Settings) HTTPModel { - return HTTPModel{provider: "claude", model: settings.ClaudeModel, url: "https://api.anthropic.com/v1/messages", apiKey: settings.ClaudeAPIKey, client: &http.Client{Timeout: 90 * time.Second}, local: NewLocalModel(settings.ClaudeModel)} + return HTTPModel{provider: "claude", model: settings.ClaudeModel, url: "https://api.anthropic.com/v1/messages", apiKey: settings.ClaudeAPIKey, client: newHTTP2Client(90 * time.Second)} } func NewOllamaModel(settings config.Settings) HTTPModel { - return HTTPModel{provider: "ollama", model: settings.OllamaModel, url: strings.TrimRight(settings.OllamaBaseURL, "/") + "/api/chat", client: &http.Client{Timeout: 120 * time.Second}, local: NewLocalModel(settings.OllamaModel)} + return HTTPModel{provider: "ollama", model: settings.OllamaModel, url: strings.TrimRight(settings.OllamaBaseURL, "/") + "/api/chat", client: newHTTP2Client(120 * time.Second)} } func (m HTTPModel) Name() string { @@ -210,26 +127,123 @@ func (m HTTPModel) Name() string { } func (m HTTPModel) Generate(ctx context.Context, prompt string) (Response, error) { - content, err := m.complete(ctx, prompt, false) - if err != nil { - return m.local.Generate(ctx, prompt) - } - return Response{Content: content, ModelName: m.model}, nil + return m.complete(ctx, prompt, false) } func (m HTTPModel) GenerateWithMessages(ctx context.Context, messages []Message) (Response, error) { - content, err := m.completeWithMessages(ctx, messages) - if err != nil { - combined := "" - for _, msg := range messages { - combined += strings.ToUpper(msg.Role) + ":\n" + msg.Content + "\n\n" + return m.completeWithMessages(ctx, messages) +} + +func (m HTTPModel) GenerateVision(ctx context.Context, systemPrompt string, userText string, imageURL string) (Response, error) { + switch m.provider { + case "openai", "openrouter": + contentParts := []map[string]any{ + {"type": "text", "text": userText}, + } + if imageURL != "" { + contentParts = append(contentParts, map[string]any{ + "type": "image_url", "image_url": map[string]string{"url": imageURL}, + }) + } + messages := []map[string]any{ + {"role": "system", "content": systemPrompt}, + {"role": "user", "content": contentParts}, } - return m.Generate(ctx, strings.TrimSpace(combined)) + body := map[string]any{"model": m.model, "messages": messages, "temperature": 0.1} + var out struct { + Choices []struct { + Message struct{ Content string `json:"content"` } `json:"message"` + } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` + } + if err := m.doJSON(ctx, http.MethodPost, m.url, body, &out, func(req *http.Request) { + req.Header.Set("Authorization", "Bearer "+m.apiKey) + if m.provider == "openrouter" { + req.Header.Set("HTTP-Referer", "http://localhost:8081") + req.Header.Set("X-Title", "xmem-go") + } + }); err != nil { + return Response{}, err + } + if len(out.Choices) == 0 { + return Response{}, errors.New("empty model response") + } + return Response{Content: out.Choices[0].Message.Content, ModelName: m.model, InputTokens: out.Usage.PromptTokens, OutputTokens: out.Usage.CompletionTokens, TotalTokens: out.Usage.TotalTokens}, nil + case "gemini": + parts := []map[string]any{{"text": userText}} + if imageURL != "" { + parts = append(parts, map[string]any{ + "inline_data": map[string]string{"mime_type": "image/jpeg", "data": imageURL}, + }) + } + body := map[string]any{ + "contents": []map[string]any{{"role": "user", "parts": parts}}, + "generationConfig": map[string]any{"temperature": 0.1}, + "system_instruction": map[string]any{ + "parts": []map[string]string{{"text": systemPrompt}}, + }, + } + var out struct { + Candidates []struct { + Content struct { + Parts []struct{ Text string `json:"text"` } `json:"parts"` + } `json:"content"` + } `json:"candidates"` + UsageMetadata struct { + PromptTokenCount int `json:"promptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` + } `json:"usageMetadata"` + } + if err := m.doJSON(ctx, http.MethodPost, m.url, body, &out, nil); err != nil { + return Response{}, err + } + if len(out.Candidates) == 0 || len(out.Candidates[0].Content.Parts) == 0 { + return Response{}, errors.New("empty gemini response") + } + return Response{Content: out.Candidates[0].Content.Parts[0].Text, ModelName: m.model, InputTokens: out.UsageMetadata.PromptTokenCount, OutputTokens: out.UsageMetadata.CandidatesTokenCount, TotalTokens: out.UsageMetadata.TotalTokenCount}, nil + case "claude": + contentBlocks := []map[string]any{ + {"type": "text", "text": userText}, + } + if imageURL != "" { + contentBlocks = append(contentBlocks, map[string]any{ + "type": "image", "source": map[string]string{"type": "url", "url": imageURL}, + }) + } + body := map[string]any{ + "model": m.model, + "max_tokens": 4096, + "system": systemPrompt, + "messages": []map[string]any{{"role": "user", "content": contentBlocks}}, + } + var out struct { + Content []struct{ Text string `json:"text"` } `json:"content"` + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + } `json:"usage"` + } + if err := m.doJSON(ctx, http.MethodPost, m.url, body, &out, func(req *http.Request) { + req.Header.Set("x-api-key", m.apiKey) + req.Header.Set("anthropic-version", "2023-06-01") + }); err != nil { + return Response{}, err + } + if len(out.Content) == 0 { + return Response{}, errors.New("empty claude response") + } + return Response{Content: out.Content[0].Text, ModelName: m.model, InputTokens: out.Usage.InputTokens, OutputTokens: out.Usage.OutputTokens, TotalTokens: out.Usage.InputTokens + out.Usage.OutputTokens}, nil + default: + return m.GenerateWithMessages(ctx, []Message{{Role: "system", Content: systemPrompt}, {Role: "user", Content: userText + "\n[Image URL: " + imageURL + "]"}}) } - return Response{Content: content, ModelName: m.model}, nil } -func (m HTTPModel) completeWithMessages(ctx context.Context, messages []Message) (string, error) { +func (m HTTPModel) completeWithMessages(ctx context.Context, messages []Message) (Response, error) { switch m.provider { case "openai", "openrouter": return m.completeOpenAIMessages(ctx, messages) @@ -240,11 +254,11 @@ func (m HTTPModel) completeWithMessages(ctx context.Context, messages []Message) case "ollama": return m.completeOllamaMessages(ctx, messages) default: - return "", errors.New("unsupported provider") + return Response{}, errors.New("unsupported provider") } } -func (m HTTPModel) completeOpenAIMessages(ctx context.Context, messages []Message) (string, error) { +func (m HTTPModel) completeOpenAIMessages(ctx context.Context, messages []Message) (Response, error) { apiMessages := make([]map[string]string, 0, len(messages)) for _, msg := range messages { apiMessages = append(apiMessages, map[string]string{"role": msg.Role, "content": msg.Content}) @@ -260,6 +274,11 @@ func (m HTTPModel) completeOpenAIMessages(ctx context.Context, messages []Messag Content string `json:"content"` } `json:"message"` } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` } if err := m.doJSON(ctx, http.MethodPost, m.url, body, &out, func(req *http.Request) { req.Header.Set("Authorization", "Bearer "+m.apiKey) @@ -268,15 +287,15 @@ func (m HTTPModel) completeOpenAIMessages(ctx context.Context, messages []Messag req.Header.Set("X-Title", "xmem-go") } }); err != nil { - return "", err + return Response{}, err } if len(out.Choices) == 0 { - return "", errors.New("empty model response") + return Response{}, errors.New("empty model response") } - return out.Choices[0].Message.Content, nil + return Response{Content: out.Choices[0].Message.Content, ModelName: m.model, InputTokens: out.Usage.PromptTokens, OutputTokens: out.Usage.CompletionTokens, TotalTokens: out.Usage.TotalTokens}, nil } -func (m HTTPModel) completeGeminiMessages(ctx context.Context, messages []Message) (string, error) { +func (m HTTPModel) completeGeminiMessages(ctx context.Context, messages []Message) (Response, error) { var systemText string contents := []map[string]any{} for _, msg := range messages { @@ -310,17 +329,22 @@ func (m HTTPModel) completeGeminiMessages(ctx context.Context, messages []Messag } `json:"parts"` } `json:"content"` } `json:"candidates"` + UsageMetadata struct { + PromptTokenCount int `json:"promptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` + } `json:"usageMetadata"` } if err := m.doJSON(ctx, http.MethodPost, m.url, body, &out, nil); err != nil { - return "", err + return Response{}, err } if len(out.Candidates) == 0 || len(out.Candidates[0].Content.Parts) == 0 { - return "", errors.New("empty gemini response") + return Response{}, errors.New("empty gemini response") } - return out.Candidates[0].Content.Parts[0].Text, nil + return Response{Content: out.Candidates[0].Content.Parts[0].Text, ModelName: m.model, InputTokens: out.UsageMetadata.PromptTokenCount, OutputTokens: out.UsageMetadata.CandidatesTokenCount, TotalTokens: out.UsageMetadata.TotalTokenCount}, nil } -func (m HTTPModel) completeClaudeMessages(ctx context.Context, messages []Message) (string, error) { +func (m HTTPModel) completeClaudeMessages(ctx context.Context, messages []Message) (Response, error) { var systemText string apiMessages := []map[string]string{} for _, msg := range messages { @@ -342,20 +366,24 @@ func (m HTTPModel) completeClaudeMessages(ctx context.Context, messages []Messag Content []struct { Text string `json:"text"` } `json:"content"` + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + } `json:"usage"` } if err := m.doJSON(ctx, http.MethodPost, m.url, body, &out, func(req *http.Request) { req.Header.Set("x-api-key", m.apiKey) req.Header.Set("anthropic-version", "2023-06-01") }); err != nil { - return "", err + return Response{}, err } if len(out.Content) == 0 { - return "", errors.New("empty claude response") + return Response{}, errors.New("empty claude response") } - return out.Content[0].Text, nil + return Response{Content: out.Content[0].Text, ModelName: m.model, InputTokens: out.Usage.InputTokens, OutputTokens: out.Usage.OutputTokens, TotalTokens: out.Usage.InputTokens + out.Usage.OutputTokens}, nil } -func (m HTTPModel) completeOllamaMessages(ctx context.Context, messages []Message) (string, error) { +func (m HTTPModel) completeOllamaMessages(ctx context.Context, messages []Message) (Response, error) { apiMessages := make([]map[string]string, 0, len(messages)) for _, msg := range messages { apiMessages = append(apiMessages, map[string]string{"role": msg.Role, "content": msg.Content}) @@ -369,11 +397,13 @@ func (m HTTPModel) completeOllamaMessages(ctx context.Context, messages []Messag Message struct { Content string `json:"content"` } `json:"message"` + PromptEvalCount int `json:"prompt_eval_count"` + EvalCount int `json:"eval_count"` } if err := m.doJSON(ctx, http.MethodPost, m.url, body, &out, nil); err != nil { - return "", err + return Response{}, err } - return out.Message.Content, nil + return Response{Content: out.Message.Content, ModelName: m.model, InputTokens: out.PromptEvalCount, OutputTokens: out.EvalCount, TotalTokens: out.PromptEvalCount + out.EvalCount}, nil } func (m HTTPModel) SelectTools(ctx context.Context, query string, catalog []map[string]string) (Response, error) { @@ -383,25 +413,27 @@ Return only JSON in this exact shape: Allowed names: search_profile, search_temporal, search_summary, search_snippet. Available profile catalog: ` + MarshalJSON(catalog) + ` Query: ` + query - content, err := m.complete(ctx, prompt, true) + resp, err := m.complete(ctx, prompt, true) if err != nil { - return m.local.SelectTools(ctx, query, catalog) + return Response{}, err } var parsed struct { ToolCalls []ToolCall `json:"tool_calls"` } - if err := json.Unmarshal([]byte(extractJSONObject(content)), &parsed); err != nil || len(parsed.ToolCalls) == 0 { - return m.local.SelectTools(ctx, query, catalog) + if err := json.Unmarshal([]byte(extractJSONObject(resp.Content)), &parsed); err != nil || len(parsed.ToolCalls) == 0 { + return resp, nil } for i := range parsed.ToolCalls { if parsed.ToolCalls[i].ID == "" { parsed.ToolCalls[i].ID = fmt.Sprintf("call-%d", i+1) } } - return Response{ToolCalls: parsed.ToolCalls, ModelName: m.model}, nil + resp.ToolCalls = parsed.ToolCalls + resp.Content = "" + return resp, nil } -func (m HTTPModel) complete(ctx context.Context, prompt string, jsonMode bool) (string, error) { +func (m HTTPModel) complete(ctx context.Context, prompt string, jsonMode bool) (Response, error) { switch m.provider { case "openai", "openrouter": return m.completeOpenAI(ctx, prompt, jsonMode) @@ -412,11 +444,11 @@ func (m HTTPModel) complete(ctx context.Context, prompt string, jsonMode bool) ( case "ollama": return m.completeOllama(ctx, prompt, jsonMode) default: - return "", errors.New("unsupported provider") + return Response{}, errors.New("unsupported provider") } } -func (m HTTPModel) completeOpenAI(ctx context.Context, prompt string, jsonMode bool) (string, error) { +func (m HTTPModel) completeOpenAI(ctx context.Context, prompt string, jsonMode bool) (Response, error) { body := map[string]any{ "model": m.model, "messages": []map[string]string{ @@ -433,6 +465,11 @@ func (m HTTPModel) completeOpenAI(ctx context.Context, prompt string, jsonMode b Content string `json:"content"` } `json:"message"` } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` } if err := m.doJSON(ctx, http.MethodPost, m.url, body, &out, func(req *http.Request) { req.Header.Set("Authorization", "Bearer "+m.apiKey) @@ -441,15 +478,15 @@ func (m HTTPModel) completeOpenAI(ctx context.Context, prompt string, jsonMode b req.Header.Set("X-Title", "xmem-go") } }); err != nil { - return "", err + return Response{}, err } if len(out.Choices) == 0 { - return "", errors.New("empty model response") + return Response{}, errors.New("empty model response") } - return out.Choices[0].Message.Content, nil + return Response{Content: out.Choices[0].Message.Content, ModelName: m.model, InputTokens: out.Usage.PromptTokens, OutputTokens: out.Usage.CompletionTokens, TotalTokens: out.Usage.TotalTokens}, nil } -func (m HTTPModel) completeGemini(ctx context.Context, prompt string) (string, error) { +func (m HTTPModel) completeGemini(ctx context.Context, prompt string) (Response, error) { body := map[string]any{ "contents": []map[string]any{{"parts": []map[string]string{{"text": prompt}}}}, "generationConfig": map[string]any{ @@ -464,17 +501,22 @@ func (m HTTPModel) completeGemini(ctx context.Context, prompt string) (string, e } `json:"parts"` } `json:"content"` } `json:"candidates"` + UsageMetadata struct { + PromptTokenCount int `json:"promptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` + } `json:"usageMetadata"` } if err := m.doJSON(ctx, http.MethodPost, m.url, body, &out, nil); err != nil { - return "", err + return Response{}, err } if len(out.Candidates) == 0 || len(out.Candidates[0].Content.Parts) == 0 { - return "", errors.New("empty gemini response") + return Response{}, errors.New("empty gemini response") } - return out.Candidates[0].Content.Parts[0].Text, nil + return Response{Content: out.Candidates[0].Content.Parts[0].Text, ModelName: m.model, InputTokens: out.UsageMetadata.PromptTokenCount, OutputTokens: out.UsageMetadata.CandidatesTokenCount, TotalTokens: out.UsageMetadata.TotalTokenCount}, nil } -func (m HTTPModel) completeClaude(ctx context.Context, prompt string) (string, error) { +func (m HTTPModel) completeClaude(ctx context.Context, prompt string) (Response, error) { body := map[string]any{ "model": m.model, "max_tokens": 1024, @@ -486,20 +528,24 @@ func (m HTTPModel) completeClaude(ctx context.Context, prompt string) (string, e Content []struct { Text string `json:"text"` } `json:"content"` + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + } `json:"usage"` } if err := m.doJSON(ctx, http.MethodPost, m.url, body, &out, func(req *http.Request) { req.Header.Set("x-api-key", m.apiKey) req.Header.Set("anthropic-version", "2023-06-01") }); err != nil { - return "", err + return Response{}, err } if len(out.Content) == 0 { - return "", errors.New("empty claude response") + return Response{}, errors.New("empty claude response") } - return out.Content[0].Text, nil + return Response{Content: out.Content[0].Text, ModelName: m.model, InputTokens: out.Usage.InputTokens, OutputTokens: out.Usage.OutputTokens, TotalTokens: out.Usage.InputTokens + out.Usage.OutputTokens}, nil } -func (m HTTPModel) completeOllama(ctx context.Context, prompt string, jsonMode bool) (string, error) { +func (m HTTPModel) completeOllama(ctx context.Context, prompt string, jsonMode bool) (Response, error) { body := map[string]any{ "model": m.model, "stream": false, @@ -514,11 +560,13 @@ func (m HTTPModel) completeOllama(ctx context.Context, prompt string, jsonMode b Message struct { Content string `json:"content"` } `json:"message"` + PromptEvalCount int `json:"prompt_eval_count"` + EvalCount int `json:"eval_count"` } if err := m.doJSON(ctx, http.MethodPost, m.url, body, &out, nil); err != nil { - return "", err + return Response{}, err } - return out.Message.Content, nil + return Response{Content: out.Message.Content, ModelName: m.model, InputTokens: out.PromptEvalCount, OutputTokens: out.EvalCount, TotalTokens: out.PromptEvalCount + out.EvalCount}, nil } func (m HTTPModel) doJSON(ctx context.Context, method, url string, body any, out any, decorate func(*http.Request)) error { diff --git a/xmem-go/internal/pipelines/ingest.go b/xmem-go/internal/pipelines/ingest.go index 2c44e80..2c649d2 100644 --- a/xmem-go/internal/pipelines/ingest.go +++ b/xmem-go/internal/pipelines/ingest.go @@ -2,6 +2,7 @@ package pipelines import ( "context" + "strings" "sync" "github.com/xortexai/xmem-go/internal/agents" @@ -17,6 +18,7 @@ type IngestPipeline struct { Temporal agents.TemporalAgent Summarizer agents.SummarizerAgent Image agents.ImageAgent + Code agents.CodeAgent Snippet agents.SnippetAgent Judge agents.JudgeAgent } @@ -69,18 +71,27 @@ func (p *IngestPipeline) invoke(ctx context.Context, req contracts.IngestRequest state.Classification = p.Classifier.Run(ctx, req.UserQuery, req.ImageURL) hasProfile, hasTemporal, hasCode, hasImage := false, false, false, false + profileQueries, temporalQueries, codeQueries, imageQueries := []string{}, []string{}, []string{}, []string{} for _, c := range state.Classification { switch c.Source { case "profile": hasProfile = true + profileQueries = append(profileQueries, c.Query) case "event": hasTemporal = true + temporalQueries = append(temporalQueries, c.Query) case "code": hasCode = true + codeQueries = append(codeQueries, c.Query) case "image": hasImage = true + imageQueries = append(imageQueries, c.Query) } } + if req.ImageURL != "" && len(imageQueries) == 0 { + hasImage = true + imageQueries = append(imageQueries, "Analyze this image for memory-relevant details.") + } isTrivial := len(splitWords(req.UserQuery)) < 4 && !hasProfile && !hasTemporal && !hasCode && !hasImage var wg sync.WaitGroup @@ -105,28 +116,28 @@ func (p *IngestPipeline) invoke(ctx context.Context, req contracts.IngestRequest } if hasProfile { run(func() IngestState { - facts := p.Profiler.Run(ctx, req.UserQuery) - judge := p.Judge.JudgeProfile(ctx, facts) + facts := p.Profiler.Run(ctx, strings.Join(profileQueries, " ")) + judge := p.Judge.JudgeProfile(ctx, facts, userID) return IngestState{ProfileJudge: judge, ProfileWeaver: p.Weaver.Execute(ctx, judge, weaver.DomainProfile, userID)} }) } if hasTemporal { run(func() IngestState { - events := p.Temporal.Run(ctx, req.UserQuery, req.SessionDatetime) - judge := p.Judge.JudgeTemporal(ctx, events) + events := p.Temporal.Run(ctx, strings.Join(temporalQueries, " "), req.SessionDatetime) + judge := p.Judge.JudgeTemporal(ctx, events, userID) return IngestState{TemporalJudge: judge, TemporalWeaver: p.Weaver.Execute(ctx, judge, weaver.DomainTemporal, userID)} }) } if hasImage { run(func() IngestState { - items := p.Image.Run(ctx, req.ImageURL) + items := p.Image.Run(ctx, strings.Join(imageQueries, " "), req.ImageURL) judge := p.Judge.JudgeItems(ctx, weaver.DomainSummary, items, userID, 0.8) return IngestState{ImageJudge: judge, ImageWeaver: p.Weaver.Execute(ctx, judge, weaver.DomainSummary, userID)} }) } if hasCode { run(func() IngestState { - items := p.Snippet.Run(ctx, req.UserQuery) + items := p.Snippet.Run(ctx, strings.Join(codeQueries, " ")) judge := p.Judge.JudgeItems(ctx, weaver.DomainSnippet, items, userID, 0.8) return IngestState{SnippetJudge: judge, SnippetWeaver: p.Weaver.Execute(ctx, judge, weaver.DomainSnippet, userID)} }) diff --git a/xmem-go/internal/pipelines/retrieval.go b/xmem-go/internal/pipelines/retrieval.go index b6713da..eae03ae 100644 --- a/xmem-go/internal/pipelines/retrieval.go +++ b/xmem-go/internal/pipelines/retrieval.go @@ -2,6 +2,7 @@ package pipelines import ( "context" + json "github.com/goccy/go-json" "fmt" "math" "strings" @@ -29,17 +30,17 @@ func (p *RetrievalPipeline) Run(ctx context.Context, req contracts.RetrieveReque catalogStr := formatProfileCatalog(catalog) systemPrompt := prompts.BuildRetrievalSystemPrompt(catalogStr) - toolResp, err := p.Model.SelectTools(ctx, req.Query, catalog) + toolResp, err := selectToolsWithRetrievalPrompt(ctx, p.Model, req.Query, catalog, systemPrompt) if err != nil { return contracts.RetrieveResponse{}, err } - _ = systemPrompt - var mu sync.Mutex var wg sync.WaitGroup sources := []contracts.SourceRecord{} + toolResults := make([][]contracts.SourceRecord, len(toolResp.ToolCalls)) calledSummary := false - for _, call := range toolResp.ToolCalls { + for i, call := range toolResp.ToolCalls { + i := i call := call if normalizeToolName(call.Name) == "searchsummary" { calledSummary = true @@ -47,18 +48,27 @@ func (p *RetrievalPipeline) Run(ctx context.Context, req contracts.RetrieveReque wg.Add(1) go func() { defer wg.Done() - records := p.executeTool(ctx, call, req.Query, userID, req.TopK, profileRecords) - mu.Lock() - sources = append(sources, records...) - mu.Unlock() + toolResults[i] = p.executeTool(ctx, call, req.Query, userID, req.TopK, profileRecords) }() } wg.Wait() + if len(toolResp.ToolCalls) == 0 { + answer := strings.TrimSpace(toolResp.Content) + return contracts.RetrieveResponse{Model: p.Model.Name(), Answer: answer, Sources: sources, Confidence: 0.1}, nil + } + + contextBlocks := make([]string, 0, len(toolResults)+1) + for _, records := range toolResults { + sources = append(sources, records...) + contextBlocks = append(contextBlocks, formatToolResults(records)) + } if !calledSummary { - sources = append(sources, p.searchSummary(ctx, req.Query, userID, 20)...) + extra := p.searchSummary(ctx, req.Query, userID, 20) + sources = append(sources, extra...) + contextBlocks = append(contextBlocks, "[Auto-fetched summary context]\n"+formatToolResults(extra)) } - contextText := formatSources(sources) + contextText := strings.Join(contextBlocks, "\n") answerPrompt := prompts.BuildAnswerPrompt(contextText, req.Query) answerResp, err := p.Model.GenerateWithMessages(ctx, []models.Message{ {Role: "user", Content: answerPrompt}, @@ -73,6 +83,66 @@ func (p *RetrievalPipeline) Run(ctx context.Context, req contracts.RetrieveReque return contracts.RetrieveResponse{Model: p.Model.Name(), Answer: answerResp.Content, Sources: sources, Confidence: confidence}, nil } +func selectToolsWithRetrievalPrompt(ctx context.Context, model models.ChatModel, query string, catalog []map[string]string, systemPrompt string) (models.Response, error) { + messages := []models.Message{ + {Role: "system", Content: systemPrompt + "\n\n" + toolSelectionOutputInstructions}, + {Role: "user", Content: query}, + } + resp, err := model.GenerateWithMessages(ctx, messages) + if err != nil { + return model.SelectTools(ctx, query, catalog) + } + + toolCalls, directAnswer, ok := parseToolCalls(resp.Content) + if !ok { + fallbackResp, fallbackErr := model.SelectTools(ctx, query, catalog) + if fallbackErr != nil { + return resp, nil + } + return fallbackResp, nil + } + resp.ToolCalls = toolCalls + if directAnswer != "" { + resp.Content = directAnswer + } + return resp, nil +} + +const toolSelectionOutputInstructions = `You have access to these retrieval tools: search_profile, search_temporal, search_summary, search_snippet. +Return only JSON in this exact shape, with no markdown or commentary: +{"tool_calls":[{"name":"search_profile","args":{"topic":"work"}},{"name":"search_temporal","args":{"query":"dentist appointment"}},{"name":"search_summary","args":{"query":"..."}},{"name":"search_snippet","args":{"query":"..."}}]} +If you can answer directly without searching, return {"tool_calls":[],"answer":"..."}.` + +func parseToolCalls(text string) ([]models.ToolCall, string, bool) { + if strings.Contains(text, "Return only JSON in this exact shape") { + return nil, "", false + } + + var parsed struct { + ToolCalls []models.ToolCall `json:"tool_calls"` + Answer string `json:"answer"` + } + if err := json.Unmarshal([]byte(extractJSONObject(text)), &parsed); err != nil { + return nil, "", false + } + for i := range parsed.ToolCalls { + if parsed.ToolCalls[i].ID == "" { + parsed.ToolCalls[i].ID = fmt.Sprintf("call-%d", i+1) + } + } + return parsed.ToolCalls, parsed.Answer, true +} + +func extractJSONObject(text string) string { + text = strings.TrimSpace(text) + start := strings.Index(text, "{") + end := strings.LastIndex(text, "}") + if start >= 0 && end > start { + return text[start : end+1] + } + return "" +} + func (p *RetrievalPipeline) Search(ctx context.Context, req contracts.SearchRequest, userID string) (contracts.SearchResponse, error) { if req.TopK == 0 { req.TopK = 10 @@ -124,7 +194,7 @@ func (p *RetrievalPipeline) executeTool(ctx context.Context, call models.ToolCal store = p.VectorStore } records, _ := store.SearchByText(ctx, q, 5, map[string]any{"domain": "snippet"}) - return toSources("snippet", records) + return toSnippetSources(records) default: return nil } @@ -163,6 +233,13 @@ func searchProfile(topic string, records []storage.SearchResult) []contracts.Sou } meta := cloneMeta(record.Metadata) meta["id"] = record.ID + meta["topic"] = topic + parts := strings.SplitN(main, "_", 2) + if len(parts) == 2 { + meta["sub_topic"] = parts[1] + } else { + meta["sub_topic"] = "" + } out = append(out, contracts.SourceRecord{Domain: "profile", Content: record.Content, Score: round3(record.Score), Metadata: meta}) } return out @@ -186,6 +263,21 @@ func toSources(domain string, records []storage.SearchResult) []contracts.Source return out } +func toSnippetSources(records []storage.SearchResult) []contracts.SourceRecord { + out := make([]contracts.SourceRecord, 0, len(records)) + for _, record := range records { + meta := cloneMeta(record.Metadata) + meta["id"] = record.ID + content := record.Content + if snippet, ok := record.Metadata["code_snippet"].(string); ok && snippet != "" { + lang, _ := record.Metadata["language"].(string) + content += fmt.Sprintf("\n```%s\n%s\n```", lang, snippet) + } + out = append(out, contracts.SourceRecord{Domain: "snippet", Content: content, Score: round3(record.Score), Metadata: meta}) + } + return out +} + func eventsToSources(events []graph.Event) []contracts.SourceRecord { out := make([]contracts.SourceRecord, 0, len(events)) for _, ev := range events { @@ -220,13 +312,17 @@ func eventsToSources(events []graph.Event) []contracts.SourceRecord { return out } -func formatSources(sources []contracts.SourceRecord) string { - if len(sources) == 0 { +func formatToolResults(records []contracts.SourceRecord) string { + if len(records) == 0 { return "No results found." } - lines := make([]string, 0, len(sources)) - for i, src := range sources { - lines = append(lines, fmt.Sprintf("%d. [%s] %s", i+1, src.Domain, src.Content)) + lines := make([]string, 0, len(records)) + for i, rec := range records { + score := "" + if rec.Score > 0 { + score = fmt.Sprintf(" (score: %.2f)", rec.Score) + } + lines = append(lines, fmt.Sprintf("%d. [%s]%s %s", i+1, rec.Domain, score, rec.Content)) } return strings.Join(lines, "\n") } @@ -255,11 +351,7 @@ func formatProfileCatalog(catalog []map[string]string) string { for _, item := range catalog { topic := item["topic"] subTopic := item["sub_topic"] - if subTopic != "" { - lines = append(lines, fmt.Sprintf(" - %s / %s", topic, subTopic)) - } else { - lines = append(lines, fmt.Sprintf(" - %s", topic)) - } + lines = append(lines, fmt.Sprintf(" - %s / %s", topic, subTopic)) } return strings.Join(lines, "\n") } diff --git a/xmem-go/internal/prompts/judge.go b/xmem-go/internal/prompts/judge.go index dbdce6e..31c31c3 100644 --- a/xmem-go/internal/prompts/judge.go +++ b/xmem-go/internal/prompts/judge.go @@ -225,7 +225,7 @@ func PackJudgeQuery(newItems []string, similarExisting []string, domain string) } newBlock := strings.Join(numberedItems, "\n") - similarBlock := "(No similar records found — vector store is empty or search returned nothing)" + similarBlock := "(No similar records found — store is empty or search returned nothing)" if len(similarExisting) > 0 { similarBlock = strings.Join(similarExisting, "\n") } diff --git a/xmem-go/internal/prompts/temporal.go b/xmem-go/internal/prompts/temporal.go index 35951fb..b41401a 100644 --- a/xmem-go/internal/prompts/temporal.go +++ b/xmem-go/internal/prompts/temporal.go @@ -61,6 +61,11 @@ var temporalExamples = []struct { "2:35 pm on 16 March, 2023", "DATE: 03-09\nEVENT_NAME: Started Gym\nYEAR: 2023\nDESC: Started going to the gym\nTIME: \nDATE_EXPRESSION: last week", }, + { + "I started a new job at Vercel as a frontend developer today!", + "4:00 pm on 20 May, 2026", + "DATE: 05-20\nEVENT_NAME: Started New Job\nYEAR: 2026\nDESC: Started a new job at Vercel as a frontend developer\nTIME: \nDATE_EXPRESSION: today", + }, { "I'm getting ready for a dance comp near me next month.", "10:43 am on 4 February, 2023", @@ -118,6 +123,7 @@ Your task is to extract ALL structured temporal event information from user inpu You will be given a CONTEXT_DATE which is the date/time when the conversation occurred. Use this to resolve relative expressions: +- "today" → use CONTEXT_DATE - "yesterday" → subtract 1 day from CONTEXT_DATE - "tomorrow" → add 1 day to CONTEXT_DATE - "next Friday" → find the next Friday after CONTEXT_DATE diff --git a/xmem-go/internal/storage/openai.go b/xmem-go/internal/storage/openai.go index 35ba9e2..24ccbe0 100644 --- a/xmem-go/internal/storage/openai.go +++ b/xmem-go/internal/storage/openai.go @@ -3,7 +3,7 @@ package storage import ( "bytes" "context" - "encoding/json" + json "github.com/goccy/go-json" "errors" "fmt" "io" diff --git a/xmem-go/internal/storage/pinecone.go b/xmem-go/internal/storage/pinecone.go index fcf468f..11ab753 100644 --- a/xmem-go/internal/storage/pinecone.go +++ b/xmem-go/internal/storage/pinecone.go @@ -3,7 +3,7 @@ package storage import ( "bytes" "context" - "encoding/json" + json "github.com/goccy/go-json" "errors" "fmt" "io" @@ -12,6 +12,7 @@ import ( "time" "github.com/xortexai/xmem-go/internal/config" + "github.com/xortexai/xmem-go/internal/utils" ) type PineconeVectorStore struct { @@ -171,6 +172,12 @@ func (s *PineconeVectorStore) resolveHost(ctx context.Context) (string, error) { } func (s *PineconeVectorStore) do(ctx context.Context, method string, path string, body any, out any) error { + return utils.RetryWithBackoff(ctx, 3, time.Second, func() error { + return s.doOnce(ctx, method, path, body, out) + }) +} + +func (s *PineconeVectorStore) doOnce(ctx context.Context, method string, path string, body any, out any) error { var reader io.Reader if body != nil { b, err := json.Marshal(body) diff --git a/xmem-go/internal/utils/retry.go b/xmem-go/internal/utils/retry.go new file mode 100644 index 0000000..dc94072 --- /dev/null +++ b/xmem-go/internal/utils/retry.go @@ -0,0 +1,43 @@ +package utils + +import ( + "context" + "strings" + "time" +) + +func RetryWithBackoff(ctx context.Context, maxRetries int, baseDelay time.Duration, fn func() error) error { + var lastErr error + for attempt := 0; attempt < maxRetries; attempt++ { + lastErr = fn() + if lastErr == nil { + return nil + } + if !isRetryable(lastErr) { + return lastErr + } + if attempt < maxRetries-1 { + delay := baseDelay * (1 << uint(attempt)) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(delay): + } + } + } + return lastErr +} + +func isRetryable(err error) bool { + msg := strings.ToLower(err.Error()) + retryableKeywords := []string{ + "connection", "timeout", "eof", "reset", "refused", + "ssl", "routing", "temporary", "unavailable", "503", "429", + } + for _, kw := range retryableKeywords { + if strings.Contains(msg, kw) { + return true + } + } + return false +}