Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions wavefront/server/apps/inference_app/inference_app/config.ini
Original file line number Diff line number Diff line change
@@ -1,8 +1 @@
[aws]
model_storage_bucket=${MODEL_STORAGE_BUCKET}

[gcp]
model_storage_bucket=${MODEL_STORAGE_BUCKET}

[cloud_config]
cloud_provider=${CLOUD_PROVIDER}
Original file line number Diff line number Diff line change
@@ -1,150 +1,24 @@
import base64
from typing import Any, Dict

from common_module.common_container import CommonContainer
from common_module.log.logger import logger
from common_module.response_formatter import ResponseFormatter
from dependency_injector.wiring import Provide, inject
from fastapi import APIRouter, Depends, status
from fastapi.responses import JSONResponse
from inference_app.inference_app_container import InferenceAppContainer
from inference_app.service.image_analyser import ImageClarityService
from inference_app.service.model_inference import (
ModelInferenceService,
PreprocessingStep,
)
from inference_app.service.model_repository import ModelRepository
from inference_app.service.image_embedding import ImageEmbedding
from pydantic import BaseModel, Field


class InferencePayload(BaseModel):
data: str
payload_type: str
model_info: dict
preprocessing_steps: list[PreprocessingStep]
max_expected_variance: int = Field(default=1000)
resize_width: int = Field(default=224)
resize_height: int = Field(default=224)
gaussian_blur_kernel: int = Field(default=3)
min_threshold: int = Field(default=50)
max_threshold: int = Field(default=150)
normalize_mean: str = Field(default='0.485,0.456,0.406')
normalize_std: str = Field(default='0.229,0.224,0.225')


class InferenceResult(BaseModel):
results: Dict[str, Any] = Field(..., description='Dictionary of inference results')
from pydantic import BaseModel


class ImagePayload(BaseModel):
image_data: str


inference_app_router = APIRouter()


@inference_app_router.post('/v1/model-repository/model/{model_id}/infer')
@inject
async def generic_inference_handler(
payload: InferencePayload,
response_formatter: ResponseFormatter = Depends(
Provide[CommonContainer.response_formatter]
),
model_repository: ModelRepository = Depends(
Provide[InferenceAppContainer.model_repository]
),
image_analyser: ImageClarityService = Depends(
Provide[InferenceAppContainer.image_analyser]
),
config: dict = Depends(Provide[InferenceAppContainer.config]),
model_inference: ModelInferenceService = Depends(
Provide[InferenceAppContainer.model_inference]
),
):
try:
provider = config['cloud_config']['cloud_provider']
model_storage_bucket = (
config['gcp']['model_storage_bucket']
if provider.lower() == 'gcp'
else config['aws']['model_storage_bucket']
)

logger.info(
f'Loading model from bucket: {model_storage_bucket}, model_info: {payload.model_info}'
)
model = await model_repository.load_model(
model_info=payload.model_info, bucket_name=model_storage_bucket
)
logger.debug('Model loaded successfully for model_id')

if payload.payload_type.lower() == 'image':
base64_data_uri = payload.data
parts = base64_data_uri.split(',')
if len(parts) == 2:
base64_data = parts[1]
image_bytes = base64.b64decode(base64_data)
image_data: str # base64 encoded image data

clarity_score = image_analyser.laplacian_detection(
image_bytes, payload.max_expected_variance
)

infer_data = model_inference.model_infer_score(
model,
image_bytes,
payload.resize_width,
payload.resize_height,
payload.normalize_mean,
payload.normalize_std,
payload.gaussian_blur_kernel,
payload.min_threshold,
payload.max_threshold,
preprocessing_steps=payload.preprocessing_steps,
)
logger.debug('Model inference completed successfully for model_id')
class ImageBatchPayload(BaseModel):
image_batch: list[str] # list of base64 encoded image data

inference_results = InferenceResult(
results={
'clarity_score': clarity_score,
'infer_data': infer_data,
'data_type': payload.payload_type.lower(),
}
)

