Skip to content

Commit 02507bc

Browse files
authored
TTS Evaluation: TTS Results signed URLs (#666)
1 parent 48b0a6b commit 02507bc

8 files changed

Lines changed: 67 additions & 44 deletions

File tree

backend/app/api/routes/stt_evaluations/dataset.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from app.crud.file import get_files_by_ids
1111
from app.crud.language import get_language_by_id
1212
from app.crud.stt_evaluations import (
13+
get_samples_by_dataset_id,
1314
get_stt_dataset_by_id,
1415
list_stt_datasets,
15-
get_samples_by_dataset_id,
1616
)
1717
from app.models.stt_evaluation import (
1818
STTDatasetCreate,
@@ -168,15 +168,10 @@ def get_dataset(
168168
samples = []
169169
for s in sample_records:
170170
signed_url = None
171-
if include_signed_url and storage and s.file_id in file_map:
172-
try:
173-
signed_url = storage.get_signed_url(
174-
file_map.get(s.file_id).object_store_url
175-
)
176-
except Exception as e:
177-
logger.warning(
178-
f"[get_dataset] Failed to generate signed URL for file_id {s.file_id}: {e}"
179-
)
171+
if storage and s.file_id in file_map:
172+
signed_url = storage.get_signed_url(
173+
file_map[s.file_id].object_store_url
174+
)
180175

181176
samples.append(
182177
STTSamplePublic(

backend/app/api/routes/tts_evaluations/evaluation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from app.api.deps import AuthContextDep, SessionDep
99
from app.api.permissions import Permission, require_permission
1010
from app.celery.utils import start_low_priority_job
11+
from app.core.cloud import get_cloud_storage
1112
from app.crud.tts_evaluations import (
1213
create_tts_run,
1314
get_results_by_run_id,
@@ -169,6 +170,9 @@ def get_tts_evaluation_run(
169170
auth_context: AuthContextDep,
170171
run_id: int,
171172
include_results: bool = Query(True, description="Include results in response"),
173+
include_signed_url: bool = Query(
174+
False, description="Include signed URLs for generated audio files"
175+
),
172176
) -> APIResponse[TTSEvaluationRunWithResults]:
173177
"""Get a TTS evaluation run with results."""
174178
run = get_tts_run_by_id(
@@ -185,11 +189,18 @@ def get_tts_evaluation_run(
185189
results_total = 0
186190

187191
if include_results:
192+
storage = None
193+
if include_signed_url:
194+
storage = get_cloud_storage(
195+
session=session, project_id=auth_context.project_.id
196+
)
197+
188198
results, results_total = get_results_by_run_id(
189199
session=session,
190200
run_id=run_id,
191201
org_id=auth_context.organization_.id,
192202
project_id=auth_context.project_.id,
203+
storage=storage,
193204
)
194205

195206
return APIResponse.success_response(

backend/app/api/routes/tts_evaluations/result.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
import logging
44

5-
from fastapi import APIRouter, Body, Depends, HTTPException
5+
from fastapi import APIRouter, Body, Depends, HTTPException, Query
66

77
from app.api.deps import AuthContextDep, SessionDep
88
from app.api.permissions import Permission, require_permission
9+
from app.core.cloud import get_cloud_storage
910
from app.crud.tts_evaluations import (
1011
get_tts_result_by_id,
1112
update_tts_human_feedback,
@@ -92,6 +93,9 @@ def get_result(
9293
session: SessionDep,
9394
auth_context: AuthContextDep,
9495
result_id: int,
96+
include_signed_url: bool = Query(
97+
False, description="Include signed URL for generated audio file"
98+
),
9599
) -> APIResponse[TTSResultPublic]:
96100
"""Get a TTS result by ID."""
97101
result = get_tts_result_by_id(
@@ -104,4 +108,13 @@ def get_result(
104108
if not result:
105109
raise HTTPException(status_code=404, detail="Result not found")
106110

107-
return APIResponse.success_response(data=TTSResultPublic.from_model(result))
111+
signed_url = None
112+
if include_signed_url and result.object_store_url is not None:
113+
storage = get_cloud_storage(
114+
session=session, project_id=auth_context.project_.id
115+
)
116+
signed_url = storage.get_signed_url(result.object_store_url)
117+
118+
return APIResponse.success_response(
119+
data=TTSResultPublic.from_model(result, signed_url=signed_url)
120+
)

backend/app/crud/stt_evaluations/result.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,17 @@
22

33
import logging
44

5-
from sqlmodel import Session, select, func
5+
from sqlmodel import Session, func, select
66

7+
from app.core.cloud.storage import CloudStorage
78
from app.core.exception_handlers import HTTPException
89
from app.core.util import now
910
from app.models.file import File
10-
from app.core.cloud.storage import CloudStorage
1111
from app.models.stt_evaluation import (
1212
STTResult,
13+
STTResultWithSample,
1314
STTSample,
1415
STTSamplePublic,
15-
STTResultWithSample,
1616
)
1717

1818
logger = logging.getLogger(__name__)
@@ -103,14 +103,7 @@ def get_results_by_run_id(
103103
# Convert to response models
104104
results = []
105105
for result, sample, file in rows:
106-
signed_url = None
107-
if storage:
108-
try:
109-
signed_url = storage.get_signed_url(file.object_store_url)
110-
except Exception as e:
111-
logger.warning(
112-
f"[get_results_by_run_id] Failed to generate signed URL: {e}"
113-
)
106+
signed_url = storage.get_signed_url(file.object_store_url) if storage else None
114107

115108
sample_public = STTSamplePublic(
116109
id=sample.id,
@@ -226,4 +219,4 @@ def count_results_by_status(
226219

227220
rows = session.exec(statement).all()
228221

229-
return {status: count for status, count in rows}
222+
return dict(rows)

backend/app/crud/tts_evaluations/cron.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313

1414
from sqlmodel import Session
1515

16-
from app.models.batch_job import BatchJobType
17-
1816
from app.celery.utils import start_low_priority_job
1917
from app.core.batch import GeminiBatchProvider
2018
from app.crud.evaluations.cron_utils import (
@@ -29,7 +27,7 @@
2927
)
3028
from app.crud.tts_evaluations.run import update_tts_run
3129
from app.models import EvaluationRun
32-
from app.models.batch_job import BatchJob
30+
from app.models.batch_job import BatchJob, BatchJobType
3331
from app.models.job import JobStatus
3432
from app.models.stt_evaluation import EvaluationType
3533

@@ -151,7 +149,9 @@ async def _on_batch_succeeded(batch_job: BatchJob, provider_name: str) -> bool:
151149
return True
152150

153151
async def _on_already_succeeded(batch_job: BatchJob, provider_name: str) -> bool:
154-
pending = get_pending_results_for_run(session, run.id, provider_name)
152+
pending = get_pending_results_for_run(
153+
session=session, run_id=run.id, provider=provider_name
154+
)
155155
if pending:
156156
logger.info(
157157
f"{log_prefix} Dispatching reprocessing for "

backend/app/crud/tts_evaluations/result.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from sqlmodel import Session, func, select
77

8+
from app.core.cloud.storage import CloudStorage
89
from app.core.exception_handlers import HTTPException
910
from app.core.util import now
1011
from app.models.job import JobStatus
@@ -104,6 +105,7 @@ def get_results_by_run_id(
104105
run_id: int,
105106
org_id: int,
106107
project_id: int,
108+
storage: CloudStorage | None = None,
107109
) -> tuple[list[TTSResultPublic], int]:
108110
"""Get all results for an evaluation run.
109111
@@ -112,6 +114,7 @@ def get_results_by_run_id(
112114
run_id: Run ID
113115
org_id: Organization ID
114116
project_id: Project ID
117+
storage: Optional cloud storage instance for generating signed URLs
115118
116119
Returns:
117120
tuple[list[TTSResultPublic], int]: Results and total count
@@ -127,7 +130,12 @@ def get_results_by_run_id(
127130
rows = session.exec(statement).all()
128131
total = len(rows)
129132

130-
results = [TTSResultPublic.from_model(result) for result in rows]
133+
results = []
134+
for result in rows:
135+
signed_url = (
136+
storage.get_signed_url(result.object_store_url) if storage else None
137+
)
138+
results.append(TTSResultPublic.from_model(result, signed_url=signed_url))
131139

132140
return results, total
133141

@@ -294,4 +302,4 @@ def count_results_by_status(
294302

295303
rows = session.exec(statement).all()
296304

297-
return {status: count for status, count in rows}
305+
return dict(rows)

backend/app/models/tts_evaluation.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ class TTSResultPublic(BaseModel):
209209
id: int
210210
sample_text: str
211211
object_store_url: str | None
212+
signed_url: str | None = None
212213
duration_seconds: float | None = None
213214
size_bytes: int | None = None
214215
provider: str
@@ -224,12 +225,18 @@ class TTSResultPublic(BaseModel):
224225
updated_at: datetime
225226

226227
@classmethod
227-
def from_model(cls, result: TTSResult) -> TTSResultPublic:
228+
def from_model(
229+
cls,
230+
result: TTSResult,
231+
*,
232+
signed_url: str | None = None,
233+
) -> TTSResultPublic:
228234
"""Create from a TTSResult model instance."""
229235
return cls(
230236
id=result.id,
231237
sample_text=result.sample_text,
232238
object_store_url=result.object_store_url,
239+
signed_url=signed_url,
233240
duration_seconds=(result.metadata_ or {}).get("duration_seconds"),
234241
size_bytes=(result.metadata_ or {}).get("size_bytes"),
235242
provider=result.provider,

backend/app/tests/api/routes/test_stt_evaluation.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
import pytest
21
from unittest.mock import MagicMock, patch
32

3+
import pytest
44
from fastapi.testclient import TestClient
55
from sqlmodel import Session
66

7-
from app.models import EvaluationDataset, EvaluationRun, File, FileType
8-
from app.models.stt_evaluation import STTSample, STTResult, EvaluationType
7+
from app.core.util import now
98
from app.crud.language import get_language_by_locale
9+
from app.models import EvaluationDataset, EvaluationRun, File, FileType
10+
from app.models.stt_evaluation import EvaluationType, STTResult, STTSample
1011
from app.tests.utils.auth import TestAuthContext
11-
from app.core.util import now
1212

1313

1414
# Helper functions
@@ -741,9 +741,7 @@ def test_get_stt_dataset_signed_url_failure(
741741
"""Test getting an STT dataset when signed URL generation fails."""
742742
# Mock cloud storage to raise an exception
743743
mock_storage = MagicMock()
744-
mock_storage.get_signed_url.side_effect = Exception(
745-
"Failed to generate signed URL"
746-
)
744+
mock_storage.get_signed_url.return_value = None
747745
mock_get_cloud_storage.return_value = mock_storage
748746

749747
dataset = create_test_stt_dataset(
@@ -1076,9 +1074,7 @@ def test_get_stt_run_signed_url_failure(
10761074
"""Test getting an STT run when signed URL generation fails."""
10771075
# Mock cloud storage to raise an exception
10781076
mock_storage = MagicMock()
1079-
mock_storage.get_signed_url.side_effect = Exception(
1080-
"Failed to generate signed URL"
1081-
)
1077+
mock_storage.get_signed_url.return_value = None
10821078
mock_get_cloud_storage.return_value = mock_storage
10831079

10841080
# Create dataset, sample, run, and result
@@ -1278,7 +1274,7 @@ def test_list_audio_files_with_signed_urls(
12781274
mock_get_cloud_storage.return_value = mock_storage
12791275

12801276
# Create test file
1281-
file = create_test_file(
1277+
_file = create_test_file(
12821278
db=db,
12831279
organization_id=user_api_key.organization_id,
12841280
project_id=user_api_key.project_id,
@@ -1314,7 +1310,7 @@ def test_list_audio_files_without_signed_urls(
13141310
) -> None:
13151311
"""Test listing audio files without signed URLs."""
13161312
# Create test file
1317-
file = create_test_file(
1313+
_file = create_test_file(
13181314
db=db,
13191315
organization_id=user_api_key.organization_id,
13201316
project_id=user_api_key.project_id,
@@ -1345,7 +1341,7 @@ def test_list_audio_files_project_isolation(
13451341
) -> None:
13461342
"""Test that audio files are isolated by project."""
13471343
# Create file in user's project
1348-
user_file = create_test_file(
1344+
_user_file = create_test_file(
13491345
db=db,
13501346
organization_id=user_api_key.organization_id,
13511347
project_id=user_api_key.project_id,

0 commit comments

Comments
 (0)