Skip to content

Commit 4252b44

Browse files
authored
Merge pull request #95 from speechmatics/feature/add-http/requested_parallel
add support for requesting parallel engines in a http batch job
2 parents 4d64231 + fb76926 commit 4252b44

4 files changed

Lines changed: 190 additions & 4 deletions

File tree

sdk/batch/speechmatics/batch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,11 @@
3232
from ._models import TranscriptFilteringConfig
3333
from ._models import TranscriptionConfig
3434
from ._models import TranslationConfig
35+
from ._transport import PROCESSING_DATA_HEADER
3536

3637
__all__ = [
3738
"AsyncClient",
39+
"PROCESSING_DATA_HEADER",
3840
"AuthBase",
3941
"AuthenticationError",
4042
"AutoChaptersConfig",

sdk/batch/speechmatics/batch/_async_client.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from ._models import JobType
3232
from ._models import Transcript
3333
from ._models import TranscriptionConfig
34+
from ._transport import PROCESSING_DATA_HEADER
3435
from ._transport import Transport
3536

3637

@@ -139,6 +140,7 @@ async def submit_job(
139140
*,
140141
config: Optional[JobConfig] = None,
141142
transcription_config: Optional[TranscriptionConfig] = None,
143+
parallel_engines: Optional[int] = None,
142144
) -> JobDetails:
143145
"""
144146
Submit a new transcription job.
@@ -154,6 +156,9 @@ async def submit_job(
154156
to build a basic job configuration.
155157
transcription_config: Transcription-specific configuration. Used if config
156158
is not provided.
159+
parallel_engines: Optional number of parallel engines to request for this job.
160+
Sent as ``{"parallel_engines": N}`` in the ``X-SM-Processing-Data`` header.
161+
This only applies when using the container onPrem on http batch mode.
157162
158163
Returns:
159164
JobDetails object containing the job ID and initial status.
@@ -200,7 +205,7 @@ async def submit_job(
200205
assert audio_file is not None # for type checker; validated above
201206
multipart_data, filename = await self._prepare_file_submission(audio_file, config_dict)
202207

203-
return await self._submit_and_create_job_details(multipart_data, filename, config)
208+
return await self._submit_and_create_job_details(multipart_data, filename, config, parallel_engines)
204209
except Exception as e:
205210
if isinstance(e, (AuthenticationError, BatchError)):
206211
raise
@@ -435,6 +440,7 @@ async def transcribe(
435440
transcription_config: Optional[TranscriptionConfig] = None,
436441
polling_interval: float = 5.0,
437442
timeout: Optional[float] = None,
443+
parallel_engines: Optional[int] = None,
438444
) -> Union[Transcript, str]:
439445
"""
440446
Complete transcription workflow: submit job and wait for completion.
@@ -448,6 +454,9 @@ async def transcribe(
448454
transcription_config: Transcription-specific configuration.
449455
polling_interval: Time in seconds between status checks.
450456
timeout: Maximum time in seconds to wait for completion.
457+
parallel_engines: Optional number of parallel engines to request for this job.
458+
Sent as ``{"parallel_engines": N}`` in the ``X-SM-Processing-Data`` header.
459+
This only applies when using the container onPrem on http batch mode.
451460
452461
Returns:
453462
Transcript object containing the transcript and metadata.
@@ -475,6 +484,7 @@ async def transcribe(
475484
audio_file,
476485
config=config,
477486
transcription_config=transcription_config,
487+
parallel_engines=parallel_engines,
478488
)
479489

480490
# Wait for completion and return result
@@ -528,10 +538,13 @@ async def _prepare_file_submission(self, audio_file: Union[str, BinaryIO], confi
528538
return multipart_data, filename
529539

530540
async def _submit_and_create_job_details(
531-
self, multipart_data: dict, filename: str, config: JobConfig
541+
self, multipart_data: dict, filename: str, config: JobConfig, parallel_engines: Optional[int] = None
532542
) -> JobDetails:
533543
"""Submit job and create JobDetails response."""
534-
response = await self._transport.post("/jobs", multipart_data=multipart_data)
544+
extra_headers: Optional[dict[str, Any]] = None
545+
if parallel_engines is not None:
546+
extra_headers = {PROCESSING_DATA_HEADER: {"parallel_engines": parallel_engines}}
547+
response = await self._transport.post("/jobs", multipart_data=multipart_data, extra_headers=extra_headers)
535548
job_id = response.get("id")
536549
if not job_id:
537550
raise BatchError("No job ID returned from server")

sdk/batch/speechmatics/batch/_transport.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import asyncio
1212
import io
13+
import json as _json
1314
import sys
1415
import uuid
1516
from typing import Any
@@ -25,6 +26,8 @@
2526
from ._logging import get_logger
2627
from ._models import ConnectionConfig
2728

29+
PROCESSING_DATA_HEADER = "X-SM-Processing-Data"
30+
2831

2932
class Transport:
3033
"""
@@ -116,6 +119,7 @@ async def post(
116119
json_data: Optional[dict[str, Any]] = None,
117120
multipart_data: Optional[dict[str, Any]] = None,
118121
timeout: Optional[float] = None,
122+
extra_headers: Optional[dict[str, Any]] = None,
119123
) -> dict[str, Any]:
120124
"""
121125
Send POST request to the API.
@@ -125,6 +129,7 @@ async def post(
125129
json_data: Optional JSON data for request body
126130
multipart_data: Optional multipart form data
127131
timeout: Optional request timeout
132+
extra_headers: Optional additional headers to include in the request
128133
129134
Returns:
130135
JSON response as dictionary
@@ -133,7 +138,14 @@ async def post(
133138
AuthenticationError: If authentication fails
134139
TransportError: If request fails
135140
"""
136-
return await self._request("POST", path, json_data=json_data, multipart_data=multipart_data, timeout=timeout)
141+
return await self._request(
142+
"POST",
143+
path,
144+
json_data=json_data,
145+
multipart_data=multipart_data,
146+
timeout=timeout,
147+
extra_headers=extra_headers,
148+
)
137149

138150
async def delete(self, path: str, timeout: Optional[float] = None) -> dict[str, Any]:
139151
"""
@@ -200,6 +212,7 @@ async def _request(
200212
json_data: Optional[dict[str, Any]] = None,
201213
multipart_data: Optional[dict[str, Any]] = None,
202214
timeout: Optional[float] = None,
215+
extra_headers: Optional[dict[str, Any]] = None,
203216
) -> dict[str, Any]:
204217
"""
205218
Send HTTP request to the API.
@@ -227,6 +240,9 @@ async def _request(
227240

228241
url = f"{self._url.rstrip('/')}{path}"
229242
headers = await self._prepare_headers()
243+
if extra_headers:
244+
for k, v in extra_headers.items():
245+
headers[k] = _json.dumps(v) if isinstance(v, dict) else v
230246

231247
self._logger.debug(
232248
"Sending HTTP request %s %s (json=%s, multipart=%s)",

tests/batch/test_submit_job.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
"""Unit tests for AsyncClient.submit_job, focusing on the parallel_engines feature."""
2+
3+
import json
4+
from io import BytesIO
5+
from unittest.mock import AsyncMock
6+
from unittest.mock import MagicMock
7+
from unittest.mock import patch
8+
9+
from typing import Optional
10+
11+
import pytest
12+
13+
from speechmatics.batch import AsyncClient
14+
from speechmatics.batch import JobConfig
15+
from speechmatics.batch import JobStatus
16+
from speechmatics.batch import JobType
17+
from speechmatics.batch import TranscriptionConfig
18+
from speechmatics.batch import PROCESSING_DATA_HEADER
19+
20+
21+
def _make_client(api_key: str = "test-key") -> AsyncClient:
22+
return AsyncClient(api_key=api_key)
23+
24+
25+
def _job_response(job_id: str = "job-123") -> dict:
26+
return {"id": job_id, "created_at": "2024-01-01T00:00:00Z"}
27+
28+
29+
# ---------------------------------------------------------------------------
30+
# Helpers
31+
# ---------------------------------------------------------------------------
32+
33+
34+
def _captured_extra_headers(mock_post: AsyncMock) -> Optional[dict]:
35+
"""Return the extra_headers kwarg from the first call to transport.post."""
36+
_, kwargs = mock_post.call_args
37+
return kwargs.get("extra_headers")
38+
39+
40+
# ---------------------------------------------------------------------------
41+
# Tests
42+
# ---------------------------------------------------------------------------
43+
44+
45+
class TestRequestedParallelHeader:
46+
"""X-SM-Processing-Data header is set correctly based on parallel_engines."""
47+
48+
@pytest.mark.asyncio
49+
async def test_header_sent_when_parallel_engines_provided(self):
50+
client = _make_client()
51+
audio = BytesIO(b"fake-audio")
52+
53+
with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
54+
mock_post.return_value = _job_response()
55+
await client.submit_job(audio, parallel_engines=4)
56+
57+
extra_headers = _captured_extra_headers(mock_post)
58+
assert extra_headers is not None
59+
assert PROCESSING_DATA_HEADER in extra_headers
60+
payload = extra_headers[PROCESSING_DATA_HEADER]
61+
assert payload == {"parallel_engines": 4}
62+
63+
@pytest.mark.asyncio
64+
async def test_header_not_sent_when_parallel_engines_is_none(self):
65+
client = _make_client()
66+
audio = BytesIO(b"fake-audio")
67+
68+
with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
69+
mock_post.return_value = _job_response()
70+
await client.submit_job(audio)
71+
72+
extra_headers = _captured_extra_headers(mock_post)
73+
assert extra_headers is None
74+
75+
@pytest.mark.asyncio
76+
async def test_header_value_is_valid_json(self):
77+
client = _make_client()
78+
audio = BytesIO(b"fake-audio")
79+
80+
with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
81+
mock_post.return_value = _job_response()
82+
await client.submit_job(audio, parallel_engines=8)
83+
84+
extra_headers = _captured_extra_headers(mock_post)
85+
# Must be parseable JSON
86+
assert extra_headers is not None
87+
parsed = extra_headers[PROCESSING_DATA_HEADER]
88+
assert parsed["parallel_engines"] == 8
89+
90+
@pytest.mark.asyncio
91+
async def test_parallel_engines_one(self):
92+
client = _make_client()
93+
audio = BytesIO(b"fake-audio")
94+
95+
with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
96+
mock_post.return_value = _job_response()
97+
await client.submit_job(audio, parallel_engines=1)
98+
99+
extra_headers = _captured_extra_headers(mock_post)
100+
assert extra_headers is not None
101+
payload = extra_headers[PROCESSING_DATA_HEADER]
102+
assert payload["parallel_engines"] == 1
103+
104+
@pytest.mark.asyncio
105+
async def test_header_sent_with_fetch_data_config(self):
106+
"""parallel_engines works with fetch_data submissions too."""
107+
client = _make_client()
108+
config = JobConfig(
109+
type=JobType.TRANSCRIPTION,
110+
fetch_data=MagicMock(url="https://example.com/audio.wav"),
111+
transcription_config=TranscriptionConfig(language="en"),
112+
)
113+
# Patch to_dict so fetch_data key is present
114+
config_dict = {
115+
"type": "transcription",
116+
"fetch_data": {"url": "https://example.com/audio.wav"},
117+
"transcription_config": {"language": "en"},
118+
}
119+
with patch.object(config, "to_dict", return_value=config_dict):
120+
with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
121+
mock_post.return_value = _job_response()
122+
await client.submit_job(None, config=config, parallel_engines=2)
123+
124+
extra_headers = _captured_extra_headers(mock_post)
125+
assert extra_headers is not None
126+
payload = extra_headers[PROCESSING_DATA_HEADER]
127+
assert payload == {"parallel_engines": 2}
128+
129+
130+
class TestSubmitJobReturnValue:
131+
"""submit_job still returns the correct JobDetails regardless of parallel_engines."""
132+
133+
@pytest.mark.asyncio
134+
async def test_returns_job_details_with_correct_id(self):
135+
client = _make_client()
136+
audio = BytesIO(b"fake-audio")
137+
138+
with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
139+
mock_post.return_value = _job_response("abc-456")
140+
job = await client.submit_job(audio, parallel_engines=3)
141+
142+
assert job.id == "abc-456"
143+
assert job.status == JobStatus.RUNNING
144+
145+
@pytest.mark.asyncio
146+
async def test_post_called_with_jobs_path(self):
147+
client = _make_client()
148+
audio = BytesIO(b"fake-audio")
149+
150+
with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
151+
mock_post.return_value = _job_response()
152+
await client.submit_job(audio, parallel_engines=2)
153+
154+
args, _ = mock_post.call_args
155+
assert args[0] == "/jobs"

0 commit comments

Comments
 (0)