Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,37 @@ The `IAudioInference` class supports the following parameters:
- `duration`: Duration of the generated audio in seconds
- `includeCost`: Whether to include cost information in the response

### Text inference streaming

To stream text inference (e.g. LLM chat) over HTTP SSE, set `deliveryMethod="stream"`. The SDK yields content chunks (strings) and a final `IText` with usage and cost:

```python
import asyncio
from runware import Runware, ITextInference, ITextInferenceMessage

async def main() -> None:
runware = Runware(api_key=RUNWARE_API_KEY)
await runware.connect()

request = ITextInference(
model="runware:qwen3-thinking@1",
messages=[ITextInferenceMessage(role="user", content="Explain photosynthesis in one sentence.")],
deliveryMethod="stream",
includeCost=True,
)

stream = await runware.textInference(request)
async for chunk in stream:
if isinstance(chunk, str):
print(chunk, end="", flush=True)
else:
print(chunk)

asyncio.run(main())
```

Streaming uses the same concurrency limit as other requests (`RUNWARE_MAX_CONCURRENT_REQUESTS`). To allow longer streams, set `RUNWARE_TEXT_STREAM_TIMEOUT` (milliseconds; default 600000).

### Model Upload

To upload model using the Runware API, you can use the `uploadModel` method of the `Runware` class. Here are examples:
Expand Down Expand Up @@ -1068,6 +1099,9 @@ RUNWARE_AUDIO_INFERENCE_TIMEOUT=300000 # Audio generation (default: 5 min)
RUNWARE_AUDIO_POLLING_DELAY=1000 # Delay between status checks (default: 1 sec)
RUNWARE_MAX_POLLS_AUDIO_GENERATION=240 # Max polling attempts for audio inference (default: 240, ~4 min total)

# Text Operations (milliseconds)
RUNWARE_TEXT_STREAM_TIMEOUT=600000 # Text inference streaming (SSE) read timeout (default: 10 min)

# Other Operations (milliseconds)
RUNWARE_PROMPT_ENHANCE_TIMEOUT=60000 # Prompt enhancement (default: 1 min)
RUNWARE_WEBHOOK_TIMEOUT=30000 # Webhook acknowledgment (default: 30 sec)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
aiofiles==23.2.1
httpx>=0.27.0
python-dotenv==1.0.1
websockets>=12.0
92 changes: 80 additions & 12 deletions runware/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import asyncio
import inspect
import json
import logging
import os
import re
from asyncio import gather
from dataclasses import asdict, is_dataclass, fields
from enum import Enum
from random import uniform
from typing import List, Optional, Union, Callable, Any, Dict, Tuple
from typing import List, Optional, Union, Callable, Any, Dict, Tuple, AsyncIterator

import httpx
from websockets.protocol import State

