diff --git a/README.md b/README.md index 5008ca0..c939db5 100644 --- a/README.md +++ b/README.md @@ -832,6 +832,37 @@ The `IAudioInference` class supports the following parameters: - `duration`: Duration of the generated audio in seconds - `includeCost`: Whether to include cost information in the response +### Text inference streaming + +To stream text inference (e.g. LLM chat) over HTTP SSE, set `deliveryMethod="stream"`. The SDK yields content chunks (strings) and a final `IText` with usage and cost: + +```python +import asyncio +from runware import Runware, ITextInference, ITextInferenceMessage + +async def main() -> None: + runware = Runware(api_key=RUNWARE_API_KEY) + await runware.connect() + + request = ITextInference( + model="runware:qwen3-thinking@1", + messages=[ITextInferenceMessage(role="user", content="Explain photosynthesis in one sentence.")], + deliveryMethod="stream", + includeCost=True, + ) + + stream = await runware.textInference(request) + async for chunk in stream: + if isinstance(chunk, str): + print(chunk, end="", flush=True) + else: + print(chunk) + +asyncio.run(main()) +``` + +Streaming uses the same concurrency limit as other requests (`RUNWARE_MAX_CONCURRENT_REQUESTS`). To allow longer streams, set `RUNWARE_TEXT_STREAM_TIMEOUT` (milliseconds; default 600000). + ### Model Upload To upload model using the Runware API, you can use the `uploadModel` method of the `Runware` class. Here are examples: @@ -1068,6 +1099,9 @@ RUNWARE_AUDIO_INFERENCE_TIMEOUT=300000 # Audio generation (default: 5 min) RUNWARE_AUDIO_POLLING_DELAY=1000 # Delay between status checks (default: 1 sec) RUNWARE_MAX_POLLS_AUDIO_GENERATION=240 # Max polling attempts for audio inference (default: 240, ~4 min total) +# Text Operations (milliseconds) +RUNWARE_TEXT_STREAM_TIMEOUT=600000 # Text inference streaming (SSE) read timeout (default: 10 min) + # Other Operations (milliseconds) RUNWARE_PROMPT_ENHANCE_TIMEOUT=60000 # Prompt enhancement (default: 1 min) RUNWARE_WEBHOOK_TIMEOUT=30000 # Webhook acknowledgment (default: 30 sec) diff --git a/requirements.txt b/requirements.txt index 611060a..f202503 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ aiofiles==23.2.1 +httpx>=0.27.0 python-dotenv==1.0.1 websockets>=12.0 \ No newline at end of file diff --git a/runware/base.py b/runware/base.py index 481b273..bb47a72 100644 --- a/runware/base.py +++ b/runware/base.py @@ -1,5 +1,6 @@ import asyncio import inspect +import json import logging import os import re @@ -7,8 +8,9 @@ from dataclasses import asdict, is_dataclass, fields from enum import Enum from random import uniform -from typing import List, Optional, Union, Callable, Any, Dict, Tuple +from typing import List, Optional, Union, Callable, Any, Dict, Tuple, AsyncIterator +import httpx from websockets.protocol import State from .logging_config import configure_logging @@ -59,11 +61,14 @@ IUploadMediaRequest, ITextInference, IText, + ITextInferenceUsage, + ITextInputs, ) from .types import IImage, IError, SdkType, ListenerType from .utils import ( BASE_RUNWARE_URLS, getUUID, + get_http_url_from_ws_url, fileToBase64, createImageFromResponse, createImageToTextFromResponse, @@ -82,6 +87,7 @@ createAsyncTaskResponse, VIDEO_INITIAL_TIMEOUT, TEXT_INITIAL_TIMEOUT, + TEXT_STREAM_READ_TIMEOUT, VIDEO_POLLING_DELAY, WEBHOOK_TIMEOUT, IMAGE_INFERENCE_TIMEOUT, @@ -2018,7 +2024,20 @@ async def _inference3d(self, request3d: I3dInference) -> Union[List[I3d], IAsync await self.ensureConnection() return await self._request3d(request3d) - async def textInference(self, requestText: ITextInference) -> Union[List[IText], IAsyncTaskResponse]: + async def textInference( + self, requestText: ITextInference + ) -> Union[List[IText], IAsyncTaskResponse, AsyncIterator[Union[str, IText]]]: + delivery_method_enum = ( + requestText.deliveryMethod + if isinstance(requestText.deliveryMethod, EDeliveryMethod) + else EDeliveryMethod(requestText.deliveryMethod) + ) + if delivery_method_enum == EDeliveryMethod.STREAM: + async def stream_with_semaphore() -> AsyncIterator[Union[str, IText]]: + async with self._request_semaphore: + async for chunk in self._requestTextStream(requestText): + yield chunk + return stream_with_semaphore() async with self._request_semaphore: return await self._retry_async_with_reconnect( self._requestText, @@ -2206,26 +2225,75 @@ def _buildTextRequest(self, requestText: ITextInference) -> Dict[str, Any]: "deliveryMethod": requestText.deliveryMethod, "messages": [asdict(m) for m in requestText.messages], } - if requestText.maxTokens is not None: - request_object["maxTokens"] = requestText.maxTokens - if requestText.temperature is not None: - request_object["temperature"] = requestText.temperature - if requestText.topP is not None: - request_object["topP"] = requestText.topP - if requestText.topK is not None: - request_object["topK"] = requestText.topK if requestText.seed is not None: request_object["seed"] = requestText.seed - if requestText.stopSequences is not None: - request_object["stopSequences"] = requestText.stopSequences if requestText.includeCost is not None: request_object["includeCost"] = requestText.includeCost + self._addOptionalField(request_object, requestText.settings) + self._addOptionalField(request_object, requestText.inputs) self._addProviderSettings(request_object, requestText) return request_object + async def _requestTextStream( + self, requestText: ITextInference + ) -> AsyncIterator[Union[str, IText]]: + requestText.taskUUID = requestText.taskUUID or getUUID() + request_object = self._buildTextRequest(requestText) + body = [request_object] + http_url = get_http_url_from_ws_url(self._url or "") + headers = { + "Accept": "text/event-stream", + "Authorization": f"Bearer {self._apiKey}", + "Content-Type": "application/json", + } + try: + async with httpx.AsyncClient(timeout=TEXT_STREAM_READ_TIMEOUT / 1000) as client: + async with client.stream( + "POST", + http_url, + json=body, + headers=headers, + ) as response: + response.raise_for_status() + async for line in response.aiter_lines(): + try: + line = json.loads(line.replace("data:", "", 1)) + except json.JSONDecodeError: + continue + data = line.get("data") or line + if data.get("error") is not None: + raise RunwareAPIError(data["error"]) + choice = (data.get("choices") or [{}])[0] + delta = choice.get("delta") or {} + if delta.get("content"): + yield delta.get("content") + if choice.get("finish_reason") is not None: + usage = instantiateDataclass(ITextInferenceUsage, data.get("usage")) + yield IText( + taskType=ETaskType.TEXT_INFERENCE.value, + taskUUID=data.get("taskUUID") or "", + finishReason=choice.get("finish_reason"), + usage=usage, + cost=data.get("cost"), + ) + return + except Exception as e: + raise RunwareAPIError({"message": str(e)}) + async def _requestText(self, requestText: ITextInference) -> Union[List[IText], IAsyncTaskResponse]: await self.ensureConnection() requestText.taskUUID = requestText.taskUUID or getUUID() + + + if requestText.inputs: + inputs = requestText.inputs + if isinstance(inputs, dict): + inputs = ITextInputs(**inputs) + requestText.inputs = inputs + + if inputs.images: + inputs.images = await process_image(inputs.images) + request_object = self._buildTextRequest(requestText) if requestText.webhookURL: diff --git a/runware/types.py b/runware/types.py index 0d59706..31d0f1b 100644 --- a/runware/types.py +++ b/runware/types.py @@ -106,6 +106,7 @@ class EOpenPosePreProcessor(Enum): class EDeliveryMethod(Enum): SYNC = "sync" ASYNC = "async" + STREAM = "stream" class OperationState(Enum): """State machine for pending operations.""" @@ -817,7 +818,7 @@ def request_key(self) -> str: @dataclass class ISettings(SerializableMixin): - # Image + # Image / Text temperature: Optional[float] = None systemPrompt: Optional[str] = None topP: Optional[float] = None @@ -846,6 +847,10 @@ class ISettings(SerializableMixin): expressiveness: Optional[str] = None removeBackground: Optional[bool] = None backgroundColor: Optional[str] = None + # Text + maxTokens: Optional[int] = None + topK: Optional[int] = None + stopSequences: Optional[List[str]] = None def __post_init__(self): if self.sparseStructure is not None and isinstance(self.sparseStructure, dict): @@ -895,6 +900,15 @@ def __post_init__(self): self.referenceImages = self.references +@dataclass +class ITextInputs(SerializableMixin): + images: Optional[List[Union[str, File]]] = None + + @property + def request_key(self) -> str: + return "inputs" + + @dataclass class IAudioInput(SerializableMixin): id: Optional[str] = None @@ -1337,6 +1351,7 @@ class IGoogleProviderSettings(BaseProviderSettings): generateAudio: Optional[bool] = None enhancePrompt: Optional[bool] = None search: Optional[bool] = None + thinkingLevel: Optional[str] = None @property def provider_key(self) -> str: @@ -1729,16 +1744,7 @@ class ITextInferenceUsage: thinkingTokens: Optional[int] = None -@dataclass -class IGoogleTextProviderSettings(BaseProviderSettings): - thinkingLevel: Optional[str] = None - - @property - def provider_key(self) -> str: - return "google" - - -TextProviderSettings = IGoogleTextProviderSettings +TextProviderSettings = IGoogleProviderSettings @dataclass @@ -1748,16 +1754,19 @@ class ITextInference: taskUUID: Optional[str] = None deliveryMethod: str = "sync" numberResults: Optional[int] = 1 - maxTokens: Optional[int] = None - temperature: Optional[float] = None - topP: Optional[float] = None - topK: Optional[int] = None - seed: Optional[int] = None - stopSequences: Optional[List[str]] = None + seed: Optional[int] = None includeCost: Optional[bool] = None + settings: Optional[Union[ISettings, Dict[str, Any]]] = None + inputs: Optional[Union[ITextInputs, Dict[str, Any]]] = None providerSettings: Optional[TextProviderSettings] = None webhookURL: Optional[str] = None + def __post_init__(self) -> None: + if self.settings is not None and isinstance(self.settings, dict): + self.settings = ISettings(**self.settings) + if self.inputs is not None and isinstance(self.inputs, dict): + self.inputs = ITextInputs(**self.inputs) + @dataclass class IText: diff --git a/runware/utils.py b/runware/utils.py index f590ffe..212a0f7 100644 --- a/runware/utils.py +++ b/runware/utils.py @@ -42,6 +42,25 @@ Environment.TEST: "ws://localhost:8080", } +# HTTP REST base URL for streaming (e.g. textInference with deliveryMethod=stream) +BASE_RUNWARE_HTTP_URLS = { + Environment.PRODUCTION: "https://api.runware.ai/v1", + Environment.TEST: "http://localhost:8080", +} + +# Map each WebSocket base URL to its HTTP counterpart (for streaming requests). +_WS_TO_HTTP = { + BASE_RUNWARE_URLS[Environment.PRODUCTION]: BASE_RUNWARE_HTTP_URLS[Environment.PRODUCTION], + BASE_RUNWARE_URLS[Environment.TEST]: BASE_RUNWARE_HTTP_URLS[Environment.TEST], +} + + +def get_http_url_from_ws_url(ws_url: str) -> str: + """Return the HTTP URL for this ws_url from _WS_TO_HTTP.""" + if not ws_url: + return BASE_RUNWARE_HTTP_URLS[Environment.PRODUCTION] + return _WS_TO_HTTP.get(ws_url, BASE_RUNWARE_HTTP_URLS[Environment.PRODUCTION]) + RETRY_SDK_COUNTS = { "GLOBAL": 2, @@ -125,6 +144,14 @@ 30000 )) +# Text streaming read timeout (milliseconds) +# Maximum time to wait for data on the SSE stream; long to avoid ReadTimeout mid-stream +# Used in: _requestTextStream() for deliveryMethod=stream +TEXT_STREAM_READ_TIMEOUT = int(os.environ.get( + "RUNWARE_TEXT_STREAM_TIMEOUT", + 600000 +)) + # Audio generation timeout (milliseconds) # Maximum time to wait for audio generation completion # Used in: _waitForAudioCompletion() for single audio generation