Skip to content

Commit c646b36

Browse files
committed
Add AgentState handling and token management to WHIP client
1 parent a3e55d2 commit c646b36

File tree

4 files changed

+51
-7
lines changed

4 files changed

+51
-7
lines changed

src/streamcore/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .audio import CHANNELS, FRAME_SIZE, SAMPLE_RATE
22
from .types import (
3+
AgentState,
34
ConnectionStatus,
45
TranscriptEntry,
56
DataChannelMessage,
@@ -13,6 +14,7 @@
1314
"Client",
1415
"Config",
1516
"EventHandler",
17+
"AgentState",
1618
"ConnectionStatus",
1719
"TranscriptEntry",
1820
"DataChannelMessage",

src/streamcore/client.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
from threading import Lock
77

8+
import aiohttp
89
import av
910
import numpy as np
1011
from aiortc import (
@@ -17,6 +18,7 @@
1718

1819
from .audio import SAMPLE_RATE, _SDKAudioTrack
1920
from .types import (
21+
AgentState,
2022
Config,
2123
ConnectionStatus,
2224
DataChannelMessage,
@@ -104,6 +106,7 @@ def on_dc_message(message):
104106
text=data.get("text", ""),
105107
final=data.get("final", False),
106108
message=data.get("message", ""),
109+
state=data.get("state", ""),
107110
)
108111
if self.events.on_data_channel_message:
109112
self.events.on_data_channel_message(msg)
@@ -132,8 +135,21 @@ async def on_connection_state_change():
132135
# so localDescription.sdp already contains all candidates.
133136
offer_sdp = pc.localDescription.sdp
134137

138+
# Fetch a fresh token from the token endpoint if configured.
139+
token = self.config.token
140+
if self.config.token_url:
141+
fetch_headers: dict[str, str] = {}
142+
if self.config.api_key:
143+
fetch_headers["Authorization"] = f"Bearer {self.config.api_key}"
144+
async with aiohttp.ClientSession() as http_session:
145+
async with http_session.post(self.config.token_url, headers=fetch_headers) as resp:
146+
if resp.status != 200:
147+
raise RuntimeError(f"Token request failed ({resp.status})")
148+
data = await resp.json()
149+
token = data["token"]
150+
135151
# WHIP exchange.
136-
result = await whip_offer(self.config.whip_endpoint, offer_sdp)
152+
result = await whip_offer(self.config.whip_endpoint, offer_sdp, token)
137153
self._session_url = result.session_url
138154

139155
answer = RTCSessionDescription(sdp=result.answer_sdp, type="answer")
@@ -152,7 +168,7 @@ async def disconnect(self) -> None:
152168
if self._sdk_track is not None:
153169
self._sdk_track.stop()
154170
self._sdk_track = None
155-
await whip_delete(self._session_url)
171+
await whip_delete(self._session_url, self.config.token)
156172
self._session_url = ""
157173
if self._blackhole:
158174
await self._blackhole.stop()
@@ -270,6 +286,14 @@ def _handle_data_channel_message(self, msg: DataChannelMessage) -> None:
270286
if self.events.on_error:
271287
self.events.on_error(RuntimeError(msg.message))
272288
return
289+
290+
elif msg.type == "state":
291+
if self.events.on_agent_state_change and msg.state:
292+
try:
293+
self.events.on_agent_state_change(AgentState(msg.state))
294+
except ValueError:
295+
logger.warning("Unknown agent state: %s", msg.state)
296+
return
273297
else:
274298
return
275299

src/streamcore/types.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,29 @@ class TranscriptEntry:
2020
partial: bool = False
2121

2222

23+
class AgentState(str, Enum):
24+
LISTENING = "listening"
25+
THINKING = "thinking"
26+
SPEAKING = "speaking"
27+
28+
2329
@dataclass
2430
class DataChannelMessage:
25-
type: str # "transcript", "response", or "error"
31+
type: str # "transcript", "response", "error", or "state"
2632
text: str = ""
2733
final: bool = False
2834
message: str = "" # for error type
35+
state: str = "" # for state type
2936

3037

3138
@dataclass
3239
class Config:
3340
"""Configuration for a StreamCoreAIClient."""
3441

3542
whip_endpoint: str = "http://localhost:8080/whip"
43+
token: str = "" # JWT token for authenticating with the WHIP endpoint
44+
token_url: str = "" # Token endpoint URL; if set, fetches a JWT before each connection (overrides token)
45+
api_key: str = "" # API key sent as Bearer header when fetching from token_url
3646
ice_servers: list[str] = field(
3747
default_factory=lambda: ["stun:stun.l.google.com:19302"]
3848
)
@@ -47,4 +57,5 @@ class EventHandler:
4757
None
4858
)
4959
on_error: Callable[[Exception], None] | None = None
60+
on_agent_state_change: Callable[[AgentState], None] | None = None
5061
on_data_channel_message: Callable[[DataChannelMessage], None] | None = None

src/streamcore/whip.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,20 @@ class WhipResult:
1212
session_url: str
1313

1414

15-
async def whip_offer(endpoint: str, offer_sdp: str) -> WhipResult:
15+
async def whip_offer(endpoint: str, offer_sdp: str, token: str = "") -> WhipResult:
1616
"""Perform a WHIP signaling exchange per RFC 9725 §4.2.
1717
1818
POST an SDP offer, receive a 201 Created with SDP answer and Location header.
1919
"""
20+
headers: dict[str, str] = {"Content-Type": "application/sdp"}
21+
if token:
22+
headers["Authorization"] = f"Bearer {token}"
23+
2024
async with aiohttp.ClientSession() as session:
2125
async with session.post(
2226
endpoint,
2327
data=offer_sdp,
24-
headers={"Content-Type": "application/sdp"},
28+
headers=headers,
2529
) as resp:
2630
if resp.status != 201:
2731
body = await resp.text()
@@ -40,17 +44,20 @@ async def whip_offer(endpoint: str, offer_sdp: str) -> WhipResult:
4044
return WhipResult(answer_sdp=answer_sdp, session_url=session_url)
4145

4246

43-
async def whip_delete(session_url: str) -> None:
47+
async def whip_delete(session_url: str, token: str = "") -> None:
4448
"""Terminate a WHIP session per RFC 9725 §4.2.
4549
4650
Send HTTP DELETE to the WHIP session URL. Best-effort; errors are ignored.
4751
"""
4852
if not session_url:
4953
return
5054
try:
55+
headers: dict[str, str] = {}
56+
if token:
57+
headers["Authorization"] = f"Bearer {token}"
5158
timeout = aiohttp.ClientTimeout(total=3)
5259
async with aiohttp.ClientSession(timeout=timeout) as session:
53-
async with session.delete(session_url):
60+
async with session.delete(session_url, headers=headers):
5461
pass
5562
except Exception:
5663
pass

0 commit comments

Comments
 (0)