diff --git a/Makefile b/Makefile index acc46ee..cdb51c6 100644 --- a/Makefile +++ b/Makefile @@ -87,16 +87,16 @@ chat-stream: chat-tool: python -m opengradient.cli chat \ --model $(MODEL) \ - --messages '[{"role":"user","content":"What is the weather in Tokyo?"}]' \ - --tools '[{"type":"function","function":{"name":"get_weather","description":"Get weather for a location","parameters":{"type":"object","properties":{"location":{"type":"string"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}]' \ - --max-tokens 100 + --messages '[{"role":"system","content":"You are a helpful assistant. Use tools when needed."},{"role":"user","content":"What'\''s the weather like in Dallas, Texas? Give me the temperature in fahrenheit."}]' \ + --tools '[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"city":{"type":"string"},"state":{"type":"string"},"unit":{"type":"string","enum":["fahrenheit","celsius"]}},"required":["city","state","unit"]}}}]' \ + --max-tokens 200 chat-stream-tool: python -m opengradient.cli chat \ --model $(MODEL) \ - --messages '[{"role":"user","content":"What is the weather in Tokyo?"}]' \ - --tools '[{"type":"function","function":{"name":"get_weather","description":"Get weather for a location","parameters":{"type":"object","properties":{"location":{"type":"string"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}]' \ - --max-tokens 100 \ + --messages '[{"role":"system","content":"You are a helpful assistant. Use tools when needed."},{"role":"user","content":"What'\''s the weather like in Dallas, Texas? Give me the temperature in fahrenheit."}]' \ + --tools '[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"city":{"type":"string"},"state":{"type":"string"},"unit":{"type":"string","enum":["fahrenheit","celsius"]}},"required":["city","state","unit"]}}}]' \ + --max-tokens 200 \ --stream .PHONY: install build publish check docs test utils_test client_test langchain_adapter_test opg_token_test integrationtest examples \ diff --git a/pyproject.toml b/pyproject.toml index 262637b..fd0428d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "opengradient" -version = "0.7.5" +version = "0.7.6" description = "Python SDK for OpenGradient decentralized model management & inference services" authors = [{name = "OpenGradient", email = "adam@vannalabs.ai"}] readme = "README.md" diff --git a/src/opengradient/abi/TEERegistry.abi b/src/opengradient/abi/TEERegistry.abi new file mode 100644 index 0000000..9dcf4a5 --- /dev/null +++ b/src/opengradient/abi/TEERegistry.abi @@ -0,0 +1,48 @@ +[ + { + "inputs": [], + "name": "getActiveTEEs", + "outputs": [{"internalType": "bytes32[]", "name": "", "type": "bytes32[]"}], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [{"internalType": "uint8", "name": "teeType", "type": "uint8"}], + "name": "getTEEsByType", + "outputs": [{"internalType": "bytes32[]", "name": "", "type": "bytes32[]"}], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [{"internalType": "bytes32", "name": "teeId", "type": "bytes32"}], + "name": "getTEE", + "outputs": [ + { + "components": [ + {"internalType": "address", "name": "owner", "type": "address"}, + {"internalType": "address", "name": "paymentAddress", "type": "address"}, + {"internalType": "string", "name": "endpoint", "type": "string"}, + {"internalType": "bytes", "name": "publicKey", "type": "bytes"}, + {"internalType": "bytes", "name": "tlsCertificate", "type": "bytes"}, + {"internalType": "bytes32", "name": "pcrHash", "type": "bytes32"}, + {"internalType": "uint8", "name": "teeType", "type": "uint8"}, + {"internalType": "bool", "name": "active", "type": "bool"}, + {"internalType": "uint256", "name": "registeredAt", "type": "uint256"}, + {"internalType": "uint256", "name": "lastUpdatedAt", "type": "uint256"} + ], + "internalType": "struct TEERegistry.TEEInfo", + "name": "", + "type": "tuple" + } + ], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [{"internalType": "bytes32", "name": "teeId", "type": "bytes32"}], + "name": "isActive", + "outputs": [{"internalType": "bool", "name": "", "type": "bool"}], + "stateMutability": "view", + "type": "function" + } +] diff --git a/src/opengradient/cli.py b/src/opengradient/cli.py index 979f676..650a86e 100644 --- a/src/opengradient/cli.py +++ b/src/opengradient/cli.py @@ -413,13 +413,29 @@ def completion( x402_settlement_mode=x402SettlementModes[x402_settlement_mode], ) - print_llm_completion_result(model_cid, completion_output.transaction_hash, completion_output.completion_output, is_vanilla=False) + print_llm_completion_result(model_cid, completion_output.transaction_hash, completion_output.completion_output, is_vanilla=False, result=completion_output) except Exception as e: click.echo(f"Error running LLM completion: {str(e)}") -def print_llm_completion_result(model_cid, tx_hash, llm_output, is_vanilla=True): +def _print_tee_info(tee_id, tee_endpoint, tee_payment_address): + """Print TEE node info if available.""" + if not any([tee_id, tee_endpoint, tee_payment_address]): + return + click.secho("TEE Node:", fg="magenta", bold=True) + if tee_endpoint: + click.echo(" Endpoint: ", nl=False) + click.secho(tee_endpoint, fg="magenta") + if tee_id: + click.echo(" TEE ID: ", nl=False) + click.secho(tee_id, fg="magenta") + if tee_payment_address: + click.echo(" Payment address: ", nl=False) + click.secho(tee_payment_address, fg="magenta") + + +def print_llm_completion_result(model_cid, tx_hash, llm_output, is_vanilla=True, result=None): click.secho("✅ LLM completion Successful", fg="green", bold=True) click.echo("──────────────────────────────────────") click.echo("Model: ", nl=False) @@ -435,6 +451,9 @@ def print_llm_completion_result(model_cid, tx_hash, llm_output, is_vanilla=True) click.echo("Source: ", nl=False) click.secho("OpenGradient TEE", fg="cyan", bold=True) + if result is not None: + _print_tee_info(result.tee_id, result.tee_endpoint, result.tee_payment_address) + click.echo("──────────────────────────────────────") click.secho("LLM Output:", fg="yellow", bold=True) click.echo() @@ -578,13 +597,13 @@ def chat( if stream: print_streaming_chat_result(model_cid, result, is_tee=True) else: - print_llm_chat_result(model_cid, result.transaction_hash, result.finish_reason, result.chat_output, is_vanilla=False) + print_llm_chat_result(model_cid, result.transaction_hash, result.finish_reason, result.chat_output, is_vanilla=False, result=result) except Exception as e: click.echo(f"Error running LLM chat inference: {str(e)}") -def print_llm_chat_result(model_cid, tx_hash, finish_reason, chat_output, is_vanilla=True): +def print_llm_chat_result(model_cid, tx_hash, finish_reason, chat_output, is_vanilla=True, result=None): click.secho("✅ LLM Chat Successful", fg="green", bold=True) click.echo("──────────────────────────────────────") click.echo("Model: ", nl=False) @@ -600,6 +619,9 @@ def print_llm_chat_result(model_cid, tx_hash, finish_reason, chat_output, is_van click.echo("Source: ", nl=False) click.secho("OpenGradient TEE", fg="cyan", bold=True) + if result is not None: + _print_tee_info(result.tee_id, result.tee_endpoint, result.tee_payment_address) + click.echo("──────────────────────────────────────") click.secho("Finish Reason: ", fg="yellow", bold=True) click.echo() @@ -608,16 +630,24 @@ def print_llm_chat_result(model_cid, tx_hash, finish_reason, chat_output, is_van click.secho("Chat Output:", fg="yellow", bold=True) click.echo() for key, value in chat_output.items(): - if value is not None and value not in ("", "[]", []): + if value is None or value in ("", "[]", []): + continue + if key == "tool_calls": + # Format tool calls the same way as the streaming path + click.secho("Tool Calls:", fg="yellow", bold=True) + for tool_call in value: + fn = tool_call.get("function", {}) + click.echo(f" Function: {fn.get('name', '')}") + click.echo(f" Arguments: {fn.get('arguments', '')}") + elif key == "content" and isinstance(value, list): # Normalize list-of-blocks content (e.g. Gemini 3 thought signatures) - if key == "content" and isinstance(value, list): - text = " ".join( - block.get("text", "") for block in value - if isinstance(block, dict) and block.get("type") == "text" - ).strip() - click.echo(f"{key}: {text}") - else: - click.echo(f"{key}: {value}") + text = " ".join( + block.get("text", "") for block in value + if isinstance(block, dict) and block.get("type") == "text" + ).strip() + click.echo(f"{key}: {text}") + else: + click.echo(f"{key}: {value}") click.echo() @@ -641,20 +671,21 @@ def print_streaming_chat_result(model_cid, stream, is_tee=True): for chunk in stream: chunk_count += 1 - if chunk.choices[0].delta.content: - content = chunk.choices[0].delta.content - sys.stdout.write(content) - sys.stdout.flush() - content_parts.append(content) - - # Handle tool calls - if chunk.choices[0].delta.tool_calls: - sys.stdout.write("\n") - sys.stdout.flush() - click.secho("Tool Calls:", fg="yellow", bold=True) - for tool_call in chunk.choices[0].delta.tool_calls: - click.echo(f" Function: {tool_call['function']['name']}") - click.echo(f" Arguments: {tool_call['function']['arguments']}") + if chunk.choices: + if chunk.choices[0].delta.content: + content = chunk.choices[0].delta.content + sys.stdout.write(content) + sys.stdout.flush() + content_parts.append(content) + + # Handle tool calls + if chunk.choices[0].delta.tool_calls: + sys.stdout.write("\n") + sys.stdout.flush() + click.secho("Tool Calls:", fg="yellow", bold=True) + for tool_call in chunk.choices[0].delta.tool_calls: + click.echo(f" Function: {tool_call['function']['name']}") + click.echo(f" Arguments: {tool_call['function']['arguments']}") # Print final info when stream completes if chunk.is_final: @@ -669,10 +700,12 @@ def print_streaming_chat_result(model_cid, stream, is_tee=True): click.echo(f" Total tokens: {chunk.usage.total_tokens}") click.echo() - if chunk.choices[0].finish_reason: + if chunk.choices and chunk.choices[0].finish_reason: click.echo("Finish reason: ", nl=False) click.secho(chunk.choices[0].finish_reason, fg="green") + _print_tee_info(chunk.tee_id, chunk.tee_endpoint, chunk.tee_payment_address) + click.echo("──────────────────────────────────────") click.echo(f"Chunks received: {chunk_count}") click.echo(f"Content length: {len(''.join(content_parts))} characters") diff --git a/src/opengradient/client/client.py b/src/opengradient/client/client.py index 8bb35e6..1af1dae 100644 --- a/src/opengradient/client/client.py +++ b/src/opengradient/client/client.py @@ -1,5 +1,6 @@ """Main Client class that unifies all OpenGradient service namespaces.""" +import logging from typing import Optional from web3 import Web3 @@ -7,15 +8,18 @@ from ..defaults import ( DEFAULT_API_URL, DEFAULT_INFERENCE_CONTRACT_ADDRESS, - DEFAULT_OPENGRADIENT_LLM_SERVER_URL, - DEFAULT_OPENGRADIENT_LLM_STREAMING_SERVER_URL, DEFAULT_RPC_URL, + DEFAULT_TEE_REGISTRY_ADDRESS, + DEFAULT_TEE_REGISTRY_RPC_URL, ) from .alpha import Alpha from .llm import LLM from .model_hub import ModelHub +from .tee_registry import TEERegistry from .twins import Twins +logger = logging.getLogger(__name__) + class Client: """ @@ -62,8 +66,10 @@ def __init__( rpc_url: str = DEFAULT_RPC_URL, api_url: str = DEFAULT_API_URL, contract_address: str = DEFAULT_INFERENCE_CONTRACT_ADDRESS, - og_llm_server_url: Optional[str] = DEFAULT_OPENGRADIENT_LLM_SERVER_URL, - og_llm_streaming_server_url: Optional[str] = DEFAULT_OPENGRADIENT_LLM_STREAMING_SERVER_URL, + og_llm_server_url: Optional[str] = None, + og_llm_streaming_server_url: Optional[str] = None, + tee_registry_address: str = DEFAULT_TEE_REGISTRY_ADDRESS, + tee_registry_rpc_url: str = DEFAULT_TEE_REGISTRY_RPC_URL, ): """ Initialize the OpenGradient client. @@ -74,6 +80,11 @@ def __init__( You can supply a separate ``alpha_private_key`` so each chain uses its own funded wallet. When omitted, ``private_key`` is used for both. + By default the LLM server endpoint and its TLS certificate are fetched from + the on-chain TEE Registry, which stores certificates that were verified during + enclave attestation. You can override the endpoint by passing + ``og_llm_server_url`` explicitly (the system CA bundle is used for that URL). + Args: private_key: Private key whose wallet holds **Base Sepolia OPG tokens** for x402 LLM payments. @@ -86,8 +97,15 @@ def __init__( rpc_url: RPC URL for the OpenGradient Alpha Testnet. api_url: API URL for the OpenGradient API. contract_address: Inference contract address. - og_llm_server_url: OpenGradient LLM server URL. - og_llm_streaming_server_url: OpenGradient LLM streaming server URL. + og_llm_server_url: Override the LLM server URL instead of using the + registry-discovered endpoint. When set, the TLS certificate is + validated against the system CA bundle rather than the registry. + og_llm_streaming_server_url: Override the LLM streaming server URL. + Defaults to ``og_llm_server_url`` when that is provided. + tee_registry_address: Address of the TEERegistry contract used to + discover active LLM proxy endpoints and their verified TLS certs. + tee_registry_rpc_url: RPC endpoint for the chain that hosts the + TEERegistry contract. """ blockchain = Web3(Web3.HTTPProvider(rpc_url)) wallet_account = blockchain.eth.account.from_key(private_key) @@ -102,6 +120,38 @@ def __init__( if email is not None: hub_user = ModelHub._login_to_hub(email, password) + # Resolve LLM server URL and TLS certificate. + # If the caller provided explicit URLs, use those with standard CA verification. + # Otherwise, discover the endpoint and registry-verified cert from the TEE Registry. + llm_tls_cert_der: Optional[bytes] = None + tee = None + if og_llm_server_url is None: + try: + registry = TEERegistry( + rpc_url=tee_registry_rpc_url, + registry_address=tee_registry_address, + ) + tee = registry.get_llm_tee() + if tee is not None: + og_llm_server_url = tee.endpoint + og_llm_streaming_server_url = og_llm_streaming_server_url or tee.endpoint + llm_tls_cert_der = tee.tls_cert_der + logger.info("Using TEE endpoint from registry: %s (teeId=%s)", tee.endpoint, tee.tee_id) + else: + raise ValueError( + "No active LLM proxy TEE found in the registry. " + "Pass og_llm_server_url explicitly to override." + ) + except ValueError: + raise + except Exception as e: + raise RuntimeError( + f"Failed to fetch LLM TEE endpoint from registry ({tee_registry_address}): {e}. " + "Pass og_llm_server_url explicitly to override." + ) from e + else: + og_llm_streaming_server_url = og_llm_streaming_server_url or og_llm_server_url + # Create namespaces self.model_hub = ModelHub(hub_user=hub_user) self.wallet_address = wallet_account.address @@ -110,6 +160,9 @@ def __init__( wallet_account=wallet_account, og_llm_server_url=og_llm_server_url, og_llm_streaming_server_url=og_llm_streaming_server_url, + tls_cert_der=llm_tls_cert_der, + tee_id=tee.tee_id if tee is not None else None, + tee_payment_address=tee.payment_address if tee is not None else None, ) self.alpha = Alpha( diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index f95acd3..9a5db0e 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -2,13 +2,13 @@ import asyncio import json +import logging +import ssl import threading from queue import Queue from typing import AsyncGenerator, Dict, List, Optional, Union -import ssl -import socket -import tempfile -from urllib.parse import urlparse + +logger = logging.getLogger(__name__) import httpx from eth_account.account import LocalAccount @@ -18,9 +18,10 @@ from x402v2.mechanisms.evm.exact.register import register_exact_evm_client as register_exact_evm_clientv2 from x402v2.mechanisms.evm.upto.register import register_upto_evm_client as register_upto_evm_clientv2 -from ..types import TEE_LLM, StreamChunk, TextGenerationOutput, TextGenerationStream, x402SettlementMode +from ..types import TEE_LLM, StreamChunk, StreamChoice, StreamDelta, TextGenerationOutput, TextGenerationStream, x402SettlementMode from .exceptions import OpenGradientError from .opg_token import Permit2ApprovalResult, ensure_opg_approval +from .tee_registry import build_ssl_context_from_der X402_PROCESSING_HASH_HEADER = "x-processing-hash" X402_PLACEHOLDER_API_KEY = "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" @@ -40,53 +41,6 @@ ) -def _fetch_tls_cert_as_ssl_context(server_url: str) -> Optional[ssl.SSLContext]: - """ - Connect to a server, retrieve its TLS certificate (TOFU), - and return an ssl.SSLContext that trusts ONLY that certificate. - - Hostname verification is disabled because the TEE server's cert - is typically issued for a hostname but we may connect via IP address. - The pinned certificate itself provides the trust anchor. - - Returns None if the server is not HTTPS or unreachable. - """ - parsed = urlparse(server_url) - if parsed.scheme != "https": - return None - - hostname = parsed.hostname - port = parsed.port or 443 - - # Connect without verification to retrieve the server's certificate - fetch_ctx = ssl.create_default_context() - fetch_ctx.check_hostname = False - fetch_ctx.verify_mode = ssl.CERT_NONE - - try: - with socket.create_connection((hostname, port), timeout=10) as sock: - with fetch_ctx.wrap_socket(sock, server_hostname=hostname) as ssock: - der_cert = ssock.getpeercert(binary_form=True) - pem_cert = ssl.DER_cert_to_PEM_cert(der_cert) - except Exception: - return None - - # Write PEM to a temp file so we can load it into the SSLContext - cert_file = tempfile.NamedTemporaryFile( - prefix="og_tee_tls_", suffix=".pem", delete=False, mode="w" - ) - cert_file.write(pem_cert) - cert_file.flush() - cert_file.close() - - # Build an SSLContext that trusts ONLY this cert, with hostname check disabled - ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - ctx.load_verify_locations(cert_file.name) - ctx.check_hostname = False # Cert is for a hostname, but we connect via IP - ctx.verify_mode = ssl.CERT_REQUIRED # Still verify the cert itself - return ctx - - class LLM: """ LLM inference namespace. @@ -110,17 +64,34 @@ class LLM: result = client.llm.completion(model=TEE_LLM.CLAUDE_3_5_HAIKU, prompt="Hello") """ - def __init__(self, wallet_account: LocalAccount, og_llm_server_url: str, og_llm_streaming_server_url: str): + def __init__( + self, + wallet_account: LocalAccount, + og_llm_server_url: str, + og_llm_streaming_server_url: str, + tls_cert_der: Optional[bytes] = None, + tee_id: Optional[str] = None, + tee_payment_address: Optional[str] = None, + ): self._wallet_account = wallet_account self._og_llm_server_url = og_llm_server_url self._og_llm_streaming_server_url = og_llm_streaming_server_url - self._tls_verify: Union[ssl.SSLContext, bool] = ( - _fetch_tls_cert_as_ssl_context(self._og_llm_server_url) or True - ) - self._streaming_tls_verify: Union[ssl.SSLContext, bool] = ( - _fetch_tls_cert_as_ssl_context(self._og_llm_streaming_server_url) or True - ) + # TEE metadata surfaced on every response so callers can verify/audit which + # enclave served the request. + self._tee_id = tee_id + self._tee_endpoint = og_llm_server_url + self._tee_payment_address = tee_payment_address + + if tls_cert_der: + # Use the registry-verified certificate as the sole trust anchor. + ssl_ctx = build_ssl_context_from_der(tls_cert_der) + self._tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx + self._streaming_tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx + else: + # No cert from registry — fall back to default system CA verification. + self._tls_verify = True + self._streaming_tls_verify = True signer = EthAccountSignerv2(self._wallet_account) self._x402_client = x402Clientv2() @@ -283,6 +254,9 @@ async def make_request_v2(): completion_output=result.get("completion"), tee_signature=result.get("tee_signature"), tee_timestamp=result.get("tee_timestamp"), + tee_id=self._tee_id, + tee_endpoint=self._tee_endpoint, + tee_payment_address=self._tee_payment_address, ) except Exception as e: @@ -334,6 +308,20 @@ def chat( OpenGradientError: If the inference fails. """ if stream: + if tools: + # The TEE streaming endpoint omits tool call content from SSE events. + # Fall back transparently to the non-streaming endpoint and emit a + # single final StreamChunk so callers get the complete tool call data. + return self._tee_llm_chat_tools_as_stream( + model=model.split("/")[1], + messages=messages, + max_tokens=max_tokens, + stop_sequence=stop_sequence, + temperature=temperature, + tools=tools, + tool_choice=tool_choice, + x402_settlement_mode=x402_settlement_mode, + ) # Use threading bridge for true sync streaming return self._tee_llm_chat_stream_sync( model=model.split("/")[1], @@ -422,6 +410,9 @@ async def make_request_v2(): chat_output=message, tee_signature=result.get("tee_signature"), tee_timestamp=result.get("tee_timestamp"), + tee_id=self._tee_id, + tee_endpoint=self._tee_endpoint, + tee_payment_address=self._tee_payment_address, ) except Exception as e: @@ -434,6 +425,59 @@ async def make_request_v2(): except Exception as e: raise OpenGradientError(f"TEE LLM chat failed: {str(e)}") + def _tee_llm_chat_tools_as_stream( + self, + model: str, + messages: List[Dict], + max_tokens: int = 100, + stop_sequence: Optional[List[str]] = None, + temperature: float = 0.0, + tools: Optional[List[Dict]] = None, + tool_choice: Optional[str] = None, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH, + ): + """ + Transparent non-streaming fallback for tool-call requests with stream=True. + + The TEE streaming endpoint returns an empty delta when tools are present — + tool call content is not emitted as SSE events. This method calls the + non-streaming endpoint instead and emits a single final StreamChunk that + carries the complete tool call response, preserving the streaming interface + for callers (including the CLI). + """ + result = self._tee_llm_chat( + model=model, + messages=messages, + max_tokens=max_tokens, + stop_sequence=stop_sequence, + temperature=temperature, + tools=tools, + tool_choice=tool_choice, + x402_settlement_mode=x402_settlement_mode, + ) + + chat_output = result.chat_output or {} + delta = StreamDelta( + role=chat_output.get("role"), + content=chat_output.get("content"), + tool_calls=chat_output.get("tool_calls"), + ) + choice = StreamChoice( + delta=delta, + index=0, + finish_reason=result.finish_reason, + ) + yield StreamChunk( + choices=[choice], + model=model, + is_final=True, + tee_signature=result.tee_signature, + tee_timestamp=result.tee_timestamp, + tee_id=result.tee_id, + tee_endpoint=result.tee_endpoint, + tee_payment_address=result.tee_payment_address, + ) + def _tee_llm_chat_stream_sync( self, model: str, @@ -560,7 +604,12 @@ async def _parse_sse_response(response) -> AsyncGenerator[StreamChunk, None]: try: data = json.loads(data_str) - yield StreamChunk.from_sse_data(data) + chunk = StreamChunk.from_sse_data(data) + if chunk.is_final: + chunk.tee_id = self._tee_id + chunk.tee_endpoint = self._tee_endpoint + chunk.tee_payment_address = self._tee_payment_address + yield chunk except json.JSONDecodeError: continue diff --git a/src/opengradient/client/tee_registry.py b/src/opengradient/client/tee_registry.py new file mode 100644 index 0000000..83e92b0 --- /dev/null +++ b/src/opengradient/client/tee_registry.py @@ -0,0 +1,152 @@ +"""TEE Registry client for fetching verified TEE endpoints and TLS certificates.""" + +import logging +import ssl +import tempfile +from dataclasses import dataclass +from typing import List, Optional + +from web3 import Web3 + +from ._utils import get_abi + +logger = logging.getLogger(__name__) + +# TEE types as defined in the registry contract +TEE_TYPE_LLM_PROXY = 0 +TEE_TYPE_VALIDATOR = 1 + + +@dataclass +class TEEEndpoint: + """A verified TEE with its endpoint URL and TLS certificate from the registry.""" + + tee_id: str + endpoint: str + tls_cert_der: bytes + payment_address: str + + +class TEERegistry: + """ + Queries the on-chain TEE Registry contract to retrieve verified TEE endpoints + and their TLS certificates. + + Instead of blindly trusting the TLS certificate presented by a TEE server + (TOFU), this class fetches the certificate that was submitted and verified + during TEE registration. Any certificate that does not match the one stored + in the registry should be rejected. + + Args: + rpc_url: RPC endpoint for the chain where the registry is deployed. + registry_address: Address of the deployed TEERegistry contract. + """ + + def __init__(self, rpc_url: str, registry_address: str): + self._web3 = Web3(Web3.HTTPProvider(rpc_url)) + abi = get_abi("TEERegistry.abi") + self._contract = self._web3.eth.contract( + address=Web3.to_checksum_address(registry_address), + abi=abi, + ) + + def get_active_tees_by_type(self, tee_type: int) -> List[TEEEndpoint]: + """ + Return all active TEEs of the given type with their endpoints and TLS certs. + + Args: + tee_type: Integer TEE type (0=LLMProxy, 1=Validator). + + Returns: + List of TEEEndpoint objects for active TEEs of that type. + """ + type_label = {TEE_TYPE_LLM_PROXY: "LLMProxy", TEE_TYPE_VALIDATOR: "Validator"}.get(tee_type, str(tee_type)) + + try: + tee_ids: List[bytes] = self._contract.functions.getTEEsByType(tee_type).call() + except Exception as e: + logger.warning("Failed to fetch TEE IDs from registry (type=%s): %s", type_label, e) + return [] + + logger.debug("Registry returned %d TEE ID(s) for type=%s", len(tee_ids), type_label) + + endpoints: List[TEEEndpoint] = [] + for tee_id in tee_ids: + tee_id_hex = tee_id.hex() + try: + info = self._contract.functions.getTEE(tee_id).call() + # TEEInfo tuple order: owner, paymentAddress, endpoint, publicKey, + # tlsCertificate, pcrHash, teeType, active, + # registeredAt, lastUpdatedAt + owner, payment_address, endpoint, _pub_key, tls_cert_der, _pcr_hash, _tee_type, active, _reg_at, _upd_at = info + if not active: + logger.debug(" teeId=%s status=inactive endpoint=%s (skipped)", tee_id_hex, endpoint) + continue + if not endpoint or not tls_cert_der: + logger.warning(" teeId=%s missing endpoint or TLS cert (skipped)", tee_id_hex) + continue + logger.info( + " teeId=%s endpoint=%s paymentAddress=%s certBytes=%d", + tee_id_hex, + endpoint, + payment_address, + len(tls_cert_der), + ) + endpoints.append( + TEEEndpoint( + tee_id=tee_id_hex, + endpoint=endpoint, + tls_cert_der=bytes(tls_cert_der), + payment_address=payment_address, + ) + ) + except Exception as e: + logger.warning("Failed to fetch TEE info for teeId=%s: %s", tee_id_hex, e) + + logger.info("Discovered %d active %s TEE(s) from registry", len(endpoints), type_label) + return endpoints + + def get_llm_tee(self) -> Optional[TEEEndpoint]: + """ + Return the first active LLM proxy TEE from the registry. + + Returns: + TEEEndpoint for an active LLM proxy TEE, or None if none are available. + """ + logger.debug("Querying TEE registry for active LLM proxy TEEs...") + tees = self.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) + if tees: + logger.info("Selected LLM TEE: endpoint=%s teeId=%s", tees[0].endpoint, tees[0].tee_id) + else: + logger.warning("No active LLM proxy TEEs found in registry") + return tees[0] if tees else None + + +def build_ssl_context_from_der(der_cert: bytes) -> ssl.SSLContext: + """ + Build an ssl.SSLContext that trusts *only* the given DER-encoded certificate. + + Hostname verification is disabled because TEE servers are typically addressed + by IP while the cert may be issued for a different hostname. The pinned + certificate itself is the trust anchor — only that cert is accepted. + + Args: + der_cert: DER-encoded X.509 certificate bytes as stored in the registry. + + Returns: + ssl.SSLContext configured to accept only the pinned certificate. + """ + pem = ssl.DER_cert_to_PEM_cert(der_cert) + + cert_file = tempfile.NamedTemporaryFile( + prefix="og_tee_tls_", suffix=".pem", delete=False, mode="w" + ) + cert_file.write(pem) + cert_file.flush() + cert_file.close() + + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.load_verify_locations(cert_file.name) + ctx.check_hostname = False # TEE cert may be issued for a hostname; we connect via IP + ctx.verify_mode = ssl.CERT_REQUIRED + return ctx diff --git a/src/opengradient/defaults.py b/src/opengradient/defaults.py index c053225..1df851e 100644 --- a/src/opengradient/defaults.py +++ b/src/opengradient/defaults.py @@ -6,6 +6,7 @@ DEFAULT_INFERENCE_CONTRACT_ADDRESS = "0x8383C9bD7462F12Eb996DD02F78234C0421A6FaE" DEFAULT_SCHEDULER_ADDRESS = "0x7179724De4e7FF9271FA40C0337c7f90C0508eF6" DEFAULT_BLOCKCHAIN_EXPLORER = "https://explorer.opengradient.ai/tx/" -# TODO (Kyle): Add a process to fetch these IPs from the TEE registry -DEFAULT_OPENGRADIENT_LLM_SERVER_URL = "https://3.15.214.21:443" -DEFAULT_OPENGRADIENT_LLM_STREAMING_SERVER_URL = "https://3.15.214.21:443" +# TEE Registry contract on the OG EVM chain — used to discover LLM proxy endpoints +# and fetch their registry-verified TLS certificates instead of blindly trusting TOFU. +DEFAULT_TEE_REGISTRY_ADDRESS = "0x3d641a2791533b4a0000345ea8d509d01e1ec301" +DEFAULT_TEE_REGISTRY_RPC_URL = "http://13.59.43.94:8545" diff --git a/src/opengradient/types.py b/src/opengradient/types.py index 3a13aac..1a703c4 100644 --- a/src/opengradient/types.py +++ b/src/opengradient/types.py @@ -241,6 +241,9 @@ class StreamChunk: is_final: Whether this is the final chunk (before [DONE]) tee_signature: RSA-PSS signature over the response, present on the final chunk tee_timestamp: ISO timestamp from the TEE at signing time, present on the final chunk + tee_id: On-chain TEE registry ID of the enclave that served this request (final chunk only) + tee_endpoint: Endpoint URL of the TEE that served this request (final chunk only) + tee_payment_address: Payment address registered for the TEE (final chunk only) """ choices: List[StreamChoice] @@ -249,6 +252,9 @@ class StreamChunk: is_final: bool = False tee_signature: Optional[str] = None tee_timestamp: Optional[str] = None + tee_id: Optional[str] = None + tee_endpoint: Optional[str] = None + tee_payment_address: Optional[str] = None @classmethod def from_sse_data(cls, data: Dict) -> "StreamChunk": @@ -263,7 +269,9 @@ def from_sse_data(cls, data: Dict) -> "StreamChunk": """ choices = [] for choice_data in data.get("choices", []): - delta_data = choice_data.get("delta", {}) + # The TEE proxy sometimes sends SSE events using the non-streaming "message" + # key instead of the standard streaming "delta" key. Fall back gracefully. + delta_data = choice_data.get("delta") or choice_data.get("message") or {} delta = StreamDelta(content=delta_data.get("content"), role=delta_data.get("role"), tool_calls=delta_data.get("tool_calls")) choice = StreamChoice(delta=delta, index=choice_data.get("index", 0), finish_reason=choice_data.get("finish_reason")) choices.append(choice) @@ -396,6 +404,15 @@ class TextGenerationOutput: tee_timestamp: Optional[str] = None """ISO timestamp from the TEE at signing time.""" + tee_id: Optional[str] = None + """On-chain TEE registry ID (keccak256 of the enclave's public key) of the TEE that served this request.""" + + tee_endpoint: Optional[str] = None + """Endpoint URL of the TEE that served this request, as registered on-chain.""" + + tee_payment_address: Optional[str] = None + """Payment address registered for the TEE that served this request.""" + @dataclass class AbiFunction: diff --git a/tests/client_test.py b/tests/client_test.py index f17283b..721f382 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -21,7 +21,10 @@ @pytest.fixture def mock_web3(): """Create a mock Web3 instance.""" - with patch("src.opengradient.client.client.Web3") as mock: + with ( + patch("src.opengradient.client.client.Web3") as mock, + patch("src.opengradient.client.client.TEERegistry") as mock_tee_registry, + ): mock_instance = MagicMock() mock.return_value = mock_instance mock.HTTPProvider.return_value = MagicMock() @@ -31,6 +34,14 @@ def mock_web3(): mock_instance.eth.gas_price = 1000000000 mock_instance.eth.contract.return_value = MagicMock() + # Return a fake active TEE endpoint so Client.__init__ doesn't need a live registry + mock_tee = MagicMock() + mock_tee.endpoint = "https://test.tee.server" + mock_tee.tls_cert_der = None + mock_tee.tee_id = "test-tee-id" + mock_tee.payment_address = "0xTestPaymentAddress" + mock_tee_registry.return_value.get_llm_tee.return_value = mock_tee + yield mock_instance @@ -194,7 +205,7 @@ def test_llm_completion_success(self, client): ) result = client.llm.completion( - model=TEE_LLM.GPT_4O, + model=TEE_LLM.GPT_5, prompt="Hello", max_tokens=100, ) @@ -215,7 +226,7 @@ def test_llm_chat_success_non_streaming(self, client): ) result = client.llm.chat( - model=TEE_LLM.GPT_4O, + model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hello"}], stream=False, ) @@ -233,7 +244,7 @@ def test_llm_chat_streaming(self, client): mock_stream.return_value = iter(mock_chunks) result = client.llm.chat( - model=TEE_LLM.GPT_4O, + model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hello"}], stream=True, )