Skip to content

Commit 2e67e0c

Browse files
authored
Add user_id as input to header too for http batch V2 API
Add user_id as input to header too for http batch V2 API
2 parents ea666dd + a57f199 commit 2e67e0c

2 files changed

Lines changed: 109 additions & 4 deletions

File tree

sdk/batch/speechmatics/batch/_async_client.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ async def submit_job(
141141
config: Optional[JobConfig] = None,
142142
transcription_config: Optional[TranscriptionConfig] = None,
143143
parallel_engines: Optional[int] = None,
144+
user_id: Optional[str] = None,
144145
) -> JobDetails:
145146
"""
146147
Submit a new transcription job.
@@ -159,6 +160,9 @@ async def submit_job(
159160
parallel_engines: Optional number of parallel engines to request for this job.
160161
Sent as ``{"parallel_engines": N}`` in the ``X-SM-Processing-Data`` header.
161162
This only applies when using the container onPrem on http batch mode.
163+
user_id: Optional user identifier to associate with this job.
164+
Sent as ``{"user_id": "..."}`` in the ``X-SM-Processing-Data`` header.
165+
This only applies when using the container onPrem on http batch mode.
162166
163167
Returns:
164168
JobDetails object containing the job ID and initial status.
@@ -205,7 +209,9 @@ async def submit_job(
205209
assert audio_file is not None # for type checker; validated above
206210
multipart_data, filename = await self._prepare_file_submission(audio_file, config_dict)
207211

208-
return await self._submit_and_create_job_details(multipart_data, filename, config, parallel_engines)
212+
return await self._submit_and_create_job_details(
213+
multipart_data, filename, config, parallel_engines, user_id
214+
)
209215
except Exception as e:
210216
if isinstance(e, (AuthenticationError, BatchError)):
211217
raise
@@ -441,6 +447,7 @@ async def transcribe(
441447
polling_interval: float = 5.0,
442448
timeout: Optional[float] = None,
443449
parallel_engines: Optional[int] = None,
450+
user_id: Optional[str] = None,
444451
) -> Union[Transcript, str]:
445452
"""
446453
Complete transcription workflow: submit job and wait for completion.
@@ -457,6 +464,9 @@ async def transcribe(
457464
parallel_engines: Optional number of parallel engines to request for this job.
458465
Sent as ``{"parallel_engines": N}`` in the ``X-SM-Processing-Data`` header.
459466
This only applies when using the container onPrem on http batch mode.
467+
user_id: Optional user identifier to associate with this job.
468+
Sent as ``{"user_id": "..."}`` in the ``X-SM-Processing-Data`` header.
469+
This only applies when using the container onPrem on http batch mode.
460470
461471
Returns:
462472
Transcript object containing the transcript and metadata.
@@ -485,6 +495,7 @@ async def transcribe(
485495
config=config,
486496
transcription_config=transcription_config,
487497
parallel_engines=parallel_engines,
498+
user_id=user_id,
488499
)
489500

490501
# Wait for completion and return result
@@ -538,12 +549,22 @@ async def _prepare_file_submission(self, audio_file: Union[str, BinaryIO], confi
538549
return multipart_data, filename
539550

540551
async def _submit_and_create_job_details(
541-
self, multipart_data: dict, filename: str, config: JobConfig, parallel_engines: Optional[int] = None
552+
self,
553+
multipart_data: dict,
554+
filename: str,
555+
config: JobConfig,
556+
parallel_engines: Optional[int] = None,
557+
user_id: Optional[str] = None,
542558
) -> JobDetails:
543559
"""Submit job and create JobDetails response."""
544560
extra_headers: Optional[dict[str, Any]] = None
561+
processing_data: dict[str, Any] = {}
545562
if parallel_engines is not None:
546-
extra_headers = {PROCESSING_DATA_HEADER: {"parallel_engines": parallel_engines}}
563+
processing_data["parallel_engines"] = parallel_engines
564+
if user_id is not None:
565+
processing_data["user_id"] = user_id
566+
if processing_data:
567+
extra_headers = {PROCESSING_DATA_HEADER: processing_data}
547568
response = await self._transport.post("/jobs", multipart_data=multipart_data, extra_headers=extra_headers)
548569
job_id = response.get("id")
549570
if not job_id:

tests/batch/test_submit_job.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Unit tests for AsyncClient.submit_job, focusing on the parallel_engines feature."""
1+
"""Unit tests for AsyncClient.submit_job, focusing on the parallel engines and user_id features."""
22

33
import json
44
from io import BytesIO
@@ -127,6 +127,90 @@ async def test_header_sent_with_fetch_data_config(self):
127127
assert payload == {"parallel_engines": 2}
128128

129129

130+
class TestUserIdHeader:
131+
"""X-SM-Processing-Data header is set correctly based on user_id."""
132+
133+
@pytest.mark.asyncio
134+
async def test_header_sent_when_user_id_provided(self):
135+
client = _make_client()
136+
audio = BytesIO(b"fake-audio")
137+
138+
with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
139+
mock_post.return_value = _job_response()
140+
await client.submit_job(audio, user_id="user-abc")
141+
142+
extra_headers = _captured_extra_headers(mock_post)
143+
assert extra_headers is not None
144+
assert PROCESSING_DATA_HEADER in extra_headers
145+
payload = extra_headers[PROCESSING_DATA_HEADER]
146+
assert payload == {"user_id": "user-abc"}
147+
148+
@pytest.mark.asyncio
149+
async def test_header_not_sent_when_user_id_is_none(self):
150+
client = _make_client()
151+
audio = BytesIO(b"fake-audio")
152+
153+
with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
154+
mock_post.return_value = _job_response()
155+
await client.submit_job(audio)
156+
157+
extra_headers = _captured_extra_headers(mock_post)
158+
assert extra_headers is None
159+
160+
@pytest.mark.asyncio
161+
async def test_user_id_and_parallel_engines_sent_together(self):
162+
client = _make_client()
163+
audio = BytesIO(b"fake-audio")
164+
165+
with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
166+
mock_post.return_value = _job_response()
167+
await client.submit_job(audio, parallel_engines=4, user_id="user-xyz")
168+
169+
extra_headers = _captured_extra_headers(mock_post)
170+
assert extra_headers is not None
171+
payload = extra_headers[PROCESSING_DATA_HEADER]
172+
assert payload == {"parallel_engines": 4, "user_id": "user-xyz"}
173+
174+
@pytest.mark.asyncio
175+
async def test_user_id_does_not_appear_when_only_parallel_engines_set(self):
176+
client = _make_client()
177+
audio = BytesIO(b"fake-audio")
178+
179+
with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
180+
mock_post.return_value = _job_response()
181+
await client.submit_job(audio, parallel_engines=2)
182+
183+
payload = _captured_extra_headers(mock_post)[PROCESSING_DATA_HEADER]
184+
assert "user_id" not in payload
185+
186+
@pytest.mark.asyncio
187+
async def test_parallel_engines_does_not_appear_when_only_user_id_set(self):
188+
client = _make_client()
189+
audio = BytesIO(b"fake-audio")
190+
191+
with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
192+
mock_post.return_value = _job_response()
193+
await client.submit_job(audio, user_id="u1")
194+
195+
payload = _captured_extra_headers(mock_post)[PROCESSING_DATA_HEADER]
196+
assert "parallel_engines" not in payload
197+
198+
@pytest.mark.asyncio
199+
async def test_user_id_forwarded_from_transcribe(self):
200+
client = _make_client()
201+
audio = BytesIO(b"fake-audio")
202+
203+
with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
204+
mock_post.return_value = _job_response()
205+
with patch.object(client, "wait_for_completion", new_callable=AsyncMock) as mock_wait:
206+
mock_wait.return_value = MagicMock()
207+
await client.transcribe(audio, user_id="transcribe-user")
208+
209+
extra_headers = _captured_extra_headers(mock_post)
210+
assert extra_headers is not None
211+
assert extra_headers[PROCESSING_DATA_HEADER]["user_id"] == "transcribe-user"
212+
213+
130214
class TestSubmitJobReturnValue:
131215
"""submit_job still returns the correct JobDetails regardless of parallel_engines."""
132216

0 commit comments

Comments
 (0)