From 906846e9e4ea402b2b0a1b1fea3124cc18a32c9f Mon Sep 17 00:00:00 2001 From: Aniket Dixit Date: Thu, 5 Mar 2026 15:08:36 +0530 Subject: [PATCH 1/6] langchain updates --- src/opengradient/agents/__init__.py | 11 +- src/opengradient/agents/og_langchain.py | 162 +++++++++++++++++++----- 2 files changed, 136 insertions(+), 37 deletions(-) diff --git a/src/opengradient/agents/__init__.py b/src/opengradient/agents/__init__.py index a2dd27c..e2df564 100644 --- a/src/opengradient/agents/__init__.py +++ b/src/opengradient/agents/__init__.py @@ -6,13 +6,17 @@ into existing applications and agent frameworks. """ +from typing import Optional, Union + +from ..client import Client from ..types import TEE_LLM, x402SettlementMode from .og_langchain import * def langchain_adapter( - private_key: str, - model_cid: TEE_LLM, + model_cid: Union[TEE_LLM, str], + private_key: Optional[str] = None, + client: Optional[Client] = None, max_tokens: int = 300, x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH, ) -> OpenGradientChatModel: @@ -21,8 +25,9 @@ def langchain_adapter( and can be plugged into LangChain agents. """ return OpenGradientChatModel( - private_key=private_key, model_cid=model_cid, + private_key=private_key, + client=client, max_tokens=max_tokens, x402_settlement_mode=x402_settlement_mode, ) diff --git a/src/opengradient/agents/og_langchain.py b/src/opengradient/agents/og_langchain.py index b62e443..48b20c5 100644 --- a/src/opengradient/agents/og_langchain.py +++ b/src/opengradient/agents/og_langchain.py @@ -1,12 +1,14 @@ # mypy: ignore-errors +import asyncio import json -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Sequence, Union -from langchain_core.callbacks.manager import CallbackManagerForLLMRun +from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun from langchain_core.language_models.base import LanguageModelInput from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( AIMessage, + AIMessageChunk, BaseMessage, HumanMessage, SystemMessage, @@ -15,6 +17,7 @@ from langchain_core.messages.tool import ToolMessage from langchain_core.outputs import ( ChatGeneration, + ChatGenerationChunk, ChatResult, ) from langchain_core.runnables import Runnable @@ -25,6 +28,7 @@ from ..types import TEE_LLM, x402SettlementMode __all__ = ["OpenGradientChatModel"] +_STREAM_END = object() def _extract_content(content: Any) -> str: @@ -69,17 +73,18 @@ def _parse_tool_call(tool_call: Dict) -> ToolCall: class OpenGradientChatModel(BaseChatModel): """OpenGradient adapter class for LangChain chat model""" - model_cid: str + model_cid: Union[TEE_LLM, str] max_tokens: int = 300 - x402_settlement_mode: Optional[str] = x402SettlementMode.SETTLE_BATCH + x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.SETTLE_BATCH _client: Client = PrivateAttr() _tools: List[Dict] = PrivateAttr(default_factory=list) def __init__( self, - private_key: str, - model_cid: TEE_LLM, + model_cid: Union[TEE_LLM, str], + private_key: Optional[str] = None, + client: Optional[Client] = None, max_tokens: int = 300, x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.SETTLE_BATCH, **kwargs, @@ -90,7 +95,13 @@ def __init__( x402_settlement_mode=x402_settlement_mode, **kwargs, ) - self._client = Client(private_key=private_key) + if client is not None and private_key is not None: + raise ValueError("Pass either client or private_key, not both.") + if client is None: + if private_key is None: + raise ValueError("Either client or private_key must be provided.") + client = Client(private_key=private_key) + self._client = client @property def _llm_type(self) -> str: @@ -136,7 +147,100 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - sdk_messages = [] + sdk_messages = self._to_sdk_messages(messages) + + chat_output = self._client.llm.chat( + model=self.model_cid, + messages=sdk_messages, + stop_sequence=stop, + max_tokens=self.max_tokens, + tools=self._tools, + x402_settlement_mode=self.x402_settlement_mode, + ) + + finish_reason = chat_output.finish_reason or "" + chat_response = chat_output.chat_output or {} + + if chat_response.get("tool_calls"): + tool_calls = [_parse_tool_call(tc) for tc in chat_response["tool_calls"]] + ai_message = AIMessage(content="", tool_calls=tool_calls) + else: + ai_message = AIMessage(content=_extract_content(chat_response.get("content", ""))) + + return ChatResult(generations=[ChatGeneration(message=ai_message, generation_info={"finish_reason": finish_reason})]) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + sdk_messages = self._to_sdk_messages(messages) + stream = self._client.llm.chat( + model=self.model_cid, + messages=sdk_messages, + stop_sequence=stop, + max_tokens=self.max_tokens, + tools=self._tools, + x402_settlement_mode=self.x402_settlement_mode, + stream=True, + ) + + for chunk in stream: + if not chunk.choices: + continue + + choice = chunk.choices[0] + delta = choice.delta + content = _extract_content(delta.content) + + additional_kwargs: Dict[str, Any] = {} + if delta.tool_calls: + additional_kwargs["tool_calls"] = delta.tool_calls + + chunk_kwargs: Dict[str, Any] = { + "content": content, + "additional_kwargs": additional_kwargs, + } + if chunk.usage: + chunk_kwargs["usage_metadata"] = { + "input_tokens": chunk.usage.prompt_tokens, + "output_tokens": chunk.usage.completion_tokens, + "total_tokens": chunk.usage.total_tokens, + } + + generation_info = {"finish_reason": choice.finish_reason} if choice.finish_reason else None + yield ChatGenerationChunk( + message=AIMessageChunk(**chunk_kwargs), + generation_info=generation_info, + ) + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + # Bridge the sync iterator from the SDK to LangChain's async streaming API. + iterator = self._stream(messages=messages, stop=stop, **kwargs) + while True: + # Use next(..., default) so StopIteration does not cross Future boundaries. + chunk = await asyncio.to_thread(next, iterator, _STREAM_END) + if chunk is _STREAM_END: + break + yield chunk + + @property + def _identifying_params(self) -> Dict[str, Any]: + return { + "model_name": self.model_cid, + } + + @staticmethod + def _to_sdk_messages(messages: List[Any]) -> List[Dict[str, Any]]: + sdk_messages: List[Dict[str, Any]] = [] for message in messages: if isinstance(message, SystemMessage): sdk_messages.append({"role": "system", "content": _extract_content(message.content)}) @@ -162,31 +266,21 @@ def _generate( "tool_call_id": message.tool_call_id, } ) - else: - raise ValueError(f"Unexpected message type: {message}") + elif isinstance(message, dict): + role = message.get("role") + if role not in {"system", "user", "assistant", "tool"}: + raise ValueError(f"Unexpected message role in dict message: {role}") - chat_output = self._client.llm.chat( - model=self.model_cid, - messages=sdk_messages, - stop_sequence=stop, - max_tokens=self.max_tokens, - tools=self._tools, - x402_settlement_mode=self.x402_settlement_mode, - ) - - finish_reason = chat_output.finish_reason or "" - chat_response = chat_output.chat_output or {} - - if chat_response.get("tool_calls"): - tool_calls = [_parse_tool_call(tc) for tc in chat_response["tool_calls"]] - ai_message = AIMessage(content="", tool_calls=tool_calls) - else: - ai_message = AIMessage(content=_extract_content(chat_response.get("content", ""))) + sdk_message: Dict[str, Any] = { + "role": role, + "content": _extract_content(message.get("content", "")), + } + if role == "assistant" and message.get("tool_calls"): + sdk_message["tool_calls"] = message["tool_calls"] + if role == "tool" and message.get("tool_call_id"): + sdk_message["tool_call_id"] = message["tool_call_id"] - return ChatResult(generations=[ChatGeneration(message=ai_message, generation_info={"finish_reason": finish_reason})]) - - @property - def _identifying_params(self) -> Dict[str, Any]: - return { - "model_name": self.model_cid, - } + sdk_messages.append(sdk_message) + else: + raise ValueError(f"Unexpected message type: {message}") + return sdk_messages From 629522744e9cf37427f177db746ce0048be0acd0 Mon Sep 17 00:00:00 2001 From: Aniket Dixit Date: Thu, 5 Mar 2026 15:28:54 +0530 Subject: [PATCH 2/6] langchain updates --- src/opengradient/agents/__init__.py | 2 + src/opengradient/agents/og_langchain.py | 6 ++ src/opengradient/client/llm.py | 137 ++++++++++++++++++------ 3 files changed, 112 insertions(+), 33 deletions(-) diff --git a/src/opengradient/agents/__init__.py b/src/opengradient/agents/__init__.py index e2df564..5a15424 100644 --- a/src/opengradient/agents/__init__.py +++ b/src/opengradient/agents/__init__.py @@ -17,6 +17,7 @@ def langchain_adapter( model_cid: Union[TEE_LLM, str], private_key: Optional[str] = None, client: Optional[Client] = None, + temperature: float = 0.0, max_tokens: int = 300, x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH, ) -> OpenGradientChatModel: @@ -28,6 +29,7 @@ def langchain_adapter( model_cid=model_cid, private_key=private_key, client=client, + temperature=temperature, max_tokens=max_tokens, x402_settlement_mode=x402_settlement_mode, ) diff --git a/src/opengradient/agents/og_langchain.py b/src/opengradient/agents/og_langchain.py index 48b20c5..7b35f7a 100644 --- a/src/opengradient/agents/og_langchain.py +++ b/src/opengradient/agents/og_langchain.py @@ -74,6 +74,7 @@ class OpenGradientChatModel(BaseChatModel): """OpenGradient adapter class for LangChain chat model""" model_cid: Union[TEE_LLM, str] + temperature: float = 0.0 max_tokens: int = 300 x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.SETTLE_BATCH @@ -85,12 +86,14 @@ def __init__( model_cid: Union[TEE_LLM, str], private_key: Optional[str] = None, client: Optional[Client] = None, + temperature: float = 0.0, max_tokens: int = 300, x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.SETTLE_BATCH, **kwargs, ): super().__init__( model_cid=model_cid, + temperature=temperature, max_tokens=max_tokens, x402_settlement_mode=x402_settlement_mode, **kwargs, @@ -153,6 +156,7 @@ def _generate( model=self.model_cid, messages=sdk_messages, stop_sequence=stop, + temperature=self.temperature, max_tokens=self.max_tokens, tools=self._tools, x402_settlement_mode=self.x402_settlement_mode, @@ -181,6 +185,7 @@ def _stream( model=self.model_cid, messages=sdk_messages, stop_sequence=stop, + temperature=self.temperature, max_tokens=self.max_tokens, tools=self._tools, x402_settlement_mode=self.x402_settlement_mode, @@ -236,6 +241,7 @@ async def _astream( def _identifying_params(self) -> Dict[str, Any]: return { "model_name": self.model_cid, + "temperature": self.temperature, } @staticmethod diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index 83f8eb8..c5d0234 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -25,6 +25,10 @@ X402_PROCESSING_HASH_HEADER = "x-processing-hash" X402_PLACEHOLDER_API_KEY = "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" BASE_TESTNET_NETWORK = "eip155:84532" +RETRYABLE_PAYMENT_ERROR_SNIPPETS = ( + "invalid payment required response", + "failed to handle payment", +) TIMEOUT = httpx.Timeout( timeout=90.0, @@ -122,10 +126,7 @@ def __init__(self, wallet_account: LocalAccount, og_llm_server_url: str, og_llm_ _fetch_tls_cert_as_ssl_context(self._og_llm_streaming_server_url) or True ) - signer = EthAccountSignerv2(self._wallet_account) - self._x402_client = x402Clientv2() - register_exact_evm_clientv2(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) - register_upto_evm_clientv2(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) + self._initialize_x402_client() self._request_client_ctx = None self._request_client = None @@ -156,6 +157,24 @@ async def _initialize_http_clients(self) -> None: self._stream_client_ctx = x402HttpxClientv2(self._x402_client, verify=self._streaming_tls_verify) self._stream_client = await self._stream_client_ctx.__aenter__() + def _initialize_x402_client(self) -> None: + signer = EthAccountSignerv2(self._wallet_account) + self._x402_client = x402Clientv2() + register_exact_evm_clientv2(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) + register_upto_evm_clientv2(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) + + @staticmethod + def _is_retryable_payment_error(error: Exception) -> bool: + message = str(error).lower() + if any(snippet in message for snippet in RETRYABLE_PAYMENT_ERROR_SNIPPETS): + return True + return "paymenterror" in error.__class__.__name__.lower() + + async def _refresh_payment_clients(self) -> None: + await self._close_http_clients() + self._initialize_x402_client() + await self._initialize_http_clients() + async def _close_http_clients(self) -> None: if self._request_client_ctx is not None: await self._request_client_ctx.__aexit__(None, None, None) @@ -198,14 +217,39 @@ def ensure_opg_approval(self, opg_amount: float) -> Permit2ApprovalResult: raise ValueError("OPG amount must be at least 0.05.") return ensure_opg_approval(self._wallet_account, opg_amount) + @staticmethod + def _resolve_model_id(model: Union[TEE_LLM, str]) -> str: + return str(model) + + @staticmethod + def _resolve_settlement_mode( + mode: Optional[Union[x402SettlementMode, str]], + ) -> x402SettlementMode: + if mode is None: + return x402SettlementMode.SETTLE_BATCH + if isinstance(mode, x402SettlementMode): + return mode + + normalized = mode.strip() + try: + return x402SettlementMode(normalized) + except Exception: + # Handle strings like "x402SettlementMode.SETTLE_BATCH" + if normalized.startswith("x402SettlementMode."): + normalized = normalized.split(".", 1)[1] + by_name = getattr(x402SettlementMode, normalized, None) + if isinstance(by_name, x402SettlementMode): + return by_name + raise OpenGradientError(f"Invalid x402 settlement mode: {mode}") + def completion( self, - model: TEE_LLM, + model: Union[TEE_LLM, str], prompt: str, max_tokens: int = 100, stop_sequence: Optional[List[str]] = None, temperature: float = 0.0, - x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.SETTLE_BATCH, + x402_settlement_mode: Optional[Union[x402SettlementMode, str]] = x402SettlementMode.SETTLE_BATCH, ) -> TextGenerationOutput: """ Perform inference on an LLM model using completions via TEE. @@ -232,7 +276,7 @@ def completion( OpenGradientError: If the inference fails. """ return self._tee_llm_completion( - model=model.split("/")[1], + model=self._resolve_model_id(model).split("/")[1], prompt=prompt, max_tokens=max_tokens, stop_sequence=stop_sequence, @@ -247,17 +291,18 @@ def _tee_llm_completion( max_tokens: int = 100, stop_sequence: Optional[List[str]] = None, temperature: float = 0.0, - x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.SETTLE_BATCH, + x402_settlement_mode: Optional[Union[x402SettlementMode, str]] = x402SettlementMode.SETTLE_BATCH, ) -> TextGenerationOutput: """ Route completion request to OpenGradient TEE LLM server with x402 payments. """ async def make_request_v2(): + settlement_mode = self._resolve_settlement_mode(x402_settlement_mode) headers = { "Content-Type": "application/json", "Authorization": f"Bearer {X402_PLACEHOLDER_API_KEY}", - "X-SETTLEMENT-TYPE": x402_settlement_mode.value, + "X-SETTLEMENT-TYPE": settlement_mode.value, } payload = { @@ -290,21 +335,24 @@ async def make_request_v2(): try: return self._run_coroutine(make_request_v2()) - except OpenGradientError: - raise except Exception as e: + if self._is_retryable_payment_error(e): + self._run_coroutine(self._refresh_payment_clients()) + return self._run_coroutine(make_request_v2()) + if isinstance(e, OpenGradientError): + raise raise OpenGradientError(f"TEE LLM completion failed: {str(e)}") def chat( self, - model: TEE_LLM, + model: Union[TEE_LLM, str], messages: List[Dict], max_tokens: int = 100, stop_sequence: Optional[List[str]] = None, temperature: float = 0.0, tools: Optional[List[Dict]] = None, tool_choice: Optional[str] = None, - x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.SETTLE_BATCH, + x402_settlement_mode: Optional[Union[x402SettlementMode, str]] = x402SettlementMode.SETTLE_BATCH, stream: bool = False, ) -> Union[TextGenerationOutput, TextGenerationStream]: """ @@ -333,10 +381,12 @@ def chat( Raises: OpenGradientError: If the inference fails. """ + resolved_model = self._resolve_model_id(model) + if stream: # Use threading bridge for true sync streaming return self._tee_llm_chat_stream_sync( - model=model.split("/")[1], + model=resolved_model.split("/")[1], messages=messages, max_tokens=max_tokens, stop_sequence=stop_sequence, @@ -348,7 +398,7 @@ def chat( else: # Non-streaming return self._tee_llm_chat( - model=model.split("/")[1], + model=resolved_model.split("/")[1], messages=messages, max_tokens=max_tokens, stop_sequence=stop_sequence, @@ -367,17 +417,18 @@ def _tee_llm_chat( temperature: float = 0.0, tools: Optional[List[Dict]] = None, tool_choice: Optional[str] = None, - x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH, + x402_settlement_mode: Optional[Union[x402SettlementMode, str]] = x402SettlementMode.SETTLE_BATCH, ) -> TextGenerationOutput: """ Route chat request to OpenGradient TEE LLM server with x402 payments. """ async def make_request_v2(): + settlement_mode = self._resolve_settlement_mode(x402_settlement_mode) headers = { "Content-Type": "application/json", "Authorization": f"Bearer {X402_PLACEHOLDER_API_KEY}", - "X-SETTLEMENT-TYPE": x402_settlement_mode.value, + "X-SETTLEMENT-TYPE": settlement_mode.value, } payload = { @@ -429,9 +480,12 @@ async def make_request_v2(): try: return self._run_coroutine(make_request_v2()) - except OpenGradientError: - raise except Exception as e: + if self._is_retryable_payment_error(e): + self._run_coroutine(self._refresh_payment_clients()) + return self._run_coroutine(make_request_v2()) + if isinstance(e, OpenGradientError): + raise raise OpenGradientError(f"TEE LLM chat failed: {str(e)}") def _tee_llm_chat_stream_sync( @@ -443,7 +497,7 @@ def _tee_llm_chat_stream_sync( temperature: float = 0.0, tools: Optional[List[Dict]] = None, tool_choice: Optional[str] = None, - x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH, + x402_settlement_mode: Optional[Union[x402SettlementMode, str]] = x402SettlementMode.SETTLE_BATCH, ): """ Sync streaming using threading bridge - TRUE real-time streaming. @@ -456,17 +510,33 @@ def _tee_llm_chat_stream_sync( async def _stream(): try: - async for chunk in self._tee_llm_chat_stream_async( - model=model, - messages=messages, - max_tokens=max_tokens, - stop_sequence=stop_sequence, - temperature=temperature, - tools=tools, - tool_choice=tool_choice, - x402_settlement_mode=x402_settlement_mode, - ): - queue.put(chunk) + try: + async for chunk in self._tee_llm_chat_stream_async( + model=model, + messages=messages, + max_tokens=max_tokens, + stop_sequence=stop_sequence, + temperature=temperature, + tools=tools, + tool_choice=tool_choice, + x402_settlement_mode=x402_settlement_mode, + ): + queue.put(chunk) + except Exception as e: + if not self._is_retryable_payment_error(e): + raise + await self._refresh_payment_clients() + async for chunk in self._tee_llm_chat_stream_async( + model=model, + messages=messages, + max_tokens=max_tokens, + stop_sequence=stop_sequence, + temperature=temperature, + tools=tools, + tool_choice=tool_choice, + x402_settlement_mode=x402_settlement_mode, + ): + queue.put(chunk) except Exception as e: queue.put(e) finally: @@ -499,17 +569,18 @@ async def _tee_llm_chat_stream_async( temperature: float = 0.0, tools: Optional[List[Dict]] = None, tool_choice: Optional[str] = None, - x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH, + x402_settlement_mode: Optional[Union[x402SettlementMode, str]] = x402SettlementMode.SETTLE_BATCH, ): """ Internal async streaming implementation for TEE LLM with x402 payments. Yields StreamChunk objects as they arrive from the server. """ + settlement_mode = self._resolve_settlement_mode(x402_settlement_mode) headers = { "Content-Type": "application/json", "Authorization": f"Bearer {X402_PLACEHOLDER_API_KEY}", - "X-SETTLEMENT-TYPE": x402_settlement_mode.value, + "X-SETTLEMENT-TYPE": settlement_mode.value, } payload = { From 0bfc95af2a086d9b18181eb0e1b08d88f5dddffd Mon Sep 17 00:00:00 2001 From: Aniket Dixit Date: Thu, 5 Mar 2026 15:37:00 +0530 Subject: [PATCH 3/6] cleanup --- src/opengradient/client/llm.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index c5d0234..64a5e0e 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -217,10 +217,6 @@ def ensure_opg_approval(self, opg_amount: float) -> Permit2ApprovalResult: raise ValueError("OPG amount must be at least 0.05.") return ensure_opg_approval(self._wallet_account, opg_amount) - @staticmethod - def _resolve_model_id(model: Union[TEE_LLM, str]) -> str: - return str(model) - @staticmethod def _resolve_settlement_mode( mode: Optional[Union[x402SettlementMode, str]], @@ -244,7 +240,7 @@ def _resolve_settlement_mode( def completion( self, - model: Union[TEE_LLM, str], + model: TEE_LLM, prompt: str, max_tokens: int = 100, stop_sequence: Optional[List[str]] = None, @@ -276,7 +272,7 @@ def completion( OpenGradientError: If the inference fails. """ return self._tee_llm_completion( - model=self._resolve_model_id(model).split("/")[1], + model=model.split("/")[1], prompt=prompt, max_tokens=max_tokens, stop_sequence=stop_sequence, @@ -345,7 +341,7 @@ async def make_request_v2(): def chat( self, - model: Union[TEE_LLM, str], + model: TEE_LLM, messages: List[Dict], max_tokens: int = 100, stop_sequence: Optional[List[str]] = None, @@ -381,12 +377,10 @@ def chat( Raises: OpenGradientError: If the inference fails. """ - resolved_model = self._resolve_model_id(model) - if stream: # Use threading bridge for true sync streaming return self._tee_llm_chat_stream_sync( - model=resolved_model.split("/")[1], + model=model.split("/")[1], messages=messages, max_tokens=max_tokens, stop_sequence=stop_sequence, @@ -398,7 +392,7 @@ def chat( else: # Non-streaming return self._tee_llm_chat( - model=resolved_model.split("/")[1], + model=model.split("/")[1], messages=messages, max_tokens=max_tokens, stop_sequence=stop_sequence, From f4df8ee6fc37b381480bc91f6bd21d0ba66d2cb3 Mon Sep 17 00:00:00 2001 From: Aniket Dixit Date: Thu, 5 Mar 2026 15:59:19 +0530 Subject: [PATCH 4/6] updates --- src/opengradient/agents/og_langchain.py | 4 +-- src/opengradient/client/llm.py | 42 ++++++------------------- 2 files changed, 11 insertions(+), 35 deletions(-) diff --git a/src/opengradient/agents/og_langchain.py b/src/opengradient/agents/og_langchain.py index 7b35f7a..cacd156 100644 --- a/src/opengradient/agents/og_langchain.py +++ b/src/opengradient/agents/og_langchain.py @@ -76,7 +76,7 @@ class OpenGradientChatModel(BaseChatModel): model_cid: Union[TEE_LLM, str] temperature: float = 0.0 max_tokens: int = 300 - x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.SETTLE_BATCH + x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH _client: Client = PrivateAttr() _tools: List[Dict] = PrivateAttr(default_factory=list) @@ -88,7 +88,7 @@ def __init__( client: Optional[Client] = None, temperature: float = 0.0, max_tokens: int = 300, - x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.SETTLE_BATCH, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH, **kwargs, ): super().__init__( diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index 64a5e0e..e628a4b 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -217,27 +217,6 @@ def ensure_opg_approval(self, opg_amount: float) -> Permit2ApprovalResult: raise ValueError("OPG amount must be at least 0.05.") return ensure_opg_approval(self._wallet_account, opg_amount) - @staticmethod - def _resolve_settlement_mode( - mode: Optional[Union[x402SettlementMode, str]], - ) -> x402SettlementMode: - if mode is None: - return x402SettlementMode.SETTLE_BATCH - if isinstance(mode, x402SettlementMode): - return mode - - normalized = mode.strip() - try: - return x402SettlementMode(normalized) - except Exception: - # Handle strings like "x402SettlementMode.SETTLE_BATCH" - if normalized.startswith("x402SettlementMode."): - normalized = normalized.split(".", 1)[1] - by_name = getattr(x402SettlementMode, normalized, None) - if isinstance(by_name, x402SettlementMode): - return by_name - raise OpenGradientError(f"Invalid x402 settlement mode: {mode}") - def completion( self, model: TEE_LLM, @@ -245,7 +224,7 @@ def completion( max_tokens: int = 100, stop_sequence: Optional[List[str]] = None, temperature: float = 0.0, - x402_settlement_mode: Optional[Union[x402SettlementMode, str]] = x402SettlementMode.SETTLE_BATCH, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH, ) -> TextGenerationOutput: """ Perform inference on an LLM model using completions via TEE. @@ -287,18 +266,17 @@ def _tee_llm_completion( max_tokens: int = 100, stop_sequence: Optional[List[str]] = None, temperature: float = 0.0, - x402_settlement_mode: Optional[Union[x402SettlementMode, str]] = x402SettlementMode.SETTLE_BATCH, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH, ) -> TextGenerationOutput: """ Route completion request to OpenGradient TEE LLM server with x402 payments. """ async def make_request_v2(): - settlement_mode = self._resolve_settlement_mode(x402_settlement_mode) headers = { "Content-Type": "application/json", "Authorization": f"Bearer {X402_PLACEHOLDER_API_KEY}", - "X-SETTLEMENT-TYPE": settlement_mode.value, + "X-SETTLEMENT-TYPE": x402_settlement_mode.value, } payload = { @@ -348,7 +326,7 @@ def chat( temperature: float = 0.0, tools: Optional[List[Dict]] = None, tool_choice: Optional[str] = None, - x402_settlement_mode: Optional[Union[x402SettlementMode, str]] = x402SettlementMode.SETTLE_BATCH, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH, stream: bool = False, ) -> Union[TextGenerationOutput, TextGenerationStream]: """ @@ -411,18 +389,17 @@ def _tee_llm_chat( temperature: float = 0.0, tools: Optional[List[Dict]] = None, tool_choice: Optional[str] = None, - x402_settlement_mode: Optional[Union[x402SettlementMode, str]] = x402SettlementMode.SETTLE_BATCH, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH, ) -> TextGenerationOutput: """ Route chat request to OpenGradient TEE LLM server with x402 payments. """ async def make_request_v2(): - settlement_mode = self._resolve_settlement_mode(x402_settlement_mode) headers = { "Content-Type": "application/json", "Authorization": f"Bearer {X402_PLACEHOLDER_API_KEY}", - "X-SETTLEMENT-TYPE": settlement_mode.value, + "X-SETTLEMENT-TYPE": x402_settlement_mode.value, } payload = { @@ -491,7 +468,7 @@ def _tee_llm_chat_stream_sync( temperature: float = 0.0, tools: Optional[List[Dict]] = None, tool_choice: Optional[str] = None, - x402_settlement_mode: Optional[Union[x402SettlementMode, str]] = x402SettlementMode.SETTLE_BATCH, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH, ): """ Sync streaming using threading bridge - TRUE real-time streaming. @@ -563,18 +540,17 @@ async def _tee_llm_chat_stream_async( temperature: float = 0.0, tools: Optional[List[Dict]] = None, tool_choice: Optional[str] = None, - x402_settlement_mode: Optional[Union[x402SettlementMode, str]] = x402SettlementMode.SETTLE_BATCH, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH, ): """ Internal async streaming implementation for TEE LLM with x402 payments. Yields StreamChunk objects as they arrive from the server. """ - settlement_mode = self._resolve_settlement_mode(x402_settlement_mode) headers = { "Content-Type": "application/json", "Authorization": f"Bearer {X402_PLACEHOLDER_API_KEY}", - "X-SETTLEMENT-TYPE": settlement_mode.value, + "X-SETTLEMENT-TYPE": x402_settlement_mode.value, } payload = { From 5a69bf4e96247bbebb770065e3d53434a82954d8 Mon Sep 17 00:00:00 2001 From: Aniket Dixit Date: Thu, 5 Mar 2026 16:10:03 +0530 Subject: [PATCH 5/6] updates --- tests/langchain_adapter_test.py | 127 ++++++++++++++++++++++++++++++-- 1 file changed, 122 insertions(+), 5 deletions(-) diff --git a/tests/langchain_adapter_test.py b/tests/langchain_adapter_test.py index 1671c7f..d607c63 100644 --- a/tests/langchain_adapter_test.py +++ b/tests/langchain_adapter_test.py @@ -11,7 +11,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from src.opengradient.agents.og_langchain import OpenGradientChatModel, _extract_content, _parse_tool_call -from src.opengradient.types import TEE_LLM, TextGenerationOutput, x402SettlementMode +from src.opengradient.types import StreamChoice, StreamChunk, StreamDelta, StreamUsage, TEE_LLM, TextGenerationOutput, x402SettlementMode @pytest.fixture @@ -26,7 +26,7 @@ def mock_client(): @pytest.fixture def model(mock_client): """Create an OpenGradientChatModel with a mocked client.""" - return OpenGradientChatModel(private_key="0x" + "a" * 64, model_cid=TEE_LLM.GPT_5) + return OpenGradientChatModel(model_cid=TEE_LLM.GPT_5, private_key="0x" + "a" * 64) class TestOpenGradientChatModel: @@ -39,21 +39,41 @@ def test_initialization(self, model): def test_initialization_custom_max_tokens(self, mock_client): """Test model initializes with custom max_tokens.""" - model = OpenGradientChatModel(private_key="0x" + "a" * 64, model_cid=TEE_LLM.CLAUDE_HAIKU_4_5, max_tokens=1000) + model = OpenGradientChatModel(model_cid=TEE_LLM.CLAUDE_HAIKU_4_5, private_key="0x" + "a" * 64, max_tokens=1000) assert model.max_tokens == 1000 def test_initialization_custom_settlement_mode(self, mock_client): """Test model initializes with custom settlement mode.""" model = OpenGradientChatModel( - private_key="0x" + "a" * 64, model_cid=TEE_LLM.GPT_5, + private_key="0x" + "a" * 64, x402_settlement_mode=x402SettlementMode.SETTLE, ) assert model.x402_settlement_mode == x402SettlementMode.SETTLE + def test_initialization_with_injected_client(self, mock_client): + """Test model can reuse an injected SDK client.""" + model = OpenGradientChatModel(model_cid="openai/gpt-4.1", client=mock_client) + assert model.model_cid == "openai/gpt-4.1" + assert model._client is mock_client + + def test_initialization_requires_client_or_private_key(self): + """Test model requires either a private key or SDK client.""" + with pytest.raises(ValueError, match="Either client or private_key must be provided."): + OpenGradientChatModel(model_cid=TEE_LLM.GPT_5) + + def test_initialization_rejects_client_and_private_key(self, mock_client): + """Test model rejects duplicate client configuration.""" + with pytest.raises(ValueError, match="Pass either client or private_key, not both."): + OpenGradientChatModel( + model_cid=TEE_LLM.GPT_5, + private_key="0x" + "a" * 64, + client=mock_client, + ) + def test_identifying_params(self, model): """Test _identifying_params returns model name.""" - assert model._identifying_params == {"model_name": TEE_LLM.GPT_5} + assert model._identifying_params == {"model_name": TEE_LLM.GPT_5, "temperature": 0.0} class TestGenerate: @@ -155,6 +175,88 @@ def test_empty_chat_output(self, model, mock_client): assert result.generations[0].message.content == "" + def test_stream_response(self, model, mock_client): + """Test _stream yields incremental chunks and usage metadata.""" + stream_chunks = iter( + [ + StreamChunk( + choices=[ + StreamChoice( + delta=StreamDelta(content="Hello"), + index=0, + finish_reason=None, + ) + ], + model="openai/gpt-5", + ), + StreamChunk( + choices=[ + StreamChoice( + delta=StreamDelta(content=" world"), + index=0, + finish_reason="stop", + ) + ], + model="openai/gpt-5", + usage=StreamUsage(prompt_tokens=10, completion_tokens=2, total_tokens=12), + is_final=True, + ), + ] + ) + mock_client.llm.chat.return_value = stream_chunks + + generations = list(model._stream([HumanMessage(content="Hi")])) + + assert len(generations) == 2 + assert generations[0].message.content == "Hello" + assert generations[1].message.content == " world" + assert generations[1].generation_info == {"finish_reason": "stop"} + assert generations[1].message.usage_metadata == { + "input_tokens": 10, + "output_tokens": 2, + "total_tokens": 12, + } + assert mock_client.llm.chat.call_args.kwargs["stream"] is True + + @pytest.mark.asyncio + async def test_astream_response(self, model, mock_client): + """Test _astream yields incremental chunks via async interface.""" + stream_chunks = iter( + [ + StreamChunk( + choices=[ + StreamChoice( + delta=StreamDelta(content="Hello"), + index=0, + finish_reason=None, + ) + ], + model="openai/gpt-5", + ), + StreamChunk( + choices=[ + StreamChoice( + delta=StreamDelta(content=" world"), + index=0, + finish_reason="stop", + ) + ], + model="openai/gpt-5", + is_final=True, + ), + ] + ) + mock_client.llm.chat.return_value = stream_chunks + + generations = [] + async for generation in model._astream([HumanMessage(content="Hi")]): + generations.append(generation) + + assert len(generations) == 2 + assert generations[0].message.content == "Hello" + assert generations[1].message.content == " world" + assert generations[1].generation_info == {"finish_reason": "stop"} + class TestMessageConversion: def test_converts_all_message_types(self, model, mock_client): @@ -199,6 +301,20 @@ def test_unsupported_message_type_raises(self, model, mock_client): with pytest.raises(ValueError, match="Unexpected message type"): model._generate([MagicMock(spec=[])]) + def test_accepts_dict_messages(self, model, mock_client): + """Test that dict messages are accepted for compatibility with existing routes.""" + mock_client.llm.chat.return_value = TextGenerationOutput( + transaction_hash="external", + finish_reason="stop", + chat_output={"role": "assistant", "content": "ok"}, + ) + + model._generate([{"role": "user", "content": "Hi"}]) + + assert mock_client.llm.chat.call_args.kwargs["messages"] == [ + {"role": "user", "content": "Hi"} + ] + def test_passes_correct_params_to_client(self, model, mock_client): """Test that _generate passes model params correctly to the SDK client.""" mock_client.llm.chat.return_value = TextGenerationOutput( @@ -213,6 +329,7 @@ def test_passes_correct_params_to_client(self, model, mock_client): model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}], stop_sequence=["END"], + temperature=0.0, max_tokens=300, tools=[], x402_settlement_mode=x402SettlementMode.SETTLE_BATCH, From 10072735938a67d76f12b5e078320800e23598c0 Mon Sep 17 00:00:00 2001 From: Aniket Dixit Date: Thu, 5 Mar 2026 16:19:47 +0530 Subject: [PATCH 6/6] updates --- tests/langchain_adapter_test.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/langchain_adapter_test.py b/tests/langchain_adapter_test.py index d607c63..d8ec348 100644 --- a/tests/langchain_adapter_test.py +++ b/tests/langchain_adapter_test.py @@ -1,6 +1,7 @@ import json import os import sys +import asyncio from unittest.mock import MagicMock, patch import pytest @@ -218,8 +219,7 @@ def test_stream_response(self, model, mock_client): } assert mock_client.llm.chat.call_args.kwargs["stream"] is True - @pytest.mark.asyncio - async def test_astream_response(self, model, mock_client): + def test_astream_response(self, model, mock_client): """Test _astream yields incremental chunks via async interface.""" stream_chunks = iter( [ @@ -248,9 +248,13 @@ async def test_astream_response(self, model, mock_client): ) mock_client.llm.chat.return_value = stream_chunks - generations = [] - async for generation in model._astream([HumanMessage(content="Hi")]): - generations.append(generation) + async def collect_generations(): + generations = [] + async for generation in model._astream([HumanMessage(content="Hi")]): + generations.append(generation) + return generations + + generations = asyncio.run(collect_generations()) assert len(generations) == 2 assert generations[0].message.content == "Hello"