diff --git a/wavefront/server/apps/inference_app/inference_app/config.ini b/wavefront/server/apps/inference_app/inference_app/config.ini index cc12d9df..8b137891 100644 --- a/wavefront/server/apps/inference_app/inference_app/config.ini +++ b/wavefront/server/apps/inference_app/inference_app/config.ini @@ -1,8 +1 @@ -[aws] -model_storage_bucket=${MODEL_STORAGE_BUCKET} -[gcp] -model_storage_bucket=${MODEL_STORAGE_BUCKET} - -[cloud_config] -cloud_provider=${CLOUD_PROVIDER} diff --git a/wavefront/server/apps/inference_app/inference_app/controllers/inference_controller.py b/wavefront/server/apps/inference_app/inference_app/controllers/inference_controller.py index e9a5e82a..856bf3ab 100644 --- a/wavefront/server/apps/inference_app/inference_app/controllers/inference_controller.py +++ b/wavefront/server/apps/inference_app/inference_app/controllers/inference_controller.py @@ -1,150 +1,24 @@ import base64 -from typing import Any, Dict from common_module.common_container import CommonContainer -from common_module.log.logger import logger from common_module.response_formatter import ResponseFormatter from dependency_injector.wiring import Provide, inject from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse from inference_app.inference_app_container import InferenceAppContainer -from inference_app.service.image_analyser import ImageClarityService -from inference_app.service.model_inference import ( - ModelInferenceService, - PreprocessingStep, -) -from inference_app.service.model_repository import ModelRepository from inference_app.service.image_embedding import ImageEmbedding -from pydantic import BaseModel, Field - - -class InferencePayload(BaseModel): - data: str - payload_type: str - model_info: dict - preprocessing_steps: list[PreprocessingStep] - max_expected_variance: int = Field(default=1000) - resize_width: int = Field(default=224) - resize_height: int = Field(default=224) - gaussian_blur_kernel: int = Field(default=3) - min_threshold: int = Field(default=50) - max_threshold: int = Field(default=150) - normalize_mean: str = Field(default='0.485,0.456,0.406') - normalize_std: str = Field(default='0.229,0.224,0.225') - - -class InferenceResult(BaseModel): - results: Dict[str, Any] = Field(..., description='Dictionary of inference results') +from pydantic import BaseModel class ImagePayload(BaseModel): - image_data: str - - -inference_app_router = APIRouter() - - -@inference_app_router.post('/v1/model-repository/model/{model_id}/infer') -@inject -async def generic_inference_handler( - payload: InferencePayload, - response_formatter: ResponseFormatter = Depends( - Provide[CommonContainer.response_formatter] - ), - model_repository: ModelRepository = Depends( - Provide[InferenceAppContainer.model_repository] - ), - image_analyser: ImageClarityService = Depends( - Provide[InferenceAppContainer.image_analyser] - ), - config: dict = Depends(Provide[InferenceAppContainer.config]), - model_inference: ModelInferenceService = Depends( - Provide[InferenceAppContainer.model_inference] - ), -): - try: - provider = config['cloud_config']['cloud_provider'] - model_storage_bucket = ( - config['gcp']['model_storage_bucket'] - if provider.lower() == 'gcp' - else config['aws']['model_storage_bucket'] - ) - - logger.info( - f'Loading model from bucket: {model_storage_bucket}, model_info: {payload.model_info}' - ) - model = await model_repository.load_model( - model_info=payload.model_info, bucket_name=model_storage_bucket - ) - logger.debug('Model loaded successfully for model_id') - - if payload.payload_type.lower() == 'image': - base64_data_uri = payload.data - parts = base64_data_uri.split(',') - if len(parts) == 2: - base64_data = parts[1] - image_bytes = base64.b64decode(base64_data) + image_data: str # base64 encoded image data - clarity_score = image_analyser.laplacian_detection( - image_bytes, payload.max_expected_variance - ) - infer_data = model_inference.model_infer_score( - model, - image_bytes, - payload.resize_width, - payload.resize_height, - payload.normalize_mean, - payload.normalize_std, - payload.gaussian_blur_kernel, - payload.min_threshold, - payload.max_threshold, - preprocessing_steps=payload.preprocessing_steps, - ) - logger.debug('Model inference completed successfully for model_id') +class ImageBatchPayload(BaseModel): + image_batch: list[str] # list of base64 encoded image data - inference_results = InferenceResult( - results={ - 'clarity_score': clarity_score, - 'infer_data': infer_data, - 'data_type': payload.payload_type.lower(), - } - ) - logger.info('Inference request completed successfully for model_id') - return JSONResponse( - status_code=status.HTTP_201_CREATED, - content=response_formatter.buildSuccessResponse( - inference_results.dict() - ), - ) - else: - error_msg = ( - "Input data is not in expected Data URI format (missing 'base64,')." - ) - logger.error( - f"Expected Data URI format with 'base64,' prefix. " - f'Data length: {len(base64_data_uri)}' - ) - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content=response_formatter.buildErrorResponse(error_msg), - ) - else: - error_msg = f"Invalid payload_type: {payload.payload_type}. Accepted values are 'image'" - logger.error(f'{error_msg}') - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content=response_formatter.buildErrorResponse( - 'Invalid payload_type. Accepted values are "image"' - ), - ) - except Exception as e: - logger.error(f'Error in generic_inference_handler {str(e)}') - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=response_formatter.buildErrorResponse('Internal server error'), - ) +inference_app_router = APIRouter() @inference_app_router.post('/v1/query/embeddings') @@ -159,10 +33,7 @@ async def image_embedding( ), ): # 1. Decode Base64 string - base64_data_uri = payload.image_data - parts = base64_data_uri.split(',') - base64_data = parts[1] if len(parts) == 2 else parts[0] - image_data = base64.b64decode(base64_data) + image_data = extract_decoded_image_data(payload.image_data) embeddings = image_embedding_service.query_embed(image_data) if not embeddings: return JSONResponse( @@ -175,3 +46,30 @@ async def image_embedding( status_code=status.HTTP_200_OK, content=response_formatter.buildSuccessResponse(data={'response': embeddings}), ) + + +@inference_app_router.post('/v1/query/embeddings/batch') +@inject +async def image_embedding_batch( + payload: ImageBatchPayload, + response_formatter: ResponseFormatter = Depends( + Provide[CommonContainer.response_formatter] + ), + image_embedding_service: ImageEmbedding = Depends( + Provide[InferenceAppContainer.image_embedding] + ), +): + image_batch = [ + extract_decoded_image_data(image_data) for image_data in payload.image_batch + ] + embeddings = image_embedding_service.query_embed_batch(image_batch) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=response_formatter.buildSuccessResponse(data={'response': embeddings}), + ) + + +def extract_decoded_image_data(image_data: str) -> bytes: + parts = image_data.split(',') + base64_data = parts[1] if len(parts) == 2 else parts[0] + return base64.b64decode(base64_data) diff --git a/wavefront/server/apps/inference_app/inference_app/inference_app_container.py b/wavefront/server/apps/inference_app/inference_app/inference_app_container.py index 8666dcbe..b0fd7ea8 100644 --- a/wavefront/server/apps/inference_app/inference_app/inference_app_container.py +++ b/wavefront/server/apps/inference_app/inference_app/inference_app_container.py @@ -1,31 +1,9 @@ from dependency_injector import containers from dependency_injector import providers -from inference_app.service.image_analyser import ImageClarityService -from flo_cloud.cloud_storage import CloudStorageManager -from inference_app.service.model_repository import ModelRepository -from inference_app.service.model_inference import ModelInferenceService from inference_app.service.image_embedding import ImageEmbedding class InferenceAppContainer(containers.DeclarativeContainer): config = providers.Configuration(ini_files=['config.ini']) - cache_manager = providers.Dependency() - cloud_storage_manager = providers.Singleton( - CloudStorageManager, provider=config.cloud_config.cloud_provider - ) - - model_repository = providers.Singleton( - ModelRepository, - cloud_storage_manager=cloud_storage_manager, - ) - - model_inference = providers.Singleton(ModelInferenceService) - - image_analyser = providers.Singleton( - ImageClarityService, - ) - - image_embedding = providers.Singleton( - ImageEmbedding, - ) + image_embedding = providers.Singleton(ImageEmbedding) diff --git a/wavefront/server/apps/inference_app/inference_app/service/image_analyser.py b/wavefront/server/apps/inference_app/inference_app/service/image_analyser.py deleted file mode 100644 index e7a89f0a..00000000 --- a/wavefront/server/apps/inference_app/inference_app/service/image_analyser.py +++ /dev/null @@ -1,22 +0,0 @@ -import cv2 -from common_module.log.logger import logger -from inference_app.utils.image_utils import decode_image_from_bytes - - -class ImageClarityService: - def __init__(self): - pass - - def laplacian_detection(self, image_bytes, max_expected_variance): - # Decode image from bytes array - logger.info( - f'Successfully decoded Base64 string. Data length: {len(image_bytes)} bytes.' - ) - images = decode_image_from_bytes(image_bytes) - images = cv2.resize(images, (256, 256), interpolation=cv2.INTER_AREA) - gray = cv2.cvtColor(images, cv2.COLOR_BGR2GRAY) - laplacian = cv2.Laplacian(gray, cv2.CV_64F) - variance = laplacian.var() - clamped_variance = min(variance, int(max_expected_variance)) - score = (clamped_variance / int(max_expected_variance)) * 100 - return int(score) diff --git a/wavefront/server/apps/inference_app/inference_app/service/image_embedding.py b/wavefront/server/apps/inference_app/inference_app/service/image_embedding.py index 106eba37..f4d451df 100644 --- a/wavefront/server/apps/inference_app/inference_app/service/image_embedding.py +++ b/wavefront/server/apps/inference_app/inference_app/service/image_embedding.py @@ -5,11 +5,11 @@ from typing import List, Dict, Any from common_module.log.logger import logger +CLIP_MODEL_NAME = 'openai/clip-vit-base-patch32' +DINO_MODEL_NAME = 'facebook/dinov3-vitl16-pretrain-lvd1689m' -class ImageEmbedding: - CLIP_MODEL_NAME = 'openai/clip-vit-base-patch32' - DINO_MODEL_NAME = 'facebook/dinov3-vitl16-pretrain-lvd1689m' +class ImageEmbedding: def __init__(self): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f'Using device: {self.device}') @@ -69,3 +69,49 @@ def query_embed(self, image_content: bytes) -> List[Dict[str, List[float]]]: results.append({name: embedding}) return results + + @torch.inference_mode() + def query_embed_batch( + self, image_batch: list[bytes] + ) -> List[Dict[str, List[List[float]]]]: + """ + GPU batch embedding. + + Returns: + [ + {"clip": [embedding_for_image_0, ..., embedding_for_image_N]}, + {"dino": [embedding_for_image_0, ..., embedding_for_image_N]}, + ] + """ + if not image_batch: + return [] + + # Decode bytes -> PIL images on CPU. + # The actual model forward pass (processor->tensor + model) is batched on GPU. + images: List[Image.Image] = [] + for idx, image_content in enumerate(image_batch): + try: + images.append(Image.open(io.BytesIO(image_content)).convert('RGB')) + except Exception as e: + logger.error( + f'Error opening image at index={idx}: {e}', + exc_info=True, + ) + return [] + + results: List[Dict[str, List[List[float]]]] = [] + + for name, embedder in self.embedders.items(): + inputs = embedder['processor'](images=images, return_tensors='pt') + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + # Batched forward pass. + image_features = embedder['extractor'](inputs) # (batch, dim) + + # L2-normalize per-vector. + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + + embeddings = image_features.cpu().numpy().tolist() # batch x dim + results.append({name: embeddings}) + + return results diff --git a/wavefront/server/apps/inference_app/inference_app/service/model_inference.py b/wavefront/server/apps/inference_app/inference_app/service/model_inference.py deleted file mode 100644 index 929ccf3e..00000000 --- a/wavefront/server/apps/inference_app/inference_app/service/model_inference.py +++ /dev/null @@ -1,99 +0,0 @@ -import cv2 -import torchvision.transforms as transforms -from PIL import Image -import torch -import numpy as np -from pydantic import BaseModel, Field -from inference_app.utils.image_utils import decode_image_from_bytes - - -class PreprocessingStep(BaseModel): - preprocess_filter: str - values: list = Field(default_factory=list) - - -class ModelInferenceService: - def __init__(self): - self.device = torch.device('cpu') - - def preprocess_image( - self, - image_bytes, - gaussian_blur_kernel, - min_threshold, - max_threshold, - preprocessing_steps: list[PreprocessingStep], - ): - """Apply preprocessing steps based on provided flags.""" - processed_image = decode_image_from_bytes(image_bytes) - - # Define available preprocessing functions - preprocessing_functions = { - 'gray': lambda img, values: cv2.cvtColor(img, cv2.COLOR_BGR2GRAY), - 'gaussian_blur': lambda img, values: cv2.GaussianBlur( - img, (gaussian_blur_kernel, gaussian_blur_kernel), 0 - ), - 'canny': lambda img, values: cv2.cvtColor( - cv2.Canny(img, min_threshold, max_threshold), cv2.COLOR_GRAY2RGB - ), - 'kernel_sharpening': lambda img, values: cv2.filter2D( - img, -1, np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]]) - ), - } - for step in preprocessing_steps: - filter_name = step.preprocess_filter - values = step.values - if filter_name and filter_name in preprocessing_functions: - processed_image = preprocessing_functions[filter_name]( - processed_image, values - ) - else: - continue - - pil_image = Image.fromarray(processed_image) - return pil_image - - def model_infer_score( - self, - model, - image_bytes, - resize_width, - resize_height, - normalize_mean, - normalize_std, - gaussian_blur_kernel, - min_threshold, - max_threshold, - preprocessing_steps: list[PreprocessingStep], - ): - """ - Predict overlap score for a single image using the same preprocessing as training - """ - # Define the same transform used during validation - normalize_mean = [float(x) for x in normalize_mean.split(',')] - normalize_std = [float(x) for x in normalize_std.split(',')] - transform = transforms.Compose( - [ - transforms.Resize((resize_width, resize_height)), - transforms.ToTensor(), - transforms.Normalize(mean=normalize_mean, std=normalize_std), - ] - ) - # Apply the same preprocessing as during training - preprocessed_image = self.preprocess_image( - image_bytes, - gaussian_blur_kernel, - min_threshold, - max_threshold, - preprocessing_steps, - ) - - # Apply transforms - image_tensor = transform(preprocessed_image).unsqueeze(0).to(self.device) - model.to(self.device) - # Predict - model.eval() - with torch.no_grad(): - response = model(image_tensor).item() - - return response diff --git a/wavefront/server/apps/inference_app/inference_app/service/model_repository.py b/wavefront/server/apps/inference_app/inference_app/service/model_repository.py deleted file mode 100644 index 831dc8a0..00000000 --- a/wavefront/server/apps/inference_app/inference_app/service/model_repository.py +++ /dev/null @@ -1,68 +0,0 @@ -from typing import Dict -import os -import torch -import io -from common_module.log.logger import logger -from flo_cloud.cloud_storage import CloudStorageManager - - -class ModelRepository: - def __init__( - self, - cloud_storage_manager: CloudStorageManager, - ): - self.cloud_storage_manager = cloud_storage_manager - self.model_storage_dir = os.getenv('MODEL_STORAGE_DIR', './models') - os.makedirs(self.model_storage_dir, exist_ok=True) - # Cache for loaded models - stores model instances in memory - self._model_cache: Dict[str, torch.nn.Module] = {} - - def _is_model_cached_locally( - self, model_name: str, file_path: str, expected_local_model_dir: str - ) -> bool: - """ - Checks if the model is available in the local persistent storage. - """ - return os.path.exists( - expected_local_model_dir - ) and f'{model_name}.{file_path.split(".")[-1]}' in os.listdir( - expected_local_model_dir - ) - - async def load_model(self, model_info: dict, bucket_name: str): - model_id = model_info['model_id'] - expected_local_model_dir = self.model_storage_dir - model_name = model_info['model_name'] - file_path = model_info['model_path'] - model_id = model_info['model_id'] - - local_model_filename = os.path.join( - expected_local_model_dir, f'{model_name}.{file_path.split(".")[1]}' - ) - local_model_full_path = os.path.join(local_model_filename) - - if self._is_model_cached_locally( - model_name, file_path, expected_local_model_dir - ): - logger.info(f'Model {model_id} found in local persistent storage, loading.') - if model_id in self._model_cache: - return self._model_cache[model_id] - else: - with open(local_model_full_path, 'rb') as f: - model_bytes_data = f.read() - return torch.load(io.BytesIO(model_bytes_data), weights_only=False) - else: - logger.info( - f'Model {model_id} not found in local persistent storage, loading from cloud storage.' - ) - model_bytes_data = self.cloud_storage_manager.read_file( - bucket_name, file_path - ) - model = torch.load(io.BytesIO(model_bytes_data), weights_only=False) - # Save to local persistent storage after fetching from cloud - os.makedirs(os.path.dirname(local_model_full_path), exist_ok=True) - with open(local_model_full_path, 'wb') as f: - f.write(model_bytes_data) - self._model_cache[model_id] = model - logger.info(f'Model {model_id} loaded and cached in memory.') - return model