Skip to content

Commit 8819512

Browse files
committed
fix: typing
1 parent c6f28b9 commit 8819512

1 file changed

Lines changed: 12 additions & 14 deletions

File tree

src/mistralai/extra/realtime/connection.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
TranscriptionStreamSegmentDelta,
3333
TranscriptionStreamTextDelta,
3434
)
35+
from mistralai.types import UNSET
3536

3637

3738
class UnknownRealtimeEvent(BaseModel):
@@ -41,6 +42,7 @@ class UnknownRealtimeEvent(BaseModel):
4142
- invalid JSON payload
4243
- schema validation failure
4344
"""
45+
4446
type: Optional[str]
4547
content: Any
4648
error: Optional[str] = None
@@ -61,7 +63,6 @@ class UnknownRealtimeEvent(BaseModel):
6163
UnknownRealtimeEvent,
6264
]
6365

64-
6566
_MESSAGE_MODELS: dict[str, Any] = {
6667
"session.created": RealtimeTranscriptionSessionCreated,
6768
"session.updated": RealtimeTranscriptionSessionUpdated,
@@ -113,7 +114,6 @@ def __init__(
113114
) -> None:
114115
self._websocket = websocket
115116
self._session = session
116-
self._audio_format = session.audio_format
117117
self._closed = False
118118
self._initial_events: Deque[RealtimeEvent] = deque(initial_events or [])
119119

@@ -127,7 +127,7 @@ def session(self) -> RealtimeTranscriptionSession:
127127

128128
@property
129129
def audio_format(self) -> AudioFormat:
130-
return self._audio_format
130+
return self._session.audio_format
131131

132132
@property
133133
def is_closed(self) -> bool:
@@ -163,16 +163,13 @@ async def update_session(
163163
if audio_format is None and target_streaming_delay_ms is None:
164164
raise ValueError("At least one session field must be provided")
165165

166-
session_update_data: dict[str, object] = {}
167-
if audio_format is not None:
168-
self._audio_format = audio_format
169-
session_update_data["audio_format"] = audio_format
170-
if target_streaming_delay_ms is not None:
171-
session_update_data["target_streaming_delay_ms"] = (
172-
target_streaming_delay_ms
173-
)
174166
message = RealtimeTranscriptionSessionUpdateMessage(
175-
session=RealtimeTranscriptionSessionUpdatePayload(**session_update_data)
167+
session=RealtimeTranscriptionSessionUpdatePayload(
168+
audio_format=audio_format if audio_format is not None else UNSET,
169+
target_streaming_delay_ms=target_streaming_delay_ms
170+
if target_streaming_delay_ms is not None
171+
else UNSET,
172+
)
176173
)
177174
await self._websocket.send(message.model_dump_json())
178175

@@ -229,6 +226,7 @@ async def events(self) -> AsyncIterator[RealtimeEvent]:
229226
await self.close()
230227

231228
def _apply_session_updates(self, ev: RealtimeEvent) -> None:
232-
if isinstance(ev, RealtimeTranscriptionSessionCreated) or isinstance(ev, RealtimeTranscriptionSessionUpdated):
229+
if isinstance(ev, RealtimeTranscriptionSessionCreated) or isinstance(
230+
ev, RealtimeTranscriptionSessionUpdated
231+
):
233232
self._session = ev.session
234-
self._audio_format = ev.session.audio_format

0 commit comments

Comments
 (0)