From b60eb7aa727dfc370c9c797fac322effac712c86 Mon Sep 17 00:00:00 2001 From: Test User Date: Sun, 5 Apr 2026 22:47:29 -0700 Subject: [PATCH 1/3] feat(adapters): abstract streaming layer behind LLMProvider (#548) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add StreamChunk dataclass to base.py as normalized streaming type - Add async_stream() + supports() to LLMProvider ABC - Implement async_stream() in AnthropicProvider (moves streaming out of streaming_chat.py; supports("extended_thinking") returns True) - Implement async_stream() in OpenAIProvider (SSE → StreamChunk) - Add async_stream() to MockProvider with configurable StreamChunk sequences - Refactor StreamingChatAdapter to accept LLMProvider; no import anthropic - Update session_chat_ws.py to construct AnthropicProvider explicitly - Export StreamChunk from adapters.llm.__init__ - Update streaming chat tests to use MockProvider instead of _async_client patch - Fix test_e2b_adapter.py to skip TestE2BAgentAdapter when e2b not installed --- codeframe/adapters/llm/__init__.py | 2 + codeframe/adapters/llm/anthropic.py | 93 +++- codeframe/adapters/llm/base.py | 80 +++- codeframe/adapters/llm/mock.py | 50 ++- codeframe/adapters/llm/openai.py | 128 +++++- codeframe/core/adapters/streaming_chat.py | 195 ++++----- codeframe/ui/routers/session_chat_ws.py | 3 + tests/adapters/test_e2b_adapter.py | 8 + tests/core/test_streaming_chat.py | 503 ++++++++-------------- 9 files changed, 622 insertions(+), 440 deletions(-) diff --git a/codeframe/adapters/llm/__init__.py b/codeframe/adapters/llm/__init__.py index b70adf6d..b94ab347 100644 --- a/codeframe/adapters/llm/__init__.py +++ b/codeframe/adapters/llm/__init__.py @@ -26,6 +26,7 @@ Message, ModelSelector, Purpose, + StreamChunk, Tool, ToolCall, ToolResult, @@ -40,6 +41,7 @@ "Message", "ModelSelector", "Purpose", + "StreamChunk", "Tool", "ToolCall", "ToolResult", diff --git a/codeframe/adapters/llm/anthropic.py b/codeframe/adapters/llm/anthropic.py index 19fb60d4..7b7d24d2 100644 --- a/codeframe/adapters/llm/anthropic.py +++ b/codeframe/adapters/llm/anthropic.py @@ -3,14 +3,16 @@ Provides Claude model access via the Anthropic API. """ +import asyncio import os -from typing import TYPE_CHECKING, Iterator, Optional +from typing import TYPE_CHECKING, AsyncIterator, Iterator, Optional from codeframe.adapters.llm.base import ( LLMProvider, LLMResponse, ModelSelector, Purpose, + StreamChunk, Tool, ToolCall, ) @@ -172,6 +174,95 @@ async def async_complete( except APIConnectionError as exc: raise LLMConnectionError(str(exc)) from exc + def supports(self, capability: str) -> bool: + """Return True for capabilities this provider supports.""" + return capability == "extended_thinking" + + async def async_stream( + self, + messages: list[dict], + system: str, + tools: list[dict], + model: str, + max_tokens: int, + interrupt_event: Optional[asyncio.Event] = None, + ) -> AsyncIterator[StreamChunk]: + """Stream using Anthropic AsyncAnthropic SDK, yielding StreamChunk objects. + + Translates Anthropic SDK events into the normalized StreamChunk format. + Tool inputs are collected and emitted in the final message_stop chunk + via tool_inputs_by_id, which is more reliable than streaming input deltas. + """ + from anthropic import AsyncAnthropic + + if self._async_client is None: + self._async_client = AsyncAnthropic(api_key=self.api_key) + + kwargs: dict = { + "model": model, + "system": system, + "messages": messages, + "tools": tools, + "max_tokens": max_tokens, + } + + active_tool_id: Optional[str] = None + + async with self._async_client.messages.stream(**kwargs) as stream: + async for sdk_event in stream: + if interrupt_event and interrupt_event.is_set(): + return + + event_type = sdk_event.type + + if event_type == "content_block_start": + block = sdk_event.content_block + if block.type == "tool_use": + active_tool_id = block.id + yield StreamChunk( + type="tool_use_start", + tool_id=block.id, + tool_name=block.name, + tool_input=getattr(block, "input", {}), + ) + + elif event_type == "content_block_delta": + delta = sdk_event.delta + if delta.type == "text_delta": + yield StreamChunk(type="text_delta", text=delta.text) + elif delta.type == "thinking_delta": + yield StreamChunk(type="thinking_delta", text=delta.thinking) + # input_json_delta: final inputs are rebuilt from message_stop + + elif event_type == "content_block_stop": + if active_tool_id is not None: + yield StreamChunk(type="tool_use_stop") + active_tool_id = None + + elif event_type == "message_stop": + # Flush any open tool block + if active_tool_id is not None: + yield StreamChunk(type="tool_use_stop") + active_tool_id = None + + final_msg = stream.get_final_message() + stop_reason = final_msg.stop_reason or "end_turn" + + # Build tool_inputs_by_id from final content blocks + tool_inputs_by_id: dict = {} + if hasattr(final_msg, "content"): + for block in final_msg.content: + if getattr(block, "type", None) == "tool_use" and hasattr(block, "id"): + tool_inputs_by_id[block.id] = getattr(block, "input", {}) + + yield StreamChunk( + type="message_stop", + stop_reason=stop_reason, + input_tokens=final_msg.usage.input_tokens, + output_tokens=final_msg.usage.output_tokens, + tool_inputs_by_id=tool_inputs_by_id, + ) + def stream( self, messages: list[dict], diff --git a/codeframe/adapters/llm/base.py b/codeframe/adapters/llm/base.py index a88b12b4..104866a2 100644 --- a/codeframe/adapters/llm/base.py +++ b/codeframe/adapters/llm/base.py @@ -9,7 +9,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum -from typing import Iterator, Optional +from typing import AsyncIterator, Iterator, Optional # --------------------------------------------------------------------------- @@ -120,6 +120,40 @@ def for_purpose(self, purpose: Purpose) -> str: return self.execution_model # Default fallback +@dataclass +class StreamChunk: + """A normalized chunk from a streaming LLM response. + + Provider-specific streaming formats are translated into this common type + by each :class:`LLMProvider` implementation. + + Attributes: + type: Event type — one of ``"text_delta"``, ``"thinking_delta"``, + ``"tool_use_start"``, ``"tool_use_stop"``, ``"message_stop"``. + text: Text content for ``text_delta`` and ``thinking_delta`` types. + tool_id: Tool call ID for ``tool_use_start``. + tool_name: Tool name for ``tool_use_start``. + tool_input: Tool input dict for ``tool_use_start`` (may be empty; + final inputs are provided in the ``message_stop`` chunk). + input_tokens: Input token count, populated for ``message_stop``. + output_tokens: Output token count, populated for ``message_stop``. + stop_reason: Why the model stopped, populated for ``message_stop``. + tool_inputs_by_id: Mapping of tool_id → final input dict, populated + for ``message_stop``. More reliable than streaming incremental + input deltas. + """ + + type: str + text: Optional[str] = None + tool_id: Optional[str] = None + tool_name: Optional[str] = None + tool_input: Optional[dict] = None + input_tokens: Optional[int] = None + output_tokens: Optional[int] = None + stop_reason: Optional[str] = None + tool_inputs_by_id: Optional[dict] = None + + @dataclass class ToolCall: """Represents a tool call requested by the LLM. @@ -332,6 +366,50 @@ async def async_complete( lambda: self.complete(messages, purpose, tools, max_tokens, temperature, system), ) + def supports(self, capability: str) -> bool: + """Check whether this provider supports an optional capability. + + Args: + capability: Capability name, e.g. ``"extended_thinking"``. + + Returns: + ``True`` if the capability is supported, ``False`` otherwise. + """ + return False + + async def async_stream( + self, + messages: list[dict], + system: str, + tools: list[dict], + model: str, + max_tokens: int, + interrupt_event: Optional[asyncio.Event] = None, + ) -> AsyncIterator["StreamChunk"]: + """Stream a completion as normalized :class:`StreamChunk` objects. + + Subclasses should override this with a provider-specific implementation. + The default raises :exc:`NotImplementedError`. + + Args: + messages: Conversation messages in the provider's expected format. + system: System prompt string. + tools: Already-serialized tool definitions (list of dicts). + model: Model identifier to use for this call. + max_tokens: Maximum output tokens. + interrupt_event: When set, the stream should stop at the next + opportunity. + + Yields: + :class:`StreamChunk` objects in order of generation. + """ + raise NotImplementedError( + f"{type(self).__name__} does not implement async_stream(). " + "Override this method in your provider subclass." + ) + if False: # pragma: no cover # makes this an async generator + yield # type: ignore[misc] + def get_model(self, purpose: Purpose) -> str: """Get the model for a given purpose. diff --git a/codeframe/adapters/llm/mock.py b/codeframe/adapters/llm/mock.py index 501190f0..408b04b6 100644 --- a/codeframe/adapters/llm/mock.py +++ b/codeframe/adapters/llm/mock.py @@ -4,13 +4,15 @@ Supports configurable responses and call tracking. """ -from typing import Callable, Iterator, Optional +import asyncio +from typing import AsyncIterator, Callable, Iterator, Optional from codeframe.adapters.llm.base import ( LLMProvider, LLMResponse, ModelSelector, Purpose, + StreamChunk, Tool, ToolCall, ) @@ -40,6 +42,8 @@ def __init__( self.responses: list[LLMResponse] = [] self.response_index = 0 self.response_handler: Optional[Callable[[list[dict]], LLMResponse]] = None + self.stream_chunks: list[list[StreamChunk]] = [] + self.stream_index = 0 def add_response(self, response: LLMResponse) -> None: """Add a canned response to the queue. @@ -175,12 +179,56 @@ def stream( for word in response.content.split(): yield word + " " + def add_stream_chunks(self, chunks: list[StreamChunk]) -> None: + """Add a sequence of StreamChunks for the next async_stream() call. + + Args: + chunks: Ordered list of StreamChunk objects to yield. + """ + self.stream_chunks.append(chunks) + + async def async_stream( + self, + messages: list[dict], + system: str, + tools: list[dict], + model: str, + max_tokens: int, + interrupt_event: Optional[asyncio.Event] = None, + ) -> AsyncIterator[StreamChunk]: + """Yield pre-configured StreamChunk sequences for testing. + + Falls back to a minimal text_delta + message_stop pair when no + stream chunks have been configured via add_stream_chunks(). + """ + if self.stream_index < len(self.stream_chunks): + chunks = self.stream_chunks[self.stream_index] + self.stream_index += 1 + else: + # Default: simple text response followed by message_stop + chunks = [ + StreamChunk(type="text_delta", text=self.default_response), + StreamChunk( + type="message_stop", + stop_reason="end_turn", + input_tokens=len(str(messages)), + output_tokens=len(self.default_response), + tool_inputs_by_id={}, + ), + ] + for chunk in chunks: + if interrupt_event and interrupt_event.is_set(): + return + yield chunk + def reset(self) -> None: """Reset call tracking and response queue.""" self.calls.clear() self.responses.clear() self.response_index = 0 self.response_handler = None + self.stream_chunks.clear() + self.stream_index = 0 @property def call_count(self) -> int: diff --git a/codeframe/adapters/llm/openai.py b/codeframe/adapters/llm/openai.py index 858a28c4..c17b8b93 100644 --- a/codeframe/adapters/llm/openai.py +++ b/codeframe/adapters/llm/openai.py @@ -4,9 +4,10 @@ (Ollama, vLLM, LM Studio, Groq, Together, etc.) via the openai SDK. """ +import asyncio import json import os -from typing import TYPE_CHECKING, Iterator, Optional +from typing import TYPE_CHECKING, AsyncIterator, Iterator, Optional import openai @@ -15,6 +16,7 @@ LLMResponse, ModelSelector, Purpose, + StreamChunk, Tool, ToolCall, ) @@ -196,6 +198,130 @@ async def async_complete( except _openai.APIConnectionError as exc: raise LLMConnectionError(str(exc)) from exc + async def async_stream( + self, + messages: list[dict], + system: str, + tools: list[dict], + model: str, + max_tokens: int, + interrupt_event: Optional[asyncio.Event] = None, + ) -> AsyncIterator[StreamChunk]: + """Stream using OpenAI async client, yielding StreamChunk objects. + + Translates OpenAI SSE chunks into the normalized StreamChunk format. + Tool calls are emitted as tool_use_start chunks; final inputs are + collected and emitted in the message_stop chunk via tool_inputs_by_id. + """ + import openai as _openai + + if self._async_client is None: + self._async_client = _openai.AsyncOpenAI( + api_key=self.api_key, base_url=self.base_url + ) + + converted = self._convert_messages(messages) + if system: + converted = [{"role": "system", "content": system}] + converted + + kwargs: dict = { + "model": model, + "max_tokens": max_tokens, + "messages": converted, + "stream": True, + "stream_options": {"include_usage": True}, + } + + # Convert tools from Anthropic-style (input_schema) to OpenAI format + if tools: + kwargs["tools"] = [ + { + "type": "function", + "function": { + "name": t["name"], + "description": t.get("description", ""), + "parameters": t.get("input_schema", {}), + }, + } + for t in tools + ] + kwargs["tool_choice"] = "auto" + + # Track partial tool calls across chunks (OpenAI streams them incrementally) + # key: index → {id, name, arguments_parts} + partial_tool_calls: dict[int, dict] = {} + usage_input: int = 0 + usage_output: int = 0 + stop_reason: str = "end_turn" + + async for chunk in await self._async_client.chat.completions.create(**kwargs): + if interrupt_event and interrupt_event.is_set(): + return + + # Usage is in the final chunk when stream_options.include_usage is set + if chunk.usage is not None: + usage_input = chunk.usage.prompt_tokens or 0 + usage_output = chunk.usage.completion_tokens or 0 + + if not chunk.choices: + continue + + choice = chunk.choices[0] + delta = choice.delta + + if choice.finish_reason: + stop_reason = _STOP_REASON_MAP.get(choice.finish_reason, choice.finish_reason) + + if delta.content: + yield StreamChunk(type="text_delta", text=delta.content) + + if delta.tool_calls: + for tc_delta in delta.tool_calls: + idx = tc_delta.index + if idx not in partial_tool_calls: + partial_tool_calls[idx] = { + "id": tc_delta.id or "", + "name": tc_delta.function.name if tc_delta.function else "", + "arguments_parts": [], + } + # Emit start immediately when we first see the tool call + yield StreamChunk( + type="tool_use_start", + tool_id=tc_delta.id or "", + tool_name=(tc_delta.function.name if tc_delta.function else ""), + tool_input={}, + ) + else: + # Accumulate id/name updates + if tc_delta.id: + partial_tool_calls[idx]["id"] = tc_delta.id + if tc_delta.function and tc_delta.function.name: + partial_tool_calls[idx]["name"] = tc_delta.function.name + if tc_delta.function and tc_delta.function.arguments: + partial_tool_calls[idx]["arguments_parts"].append( + tc_delta.function.arguments + ) + + # Build tool_inputs_by_id from accumulated partial tool calls + tool_inputs_by_id: dict = {} + for tc in partial_tool_calls.values(): + try: + tool_inputs_by_id[tc["id"]] = json.loads( + "".join(tc["arguments_parts"]) or "{}" + ) + except json.JSONDecodeError: + tool_inputs_by_id[tc["id"]] = {} + # Emit tool_use_stop for each completed tool call + yield StreamChunk(type="tool_use_stop") + + yield StreamChunk( + type="message_stop", + stop_reason=stop_reason, + input_tokens=usage_input, + output_tokens=usage_output, + tool_inputs_by_id=tool_inputs_by_id, + ) + def stream( self, messages: list[dict], diff --git a/codeframe/core/adapters/streaming_chat.py b/codeframe/core/adapters/streaming_chat.py index 6a1a9845..1fe98640 100644 --- a/codeframe/core/adapters/streaming_chat.py +++ b/codeframe/core/adapters/streaming_chat.py @@ -1,11 +1,11 @@ -"""Streaming chat adapter for Anthropic SDK. +"""Streaming chat adapter — provider-agnostic. -Wraps ``anthropic.AsyncAnthropic().messages.stream()`` and emits typed -``ChatEvent`` objects for consumption by the WebSocket relay layer. +Wraps any :class:`LLMProvider` that implements ``async_stream()`` and emits +typed ``ChatEvent`` objects for consumption by the WebSocket relay layer. Supports: - Token-by-token text streaming (TEXT_DELTA) -- Extended thinking tokens (THINKING) +- Extended thinking tokens (THINKING) on providers that support it - Safe read-only tool calls: read_file, list_files, search_codebase - Interrupt via asyncio.Event - Message persistence to session_messages after each complete turn @@ -16,13 +16,12 @@ import asyncio import logging -import os from dataclasses import dataclass from enum import Enum from pathlib import Path from typing import AsyncIterator, Optional -from codeframe.adapters.llm.base import Tool, ToolCall, ToolResult +from codeframe.adapters.llm.base import LLMProvider, Tool, ToolCall, ToolResult from codeframe.core.tools import ( execute_tool, _READ_FILE_SCHEMA, @@ -132,7 +131,7 @@ def to_dict(self) -> dict: class StreamingChatAdapter: - """Async streaming adapter over ``anthropic.AsyncAnthropic().messages.stream()``. + """Async streaming adapter over any :class:`LLMProvider` with ``async_stream()``. Each call to :meth:`send_message` is a single conversational turn. History is loaded from the DB at call time and persisted after the turn @@ -150,6 +149,7 @@ def __init__( workspace_path: Path, model: str = _DEFAULT_MODEL, api_key: Optional[str] = None, + provider: Optional[LLMProvider] = None, ) -> None: """Initialise the adapter. @@ -157,34 +157,26 @@ def __init__( session_id: ID of the interactive session (used for DB access). db_repo: ``InteractiveSessionRepository`` instance. workspace_path: Absolute path used to scope file-system tool calls. - model: Anthropic model identifier. - api_key: Override API key (falls back to ``ANTHROPIC_API_KEY`` env var). + model: Model identifier passed to the provider's ``async_stream()``. + api_key: API key used when constructing the default + ``AnthropicProvider`` (when ``provider`` is ``None``). + provider: LLM provider to use for streaming. When ``None``, + an ``AnthropicProvider`` is constructed automatically using + ``api_key`` or the ``ANTHROPIC_API_KEY`` environment variable. Raises: - ValueError: If no Anthropic API key is available. + ValueError: If no provider is given and no Anthropic API key is + available. """ - resolved_key = api_key or os.getenv("ANTHROPIC_API_KEY") - if not resolved_key: - raise ValueError( - "ANTHROPIC_API_KEY not set. " - "Set the environment variable or pass api_key to StreamingChatAdapter." - ) + if provider is None: + from codeframe.adapters.llm.anthropic import AnthropicProvider + provider = AnthropicProvider(api_key=api_key) self._session_id = session_id self._db_repo = db_repo self._workspace_path = workspace_path self._model = model - - # Lazily initialised — avoids importing anthropic at module import time - self._async_client = None - self._api_key = resolved_key - - @property - def _client(self): - if self._async_client is None: - from anthropic import AsyncAnthropic - self._async_client = AsyncAnthropic(api_key=self._api_key) - return self._async_client + self._provider = provider # ------------------------------------------------------------------ # History helpers @@ -228,7 +220,7 @@ def _count(msgs: list[dict]) -> int: # Drop in pairs so we don't strand an assistant message at index 0 messages = messages[2:] if len(messages) >= 2 else messages[1:] - # Anthropic requires the first message to have role "user" + # First message must have role "user" while messages and messages[0].get("role") != "user": messages = messages[1:] @@ -357,90 +349,67 @@ async def _stream_turn( """ current_messages = list(messages) + system_prompt = ( + f"You are a CodeFrame assistant helping the user understand and navigate " + f"their codebase. You have read-only access to the workspace at " + f"{self._workspace_path}. Available tools: read_file, list_files, " + f"search_codebase. Do not attempt to modify files or execute shell commands." + ) + while True: - # Track tool calls seen in this API turn for the follow-up message - pending_tool_calls: list[dict] = [] # {id, name, input, result} - active_tool: dict | None = None # buffering the current tool_use block + pending_tool_calls: list[dict] = [] # {id, name, input} stop_reason = "end_turn" - async with self._client.messages.stream( - model=self._model, - system=( - f"You are a CodeFrame assistant helping the user understand and navigate " - f"their codebase. You have read-only access to the workspace at " - f"{self._workspace_path}. Available tools: read_file, list_files, " - f"search_codebase. Do not attempt to modify files or execute shell commands." - ), + async for chunk in self._provider.async_stream( messages=current_messages, + system=system_prompt, tools=_TOOLS_FOR_API, + model=self._model, max_tokens=4096, - ) as stream: - async for sdk_event in stream: - # Honour interrupt between chunks - if interrupt_event and interrupt_event.is_set(): - return - - event_type = sdk_event.type - - if event_type == "content_block_start": - block = sdk_event.content_block - if block.type == "tool_use": - active_tool = { - "id": block.id, - "name": block.name, - "input": getattr(block, "input", {}), - } - yield ChatEvent( - type=ChatEventType.TOOL_USE_START, - tool_name=block.name, - tool_input=getattr(block, "input", {}), - ) - - elif event_type == "content_block_delta": - delta = sdk_event.delta - if delta.type == "text_delta": - yield ChatEvent( - type=ChatEventType.TEXT_DELTA, - content=delta.text, - ) - elif delta.type == "thinking_delta": - yield ChatEvent( - type=ChatEventType.THINKING, - content=delta.thinking, - ) - elif delta.type == "input_json_delta" and active_tool is not None: - # The SDK may stream tool input as JSON deltas; accumulate - pass # Full input is available on content_block_stop via final msg - - elif event_type == "content_block_stop": - if active_tool is not None: - pending_tool_calls.append(active_tool) - active_tool = None - - elif event_type == "message_stop": - # Flush any tool block that didn't get a content_block_stop - if active_tool is not None: - pending_tool_calls.append(active_tool) - active_tool = None - - # Collect final usage stats - final_msg = stream.get_final_message() - stop_reason = final_msg.stop_reason or "end_turn" - - # Rebuild tool inputs from final message (more reliable than streaming) - if pending_tool_calls and hasattr(final_msg, "content"): - _rebuild_tool_inputs(final_msg.content, pending_tool_calls) - - yield ChatEvent( - type=ChatEventType.COST_UPDATE, - input_tokens=final_msg.usage.input_tokens, - output_tokens=final_msg.usage.output_tokens, - cost_usd=_estimate_cost( - final_msg.usage.input_tokens, - final_msg.usage.output_tokens, - self._model, - ), - ) + interrupt_event=interrupt_event, + ): + if interrupt_event and interrupt_event.is_set(): + return + + if chunk.type == "text_delta": + yield ChatEvent(type=ChatEventType.TEXT_DELTA, content=chunk.text) + + elif chunk.type == "thinking_delta": + yield ChatEvent(type=ChatEventType.THINKING, content=chunk.text) + + elif chunk.type == "tool_use_start": + pending_tool_calls.append({ + "id": chunk.tool_id, + "name": chunk.tool_name, + "input": chunk.tool_input or {}, + }) + yield ChatEvent( + type=ChatEventType.TOOL_USE_START, + tool_name=chunk.tool_name, + tool_input=chunk.tool_input or {}, + ) + + elif chunk.type == "message_stop": + stop_reason = chunk.stop_reason or "end_turn" + + # Back-fill tool inputs from final message (more reliable) + if chunk.tool_inputs_by_id and pending_tool_calls: + for tc in pending_tool_calls: + if tc["id"] in chunk.tool_inputs_by_id: + tc["input"] = chunk.tool_inputs_by_id[tc["id"]] + + yield ChatEvent( + type=ChatEventType.COST_UPDATE, + input_tokens=chunk.input_tokens, + output_tokens=chunk.output_tokens, + cost_usd=_estimate_cost( + chunk.input_tokens or 0, + chunk.output_tokens or 0, + self._model, + ), + ) + + # tool_use_stop is informational only — no ChatEvent needed if stop_reason == "end_turn" or not pending_tool_calls: yield ChatEvent(type=ChatEventType.DONE) @@ -485,20 +454,6 @@ async def _stream_turn( # --------------------------------------------------------------------------- -def _rebuild_tool_inputs(content_blocks, pending_tool_calls: list[dict]) -> None: - """Back-fill tool inputs from the final message content blocks. - - The streaming API may emit input_json_delta events that are tricky to - reconstruct incrementally. Reading inputs off the final message is simpler - and more reliable. - """ - by_id = {tc["id"]: tc for tc in pending_tool_calls} - for block in content_blocks: - block_id = getattr(block, "id", None) - if block_id and block_id in by_id and getattr(block, "type", None) == "tool_use": - by_id[block_id]["input"] = getattr(block, "input", {}) - - def _estimate_cost(input_tokens: int, output_tokens: int, model: str) -> float: """Rough cost estimate in USD. diff --git a/codeframe/ui/routers/session_chat_ws.py b/codeframe/ui/routers/session_chat_ws.py index 7f67a475..b3f987ea 100644 --- a/codeframe/ui/routers/session_chat_ws.py +++ b/codeframe/ui/routers/session_chat_ws.py @@ -115,10 +115,13 @@ async def _run_streaming_adapter( workspace_path: Absolute path to the workspace for file-system tools. """ try: + from codeframe.adapters.llm.anthropic import AnthropicProvider + provider = AnthropicProvider() adapter = StreamingChatAdapter( session_id=session_id, db_repo=db_repo, workspace_path=workspace_path, + provider=provider, ) async for event in adapter.send_message( content=user_message, diff --git a/tests/adapters/test_e2b_adapter.py b/tests/adapters/test_e2b_adapter.py index 350d0784..f1a62eb7 100644 --- a/tests/adapters/test_e2b_adapter.py +++ b/tests/adapters/test_e2b_adapter.py @@ -6,6 +6,7 @@ from __future__ import annotations +import importlib import os import sqlite3 from pathlib import Path @@ -14,6 +15,12 @@ pytestmark = pytest.mark.v2 +# E2B is an optional cloud extra; skip adapter tests when the package is absent +_e2b_available = importlib.util.find_spec("e2b") is not None +_skip_if_no_e2b = pytest.mark.skipif( + not _e2b_available, reason="e2b package not installed (pip install codeframe[cloud])" +) + # --------------------------------------------------------------------------- # credential_scanner tests @@ -244,6 +251,7 @@ def _make_mock_sandbox(exit_code: int = 0, stdout: str = "", stderr: str = ""): return sbx +@_skip_if_no_e2b class TestE2BAgentAdapter: """Tests for E2BAgentAdapter.""" diff --git a/tests/core/test_streaming_chat.py b/tests/core/test_streaming_chat.py index c19f44bd..0170cdc7 100644 --- a/tests/core/test_streaming_chat.py +++ b/tests/core/test_streaming_chat.py @@ -1,8 +1,8 @@ """Unit tests for StreamingChatAdapter. -Uses mocked AsyncAnthropic client to avoid real API calls. -All streaming events are exercised: TEXT_DELTA, THINKING, TOOL_USE_START, -TOOL_RESULT, COST_UPDATE, DONE, ERROR. +Uses MockProvider with async_stream() to drive streaming scenarios without +real API calls. All streaming events are exercised: TEXT_DELTA, THINKING, +TOOL_USE_START, TOOL_RESULT, COST_UPDATE, DONE, ERROR. """ from __future__ import annotations @@ -10,11 +10,12 @@ import asyncio import uuid from pathlib import Path -from typing import AsyncIterator from unittest.mock import AsyncMock, MagicMock, patch import pytest +from codeframe.adapters.llm.base import StreamChunk +from codeframe.adapters.llm.mock import MockProvider from codeframe.core.adapters.streaming_chat import ( ChatEventType, StreamingChatAdapter, @@ -34,78 +35,35 @@ def _make_db_repo(messages: list[dict] | None = None): return repo -def _make_stream_events(events: list[dict]) -> AsyncIterator: - """Async iterator of mock SDK events for use in stream context manager.""" - - class _MockStreamCM: - def __init__(self, evts): - self._events = evts - self._final_message = MagicMock() - self._final_message.usage.input_tokens = 10 - self._final_message.usage.output_tokens = 20 - self._final_message.stop_reason = "end_turn" - - async def __aenter__(self): - return self - - async def __aexit__(self, *_): - pass - - def __aiter__(self): - return self._iter() - - async def _iter(self): - for evt in self._events: - yield evt - - def get_final_message(self): - return self._final_message - - return _MockStreamCM(events) - - -def _text_event(text: str): - evt = MagicMock() - evt.type = "content_block_delta" - evt.delta = MagicMock() - evt.delta.type = "text_delta" - evt.delta.text = text - return evt - - -def _thinking_event(text: str): - evt = MagicMock() - evt.type = "content_block_delta" - evt.delta = MagicMock() - evt.delta.type = "thinking_delta" - evt.delta.thinking = text - return evt - - -def _tool_start_event(tool_name: str, tool_id: str, tool_input: dict): - evt = MagicMock() - evt.type = "content_block_start" - evt.content_block = MagicMock() - evt.content_block.type = "tool_use" - evt.content_block.name = tool_name - evt.content_block.id = tool_id - evt.content_block.input = tool_input - return evt - - -def _tool_stop_event(tool_id: str): - evt = MagicMock() - evt.type = "content_block_stop" - evt.index = 0 - # Signal that this stop corresponds to a tool_use block - evt._tool_id = tool_id - return evt - - -def _message_stop_event(): - evt = MagicMock() - evt.type = "message_stop" - return evt +def _stop_chunk( + stop_reason: str = "end_turn", + input_tokens: int = 10, + output_tokens: int = 20, + tool_inputs_by_id: dict | None = None, +) -> StreamChunk: + return StreamChunk( + type="message_stop", + stop_reason=stop_reason, + input_tokens=input_tokens, + output_tokens=output_tokens, + tool_inputs_by_id=tool_inputs_by_id or {}, + ) + + +def _make_adapter( + session_id: str = "s1", + db_repo=None, + workspace_path: Path | None = None, + provider: MockProvider | None = None, + model: str = "claude-sonnet-4-20250514", +) -> StreamingChatAdapter: + return StreamingChatAdapter( + session_id=session_id, + db_repo=db_repo or _make_db_repo(), + workspace_path=workspace_path or Path("/tmp"), + model=model, + provider=provider or MockProvider(), + ) # --------------------------------------------------------------------------- @@ -115,6 +73,7 @@ def _message_stop_event(): class TestStreamingChatAdapterInit: def test_raises_if_no_api_key(self, monkeypatch): + """When no provider is supplied, the default AnthropicProvider raises.""" monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) with pytest.raises(ValueError, match="ANTHROPIC_API_KEY"): StreamingChatAdapter( @@ -123,33 +82,28 @@ def test_raises_if_no_api_key(self, monkeypatch): workspace_path=Path("/tmp"), ) - def test_accepts_api_key_from_env(self, monkeypatch): - monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") - adapter = StreamingChatAdapter( - session_id="s1", - db_repo=_make_db_repo(), - workspace_path=Path("/tmp"), - ) + def test_accepts_explicit_provider(self): + adapter = _make_adapter() assert adapter._session_id == "s1" + assert isinstance(adapter._provider, MockProvider) - def test_default_model(self, monkeypatch): - monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") - adapter = StreamingChatAdapter( - session_id="s1", - db_repo=_make_db_repo(), - workspace_path=Path("/tmp"), - ) + def test_default_model(self): + adapter = _make_adapter() assert adapter._model == "claude-sonnet-4-20250514" - def test_custom_model(self, monkeypatch): + def test_custom_model(self): + adapter = _make_adapter(model="claude-opus-4-20250514") + assert adapter._model == "claude-opus-4-20250514" + + def test_accepts_api_key_from_env(self, monkeypatch): + """Legacy path: no provider given but ANTHROPIC_API_KEY set.""" monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") adapter = StreamingChatAdapter( session_id="s1", db_repo=_make_db_repo(), workspace_path=Path("/tmp"), - model="claude-opus-4-20250514", ) - assert adapter._model == "claude-opus-4-20250514" + assert adapter._session_id == "s1" # --------------------------------------------------------------------------- @@ -158,25 +112,19 @@ def test_custom_model(self, monkeypatch): class TestLoadHistory: - def test_empty_history(self, monkeypatch): - monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + def test_empty_history(self): repo = _make_db_repo([]) - adapter = StreamingChatAdapter( - session_id="s1", db_repo=repo, workspace_path=Path("/tmp") - ) + adapter = _make_adapter(db_repo=repo) history = adapter._load_history() assert history == [] - def test_converts_messages_to_anthropic_format(self, monkeypatch): - monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + def test_converts_messages(self): stored = [ {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there"}, ] repo = _make_db_repo(stored) - adapter = StreamingChatAdapter( - session_id="s1", db_repo=repo, workspace_path=Path("/tmp") - ) + adapter = _make_adapter(db_repo=repo) history = adapter._load_history() assert history == [ {"role": "user", "content": "Hello"}, @@ -190,15 +138,8 @@ def test_converts_messages_to_anthropic_format(self, monkeypatch): class TestTruncateHistory: - def _make_adapter(self, monkeypatch): - monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") - return StreamingChatAdapter( - session_id="s1", db_repo=_make_db_repo(), workspace_path=Path("/tmp") - ) - - def test_result_starts_with_user_message(self, monkeypatch): - adapter = self._make_adapter(monkeypatch) - # Simulate history that would start with assistant after truncation + def test_result_starts_with_user_message(self): + adapter = _make_adapter() msgs = [ {"role": "user", "content": "a" * 100}, {"role": "assistant", "content": "b" * 100}, @@ -207,12 +148,12 @@ def test_result_starts_with_user_message(self, monkeypatch): result = adapter._truncate_history(msgs) assert result[0]["role"] == "user" - def test_empty_list_unchanged(self, monkeypatch): - adapter = self._make_adapter(monkeypatch) + def test_empty_list_unchanged(self): + adapter = _make_adapter() assert adapter._truncate_history([]) == [] - def test_no_truncation_when_within_budget(self, monkeypatch): - adapter = self._make_adapter(monkeypatch) + def test_no_truncation_when_within_budget(self): + adapter = _make_adapter() msgs = [ {"role": "user", "content": "hello"}, {"role": "assistant", "content": "hi"}, @@ -228,91 +169,62 @@ def test_no_truncation_when_within_budget(self, monkeypatch): @pytest.mark.asyncio class TestSendMessageTextOnly: - async def test_yields_text_delta_events(self, monkeypatch, tmp_path): - monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") - repo = _make_db_repo() - adapter = StreamingChatAdapter( - session_id="s1", db_repo=repo, workspace_path=tmp_path - ) - - sdk_events = [ - _text_event("Hello"), - _text_event(" world"), - _message_stop_event(), - ] - - with patch.object(adapter, "_async_client") as mock_client: - mock_client.messages.stream.return_value = _make_stream_events(sdk_events) - - collected = [] - async for event in adapter.send_message("hi", []): - collected.append(event) + async def test_yields_text_delta_events(self, tmp_path): + provider = MockProvider() + provider.add_stream_chunks([ + StreamChunk(type="text_delta", text="Hello"), + StreamChunk(type="text_delta", text=" world"), + _stop_chunk(), + ]) + adapter = _make_adapter(workspace_path=tmp_path, provider=provider) + collected = [e async for e in adapter.send_message("hi", [])] types = [e.type for e in collected] assert ChatEventType.TEXT_DELTA in types assert ChatEventType.COST_UPDATE in types assert ChatEventType.DONE in types - async def test_text_delta_content(self, monkeypatch, tmp_path): - monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") - repo = _make_db_repo() - adapter = StreamingChatAdapter( - session_id="s1", db_repo=repo, workspace_path=tmp_path - ) + async def test_text_delta_content(self, tmp_path): + provider = MockProvider() + provider.add_stream_chunks([ + StreamChunk(type="text_delta", text="Hello"), + StreamChunk(type="text_delta", text=" world"), + _stop_chunk(), + ]) + adapter = _make_adapter(workspace_path=tmp_path, provider=provider) - sdk_events = [ - _text_event("Hello"), - _text_event(" world"), - _message_stop_event(), + deltas = [ + e.content + async for e in adapter.send_message("hi", []) + if e.type == ChatEventType.TEXT_DELTA ] - - with patch.object(adapter, "_async_client") as mock_client: - mock_client.messages.stream.return_value = _make_stream_events(sdk_events) - - deltas = [ - e.content - async for e in adapter.send_message("hi", []) - if e.type == ChatEventType.TEXT_DELTA - ] - assert deltas == ["Hello", " world"] - async def test_cost_update_has_token_counts(self, monkeypatch, tmp_path): - monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") - repo = _make_db_repo() - adapter = StreamingChatAdapter( - session_id="s1", db_repo=repo, workspace_path=tmp_path - ) - - sdk_events = [_text_event("ok"), _message_stop_event()] - - with patch.object(adapter, "_async_client") as mock_client: - mock_client.messages.stream.return_value = _make_stream_events(sdk_events) - - cost_events = [ - e - async for e in adapter.send_message("hi", []) - if e.type == ChatEventType.COST_UPDATE - ] + async def test_cost_update_has_token_counts(self, tmp_path): + provider = MockProvider() + provider.add_stream_chunks([ + StreamChunk(type="text_delta", text="ok"), + _stop_chunk(input_tokens=10, output_tokens=20), + ]) + adapter = _make_adapter(workspace_path=tmp_path, provider=provider) + cost_events = [ + e + async for e in adapter.send_message("hi", []) + if e.type == ChatEventType.COST_UPDATE + ] assert len(cost_events) == 1 assert cost_events[0].input_tokens == 10 assert cost_events[0].output_tokens == 20 - async def test_done_is_last_event(self, monkeypatch, tmp_path): - monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") - repo = _make_db_repo() - adapter = StreamingChatAdapter( - session_id="s1", db_repo=repo, workspace_path=tmp_path - ) - - sdk_events = [_text_event("hi"), _message_stop_event()] - - with patch.object(adapter, "_async_client") as mock_client: - mock_client.messages.stream.return_value = _make_stream_events(sdk_events) - - collected = [e async for e in adapter.send_message("hi", [])] - + async def test_done_is_last_event(self, tmp_path): + provider = MockProvider() + provider.add_stream_chunks([ + StreamChunk(type="text_delta", text="hi"), + _stop_chunk(), + ]) + adapter = _make_adapter(workspace_path=tmp_path, provider=provider) + collected = [e async for e in adapter.send_message("hi", [])] assert collected[-1].type == ChatEventType.DONE @@ -323,31 +235,39 @@ async def test_done_is_last_event(self, monkeypatch, tmp_path): @pytest.mark.asyncio class TestThinkingEvents: - async def test_yields_thinking_events(self, monkeypatch, tmp_path): - monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") - repo = _make_db_repo() - adapter = StreamingChatAdapter( - session_id="s1", db_repo=repo, workspace_path=tmp_path - ) + async def test_yields_thinking_events(self, tmp_path): + provider = MockProvider() + provider.add_stream_chunks([ + StreamChunk(type="thinking_delta", text="Let me think..."), + StreamChunk(type="text_delta", text="Answer"), + _stop_chunk(), + ]) + adapter = _make_adapter(workspace_path=tmp_path, provider=provider) - sdk_events = [ - _thinking_event("Let me think..."), - _text_event("Answer"), - _message_stop_event(), + thinking_events = [ + e + async for e in adapter.send_message("hi", []) + if e.type == ChatEventType.THINKING ] - - with patch.object(adapter, "_async_client") as mock_client: - mock_client.messages.stream.return_value = _make_stream_events(sdk_events) - - thinking_events = [ - e - async for e in adapter.send_message("hi", []) - if e.type == ChatEventType.THINKING - ] - assert len(thinking_events) == 1 assert thinking_events[0].content == "Let me think..." + async def test_non_anthropic_provider_no_thinking(self, tmp_path): + """Providers that don't emit thinking_delta produce no THINKING events.""" + provider = MockProvider() + provider.add_stream_chunks([ + StreamChunk(type="text_delta", text="Answer"), + _stop_chunk(), + ]) + adapter = _make_adapter(workspace_path=tmp_path, provider=provider) + + thinking_events = [ + e + async for e in adapter.send_message("hi", []) + if e.type == ChatEventType.THINKING + ] + assert thinking_events == [] + # --------------------------------------------------------------------------- # send_message — tool calls @@ -356,43 +276,36 @@ async def test_yields_thinking_events(self, monkeypatch, tmp_path): @pytest.mark.asyncio class TestToolCallEvents: - async def test_yields_tool_use_start_and_result(self, monkeypatch, tmp_path): - monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") - repo = _make_db_repo() - adapter = StreamingChatAdapter( - session_id="s1", db_repo=repo, workspace_path=tmp_path - ) - + async def test_yields_tool_use_start_and_result(self, tmp_path): tool_id = "tool_abc" tool_input = {"path": "README.md"} - # First turn: tool_use stop_reason, second turn: end_turn - first_stream = _make_stream_events([ - _tool_start_event("read_file", tool_id, tool_input), - _message_stop_event(), + provider = MockProvider() + # First turn: tool_use with stop_reason="tool_use" + provider.add_stream_chunks([ + StreamChunk( + type="tool_use_start", + tool_id=tool_id, + tool_name="read_file", + tool_input=tool_input, + ), + StreamChunk(type="tool_use_stop"), + _stop_chunk( + stop_reason="tool_use", + tool_inputs_by_id={tool_id: tool_input}, + ), ]) - # Override stop_reason on first stream's final message - first_stream._final_message.stop_reason = "tool_use" - - second_stream = _make_stream_events([ - _text_event("Here is the file content."), - _message_stop_event(), + # Second turn: text response + provider.add_stream_chunks([ + StreamChunk(type="text_delta", text="Here is the file content."), + _stop_chunk(), ]) - call_count = 0 + adapter = _make_adapter(workspace_path=tmp_path, provider=provider) - def _fake_stream(**kwargs): - nonlocal call_count - call_count += 1 - return first_stream if call_count == 1 else second_stream - - with patch.object(adapter, "_async_client") as mock_client: - mock_client.messages.stream.side_effect = _fake_stream - - with patch.object(adapter, "_execute_tool", new_callable=AsyncMock) as mock_tool: - mock_tool.return_value = "file contents here" - - collected = [e async for e in adapter.send_message("read README", [])] + with patch.object(adapter, "_execute_tool", new_callable=AsyncMock) as mock_tool: + mock_tool.return_value = "file contents here" + collected = [e async for e in adapter.send_message("read README", [])] types = [e.type for e in collected] assert ChatEventType.TOOL_USE_START in types @@ -413,54 +326,29 @@ def _fake_stream(**kwargs): @pytest.mark.asyncio class TestInterrupt: - async def test_interrupt_event_stops_stream(self, monkeypatch, tmp_path): - monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") - repo = _make_db_repo() - adapter = StreamingChatAdapter( - session_id="s1", db_repo=repo, workspace_path=tmp_path - ) - + async def test_interrupt_event_stops_stream(self, tmp_path): interrupt = asyncio.Event() - class _SlowStream: - def __init__(self): - self._final_message = MagicMock() - self._final_message.usage.input_tokens = 0 - self._final_message.usage.output_tokens = 0 - self._final_message.stop_reason = "end_turn" - - async def __aenter__(self): - return self - - async def __aexit__(self, *_): - pass - - def __aiter__(self): - return self._iter() - - async def _iter(self): - for i in range(100): - if interrupt.is_set(): - return - evt = MagicMock() - evt.type = "content_block_delta" - evt.delta = MagicMock() - evt.delta.type = "text_delta" - evt.delta.text = f"chunk{i}" - yield evt - # Set interrupt mid-stream - if i == 2: - interrupt.set() - - def get_final_message(self): - return self._final_message - - with patch.object(adapter, "_async_client") as mock_client: - mock_client.messages.stream.return_value = _SlowStream() - - collected = [e async for e in adapter.send_message("hi", [], interrupt)] - - # Should stop well before 100 chunks + # Build a lot of chunks so interrupt has time to fire + chunks = [StreamChunk(type="text_delta", text=f"chunk{i}") for i in range(100)] + # Inject message_stop at the end + chunks.append(_stop_chunk()) + + provider = MockProvider() + + async def _slow_stream(*args, **kwargs): + for i, chunk in enumerate(chunks): + if interrupt.is_set(): + return + yield chunk + if i == 2: + interrupt.set() + + provider.async_stream = _slow_stream + + adapter = _make_adapter(workspace_path=tmp_path, provider=provider) + collected = [e async for e in adapter.send_message("hi", [], interrupt)] + text_deltas = [e for e in collected if e.type == ChatEventType.TEXT_DELTA] assert len(text_deltas) < 20 @@ -472,19 +360,15 @@ def get_final_message(self): @pytest.mark.asyncio class TestPersistence: - async def test_messages_persisted_after_turn(self, monkeypatch, tmp_path): - monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + async def test_messages_persisted_after_turn(self, tmp_path): repo = _make_db_repo() - adapter = StreamingChatAdapter( - session_id="s1", db_repo=repo, workspace_path=tmp_path - ) - - sdk_events = [_text_event("response text"), _message_stop_event()] - - with patch.object(adapter, "_async_client") as mock_client: - mock_client.messages.stream.return_value = _make_stream_events(sdk_events) - - _ = [e async for e in adapter.send_message("user input", [])] + provider = MockProvider() + provider.add_stream_chunks([ + StreamChunk(type="text_delta", text="response text"), + _stop_chunk(), + ]) + adapter = _make_adapter(db_repo=repo, workspace_path=tmp_path, provider=provider) + _ = [e async for e in adapter.send_message("user input", [])] # Two calls: one for user message, one for assistant message assert repo.add_message.call_count == 2 @@ -501,31 +385,18 @@ async def test_messages_persisted_after_turn(self, monkeypatch, tmp_path): @pytest.mark.asyncio class TestErrorHandling: - async def test_yields_error_event_on_api_failure(self, monkeypatch, tmp_path): - monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") - repo = _make_db_repo() - adapter = StreamingChatAdapter( - session_id="s1", db_repo=repo, workspace_path=tmp_path - ) - - class _ErrorStream: - async def __aenter__(self): - return self - - async def __aexit__(self, *_): - pass - - def __aiter__(self): - return self._iter() + async def test_yields_error_event_on_api_failure(self, tmp_path): + provider = MockProvider() - async def _iter(self): - raise RuntimeError("API failure") - yield # make it a generator + async def _error_stream(*args, **kwargs): + raise RuntimeError("API failure") + if False: + yield # make it an async generator - with patch.object(adapter, "_async_client") as mock_client: - mock_client.messages.stream.return_value = _ErrorStream() + provider.async_stream = _error_stream - collected = [e async for e in adapter.send_message("hi", [])] + adapter = _make_adapter(workspace_path=tmp_path, provider=provider) + collected = [e async for e in adapter.send_message("hi", [])] error_events = [e for e in collected if e.type == ChatEventType.ERROR] assert len(error_events) == 1 From a57085fe10541a1ded38f03b1d6458ba823341db Mon Sep 17 00:00:00 2001 From: Test User Date: Sun, 5 Apr 2026 23:10:20 -0700 Subject: [PATCH 2/3] fix(adapters): address claude-review and CodeRabbit feedback on #548 - Add extended_thinking param to async_stream() across all providers; streaming_chat.py now passes supports("extended_thinking") flag so providers can gate thinking tokens without checking capability themselves - AnthropicProvider: add betas=[interleaved-thinking] when extended_thinking=True - OpenAIProvider: defer tool_use_start emission until both id+name are known - OpenAIProvider: add exception handling (LLMAuthError/RateLimitError/ LLMConnectionError) to async_stream() matching async_complete() semantics - OpenAIProvider: extract _convert_raw_tools() helper to avoid duplicate tool conversion logic; add logger + warning on JSON decode failures - streaming_chat.py: remove workspace absolute path from system prompt; replace with generic "current workspace" phrase - base.py + StreamChunk docstring: document tool_use_stop ordering difference between Anthropic (interleaved) and OpenAI (batched at end) - base.py: add comment explaining why async_stream is not @abstractmethod - streaming_chat.py: add backward-compat comment on provider=None fallback --- codeframe/adapters/llm/anthropic.py | 13 ++ codeframe/adapters/llm/base.py | 22 +++ codeframe/adapters/llm/mock.py | 1 + codeframe/adapters/llm/openai.py | 171 ++++++++++++++-------- codeframe/core/adapters/streaming_chat.py | 16 +- 5 files changed, 154 insertions(+), 69 deletions(-) diff --git a/codeframe/adapters/llm/anthropic.py b/codeframe/adapters/llm/anthropic.py index 7b7d24d2..82fec1a2 100644 --- a/codeframe/adapters/llm/anthropic.py +++ b/codeframe/adapters/llm/anthropic.py @@ -186,12 +186,17 @@ async def async_stream( model: str, max_tokens: int, interrupt_event: Optional[asyncio.Event] = None, + extended_thinking: bool = False, ) -> AsyncIterator[StreamChunk]: """Stream using Anthropic AsyncAnthropic SDK, yielding StreamChunk objects. Translates Anthropic SDK events into the normalized StreamChunk format. Tool inputs are collected and emitted in the final message_stop chunk via tool_inputs_by_id, which is more reliable than streaming input deltas. + + When ``extended_thinking=True``, requests interleaved thinking via the + Anthropic betas API. The flag is silently ignored on SDK versions that + do not support it. """ from anthropic import AsyncAnthropic @@ -206,6 +211,14 @@ async def async_stream( "max_tokens": max_tokens, } + if extended_thinking: + # interleaved-thinking requires the beta header; degrade gracefully + # if the running SDK version doesn't recognise the param. + try: + kwargs["betas"] = ["interleaved-thinking-2025-05-14"] + except Exception: # pragma: no cover + pass + active_tool_id: Optional[str] = None async with self._async_client.messages.stream(**kwargs) as stream: diff --git a/codeframe/adapters/llm/base.py b/codeframe/adapters/llm/base.py index 104866a2..9cd62a73 100644 --- a/codeframe/adapters/llm/base.py +++ b/codeframe/adapters/llm/base.py @@ -141,6 +141,19 @@ class StreamChunk: tool_inputs_by_id: Mapping of tool_id → final input dict, populated for ``message_stop``. More reliable than streaming incremental input deltas. + + .. note:: ``tool_use_stop`` ordering differs by provider: + + - **Anthropic**: emitted immediately when each tool call's content + block ends (``content_block_stop`` event), so consumers see + ``tool_use_start → [deltas] → tool_use_stop`` interleaved. + - **OpenAI-compatible**: emitted after the full stream ends (before + ``message_stop``), because the SSE protocol has no per-tool stop + marker. All ``tool_use_stop`` chunks arrive together at the end. + + Consumers MUST use ``tool_inputs_by_id`` from the ``message_stop`` + chunk for final tool inputs rather than relying on ``tool_use_stop`` + ordering. """ type: str @@ -377,6 +390,10 @@ def supports(self, capability: str) -> bool: """ return False + # Not decorated with @abstractmethod intentionally: providers that only + # support synchronous completion (e.g. thin wrappers) don't need to + # implement streaming. Calling async_stream() on such a provider raises + # NotImplementedError at call time rather than at instantiation. async def async_stream( self, messages: list[dict], @@ -385,6 +402,7 @@ async def async_stream( model: str, max_tokens: int, interrupt_event: Optional[asyncio.Event] = None, + extended_thinking: bool = False, ) -> AsyncIterator["StreamChunk"]: """Stream a completion as normalized :class:`StreamChunk` objects. @@ -399,6 +417,10 @@ async def async_stream( max_tokens: Maximum output tokens. interrupt_event: When set, the stream should stop at the next opportunity. + extended_thinking: When ``True``, request extended thinking tokens + from providers that support them (see :meth:`supports`). + Providers that do not support this capability should silently + ignore the flag. Yields: :class:`StreamChunk` objects in order of generation. diff --git a/codeframe/adapters/llm/mock.py b/codeframe/adapters/llm/mock.py index 408b04b6..9f80afe2 100644 --- a/codeframe/adapters/llm/mock.py +++ b/codeframe/adapters/llm/mock.py @@ -195,6 +195,7 @@ async def async_stream( model: str, max_tokens: int, interrupt_event: Optional[asyncio.Event] = None, + extended_thinking: bool = False, ) -> AsyncIterator[StreamChunk]: """Yield pre-configured StreamChunk sequences for testing. diff --git a/codeframe/adapters/llm/openai.py b/codeframe/adapters/llm/openai.py index c17b8b93..c581ea6d 100644 --- a/codeframe/adapters/llm/openai.py +++ b/codeframe/adapters/llm/openai.py @@ -6,6 +6,7 @@ import asyncio import json +import logging import os from typing import TYPE_CHECKING, AsyncIterator, Iterator, Optional @@ -24,6 +25,8 @@ if TYPE_CHECKING: from codeframe.core.credentials import CredentialManager +logger = logging.getLogger(__name__) + _STOP_REASON_MAP = { "stop": "end_turn", "tool_calls": "tool_use", @@ -206,14 +209,24 @@ async def async_stream( model: str, max_tokens: int, interrupt_event: Optional[asyncio.Event] = None, + extended_thinking: bool = False, ) -> AsyncIterator[StreamChunk]: """Stream using OpenAI async client, yielding StreamChunk objects. Translates OpenAI SSE chunks into the normalized StreamChunk format. - Tool calls are emitted as tool_use_start chunks; final inputs are - collected and emitted in the message_stop chunk via tool_inputs_by_id. + Tool calls are emitted as tool_use_start chunks (deferred until both + id and name are known); final inputs are collected and emitted in the + message_stop chunk via tool_inputs_by_id. + + ``extended_thinking`` is silently ignored — OpenAI-compatible endpoints + do not support Anthropic extended thinking. """ import openai as _openai + from codeframe.adapters.llm.base import ( + LLMAuthError, + LLMConnectionError, + LLMRateLimitError, + ) if self._async_client is None: self._async_client = _openai.AsyncOpenAI( @@ -232,84 +245,92 @@ async def async_stream( "stream_options": {"include_usage": True}, } - # Convert tools from Anthropic-style (input_schema) to OpenAI format if tools: - kwargs["tools"] = [ - { - "type": "function", - "function": { - "name": t["name"], - "description": t.get("description", ""), - "parameters": t.get("input_schema", {}), - }, - } - for t in tools - ] + kwargs["tools"] = self._convert_raw_tools(tools) kwargs["tool_choice"] = "auto" - # Track partial tool calls across chunks (OpenAI streams them incrementally) - # key: index → {id, name, arguments_parts} + # Track partial tool calls across chunks (OpenAI streams them incrementally). + # key: index → {id, name, arguments_parts, emitted_start} partial_tool_calls: dict[int, dict] = {} usage_input: int = 0 usage_output: int = 0 stop_reason: str = "end_turn" - async for chunk in await self._async_client.chat.completions.create(**kwargs): - if interrupt_event and interrupt_event.is_set(): - return - - # Usage is in the final chunk when stream_options.include_usage is set - if chunk.usage is not None: - usage_input = chunk.usage.prompt_tokens or 0 - usage_output = chunk.usage.completion_tokens or 0 - - if not chunk.choices: - continue - - choice = chunk.choices[0] - delta = choice.delta - - if choice.finish_reason: - stop_reason = _STOP_REASON_MAP.get(choice.finish_reason, choice.finish_reason) - - if delta.content: - yield StreamChunk(type="text_delta", text=delta.content) + try: + async for chunk in await self._async_client.chat.completions.create(**kwargs): + if interrupt_event and interrupt_event.is_set(): + return + + # Usage is in the final chunk when stream_options.include_usage is set + if chunk.usage is not None: + usage_input = chunk.usage.prompt_tokens or 0 + usage_output = chunk.usage.completion_tokens or 0 + + if not chunk.choices: + continue + + choice = chunk.choices[0] + delta = choice.delta + + if choice.finish_reason: + stop_reason = _STOP_REASON_MAP.get(choice.finish_reason, choice.finish_reason) + + if delta.content: + yield StreamChunk(type="text_delta", text=delta.content) + + if delta.tool_calls: + for tc_delta in delta.tool_calls: + idx = tc_delta.index + if idx not in partial_tool_calls: + partial_tool_calls[idx] = { + "id": tc_delta.id or "", + "name": (tc_delta.function.name if tc_delta.function else ""), + "arguments_parts": [], + "emitted_start": False, + } + else: + # Accumulate id/name as they arrive across deltas + if tc_delta.id: + partial_tool_calls[idx]["id"] = tc_delta.id + if tc_delta.function and tc_delta.function.name: + partial_tool_calls[idx]["name"] = tc_delta.function.name + + if tc_delta.function and tc_delta.function.arguments: + partial_tool_calls[idx]["arguments_parts"].append( + tc_delta.function.arguments + ) + + # Defer tool_use_start until both id and name are known + tc_info = partial_tool_calls[idx] + if not tc_info["emitted_start"] and tc_info["id"] and tc_info["name"]: + yield StreamChunk( + type="tool_use_start", + tool_id=tc_info["id"], + tool_name=tc_info["name"], + tool_input={}, + ) + tc_info["emitted_start"] = True - if delta.tool_calls: - for tc_delta in delta.tool_calls: - idx = tc_delta.index - if idx not in partial_tool_calls: - partial_tool_calls[idx] = { - "id": tc_delta.id or "", - "name": tc_delta.function.name if tc_delta.function else "", - "arguments_parts": [], - } - # Emit start immediately when we first see the tool call - yield StreamChunk( - type="tool_use_start", - tool_id=tc_delta.id or "", - tool_name=(tc_delta.function.name if tc_delta.function else ""), - tool_input={}, - ) - else: - # Accumulate id/name updates - if tc_delta.id: - partial_tool_calls[idx]["id"] = tc_delta.id - if tc_delta.function and tc_delta.function.name: - partial_tool_calls[idx]["name"] = tc_delta.function.name - if tc_delta.function and tc_delta.function.arguments: - partial_tool_calls[idx]["arguments_parts"].append( - tc_delta.function.arguments - ) + except _openai.AuthenticationError as exc: + raise LLMAuthError(str(exc)) from exc + except _openai.RateLimitError as exc: + raise LLMRateLimitError(str(exc)) from exc + except _openai.APIConnectionError as exc: + raise LLMConnectionError(str(exc)) from exc # Build tool_inputs_by_id from accumulated partial tool calls tool_inputs_by_id: dict = {} for tc in partial_tool_calls.values(): + raw_args = "".join(tc["arguments_parts"]) or "{}" try: - tool_inputs_by_id[tc["id"]] = json.loads( - "".join(tc["arguments_parts"]) or "{}" - ) + tool_inputs_by_id[tc["id"]] = json.loads(raw_args) except json.JSONDecodeError: + logger.warning( + "Failed to parse tool arguments for tool '%s' (id=%s): %r", + tc["name"], + tc["id"], + raw_args, + ) tool_inputs_by_id[tc["id"]] = {} # Emit tool_use_stop for each completed tool call yield StreamChunk(type="tool_use_stop") @@ -412,6 +433,26 @@ def _convert_tools(self, tools: list[Tool]) -> list[dict]: for tool in tools ] + def _convert_raw_tools(self, tools: list[dict]) -> list[dict]: + """Convert already-serialized tool dicts (Anthropic-style) to OpenAI format. + + The ``async_stream()`` interface receives tools as ``list[dict]`` with an + ``input_schema`` key (Anthropic API format). This helper converts them to + the OpenAI ``function`` calling format, mirroring :meth:`_convert_tools` + for raw dicts instead of :class:`Tool` objects. + """ + return [ + { + "type": "function", + "function": { + "name": t["name"], + "description": t.get("description", ""), + "parameters": t.get("input_schema", {}), + }, + } + for t in tools + ] + def _parse_response(self, response) -> LLMResponse: """Parse OpenAI ChatCompletion into LLMResponse.""" choice = response.choices[0] diff --git a/codeframe/core/adapters/streaming_chat.py b/codeframe/core/adapters/streaming_chat.py index 1fe98640..3a3fdec1 100644 --- a/codeframe/core/adapters/streaming_chat.py +++ b/codeframe/core/adapters/streaming_chat.py @@ -169,6 +169,11 @@ def __init__( available. """ if provider is None: + # Backward-compatibility fallback: callers that haven't been + # updated to pass an explicit provider (e.g. tests using the old + # api_key= constructor argument) still get an AnthropicProvider. + # New callers should construct the provider themselves and pass it + # in — see session_chat_ws.py for the recommended pattern. from codeframe.adapters.llm.anthropic import AnthropicProvider provider = AnthropicProvider(api_key=api_key) @@ -350,12 +355,14 @@ async def _stream_turn( current_messages = list(messages) system_prompt = ( - f"You are a CodeFrame assistant helping the user understand and navigate " - f"their codebase. You have read-only access to the workspace at " - f"{self._workspace_path}. Available tools: read_file, list_files, " - f"search_codebase. Do not attempt to modify files or execute shell commands." + "You are a CodeFrame assistant helping the user understand and navigate " + "their codebase. You have read-only access to the current workspace. " + "Available tools: read_file, list_files, search_codebase. " + "Do not attempt to modify files or execute shell commands." ) + use_extended_thinking = self._provider.supports("extended_thinking") + while True: pending_tool_calls: list[dict] = [] # {id, name, input} stop_reason = "end_turn" @@ -367,6 +374,7 @@ async def _stream_turn( model=self._model, max_tokens=4096, interrupt_event=interrupt_event, + extended_thinking=use_extended_thinking, ): if interrupt_event and interrupt_event.is_set(): return From a60e078a8ca3db58b2f46c6128fdc8444346651a Mon Sep 17 00:00:00 2001 From: Test User Date: Mon, 6 Apr 2026 17:05:42 -0700 Subject: [PATCH 3/3] fix(adapters): fix get_final_message await + convert_messages + mock tracking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - anthropic.py: await stream.get_final_message() — it IS async in the SDK (get_final_message is a coroutine; the test mock masked this) - anthropic.py: call _convert_messages() before streaming to normalize tool_calls/tool_results payload format - anthropic.py: move extended_thinking betas retry to stream() call site rather than the dict assignment - mock.py: async_stream() now tracks calls in self.calls (same metadata as complete()) and derives default chunks from response queue / response_handler / default_response - tests/adapters/test_llm_async.py: add 5 new async_stream() lifecycle tests covering default chunks, preconfigured chunks, call tracking, interrupt, and extended_thinking flag propagation --- codeframe/adapters/llm/anthropic.py | 27 +++++--- codeframe/adapters/llm/mock.py | 36 +++++++++-- tests/adapters/test_llm_async.py | 98 +++++++++++++++++++++++++++++ 3 files changed, 147 insertions(+), 14 deletions(-) diff --git a/codeframe/adapters/llm/anthropic.py b/codeframe/adapters/llm/anthropic.py index 82fec1a2..7d52b702 100644 --- a/codeframe/adapters/llm/anthropic.py +++ b/codeframe/adapters/llm/anthropic.py @@ -203,25 +203,34 @@ async def async_stream( if self._async_client is None: self._async_client = AsyncAnthropic(api_key=self.api_key) + # Convert messages to Anthropic API format (handles tool_calls/tool_results) + converted = self._convert_messages(messages) + kwargs: dict = { "model": model, "system": system, - "messages": messages, + "messages": converted, "tools": tools, "max_tokens": max_tokens, } if extended_thinking: - # interleaved-thinking requires the beta header; degrade gracefully - # if the running SDK version doesn't recognise the param. - try: - kwargs["betas"] = ["interleaved-thinking-2025-05-14"] - except Exception: # pragma: no cover - pass + kwargs["betas"] = ["interleaved-thinking-2025-05-14"] active_tool_id: Optional[str] = None - async with self._async_client.messages.stream(**kwargs) as stream: + # When extended_thinking is set, the beta header may be unsupported on + # older SDK versions. Retry without it rather than hard-failing. + try: + stream_ctx = self._async_client.messages.stream(**kwargs) + except Exception: # pragma: no cover + if extended_thinking: + kwargs.pop("betas", None) + stream_ctx = self._async_client.messages.stream(**kwargs) + else: + raise + + async with stream_ctx as stream: async for sdk_event in stream: if interrupt_event and interrupt_event.is_set(): return @@ -258,7 +267,7 @@ async def async_stream( yield StreamChunk(type="tool_use_stop") active_tool_id = None - final_msg = stream.get_final_message() + final_msg = await stream.get_final_message() stop_reason = final_msg.stop_reason or "end_turn" # Build tool_inputs_by_id from final content blocks diff --git a/codeframe/adapters/llm/mock.py b/codeframe/adapters/llm/mock.py index 9f80afe2..4545912b 100644 --- a/codeframe/adapters/llm/mock.py +++ b/codeframe/adapters/llm/mock.py @@ -199,24 +199,50 @@ async def async_stream( ) -> AsyncIterator[StreamChunk]: """Yield pre-configured StreamChunk sequences for testing. - Falls back to a minimal text_delta + message_stop pair when no - stream chunks have been configured via add_stream_chunks(). + Tracks each call in :attr:`calls` (same metadata as :meth:`complete`). + When pre-configured ``stream_chunks`` are available, yields them in + order. Otherwise falls back to a minimal ``text_delta`` + + ``message_stop`` pair derived from the normal response queue + (``responses`` / ``response_handler`` / ``default_response``). """ + # Track the call so tests can assert on it + self.calls.append( + { + "messages": messages, + "system": system, + "tools": tools, + "model": model, + "max_tokens": max_tokens, + "extended_thinking": extended_thinking, + } + ) + if self.stream_index < len(self.stream_chunks): chunks = self.stream_chunks[self.stream_index] self.stream_index += 1 else: - # Default: simple text response followed by message_stop + # Derive response text from the normal queue / handler + if self.response_handler: + resp = self.response_handler(messages) + text = resp.content + elif self.response_index < len(self.responses): + resp = self.responses[self.response_index] + self.response_index += 1 + text = resp.content + else: + text = self.default_response + chunks = [ - StreamChunk(type="text_delta", text=self.default_response), + StreamChunk(type="text_delta", text=text), StreamChunk( type="message_stop", stop_reason="end_turn", input_tokens=len(str(messages)), - output_tokens=len(self.default_response), + output_tokens=len(text), tool_inputs_by_id={}, ), ] + for chunk in chunks: if interrupt_event and interrupt_event.is_set(): return diff --git a/tests/adapters/test_llm_async.py b/tests/adapters/test_llm_async.py index 4d8fa6f8..ff573386 100644 --- a/tests/adapters/test_llm_async.py +++ b/tests/adapters/test_llm_async.py @@ -47,6 +47,104 @@ async def test_async_complete_tracks_call(self): assert provider.call_count == 2 +class TestMockProviderAsyncStream: + """async_stream() tests using MockProvider.""" + + @pytest.mark.asyncio + async def test_async_stream_yields_default_chunks(self): + """Default async_stream yields text_delta + message_stop.""" + provider = MockProvider(default_response="streamed reply") + chunks = [ + chunk + async for chunk in provider.async_stream( + messages=[{"role": "user", "content": "hi"}], + system="", + tools=[], + model="mock", + max_tokens=100, + ) + ] + types = [c.type for c in chunks] + assert "text_delta" in types + assert "message_stop" in types + text = next(c.text for c in chunks if c.type == "text_delta") + assert text == "streamed reply" + + @pytest.mark.asyncio + async def test_async_stream_uses_preconfigured_chunks(self): + """add_stream_chunks() controls what async_stream yields.""" + from codeframe.adapters.llm.base import StreamChunk + + provider = MockProvider() + provider.add_stream_chunks([ + StreamChunk(type="text_delta", text="custom"), + StreamChunk(type="message_stop", stop_reason="end_turn", + input_tokens=1, output_tokens=1, tool_inputs_by_id={}), + ]) + chunks = [ + c + async for c in provider.async_stream( + messages=[], system="", tools=[], model="mock", max_tokens=10 + ) + ] + assert chunks[0].text == "custom" + assert chunks[1].type == "message_stop" + + @pytest.mark.asyncio + async def test_async_stream_tracks_call(self): + """async_stream records the call in provider.calls.""" + provider = MockProvider() + _ = [ + c + async for c in provider.async_stream( + messages=[{"role": "user", "content": "hi"}], + system="sys", + tools=[], + model="mock-model", + max_tokens=50, + ) + ] + assert provider.call_count == 1 + assert provider.last_call["model"] == "mock-model" + + @pytest.mark.asyncio + async def test_async_stream_honours_interrupt(self): + """async_stream stops early when interrupt_event is set.""" + import asyncio + from codeframe.adapters.llm.base import StreamChunk + + interrupt = asyncio.Event() + provider = MockProvider() + provider.add_stream_chunks([ + StreamChunk(type="text_delta", text="a"), + StreamChunk(type="text_delta", text="b"), + StreamChunk(type="message_stop", stop_reason="end_turn", + input_tokens=0, output_tokens=0, tool_inputs_by_id={}), + ]) + interrupt.set() + chunks = [ + c + async for c in provider.async_stream( + messages=[], system="", tools=[], model="m", max_tokens=10, + interrupt_event=interrupt, + ) + ] + assert chunks == [] + + @pytest.mark.asyncio + async def test_async_stream_supports_extended_thinking_param(self): + """extended_thinking param is accepted and stored in call metadata.""" + provider = MockProvider() + _ = [ + c + async for c in provider.async_stream( + messages=[], system="", tools=[], model="m", max_tokens=10, + extended_thinking=True, + ) + ] + assert provider.last_call["extended_thinking"] is True + + class TestLLMExceptions: """Common LLM exception hierarchy."""