-
Notifications
You must be signed in to change notification settings - Fork 30
Cu-86d2e1ka4: simplify inference module for image embedding generation only #262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1 @@ | ||
| [aws] | ||
| model_storage_bucket=${MODEL_STORAGE_BUCKET} | ||
|
|
||
| [gcp] | ||
| model_storage_bucket=${MODEL_STORAGE_BUCKET} | ||
|
|
||
| [cloud_config] | ||
| cloud_provider=${CLOUD_PROVIDER} |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,150 +1,24 @@ | ||||||||||||||||||||||||||||||||||
| import base64 | ||||||||||||||||||||||||||||||||||
| from typing import Any, Dict | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| from common_module.common_container import CommonContainer | ||||||||||||||||||||||||||||||||||
| from common_module.log.logger import logger | ||||||||||||||||||||||||||||||||||
| from common_module.response_formatter import ResponseFormatter | ||||||||||||||||||||||||||||||||||
| from dependency_injector.wiring import Provide, inject | ||||||||||||||||||||||||||||||||||
| from fastapi import APIRouter, Depends, status | ||||||||||||||||||||||||||||||||||
| from fastapi.responses import JSONResponse | ||||||||||||||||||||||||||||||||||
| from inference_app.inference_app_container import InferenceAppContainer | ||||||||||||||||||||||||||||||||||
| from inference_app.service.image_analyser import ImageClarityService | ||||||||||||||||||||||||||||||||||
| from inference_app.service.model_inference import ( | ||||||||||||||||||||||||||||||||||
| ModelInferenceService, | ||||||||||||||||||||||||||||||||||
| PreprocessingStep, | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| from inference_app.service.model_repository import ModelRepository | ||||||||||||||||||||||||||||||||||
| from inference_app.service.image_embedding import ImageEmbedding | ||||||||||||||||||||||||||||||||||
| from pydantic import BaseModel, Field | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| class InferencePayload(BaseModel): | ||||||||||||||||||||||||||||||||||
| data: str | ||||||||||||||||||||||||||||||||||
| payload_type: str | ||||||||||||||||||||||||||||||||||
| model_info: dict | ||||||||||||||||||||||||||||||||||
| preprocessing_steps: list[PreprocessingStep] | ||||||||||||||||||||||||||||||||||
| max_expected_variance: int = Field(default=1000) | ||||||||||||||||||||||||||||||||||
| resize_width: int = Field(default=224) | ||||||||||||||||||||||||||||||||||
| resize_height: int = Field(default=224) | ||||||||||||||||||||||||||||||||||
| gaussian_blur_kernel: int = Field(default=3) | ||||||||||||||||||||||||||||||||||
| min_threshold: int = Field(default=50) | ||||||||||||||||||||||||||||||||||
| max_threshold: int = Field(default=150) | ||||||||||||||||||||||||||||||||||
| normalize_mean: str = Field(default='0.485,0.456,0.406') | ||||||||||||||||||||||||||||||||||
| normalize_std: str = Field(default='0.229,0.224,0.225') | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| class InferenceResult(BaseModel): | ||||||||||||||||||||||||||||||||||
| results: Dict[str, Any] = Field(..., description='Dictionary of inference results') | ||||||||||||||||||||||||||||||||||
| from pydantic import BaseModel | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| class ImagePayload(BaseModel): | ||||||||||||||||||||||||||||||||||
| image_data: str | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| inference_app_router = APIRouter() | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @inference_app_router.post('/v1/model-repository/model/{model_id}/infer') | ||||||||||||||||||||||||||||||||||
| @inject | ||||||||||||||||||||||||||||||||||
| async def generic_inference_handler( | ||||||||||||||||||||||||||||||||||
| payload: InferencePayload, | ||||||||||||||||||||||||||||||||||
| response_formatter: ResponseFormatter = Depends( | ||||||||||||||||||||||||||||||||||
| Provide[CommonContainer.response_formatter] | ||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||
| model_repository: ModelRepository = Depends( | ||||||||||||||||||||||||||||||||||
| Provide[InferenceAppContainer.model_repository] | ||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||
| image_analyser: ImageClarityService = Depends( | ||||||||||||||||||||||||||||||||||
| Provide[InferenceAppContainer.image_analyser] | ||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||
| config: dict = Depends(Provide[InferenceAppContainer.config]), | ||||||||||||||||||||||||||||||||||
| model_inference: ModelInferenceService = Depends( | ||||||||||||||||||||||||||||||||||
| Provide[InferenceAppContainer.model_inference] | ||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||
| provider = config['cloud_config']['cloud_provider'] | ||||||||||||||||||||||||||||||||||
| model_storage_bucket = ( | ||||||||||||||||||||||||||||||||||
| config['gcp']['model_storage_bucket'] | ||||||||||||||||||||||||||||||||||
| if provider.lower() == 'gcp' | ||||||||||||||||||||||||||||||||||
| else config['aws']['model_storage_bucket'] | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| logger.info( | ||||||||||||||||||||||||||||||||||
| f'Loading model from bucket: {model_storage_bucket}, model_info: {payload.model_info}' | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| model = await model_repository.load_model( | ||||||||||||||||||||||||||||||||||
| model_info=payload.model_info, bucket_name=model_storage_bucket | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| logger.debug('Model loaded successfully for model_id') | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if payload.payload_type.lower() == 'image': | ||||||||||||||||||||||||||||||||||
| base64_data_uri = payload.data | ||||||||||||||||||||||||||||||||||
| parts = base64_data_uri.split(',') | ||||||||||||||||||||||||||||||||||
| if len(parts) == 2: | ||||||||||||||||||||||||||||||||||
| base64_data = parts[1] | ||||||||||||||||||||||||||||||||||
| image_bytes = base64.b64decode(base64_data) | ||||||||||||||||||||||||||||||||||
| image_data: str # base64 encoded image data | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| clarity_score = image_analyser.laplacian_detection( | ||||||||||||||||||||||||||||||||||
| image_bytes, payload.max_expected_variance | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| infer_data = model_inference.model_infer_score( | ||||||||||||||||||||||||||||||||||
| model, | ||||||||||||||||||||||||||||||||||
| image_bytes, | ||||||||||||||||||||||||||||||||||
| payload.resize_width, | ||||||||||||||||||||||||||||||||||
| payload.resize_height, | ||||||||||||||||||||||||||||||||||
| payload.normalize_mean, | ||||||||||||||||||||||||||||||||||
| payload.normalize_std, | ||||||||||||||||||||||||||||||||||
| payload.gaussian_blur_kernel, | ||||||||||||||||||||||||||||||||||
| payload.min_threshold, | ||||||||||||||||||||||||||||||||||
| payload.max_threshold, | ||||||||||||||||||||||||||||||||||
| preprocessing_steps=payload.preprocessing_steps, | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| logger.debug('Model inference completed successfully for model_id') | ||||||||||||||||||||||||||||||||||
| class ImageBatchPayload(BaseModel): | ||||||||||||||||||||||||||||||||||
| image_batch: list[str] # list of base64 encoded image data | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| inference_results = InferenceResult( | ||||||||||||||||||||||||||||||||||
| results={ | ||||||||||||||||||||||||||||||||||
| 'clarity_score': clarity_score, | ||||||||||||||||||||||||||||||||||
| 'infer_data': infer_data, | ||||||||||||||||||||||||||||||||||
| 'data_type': payload.payload_type.lower(), | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| logger.info('Inference request completed successfully for model_id') | ||||||||||||||||||||||||||||||||||
| return JSONResponse( | ||||||||||||||||||||||||||||||||||
| status_code=status.HTTP_201_CREATED, | ||||||||||||||||||||||||||||||||||
| content=response_formatter.buildSuccessResponse( | ||||||||||||||||||||||||||||||||||
| inference_results.dict() | ||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| error_msg = ( | ||||||||||||||||||||||||||||||||||
| "Input data is not in expected Data URI format (missing 'base64,')." | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| logger.error( | ||||||||||||||||||||||||||||||||||
| f"Expected Data URI format with 'base64,' prefix. " | ||||||||||||||||||||||||||||||||||
| f'Data length: {len(base64_data_uri)}' | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| return JSONResponse( | ||||||||||||||||||||||||||||||||||
| status_code=status.HTTP_400_BAD_REQUEST, | ||||||||||||||||||||||||||||||||||
| content=response_formatter.buildErrorResponse(error_msg), | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| error_msg = f"Invalid payload_type: {payload.payload_type}. Accepted values are 'image'" | ||||||||||||||||||||||||||||||||||
| logger.error(f'{error_msg}') | ||||||||||||||||||||||||||||||||||
| return JSONResponse( | ||||||||||||||||||||||||||||||||||
| status_code=status.HTTP_400_BAD_REQUEST, | ||||||||||||||||||||||||||||||||||
| content=response_formatter.buildErrorResponse( | ||||||||||||||||||||||||||||||||||
| 'Invalid payload_type. Accepted values are "image"' | ||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||||||
| logger.error(f'Error in generic_inference_handler {str(e)}') | ||||||||||||||||||||||||||||||||||
| return JSONResponse( | ||||||||||||||||||||||||||||||||||
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||||||||||||||||||||||||||||||||||
| content=response_formatter.buildErrorResponse('Internal server error'), | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| inference_app_router = APIRouter() | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @inference_app_router.post('/v1/query/embeddings') | ||||||||||||||||||||||||||||||||||
|
|
@@ -159,10 +33,7 @@ async def image_embedding( | |||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||
| # 1. Decode Base64 string | ||||||||||||||||||||||||||||||||||
| base64_data_uri = payload.image_data | ||||||||||||||||||||||||||||||||||
| parts = base64_data_uri.split(',') | ||||||||||||||||||||||||||||||||||
| base64_data = parts[1] if len(parts) == 2 else parts[0] | ||||||||||||||||||||||||||||||||||
| image_data = base64.b64decode(base64_data) | ||||||||||||||||||||||||||||||||||
| image_data = extract_decoded_image_data(payload.image_data) | ||||||||||||||||||||||||||||||||||
| embeddings = image_embedding_service.query_embed(image_data) | ||||||||||||||||||||||||||||||||||
| if not embeddings: | ||||||||||||||||||||||||||||||||||
| return JSONResponse( | ||||||||||||||||||||||||||||||||||
|
|
@@ -175,3 +46,30 @@ async def image_embedding( | |||||||||||||||||||||||||||||||||
| status_code=status.HTTP_200_OK, | ||||||||||||||||||||||||||||||||||
| content=response_formatter.buildSuccessResponse(data={'response': embeddings}), | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @inference_app_router.post('/v1/query/embeddings/batch') | ||||||||||||||||||||||||||||||||||
| @inject | ||||||||||||||||||||||||||||||||||
| async def image_embedding_batch( | ||||||||||||||||||||||||||||||||||
| payload: ImageBatchPayload, | ||||||||||||||||||||||||||||||||||
| response_formatter: ResponseFormatter = Depends( | ||||||||||||||||||||||||||||||||||
| Provide[CommonContainer.response_formatter] | ||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||
| image_embedding_service: ImageEmbedding = Depends( | ||||||||||||||||||||||||||||||||||
| Provide[InferenceAppContainer.image_embedding] | ||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||
| image_batch = [ | ||||||||||||||||||||||||||||||||||
| extract_decoded_image_data(image_data) for image_data in payload.image_batch | ||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||
| embeddings = image_embedding_service.query_embed_batch(image_batch) | ||||||||||||||||||||||||||||||||||
| return JSONResponse( | ||||||||||||||||||||||||||||||||||
| status_code=status.HTTP_200_OK, | ||||||||||||||||||||||||||||||||||
| content=response_formatter.buildSuccessResponse(data={'response': embeddings}), | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def extract_decoded_image_data(image_data: str) -> bytes: | ||||||||||||||||||||||||||||||||||
| parts = image_data.split(',') | ||||||||||||||||||||||||||||||||||
| base64_data = parts[1] if len(parts) == 2 else parts[0] | ||||||||||||||||||||||||||||||||||
| return base64.b64decode(base64_data) | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+72
to
+75
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
💡 Option: Add validation with a clearer error+import binascii
+
+class InvalidImageDataError(Exception):
+ pass
+
def extract_decoded_image_data(image_data: str) -> bytes:
parts = image_data.split(',')
base64_data = parts[1] if len(parts) == 2 else parts[0]
- return base64.b64decode(base64_data)
+ try:
+ return base64.b64decode(base64_data)
+ except binascii.Error as e:
+ raise InvalidImageDataError('Invalid base64 encoded image data') from e📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,31 +1,9 @@ | ||
| from dependency_injector import containers | ||
| from dependency_injector import providers | ||
| from inference_app.service.image_analyser import ImageClarityService | ||
| from flo_cloud.cloud_storage import CloudStorageManager | ||
| from inference_app.service.model_repository import ModelRepository | ||
| from inference_app.service.model_inference import ModelInferenceService | ||
| from inference_app.service.image_embedding import ImageEmbedding | ||
|
|
||
|
|
||
| class InferenceAppContainer(containers.DeclarativeContainer): | ||
| config = providers.Configuration(ini_files=['config.ini']) | ||
| cache_manager = providers.Dependency() | ||
|
|
||
| cloud_storage_manager = providers.Singleton( | ||
| CloudStorageManager, provider=config.cloud_config.cloud_provider | ||
| ) | ||
|
|
||
| model_repository = providers.Singleton( | ||
| ModelRepository, | ||
| cloud_storage_manager=cloud_storage_manager, | ||
| ) | ||
|
|
||
| model_inference = providers.Singleton(ModelInferenceService) | ||
|
|
||
| image_analyser = providers.Singleton( | ||
| ImageClarityService, | ||
| ) | ||
|
|
||
| image_embedding = providers.Singleton( | ||
| ImageEmbedding, | ||
| ) | ||
| image_embedding = providers.Singleton(ImageEmbedding) | ||
|
Comment on lines
6
to
+9
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Verify that server.py still passes cache_manager to InferenceAppContainer
rg -n 'InferenceAppContainer\(' --type py -A 3Repository: rootflo/wavefront Length of output: 891 🏁 Script executed: cat -n wavefront/server/apps/inference_app/inference_app/inference_app_container.pyRepository: rootflo/wavefront Length of output: 459 🏁 Script executed: cat -n wavefront/server/apps/inference_app/inference_app/server.py | head -35Repository: rootflo/wavefront Length of output: 1286 Runtime
Either restore the 🤖 Prompt for AI Agents |
||
This file was deleted.
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5,11 +5,11 @@ | |||||||||||||||||||||||||||||||||||||
| from typing import List, Dict, Any | ||||||||||||||||||||||||||||||||||||||
| from common_module.log.logger import logger | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| CLIP_MODEL_NAME = 'openai/clip-vit-base-patch32' | ||||||||||||||||||||||||||||||||||||||
| DINO_MODEL_NAME = 'facebook/dinov3-vitl16-pretrain-lvd1689m' | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| class ImageEmbedding: | ||||||||||||||||||||||||||||||||||||||
| CLIP_MODEL_NAME = 'openai/clip-vit-base-patch32' | ||||||||||||||||||||||||||||||||||||||
| DINO_MODEL_NAME = 'facebook/dinov3-vitl16-pretrain-lvd1689m' | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| class ImageEmbedding: | ||||||||||||||||||||||||||||||||||||||
| def __init__(self): | ||||||||||||||||||||||||||||||||||||||
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||||||||||||||||||||||||||||||||||||||
| logger.info(f'Using device: {self.device}') | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -69,3 +69,49 @@ def query_embed(self, image_content: bytes) -> List[Dict[str, List[float]]]: | |||||||||||||||||||||||||||||||||||||
| results.append({name: embedding}) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| return results | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| @torch.inference_mode() | ||||||||||||||||||||||||||||||||||||||
| def query_embed_batch( | ||||||||||||||||||||||||||||||||||||||
| self, image_batch: list[bytes] | ||||||||||||||||||||||||||||||||||||||
| ) -> List[Dict[str, List[List[float]]]]: | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| GPU batch embedding. | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||
| [ | ||||||||||||||||||||||||||||||||||||||
| {"clip": [embedding_for_image_0, ..., embedding_for_image_N]}, | ||||||||||||||||||||||||||||||||||||||
| {"dino": [embedding_for_image_0, ..., embedding_for_image_N]}, | ||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| if not image_batch: | ||||||||||||||||||||||||||||||||||||||
| return [] | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Decode bytes -> PIL images on CPU. | ||||||||||||||||||||||||||||||||||||||
| # The actual model forward pass (processor->tensor + model) is batched on GPU. | ||||||||||||||||||||||||||||||||||||||
| images: List[Image.Image] = [] | ||||||||||||||||||||||||||||||||||||||
| for idx, image_content in enumerate(image_batch): | ||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||
| images.append(Image.open(io.BytesIO(image_content)).convert('RGB')) | ||||||||||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||||||||||
| logger.error( | ||||||||||||||||||||||||||||||||||||||
| f'Error opening image at index={idx}: {e}', | ||||||||||||||||||||||||||||||||||||||
| exc_info=True, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| return [] | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+92
to
+100
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Silent failure on batch decode returns empty list indistinguishable from empty input. When any single image in the batch fails to decode, the method logs the error but returns Consider either:
💡 Option 1: Raise an exception except Exception as e:
logger.error(
f'Error opening image at index={idx}: {e}',
exc_info=True,
)
- return []
+ raise ValueError(f'Failed to decode image at index {idx}') from e📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.15.7)[warning] 95-95: Do not catch blind exception: (BLE001) 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| results: List[Dict[str, List[List[float]]]] = [] | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| for name, embedder in self.embedders.items(): | ||||||||||||||||||||||||||||||||||||||
| inputs = embedder['processor'](images=images, return_tensors='pt') | ||||||||||||||||||||||||||||||||||||||
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Batched forward pass. | ||||||||||||||||||||||||||||||||||||||
| image_features = embedder['extractor'](inputs) # (batch, dim) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # L2-normalize per-vector. | ||||||||||||||||||||||||||||||||||||||
| image_features = image_features / image_features.norm(dim=-1, keepdim=True) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| embeddings = image_features.cpu().numpy().tolist() # batch x dim | ||||||||||||||||||||||||||||||||||||||
| results.append({name: embeddings}) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| return results | ||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent error handling: batch endpoint always returns HTTP 200, unlike single-image endpoint.
The single-image endpoint (lines 38-44) checks
if not embeddingsand returns HTTP 400. The batch endpoint does not perform this check, always returning HTTP 200 even whenquery_embed_batchreturns an empty list due to decode failures.Additionally,
extract_decoded_image_datacan raisebinascii.Erroron malformed base64, which is unhandled here and would result in an HTTP 500.🛡️ Proposed fix for consistent error handling
`@inference_app_router.post`('/v1/query/embeddings/batch') `@inject` async def image_embedding_batch( payload: ImageBatchPayload, response_formatter: ResponseFormatter = Depends( Provide[CommonContainer.response_formatter] ), image_embedding_service: ImageEmbedding = Depends( Provide[InferenceAppContainer.image_embedding] ), ): - image_batch = [ - extract_decoded_image_data(image_data) for image_data in payload.image_batch - ] + try: + image_batch = [ + extract_decoded_image_data(image_data) for image_data in payload.image_batch + ] + except Exception: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=response_formatter.buildErrorResponse('Invalid base64 image data'), + ) embeddings = image_embedding_service.query_embed_batch(image_batch) + if not embeddings: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=response_formatter.buildErrorResponse( + 'No Embedding data is present' + ), + ) return JSONResponse( status_code=status.HTTP_200_OK, content=response_formatter.buildSuccessResponse(data={'response': embeddings}), )🧰 Tools
🪛 Ruff (0.15.7)
[warning] 55-57: Do not perform function call
Dependsin argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable(B008)
[warning] 58-60: Do not perform function call
Dependsin argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable(B008)
🤖 Prompt for AI Agents