from .logging_config import configure_logging
Expand Down Expand Up @@ -59,11 +61,14 @@
IUploadMediaRequest,
ITextInference,
IText,
ITextInferenceUsage,
ITextInputs,
)
from .types import IImage, IError, SdkType, ListenerType
from .utils import (
BASE_RUNWARE_URLS,
getUUID,
get_http_url_from_ws_url,
fileToBase64,
createImageFromResponse,
createImageToTextFromResponse,
Expand All @@ -82,6 +87,7 @@
createAsyncTaskResponse,
VIDEO_INITIAL_TIMEOUT,
TEXT_INITIAL_TIMEOUT,
TEXT_STREAM_READ_TIMEOUT,
VIDEO_POLLING_DELAY,
WEBHOOK_TIMEOUT,
IMAGE_INFERENCE_TIMEOUT,
Expand Down Expand Up @@ -2018,7 +2024,20 @@ async def _inference3d(self, request3d: I3dInference) -> Union[List[I3d], IAsync
await self.ensureConnection()
return await self._request3d(request3d)

async def textInference(self, requestText: ITextInference) -> Union[List[IText], IAsyncTaskResponse]:
async def textInference(
self, requestText: ITextInference
) -> Union[List[IText], IAsyncTaskResponse, AsyncIterator[Union[str, IText]]]:
delivery_method_enum = (
requestText.deliveryMethod
if isinstance(requestText.deliveryMethod, EDeliveryMethod)
else EDeliveryMethod(requestText.deliveryMethod)
)
if delivery_method_enum == EDeliveryMethod.STREAM:
async def stream_with_semaphore() -> AsyncIterator[Union[str, IText]]:
async with self._request_semaphore:
async for chunk in self._requestTextStream(requestText):
yield chunk
return stream_with_semaphore()
async with self._request_semaphore:
return await self._retry_async_with_reconnect(
self._requestText,
Expand Down Expand Up @@ -2206,26 +2225,75 @@ def _buildTextRequest(self, requestText: ITextInference) -> Dict[str, Any]:
"deliveryMethod": requestText.deliveryMethod,
"messages": [asdict(m) for m in requestText.messages],
}
if requestText.maxTokens is not None:
request_object["maxTokens"] = requestText.maxTokens
if requestText.temperature is not None:
request_object["temperature"] = requestText.temperature
if requestText.topP is not None:
request_object["topP"] = requestText.topP
if requestText.topK is not None:
request_object["topK"] = requestText.topK
if requestText.seed is not None:
request_object["seed"] = requestText.seed
if requestText.stopSequences is not None:
request_object["stopSequences"] = requestText.stopSequences
if requestText.includeCost is not None:
request_object["includeCost"] = requestText.includeCost
self._addOptionalField(request_object, requestText.settings)
self._addOptionalField(request_object, requestText.inputs)
self._addProviderSettings(request_object, requestText)
return request_object

async def _requestTextStream(
self, requestText: ITextInference
) -> AsyncIterator[Union[str, IText]]:
requestText.taskUUID = requestText.taskUUID or getUUID()
request_object = self._buildTextRequest(requestText)
body = [request_object]
http_url = get_http_url_from_ws_url(self._url or "")
headers = {
"Accept": "text/event-stream",
"Authorization": f"Bearer {self._apiKey}",
"Content-Type": "application/json",
}
try:
async with httpx.AsyncClient(timeout=TEXT_STREAM_READ_TIMEOUT / 1000) as client:
async with client.stream(
"POST",
http_url,
json=body,
headers=headers,
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
try:
line = json.loads(line.replace("data:", "", 1))
except json.JSONDecodeError:
continue
data = line.get("data") or line
if data.get("error") is not None:
raise RunwareAPIError(data["error"])
choice = (data.get("choices") or [{}])[0]
delta = choice.get("delta") or {}
if delta.get("content"):
yield delta.get("content")
if choice.get("finish_reason") is not None:
usage = instantiateDataclass(ITextInferenceUsage, data.get("usage"))
yield IText(
taskType=ETaskType.TEXT_INFERENCE.value,
taskUUID=data.get("taskUUID") or "",
finishReason=choice.get("finish_reason"),
usage=usage,
cost=data.get("cost"),
)
return
except Exception as e:
raise RunwareAPIError({"message": str(e)})

async def _requestText(self, requestText: ITextInference) -> Union[List[IText], IAsyncTaskResponse]:
await self.ensureConnection()
requestText.taskUUID = requestText.taskUUID or getUUID()


if requestText.inputs:
inputs = requestText.inputs
if isinstance(inputs, dict):
inputs = ITextInputs(**inputs)
requestText.inputs = inputs

if inputs.images:
inputs.images = await process_image(inputs.images)

request_object = self._buildTextRequest(requestText)

if requestText.webhookURL:
Expand Down
43 changes: 26 additions & 17 deletions runware/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class EOpenPosePreProcessor(Enum):
class EDeliveryMethod(Enum):
SYNC = "sync"
ASYNC = "async"
STREAM = "stream"

class OperationState(Enum):
"""State machine for pending operations."""
Expand Down Expand Up @@ -817,7 +818,7 @@ def request_key(self) -> str:

@dataclass
class ISettings(SerializableMixin):
# Image
# Image / Text
temperature: Optional[float] = None
systemPrompt: Optional[str] = None
topP: Optional[float] = None
Expand Down Expand Up @@ -846,6 +847,10 @@ class ISettings(SerializableMixin):
expressiveness: Optional[str] = None
removeBackground: Optional[bool] = None
backgroundColor: Optional[str] = None
# Text
maxTokens: Optional[int] = None
topK: Optional[int] = None
stopSequences: Optional[List[str]] = None

def __post_init__(self):
if self.sparseStructure is not None and isinstance(self.sparseStructure, dict):
Expand Down Expand Up @@ -895,6 +900,15 @@ def __post_init__(self):
self.referenceImages = self.references


@dataclass
class ITextInputs(SerializableMixin):
images: Optional[List[Union[str, File]]] = None

@property
def request_key(self) -> str:
return "inputs"


@dataclass
class IAudioInput(SerializableMixin):
id: Optional[str] = None
Expand Down Expand Up @@ -1337,6 +1351,7 @@ class IGoogleProviderSettings(BaseProviderSettings):
generateAudio: Optional[bool] = None
enhancePrompt: Optional[bool] = None
search: Optional[bool] = None
thinkingLevel: Optional[str] = None

@property
def provider_key(self) -> str:
Expand Down Expand Up @@ -1729,16 +1744,7 @@ class ITextInferenceUsage:
thinkingTokens: Optional[int] = None


@dataclass
class IGoogleTextProviderSettings(BaseProviderSettings):
thinkingLevel: Optional[str] = None

@property
def provider_key(self) -> str:
return "google"


TextProviderSettings = IGoogleTextProviderSettings
TextProviderSettings = IGoogleProviderSettings


@dataclass
Expand All @@ -1748,16 +1754,19 @@ class ITextInference:
taskUUID: Optional[str] = None
deliveryMethod: str = "sync"
numberResults: Optional[int] = 1
maxTokens: Optional[int] = None
temperature: Optional[float] = None
topP: Optional[float] = None
topK: Optional[int] = None
seed: Optional[int] = None
stopSequences: Optional[List[str]] = None
Comment on lines -1751 to -1756
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backward comp?

seed: Optional[int] = None
includeCost: Optional[bool] = None
settings: Optional[Union[ISettings, Dict[str, Any]]] = None
inputs: Optional[Union[ITextInputs, Dict[str, Any]]] = None
providerSettings: Optional[TextProviderSettings] = None
webhookURL: Optional[str] = None

def __post_init__(self) -> None:
if self.settings is not None and isinstance(self.settings, dict):
self.settings = ISettings(**self.settings)
if self.inputs is not None and isinstance(self.inputs, dict):
self.inputs = ITextInputs(**self.inputs)


@dataclass
class IText:
Expand Down
27 changes: 27 additions & 0 deletions runware/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,25 @@
Environment.TEST: "ws://localhost:8080",
}

# HTTP REST base URL for streaming (e.g. textInference with deliveryMethod=stream)
BASE_RUNWARE_HTTP_URLS = {
Environment.PRODUCTION: "https://api.runware.ai/v1",
Environment.TEST: "http://localhost:8080",
}

# Map each WebSocket base URL to its HTTP counterpart (for streaming requests).
_WS_TO_HTTP = {
BASE_RUNWARE_URLS[Environment.PRODUCTION]: BASE_RUNWARE_HTTP_URLS[Environment.PRODUCTION],
BASE_RUNWARE_URLS[Environment.TEST]: BASE_RUNWARE_HTTP_URLS[Environment.TEST],
}


def get_http_url_from_ws_url(ws_url: str) -> str:
"""Return the HTTP URL for this ws_url from _WS_TO_HTTP."""
if not ws_url:
return BASE_RUNWARE_HTTP_URLS[Environment.PRODUCTION]
return _WS_TO_HTTP.get(ws_url, BASE_RUNWARE_HTTP_URLS[Environment.PRODUCTION])


RETRY_SDK_COUNTS = {
"GLOBAL": 2,
Expand Down Expand Up @@ -125,6 +144,14 @@
30000
))

# Text streaming read timeout (milliseconds)
# Maximum time to wait for data on the SSE stream; long to avoid ReadTimeout mid-stream
# Used in: _requestTextStream() for deliveryMethod=stream
TEXT_STREAM_READ_TIMEOUT = int(os.environ.get(
"RUNWARE_TEXT_STREAM_TIMEOUT",
600000
))

# Audio generation timeout (milliseconds)
# Maximum time to wait for audio generation completion
# Used in: _waitForAudioCompletion() for single audio generation
Expand Down