logger.info('Inference request completed successfully for model_id')
return JSONResponse(
status_code=status.HTTP_201_CREATED,
content=response_formatter.buildSuccessResponse(
inference_results.dict()
),
)
else:
error_msg = (
"Input data is not in expected Data URI format (missing 'base64,')."
)
logger.error(
f"Expected Data URI format with 'base64,' prefix. "
f'Data length: {len(base64_data_uri)}'
)
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=response_formatter.buildErrorResponse(error_msg),
)
else:
error_msg = f"Invalid payload_type: {payload.payload_type}. Accepted values are 'image'"
logger.error(f'{error_msg}')
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=response_formatter.buildErrorResponse(
'Invalid payload_type. Accepted values are "image"'
),
)
except Exception as e:
logger.error(f'Error in generic_inference_handler {str(e)}')
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=response_formatter.buildErrorResponse('Internal server error'),
)
inference_app_router = APIRouter()


@inference_app_router.post('/v1/query/embeddings')
Expand All @@ -159,10 +33,7 @@ async def image_embedding(
),
):
# 1. Decode Base64 string
base64_data_uri = payload.image_data
parts = base64_data_uri.split(',')
base64_data = parts[1] if len(parts) == 2 else parts[0]
image_data = base64.b64decode(base64_data)
image_data = extract_decoded_image_data(payload.image_data)
embeddings = image_embedding_service.query_embed(image_data)
if not embeddings:
return JSONResponse(
Expand All @@ -175,3 +46,30 @@ async def image_embedding(
status_code=status.HTTP_200_OK,
content=response_formatter.buildSuccessResponse(data={'response': embeddings}),
)


@inference_app_router.post('/v1/query/embeddings/batch')
@inject
async def image_embedding_batch(
payload: ImageBatchPayload,
response_formatter: ResponseFormatter = Depends(
Provide[CommonContainer.response_formatter]
),
image_embedding_service: ImageEmbedding = Depends(
Provide[InferenceAppContainer.image_embedding]
),
):
image_batch = [
extract_decoded_image_data(image_data) for image_data in payload.image_batch
]
embeddings = image_embedding_service.query_embed_batch(image_batch)
return JSONResponse(
status_code=status.HTTP_200_OK,
content=response_formatter.buildSuccessResponse(data={'response': embeddings}),
)
Comment on lines +51 to +69
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Inconsistent error handling: batch endpoint always returns HTTP 200, unlike single-image endpoint.

The single-image endpoint (lines 38-44) checks if not embeddings and returns HTTP 400. The batch endpoint does not perform this check, always returning HTTP 200 even when query_embed_batch returns an empty list due to decode failures.

Additionally, extract_decoded_image_data can raise binascii.Error on malformed base64, which is unhandled here and would result in an HTTP 500.

🛡️ Proposed fix for consistent error handling
 `@inference_app_router.post`('/v1/query/embeddings/batch')
 `@inject`
 async def image_embedding_batch(
     payload: ImageBatchPayload,
     response_formatter: ResponseFormatter = Depends(
         Provide[CommonContainer.response_formatter]
     ),
     image_embedding_service: ImageEmbedding = Depends(
         Provide[InferenceAppContainer.image_embedding]
     ),
 ):
-    image_batch = [
-        extract_decoded_image_data(image_data) for image_data in payload.image_batch
-    ]
+    try:
+        image_batch = [
+            extract_decoded_image_data(image_data) for image_data in payload.image_batch
+        ]
+    except Exception:
+        return JSONResponse(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            content=response_formatter.buildErrorResponse('Invalid base64 image data'),
+        )
     embeddings = image_embedding_service.query_embed_batch(image_batch)
+    if not embeddings:
+        return JSONResponse(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            content=response_formatter.buildErrorResponse(
+                'No Embedding data is present'
+            ),
+        )
     return JSONResponse(
         status_code=status.HTTP_200_OK,
         content=response_formatter.buildSuccessResponse(data={'response': embeddings}),
     )
🧰 Tools
🪛 Ruff (0.15.7)

[warning] 55-57: Do not perform function call Depends in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable

(B008)


[warning] 58-60: Do not perform function call Depends in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable

(B008)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@wavefront/server/apps/inference_app/inference_app/controllers/inference_controller.py`
around lines 51 - 69, The batch endpoint image_embedding_batch currently always
returns HTTP 200 and doesn't handle decode errors; update image_embedding_batch
to (1) catch binascii.Error (raised by extract_decoded_image_data) and return an
HTTP 400 using response_formatter.buildSuccessResponse or an appropriate error
response, and (2) after calling image_embedding_service.query_embed_batch, check
if embeddings is empty or falsy and return HTTP 400 (mirroring the single-image
flow) instead of always returning 200; reference extract_decoded_image_data,
query_embed_batch, image_embedding_batch, and
response_formatter.buildSuccessResponse to locate where to add the try/except
and the empty-result conditional.



def extract_decoded_image_data(image_data: str) -> bytes:
parts = image_data.split(',')
base64_data = parts[1] if len(parts) == 2 else parts[0]
return base64.b64decode(base64_data)
Comment on lines +72 to +75
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

extract_decoded_image_data can raise unhandled binascii.Error on malformed base64.

base64.b64decode raises binascii.Error for invalid base64 input. This function is called directly without try/except in both endpoints. For the single-image endpoint (line 36), this would cause an HTTP 500. Consider adding error handling here or at call sites.

💡 Option: Add validation with a clearer error
+import binascii
+
+class InvalidImageDataError(Exception):
+    pass
+
 def extract_decoded_image_data(image_data: str) -> bytes:
     parts = image_data.split(',')
     base64_data = parts[1] if len(parts) == 2 else parts[0]
-    return base64.b64decode(base64_data)
+    try:
+        return base64.b64decode(base64_data)
+    except binascii.Error as e:
+        raise InvalidImageDataError('Invalid base64 encoded image data') from e
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def extract_decoded_image_data(image_data: str) -> bytes:
parts = image_data.split(',')
base64_data = parts[1] if len(parts) == 2 else parts[0]
return base64.b64decode(base64_data)
import binascii
class InvalidImageDataError(Exception):
pass
def extract_decoded_image_data(image_data: str) -> bytes:
parts = image_data.split(',')
base64_data = parts[1] if len(parts) == 2 else parts[0]
try:
return base64.b64decode(base64_data)
except binascii.Error as e:
raise InvalidImageDataError('Invalid base64 encoded image data') from e
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@wavefront/server/apps/inference_app/inference_app/controllers/inference_controller.py`
around lines 72 - 75, The helper extract_decoded_image_data currently calls
base64.b64decode which can raise binascii.Error on malformed input; wrap the
decode call in a try/except that catches binascii.Error (import binascii) and
re-raise a clear ValueError or custom exception (e.g., "Invalid base64 image
data") so caller endpoints can return a 4xx response instead of an unhandled
500; update any call sites (the single-image and multi-image endpoints) to catch
that ValueError and convert it to an appropriate HTTP error response.

Original file line number Diff line number Diff line change
@@ -1,31 +1,9 @@
from dependency_injector import containers
from dependency_injector import providers
from inference_app.service.image_analyser import ImageClarityService
from flo_cloud.cloud_storage import CloudStorageManager
from inference_app.service.model_repository import ModelRepository
from inference_app.service.model_inference import ModelInferenceService
from inference_app.service.image_embedding import ImageEmbedding


class InferenceAppContainer(containers.DeclarativeContainer):
config = providers.Configuration(ini_files=['config.ini'])
cache_manager = providers.Dependency()

cloud_storage_manager = providers.Singleton(
CloudStorageManager, provider=config.cloud_config.cloud_provider
)

model_repository = providers.Singleton(
ModelRepository,
cloud_storage_manager=cloud_storage_manager,
)

model_inference = providers.Singleton(ModelInferenceService)

image_analyser = providers.Singleton(
ImageClarityService,
)

image_embedding = providers.Singleton(
ImageEmbedding,
)
image_embedding = providers.Singleton(ImageEmbedding)
Comment on lines 6 to +9
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify that server.py still passes cache_manager to InferenceAppContainer
rg -n 'InferenceAppContainer\(' --type py -A 3

Repository: rootflo/wavefront

Length of output: 891


🏁 Script executed:

cat -n wavefront/server/apps/inference_app/inference_app/inference_app_container.py

Repository: rootflo/wavefront

Length of output: 459


🏁 Script executed:

cat -n wavefront/server/apps/inference_app/inference_app/server.py | head -35

Repository: rootflo/wavefront

Length of output: 1286


Runtime TypeError: server.py passes cache_manager=None but this container no longer accepts it.

server.py (lines 25-27) instantiates InferenceAppContainer(cache_manager=None). Since the cache_manager provider was removed from the container, this raises:

TypeError: __init__() got an unexpected keyword argument 'cache_manager'

Either restore the cache_manager = providers.Dependency() provider, or remove the argument from the instantiation in server.py.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@wavefront/server/apps/inference_app/inference_app/inference_app_container.py`
around lines 6 - 9, The failure is caused by server.py calling
InferenceAppContainer(cache_manager=None) but the container no longer declares a
cache_manager provider; restore compatibility by adding a dependency provider
named cache_manager to InferenceAppContainer (e.g., add a line cache_manager =
providers.Dependency() in the class) so the existing instantiation with
cache_manager=None succeeds, or alternatively remove the cache_manager argument
from the call site in server.py; update the symbol InferenceAppContainer to
include cache_manager = providers.Dependency() if you choose the former.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from typing import List, Dict, Any
from common_module.log.logger import logger

CLIP_MODEL_NAME = 'openai/clip-vit-base-patch32'
DINO_MODEL_NAME = 'facebook/dinov3-vitl16-pretrain-lvd1689m'

class ImageEmbedding:
CLIP_MODEL_NAME = 'openai/clip-vit-base-patch32'
DINO_MODEL_NAME = 'facebook/dinov3-vitl16-pretrain-lvd1689m'

class ImageEmbedding:
def __init__(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f'Using device: {self.device}')
Expand Down Expand Up @@ -69,3 +69,49 @@ def query_embed(self, image_content: bytes) -> List[Dict[str, List[float]]]:
results.append({name: embedding})

return results

@torch.inference_mode()
def query_embed_batch(
self, image_batch: list[bytes]
) -> List[Dict[str, List[List[float]]]]:
"""
GPU batch embedding.

Returns:
[
{"clip": [embedding_for_image_0, ..., embedding_for_image_N]},
{"dino": [embedding_for_image_0, ..., embedding_for_image_N]},
]
"""
if not image_batch:
return []

# Decode bytes -> PIL images on CPU.
# The actual model forward pass (processor->tensor + model) is batched on GPU.
images: List[Image.Image] = []
for idx, image_content in enumerate(image_batch):
try:
images.append(Image.open(io.BytesIO(image_content)).convert('RGB'))
except Exception as e:
logger.error(
f'Error opening image at index={idx}: {e}',
exc_info=True,
)
return []
Comment on lines +92 to +100
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Silent failure on batch decode returns empty list indistinguishable from empty input.

When any single image in the batch fails to decode, the method logs the error but returns []. This is the same return value as an empty input batch (line 86-87), making it impossible for callers to distinguish between "no images provided" and "decode failed."

Consider either:

  1. Raising an exception with the failing index for explicit error handling.
  2. Returning a result structure that indicates which images failed.
  3. Skipping invalid images and processing valid ones (partial success).
💡 Option 1: Raise an exception
             except Exception as e:
                 logger.error(
                     f'Error opening image at index={idx}: {e}',
                     exc_info=True,
                 )
-                return []
+                raise ValueError(f'Failed to decode image at index {idx}') from e
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for idx, image_content in enumerate(image_batch):
try:
images.append(Image.open(io.BytesIO(image_content)).convert('RGB'))
except Exception as e:
logger.error(
f'Error opening image at index={idx}: {e}',
exc_info=True,
)
return []
for idx, image_content in enumerate(image_batch):
try:
images.append(Image.open(io.BytesIO(image_content)).convert('RGB'))
except Exception as e:
logger.error(
f'Error opening image at index={idx}: {e}',
exc_info=True,
)
raise ValueError(f'Failed to decode image at index {idx}') from e
🧰 Tools
🪛 Ruff (0.15.7)

[warning] 95-95: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@wavefront/server/apps/inference_app/inference_app/service/image_embedding.py`
around lines 92 - 100, The current loop that opens images (the block that calls
Image.open on items from image_batch and appends to images) silently returns []
on any decode error, making failures indistinguishable from an empty input;
change this to raise a descriptive exception instead: when an exception occurs
while opening an image (use the same except block that catches Exception as e),
log the error and then raise a ValueError (or a custom DecodeError) that
includes the failing index (idx) and the original exception (e) so callers can
distinguish a decode failure from an empty batch and handle it explicitly.


results: List[Dict[str, List[List[float]]]] = []

for name, embedder in self.embedders.items():
inputs = embedder['processor'](images=images, return_tensors='pt')
inputs = {k: v.to(self.device) for k, v in inputs.items()}

# Batched forward pass.
image_features = embedder['extractor'](inputs) # (batch, dim)

# L2-normalize per-vector.
image_features = image_features / image_features.norm(dim=-1, keepdim=True)

embeddings = image_features.cpu().numpy().tolist() # batch x dim
results.append({name: embeddings})

return results
Loading
Loading