Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/opengradient/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
into existing applications and agent frameworks.
"""

from typing import Optional, Union

from ..client import Client
from ..types import TEE_LLM, x402SettlementMode
from .og_langchain import *


def langchain_adapter(
private_key: str,
model_cid: TEE_LLM,
model_cid: Union[TEE_LLM, str],
private_key: Optional[str] = None,
client: Optional[Client] = None,
temperature: float = 0.0,
max_tokens: int = 300,
x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH,
) -> OpenGradientChatModel:
Expand All @@ -21,8 +26,10 @@ def langchain_adapter(
and can be plugged into LangChain agents.
"""
return OpenGradientChatModel(
private_key=private_key,
model_cid=model_cid,
private_key=private_key,
client=client,
temperature=temperature,
max_tokens=max_tokens,
x402_settlement_mode=x402_settlement_mode,
)
Expand Down
170 changes: 135 additions & 35 deletions src/opengradient/agents/og_langchain.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# mypy: ignore-errors
import asyncio
import json
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Sequence, Union

from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun
from langchain_core.language_models.base import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
SystemMessage,
Expand All @@ -15,6 +17,7 @@
from langchain_core.messages.tool import ToolMessage
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
)
from langchain_core.runnables import Runnable
Expand All @@ -25,6 +28,7 @@
from ..types import TEE_LLM, x402SettlementMode

__all__ = ["OpenGradientChatModel"]
_STREAM_END = object()


def _extract_content(content: Any) -> str:
Expand Down Expand Up @@ -69,28 +73,38 @@ def _parse_tool_call(tool_call: Dict) -> ToolCall:
class OpenGradientChatModel(BaseChatModel):
"""OpenGradient adapter class for LangChain chat model"""

model_cid: str
model_cid: Union[TEE_LLM, str]
temperature: float = 0.0
max_tokens: int = 300
x402_settlement_mode: Optional[str] = x402SettlementMode.SETTLE_BATCH
x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH

_client: Client = PrivateAttr()
_tools: List[Dict] = PrivateAttr(default_factory=list)

def __init__(
self,
private_key: str,
model_cid: TEE_LLM,
model_cid: Union[TEE_LLM, str],
private_key: Optional[str] = None,
client: Optional[Client] = None,
temperature: float = 0.0,
max_tokens: int = 300,
x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.SETTLE_BATCH,
x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH,
**kwargs,
):
super().__init__(
model_cid=model_cid,
temperature=temperature,
max_tokens=max_tokens,
x402_settlement_mode=x402_settlement_mode,
**kwargs,
)
self._client = Client(private_key=private_key)
if client is not None and private_key is not None:
raise ValueError("Pass either client or private_key, not both.")
if client is None:
if private_key is None:
raise ValueError("Either client or private_key must be provided.")
client = Client(private_key=private_key)
self._client = client

@property
def _llm_type(self) -> str:
Expand Down Expand Up @@ -136,7 +150,103 @@ def _generate(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
sdk_messages = []
sdk_messages = self._to_sdk_messages(messages)

chat_output = self._client.llm.chat(
model=self.model_cid,
messages=sdk_messages,
stop_sequence=stop,
temperature=self.temperature,
max_tokens=self.max_tokens,
tools=self._tools,
x402_settlement_mode=self.x402_settlement_mode,
)

finish_reason = chat_output.finish_reason or ""
chat_response = chat_output.chat_output or {}

if chat_response.get("tool_calls"):
tool_calls = [_parse_tool_call(tc) for tc in chat_response["tool_calls"]]
ai_message = AIMessage(content="", tool_calls=tool_calls)
else:
ai_message = AIMessage(content=_extract_content(chat_response.get("content", "")))

return ChatResult(generations=[ChatGeneration(message=ai_message, generation_info={"finish_reason": finish_reason})])

def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
sdk_messages = self._to_sdk_messages(messages)
stream = self._client.llm.chat(
model=self.model_cid,
messages=sdk_messages,
stop_sequence=stop,
temperature=self.temperature,
max_tokens=self.max_tokens,
tools=self._tools,
x402_settlement_mode=self.x402_settlement_mode,
stream=True,
)

for chunk in stream:
if not chunk.choices:
continue

choice = chunk.choices[0]
delta = choice.delta
content = _extract_content(delta.content)

additional_kwargs: Dict[str, Any] = {}
if delta.tool_calls:
additional_kwargs["tool_calls"] = delta.tool_calls

chunk_kwargs: Dict[str, Any] = {
"content": content,
"additional_kwargs": additional_kwargs,
}
if chunk.usage:
chunk_kwargs["usage_metadata"] = {
"input_tokens": chunk.usage.prompt_tokens,
"output_tokens": chunk.usage.completion_tokens,
"total_tokens": chunk.usage.total_tokens,
}

generation_info = {"finish_reason": choice.finish_reason} if choice.finish_reason else None
yield ChatGenerationChunk(
message=AIMessageChunk(**chunk_kwargs),
generation_info=generation_info,
)

async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
# Bridge the sync iterator from the SDK to LangChain's async streaming API.
iterator = self._stream(messages=messages, stop=stop, **kwargs)
while True:
# Use next(..., default) so StopIteration does not cross Future boundaries.
chunk = await asyncio.to_thread(next, iterator, _STREAM_END)
if chunk is _STREAM_END:
break
yield chunk

@property
def _identifying_params(self) -> Dict[str, Any]:
return {
"model_name": self.model_cid,
"temperature": self.temperature,
}

@staticmethod
def _to_sdk_messages(messages: List[Any]) -> List[Dict[str, Any]]:
sdk_messages: List[Dict[str, Any]] = []
for message in messages:
if isinstance(message, SystemMessage):
sdk_messages.append({"role": "system", "content": _extract_content(message.content)})
Expand All @@ -162,31 +272,21 @@ def _generate(
"tool_call_id": message.tool_call_id,
}
)
else:
raise ValueError(f"Unexpected message type: {message}")
elif isinstance(message, dict):
role = message.get("role")
if role not in {"system", "user", "assistant", "tool"}:
raise ValueError(f"Unexpected message role in dict message: {role}")

chat_output = self._client.llm.chat(
model=self.model_cid,
messages=sdk_messages,
stop_sequence=stop,
max_tokens=self.max_tokens,
tools=self._tools,
x402_settlement_mode=self.x402_settlement_mode,
)

finish_reason = chat_output.finish_reason or ""
chat_response = chat_output.chat_output or {}

if chat_response.get("tool_calls"):
tool_calls = [_parse_tool_call(tc) for tc in chat_response["tool_calls"]]
ai_message = AIMessage(content="", tool_calls=tool_calls)
else:
ai_message = AIMessage(content=_extract_content(chat_response.get("content", "")))
sdk_message: Dict[str, Any] = {
"role": role,
"content": _extract_content(message.get("content", "")),
}
if role == "assistant" and message.get("tool_calls"):
sdk_message["tool_calls"] = message["tool_calls"]
if role == "tool" and message.get("tool_call_id"):
sdk_message["tool_call_id"] = message["tool_call_id"]

return ChatResult(generations=[ChatGeneration(message=ai_message, generation_info={"finish_reason": finish_reason})])

@property
def _identifying_params(self) -> Dict[str, Any]:
return {
"model_name": self.model_cid,
}
sdk_messages.append(sdk_message)
else:
raise ValueError(f"Unexpected message type: {message}")
return sdk_messages
Loading