|
| 1 | +#!/usr/bin/env python3 |
| 2 | +"""Minimal local smoke test for /v1/audio/speech stream overload behavior. |
| 3 | +
|
| 4 | +Run: |
| 5 | + uv run python scripts/speech_sse_smoke.py |
| 6 | +or: |
| 7 | + uvx --from . python scripts/speech_sse_smoke.py |
| 8 | +""" |
| 9 | + |
| 10 | +from __future__ import annotations |
| 11 | + |
| 12 | +import json |
| 13 | + |
| 14 | +import httpx |
| 15 | + |
| 16 | +from mistralai.client import Mistral, models |
| 17 | + |
| 18 | + |
| 19 | +def _sse_bytes() -> bytes: |
| 20 | + """Return two SSE events: one delta and one done.""" |
| 21 | + events = [ |
| 22 | + "event: speech.audio.delta\n" |
| 23 | + 'data: {"type":"speech.audio.delta","audio_data":"Zm9v"}\n\n', |
| 24 | + "event: speech.audio.done\n" |
| 25 | + 'data: {"type":"speech.audio.done","usage":{"prompt_tokens":1,"total_tokens":1}}\n\n', |
| 26 | + ] |
| 27 | + return "".join(events).encode("utf-8") |
| 28 | + |
| 29 | + |
| 30 | +def _handler(request: httpx.Request) -> httpx.Response: |
| 31 | + assert request.method == "POST" |
| 32 | + assert request.url.path == "/v1/audio/speech" |
| 33 | + |
| 34 | + accept = request.headers.get("accept") |
| 35 | + payload = json.loads(request.content.decode("utf-8")) |
| 36 | + stream = bool(payload.get("stream", False)) |
| 37 | + |
| 38 | + if stream: |
| 39 | + assert accept == "text/event-stream", f"unexpected accept={accept}" |
| 40 | + return httpx.Response( |
| 41 | + 200, |
| 42 | + headers={"content-type": "text/event-stream"}, |
| 43 | + content=_sse_bytes(), |
| 44 | + request=request, |
| 45 | + ) |
| 46 | + |
| 47 | + assert accept == "application/json", f"unexpected accept={accept}" |
| 48 | + return httpx.Response( |
| 49 | + 200, |
| 50 | + headers={"content-type": "application/json"}, |
| 51 | + json={"audio_data": "Zm9v"}, |
| 52 | + request=request, |
| 53 | + ) |
| 54 | + |
| 55 | + |
| 56 | +def main() -> None: |
| 57 | + transport = httpx.MockTransport(_handler) |
| 58 | + client = httpx.Client(transport=transport, base_url="https://api.mistral.ai") |
| 59 | + |
| 60 | + sdk = Mistral(api_key="dummy", client=client) |
| 61 | + |
| 62 | + non_stream = sdk.audio.speech.complete(input="hello", stream=False) |
| 63 | + assert isinstance(non_stream, models.SpeechResponse) |
| 64 | + assert isinstance(non_stream.audio_data, (bytes, bytearray)) |
| 65 | + print(f"stream=False OK: {type(non_stream).__name__}") |
| 66 | + |
| 67 | + stream_resp = sdk.audio.speech.complete(input="hello", stream=True) |
| 68 | + assert hasattr(stream_resp, "__iter__") |
| 69 | + |
| 70 | + collected: list[models.SpeechStreamEvents] = list( |
| 71 | + stream_resp # type: ignore[arg-type] |
| 72 | + ) |
| 73 | + assert len(collected) == 2 |
| 74 | + assert collected[0].event == "speech.audio.delta" |
| 75 | + assert collected[1].event == "speech.audio.done" |
| 76 | + |
| 77 | + print("stream=True OK: received events ->", [event.event for event in collected]) |
| 78 | + print("All good.") |
| 79 | + |
| 80 | + |
| 81 | +if __name__ == "__main__": |
| 82 | + main() |
0 commit comments