From aebec7be3fed372771bf54991417f0dbcf1d552b Mon Sep 17 00:00:00 2001 From: jphillips Date: Tue, 25 Mar 2025 08:07:41 -0500 Subject: [PATCH 01/69] Move db providers to feature module Signed-off-by: jphillips --- .cursor/rules/task.mdc | 11 ++++++++++- servers/data_service/src/api/routes/index.ts | 2 +- servers/data_service/src/app.ts | 2 +- .../providers.ts => features/providers/controller.ts} | 2 +- .../providers.ts => features/providers/routes.ts} | 6 +++--- .../providers.ts => features/providers/schemas.ts} | 0 6 files changed, 16 insertions(+), 7 deletions(-) rename servers/data_service/src/{api/controllers/providers.ts => features/providers/controller.ts} (99%) rename servers/data_service/src/{api/routes/providers.ts => features/providers/routes.ts} (93%) rename servers/data_service/src/{api/schemas/providers.ts => features/providers/schemas.ts} (100%) diff --git a/.cursor/rules/task.mdc b/.cursor/rules/task.mdc index f03b36a7..7f1b94d7 100644 --- a/.cursor/rules/task.mdc +++ b/.cursor/rules/task.mdc @@ -4,4 +4,13 @@ globs: alwaysApply: true --- # Task -Add your task for the agent here. \ No newline at end of file +Provider configuration still relies on some file based configuration. + +### Describe the solution you'd like +- Utilize the UI / DB for provider configuration +- Remove all old provider handling code +- Resolve issues cited in discord related to provider ux + + +[provider_manager.py](mdc:servers/inference_server/graphcap/providers/provider_manager.py) +[provider_config.py](mdc:servers/inference_server/graphcap/providers/provider_config.py) diff --git a/servers/data_service/src/api/routes/index.ts b/servers/data_service/src/api/routes/index.ts index 22560994..0a151bdf 100644 --- a/servers/data_service/src/api/routes/index.ts +++ b/servers/data_service/src/api/routes/index.ts @@ -5,7 +5,7 @@ * This file exports route definitions for client consumption. */ +import { providerRoutes } from '../../features/providers/routes'; import { batchQueueRoutes } from './batch_queue'; -import { providerRoutes } from './providers'; export { providerRoutes, batchQueueRoutes }; \ No newline at end of file diff --git a/servers/data_service/src/app.ts b/servers/data_service/src/app.ts index 5faa524a..eec14188 100644 --- a/servers/data_service/src/app.ts +++ b/servers/data_service/src/app.ts @@ -14,9 +14,9 @@ import { timing } from 'hono/timing'; import { z } from 'zod'; import { batchQueueRoutes } from './api/routes/batch_queue'; -import { providerRoutes } from './api/routes/providers'; import { checkDatabaseConnection } from './db/init'; import { env } from './env'; +import { providerRoutes } from './features/providers/routes'; import { logger } from './utils/logger'; // Create OpenAPI Hono app diff --git a/servers/data_service/src/api/controllers/providers.ts b/servers/data_service/src/features/providers/controller.ts similarity index 99% rename from servers/data_service/src/api/controllers/providers.ts rename to servers/data_service/src/features/providers/controller.ts index 06ad3e07..1b4a7c23 100644 --- a/servers/data_service/src/api/controllers/providers.ts +++ b/servers/data_service/src/features/providers/controller.ts @@ -15,7 +15,7 @@ import type { ProviderApiKey, ProviderCreate, ProviderUpdate, -} from "../schemas/providers"; +} from "./schemas"; // Type for the validated parameters type ValidatedParams = { diff --git a/servers/data_service/src/api/routes/providers.ts b/servers/data_service/src/features/providers/routes.ts similarity index 93% rename from servers/data_service/src/api/routes/providers.ts rename to servers/data_service/src/features/providers/routes.ts index d2724668..22259ddd 100644 --- a/servers/data_service/src/api/routes/providers.ts +++ b/servers/data_service/src/features/providers/routes.ts @@ -7,9 +7,9 @@ import { OpenAPIHono, createRoute } from '@hono/zod-openapi'; import { z } from 'zod'; -import * as handlers from '../controllers/providers'; -import { providerSchema, providerCreateSchema, providerUpdateSchema, providerApiKeySchema } from '../schemas/providers'; -import { commonResponses, notFoundResponse, invalidRequestResponse, successResponse } from '../schemas/common'; +import { commonResponses, invalidRequestResponse, notFoundResponse, successResponse } from '../../api/schemas/common'; +import * as handlers from './controller'; +import { providerApiKeySchema, providerCreateSchema, providerSchema, providerUpdateSchema } from './schemas'; // Create a new OpenAPI router const router = new OpenAPIHono(); diff --git a/servers/data_service/src/api/schemas/providers.ts b/servers/data_service/src/features/providers/schemas.ts similarity index 100% rename from servers/data_service/src/api/schemas/providers.ts rename to servers/data_service/src/features/providers/schemas.ts From 40951aea8f57ff7b1cdd2bb10f67a0f33a7ed056 Mon Sep 17 00:00:00 2001 From: jphillips Date: Tue, 25 Mar 2025 08:12:16 -0500 Subject: [PATCH 02/69] Decouple inference bridge from file based provider config Signed-off-by: jphillips --- .../graphcap/providers/provider_manager.py | 110 ++++------ .../server/features/providers/models.py | 22 +- .../server/features/providers/router.py | 68 +----- .../server/features/providers/service.py | 197 ++++-------------- 4 files changed, 112 insertions(+), 285 deletions(-) diff --git a/servers/inference_server/graphcap/providers/provider_manager.py b/servers/inference_server/graphcap/providers/provider_manager.py index 8db97294..2e18ee79 100644 --- a/servers/inference_server/graphcap/providers/provider_manager.py +++ b/servers/inference_server/graphcap/providers/provider_manager.py @@ -5,8 +5,7 @@ This module handles provider lifecycle management and client initialization. Key features: -- Provider configuration loading -- Client initialization and caching +- Client initialization - Environment validation - Rate limit management @@ -14,101 +13,64 @@ ProviderManager: Main provider management class """ -from pathlib import Path -from typing import Dict +from typing import Dict, Optional from loguru import logger from .clients import BaseClient, get_client -from .provider_config import get_providers_config -from .types import ProviderConfig class ProviderManager: """Manager class for handling provider lifecycle and client initialization""" - def __init__(self, config_path: str | Path): - """Initialize provider manager with configuration file""" - logger.info(f"Initializing ProviderManager with config from: {config_path}") - self.providers = get_providers_config(config_path) + def __init__(self, _: Optional[str] = None): + """Initialize provider manager""" + logger.info("Initializing ProviderManager") self._clients: Dict[str, BaseClient] = {} - logger.info(f"Loaded {len(self.providers)} provider configurations") - for name, config in self.providers.items(): - logger.info(f"Provider '{name}' configuration:") - logger.info(f" - kind: {config.kind}") - logger.info(f" - environment: {config.environment}") - logger.info(f" - base_url: {config.base_url}") - logger.info(f" - default_model: {config.default_model}") - if config.rate_limits: - logger.info(f" - rate_limits: {config.rate_limits}") - def get_client(self, provider_name: str) -> BaseClient: - """Get or create a client for the specified provider""" - if provider_name not in self.providers: - logger.error(f"Requested unknown provider: {provider_name}") - logger.debug(f"Available providers: {', '.join(self.providers.keys())}") - raise ValueError(f"Unknown provider: {provider_name}") - - # Return cached client if available - if provider_name in self._clients: - logger.debug(f"Using cached client for provider: {provider_name}") - return self._clients[provider_name] - - # Create new client - config = self.providers[provider_name] - logger.info(f"Initializing new client for provider: {provider_name}") + def get_client( + self, + name: str, + kind: str, + environment: str, + base_url: str, + api_key: str, + default_model: Optional[str] = None, + rate_limits: Optional[dict] = None, + ) -> BaseClient: + """Initialize a client with the given configuration""" + logger.info(f"Initializing client for provider: {name}") logger.info(f"Provider config details:") - logger.info(f" - kind: {config.kind}") - logger.info(f" - environment: {config.environment}") - logger.info(f" - base_url: {config.base_url}") - logger.info(f" - default_model: {config.default_model}") + logger.info(f" - kind: {kind}") + logger.info(f" - environment: {environment}") + logger.info(f" - base_url: {base_url}") + logger.info(f" - default_model: {default_model}") try: client = get_client( - name=provider_name, - kind=config.kind, - environment=config.environment, - env_var=config.env_var, - base_url=config.base_url, - default_model=config.default_model, + name=name, + kind=kind, + environment=environment, + api_key=api_key, + base_url=base_url, + default_model=default_model, ) # Set rate limits if configured - if config.rate_limits: + if rate_limits: logger.debug( - f"Setting rate limits for {provider_name} - requests: {config.rate_limits.requests_per_minute}/min, tokens: {config.rate_limits.tokens_per_minute}/min" + f"Setting rate limits for {name} - requests: {rate_limits.get('requests_per_minute')}/min, tokens: {rate_limits.get('tokens_per_minute')}/min" ) - client.requests_per_minute = config.rate_limits.requests_per_minute - client.tokens_per_minute = config.rate_limits.tokens_per_minute + client.requests_per_minute = rate_limits.get("requests_per_minute") + client.tokens_per_minute = rate_limits.get("tokens_per_minute") - self._clients[provider_name] = client - logger.info(f"Successfully initialized client for provider: {provider_name}") return client except Exception as e: - logger.error(f"Failed to initialize client for {provider_name}: {str(e)}") + logger.error(f"Failed to initialize client for {name}: {str(e)}") logger.error(f"Provider config details:") - logger.error(f" - kind: {config.kind}") - logger.error(f" - environment: {config.environment}") - logger.error(f" - base_url: {config.base_url}") - logger.error(f" - default_model: {config.default_model}") + logger.error(f" - kind: {kind}") + logger.error(f" - environment: {environment}") + logger.error(f" - base_url: {base_url}") + logger.error(f" - default_model: {default_model}") raise - - def clients(self) -> Dict[str, BaseClient]: - """Get all initialized clients""" - logger.debug(f"Returning {len(self._clients)} initialized clients") - return self._clients.copy() - - def available_providers(self) -> list[str]: - """Get list of available provider names""" - providers = list(self.providers.keys()) - logger.debug(f"Available providers: {', '.join(providers)}") - return providers - - def get_provider_config(self, provider_name: str) -> ProviderConfig: - """Get configuration for a specific provider""" - if provider_name not in self.providers: - logger.error(f"Requested config for unknown provider: {provider_name}") - raise ValueError(f"Unknown provider: {provider_name}") - logger.debug(f"Returning config for provider: {provider_name}") - return self.providers[provider_name] diff --git a/servers/inference_server/server/server/features/providers/models.py b/servers/inference_server/server/server/features/providers/models.py index 506f7298..177b2025 100644 --- a/servers/inference_server/server/server/features/providers/models.py +++ b/servers/inference_server/server/server/features/providers/models.py @@ -5,7 +5,7 @@ Defines data models for the providers API endpoints. """ -from typing import List +from typing import List, Optional from pydantic import BaseModel, Field @@ -37,3 +37,23 @@ class ProviderModelsResponse(BaseModel): provider: str = Field(..., description="Name of the provider") models: List[ModelInfo] = Field(..., description="List of available models") + + +class ProviderConfig(BaseModel): + """Provider configuration model.""" + + name: str = Field(..., description="Unique identifier for the provider") + kind: str = Field(..., description="Type of provider (e.g., 'openai', 'anthropic', 'gemini')") + environment: str = Field(..., description="Provider environment (cloud, local)") + base_url: str = Field(..., description="Base URL for the provider API") + api_key: str = Field(..., description="API key for the provider") + default_model: Optional[str] = Field(None, description="Default model for the provider") + models: List[str] = Field(default_factory=list, description="List of available model IDs") + fetch_models: bool = Field(default=True, description="Whether to fetch models from the provider API") + rate_limits: Optional[dict] = Field(None, description="Rate limiting configuration") + + +class ProviderConfigureRequest(BaseModel): + """Request model for configuring a provider.""" + + config: ProviderConfig = Field(..., description="Provider configuration") diff --git a/servers/inference_server/server/server/features/providers/router.py b/servers/inference_server/server/server/features/providers/router.py index 4e71725a..89080fd1 100644 --- a/servers/inference_server/server/server/features/providers/router.py +++ b/servers/inference_server/server/server/features/providers/router.py @@ -5,84 +5,34 @@ Defines API routes for working with AI providers. This module provides the following endpoints: -- GET /providers/list - List all available providers -- GET /providers/check/{provider_name} - Check if a specific provider is available -- GET /providers/{provider_name}/models - List available models for a specific provider +- POST /providers/{provider_name}/models - List available models for a provider using provided configuration """ from fastapi import APIRouter, HTTPException -from .models import ProviderListResponse, ProviderModelsResponse -from .service import get_available_providers, get_provider_manager, get_provider_models +from .models import ProviderConfig, ProviderModelsResponse +from .service import get_provider_models router = APIRouter(prefix="/providers", tags=["providers"]) -@router.get("/list", response_model=ProviderListResponse) -async def list_providers() -> ProviderListResponse: +@router.post("/{provider_name}/models", response_model=ProviderModelsResponse) +async def list_provider_models(provider_name: str, config: ProviderConfig) -> ProviderModelsResponse: """ - List all available providers. - - Returns: - List of available providers - """ - providers = get_available_providers() - return ProviderListResponse(providers=providers) - - -@router.get("/check/{provider_name}") -async def check_provider(provider_name: str) -> dict: - """ - Check if a specific provider is available. - - Args: - provider_name: Name of the provider to check - - Returns: - Status of the provider - - Raises: - HTTPException: If the provider is not found - """ - provider_manager = get_provider_manager() - available_providers = provider_manager.available_providers() - - if provider_name not in available_providers: - raise HTTPException( - status_code=404, - detail=f"Provider '{provider_name}' not found. Available providers: {', '.join(available_providers)}", - ) - - # Get the provider config - provider_config = provider_manager.get_provider_config(provider_name) - - return { - "status": "available", - "provider": provider_name, - "kind": provider_config.kind, - "environment": provider_config.environment, - "default_model": provider_config.default_model or "", - } - - -@router.get("/{provider_name}/models", response_model=ProviderModelsResponse) -async def list_provider_models(provider_name: str) -> ProviderModelsResponse: - """ - List available models for a specific provider. + List available models for a specific provider using provided configuration. Args: provider_name: Name of the provider to get models for + config: Provider configuration for this request Returns: List of available models for the provider Raises: - HTTPException: If the provider is not found + HTTPException: If there is an error getting models """ try: - models = await get_provider_models(provider_name) + models = await get_provider_models(provider_name, config) return ProviderModelsResponse(provider=provider_name, models=models) - except ValueError as e: - raise HTTPException(status_code=404, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=f"Error getting models: {str(e)}") diff --git a/servers/inference_server/server/server/features/providers/service.py b/servers/inference_server/server/server/features/providers/service.py index c4186d7a..ed4d7654 100644 --- a/servers/inference_server/server/server/features/providers/service.py +++ b/servers/inference_server/server/server/features/providers/service.py @@ -5,90 +5,67 @@ Provides services for working with AI providers. """ -import os -from pathlib import Path -from typing import Any, List, Optional +from typing import Any, List from graphcap.providers.factory import initialize_provider_manager from graphcap.providers.provider_manager import ProviderManager from loguru import logger -from ...config import settings -from .models import ModelInfo, ProviderInfo +from .models import ModelInfo, ProviderConfig -# Global provider manager instance -_provider_manager: Optional[ProviderManager] = None +# Global provider manager instance for handling requests +_provider_manager: ProviderManager = initialize_provider_manager(None) -def get_provider_manager() -> ProviderManager: +async def get_provider_models(provider_name: str, config: ProviderConfig) -> List[ModelInfo]: """ - Get or initialize the provider manager. - Returns: - ProviderManager: The initialized provider manager - """ - global _provider_manager - - if _provider_manager is None: - # Use the provider config path from server settings - config_path = settings.PROVIDER_CONFIG_PATH - - # Verify the config path exists - if config_path is None: - logger.warning("Provider config path is None, using default locations") - elif not os.path.exists(str(config_path)): - logger.warning(f"Provider config path does not exist: {config_path}") - # Check if the directory exists - config_dir = Path(str(config_path)).parent - if not os.path.exists(str(config_dir)): - logger.warning(f"Config directory does not exist: {config_dir}") - else: - logger.info(f"Config directory exists: {config_dir}, but provider.config.toml is missing") - # List files in the directory - files = os.listdir(str(config_dir)) - logger.info(f"Files in config directory: {files}") - else: - logger.info(f"Provider config file exists: {config_path}") - - logger.info(f"Initializing provider manager with config path: {config_path}") - _provider_manager = initialize_provider_manager(config_path) - - # Log the available providers - provider_names = _provider_manager.available_providers() - if provider_names: - logger.info(f"Available providers: {', '.join(provider_names)}") - else: - logger.warning("No providers available. Check your provider.config.toml file.") - - return _provider_manager - + Get a list of available models for a specific provider. -def get_available_providers() -> List[ProviderInfo]: - """ - Get a list of available providers. + Args: + provider_name: Name of the provider to get models for + config: Provider configuration for this request Returns: - List of provider information + List of model information """ - # Get the provider manager - provider_manager = get_provider_manager() - - # Get the list of available providers - provider_names = provider_manager.available_providers() - providers = [] - - for name in provider_names: + # Initialize client with provided configuration + client = _provider_manager.get_client( + name=provider_name, + kind=config.kind, + environment=config.environment, + base_url=config.base_url, + api_key=config.api_key, + default_model=config.default_model, + rate_limits=config.rate_limits + ) + + models = [] + + # Try to fetch models if configured + if config.fetch_models: try: - config = provider_manager.get_provider_config(name) - providers.append( - ProviderInfo( - name=name, - kind=config.kind, - default_model=config.default_model or "", - ) - ) + logger.info(f"Fetching models from provider {provider_name}") + if hasattr(client, "get_available_models"): + provider_models = await client.get_available_models() + if hasattr(provider_models, "data"): + for model in provider_models.data: + model_id = _extract_model_id(model) + models.append(_create_model_info(model_id, config.default_model or "")) + elif hasattr(client, "get_models"): + provider_models = await client.get_models() + if hasattr(provider_models, "models"): + for model in provider_models.models: + model_id = _extract_model_id(model) + models.append(_create_model_info(model_id, config.default_model or "")) except Exception as e: - logger.error(f"Error getting provider {name}: {str(e)}") + logger.error(f"Error fetching models from provider {provider_name}: {str(e)}") + logger.info(f"Falling back to configured models for provider {provider_name}") + + # Fall back to configured models if none fetched + if not models: + models = [_create_model_info(model_id, config.default_model or "") for model_id in config.models] + logger.info(f"Using {len(models)} configured models for provider {provider_name}") - return providers + return models def _create_model_info(model_id: str, default_model: str) -> ModelInfo: @@ -101,85 +78,3 @@ def _extract_model_id(model: Any) -> str: if hasattr(model, "id"): return model.id return model.name if hasattr(model, "name") else str(model) - - -async def _fetch_models_from_available_models(client: Any, default_model: str) -> List[ModelInfo]: - """Fetch models using get_available_models method.""" - models = [] - provider_models = await client.get_available_models() - - if hasattr(provider_models, "data"): - for model in provider_models.data: - model_id = _extract_model_id(model) - models.append(_create_model_info(model_id, default_model)) - - return models - - -async def _fetch_models_from_get_models(client: Any, default_model: str) -> List[ModelInfo]: - """Fetch models using get_models method.""" - models = [] - provider_models = await client.get_models() - - if hasattr(provider_models, "models"): - for model in provider_models.models: - model_id = _extract_model_id(model) - models.append(_create_model_info(model_id, default_model)) - - return models - - -def _get_configured_models(config: Any) -> List[ModelInfo]: - """Get models from configuration.""" - return [_create_model_info(model_id, config.default_model) for model_id in config.models] - - -async def _fetch_provider_models(client: Any, provider_name: str, config: Any) -> List[ModelInfo]: - """Attempt to fetch models from the provider.""" - models = [] - - try: - logger.info(f"Fetching models from provider {provider_name}") - - if hasattr(client, "get_available_models"): - models = await _fetch_models_from_available_models(client, config.default_model) - elif hasattr(client, "get_models"): - models = await _fetch_models_from_get_models(client, config.default_model) - - logger.info(f"Found {len(models)} models for provider {provider_name}") - except Exception as e: - logger.error(f"Error fetching models from provider {provider_name}: {str(e)}") - logger.info(f"Falling back to configured models for provider {provider_name}") - - return models - - -async def get_provider_models(provider_name: str) -> List[ModelInfo]: - """ - Get a list of available models for a specific provider. - - Args: - provider_name: Name of the provider to get models for - Returns: - List of model information - Raises: - ValueError: If the provider is not found - """ - provider_manager = get_provider_manager() - available_providers = provider_manager.available_providers() - - if provider_name not in available_providers: - raise ValueError(f"Provider '{provider_name}' not found. Available providers: {', '.join(available_providers)}") - - config = provider_manager.get_provider_config(provider_name) - client = provider_manager.get_client(provider_name) - models = [] - - if config.fetch_models: - models = await _fetch_provider_models(client, provider_name, config) - - if not models: - models = _get_configured_models(config) - logger.info(f"Using {len(models)} configured models for provider {provider_name}") - - return models From e24e39f92b5f53c4c3d1ee1e346be2d9a96ebb3c Mon Sep 17 00:00:00 2001 From: jphillips Date: Tue, 25 Mar 2025 08:36:56 -0500 Subject: [PATCH 03/69] Remove provider config files Signed-off-by: jphillips --- .../provider_tests/test_provider_factory.py | 226 ++++++++++++++++++ workspace/config/.env.template | 2 - workspace/config/provider.example.config.toml | 74 ------ 3 files changed, 226 insertions(+), 76 deletions(-) create mode 100644 test/library_tests/provider_tests/test_provider_factory.py delete mode 100644 workspace/config/provider.example.config.toml diff --git a/test/library_tests/provider_tests/test_provider_factory.py b/test/library_tests/provider_tests/test_provider_factory.py new file mode 100644 index 00000000..57bf8a00 --- /dev/null +++ b/test/library_tests/provider_tests/test_provider_factory.py @@ -0,0 +1,226 @@ +""" +# SPDX-License-Identifier: Apache-2.0 +graphcap.tests.lib.providers.test_provider_factory + +Tests for provider factory functionality. + +Key features: +- Provider client creation and caching +- Environment validation +- Client-specific configurations +""" + +import pytest +from unittest.mock import patch, MagicMock + +from graphcap.providers.factory import ( + ProviderFactory, + create_provider_client, + get_provider_factory, + clear_provider_cache +) +from graphcap.providers.types import ProviderConfig, RateLimits + + +def test_provider_factory_initialization(): + """ + GIVEN a provider factory + WHEN initializing a new instance + THEN should create an empty client cache + """ + factory = ProviderFactory() + assert hasattr(factory, '_client_cache') + assert factory._client_cache == {} + + +@pytest.mark.parametrize( + "provider_config", + [ + { + "name": "test-openai", + "kind": "openai", + "environment": "cloud", + "base_url": "https://api.openai.com/v1", + "api_key": "test-key", + "default_model": "gpt-4o-mini", + }, + { + "name": "test-gemini", + "kind": "gemini", + "environment": "cloud", + "base_url": "https://generativelanguage.googleapis.com/v1beta", + "api_key": "test-key", + "default_model": "gemini-2.0-flash-exp", + }, + ], +) +@patch("graphcap.providers.factory.get_client") +def test_create_client(mock_get_client, provider_config): + """ + GIVEN valid provider configurations + WHEN creating a client + THEN should call get_client with correct parameters + AND should return the expected client instance + """ + # Setup mock + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + # Create factory and client + factory = ProviderFactory() + client = factory.create_client(**provider_config) + + # Verify + mock_get_client.assert_called_once_with( + name=provider_config["name"], + kind=provider_config["kind"], + environment=provider_config["environment"], + api_key=provider_config["api_key"], + base_url=provider_config["base_url"], + default_model=provider_config["default_model"], + ) + assert client == mock_client + + +@patch("graphcap.providers.factory.get_client") +def test_client_caching(mock_get_client): + """ + GIVEN a client that has been created + WHEN creating the same client again + THEN should return the cached client + AND should not call get_client again + """ + # Setup mock + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + # Create configuration + config = { + "name": "test-openai", + "kind": "openai", + "environment": "cloud", + "base_url": "https://test.com", + "api_key": "test-key", + "default_model": "test-model", + } + + # Create factory and client + factory = ProviderFactory() + + # First call should create the client + client1 = factory.create_client(**config, use_cache=True) + assert mock_get_client.call_count == 1 + + # Second call should use cached client + client2 = factory.create_client(**config, use_cache=True) + assert mock_get_client.call_count == 1 # Count should still be 1 + assert client1 is client2 # Should be the same instance + + # Call with use_cache=False should create a new client + client3 = factory.create_client(**config, use_cache=False) + assert mock_get_client.call_count == 2 # Count should now be 2 + assert client1 is not client3 # Should be different instances + + +def test_clear_cache(): + """ + GIVEN a factory with cached clients + WHEN clearing the cache + THEN should remove all cached clients + """ + with patch("graphcap.providers.factory.get_client") as mock_get_client: + # Setup mock + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + # Create factory and add some clients to cache + factory = ProviderFactory() + factory.create_client( + name="test1", + kind="openai", + environment="cloud", + base_url="https://test1.com", + api_key="key1", + default_model="model1", + ) + factory.create_client( + name="test2", + kind="gemini", + environment="cloud", + base_url="https://test2.com", + api_key="key2", + default_model="model2", + ) + + # Verify cache has clients + assert len(factory._client_cache) == 2 + + # Clear cache + factory.clear_cache() + + # Verify cache is empty + assert len(factory._client_cache) == 0 + + +@patch("graphcap.providers.factory._provider_factory", None) +@patch("graphcap.providers.factory.ProviderFactory") +def test_get_provider_factory(mock_factory_class): + """ + GIVEN no existing provider factory + WHEN calling get_provider_factory + THEN should create a new factory instance + """ + # Setup mock + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory + + # Call function + factory = get_provider_factory() + + # Verify + mock_factory_class.assert_called_once() + assert factory == mock_factory + + +@patch("graphcap.providers.factory.get_provider_factory") +def test_create_provider_client(mock_get_factory): + """ + GIVEN valid provider configuration + WHEN calling create_provider_client + THEN should get factory and call create_client + """ + # Setup mock + mock_factory = MagicMock() + mock_client = MagicMock() + mock_factory.create_client.return_value = mock_client + mock_get_factory.return_value = mock_factory + + # Call function + config = { + "name": "test", + "kind": "openai", + "environment": "cloud", + "base_url": "https://test.com", + "api_key": "test-key", + "default_model": "test-model", + } + client = create_provider_client(**config) + + # Verify + mock_get_factory.assert_called_once() + mock_factory.create_client.assert_called_once_with(**config) + assert client == mock_client + + +@patch("graphcap.providers.factory._provider_factory") +def test_clear_provider_cache(mock_factory): + """ + GIVEN an existing provider factory + WHEN calling clear_provider_cache + THEN should call clear_cache on the factory + """ + # Call function + clear_provider_cache() + + # Verify + mock_factory.clear_cache.assert_called_once() \ No newline at end of file diff --git a/workspace/config/.env.template b/workspace/config/.env.template index 70779218..2baa814e 100644 --- a/workspace/config/.env.template +++ b/workspace/config/.env.template @@ -17,8 +17,6 @@ POSTGRES_HOST=graphcap_postgres POSTGRES_PORT=5432 POSTGRES_DB=graphcap -# Configuration Paths -DEFAULT_PROVIDER_CONFIG="./provider.config.toml" GRAPHCAP_SERVER=http://localhost:32100 diff --git a/workspace/config/provider.example.config.toml b/workspace/config/provider.example.config.toml deleted file mode 100644 index 03417698..00000000 --- a/workspace/config/provider.example.config.toml +++ /dev/null @@ -1,74 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# -# This is a provider configuration file that allows you to customize provider configurations. -# To use this file: -# -# 1. Uncomment the providers you want to enable -# 2. Make your desired changes to those providers -# 3. Save as 'provider.config.toml' -# 4. Run 'docker compose build' and 'docker compose up -d' as normal - -[openai] -kind = "openai" -environment = "cloud" -env_var = "OPENAI_API_KEY" -base_url = "https://api.openai.com/v1" -models = [ - "gpt-4o-mini", - "gpt-4o", -] - -[gemini] -kind = "gemini" -environment = "cloud" -env_var = "GOOGLE_API_KEY" -base_url = "https://generativelanguage.googleapis.com/v1beta" -models = [ - "gemini-2.0-flash-exp", -] -# Rate limits configuration -rate_limits.requests_per_minute = 10 -rate_limits.tokens_per_minute = 4000000 - -# [openrouter] -# kind = "openrouter" -# environment = "cloud" -# env_var = "OPENROUTER_API_KEY" -# base_url = "https://openrouter.ai/api/v1" -# models = [ -# "minimax/minimax-01", -# "qwen/qvq-72b-preview", -# "qwen/qvq-32b-preview", -# "qwen/qvq-1.5b-preview", -# "google/gemini-2.0-flash-exp:free", -# "mistralai/pixtral-large-2411", -# "meta-llama/llama-3.2-90b-vision-instruct:free", -# "qwen/qwen-2-vl-72b-instruct" -# ] - -# [custom] -# # Custom provider configuration -# # Each provider needs a unique name, env_var (or stub), and base_url - -# [ollama] -# kind = "ollama" -# environment = "local" -# env_var = "CUSTOM_PROVIDER_1_KEY" -# base_url = "http://localhost:11434" -# fetch_models = true - -# [my_provider_2] -# kind = "ollama" -# environment = "local" -# env_var = "CUSTOM_PROVIDER_2_KEY" -# base_url = "http://localhost:11435" -# fetch_models = true - - -# # Add more custom providers as needed following the same pattern: -# # [provider_name] -# # environment = "cloud" -# # kind = "vllm" -# # env_var = "API_KEY" -# # base_url = "BASE_URL" -# # models = ["model1", "model2"] From f789c3889e77736b14097f4af7adc868374ad693 Mon Sep 17 00:00:00 2001 From: jphillips Date: Tue, 25 Mar 2025 08:38:42 -0500 Subject: [PATCH 04/69] Clean up old provider file code. Breaking changes for offline pipelines Signed-off-by: jphillips --- .../graphcap/providers/__init__.py | 21 +- .../graphcap/providers/factory.py | 213 ++++++++++++------ .../graphcap/providers/provider_config.py | 163 -------------- .../graphcap/providers/provider_manager.py | 76 ------- .../pipelines/pipelines/providers/assets.py | 87 +++---- .../pipelines/pipelines/providers/util.py | 26 +-- .../server/features/perspectives/service.py | 2 +- .../server/features/providers/service.py | 82 +++++-- 8 files changed, 284 insertions(+), 386 deletions(-) delete mode 100644 servers/inference_server/graphcap/providers/provider_config.py delete mode 100644 servers/inference_server/graphcap/providers/provider_manager.py diff --git a/servers/inference_server/graphcap/providers/__init__.py b/servers/inference_server/graphcap/providers/__init__.py index 950cb1d5..57392b17 100644 --- a/servers/inference_server/graphcap/providers/__init__.py +++ b/servers/inference_server/graphcap/providers/__init__.py @@ -14,6 +14,23 @@ Components: clients: Provider-specific client implementations - provider_config: Configuration management - provider_manager: Provider lifecycle management + factory: Provider client factory + types: Common type definitions """ + +from .factory import ( + ProviderFactory, + clear_provider_cache, + create_provider_client, + get_provider_factory, +) +from .types import ProviderConfig, RateLimits + +__all__ = [ + "ProviderFactory", + "create_provider_client", + "get_provider_factory", + "clear_provider_cache", + "ProviderConfig", + "RateLimits", +] diff --git a/servers/inference_server/graphcap/providers/factory.py b/servers/inference_server/graphcap/providers/factory.py index efbac5ab..805f26ff 100644 --- a/servers/inference_server/graphcap/providers/factory.py +++ b/servers/inference_server/graphcap/providers/factory.py @@ -2,84 +2,171 @@ # SPDX-License-Identifier: Apache-2.0 Provider Factory Module -This module provides factory functions for creating provider clients. +This module provides factory functionality for creating provider clients. + +Key features: +- Client instantiation +- Environment validation +- Rate limit configuration +- Client caching """ -import os -import tempfile -from pathlib import Path -from typing import Optional +from typing import Dict, Optional from loguru import logger from .clients import BaseClient, get_client -from .provider_config import get_providers_config -from .provider_manager import ProviderManager - -# Global provider manager instance -_provider_manager: Optional[ProviderManager] = None - -def initialize_provider_manager(config_path: Optional[str | Path] = None) -> ProviderManager: - """Initialize the global provider manager with the given config path. - Args: - config_path: Path to the provider configuration file. If None, uses default locations. +class ProviderFactory: + """Factory class for creating provider clients with specific configurations""" + + def __init__(self): + """Initialize provider factory""" + logger.info("Initializing ProviderFactory") + self._client_cache: Dict[str, BaseClient] = {} + + def create_client( + self, + name: str, + kind: str, + environment: str, + base_url: str, + api_key: str, + default_model: Optional[str] = None, + rate_limits: Optional[dict] = None, + use_cache: bool = True, + ) -> BaseClient: + """Create a client with the given configuration. + + Args: + name: Unique identifier for the provider + kind: Type of provider (e.g., 'openai', 'anthropic', 'gemini') + environment: Provider environment (cloud, local) + base_url: Base URL for the provider API + api_key: API key for the provider + default_model: Default model for the provider + rate_limits: Rate limiting configuration + use_cache: Whether to cache and reuse client instances (default: True) + + Returns: + BaseClient: The provider client instance + + Raises: + ValueError: If client creation fails + """ + # Check cache first if enabled + cache_key = f"{name}:{kind}:{environment}:{base_url}:{api_key}" + if use_cache and cache_key in self._client_cache: + logger.debug(f"Using cached client for provider: {name}") + return self._client_cache[cache_key] + + logger.info(f"Creating new client for provider: {name}") + logger.info(f"Provider config details:") + logger.info(f" - kind: {kind}") + logger.info(f" - environment: {environment}") + logger.info(f" - base_url: {base_url}") + logger.info(f" - default_model: {default_model}") + + try: + client = get_client( + name=name, + kind=kind, + environment=environment, + api_key=api_key, + base_url=base_url, + default_model=default_model, + ) + + # Set rate limits if configured + if rate_limits: + logger.debug( + f"Setting rate limits for {name} - requests: {rate_limits.get('requests_per_minute')}/min, tokens: {rate_limits.get('tokens_per_minute')}/min" + ) + client.requests_per_minute = rate_limits.get("requests_per_minute") + client.tokens_per_minute = rate_limits.get("tokens_per_minute") + + # Cache the client if enabled + if use_cache: + self._client_cache[cache_key] = client + + return client + + except Exception as e: + logger.error(f"Failed to create client for {name}: {str(e)}") + logger.error(f"Provider config details:") + logger.error(f" - kind: {kind}") + logger.error(f" - environment: {environment}") + logger.error(f" - base_url: {base_url}") + logger.error(f" - default_model: {default_model}") + raise ValueError(f"Failed to create client for {name}: {str(e)}") + + def clear_cache(self) -> None: + """Clear the client cache""" + self._client_cache.clear() + + +# Global provider factory instance +_provider_factory: Optional[ProviderFactory] = None + + +def get_provider_factory() -> ProviderFactory: + """Get or create the global provider factory instance. Returns: - ProviderManager: The initialized provider manager + ProviderFactory: The global provider factory instance """ - global _provider_manager - - if config_path is None: - # Try to find config in standard locations - possible_paths = [ - os.environ.get("PROVIDER_CONFIG_PATH"), - "./provider.config.toml", - "./config/provider.config.toml", - "/app/provider.config.toml", - "/app/config/provider.config.toml", - ] - - for path in possible_paths: - if path and Path(path).exists(): - config_path = path - break - - if not config_path or not Path(str(config_path)).exists(): - logger.warning(f"No provider config found at {config_path}. Using empty configuration.") - # Create a temporary empty config file - with tempfile.NamedTemporaryFile(delete=False, suffix=".toml") as temp: - temp.write(b"# Empty provider config\n") - config_path = temp.name - - # At this point, config_path should not be None - _provider_manager = ProviderManager(str(config_path)) - return _provider_manager - - -def get_provider_client(provider_name: str = "default") -> BaseClient: - """Get a provider client by name. + global _provider_factory + + if _provider_factory is None: + _provider_factory = ProviderFactory() + logger.info("Created new provider factory instance") + + return _provider_factory + + +def create_provider_client( + name: str, + kind: str, + environment: str, + base_url: str, + api_key: str, + default_model: Optional[str] = None, + rate_limits: Optional[dict] = None, + use_cache: bool = True, +) -> BaseClient: + """Create a provider client with the given configuration. Args: - provider_name: Name of the provider to get. Defaults to "default". + name: Unique identifier for the provider + kind: Type of provider (e.g., 'openai', 'anthropic', 'gemini') + environment: Provider environment (cloud, local) + base_url: Base URL for the provider API + api_key: API key for the provider + default_model: Default model for the provider + rate_limits: Rate limiting configuration + use_cache: Whether to cache and reuse client instances (default: True) Returns: - BaseClient: The provider client + BaseClient: The provider client instance Raises: - ValueError: If the provider is not found + ValueError: If client creation fails """ - global _provider_manager - - if _provider_manager is None: - initialize_provider_manager() - - if _provider_manager is None: - raise ValueError("Failed to initialize provider manager") - - try: - return _provider_manager.get_client(provider_name) - except ValueError as e: - logger.error(f"Failed to get provider client: {e}") - raise + factory = get_provider_factory() + return factory.create_client( + name=name, + kind=kind, + environment=environment, + base_url=base_url, + api_key=api_key, + default_model=default_model, + rate_limits=rate_limits, + use_cache=use_cache, + ) + + +def clear_provider_cache() -> None: + """Clear the provider client cache""" + if _provider_factory is not None: + _provider_factory.clear_cache() diff --git a/servers/inference_server/graphcap/providers/provider_config.py b/servers/inference_server/graphcap/providers/provider_config.py deleted file mode 100644 index a5e4c4da..00000000 --- a/servers/inference_server/graphcap/providers/provider_config.py +++ /dev/null @@ -1,163 +0,0 @@ -""" -# SPDX-License-Identifier: Apache-2.0 -Provider Configuration Module - -This module handles loading and validating provider configurations from TOML files. - -Key features: -- TOML configuration loading -- Provider config validation -- Default model handling -- Environment variable management - -Classes: - ProviderConfig: Configuration dataclass for providers - -Functions: - load_provider_config: Load config from TOML file - parse_provider_config: Parse config into ProviderConfig object - get_providers_config: Load and parse all provider configs - validate_config: Validate provider configurations -""" - -import tomllib -from pathlib import Path -from typing import Any - -from loguru import logger - -from .types import ProviderConfig, RateLimits - - -def _load_provider_config(config_path: str | Path) -> dict[str, ProviderConfig]: - """Load provider configuration from a TOML file.""" - - config_path = Path(config_path) - - if not config_path.exists(): - raise FileNotFoundError(f"Configuration file not found: {config_path}") - - config_data = {} - with config_path.open("rb") as f: - config_data = tomllib.load(f) - return config_data - - -def _parse_provider_config(config_data: dict[str, Any]) -> ProviderConfig: - """Parse a provider's configuration data into a ProviderConfig object""" - # Get models list and default model - models: list[str] = config_data.get("models", []) - default_model: str = config_data.get("default_model", "") - fetch_models: bool = config_data.get("fetch_models", False) - - kind: str = config_data["kind"] - environment: str = config_data["environment"] - env_var: str = config_data.get("env_var", "") - base_url: str = config_data["base_url"] - - # If no default model specified, require one to be set - if not default_model: - if models: - default_model = models[0] - logger.debug(f"Using first model as default: {default_model}") - else: - raise ValueError("Must specify default_model when no models list is provided") - - # Parse rate limits if present - rate_limits = None - if "rate_limits" in config_data: - rate_limits_data: dict[str, int | None] = config_data["rate_limits"] - rate_limits = RateLimits( - requests_per_minute=rate_limits_data.get("requests_per_minute"), - tokens_per_minute=rate_limits_data.get("tokens_per_minute"), - ) - - return ProviderConfig( - kind=kind, - environment=environment, - env_var=env_var, - base_url=base_url, - models=models, - default_model=default_model, - fetch_models=fetch_models, - rate_limits=rate_limits, - ) - - -def get_providers_config(config_path: str | Path) -> dict[str, ProviderConfig]: - """ - Load and parse the providers configuration. - - - Args: - config_path: Path to the TOML configuration file - - Returns: - Dictionary mapping provider names to their configurations - - Example config: - [openai] - kind = "openai" - environment = "cloud" - env_var = "OPENAI_API_KEY" - base_url = "https://api.openai.com/v1" - models = ["gpt-4o", "gpt-4o-mini"] - default_model = "gpt-4o-mini" # Optional, defaults to first model in list - - [ollama] - kind = "ollama" - environment = "local" - env_var = "CUSTOM_KEY" - base_url = "http://localhost:11434" - fetch_models = true - default_model = "llama3.2" # Optional, defaults to "default" if no models - """ - config = _load_provider_config(config_path) - providers = {} - - # Parse all top-level provider configs - for name, provider_config in config.items(): - if isinstance(provider_config, dict): # Skip non-provider sections - try: - providers[name] = _parse_provider_config(provider_config) - except KeyError as e: - logger.warning(f"Skipping provider '{name}': Missing required field {e}") - provider_errors = validate_config(providers) - if provider_errors: - logger.error(f"Provider configuration errors: {provider_errors}") - raise ValueError(f"Provider configuration errors: {provider_errors}") - logger.info(f"Loaded {len(providers)} providers") - logger.debug(f"Providers: {providers}") - return providers - - -def validate_config(providers: dict[str, ProviderConfig]) -> list[str]: - """Validate the provider configuration.""" - errors: list[str] = [] - - for name, provider in providers.items(): - # Required fields - if not provider.base_url: - errors.append(f"{name}: Missing base URL") - if not provider.kind: - errors.append(f"{name}: Missing kind") - if not provider.environment: - errors.append(f"{name}: Missing environment") - if not provider.default_model: - errors.append(f"{name}: Missing default_model") - - # Environment validation - if provider.environment not in ["cloud", "local"]: - errors.append(f"{name}: Environment must be 'cloud' or 'local'") - - # URL format - if provider.base_url and not ( - provider.base_url.startswith("http://") or provider.base_url.startswith("https://") - ): - errors.append(f"{name}: Base URL must start with http:// or https://") - - # Models list when fetch_models is False - if not provider.fetch_models and not provider.models: - errors.append(f"{name}: Must specify models list when fetch_models is False") - - return errors diff --git a/servers/inference_server/graphcap/providers/provider_manager.py b/servers/inference_server/graphcap/providers/provider_manager.py deleted file mode 100644 index 2e18ee79..00000000 --- a/servers/inference_server/graphcap/providers/provider_manager.py +++ /dev/null @@ -1,76 +0,0 @@ -""" -# SPDX-License-Identifier: Apache-2.0 -Provider Manager Module - -This module handles provider lifecycle management and client initialization. - -Key features: -- Client initialization -- Environment validation -- Rate limit management - -Classes: - ProviderManager: Main provider management class -""" - -from typing import Dict, Optional - -from loguru import logger - -from .clients import BaseClient, get_client - - -class ProviderManager: - """Manager class for handling provider lifecycle and client initialization""" - - def __init__(self, _: Optional[str] = None): - """Initialize provider manager""" - logger.info("Initializing ProviderManager") - self._clients: Dict[str, BaseClient] = {} - - def get_client( - self, - name: str, - kind: str, - environment: str, - base_url: str, - api_key: str, - default_model: Optional[str] = None, - rate_limits: Optional[dict] = None, - ) -> BaseClient: - """Initialize a client with the given configuration""" - logger.info(f"Initializing client for provider: {name}") - logger.info(f"Provider config details:") - logger.info(f" - kind: {kind}") - logger.info(f" - environment: {environment}") - logger.info(f" - base_url: {base_url}") - logger.info(f" - default_model: {default_model}") - - try: - client = get_client( - name=name, - kind=kind, - environment=environment, - api_key=api_key, - base_url=base_url, - default_model=default_model, - ) - - # Set rate limits if configured - if rate_limits: - logger.debug( - f"Setting rate limits for {name} - requests: {rate_limits.get('requests_per_minute')}/min, tokens: {rate_limits.get('tokens_per_minute')}/min" - ) - client.requests_per_minute = rate_limits.get("requests_per_minute") - client.tokens_per_minute = rate_limits.get("tokens_per_minute") - - return client - - except Exception as e: - logger.error(f"Failed to initialize client for {name}: {str(e)}") - logger.error(f"Provider config details:") - logger.error(f" - kind: {kind}") - logger.error(f" - environment: {environment}") - logger.error(f" - base_url: {base_url}") - logger.error(f" - default_model: {default_model}") - raise diff --git a/servers/inference_server/pipelines/pipelines/providers/assets.py b/servers/inference_server/pipelines/pipelines/providers/assets.py index c5952a59..4876f39c 100644 --- a/servers/inference_server/pipelines/pipelines/providers/assets.py +++ b/servers/inference_server/pipelines/pipelines/providers/assets.py @@ -2,7 +2,6 @@ """Assets for loading provider configurations.""" import dagster as dg -from graphcap.providers.provider_config import get_providers_config from graphcap.providers.types import ProviderConfig from ..common.resources import ProviderConfigFile @@ -12,56 +11,44 @@ def provider_list( context: dg.AssetExecutionContext, provider_config_file: ProviderConfigFile ) -> dict[str, ProviderConfig]: - """Loads the list of providers from the provider.config.toml file.""" - config_path = provider_config_file.provider_config - try: - providers = get_providers_config(config_path) - context.log.info(f"Loaded providers from {config_path}") - provider_info = [f"{name}: {provider.default_model}" for name, provider in providers.items()] - context.add_output_metadata( - { - "num_providers": len(providers), - "config_path": config_path, - "providers": ", ".join(provider_info), - } - ) - return providers - except FileNotFoundError: - context.log.error(f"Provider config file not found: {config_path}") - return {} - except Exception as e: - context.log.error(f"Error loading provider config: {e}") - return {} + """Loads the list of providers (now from data service API).""" + # TODO: Call data service API to get providers instead of loading from file + # For now, return an empty dictionary to avoid errors + context.log.info("Provider configuration is now managed by the data service") + + # Sample provider for testing + gemini_config = ProviderConfig( + kind="gemini", + environment="cloud", + env_var="GOOGLE_API_KEY", + base_url="https://generativelanguage.googleapis.com/v1beta", + models=["gemini-2.0-flash-exp"], + default_model="gemini-2.0-flash-exp", + fetch_models=False, + ) + + providers = {"gemini": gemini_config} + + context.add_output_metadata( + { + "num_providers": len(providers), + "providers": "gemini: gemini-2.0-flash-exp", + "note": "Provider configuration is now managed by the data service" + } + ) + return providers -# TODO: Remove this asset @dg.asset(compute_kind="python", group_name="providers") def default_provider(context: dg.AssetExecutionContext, provider_config_file: ProviderConfigFile) -> str | None: - """Loads the default provider based on the selected_provider config.""" - config_path = provider_config_file.provider_config - try: - providers = get_providers_config(config_path) - selected_provider_name = provider_config_file.default_provider - - if selected_provider_name not in providers: - context.log.warning(f"Selected provider '{selected_provider_name}' not found in config.") - return None - - selected_provider_config = providers[selected_provider_name] - - context.log.info(f"Loaded default provider: {selected_provider_name}") - context.add_output_metadata( - { - "selected_provider": selected_provider_name, - "provider_kind": selected_provider_config.kind, - "provider_environment": selected_provider_config.environment, - "provider_default_model": selected_provider_config.default_model, - } - ) - return selected_provider_name - except FileNotFoundError: - context.log.error(f"Provider config file not found: {config_path}") - return None - except Exception as e: - context.log.error(f"Error loading provider config: {e}") - return None + """Returns the default provider.""" + selected_provider_name = provider_config_file.default_provider + context.log.info(f"Using default provider: {selected_provider_name}") + + context.add_output_metadata( + { + "selected_provider": selected_provider_name, + "note": "Provider configuration is now managed by the data service" + } + ) + return selected_provider_name diff --git a/servers/inference_server/pipelines/pipelines/providers/util.py b/servers/inference_server/pipelines/pipelines/providers/util.py index 5f058ba2..2086cfd4 100644 --- a/servers/inference_server/pipelines/pipelines/providers/util.py +++ b/servers/inference_server/pipelines/pipelines/providers/util.py @@ -1,5 +1,4 @@ -from graphcap.providers.provider_config import get_providers_config -from graphcap.providers.clients import get_client +from graphcap.providers.factory import create_provider_client from ..perspectives.jobs.config import PerspectivePipelineConfig @@ -7,20 +6,21 @@ def get_provider(config_path: str, default_provider: str): """Instantiates the client based on the provider configuration. Args: - config_path (str): Path to the provider configuration file. + config_path (str): Path to the provider configuration file (deprecated). default_provider (str): The name of the default provider. Returns: The instantiated client. """ - providers = get_providers_config(config_path) - selected_provider_config = providers[default_provider] - client_args = { - "name": default_provider, - "environment": selected_provider_config.environment, - "env_var": selected_provider_config.env_var, - "base_url": selected_provider_config.base_url, - "default_model": selected_provider_config.default_model, - } - client = get_client(selected_provider_config.kind, **client_args) + # TODO: Get provider configuration from the data service API + # For now, hardcode a default configuration for Gemini + raise NotImplementedError("v2 provider configuration not implemented") + client = create_provider_client( + name=default_provider, + kind="gemini", + environment="cloud", + base_url="https://generativelanguage.googleapis.com/v1beta", + api_key="", # API key will be retrieved from environment variable + default_model="gemini-2.0-flash-exp", + ) return client diff --git a/servers/inference_server/server/server/features/perspectives/service.py b/servers/inference_server/server/server/features/perspectives/service.py index d52d969b..07931bd8 100644 --- a/servers/inference_server/server/server/features/perspectives/service.py +++ b/servers/inference_server/server/server/features/perspectives/service.py @@ -9,9 +9,9 @@ import os import socket import tempfile +from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional -from collections import defaultdict import aiohttp from fastapi import HTTPException, UploadFile diff --git a/servers/inference_server/server/server/features/providers/service.py b/servers/inference_server/server/server/features/providers/service.py index ed4d7654..b9e8955b 100644 --- a/servers/inference_server/server/server/features/providers/service.py +++ b/servers/inference_server/server/server/features/providers/service.py @@ -5,16 +5,38 @@ Provides services for working with AI providers. """ -from typing import Any, List +from typing import Any, Dict, List, Protocol, runtime_checkable -from graphcap.providers.factory import initialize_provider_manager -from graphcap.providers.provider_manager import ProviderManager +from graphcap.providers.clients.base_client import BaseClient +from graphcap.providers.factory import create_provider_client, get_provider_factory from loguru import logger from .models import ModelInfo, ProviderConfig -# Global provider manager instance for handling requests -_provider_manager: ProviderManager = initialize_provider_manager(None) + +@runtime_checkable +class ModelProvider(Protocol): + """Protocol for model providers""" + async def get_available_models(self) -> Any: ... + async def get_models(self) -> Any: ... + + +def _extract_model_id(model: Any) -> str: + """Extract model ID from provider response""" + if hasattr(model, "id"): + return model.id + if hasattr(model, "name"): + return model.name + return str(model) + + +def _create_model_info(model_id: str, default_model: str) -> ModelInfo: + """Create a ModelInfo instance""" + return ModelInfo( + id=model_id, + name=model_id, + is_default=model_id == default_model + ) async def get_provider_models(provider_name: str, config: ProviderConfig) -> List[ModelInfo]: @@ -28,20 +50,21 @@ async def get_provider_models(provider_name: str, config: ProviderConfig) -> Lis List of model information """ # Initialize client with provided configuration - client = _provider_manager.get_client( + client = create_provider_client( name=provider_name, kind=config.kind, environment=config.environment, base_url=config.base_url, api_key=config.api_key, default_model=config.default_model, - rate_limits=config.rate_limits + rate_limits=config.rate_limits, + use_cache=True, # Cache clients for better performance ) models = [] # Try to fetch models if configured - if config.fetch_models: + if config.fetch_models and isinstance(client, ModelProvider): try: logger.info(f"Fetching models from provider {provider_name}") if hasattr(client, "get_available_models"): @@ -68,13 +91,36 @@ async def get_provider_models(provider_name: str, config: ProviderConfig) -> Lis return models -def _create_model_info(model_id: str, default_model: str) -> ModelInfo: - """Create a ModelInfo instance with the given ID and default model.""" - return ModelInfo(id=model_id, name=model_id, is_default=(model_id == default_model)) - - -def _extract_model_id(model: Any) -> str: - """Extract model ID from a model object.""" - if hasattr(model, "id"): - return model.id - return model.name if hasattr(model, "name") else str(model) +def get_provider_manager(): + """ + Get a compatible provider manager using the factory pattern. + This function creates a wrapper around the provider factory that maintains + backward compatibility with code expecting a provider manager. + + Returns: + An object with the provider manager interface + """ + factory = get_provider_factory() + + # Create a wrapper object that delegates to the factory + class ProviderManagerWrapper: + def __init__(self, factory): + self.factory = factory + self._client_cache: Dict[str, BaseClient] = {} + + def get_client(self, name: str) -> BaseClient: + """Get a client for the specified provider""" + if name in self._client_cache: + return self._client_cache[name] + + # In a real implementation, this would look up the config for the name + # For now, we're just passing through to create_provider_client + # which will fail if the provider doesn't exist + return create_provider_client(name=name, kind="", environment="", base_url="", api_key="") + + def available_providers(self) -> List[str]: + """Return a list of available provider names""" + # This is a stub - in a real implementation we would return actual providers + return ["gemini"] + + return ProviderManagerWrapper(factory) From 058148296a388fd093606ceb4ff07baf44af70c1 Mon Sep 17 00:00:00 2001 From: jphillips Date: Tue, 25 Mar 2025 09:46:48 -0500 Subject: [PATCH 05/69] Initial updates for v2 Inference Bridge in client Signed-off-by: jphillips --- docker-compose.override.example.yml | 2 +- docker-compose.yml | 2 +- .../src/features/inference/constants.ts | 14 +- .../inference/hooks/useModelSelection.ts | 29 ++- .../inference/hooks/useProviderForm.ts | 16 +- .../hooks/useProviderModelSelection.ts | 22 +- .../src/features/inference/providers/types.ts | 55 ++++- .../features/inference/services/providers.ts | 83 +++---- .../components/PerspectivesErrorState.tsx | 4 +- .../context/PerspectivesDataContext.tsx | 13 +- .../hooks/useGeneratePerspectiveCaption.ts | 32 +-- .../hooks/useImagePerspectives.ts | 4 +- .../hooks/usePerspectiveModules.ts | 41 ++-- .../perspectives/hooks/usePerspectives.ts | 41 ++-- .../src/features/perspectives/services/api.ts | 66 ++--- .../perspectives/services/constants.ts | 2 +- .../features/perspectives/services/utils.ts | 28 ++- .../features/server-connections/constants.ts | 6 +- .../server-connections/services/apiClients.ts | 113 +++++++++ .../server-connections/services/index.ts | 28 ++- .../server-connections/services/providers.ts | 228 ++++++++++++++++++ .../services/serverConnections.ts | 59 ++++- .../useServerConnections.ts | 14 +- test/inference_tests/ollama_graphcap_REST.py | 2 +- 24 files changed, 675 insertions(+), 229 deletions(-) create mode 100644 graphcap_studio/src/features/server-connections/services/apiClients.ts create mode 100644 graphcap_studio/src/features/server-connections/services/providers.ts diff --git a/docker-compose.override.example.yml b/docker-compose.override.example.yml index 5909d20f..2d1cec70 100644 --- a/docker-compose.override.example.yml +++ b/docker-compose.override.example.yml @@ -95,7 +95,7 @@ name: graphcap # target: /app/pnpm-lock.yaml # environment: # - NODE_ENV=${NODE_ENV:-development} - # - VITE_API_URL=http://localhost:32100/api + # - VITE_API_URL=http://localhost:32100/ # - VITE_WORKSPACE_PATH=/workspace/.local # - VITE_MEDIA_SERVER_URL=http://localhost:32400 # networks: diff --git a/docker-compose.yml b/docker-compose.yml index 97311fef..d029f1a5 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -83,7 +83,7 @@ services: target: /app/pnpm-lock.yaml environment: - NODE_ENV=${NODE_ENV:-development} - - VITE_API_URL=http://localhost:32100/api + - VITE_API_URL=http://localhost:32100 - VITE_WORKSPACE_PATH=/workspace/.local - VITE_MEDIA_SERVER_URL=http://localhost:32400 - VITE_DATASETS_PATH=/workspace/datasets diff --git a/graphcap_studio/src/features/inference/constants.ts b/graphcap_studio/src/features/inference/constants.ts index 63343b2f..35737246 100644 --- a/graphcap_studio/src/features/inference/constants.ts +++ b/graphcap_studio/src/features/inference/constants.ts @@ -8,8 +8,10 @@ export const DEFAULT_PROVIDER_FORM_DATA = { kind: "", environment: "cloud" as const, baseUrl: "", - envVar: "", + apiKey: "", isEnabled: true, + defaultModel: "", + fetchModels: true, models: [], rateLimits: { requestsPerMinute: 0, @@ -21,3 +23,13 @@ export const DEFAULT_PROVIDER_FORM_DATA = { * Environment options for providers */ export const PROVIDER_ENVIRONMENTS = ["cloud", "local"] as const; + +/** + * Provider kinds + */ +export const PROVIDER_KINDS = [ + "openai", + "google", + "ollama", + "vllm", +] as const; diff --git a/graphcap_studio/src/features/inference/hooks/useModelSelection.ts b/graphcap_studio/src/features/inference/hooks/useModelSelection.ts index 23be2ffe..df1ab19c 100644 --- a/graphcap_studio/src/features/inference/hooks/useModelSelection.ts +++ b/graphcap_studio/src/features/inference/hooks/useModelSelection.ts @@ -1,20 +1,23 @@ +import { useProviderModels } from "@/features/server-connections/services/providers"; // SPDX-License-Identifier: Apache-2.0 import { useCallback, useEffect, useState } from "react"; -import { useProviderModels } from "../services/providers"; +import type { Provider } from "../providers/types"; /** * Custom hook for managing model selection * - * @param providerName - Name of the provider to fetch models for + * @param provider - Provider to fetch models for * @param onModelSelect - Callback function when a model is selected * @returns Model selection state and handlers */ export function useModelSelection( - providerName: string, + provider: Provider, onModelSelect?: (providerName: string, modelId: string) => void, ) { // State for model selection - const [selectedModelId, setSelectedModelId] = useState(""); + const [selectedModelId, setSelectedModelId] = useState( + provider.defaultModel || "" + ); // Get models for the current provider const { @@ -22,24 +25,26 @@ export function useModelSelection( isLoading: isLoadingModels, isError: isModelsError, error: modelsError, - } = useProviderModels(providerName); + } = useProviderModels(provider); - // Update selected model when models are loaded + // Update selected model when models are loaded or default model changes useEffect(() => { - if (providerModelsData?.models && providerModelsData.models.length > 0) { + if (provider.defaultModel) { + setSelectedModelId(provider.defaultModel); + } else if (providerModelsData?.models && providerModelsData.models.length > 0) { const defaultModel = providerModelsData.models.find( - (model) => model.is_default, + (model) => model.is_default ); setSelectedModelId(defaultModel?.id ?? providerModelsData.models[0].id); } - }, [providerModelsData]); + }, [providerModelsData, provider.defaultModel]); // Handle model selection const handleModelSelect = useCallback(() => { - if (onModelSelect && providerName && selectedModelId) { - onModelSelect(providerName, selectedModelId); + if (onModelSelect && provider.name && selectedModelId) { + onModelSelect(provider.name, selectedModelId); } - }, [onModelSelect, providerName, selectedModelId]); + }, [onModelSelect, provider.name, selectedModelId]); return { selectedModelId, diff --git a/graphcap_studio/src/features/inference/hooks/useProviderForm.ts b/graphcap_studio/src/features/inference/hooks/useProviderForm.ts index 241607ef..d73d2835 100644 --- a/graphcap_studio/src/features/inference/hooks/useProviderForm.ts +++ b/graphcap_studio/src/features/inference/hooks/useProviderForm.ts @@ -6,7 +6,7 @@ import type { import { useCreateProvider, useUpdateProvider, -} from "@/features/inference/services/providers"; +} from "@/features/server-connections/services/providers"; import { useCallback } from "react"; import { useForm } from "react-hook-form"; @@ -18,7 +18,7 @@ type FormData = ProviderCreate | ProviderUpdate; * Custom hook for managing provider form state and operations */ export function useProviderForm(initialData: Partial = {}) { - // Initialize react-hook-form + // Initialize react-hook-form with validation const { control, handleSubmit, @@ -30,10 +30,13 @@ export function useProviderForm(initialData: Partial = {}) { ...DEFAULT_PROVIDER_FORM_DATA, ...initialData, }, + mode: "onBlur", }); - // Watch the provider name for use in UI + // Watch the provider name and other fields for use in UI const providerName = watch("name"); + const fetchModels = watch("fetchModels"); + const defaultModel = watch("defaultModel"); // Mutations const createProvider = useCreateProvider(); @@ -43,6 +46,11 @@ export function useProviderForm(initialData: Partial = {}) { const onSubmit = useCallback( async (data: FormData, isCreating: boolean, providerId?: number) => { try { + // Ensure required fields are present + if (!data.name || !data.kind || !data.environment || !data.baseUrl || !data.apiKey) { + throw new Error("Missing required fields"); + } + if (isCreating) { await createProvider.mutateAsync(data as ProviderCreate); } else if (providerId) { @@ -69,6 +77,8 @@ export function useProviderForm(initialData: Partial = {}) { errors, watch, providerName, + fetchModels, + defaultModel, reset, // Form submission diff --git a/graphcap_studio/src/features/inference/hooks/useProviderModelSelection.ts b/graphcap_studio/src/features/inference/hooks/useProviderModelSelection.ts index a177e575..5e69e108 100644 --- a/graphcap_studio/src/features/inference/hooks/useProviderModelSelection.ts +++ b/graphcap_studio/src/features/inference/hooks/useProviderModelSelection.ts @@ -1,11 +1,12 @@ +import { useProviderModels, useProviders } from "@/features/server-connections/services/providers"; // SPDX-License-Identifier: Apache-2.0 import { useMemo } from "react"; -import { useProviderModels, useProviders } from "../services/providers"; +import type { Provider } from "../providers/types"; /** * Custom hook to handle provider and model selection logic */ -export function useProviderModelSelection(providerName: string) { +export function useProviderModelSelection(provider: Provider) { // Fetch providers from API const { data: providers = [], @@ -19,7 +20,7 @@ export function useProviderModelSelection(providerName: string) { isLoading: isLoadingModels, isError: isModelsError, error: modelsError, - } = useProviderModels(providerName); + } = useProviderModels(provider); // Memoize the available providers const availableProviders = useMemo(() => { @@ -30,15 +31,22 @@ export function useProviderModelSelection(providerName: string) { const providersWithNoModels = useMemo(() => { const noModelsSet = new Set(); - if (providerModelsData?.models?.length === 0) { - noModelsSet.add(providerName); + if (providerModelsData?.models?.length === 0 && provider.fetchModels) { + noModelsSet.add(provider.name); } return noModelsSet; - }, [providerName, providerModelsData]); + }, [provider.name, provider.fetchModels, providerModelsData]); // Get default model if available const defaultModel = useMemo(() => { + if (provider.defaultModel) { + return { + id: provider.defaultModel, + name: provider.defaultModel, + is_default: true, + }; + } if (providerModelsData?.models && providerModelsData.models.length > 0) { return ( providerModelsData.models.find((model) => model.is_default) || @@ -46,7 +54,7 @@ export function useProviderModelSelection(providerName: string) { ); } return null; - }, [providerModelsData]); + }, [provider.defaultModel, providerModelsData]); return { providers: availableProviders, diff --git a/graphcap_studio/src/features/inference/providers/types.ts b/graphcap_studio/src/features/inference/providers/types.ts index d6e2915f..6a02a6eb 100644 --- a/graphcap_studio/src/features/inference/providers/types.ts +++ b/graphcap_studio/src/features/inference/providers/types.ts @@ -5,6 +5,25 @@ * Type definitions for provider-related data. */ +/** + * Server-side provider configuration + * This is the configuration that gets sent to the inference server + */ +export interface ServerProviderConfig { + name: string; + kind: string; + environment: "cloud" | "local"; + base_url: string; + api_key: string; // Required for server requests + default_model?: string; + models: string[]; + fetch_models: boolean; + rate_limits?: { + requests_per_minute?: number; + tokens_per_minute?: number; + }; +} + /** * Provider model */ @@ -30,17 +49,18 @@ export interface RateLimits { } /** - * Provider configuration + * Provider configuration stored in data service */ export interface Provider { id: number; name: string; kind: string; environment: "cloud" | "local"; - envVar: string; baseUrl: string; - apiKey?: string; + apiKey: string; // Changed from optional to required isEnabled: boolean; + defaultModel?: string; + fetchModels: boolean; createdAt: string | Date; updatedAt: string | Date; models?: ProviderModel[]; @@ -54,10 +74,11 @@ export interface ProviderCreate { name: string; kind: string; environment: "cloud" | "local"; - envVar: string; baseUrl: string; - apiKey?: string; + apiKey: string; // Changed from optional to required isEnabled?: boolean; + defaultModel?: string; + fetchModels?: boolean; models?: Array<{ name: string; isEnabled?: boolean; @@ -75,9 +96,11 @@ export interface ProviderUpdate { name?: string; kind?: string; environment?: "cloud" | "local"; - envVar?: string; baseUrl?: string; + apiKey?: string; isEnabled?: boolean; + defaultModel?: string; + fetchModels?: boolean; models?: Array<{ id?: number; name: string; @@ -120,3 +143,23 @@ export interface ProviderModelsResponse { provider: string; models: ProviderModelInfo[]; } + +/** + * Helper function to convert Provider to ServerProviderConfig + */ +export function toServerConfig(provider: Provider): ServerProviderConfig { + return { + name: provider.name, + kind: provider.kind, + environment: provider.environment, + base_url: provider.baseUrl, + api_key: provider.apiKey, + default_model: provider.defaultModel, + models: provider.models?.map(m => m.name) || [], + fetch_models: provider.fetchModels, + rate_limits: provider.rateLimits ? { + requests_per_minute: provider.rateLimits.requestsPerMinute, + tokens_per_minute: provider.rateLimits.tokensPerMinute + } : undefined + }; +} diff --git a/graphcap_studio/src/features/inference/services/providers.ts b/graphcap_studio/src/features/inference/services/providers.ts index fd95a03e..b23a4ed9 100644 --- a/graphcap_studio/src/features/inference/services/providers.ts +++ b/graphcap_studio/src/features/inference/services/providers.ts @@ -8,32 +8,38 @@ import { useServerConnectionsContext } from "@/context/ServerConnectionsContext"; import { SERVER_IDS } from "@/features/server-connections/constants"; +import { createDataServiceClient, createInferenceBridgeClient } from "@/features/server-connections/services/apiClients"; import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; -import { hc } from "hono/client"; -import type { AppType } from "../../../../../data_service/src/app"; // TODO: Refactor import type { Provider, ProviderApiKey, ProviderCreate, ProviderModelsResponse, ProviderUpdate, + ServerProviderConfig, SuccessResponse, } from "../providers/types"; +import { toServerConfig } from "../providers/types"; // Query keys for TanStack Query export const queryKeys = { providers: ["providers"] as const, - provider: (id: number) => [...queryKeys.providers, id] as const, - providerModels: (providerName: string) => - [...queryKeys.providers, "models", providerName] as const, + provider: (id: number) => ["providers", id] as const, + providerModels: (providerName: string) => ["providers", "models", providerName] as const, }; +interface ServerConnection { + id: string; + url: string; + status: string; +} + // Define a more specific type for the client interface DataServiceClient { providers: { $get: () => Promise; $post: (options: { json: ProviderCreate }) => Promise; - [":id"]: { + ":id": { $get: (options: { param: { id: string } }) => Promise; $put: (options: { param: { id: string }; @@ -50,27 +56,10 @@ interface DataServiceClient { }; } -/** - * Get the Data Service URL from server connections context - */ -function getDataServiceUrl(connections: any[]): string { - const dataServiceConnection = connections.find( - (conn) => conn.id === SERVER_IDS.DATA_SERVICE, - ); - - return ( - dataServiceConnection?.url || - import.meta.env.VITE_DATA_SERVICE_URL || - "http://localhost:32550" - ); -} - -/** - * Create a Hono client for the Data Service - */ -function createDataServiceClient(connections: any[]): DataServiceClient { - const baseUrl = getDataServiceUrl(connections); - return hc(`${baseUrl}/api/v1`) as DataServiceClient; +interface GraphCapServerClient { + models: { + $post: (options: { json: ServerProviderConfig }) => Promise; + }; } /** @@ -243,37 +232,25 @@ export function useUpdateProviderApiKey() { } /** - * Get the GraphCap Server URL from server connections context + * Hook to get provider models + * This now uses the new server-side configuration */ -function getGraphCapServerUrl(connections: any[]): string { - const graphcapServerConnection = connections.find( - (conn) => conn.id === SERVER_IDS.GRAPHCAP_SERVER, - ); - - return ( - graphcapServerConnection?.url || - import.meta.env.VITE_GRAPHCAP_SERVER_URL || - "http://localhost:32100" - ); -} - -/** - * Hook to get available models for a provider from the GraphCap server - */ -export function useProviderModels(providerName: string) { +export function useProviderModels(provider: Provider) { const { connections } = useServerConnectionsContext(); - const graphcapServerConnection = connections.find( - (conn) => conn.id === SERVER_IDS.GRAPHCAP_SERVER, + const inferenceBridgeConnection = connections.find( + (conn) => conn.id === SERVER_IDS.INFERENCE_BRIDGE, ); - const isConnected = graphcapServerConnection?.status === "connected"; + const isConnected = inferenceBridgeConnection?.status === "connected"; return useQuery({ - queryKey: queryKeys.providerModels(providerName), + queryKey: queryKeys.providerModels(provider.name), queryFn: async () => { - const baseUrl = getGraphCapServerUrl(connections); - const response = await fetch( - `${baseUrl}/providers/${providerName}/models`, - ); + const client = createInferenceBridgeClient(connections); + const serverConfig = toServerConfig(provider); + + const response = await client.models.$post({ + json: serverConfig, + }); if (!response.ok) { throw new Error(`Failed to fetch provider models: ${response.status}`); @@ -281,7 +258,7 @@ export function useProviderModels(providerName: string) { return response.json() as Promise; }, - enabled: isConnected && !!providerName, + enabled: isConnected && provider.fetchModels, staleTime: 1000 * 60 * 5, // 5 minutes }); } diff --git a/graphcap_studio/src/features/perspectives/components/PerspectivesErrorState.tsx b/graphcap_studio/src/features/perspectives/components/PerspectivesErrorState.tsx index af12d179..908a091c 100644 --- a/graphcap_studio/src/features/perspectives/components/PerspectivesErrorState.tsx +++ b/graphcap_studio/src/features/perspectives/components/PerspectivesErrorState.tsx @@ -42,7 +42,7 @@ export function PerspectivesErrorState({ Server Connection Error - Unable to connect to the GraphCap server. Please check your + Unable to connect to the Inference Bridge. Please check your connection settings and try again. ) : ( - + <> + + + )} + + {/* Error Dialog */} + setIsErrorDialogOpen(false)} + error={connectionError} + providerName={selectedProvider?.name || "Provider"} + /> + + {/* Success Dialog */} + setIsSuccessDialogOpen(false)} + providerName={selectedProvider?.name || "Provider"} + connectionDetails={connectionDetails} + /> ); } diff --git a/graphcap_studio/src/features/inference/providers/components/ProviderConnectionErrorDialog.tsx b/graphcap_studio/src/features/inference/providers/components/ProviderConnectionErrorDialog.tsx new file mode 100644 index 00000000..addacae9 --- /dev/null +++ b/graphcap_studio/src/features/inference/providers/components/ProviderConnectionErrorDialog.tsx @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: Apache-2.0 +import { + Box, + Button, + Code, + Dialog, + Icon, + Portal, + Text, + VStack, +} from "@chakra-ui/react"; +import { useEffect, useRef } from "react"; +import { LuTriangleAlert } from "react-icons/lu"; + +/** + * Dialog component that displays detailed error information when a provider + * connection test fails + */ +type ErrorDetails = Record | string | null; + +type ProviderConnectionErrorDialogProps = { + isOpen: boolean; + onClose: () => void; + error: ErrorDetails; + providerName: string; +}; + +export function ProviderConnectionErrorDialog({ + isOpen, + onClose, + error, + providerName, +}: ProviderConnectionErrorDialogProps) { + // Create a reference to the dialog content + const dialogContentRef = useRef(null); + + // Prevent clicks inside the dialog from triggering outside click handlers + useEffect(() => { + function handleDialogClick(e: MouseEvent) { + // Stop event propagation for all clicks inside the dialog + e.stopPropagation(); + } + + const dialogElement = dialogContentRef.current; + if (dialogElement) { + dialogElement.addEventListener("click", handleDialogClick); + + return () => { + dialogElement.removeEventListener("click", handleDialogClick); + }; + } + }, []); // No dependencies needed as we're just setting up the event listener + + // Format error details - simplified direct approach + let formattedErrorDetails = "Unknown error occurred"; + + if (error) { + if (typeof error === "object") { + try { + formattedErrorDetails = JSON.stringify(error, null, 2); + } catch (e) { + formattedErrorDetails = `Error could not be serialized: ${String(e)}`; + } + } else { + formattedErrorDetails = String(error); + } + } + + return ( + !e.open && onClose()}> + + + + + + Connection Error: {providerName} + + + + + + + + ); +} diff --git a/graphcap_studio/src/features/inference/providers/components/ProviderConnectionSuccessDialog.tsx b/graphcap_studio/src/features/inference/providers/components/ProviderConnectionSuccessDialog.tsx new file mode 100644 index 00000000..4348628a --- /dev/null +++ b/graphcap_studio/src/features/inference/providers/components/ProviderConnectionSuccessDialog.tsx @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: Apache-2.0 +import { + Button, + Dialog, + Icon, + Portal, + Text, + VStack +} from "@chakra-ui/react"; +import { useEffect, useRef } from "react"; +import { LuCheck } from "react-icons/lu"; + +/** + * Dialog component that displays a success message when a provider + * connection test is successful + */ +interface ConnectionDetails { + models?: unknown[]; + [key: string]: unknown; +} + +type ProviderConnectionSuccessDialogProps = { + isOpen: boolean; + onClose: () => void; + providerName: string; + connectionDetails?: ConnectionDetails | null; +}; + +export function ProviderConnectionSuccessDialog({ + isOpen, + onClose, + providerName, + connectionDetails +}: ProviderConnectionSuccessDialogProps) { + // Create a reference to the dialog content + const dialogContentRef = useRef(null); + + // Prevent clicks inside the dialog from triggering outside click handlers + useEffect(() => { + function handleDialogClick(e: MouseEvent) { + // Stop event propagation for all clicks inside the dialog + e.stopPropagation(); + } + + const dialogElement = dialogContentRef.current; + if (dialogElement) { + dialogElement.addEventListener("click", handleDialogClick); + + return () => { + dialogElement.removeEventListener("click", handleDialogClick); + }; + } + }, []); // No dependencies needed as we're just setting up the event listener + + return ( + !e.open && onClose()}> + + + + + + Connection Successful + + + + + + + + ); +} \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/services/providers.ts b/graphcap_studio/src/features/inference/services/providers.ts index b23a4ed9..809b709b 100644 --- a/graphcap_studio/src/features/inference/services/providers.ts +++ b/graphcap_studio/src/features/inference/services/providers.ts @@ -56,10 +56,11 @@ interface DataServiceClient { }; } -interface GraphCapServerClient { - models: { - $post: (options: { json: ServerProviderConfig }) => Promise; - }; +/** + * Extended Error interface with cause property + */ +interface ErrorWithCause extends Error { + cause?: unknown; } /** @@ -232,8 +233,56 @@ export function useUpdateProviderApiKey() { } /** - * Hook to get provider models - * This now uses the new server-side configuration + * Hook to test provider connection + */ +export function useTestProviderConnection() { + const { connections } = useServerConnectionsContext(); + + return useMutation({ + mutationFn: async ({ providerName, config }: { providerName: string; config: ServerProviderConfig }) => { + const client = createInferenceBridgeClient(connections); + const response = await client.providers[":provider_name"]["test-connection"].$post({ + param: { provider_name: providerName }, + json: config, + }); + + if (!response.ok) { + const errorData = await response.json(); + + // Check if this is our enhanced error format + if (errorData.status === 'error' && errorData.details) { + // Use the structured error data with cause property + const error = new Error(errorData.message || 'Connection test failed') as ErrorWithCause; + error.cause = errorData; + throw error; + } + + // Handle different error formats + if (errorData.detail) { + throw new Error(errorData.detail); + } + + if (errorData.message) { + throw new Error(errorData.message); + } + + if (typeof errorData === 'object') { + // For raw objects, don't wrap in Error, just throw the object directly + // This prevents "[object Object]" in the error message + throw { ...errorData }; + } + + // Fallback to simple error + throw new Error(`Connection test failed: ${response.status}`); + } + + return response.json(); + }, + }); +} + +/** + * Hook to fetch provider models */ export function useProviderModels(provider: Provider) { const { connections } = useServerConnectionsContext(); @@ -253,7 +302,9 @@ export function useProviderModels(provider: Provider) { }); if (!response.ok) { - throw new Error(`Failed to fetch provider models: ${response.status}`); + throw new Error( + `Failed to fetch models for ${provider.name}: ${response.status}`, + ); } return response.json() as Promise; diff --git a/graphcap_studio/src/features/server-connections/services/apiClients.ts b/graphcap_studio/src/features/server-connections/services/apiClients.ts index 34d8e82f..a77f51a1 100644 --- a/graphcap_studio/src/features/server-connections/services/apiClients.ts +++ b/graphcap_studio/src/features/server-connections/services/apiClients.ts @@ -40,6 +40,22 @@ export interface InferenceBridgeClient { models: { $post: (options: { json: unknown }) => Promise; }; + providers: { + ":provider_name": { + "test-connection": { + $post: (options: { + param: { provider_name: string }; + json: unknown; + }) => Promise; + }; + "models": { + $post: (options: { + param: { provider_name: string }; + json: unknown; + }) => Promise; + }; + }; + }; perspectives: { list: { $get: () => Promise; diff --git a/servers/inference_server/server/server/features/providers/error_handler.py b/servers/inference_server/server/server/features/providers/error_handler.py new file mode 100644 index 00000000..caaf9265 --- /dev/null +++ b/servers/inference_server/server/server/features/providers/error_handler.py @@ -0,0 +1,190 @@ +""" +# SPDX-License-Identifier: Apache-2.0 +Provider Error Handler + +Handles provider-specific error formatting and responses. +""" + +import datetime +import traceback +from typing import Any, Dict, List, Set, Union + +from fastapi.responses import JSONResponse +from pydantic import ValidationError + +from .models import ProviderConfig +from ...utils.logger import logger + + +def format_provider_validation_error(e: ValidationError, provider_name: str) -> JSONResponse: + """ + Format a provider validation error into a standardized response. + + Args: + e: The validation error + provider_name: Name of the provider + + Returns: + A JSONResponse with detailed error information + """ + errors = e.errors() + invalid_params = {} + + # Extract field names for the error message + invalid_fields: Set[str] = set() + + for error in errors: + # Get field location + loc = error.get("loc", []) + if len(loc) > 1: + field_name = loc[1] if isinstance(loc[1], str) else str(loc[1]) + invalid_fields.add(field_name) + + # Format specific error details + field = ".".join(str(loc) for loc in error.get("loc", [])) if error.get("loc") else "" + message = error.get("msg", "Validation error") + error_type = error.get("type", "unknown_error") + + # Add context if available + context = {} + if error.get("ctx"): + for key, value in error.get("ctx", {}).items(): + if key != "expected" or not isinstance(value, list) or len(value) < 5: + context[key] = value + + invalid_params[field] = { + "message": message, + "error_type": error_type + } + + if context: + invalid_params[field]["context"] = context + + # Generate appropriate overall message + if len(invalid_fields) == 1: + field = next(iter(invalid_fields)) + message = f"Invalid provider configuration: '{field}' parameter is invalid" + elif len(invalid_fields) > 1: + field_list = "', '".join(sorted(invalid_fields)) + message = f"Invalid provider configuration: Parameters '{field_list}' are invalid" + else: + message = "Invalid provider configuration" + + # Build provider-specific suggestions + suggestions = ["Check API key and endpoint URL", "Verify the provider is correctly configured"] + + for error in errors: + error_type = error.get("type", "") + field = ".".join(str(loc) for loc in error.get("loc", [])[1:]) if error.get("loc") else "" + + if error_type == "missing": + suggestions.append(f"Add the missing required parameter: '{field}'") + elif error_type == "string_type": + suggestions.append(f"Ensure '{field}' is a valid string") + elif error_type == "url_parsing": + suggestions.append(f"Use a valid URL format for '{field}'") + elif error_type and "enum" in error_type: + valid_values = error.get("ctx", {}).get("expected", []) + if valid_values: + values_str = ", ".join([f"'{v}'" for v in valid_values]) + suggestions.append(f"Choose a valid option for '{field}': {values_str}") + else: + suggestions.append(f"Choose a valid option for '{field}'") + + # Add provider-specific field suggestions + if field == "api_key": + suggestions.append("Check the API key is correct for this provider") + elif field == "base_url": + suggestions.append("Verify the base URL format matches the provider's API documentation") + elif field == "environment": + suggestions.append("Valid environment values are typically 'cloud' or 'local'") + + suggestions.append("Check server logs for more details") + + # Build the response + error_response = { + "title": "Connection failed", + "timestamp": datetime.datetime.now().isoformat(), + "message": message, + "name": "Error", + "details": "The server rejected the request due to invalid provider parameters.", + "invalid_parameters": invalid_params, + "suggestions": list(dict.fromkeys(suggestions)) + } + + return JSONResponse( + status_code=400, + content=error_response + ) + + +def format_provider_connection_error(e: Exception, provider_name: str, config: ProviderConfig) -> JSONResponse: + """ + Format a provider connection error into a standardized response. + + Args: + e: The exception that occurred + provider_name: Name of the provider + config: The provider configuration + + Returns: + A JSONResponse with detailed error information + """ + # Create a detailed error response + error_response: Dict[str, Any] = { + "title": "Connection failed", + "timestamp": datetime.datetime.now().isoformat(), + "status": "error", + "message": str(e), + "name": "Error", + "details": "Failed to connect to the provider service.", + "provider_details": { + "provider": provider_name, + "error_type": type(e).__name__, + } + } + + # Add any configuration info that might be helpful for debugging + # but exclude sensitive data like API keys + safe_config = { + "kind": config.kind, + "environment": config.environment, + "base_url": config.base_url, + "default_model": config.default_model, + "models": config.models, + "fetch_models": config.fetch_models, + } + error_response["provider_details"]["config"] = safe_config + + # Create specific suggestions for common issues + suggestions = ["Check API key and endpoint URL", "Verify the provider is correctly configured"] + + if "authentication failed" in str(e).lower() or "unauthorized" in str(e).lower(): + error_response["provider_details"]["error_code"] = "AUTH_ERROR" + suggestions.append("Check if the API key is valid and has the necessary permissions") + elif "not found" in str(e).lower() or "404" in str(e).lower(): + error_response["provider_details"]["error_code"] = "ENDPOINT_NOT_FOUND" + suggestions.append("Verify the base URL is correct for this provider") + elif "timeout" in str(e).lower(): + error_response["provider_details"]["error_code"] = "TIMEOUT" + suggestions.append("The server took too long to respond. Check network connectivity or try again later") + elif "connection" in str(e).lower(): + error_response["provider_details"]["error_code"] = "CONNECTION_ERROR" + suggestions.append("Failed to establish connection to the provider. Check network connectivity") + elif "rate limit" in str(e).lower() or "too many requests" in str(e).lower(): + error_response["provider_details"]["error_code"] = "RATE_LIMIT" + suggestions.append("You've exceeded the provider's rate limits. Try again later") + elif "quota" in str(e).lower() or "exceeded" in str(e).lower(): + error_response["provider_details"]["error_code"] = "QUOTA_EXCEEDED" + suggestions.append("You've exceeded your provider quota. Check your usage dashboard") + else: + error_response["provider_details"]["error_code"] = "UNKNOWN_ERROR" + + suggestions.append("Check server logs for more details") + error_response["suggestions"] = suggestions + + # Return a structured error response with HTTP 400 status + return JSONResponse( + status_code=400, + content=error_response + ) \ No newline at end of file diff --git a/servers/inference_server/server/server/features/providers/router.py b/servers/inference_server/server/server/features/providers/router.py index 89080fd1..2a4c2c5a 100644 --- a/servers/inference_server/server/server/features/providers/router.py +++ b/servers/inference_server/server/server/features/providers/router.py @@ -6,18 +6,26 @@ This module provides the following endpoints: - POST /providers/{provider_name}/models - List available models for a provider using provided configuration +- POST /providers/{provider_name}/test-connection - Test connection to a provider using provided configuration """ -from fastapi import APIRouter, HTTPException +import traceback +from typing import Union +from fastapi import APIRouter +from fastapi.responses import JSONResponse +from pydantic import ValidationError + +from ...utils.logger import logger +from .error_handler import format_provider_connection_error, format_provider_validation_error from .models import ProviderConfig, ProviderModelsResponse -from .service import get_provider_models +from .service import get_provider_models, test_provider_connection router = APIRouter(prefix="/providers", tags=["providers"]) @router.post("/{provider_name}/models", response_model=ProviderModelsResponse) -async def list_provider_models(provider_name: str, config: ProviderConfig) -> ProviderModelsResponse: +async def list_provider_models(provider_name: str, config: ProviderConfig) -> Union[ProviderModelsResponse, JSONResponse]: """ List available models for a specific provider using provided configuration. @@ -26,7 +34,7 @@ async def list_provider_models(provider_name: str, config: ProviderConfig) -> Pr config: Provider configuration for this request Returns: - List of available models for the provider + List of available models for the provider or an error response Raises: HTTPException: If there is an error getting models @@ -34,5 +42,35 @@ async def list_provider_models(provider_name: str, config: ProviderConfig) -> Pr try: models = await get_provider_models(provider_name, config) return ProviderModelsResponse(provider=provider_name, models=models) + except ValidationError as e: + return format_provider_validation_error(e, provider_name) + except Exception as e: + logger.error(f"Error getting models for {provider_name}: {str(e)}") + logger.error(traceback.format_exc()) + return format_provider_connection_error(e, provider_name, config) + + +@router.post("/{provider_name}/test-connection") +async def test_connection(provider_name: str, config: ProviderConfig): + """ + Test connection to a provider using provided configuration. + + Args: + provider_name: Name of the provider to test connection for + config: Provider configuration for this request + + Returns: + A success message if connection is successful + + Raises: + HTTPException: If there is an error connecting to the provider + """ + try: + result = await test_provider_connection(provider_name, config) + return {"status": "success", "message": "Connection successful", "result": result} + except ValidationError as e: + return format_provider_validation_error(e, provider_name) except Exception as e: - raise HTTPException(status_code=500, detail=f"Error getting models: {str(e)}") + logger.error(f"Error testing connection to {provider_name}: {str(e)}") + logger.error(traceback.format_exc()) + return format_provider_connection_error(e, provider_name, config) diff --git a/servers/inference_server/server/server/features/providers/service.py b/servers/inference_server/server/server/features/providers/service.py index b9e8955b..21778a76 100644 --- a/servers/inference_server/server/server/features/providers/service.py +++ b/servers/inference_server/server/server/features/providers/service.py @@ -6,6 +6,7 @@ """ from typing import Any, Dict, List, Protocol, runtime_checkable +import datetime from graphcap.providers.clients.base_client import BaseClient from graphcap.providers.factory import create_provider_client, get_provider_factory @@ -124,3 +125,174 @@ def available_providers(self) -> List[str]: return ["gemini"] return ProviderManagerWrapper(factory) + + +async def test_provider_connection(provider_name: str, config: ProviderConfig) -> Dict[str, Any]: + """ + Test connection to a provider by initializing the client and performing a simple operation. + + Args: + provider_name: Name of the provider to test + config: Provider configuration for this request + + Returns: + Dictionary containing test results and additional information + + Raises: + Exception: If the connection test fails + """ + result = { + "provider": provider_name, + "details": {}, + "diagnostics": { + "config_summary": { + "kind": config.kind, + "environment": config.environment, + "base_url_valid": bool(config.base_url), + "api_key_provided": bool(config.api_key), + "default_model": config.default_model, + "models_count": len(config.models), + }, + "connection_steps": [], + "warnings": [] + } + } + + try: + # Add diagnostic step + result["diagnostics"]["connection_steps"].append({ + "step": "initialize_client", + "status": "pending", + "timestamp": str(datetime.datetime.now()) + }) + + # Initialize client with provided configuration + client = create_provider_client( + name=provider_name, + kind=config.kind, + environment=config.environment, + base_url=config.base_url, + api_key=config.api_key, + default_model=config.default_model, + rate_limits=config.rate_limits, + use_cache=False, # Don't cache test clients + ) + + # Update diagnostic step + result["diagnostics"]["connection_steps"][-1]["status"] = "success" + result["client_initialized"] = True + + # Check if an empty API key was provided + if not config.api_key: + result["diagnostics"]["warnings"].append({ + "warning_type": "empty_api_key", + "message": "An empty API key was provided. This might not work with most providers." + }) + + # Check if the base URL seems valid + if not config.base_url.startswith(("http://", "https://")): + result["diagnostics"]["warnings"].append({ + "warning_type": "invalid_base_url", + "message": "The base URL doesn't start with http:// or https://" + }) + + # Try to test the connection with a lightweight operation + # First check if we can get models (most providers support this) + if isinstance(client, ModelProvider): + try: + # Add diagnostic step + result["diagnostics"]["connection_steps"].append({ + "step": "verify_connection", + "status": "pending", + "timestamp": str(datetime.datetime.now()) + }) + + if hasattr(client, "get_available_models"): + provider_models = await client.get_available_models() + result["connection_verified"] = True + result["details"]["method"] = "get_available_models" + + # Add model information if available + if hasattr(provider_models, "data"): + models_data = [] + for model in provider_models.data: + model_id = _extract_model_id(model) + models_data.append({"id": model_id}) + result["details"]["available_models"] = models_data + result["details"]["models_count"] = len(models_data) + + elif hasattr(client, "get_models"): + provider_models = await client.get_models() + result["connection_verified"] = True + result["details"]["method"] = "get_models" + + # Add model information if available + if hasattr(provider_models, "models"): + models_data = [] + for model in provider_models.models: + model_id = _extract_model_id(model) + models_data.append({"id": model_id}) + result["details"]["available_models"] = models_data + result["details"]["models_count"] = len(models_data) + + else: + # If we can't check models, just having created the client is enough + result["connection_verified"] = True + result["details"]["method"] = "client_init_only" + + # Update diagnostic step + result["diagnostics"]["connection_steps"][-1]["status"] = "success" + + except Exception as e: + logger.error(f"Error testing connection to {provider_name}: {str(e)}") + + # Update diagnostic step + result["diagnostics"]["connection_steps"][-1]["status"] = "failed" + result["diagnostics"]["connection_steps"][-1]["error"] = str(e) + result["diagnostics"]["connection_steps"][-1]["error_type"] = type(e).__name__ + + result["connection_verified"] = False + result["details"]["error"] = str(e) + result["details"]["error_type"] = type(e).__name__ + + # Add specific suggestions based on error type + if "authentication" in str(e).lower() or "unauthorized" in str(e).lower() or "auth" in str(e).lower(): + result["details"]["suggestion"] = "Check if the API key is valid and has necessary permissions" + elif "timeout" in str(e).lower(): + result["details"]["suggestion"] = "Connection timed out. Check network connectivity or server status" + elif "url" in str(e).lower() or "endpoint" in str(e).lower(): + result["details"]["suggestion"] = "Check if the base URL is correct for this provider" + + raise Exception(f"Error verifying connection: {str(e)}") + else: + # For providers that don't support models API + result["connection_verified"] = True + result["details"]["method"] = "client_init_only" + + return result + + except Exception as e: + logger.error(f"Error initializing provider client for {provider_name}: {str(e)}") + + # Update diagnostic information + if "connection_steps" in result["diagnostics"] and result["diagnostics"]["connection_steps"]: + result["diagnostics"]["connection_steps"][-1]["status"] = "failed" + result["diagnostics"]["connection_steps"][-1]["error"] = str(e) + result["diagnostics"]["connection_steps"][-1]["error_type"] = type(e).__name__ + + result["client_initialized"] = False + result["connection_verified"] = False + result["details"]["error"] = str(e) + result["details"]["error_type"] = type(e).__name__ + + # Add specific suggestions based on error type + if any(keyword in str(e).lower() for keyword in ["api key", "authentication", "auth", "credential"]): + result["details"]["suggestion"] = "Check if the API key is valid" + elif any(keyword in str(e).lower() for keyword in ["url", "endpoint", "address"]): + result["details"]["suggestion"] = "Check if the base URL is correct" + elif any(keyword in str(e).lower() for keyword in ["timeout", "connect"]): + result["details"]["suggestion"] = "Network connectivity issue. Check your internet connection" + elif "provider" in str(e).lower(): + result["details"]["suggestion"] = "Verify that the provider type is supported and correctly configured" + + raise Exception(f"Failed to initialize provider client: {str(e)}", result) diff --git a/servers/inference_server/server/server/main.py b/servers/inference_server/server/server/main.py index 6eb37a96..a5c86905 100644 --- a/servers/inference_server/server/server/main.py +++ b/servers/inference_server/server/server/main.py @@ -16,6 +16,7 @@ from .db import init_app_db from .routers import main_router from .utils.logger import logger +from .utils.middleware import setup_middlewares class GracefulExit(SystemExit): @@ -84,6 +85,9 @@ async def lifespan(app: FastAPI): lifespan=lifespan, ) +# Set up middleware +setup_middlewares(app) + # Configure CORS app.add_middleware( CORSMiddleware, diff --git a/servers/inference_server/server/server/utils/__init__.py b/servers/inference_server/server/server/utils/__init__.py index 22acaa32..46c32ec9 100644 --- a/servers/inference_server/server/server/utils/__init__.py +++ b/servers/inference_server/server/server/utils/__init__.py @@ -1,15 +1,17 @@ """ # SPDX-License-Identifier: Apache-2.0 -Utility Module Collection +Utils Module -Collection of utility functions and helpers used throughout the graphcap server. +Provides utility functions and classes for the FastAPI application. -Key features: -- Logging configuration -- JSON formatting -- Error handling -- Common utilities - -Components: - logger: Configured loguru logger with JSON formatting +Key components: +- logger: Configured loguru logger +- resizing: Image resizing utilities +- middleware: FastAPI middleware components """ + +from . import logger +from . import resizing +from . import middleware + +__all__ = ["logger", "resizing", "middleware"] diff --git a/servers/inference_server/server/server/utils/middleware.py b/servers/inference_server/server/server/utils/middleware.py new file mode 100644 index 00000000..0633b899 --- /dev/null +++ b/servers/inference_server/server/server/utils/middleware.py @@ -0,0 +1,208 @@ +""" +# SPDX-License-Identifier: Apache-2.0 +Middleware for FastAPI + +Contains middleware components for the FastAPI application. +""" + +import datetime +import json +from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from fastapi import FastAPI, Request, status +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +from pydantic import ValidationError + +from ..utils.logger import logger + + +class ValidationErrorMiddleware: + """ + Middleware for handling validation errors and providing detailed error messages. + + This middleware intercepts RequestValidationError exceptions and transforms them + into user-friendly error responses with specific details about what parameters + were invalid. + """ + + def __init__(self, app: FastAPI): + """Initialize the middleware with the FastAPI app.""" + self.app = app + + # Register the exception handler + @app.exception_handler(RequestValidationError) + async def validation_exception_handler(request: Request, exc: RequestValidationError): + return self.handle_validation_error(request, exc) + + # Register the pydantic ValidationError handler + @app.exception_handler(ValidationError) + async def pydantic_validation_exception_handler(request: Request, exc: ValidationError): + return self.handle_validation_error(request, exc) + + def handle_validation_error(self, request: Request, exc: Union[RequestValidationError, ValidationError]): + """ + Handle validation errors and transform them into detailed error responses. + + Args: + request: The FastAPI request + exc: The validation exception + + Returns: + JSONResponse: A detailed error response + """ + # Extract error details from the exception + errors = exc.errors() + + # Log the error + logger.error(f"Validation error: {errors}") + + # Generate an overall message about the invalid parameters + message = self._generate_overall_message(errors) + + # Generate suggestions based on error types + suggestions = self._generate_suggestions(errors) + + # Generate field-specific error details + invalid_params = self._format_error_details(errors) + + # Build the response + error_response = { + "title": "Validation Error", + "timestamp": datetime.datetime.now().isoformat(), + "message": message, + "name": "Error", + "details": "The request was rejected due to invalid parameters.", + "invalid_parameters": invalid_params, + "suggestions": suggestions + } + + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=error_response + ) + + def _generate_overall_message(self, errors: Sequence[Dict[str, Any]]) -> str: + """ + Generate a clear overall error message summarizing what's invalid. + + Args: + errors: List of error dictionaries + + Returns: + A summary message string + """ + # Start with a default message + if not errors: + return "Invalid request parameters" + + # Count how many fields have errors + invalid_fields = set() + for error in errors: + loc = error.get("loc", []) + if len(loc) > 1: # Skip the body/query prefix + field_name = loc[1] if isinstance(loc[1], str) else str(loc[1]) + invalid_fields.add(field_name) + + if len(invalid_fields) == 1: + field = next(iter(invalid_fields)) + return f"Invalid request: '{field}' parameter is invalid" + elif len(invalid_fields) > 1: + field_list = "', '".join(sorted(invalid_fields)) + return f"Invalid request: Parameters '{field_list}' are invalid" + else: + return "Invalid request parameters" + + def _generate_suggestions(self, errors: Sequence[Dict[str, Any]]) -> List[str]: + """ + Generate helpful suggestions based on error types. + + Args: + errors: List of error dictionaries + + Returns: + List of suggestion strings + """ + suggestions = [] + + # Add specific suggestions based on error types + for error in errors: + error_type = error.get("type", "") + field = ".".join(str(loc) for loc in error.get("loc", [])[1:]) if error.get("loc") else "" + + if error_type == "missing": + suggestions.append(f"Add the missing required parameter: '{field}'") + elif error_type == "string_type": + suggestions.append(f"Ensure '{field}' is a valid string") + elif error_type == "url_parsing": + suggestions.append(f"Use a valid URL format for '{field}'") + elif error_type and "enum" in error_type: + valid_values = error.get("ctx", {}).get("expected", []) + if valid_values: + values_str = ", ".join([f"'{v}'" for v in valid_values]) + suggestions.append(f"Choose a valid option for '{field}': {values_str}") + else: + suggestions.append(f"Choose a valid option for '{field}'") + elif error_type == "value_error": + suggestions.append(f"Provide a valid value for '{field}'") + elif error_type == "type_error": + suggestions.append(f"Check the data type for '{field}'") + + # Add generic suggestion at the end + suggestions.append("Check the documentation for correct parameter formats") + + # Return unique suggestions + return list(dict.fromkeys(suggestions)) + + def _format_error_details(self, errors: Sequence[Dict[str, Any]]) -> Dict[str, Dict[str, str]]: + """ + Format validation errors into a structured dictionary. + + Args: + errors: List of error dictionaries + + Returns: + Dictionary of field names to error details + """ + invalid_params = {} + + for error in errors: + # Extract location (field name) + location = error.get("loc", []) + if len(location) < 2: + continue + + # Skip the first element (usually 'body' or 'query') + field_path = ".".join(str(loc) for loc in location[1:]) + + # Extract error message and type + message = error.get("msg", "Validation error") + error_type = error.get("type", "unknown_error") + + # Add any context information from the error + context = {} + if error.get("ctx"): + for key, value in error.get("ctx", {}).items(): + if key != "expected" or not isinstance(value, list) or len(value) < 5: + context[key] = value + + invalid_params[field_path] = { + "message": message, + "error_type": error_type + } + + # Add context if available + if context: + invalid_params[field_path]["context"] = context + + return invalid_params + + +def setup_middlewares(app: FastAPI) -> None: + """ + Set up all middleware for the FastAPI application. + + Args: + app: The FastAPI application instance + """ + # Initialize the validation error middleware + ValidationErrorMiddleware(app) \ No newline at end of file From ff9d8c05518adda411a4c43f17118991efd59b86 Mon Sep 17 00:00:00 2001 From: jphillips Date: Tue, 25 Mar 2025 11:51:06 -0500 Subject: [PATCH 08/69] Add api key field Signed-off-by: jphillips --- .../inference/providers/ProviderForm.tsx | 2 +- .../providers/form/ConnectionSection.tsx | 49 +++++++++++++++---- 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/graphcap_studio/src/features/inference/providers/ProviderForm.tsx b/graphcap_studio/src/features/inference/providers/ProviderForm.tsx index 56a20bea..4d2f3325 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderForm.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderForm.tsx @@ -120,8 +120,8 @@ function ProviderForm() { diff --git a/graphcap_studio/src/features/inference/providers/form/ConnectionSection.tsx b/graphcap_studio/src/features/inference/providers/form/ConnectionSection.tsx index d424a855..d225990b 100644 --- a/graphcap_studio/src/features/inference/providers/form/ConnectionSection.tsx +++ b/graphcap_studio/src/features/inference/providers/form/ConnectionSection.tsx @@ -1,7 +1,9 @@ import { Switch } from "@/components/ui/buttons/Switch"; import { useColorModeValue } from "@/components/ui/theme/color-mode"; -import { Box, Field, Input, Text, VStack } from "@chakra-ui/react"; +import { Box, Button, Field, Flex, Input, Text, VStack } from "@chakra-ui/react"; +import { Group, InputElement } from "@chakra-ui/react"; // SPDX-License-Identifier: Apache-2.0 +import { useState } from "react"; import { Controller } from "react-hook-form"; import { useInferenceProviderContext } from "../context"; @@ -9,15 +11,18 @@ import { useInferenceProviderContext } from "../context"; * Component for displaying and editing provider connection settings */ export function ConnectionSection() { - const { control, errors, watch, isEditing } = useInferenceProviderContext(); + const { control, errors, watch, isEditing, selectedProvider } = useInferenceProviderContext(); + const [showApiKey, setShowApiKey] = useState(false); const labelColor = useColorModeValue("gray.600", "gray.300"); const textColor = useColorModeValue("gray.700", "gray.200"); // Watch form values for read-only display const baseUrl = watch("baseUrl"); - const envVar = watch("envVar"); const isEnabled = watch("isEnabled"); + // Toggle API key visibility + const toggleShowApiKey = () => setShowApiKey(!showApiKey); + if (!isEditing) { return ( @@ -30,9 +35,21 @@ export function ConnectionSection() { - Environment Variable + API Key - {envVar} + + + + + + @@ -60,13 +77,25 @@ export function ConnectionSection() { /> ( - - Environment Variable - - {errors.envVar?.message} + + API Key + + + + + + + {errors.apiKey?.message} )} /> From 0d3ec149f3cdb342782a6e69f092e543b9ef6960 Mon Sep 17 00:00:00 2001 From: jphillips Date: Tue, 25 Mar 2025 12:27:06 -0500 Subject: [PATCH 09/69] Fix toast system, improve error info from data service Signed-off-by: jphillips --- graphcap_studio/package.json | 1 - .../src/components/ui/theme/ThemeProvider.tsx | 2 + graphcap_studio/src/components/ui/toaster.tsx | 43 ++++++ .../components/CreateDatasetModal.tsx | 6 +- .../components/DeleteDatasetModal.tsx | 4 +- .../datasets/context/DatasetContext.tsx | 22 +-- .../features/datasets/hooks/useDatasets.ts | 26 ++-- .../editor/components/ImageEditor.tsx | 16 +- .../features/editor/hooks/useImageActions.ts | 8 +- .../features/editor/hooks/useImageEditor.ts | 8 +- .../image-uploader/useImageUploader.ts | 20 +-- .../inference/providers/ProviderForm.tsx | 88 ++++++++--- .../providers/form/ConnectionSection.tsx | 4 +- .../features/inference/services/providers.ts | 18 ++- graphcap_studio/src/utils/error-handler.ts | 108 +++++++++++++ graphcap_studio/src/utils/toast.ts | 133 ++++++++++++++-- servers/data_service/src/app.ts | 22 ++- .../src/features/providers/controller.ts | 68 ++++++++- .../src/features/providers/schemas.ts | 6 +- .../data_service/src/utils/error-handler.ts | 142 ++++++++++++++++++ 20 files changed, 642 insertions(+), 103 deletions(-) create mode 100644 graphcap_studio/src/components/ui/toaster.tsx create mode 100644 graphcap_studio/src/utils/error-handler.ts create mode 100644 servers/data_service/src/utils/error-handler.ts diff --git a/graphcap_studio/package.json b/graphcap_studio/package.json index d1db07f9..e54b5d1c 100644 --- a/graphcap_studio/package.json +++ b/graphcap_studio/package.json @@ -42,7 +42,6 @@ "react-icons": "^5.5.0", "react-window": "^1.8.11", "react-window-infinite-loader": "^1.0.10", - "sonner": "^1.7.4", "styled-components": "^6.1.15", "tailwindcss": "^4.0.12", "zod": "^3.24.2" diff --git a/graphcap_studio/src/components/ui/theme/ThemeProvider.tsx b/graphcap_studio/src/components/ui/theme/ThemeProvider.tsx index 85e553d4..f3790e4a 100644 --- a/graphcap_studio/src/components/ui/theme/ThemeProvider.tsx +++ b/graphcap_studio/src/components/ui/theme/ThemeProvider.tsx @@ -1,6 +1,7 @@ "use client"; import { graphcapTheme } from "@/app/theme"; +import { Toaster } from "@/components/ui/toaster"; import { ChakraProvider } from "@chakra-ui/react"; import { ColorModeProvider, type ColorModeProviderProps } from "./color-mode"; @@ -8,6 +9,7 @@ export function Provider(props: Readonly) { return ( + ); } diff --git a/graphcap_studio/src/components/ui/toaster.tsx b/graphcap_studio/src/components/ui/toaster.tsx new file mode 100644 index 00000000..df6c2c38 --- /dev/null +++ b/graphcap_studio/src/components/ui/toaster.tsx @@ -0,0 +1,43 @@ +"use client" + +import { + Toaster as ChakraToaster, + Portal, + Spinner, + Stack, + Toast, + createToaster, +} from "@chakra-ui/react" + +export const toaster = createToaster({ + placement: "bottom-end", + pauseOnPageIdle: true, +}) + +export const Toaster = () => { + return ( + + + {(toast) => ( + + {toast.type === "loading" ? ( + + ) : ( + + )} + + {toast.title && {toast.title}} + {toast.description && ( + {toast.description} + )} + + {toast.action && ( + {toast.action.label} + )} + {toast.meta?.closable && } + + )} + + + ) +} diff --git a/graphcap_studio/src/features/datasets/components/CreateDatasetModal.tsx b/graphcap_studio/src/features/datasets/components/CreateDatasetModal.tsx index c46628e9..38e5c8ed 100644 --- a/graphcap_studio/src/features/datasets/components/CreateDatasetModal.tsx +++ b/graphcap_studio/src/features/datasets/components/CreateDatasetModal.tsx @@ -1,6 +1,6 @@ +import { toast } from "@/utils/toast"; // SPDX-License-Identifier: Apache-2.0 import { useState } from "react"; -import { toast } from "sonner"; import { useDatasetContext } from "../context/DatasetContext"; type CreateDatasetModalProps = { @@ -50,7 +50,7 @@ export function CreateDatasetModal({ try { await createDataset(datasetName); - toast.success(`Dataset "${datasetName}" created successfully`); + toast.success({ title: `Dataset "${datasetName}" created successfully` }); onDatasetCreated(datasetName); onClose(); } catch (error) { @@ -60,7 +60,7 @@ export function CreateDatasetModal({ if (error instanceof Error && error.message.includes("409")) { // If the dataset already exists, we can still consider this a success // and notify the user that we're switching to the existing dataset - toast.info(`Dataset "${datasetName}" already exists. Switching to it.`); + toast.info({ title: `Dataset "${datasetName}" already exists. Switching to it.` }); onDatasetCreated(datasetName); onClose(); } else { diff --git a/graphcap_studio/src/features/datasets/components/DeleteDatasetModal.tsx b/graphcap_studio/src/features/datasets/components/DeleteDatasetModal.tsx index ce418dca..5d6b3fc1 100644 --- a/graphcap_studio/src/features/datasets/components/DeleteDatasetModal.tsx +++ b/graphcap_studio/src/features/datasets/components/DeleteDatasetModal.tsx @@ -1,7 +1,7 @@ import { useDeleteDataset } from "@/services/dataset"; +import { toast } from "@/utils/toast"; // SPDX-License-Identifier: Apache-2.0 import { useState } from "react"; -import { toast } from "sonner"; type DeleteDatasetModalProps = { readonly isOpen: boolean; @@ -32,7 +32,7 @@ export function DeleteDatasetModal({ try { await deleteDatasetMutation.mutateAsync(datasetName); - toast.success(`Dataset "${datasetName}" deleted successfully`); + toast.success({ title: `Dataset "${datasetName}" deleted successfully` }); onDatasetDeleted(); onClose(); } catch (error) { diff --git a/graphcap_studio/src/features/datasets/context/DatasetContext.tsx b/graphcap_studio/src/features/datasets/context/DatasetContext.tsx index aee859e1..17e048fc 100644 --- a/graphcap_studio/src/features/datasets/context/DatasetContext.tsx +++ b/graphcap_studio/src/features/datasets/context/DatasetContext.tsx @@ -2,6 +2,7 @@ import type { Dataset } from "@/services/dataset"; import { useAddImageToDataset, useCreateDataset } from "@/services/dataset"; import type { Image } from "@/services/images"; +import { toast } from "@/utils/toast"; import { type ReactNode, createContext, @@ -11,7 +12,6 @@ import { useMemo, useState, } from "react"; -import { toast } from "sonner"; /** * Interface for the dataset context state @@ -121,17 +121,17 @@ export function DatasetProvider({ }); if (result.success) { - toast.success( - result.message ?? - `Image added to dataset ${targetDataset} successfully`, - ); + toast.success({ + title: result.message ?? + `Image added to dataset ${targetDataset} successfully` + }); } else { - toast.error(result.message ?? "Failed to add image to dataset"); + toast.error({ title: result.message ?? "Failed to add image to dataset" }); } } catch (error) { - toast.error( - `Failed to add image to dataset: ${(error as Error).message}`, - ); + toast.error({ + title: `Failed to add image to dataset: ${(error as Error).message}` + }); console.error("Error adding image to dataset:", error); } }, @@ -149,10 +149,10 @@ export function DatasetProvider({ // Otherwise use the mutation from dataset service await createDatasetMutation.mutateAsync(name); - toast.success(`Created dataset ${name}`); + toast.success({ title: `Created dataset ${name}` }); } catch (error) { console.error("Failed to create dataset:", error); - toast.error(`Failed to create dataset: ${(error as Error).message}`); + toast.error({ title: `Failed to create dataset: ${(error as Error).message}` }); throw error; } }, diff --git a/graphcap_studio/src/features/datasets/hooks/useDatasets.ts b/graphcap_studio/src/features/datasets/hooks/useDatasets.ts index 24557fa1..dd804d7d 100644 --- a/graphcap_studio/src/features/datasets/hooks/useDatasets.ts +++ b/graphcap_studio/src/features/datasets/hooks/useDatasets.ts @@ -5,9 +5,9 @@ import { useListDatasets, } from "@/services/dataset"; import { getQueryClient } from "@/utils/queryClient"; +import { toast } from "@/utils/toast"; // SPDX-License-Identifier: Apache-2.0 import { useCallback, useEffect, useMemo, useState } from "react"; -import { toast } from "sonner"; /** * Custom hook for managing datasets @@ -96,10 +96,10 @@ export function useDatasets() { setSelectedDataset(name); setSelectedSubfolder(null); - toast.success(`Created dataset ${name}`); + toast.success({ title: `Created dataset ${name}` }); } catch (error) { console.error("Failed to create dataset:", error); - toast.error(`Failed to create dataset: ${(error as Error).message}`); + toast.error({ title: `Failed to create dataset: ${(error as Error).message}` }); throw error; } }, @@ -120,17 +120,17 @@ export function useDatasets() { }); if (result.success) { - toast.success( - result.message ?? - `Image added to dataset ${targetDataset} successfully`, - ); + toast.success({ + title: result.message ?? + `Image added to dataset ${targetDataset} successfully` + }); } else { - toast.error(result.message ?? "Failed to add image to dataset"); + toast.error({ title: result.message ?? "Failed to add image to dataset" }); } } catch (error) { - toast.error( - `Failed to add image to dataset: ${(error as Error).message}`, - ); + toast.error({ + title: `Failed to add image to dataset: ${(error as Error).message}` + }); console.error("Error adding image to dataset:", error); } }, @@ -151,11 +151,11 @@ export function useDatasets() { // After refresh, identify new images and mark them as recently uploaded if (currentDataset?.images) { const newRecentImages = new Set(recentlyUploadedImages); - currentDataset.images.forEach((image) => { + for (const image of currentDataset.images) { // Add all images from the current dataset to the recent set // In a real implementation, you might want to be more selective newRecentImages.add(image.path); - }); + } setRecentlyUploadedImages(newRecentImages); // Clear the recent uploads set after 5 minutes diff --git a/graphcap_studio/src/features/editor/components/ImageEditor.tsx b/graphcap_studio/src/features/editor/components/ImageEditor.tsx index 268f50f9..3a2bfdcf 100644 --- a/graphcap_studio/src/features/editor/components/ImageEditor.tsx +++ b/graphcap_studio/src/features/editor/components/ImageEditor.tsx @@ -1,11 +1,11 @@ import { - ImageProcessResponse, + type ImageProcessResponse, getImageUrl, useProcessImage, } from "@/services/images"; +import { toast } from "@/utils/toast"; import { useCallback, useState } from "react"; -import Cropper, { Area } from "react-easy-crop"; -import { toast } from "sonner"; +import Cropper, { type Area } from "react-easy-crop"; import { ImageViewer } from "../../gallery-viewer"; interface ImageEditorProps { @@ -36,7 +36,7 @@ export function ImageEditor({ imagePath, onSave, onCancel }: ImageEditorProps) { const handleSave = async () => { if (!croppedAreaPixels) { - toast.error("No crop area selected"); + toast.error({ title: "No crop area selected" }); return; } @@ -55,12 +55,12 @@ export function ImageEditor({ imagePath, onSave, onCancel }: ImageEditorProps) { }, }); - toast.success("Image saved successfully"); + toast.success({ title: "Image saved successfully" }); onSave?.(result); } catch (error) { - toast.error( - `Failed to save image: ${error instanceof Error ? error.message : String(error)}`, - ); + toast.error({ + title: `Failed to save image: ${error instanceof Error ? error.message : String(error)}`, + }); } finally { setIsSaving(false); } diff --git a/graphcap_studio/src/features/editor/hooks/useImageActions.ts b/graphcap_studio/src/features/editor/hooks/useImageActions.ts index 8dc99406..8d89dd4a 100644 --- a/graphcap_studio/src/features/editor/hooks/useImageActions.ts +++ b/graphcap_studio/src/features/editor/hooks/useImageActions.ts @@ -1,7 +1,7 @@ -import { Image } from "@/services/images"; +import type { Image } from "@/services/images"; +import { toast } from "@/utils/toast"; // SPDX-License-Identifier: Apache-2.0 import { useCallback } from "react"; -import { toast } from "sonner"; interface UseImageActionsProps { selectedImage: Image | null; @@ -37,7 +37,7 @@ export function useImageActions({ const handleDownload = useCallback(() => { if (selectedImage) { // Implementation for download - toast.success("Download started"); + toast.success({ title: "Download started" }); } }, [selectedImage]); @@ -45,7 +45,7 @@ export function useImageActions({ const handleDelete = useCallback(() => { if (selectedImage) { // Implementation for delete - toast.success("Image deleted"); + toast.success({ title: "Image deleted" }); } }, [selectedImage]); diff --git a/graphcap_studio/src/features/editor/hooks/useImageEditor.ts b/graphcap_studio/src/features/editor/hooks/useImageEditor.ts index d72aa487..946df122 100644 --- a/graphcap_studio/src/features/editor/hooks/useImageEditor.ts +++ b/graphcap_studio/src/features/editor/hooks/useImageEditor.ts @@ -1,9 +1,9 @@ import { queryKeys } from "@/services/dataset"; -import { Image } from "@/services/images"; +import type { Image } from "@/services/images"; +import { toast } from "@/utils/toast"; import { useQueryClient } from "@tanstack/react-query"; // SPDX-License-Identifier: Apache-2.0 import { useCallback, useState } from "react"; -import { toast } from "sonner"; interface UseImageEditorProps { selectedDataset: string | null; @@ -28,7 +28,7 @@ export function useImageEditor({ selectedDataset }: UseImageEditorProps) { if (selectedImage) { setIsEditing(true); } else { - toast.error("Please select an image to edit"); + toast.error({ title: "Please select an image to edit" }); } }, []); @@ -36,7 +36,7 @@ export function useImageEditor({ selectedDataset }: UseImageEditorProps) { * Save edited image */ const handleSave = useCallback(() => { - toast.success("Image saved successfully"); + toast.success({ title: "Image saved successfully" }); setIsEditing(false); // Invalidate cache for this dataset to refresh the images diff --git a/graphcap_studio/src/features/gallery-viewer/image-uploader/useImageUploader.ts b/graphcap_studio/src/features/gallery-viewer/image-uploader/useImageUploader.ts index 3f0fffc8..476bc672 100644 --- a/graphcap_studio/src/features/gallery-viewer/image-uploader/useImageUploader.ts +++ b/graphcap_studio/src/features/gallery-viewer/image-uploader/useImageUploader.ts @@ -1,9 +1,9 @@ import { useUploadImage } from "@/services/images"; +import { toast } from "@/utils/toast"; // SPDX-License-Identifier: Apache-2.0 // TODO: RESOLVE OLD DATASET NAME SYSTEM import { useCallback, useState } from "react"; import { useDropzone } from "react-dropzone"; -import { toast } from "sonner"; export interface UseImageUploaderProps { readonly datasetName: string; @@ -48,9 +48,9 @@ export function useImageUploader({ // Initialize progress for each file const initialProgress: Record = {}; - acceptedFiles.forEach((file) => { + for (const file of acceptedFiles) { initialProgress[file.name] = 0; - }); + } setUploadProgress(initialProgress); // Process files sequentially to avoid overwhelming the server @@ -73,7 +73,7 @@ export function useImageUploader({ })); // Show success toast for each file - toast.success(`Uploaded ${file.name}`); + toast.success({ title: `Uploaded ${file.name}` }); } catch (error) { console.error(`Error uploading ${file.name}:`, error); failedUploads.push(file.name); @@ -85,25 +85,25 @@ export function useImageUploader({ })); // Show error toast - toast.error(`Failed to upload ${file.name}`); + toast.error({ title: `Failed to upload ${file.name}` }); } } // Show summary toast if (failedUploads.length === 0) { if (totalFiles > 1) { - toast.success(`Successfully uploaded all ${totalFiles} images`); + toast.success({ title: `Successfully uploaded all ${totalFiles} images` }); } } else { - toast.error( - `Failed to upload ${failedUploads.length} of ${totalFiles} images`, - ); + toast.error({ + title: `Failed to upload ${failedUploads.length} of ${totalFiles} images`, + }); } setIsUploading(false); onUploadComplete(); }, - [datasetName, onUploadComplete], + [datasetName, onUploadComplete, uploadImageMutation.mutateAsync], ); const { getRootProps, getInputProps, isDragActive } = useDropzone({ diff --git a/graphcap_studio/src/features/inference/providers/ProviderForm.tsx b/graphcap_studio/src/features/inference/providers/ProviderForm.tsx index 4d2f3325..1ab8ea42 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderForm.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderForm.tsx @@ -1,3 +1,4 @@ +import { handleApiError } from "@/utils/error-handler"; import { Box, Button, Flex } from "@chakra-ui/react"; // SPDX-License-Identifier: Apache-2.0 import { memo, useState } from "react"; @@ -6,7 +7,11 @@ import { FormFields } from "./FormFields"; import { ProviderConnectionErrorDialog } from "./components/ProviderConnectionErrorDialog"; import { ProviderConnectionSuccessDialog } from "./components/ProviderConnectionSuccessDialog"; import { useInferenceProviderContext } from "./context"; -import { toServerConfig } from "./types"; +import { + type ProviderCreate, + type ProviderUpdate, + toServerConfig, +} from "./types"; // Extended Error interface with cause property interface ErrorWithCause extends Error { @@ -26,11 +31,23 @@ interface ErrorWithResponse { * Component for provider form that displays fields in either view or edit mode */ function ProviderForm() { - const { handleSubmit, isSubmitting, onSubmit, onCancel, mode, setMode, selectedProvider } = - useInferenceProviderContext(); + const { + handleSubmit, + isSubmitting, + onSubmit: onSubmitProp, + onCancel, + mode, + setMode, + selectedProvider, + } = useInferenceProviderContext(); const [isTestingConnection, setIsTestingConnection] = useState(false); - const [connectionError, setConnectionError] = useState | string | null>(null); - const [connectionDetails, setConnectionDetails] = useState | null>(null); + const [connectionError, setConnectionError] = useState< + Record | string | null + >(null); + const [connectionDetails, setConnectionDetails] = useState | null>(null); const [isErrorDialogOpen, setIsErrorDialogOpen] = useState(false); const [isSuccessDialogOpen, setIsSuccessDialogOpen] = useState(false); @@ -40,15 +57,47 @@ function ProviderForm() { const isCreating = mode === "create"; const isViewMode = mode === "view"; + // Wrap the submit handler to use our error handler + const onSubmit = async (data: ProviderCreate | ProviderUpdate) => { + try { + await onSubmitProp(data); + } catch (error) { + handleApiError(error); + } + }; + const handleTestConnection = async () => { if (!selectedProvider) return; + // Validate API key is present + if (!selectedProvider.apiKey) { + setConnectionError({ + title: "Connection failed", + timestamp: new Date().toISOString(), + message: "API key is required", + name: "ValidationError", + details: "Please provide an API key in the provider configuration.", + suggestions: [ + "Edit the provider to add an API key", + "API keys should be non-empty strings", + ], + }); + setIsErrorDialogOpen(true); + return; + } + setIsTestingConnection(true); setConnectionError(null); try { const config = toServerConfig(selectedProvider); - + + // Log the config for debugging + console.log("Testing connection with provider config:", { + ...config, + api_key: config.api_key ? "[REDACTED]" : null, + }); + const result = await testConnection.mutateAsync({ providerName: selectedProvider.name, config, @@ -58,46 +107,47 @@ function ProviderForm() { setIsSuccessDialogOpen(true); } catch (error) { console.error("Connection test failed:", error); - + // Create a user-friendly error object that can be displayed directly let errorObj: Record = { title: "Connection failed", - timestamp: new Date().toISOString() + timestamp: new Date().toISOString(), }; - + // Extract error information based on type if (error instanceof Error) { // Extract useful properties from Error objects errorObj.message = error.message; errorObj.name = error.name; - + // Special case for [object Object] errors - if (error.message?.includes('[object Object]')) { + if (error.message?.includes("[object Object]")) { errorObj.message = "Invalid provider configuration"; - errorObj.details = "The server rejected the request due to invalid parameters."; + errorObj.details = + "The server rejected the request due to invalid parameters."; errorObj.suggestions = [ "Check API key and endpoint URL", "Verify the provider is correctly configured", - "Check server logs for more details" + "Check server logs for more details", ]; } - + // Check for cause object with additional details const errorWithCause = error as ErrorWithCause; - if (errorWithCause.cause && typeof errorWithCause.cause === 'object') { + if (errorWithCause.cause && typeof errorWithCause.cause === "object") { errorObj.errorDetails = errorWithCause.cause; } - } else if (typeof error === 'object' && error !== null) { + } else if (typeof error === "object" && error !== null) { // For direct object errors, merge with our error object errorObj = { ...errorObj, - ...error as Record + ...(error as Record), }; } else { // For primitive errors errorObj.message = String(error); } - + // Set the formatted error object setConnectionError(errorObj); setIsErrorDialogOpen(true); @@ -147,7 +197,7 @@ function ProviderForm() { {/* Error Dialog */} - setIsErrorDialogOpen(false)} error={connectionError} diff --git a/graphcap_studio/src/features/inference/providers/form/ConnectionSection.tsx b/graphcap_studio/src/features/inference/providers/form/ConnectionSection.tsx index d225990b..f874cae2 100644 --- a/graphcap_studio/src/features/inference/providers/form/ConnectionSection.tsx +++ b/graphcap_studio/src/features/inference/providers/form/ConnectionSection.tsx @@ -88,6 +88,8 @@ export function ConnectionSection() { id="apiKey" type={showApiKey ? "text" : "password"} pe="4.5rem" + required + placeholder="Enter API key" /> - {errors.apiKey?.message} + {errors.apiKey?.message || (field.value === "" && "API key is required")} )} /> diff --git a/graphcap_studio/src/features/inference/services/providers.ts b/graphcap_studio/src/features/inference/services/providers.ts index 809b709b..d6582722 100644 --- a/graphcap_studio/src/features/inference/services/providers.ts +++ b/graphcap_studio/src/features/inference/services/providers.ts @@ -161,7 +161,9 @@ export function useUpdateProvider() { }); if (!response.ok) { - throw new Error(`Failed to update provider: ${response.status}`); + const errorData = await response.json(); + console.error("Provider update error:", errorData); + throw errorData; } return response.json() as Promise; @@ -220,7 +222,9 @@ export function useUpdateProviderApiKey() { }); if (!response.ok) { - throw new Error(`Failed to update API key: ${response.status}`); + const errorData = await response.json(); + console.error("API key update error:", errorData); + throw errorData; } return response.json() as Promise; @@ -241,6 +245,15 @@ export function useTestProviderConnection() { return useMutation({ mutationFn: async ({ providerName, config }: { providerName: string; config: ServerProviderConfig }) => { const client = createInferenceBridgeClient(connections); + + // Add console logging to debug + console.log('Testing connection with config:', JSON.stringify(config)); + + // Make sure api_key is properly set and not null or undefined + if (!config.api_key) { + throw new Error("API key is required for testing provider connection"); + } + const response = await client.providers[":provider_name"]["test-connection"].$post({ param: { provider_name: providerName }, json: config, @@ -248,6 +261,7 @@ export function useTestProviderConnection() { if (!response.ok) { const errorData = await response.json(); + console.error('Error response:', errorData); // Check if this is our enhanced error format if (errorData.status === 'error' && errorData.details) { diff --git a/graphcap_studio/src/utils/error-handler.ts b/graphcap_studio/src/utils/error-handler.ts new file mode 100644 index 00000000..3ce3bd6f --- /dev/null +++ b/graphcap_studio/src/utils/error-handler.ts @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Error Handler + * + * Utilities for handling and formatting errors in the client application. + */ + +// Import from our custom toast utility +import { toast } from './toast'; + +interface ServerErrorResponse { + status?: string; + statusCode?: number; + message?: string; + timestamp?: string; + path?: string; + details?: unknown; + validationErrors?: Record; +} + +/** + * Formats a server error response into a human-readable message + */ +export function formatServerError(error: unknown): string { + // If it's already a string, just return it + if (typeof error === 'string') { + return error; + } + + // Try to handle server error response + if (error && typeof error === 'object') { + const serverError = error as ServerErrorResponse; + + // If there's a message, use it + if (serverError.message) { + return serverError.message; + } + + // If there are validation errors, format them + if (serverError.validationErrors) { + const validationMessages: string[] = []; + + for (const [field, errors] of Object.entries(serverError.validationErrors)) { + for (const errorMsg of errors) { + validationMessages.push(`${field}: ${errorMsg}`); + } + } + + if (validationMessages.length > 0) { + return `Validation errors:\n${validationMessages.join('\n')}`; + } + } + + // If there's an error property with a message (common in Axios errors) + if ('error' in serverError && typeof serverError.error === 'string') { + return serverError.error; + } + } + + // Fallback for Error instances + if (error instanceof Error) { + return error.message; + } + + // Last resort + return 'An unknown error occurred'; +} + +/** + * Shows a toast notification for server errors + */ +export function showServerError(error: unknown, title = 'Error'): void { + const message = formatServerError(error); + toast.error({ title, description: message }); +} + +/** + * Helper to extract validation errors from server responses + */ +export function getValidationErrors(error: unknown): Record | null { + if (!error || typeof error !== 'object') { + return null; + } + + const serverError = error as ServerErrorResponse; + + if (!serverError.validationErrors) { + return null; + } + + const formattedErrors: Record = {}; + + for (const [field, errors] of Object.entries(serverError.validationErrors)) { + if (errors && errors.length > 0) { + formattedErrors[field] = errors[0]; + } + } + + return Object.keys(formattedErrors).length > 0 ? formattedErrors : null; +} + +/** + * Handles common query/mutation errors + */ +export function handleApiError(error: unknown): void { + showServerError(error); + console.error('API error:', error); +} \ No newline at end of file diff --git a/graphcap_studio/src/utils/toast.ts b/graphcap_studio/src/utils/toast.ts index 81197da5..ef8b5bbe 100644 --- a/graphcap_studio/src/utils/toast.ts +++ b/graphcap_studio/src/utils/toast.ts @@ -1,49 +1,152 @@ -import { ToastT, toast } from "sonner"; +// SPDX-License-Identifier: Apache-2.0 +import { toaster } from "@/components/ui/toaster"; type ToastType = "error" | "success"; export const showToast = ( text: string, type: ToastType, - options?: Parameters[1], + options?: Omit[0], "title" | "description">, ) => { - const toastFn = type === "error" ? toast.error : toast.success; - toastFn(text, options); + const title = text; + if (type === "error") { + toaster.create({ + title, + type: "error", + ...options, + }); + } else { + toaster.create({ + title, + type: "success", + ...options, + }); + } }; export const errorToast = ( text: string, - options?: Parameters[1], + options?: Omit[0], "title" | "description" | "type">, ) => { console.error(text); if (text != null && text !== "") { - toast.error(text, options); + toaster.create({ + title: text, + type: "error", + ...options, + }); } }; export const successToast = ( text: string, - options?: Parameters[1], + options?: Omit[0], "title" | "description" | "type">, ) => { if (text != null && text !== "") { - toast.success(text, options); + toaster.create({ + title: text, + type: "success", + ...options, + }); } }; type MessageType = { - success: string | ((data: any) => string); - error?: string | ((error: any) => string); + success: string | ((data: unknown) => string); + error?: string | ((error: unknown) => string); }; export async function promiseToast( promise: Promise, message: MessageType, - options?: Parameters[1], + options?: Omit[0], "title" | "description" | "type">, ) { - return toast.promise(promise, { - loading: "Loading", - success: message.success, - error: message.error || "Error. Please try again", + // Show loading toast + const loadingToastId = toaster.create({ + title: "Loading", + type: "loading", ...options, }); + + try { + const result = await promise; + // Close loading toast + toaster.dismiss(loadingToastId); + // Show success toast + const successMessage = typeof message.success === 'function' + ? message.success(result) + : message.success; + toaster.create({ + title: successMessage, + type: "success", + ...options, + }); + return result; + } catch (error) { + // Close loading toast + toaster.dismiss(loadingToastId); + // Show error toast + const errorMessage = message.error && typeof message.error === 'function' + ? message.error(error) + : message.error || "Error. Please try again"; + toaster.create({ + title: errorMessage, + type: "error", + ...options, + }); + throw error; + } } + +/** + * Toast notification utility + */ +export const toast = { + /** + * Show a success toast + */ + success: ({ title, description, duration = 3000 }: { title: string; description?: string; duration?: number }) => { + toaster.create({ + title, + description, + duration, + type: "success", + }); + }, + + /** + * Show an error toast + */ + error: ({ title, description, duration = 5000 }: { title: string; description?: string; duration?: number }) => { + toaster.create({ + title, + description, + duration, + type: "error", + }); + }, + + /** + * Show an info toast + */ + info: ({ title, description, duration = 3000 }: { title: string; description?: string; duration?: number }) => { + toaster.create({ + title, + description, + duration, + type: "info", + }); + }, + + /** + * Show a warning toast + */ + warning: ({ title, description, duration = 4000 }: { title: string; description?: string; duration?: number }) => { + toaster.create({ + title, + description, + duration, + type: "warning", + }); + } +}; diff --git a/servers/data_service/src/app.ts b/servers/data_service/src/app.ts index eec14188..ebff4278 100644 --- a/servers/data_service/src/app.ts +++ b/servers/data_service/src/app.ts @@ -17,11 +17,15 @@ import { batchQueueRoutes } from './api/routes/batch_queue'; import { checkDatabaseConnection } from './db/init'; import { env } from './env'; import { providerRoutes } from './features/providers/routes'; +import { errorHandlerMiddleware, notFoundHandler } from './utils/error-handler'; import { logger } from './utils/logger'; // Create OpenAPI Hono app const app = new OpenAPIHono(); +// Add error handling middleware first so it can catch errors from other middleware +app.use('*', errorHandlerMiddleware({ logErrors: true })); + // Add middleware app.use('*', cors()); app.use('*', prettyJSON()); @@ -164,12 +168,24 @@ app.get('/docs', apiReference({ layout: 'modern', })); -// Error handling +// Error handling - replace existing onError handler app.onError((err, c) => { - logger.error({ err, path: c.req.path }, 'Unhandled error'); - return c.json({ error: 'Internal server error' }, 500); + // The middleware should handle most errors, + // but this is a fallback for errors that somehow bypass the middleware + logger.error({ err, path: c.req.path }, 'Unhandled error in onError handler'); + + return c.json({ + status: 'error', + statusCode: 500, + message: 'Internal server error', + timestamp: new Date().toISOString(), + path: c.req.path + }, 500); }); +// Add not found handler +app.notFound(notFoundHandler); + // Export the app export default app; diff --git a/servers/data_service/src/features/providers/controller.ts b/servers/data_service/src/features/providers/controller.ts index 1b4a7c23..c3bf06c9 100644 --- a/servers/data_service/src/features/providers/controller.ts +++ b/servers/data_service/src/features/providers/controller.ts @@ -183,7 +183,13 @@ export const updateProvider = async (c: Context) => { if (!existingProvider) { logger.debug({ id }, "Provider not found for update"); - return c.json({ error: "Provider not found" }, 404); + return c.json({ + status: "error", + statusCode: 404, + message: "Provider not found", + timestamp: new Date().toISOString(), + path: c.req.path + }, 404); } // Extract models and rate limits if provided @@ -261,8 +267,21 @@ export const updateProvider = async (c: Context) => { logger.debug({ id }, "Provider updated successfully"); return c.json(result); } catch (error) { - logger.error({ error }, "Error updating provider"); - return c.json({ error: "Failed to update provider" }, 500); + logger.error({ + error, + message: error instanceof Error ? error.message : "Unknown error", + stack: error instanceof Error ? error.stack : undefined, + }, "Error updating provider"); + + // Return detailed error response + return c.json({ + status: "error", + statusCode: 500, + message: error instanceof Error ? error.message : "Failed to update provider", + timestamp: new Date().toISOString(), + path: c.req.path, + details: error instanceof Error ? { name: error.name } : undefined + }, 500); } }; @@ -310,6 +329,21 @@ export const updateProviderApiKey = async (c: Context) => { const { apiKey } = c.req.valid("json") as ProviderApiKey; logger.debug({ id }, "Updating provider API key"); + // Validate API key + if (!apiKey || apiKey.trim() === '') { + logger.debug({ id }, "Empty API key provided"); + return c.json({ + status: "error", + statusCode: 400, + message: "API key cannot be empty", + timestamp: new Date().toISOString(), + path: c.req.path, + validationErrors: { + "apiKey": ["API key cannot be empty"] + } + }, 400); + } + // Check if provider exists const existingProvider = await db.query.providers.findFirst({ where: eq(providers.id, Number.parseInt(id)), @@ -317,7 +351,13 @@ export const updateProviderApiKey = async (c: Context) => { if (!existingProvider) { logger.debug({ id }, "Provider not found for API key update"); - return c.json({ error: "Provider not found" }, 404); + return c.json({ + status: "error", + statusCode: 404, + message: "Provider not found", + timestamp: new Date().toISOString(), + path: c.req.path + }, 404); } // Encrypt the API key @@ -336,9 +376,25 @@ export const updateProviderApiKey = async (c: Context) => { return c.json({ success: true, message: "API key updated successfully", + timestamp: new Date().toISOString() }); } catch (error) { - logger.error({ error }, "Error updating provider API key"); - return c.json({ error: "Failed to update API key" }, 500); + const providerId = c.req.param('id'); + logger.error({ + error, + message: error instanceof Error ? error.message : "Unknown error", + stack: error instanceof Error ? error.stack : undefined, + providerId + }, "Error updating provider API key"); + + // Return detailed error response + return c.json({ + status: "error", + statusCode: 500, + message: error instanceof Error ? error.message : "Failed to update API key", + timestamp: new Date().toISOString(), + path: c.req.path, + details: error instanceof Error ? { name: error.name } : undefined + }, 500); } }; diff --git a/servers/data_service/src/features/providers/schemas.ts b/servers/data_service/src/features/providers/schemas.ts index 89f2856f..03427433 100644 --- a/servers/data_service/src/features/providers/schemas.ts +++ b/servers/data_service/src/features/providers/schemas.ts @@ -83,7 +83,11 @@ export const providerUpdateSchema = z.object({ // Schema for updating a provider's API key export const providerApiKeySchema = z.object({ - apiKey: z.string().min(1, 'API key is required'), + apiKey: z.string() + .min(1, { message: 'API key is required and cannot be empty' }) + .refine(val => val.trim() !== '', { + message: 'API key cannot be just whitespace' + }), }); // Export types diff --git a/servers/data_service/src/utils/error-handler.ts b/servers/data_service/src/utils/error-handler.ts new file mode 100644 index 00000000..726e63a3 --- /dev/null +++ b/servers/data_service/src/utils/error-handler.ts @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Error Handler + * + * Utility for handling and formatting errors in a consistent way. + */ + +import type { Context } from "hono"; +import { ZodError } from "zod"; +import { logger } from "./logger"; + +interface ErrorResponse { + status: "error"; + statusCode: number; + message: string; + timestamp: string; + path?: string; + details?: unknown; + validationErrors?: Record; +} + +/** + * Creates a standardized error response object + */ +export function createErrorResponse( + message: string, + statusCode = 400, + details?: unknown, + path?: string, +): ErrorResponse { + return { + status: "error", + statusCode, + message, + timestamp: new Date().toISOString(), + path, + details, + }; +} + +/** + * Handles validation errors from Zod + */ +export function handleValidationError(error: ZodError, c: Context): Response { + const validationErrors: Record = {}; + + error.errors.forEach((err) => { + const path = err.path.join("."); + if (!validationErrors[path]) { + validationErrors[path] = []; + } + validationErrors[path].push(err.message); + }); + + const response = createErrorResponse( + "Validation error", + 400, + undefined, + c.req.path, + ); + + response.validationErrors = validationErrors; + + logger.debug({ validationErrors }, "Validation errors"); + + return c.json(response, 400); +} + +/** + * Handles general application errors + */ +export function handleApplicationError(error: unknown, c: Context): Response { + if (error instanceof ZodError) { + return handleValidationError(error, c); + } + + const statusCode = 500; + let message = "Internal server error"; + let details = undefined; + + if (error instanceof Error) { + message = error.message; + details = { + name: error.name, + stack: process.env.NODE_ENV === "development" ? error.stack : undefined, + }; + } else if (typeof error === "string") { + message = error; + } else if (typeof error === "object" && error !== null) { + message = "Application error"; + details = error; + } + + logger.error({ error, path: c.req.path }, message); + + const response = createErrorResponse( + message, + statusCode, + details, + c.req.path, + ); + return c.json(response, statusCode); +} + +/** + * Error handling middleware for Hono + */ +export function errorHandlerMiddleware(options: { logErrors?: boolean } = {}) { + return async (c: Context, next: () => Promise) => { + try { + await next(); + } catch (error) { + if (options.logErrors !== false) { + logger.error( + { + error, + path: c.req.path, + method: c.req.method, + headers: Object.fromEntries(c.req.headers.entries()), + }, + "Error caught in middleware", + ); + } + + return handleApplicationError(error, c); + } + }; +} + +/** + * Not found error handler for Hono + */ +export function notFoundHandler(c: Context) { + const response = createErrorResponse( + `Route not found: ${c.req.method} ${c.req.path}`, + 404, + undefined, + c.req.path, + ); + + return c.json(response, 404); +} From 8e31424fdf9ae370b8c86098d05b7125eb00f47c Mon Sep 17 00:00:00 2001 From: jphillips Date: Wed, 26 Mar 2025 07:48:47 -0500 Subject: [PATCH 10/69] Improved provider form error ux. Decompose into module Signed-off-by: jphillips --- graphcap_studio/pnpm-lock.yaml | 14 -- .../inference/hooks/useProviderForm.ts | 59 ++++- .../FormFields.module.css | 0 .../ProviderConnection/ProviderForm.tsx | 63 +++++ .../ProviderFormActions.tsx} | 5 +- .../ProviderFormTabs.tsx} | 4 +- .../components/ProviderConnectionActions.tsx | 41 ++++ .../ProviderConnectionErrorDialog.tsx | 134 +++++++++++ .../ProviderConnectionSuccessDialog.tsx | 0 .../components}/form/BasicInfoSection.tsx | 0 .../components}/form/ConnectionSection.tsx | 50 ++-- .../components}/form/EnvironmentSelect.tsx | 4 +- .../form}/ModelSelectionSection.tsx | 8 +- .../components}/form/ModelSelector.tsx | 0 .../components/form/ProviderFormView.tsx | 138 +++++++++++ .../components}/form/ProviderSelect.tsx | 2 +- .../components}/form/RateLimitsSection.tsx | 8 +- .../components}/form/index.ts | 4 +- .../context/useProviderForm.ts | 164 +++++++++++++ .../hooks/useProviderConnection.ts | 149 ++++++++++++ .../providers/ProviderConnection/index.ts | 4 + .../inference/providers/ProviderForm.tsx | 218 ------------------ .../inference/providers/ProvidersList.tsx | 14 +- .../inference/providers/ProvidersPanel.tsx | 34 ++- .../ProviderConnectionErrorDialog.tsx | 133 ----------- .../context/InferenceProviderContext.tsx | 59 ++++- .../src/features/inference/providers/index.ts | 8 +- .../server-connections/services/providers.ts | 64 ++++- .../src/features/providers/controller.ts | 142 ++++++++++-- 29 files changed, 1052 insertions(+), 471 deletions(-) rename graphcap_studio/src/features/inference/providers/{ => ProviderConnection}/FormFields.module.css (100%) create mode 100644 graphcap_studio/src/features/inference/providers/ProviderConnection/ProviderForm.tsx rename graphcap_studio/src/features/inference/providers/{FormActions.tsx => ProviderConnection/ProviderFormActions.tsx} (94%) rename graphcap_studio/src/features/inference/providers/{FormFields.tsx => ProviderConnection/ProviderFormTabs.tsx} (93%) create mode 100644 graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionActions.tsx create mode 100644 graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionErrorDialog.tsx rename graphcap_studio/src/features/inference/providers/{ => ProviderConnection}/components/ProviderConnectionSuccessDialog.tsx (100%) rename graphcap_studio/src/features/inference/providers/{ => ProviderConnection/components}/form/BasicInfoSection.tsx (100%) rename graphcap_studio/src/features/inference/providers/{ => ProviderConnection/components}/form/ConnectionSection.tsx (74%) rename graphcap_studio/src/features/inference/providers/{ => ProviderConnection/components}/form/EnvironmentSelect.tsx (92%) rename graphcap_studio/src/features/inference/providers/{ => ProviderConnection/components/form}/ModelSelectionSection.tsx (86%) rename graphcap_studio/src/features/inference/providers/{ => ProviderConnection/components}/form/ModelSelector.tsx (100%) create mode 100644 graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ProviderFormView.tsx rename graphcap_studio/src/features/inference/providers/{ => ProviderConnection/components}/form/ProviderSelect.tsx (96%) rename graphcap_studio/src/features/inference/providers/{ => ProviderConnection/components}/form/RateLimitsSection.tsx (93%) rename graphcap_studio/src/features/inference/providers/{ => ProviderConnection/components}/form/index.ts (65%) create mode 100644 graphcap_studio/src/features/inference/providers/ProviderConnection/context/useProviderForm.ts create mode 100644 graphcap_studio/src/features/inference/providers/ProviderConnection/hooks/useProviderConnection.ts create mode 100644 graphcap_studio/src/features/inference/providers/ProviderConnection/index.ts delete mode 100644 graphcap_studio/src/features/inference/providers/ProviderForm.tsx delete mode 100644 graphcap_studio/src/features/inference/providers/components/ProviderConnectionErrorDialog.tsx diff --git a/graphcap_studio/pnpm-lock.yaml b/graphcap_studio/pnpm-lock.yaml index b7056229..d6bd07e5 100644 --- a/graphcap_studio/pnpm-lock.yaml +++ b/graphcap_studio/pnpm-lock.yaml @@ -86,9 +86,6 @@ importers: react-window-infinite-loader: specifier: ^1.0.10 version: 1.0.10(react-dom@19.0.0(react@19.0.0))(react@19.0.0) - sonner: - specifier: ^1.7.4 - version: 1.7.4(react-dom@19.0.0(react@19.0.0))(react@19.0.0) styled-components: specifier: ^6.1.15 version: 6.1.15(react-dom@19.0.0(react@19.0.0))(react@19.0.0) @@ -2666,12 +2663,6 @@ packages: resolution: {integrity: sha512-94Bdh3cC2PKrbgSOUqTiGPWVZeSiXfKOVZNJniWoqrWrRkB1CJzBU3NEbiTsPcYy1lDsANA/THzS+9WBiy5nfQ==} engines: {node: '>= 10'} - sonner@1.7.4: - resolution: {integrity: sha512-DIS8z4PfJRbIyfVFDVnK9rO3eYDtse4Omcm6bt0oEr5/jtLgysmjuBl1frJ9E/EQZrFmKx2A8m/s5s9CRXIzhw==} - peerDependencies: - react: ^18.0.0 || ^19.0.0 || ^19.0.0-rc - react-dom: ^18.0.0 || ^19.0.0 || ^19.0.0-rc - source-map-js@1.2.1: resolution: {integrity: sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==} engines: {node: '>=0.10.0'} @@ -5790,11 +5781,6 @@ snapshots: mrmime: 2.0.1 totalist: 3.0.1 - sonner@1.7.4(react-dom@19.0.0(react@19.0.0))(react@19.0.0): - dependencies: - react: 19.0.0 - react-dom: 19.0.0(react@19.0.0) - source-map-js@1.2.1: {} source-map@0.5.7: {} diff --git a/graphcap_studio/src/features/inference/hooks/useProviderForm.ts b/graphcap_studio/src/features/inference/hooks/useProviderForm.ts index d73d2835..9c6a3970 100644 --- a/graphcap_studio/src/features/inference/hooks/useProviderForm.ts +++ b/graphcap_studio/src/features/inference/hooks/useProviderForm.ts @@ -6,6 +6,7 @@ import type { import { useCreateProvider, useUpdateProvider, + useUpdateProviderApiKey, } from "@/features/server-connections/services/providers"; import { useCallback } from "react"; @@ -29,6 +30,8 @@ export function useProviderForm(initialData: Partial = {}) { defaultValues: { ...DEFAULT_PROVIDER_FORM_DATA, ...initialData, + // Ensure apiKey is always a string, never undefined + apiKey: initialData.apiKey || '', }, mode: "onBlur", }); @@ -41,33 +44,66 @@ export function useProviderForm(initialData: Partial = {}) { // Mutations const createProvider = useCreateProvider(); const updateProvider = useUpdateProvider(); + const updateApiKeyMutation = useUpdateProviderApiKey(); + + // Update API key separately (needed because the server has a separate endpoint) + const updateApiKey = useCallback(async (providerId: number, apiKey: string) => { + if (!apiKey.trim()) { + console.warn("Attempted to update with empty API key, skipping"); + return { success: false, error: "API key cannot be empty" }; + } + + try { + await updateApiKeyMutation.mutateAsync({ id: providerId, apiKey }); + return { success: true }; + } catch (error) { + console.error("Error updating API key:", error); + return { + success: false, + error: error instanceof Error ? error.message : "Unknown error", + }; + } + }, [updateApiKeyMutation]); // Handle form submission const onSubmit = useCallback( async (data: FormData, isCreating: boolean, providerId?: number) => { try { - // Ensure required fields are present - if (!data.name || !data.kind || !data.environment || !data.baseUrl || !data.apiKey) { - throw new Error("Missing required fields"); - } - + // For create, we need all required fields if (isCreating) { + // Ensure required fields are present + if (!data.name || !data.kind || !data.environment || !data.baseUrl) { + throw new Error("Missing required fields"); + } + + // For create, we need the API key in the initial request await createProvider.mutateAsync(data as ProviderCreate); } else if (providerId) { + // For update, we only need the fields that changed + // apiKey should be handled separately + const { apiKey, ...updateData } = data; + await updateProvider.mutateAsync({ id: providerId, - data: data as ProviderUpdate, + data: updateData as ProviderUpdate, }); + + // If apiKey is provided and not empty, update it separately + if (apiKey && apiKey.trim() !== '') { + await updateApiKey(providerId, apiKey); + } } + reset(DEFAULT_PROVIDER_FORM_DATA); return { success: true }; } catch (error) { - return { - error: error instanceof Error ? error.message : "Unknown error", - }; + console.error("Error submitting provider form:", error); + + // Propagate the error so it can be handled by the UI + throw error; } }, - [createProvider, updateProvider, reset], + [createProvider, updateProvider, updateApiKey, reset], ); return { @@ -83,8 +119,9 @@ export function useProviderForm(initialData: Partial = {}) { // Form submission onSubmit, + updateApiKey, // Loading state - isSubmitting: createProvider.isPending || updateProvider.isPending, + isSubmitting: createProvider.isPending || updateProvider.isPending || updateApiKeyMutation.isPending, }; } diff --git a/graphcap_studio/src/features/inference/providers/FormFields.module.css b/graphcap_studio/src/features/inference/providers/ProviderConnection/FormFields.module.css similarity index 100% rename from graphcap_studio/src/features/inference/providers/FormFields.module.css rename to graphcap_studio/src/features/inference/providers/ProviderConnection/FormFields.module.css diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/ProviderForm.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/ProviderForm.tsx new file mode 100644 index 00000000..8d444674 --- /dev/null +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/ProviderForm.tsx @@ -0,0 +1,63 @@ +import { Box, Flex } from "@chakra-ui/react"; +// SPDX-License-Identifier: Apache-2.0 +import { memo } from "react"; +import { useInferenceProviderContext } from "../context"; +import { ProviderFormActions } from "./ProviderFormActions"; +import { ProviderFormTabs } from "./ProviderFormTabs"; +import { ProviderConnectionActions } from "./components/ProviderConnectionActions"; +import { ProviderConnectionErrorDialog } from "./components/ProviderConnectionErrorDialog"; +import { ProviderConnectionSuccessDialog } from "./components/ProviderConnectionSuccessDialog"; +import { useProviderConnection } from "./hooks/useProviderConnection"; + +/** + * Container component for the provider form that handles business logic and state management + */ +function ProviderForm() { + const { selectedProvider, mode } = useInferenceProviderContext(); + const { + isTestingConnection, + connectionError, + connectionDetails, + dialogs, + handleTestConnection, + closeDialog + } = useProviderConnection(selectedProvider); + + const isEditing = mode === "edit"; + const isCreating = mode === "create"; + + return ( + + + + {/* Actions */} + + + {(isEditing || isCreating) && } + + + {/* Connection Error Dialog */} + closeDialog('error')} + error={connectionError} + providerName={selectedProvider?.name || "Provider"} + /> + + {/* Success Dialog */} + closeDialog('success')} + providerName={selectedProvider?.name || "Provider"} + connectionDetails={connectionDetails} + /> + + ); +} + +export default memo(ProviderForm); diff --git a/graphcap_studio/src/features/inference/providers/FormActions.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/ProviderFormActions.tsx similarity index 94% rename from graphcap_studio/src/features/inference/providers/FormActions.tsx rename to graphcap_studio/src/features/inference/providers/ProviderConnection/ProviderFormActions.tsx index c7392309..1cc108ba 100644 --- a/graphcap_studio/src/features/inference/providers/FormActions.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/ProviderFormActions.tsx @@ -1,12 +1,12 @@ import { useColorMode } from "@/components/ui/theme/color-mode"; import { Button, Flex, HStack } from "@chakra-ui/react"; // SPDX-License-Identifier: Apache-2.0 -import { useInferenceProviderContext } from "./context"; +import { useInferenceProviderContext } from "../context"; /** * Component for rendering form action buttons with Chakra UI styling */ -export function FormActions() { +export function ProviderFormActions() { const { isSubmitting, isCreating, onCancel } = useInferenceProviderContext(); const { colorMode } = useColorMode(); @@ -32,6 +32,7 @@ export function FormActions() { + {showEditButton && ( + + )} + + ); +} \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionErrorDialog.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionErrorDialog.tsx new file mode 100644 index 00000000..4ec2cc59 --- /dev/null +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionErrorDialog.tsx @@ -0,0 +1,134 @@ +// SPDX-License-Identifier: Apache-2.0 +import { + Box, + Button, + Code, + Dialog, + Grid, + GridItem, + Icon, + Portal, + Text, +} from "@chakra-ui/react"; +import { useEffect, useRef } from "react"; +import { LuTriangleAlert } from "react-icons/lu"; + +type ErrorDetails = { + message?: string; + name?: string; + details?: string; + suggestions?: string[]; + requestDetails?: { + provider: string; + config: Record; + }; +} | string | null; + +type ProviderConnectionErrorDialogProps = { + readonly isOpen: boolean; + readonly onClose: () => void; + readonly error: ErrorDetails; + readonly providerName: string; +}; + +export function ProviderConnectionErrorDialog({ + isOpen, + onClose, + error, + providerName, +}: ProviderConnectionErrorDialogProps) { + const dialogContentRef = useRef(null); + + useEffect(() => { + function handleDialogClick(e: MouseEvent) { + e.stopPropagation(); + } + + const dialogElement = dialogContentRef.current; + if (dialogElement) { + dialogElement.addEventListener("click", handleDialogClick); + return () => { + dialogElement.removeEventListener("click", handleDialogClick); + }; + } + }, []); + + const errorObj = typeof error === 'string' ? { message: error } : error; + + return ( + !e.open && onClose()}> + + + + + + + Error: {providerName} + + + + + + + + + ); +} diff --git a/graphcap_studio/src/features/inference/providers/components/ProviderConnectionSuccessDialog.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionSuccessDialog.tsx similarity index 100% rename from graphcap_studio/src/features/inference/providers/components/ProviderConnectionSuccessDialog.tsx rename to graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionSuccessDialog.tsx diff --git a/graphcap_studio/src/features/inference/providers/form/BasicInfoSection.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/BasicInfoSection.tsx similarity index 100% rename from graphcap_studio/src/features/inference/providers/form/BasicInfoSection.tsx rename to graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/BasicInfoSection.tsx diff --git a/graphcap_studio/src/features/inference/providers/form/ConnectionSection.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ConnectionSection.tsx similarity index 74% rename from graphcap_studio/src/features/inference/providers/form/ConnectionSection.tsx rename to graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ConnectionSection.tsx index f874cae2..1c35e96a 100644 --- a/graphcap_studio/src/features/inference/providers/form/ConnectionSection.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ConnectionSection.tsx @@ -5,7 +5,7 @@ import { Group, InputElement } from "@chakra-ui/react"; // SPDX-License-Identifier: Apache-2.0 import { useState } from "react"; import { Controller } from "react-hook-form"; -import { useInferenceProviderContext } from "../context"; +import { useInferenceProviderContext } from "../../../context"; /** * Component for displaying and editing provider connection settings @@ -79,27 +79,33 @@ export function ConnectionSection() { ( - - API Key - - - - - - - {errors.apiKey?.message || (field.value === "" && "API key is required")} - - )} + render={({ field }) => { + // Ensure we always have a defined string value + const value = field.value || ""; + return ( + + API Key + + field.onChange(e.target.value)} + /> + + + + + {errors.apiKey?.message || (value === "" && "API key is required")} + + ); + }} /> | string | null; + connectionDetails: Record | null; + dialogs: { + error: boolean; + success: boolean; + formError: boolean; + }; + onSubmit: (data: ProviderCreate | ProviderUpdate) => Promise; + handleSubmit: (handler: (data: ProviderCreate | ProviderUpdate) => Promise) => (e: React.FormEvent) => void; + handleTestConnection: () => Promise; + setMode: (mode: 'view' | 'edit' | 'create') => void; + closeDialog: (dialog: 'error' | 'success' | 'formError') => void; +} + +/** + * Presentational component for the provider form + */ +export function ProviderFormView({ + mode, + isSubmitting, + saveSuccess, + isTestingConnection, + selectedProvider, + formError, + connectionError, + connectionDetails, + dialogs, + onSubmit, + handleSubmit, + handleTestConnection, + setMode, + closeDialog, +}: ProviderFormViewProps) { + const isEditing = mode === "edit"; + const isCreating = mode === "create"; + + return ( +
+ + + + {/* Loading/Success Message */} + + {isSubmitting && ( + + + Saving changes... + + )} + + {!isSubmitting && saveSuccess && ( + + Provider saved successfully! + + )} + + + {/* Actions */} + + {isEditing || isCreating ? ( + + ) : ( + <> + + + + )} + + + {/* Form Error Dialog */} + closeDialog("formError")} + error={formError} + providerName={selectedProvider?.name || "Provider"} + /> + + {/* Connection Error Dialog */} + closeDialog("error")} + error={connectionError} + providerName={selectedProvider?.name || "Provider"} + /> + + {/* Success Dialog */} + closeDialog("success")} + providerName={selectedProvider?.name || "Provider"} + connectionDetails={connectionDetails} + /> + +
+ ); +} diff --git a/graphcap_studio/src/features/inference/providers/form/ProviderSelect.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ProviderSelect.tsx similarity index 96% rename from graphcap_studio/src/features/inference/providers/form/ProviderSelect.tsx rename to graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ProviderSelect.tsx index b92ef9d5..d4ab407e 100644 --- a/graphcap_studio/src/features/inference/providers/form/ProviderSelect.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ProviderSelect.tsx @@ -7,7 +7,7 @@ import { SelectValueText, } from "@/components/ui/select"; import { createListCollection } from "@chakra-ui/react"; -import { useInferenceProviderContext } from "../context"; +import { useInferenceProviderContext } from "../../context"; type ProviderSelectProps = { readonly className?: string; diff --git a/graphcap_studio/src/features/inference/providers/form/RateLimitsSection.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/RateLimitsSection.tsx similarity index 93% rename from graphcap_studio/src/features/inference/providers/form/RateLimitsSection.tsx rename to graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/RateLimitsSection.tsx index 0c6f8089..3dde8c3d 100644 --- a/graphcap_studio/src/features/inference/providers/form/RateLimitsSection.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/RateLimitsSection.tsx @@ -9,9 +9,9 @@ import { VStack, } from "@chakra-ui/react"; // SPDX-License-Identifier: Apache-2.0 -import { ChangeEvent } from "react"; +import type { ChangeEvent } from "react"; import { Controller } from "react-hook-form"; -import { useInferenceProviderContext } from "../context"; +import { useInferenceProviderContext } from "../../context"; /** * Component for displaying and editing provider rate limits @@ -74,7 +74,7 @@ export function RateLimitsSection() { type="number" value={value ?? 0} onChange={(e: ChangeEvent) => - onChange(parseInt(e.target.value) || 0) + onChange(Number.parseInt(e.target.value) || 0) } min={0} /> @@ -101,7 +101,7 @@ export function RateLimitsSection() { type="number" value={value ?? 0} onChange={(e: ChangeEvent) => - onChange(parseInt(e.target.value) || 0) + onChange(Number.parseInt(e.target.value) || 0) } min={0} /> diff --git a/graphcap_studio/src/features/inference/providers/form/index.ts b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/index.ts similarity index 65% rename from graphcap_studio/src/features/inference/providers/form/index.ts rename to graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/index.ts index 74080701..fc3f7fee 100644 --- a/graphcap_studio/src/features/inference/providers/form/index.ts +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/index.ts @@ -4,6 +4,6 @@ export * from "./ConnectionSection"; export * from "./RateLimitsSection"; export * from "./EnvironmentSelect"; export * from "./ProviderSelect"; -export * from "../../../../components/ui/status/StatusMessage"; +export * from "../../../../../components/ui/status/StatusMessage"; export * from "./ModelSelector"; -export * from "../../../../components/ui/buttons/ActionButton"; +export * from "../../../../../components/ui/buttons/ActionButton"; diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/context/useProviderForm.ts b/graphcap_studio/src/features/inference/providers/ProviderConnection/context/useProviderForm.ts new file mode 100644 index 00000000..12662fbd --- /dev/null +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/context/useProviderForm.ts @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: Apache-2.0 +import { useState } from "react"; +import { useTestProviderConnection } from "../../../services/providers"; +import { useInferenceProviderContext } from "../../context"; +import { type Provider, type ProviderCreate, type ProviderUpdate, toServerConfig } from "../../types"; + +interface UseProviderFormResult { + mode: 'view' | 'edit' | 'create'; + isSubmitting: boolean; + saveSuccess: boolean; + isTestingConnection: boolean; + selectedProvider?: Provider | null; + formError: unknown; + connectionError: Record | string | null; + connectionDetails: Record | null; + dialogs: { + error: boolean; + success: boolean; + formError: boolean; + }; + onSubmit: (data: ProviderCreate | ProviderUpdate) => Promise; + handleSubmit: (handler: (data: ProviderCreate | ProviderUpdate) => Promise) => (e: React.FormEvent) => void; + handleTestConnection: () => Promise; + setMode: (mode: 'view' | 'edit' | 'create') => void; + closeDialog: (dialog: 'error' | 'success' | 'formError') => void; +} + +/** + * Custom hook that manages provider form state and logic + */ +export function useProviderForm(selectedProvider: Provider | null): UseProviderFormResult { + const { + handleSubmit, + isSubmitting, + onSubmit: onSubmitProp, + mode, + setMode, + } = useInferenceProviderContext(); + + const [isTestingConnection, setIsTestingConnection] = useState(false); + const [connectionError, setConnectionError] = useState | string | null>(null); + const [connectionDetails, setConnectionDetails] = useState | null>(null); + const [formError, setFormError] = useState(null); + const [saveSuccess, setSaveSuccess] = useState(false); + const [dialogs, setDialogs] = useState({ + error: false, + success: false, + formError: false + }); + + const testConnection = useTestProviderConnection(); + + const closeDialog = (dialog: keyof typeof dialogs) => { + setDialogs(prev => ({ ...prev, [dialog]: false })); + }; + + const onSubmit = async (data: ProviderCreate | ProviderUpdate) => { + try { + setFormError(null); + setSaveSuccess(false); + await onSubmitProp(data); + setSaveSuccess(true); + + // Reset success message after 3 seconds + setTimeout(() => { + setSaveSuccess(false); + }, 3000); + } catch (error) { + console.error("Provider form submission error:", error); + setFormError(error); + setDialogs(prev => ({ ...prev, formError: true })); + } + }; + + const handleTestConnection = async () => { + if (!selectedProvider) return; + + // Validate API key is present + if (!selectedProvider.apiKey) { + setConnectionError({ + title: "Connection failed", + timestamp: new Date().toISOString(), + message: "API key is required", + name: "ValidationError", + details: "Please provide an API key in the provider configuration.", + suggestions: [ + "Edit the provider to add an API key", + "API keys should be non-empty strings", + ], + }); + setDialogs(prev => ({ ...prev, error: true })); + return; + } + + setIsTestingConnection(true); + setConnectionError(null); + + try { + const config = toServerConfig(selectedProvider); + const result = await testConnection.mutateAsync({ + providerName: selectedProvider.name, + config, + }); + + setConnectionDetails(result); + setDialogs(prev => ({ ...prev, success: true })); + } catch (error) { + console.error("Connection test failed:", error); + + let errorObj: Record = { + title: "Connection failed", + timestamp: new Date().toISOString(), + }; + + if (error instanceof Error) { + errorObj.message = error.message; + errorObj.name = error.name; + + if (error.message?.includes("[object Object]")) { + errorObj.message = "Invalid provider configuration"; + errorObj.details = "The server rejected the request due to invalid parameters."; + errorObj.suggestions = [ + "Check API key and endpoint URL", + "Verify the provider is correctly configured", + "Check server logs for more details", + ]; + } + + if ('cause' in error && typeof error.cause === 'object') { + errorObj.errorDetails = error.cause; + } + } else if (typeof error === "object" && error !== null) { + errorObj = { + ...errorObj, + ...(error as Record), + }; + } else { + errorObj.message = String(error); + } + + setConnectionError(errorObj); + setDialogs(prev => ({ ...prev, error: true })); + } finally { + setIsTestingConnection(false); + } + }; + + return { + mode, + isSubmitting, + saveSuccess, + isTestingConnection, + selectedProvider, + formError, + connectionError, + connectionDetails, + dialogs, + onSubmit, + handleSubmit, + handleTestConnection, + setMode, + closeDialog, + }; +} \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/hooks/useProviderConnection.ts b/graphcap_studio/src/features/inference/providers/ProviderConnection/hooks/useProviderConnection.ts new file mode 100644 index 00000000..138bbceb --- /dev/null +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/hooks/useProviderConnection.ts @@ -0,0 +1,149 @@ +// SPDX-License-Identifier: Apache-2.0 +import { useState } from "react"; +import { useTestProviderConnection } from "../../../services/providers"; +import { useInferenceProviderContext } from "../../context"; +import { type Provider, type ProviderCreate, toServerConfig } from "../../types"; + +interface UseProviderConnectionResult { + isTestingConnection: boolean; + connectionError: Record | string | null; + connectionDetails: Record | null; + dialogs: { + error: boolean; + success: boolean; + }; + handleTestConnection: () => Promise; + closeDialog: (dialog: 'error' | 'success') => void; +} + +/** + * Hook for managing provider connection testing + */ +export function useProviderConnection(selectedProvider: Provider | null): UseProviderConnectionResult { + const { watch } = useInferenceProviderContext(); + const [isTestingConnection, setIsTestingConnection] = useState(false); + const [connectionError, setConnectionError] = useState | string | null>(null); + const [connectionDetails, setConnectionDetails] = useState | null>(null); + const [dialogs, setDialogs] = useState({ + error: false, + success: false + }); + + const testConnection = useTestProviderConnection(); + + const closeDialog = (dialog: keyof typeof dialogs) => { + setDialogs(prev => ({ ...prev, [dialog]: false })); + }; + + const handleTestConnection = async () => { + // Get current form values + const currentFormValues = { + ...selectedProvider, // Base values from saved provider + name: watch('name'), + apiKey: watch('apiKey'), + baseUrl: watch('baseUrl'), + kind: watch('kind'), + environment: watch('environment'), + // Add other necessary fields from the form + } as Provider; + + if (!currentFormValues.apiKey) { + setConnectionError({ + title: "Connection failed", + timestamp: new Date().toISOString(), + message: "API key is required", + name: "ValidationError", + details: "Please provide an API key in the provider configuration.", + suggestions: [ + "Enter an API key in the form", + "API keys should be non-empty strings", + ], + requestDetails: { + provider: currentFormValues.name, + config: { + ...toServerConfig(currentFormValues), + api_key: '[MISSING]' + } + } + }); + setDialogs(prev => ({ ...prev, error: true })); + return; + } + + setIsTestingConnection(true); + setConnectionError(null); + + try { + const config = toServerConfig(currentFormValues); + const requestDetails = { + provider: currentFormValues.name, + config: { + ...config, + api_key: config.api_key ? '[REDACTED]' : undefined + } + }; + + const result = await testConnection.mutateAsync({ + providerName: currentFormValues.name, + config, + }); + + setConnectionDetails(result); + setDialogs(prev => ({ ...prev, success: true })); + } catch (error) { + console.error("Connection test failed:", error); + + let errorObj: Record = { + title: "Connection failed", + timestamp: new Date().toISOString(), + requestDetails: { + provider: currentFormValues.name, + config: { + ...toServerConfig(currentFormValues), + api_key: '[REDACTED]' + } + } + }; + + if (error instanceof Error) { + errorObj.message = error.message; + errorObj.name = error.name; + + if (error.message?.includes("[object Object]")) { + errorObj.message = "Invalid provider configuration"; + errorObj.details = "The server rejected the request due to invalid parameters."; + errorObj.suggestions = [ + "Check API key and endpoint URL in the form", + "Verify the provider configuration is correct", + "Check server logs for more details", + ]; + } + + if ('cause' in error && typeof error.cause === 'object') { + errorObj.errorDetails = error.cause; + } + } else if (typeof error === "object" && error !== null) { + errorObj = { + ...errorObj, + ...(error as Record), + }; + } else { + errorObj.message = String(error); + } + + setConnectionError(errorObj); + setDialogs(prev => ({ ...prev, error: true })); + } finally { + setIsTestingConnection(false); + } + }; + + return { + isTestingConnection, + connectionError, + connectionDetails, + dialogs, + handleTestConnection, + closeDialog, + }; +} \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/index.ts b/graphcap_studio/src/features/inference/providers/ProviderConnection/index.ts new file mode 100644 index 00000000..eb9a0f8a --- /dev/null +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/index.ts @@ -0,0 +1,4 @@ +// SPDX-License-Identifier: Apache-2.0 +export { default as ProviderForm } from './ProviderForm'; +export { useProviderForm } from './context/useProviderForm'; +export { ProviderFormView } from './components/form/ProviderFormView'; \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/providers/ProviderForm.tsx b/graphcap_studio/src/features/inference/providers/ProviderForm.tsx deleted file mode 100644 index 1ab8ea42..00000000 --- a/graphcap_studio/src/features/inference/providers/ProviderForm.tsx +++ /dev/null @@ -1,218 +0,0 @@ -import { handleApiError } from "@/utils/error-handler"; -import { Box, Button, Flex } from "@chakra-ui/react"; -// SPDX-License-Identifier: Apache-2.0 -import { memo, useState } from "react"; -import { useTestProviderConnection } from "../services/providers"; -import { FormFields } from "./FormFields"; -import { ProviderConnectionErrorDialog } from "./components/ProviderConnectionErrorDialog"; -import { ProviderConnectionSuccessDialog } from "./components/ProviderConnectionSuccessDialog"; -import { useInferenceProviderContext } from "./context"; -import { - type ProviderCreate, - type ProviderUpdate, - toServerConfig, -} from "./types"; - -// Extended Error interface with cause property -interface ErrorWithCause extends Error { - cause?: unknown; -} - -// Add this interface below the ErrorWithCause interface -interface ErrorWithResponse { - response?: { - data?: unknown; - }; - error?: string; - message?: string; -} - -/** - * Component for provider form that displays fields in either view or edit mode - */ -function ProviderForm() { - const { - handleSubmit, - isSubmitting, - onSubmit: onSubmitProp, - onCancel, - mode, - setMode, - selectedProvider, - } = useInferenceProviderContext(); - const [isTestingConnection, setIsTestingConnection] = useState(false); - const [connectionError, setConnectionError] = useState< - Record | string | null - >(null); - const [connectionDetails, setConnectionDetails] = useState | null>(null); - const [isErrorDialogOpen, setIsErrorDialogOpen] = useState(false); - const [isSuccessDialogOpen, setIsSuccessDialogOpen] = useState(false); - - const testConnection = useTestProviderConnection(); - - const isEditing = mode === "edit"; - const isCreating = mode === "create"; - const isViewMode = mode === "view"; - - // Wrap the submit handler to use our error handler - const onSubmit = async (data: ProviderCreate | ProviderUpdate) => { - try { - await onSubmitProp(data); - } catch (error) { - handleApiError(error); - } - }; - - const handleTestConnection = async () => { - if (!selectedProvider) return; - - // Validate API key is present - if (!selectedProvider.apiKey) { - setConnectionError({ - title: "Connection failed", - timestamp: new Date().toISOString(), - message: "API key is required", - name: "ValidationError", - details: "Please provide an API key in the provider configuration.", - suggestions: [ - "Edit the provider to add an API key", - "API keys should be non-empty strings", - ], - }); - setIsErrorDialogOpen(true); - return; - } - - setIsTestingConnection(true); - setConnectionError(null); - - try { - const config = toServerConfig(selectedProvider); - - // Log the config for debugging - console.log("Testing connection with provider config:", { - ...config, - api_key: config.api_key ? "[REDACTED]" : null, - }); - - const result = await testConnection.mutateAsync({ - providerName: selectedProvider.name, - config, - }); - - setConnectionDetails(result); - setIsSuccessDialogOpen(true); - } catch (error) { - console.error("Connection test failed:", error); - - // Create a user-friendly error object that can be displayed directly - let errorObj: Record = { - title: "Connection failed", - timestamp: new Date().toISOString(), - }; - - // Extract error information based on type - if (error instanceof Error) { - // Extract useful properties from Error objects - errorObj.message = error.message; - errorObj.name = error.name; - - // Special case for [object Object] errors - if (error.message?.includes("[object Object]")) { - errorObj.message = "Invalid provider configuration"; - errorObj.details = - "The server rejected the request due to invalid parameters."; - errorObj.suggestions = [ - "Check API key and endpoint URL", - "Verify the provider is correctly configured", - "Check server logs for more details", - ]; - } - - // Check for cause object with additional details - const errorWithCause = error as ErrorWithCause; - if (errorWithCause.cause && typeof errorWithCause.cause === "object") { - errorObj.errorDetails = errorWithCause.cause; - } - } else if (typeof error === "object" && error !== null) { - // For direct object errors, merge with our error object - errorObj = { - ...errorObj, - ...(error as Record), - }; - } else { - // For primitive errors - errorObj.message = String(error); - } - - // Set the formatted error object - setConnectionError(errorObj); - setIsErrorDialogOpen(true); - } finally { - setIsTestingConnection(false); - } - }; - - return ( - - - - {/* Actions */} - - {isEditing || isCreating ? ( - <> - - - - ) : ( - <> - - - - )} - - - {/* Error Dialog */} - setIsErrorDialogOpen(false)} - error={connectionError} - providerName={selectedProvider?.name || "Provider"} - /> - - {/* Success Dialog */} - setIsSuccessDialogOpen(false)} - providerName={selectedProvider?.name || "Provider"} - connectionDetails={connectionDetails} - /> - - ); -} - -export default memo(ProviderForm); diff --git a/graphcap_studio/src/features/inference/providers/ProvidersList.tsx b/graphcap_studio/src/features/inference/providers/ProvidersList.tsx index c38bb2dd..a4efaa96 100644 --- a/graphcap_studio/src/features/inference/providers/ProvidersList.tsx +++ b/graphcap_studio/src/features/inference/providers/ProvidersList.tsx @@ -1,17 +1,12 @@ -import { useProviderFormContext } from "./context"; // SPDX-License-Identifier: Apache-2.0 -import { ProviderSelect } from "./form"; +import { ProviderSelect } from "./ProviderConnection/form"; +import { useProviderFormContext } from "./context"; -type ProvidersListProps = { - readonly onSelectProvider: (id: number) => void; -}; /** * Component for displaying a list of providers as a dropdown */ -export default function ProvidersList({ - onSelectProvider, -}: ProvidersListProps) { +export default function ProvidersList() { const { providers } = useProviderFormContext(); if (providers.length === 0) { @@ -21,10 +16,9 @@ export default function ProvidersList({ ); } - return (
- +
); } diff --git a/graphcap_studio/src/features/inference/providers/ProvidersPanel.tsx b/graphcap_studio/src/features/inference/providers/ProvidersPanel.tsx index 5711a3ce..2a8a2b18 100644 --- a/graphcap_studio/src/features/inference/providers/ProvidersPanel.tsx +++ b/graphcap_studio/src/features/inference/providers/ProvidersPanel.tsx @@ -1,14 +1,15 @@ import { useColorMode } from "@/components/ui/theme/color-mode"; import { Box, Button, Center, Flex, Text, VStack } from "@chakra-ui/react"; // SPDX-License-Identifier: Apache-2.0 -import { useMemo } from "react"; +import { useCallback, useMemo, useState } from "react"; +import type { ProviderCreate, ProviderUpdate } from "../providers/types"; import { useProviders } from "../services/providers"; -import ProviderForm from "./ProviderForm"; +import ProviderForm from "./ProviderConnection/ProviderForm"; +import { ProviderSelect } from "./ProviderConnection/form"; import { InferenceProviderProvider, useInferenceProviderContext, } from "./context"; -import { ProviderSelect } from "./form"; /** * Panel content that requires context @@ -78,8 +79,12 @@ export function ProvidersPanel() { isLoading, isError, error, + refetch } = useProviders(); + // State to track form submission + const [isSubmitting, setIsSubmitting] = useState(false); + // Set the initial selected provider to the first one in the list const initialSelectedProvider = useMemo(() => { return providersData.length > 0 ? providersData[0] : null; @@ -88,6 +93,25 @@ export function ProvidersPanel() { const { colorMode } = useColorMode(); const textColor = colorMode === "light" ? "gray.600" : "gray.300"; + // Handle form submission + const handleSubmit = useCallback(async (data: ProviderCreate | ProviderUpdate) => { + try { + console.log("Provider form submitted:", data); + setIsSubmitting(true); + + // Simulate a delay for demonstration purposes + await new Promise(resolve => setTimeout(resolve, 1000)); + + // Refetch providers to get updated data + await refetch(); + console.log("Provider updated successfully:", data); + } catch (error) { + console.error("Error updating provider:", error); + } finally { + setIsSubmitting(false); + } + }, [refetch]); + // Loading state if (isLoading) { return ( @@ -111,9 +135,9 @@ export function ProvidersPanel() { providers={providersData} selectedProvider={initialSelectedProvider} isCreating={false} - onSubmit={() => {}} + onSubmit={handleSubmit} onCancel={() => {}} - isSubmitting={false} + isSubmitting={isSubmitting} > diff --git a/graphcap_studio/src/features/inference/providers/components/ProviderConnectionErrorDialog.tsx b/graphcap_studio/src/features/inference/providers/components/ProviderConnectionErrorDialog.tsx deleted file mode 100644 index addacae9..00000000 --- a/graphcap_studio/src/features/inference/providers/components/ProviderConnectionErrorDialog.tsx +++ /dev/null @@ -1,133 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -import { - Box, - Button, - Code, - Dialog, - Icon, - Portal, - Text, - VStack, -} from "@chakra-ui/react"; -import { useEffect, useRef } from "react"; -import { LuTriangleAlert } from "react-icons/lu"; - -/** - * Dialog component that displays detailed error information when a provider - * connection test fails - */ -type ErrorDetails = Record | string | null; - -type ProviderConnectionErrorDialogProps = { - isOpen: boolean; - onClose: () => void; - error: ErrorDetails; - providerName: string; -}; - -export function ProviderConnectionErrorDialog({ - isOpen, - onClose, - error, - providerName, -}: ProviderConnectionErrorDialogProps) { - // Create a reference to the dialog content - const dialogContentRef = useRef(null); - - // Prevent clicks inside the dialog from triggering outside click handlers - useEffect(() => { - function handleDialogClick(e: MouseEvent) { - // Stop event propagation for all clicks inside the dialog - e.stopPropagation(); - } - - const dialogElement = dialogContentRef.current; - if (dialogElement) { - dialogElement.addEventListener("click", handleDialogClick); - - return () => { - dialogElement.removeEventListener("click", handleDialogClick); - }; - } - }, []); // No dependencies needed as we're just setting up the event listener - - // Format error details - simplified direct approach - let formattedErrorDetails = "Unknown error occurred"; - - if (error) { - if (typeof error === "object") { - try { - formattedErrorDetails = JSON.stringify(error, null, 2); - } catch (e) { - formattedErrorDetails = `Error could not be serialized: ${String(e)}`; - } - } else { - formattedErrorDetails = String(error); - } - } - - return ( - !e.open && onClose()}> - - - - - - Connection Error: {providerName} - - - - - - - - ); -} diff --git a/graphcap_studio/src/features/inference/providers/context/InferenceProviderContext.tsx b/graphcap_studio/src/features/inference/providers/context/InferenceProviderContext.tsx index 7401a316..16dd21ab 100644 --- a/graphcap_studio/src/features/inference/providers/context/InferenceProviderContext.tsx +++ b/graphcap_studio/src/features/inference/providers/context/InferenceProviderContext.tsx @@ -247,7 +247,21 @@ export function InferenceProviderProvider({ onSubmit: onSubmitForm, watch, reset, - } = useProviderForm(initialData); + updateApiKey, + } = useProviderForm({ + ...initialData, + // If we're editing an existing provider, ensure its data is properly passed + ...(selectedProvider && mode === "edit" ? { + name: selectedProvider.name, + kind: selectedProvider.kind, + environment: selectedProvider.environment, + baseUrl: selectedProvider.baseUrl, + apiKey: selectedProvider.apiKey || '', + isEnabled: selectedProvider.isEnabled, + defaultModel: selectedProvider.defaultModel || '', + fetchModels: selectedProvider.fetchModels, + } : {}), + }); // Reset form data when selected provider changes useEffect(() => { @@ -257,8 +271,12 @@ export function InferenceProviderProvider({ kind: selectedProvider.kind, environment: selectedProvider.environment, baseUrl: selectedProvider.baseUrl, - envVar: selectedProvider.envVar, isEnabled: selectedProvider.isEnabled, + // If apiKey is null from the server, use an empty string to avoid React controlled/uncontrolled issues + // Don't include apiKey if in view mode to prevent showing empty field + ...(mode === "edit" ? { apiKey: selectedProvider.apiKey || "" } : {}), + defaultModel: selectedProvider.defaultModel || "", + fetchModels: selectedProvider.fetchModels, rateLimits: selectedProvider.rateLimits || { requestsPerMinute: 0, tokensPerMinute: 0, @@ -288,13 +306,37 @@ export function InferenceProviderProvider({ // Create a memoized version of onSubmit that calls both form and prop handlers const onSubmitHandler = useCallback( async (data: FormData) => { - const result = await onSubmitForm(data, isCreating, selectedProvider?.id); - if (result.success) { - onSubmitProp(data); - setMode("view"); + try { + console.log("InferenceProviderContext onSubmitHandler called with data:", data); + console.log("isSubmitting state:", isSubmitting); + + // Extract apiKey if present - we need to update it separately + const { apiKey, ...providerData } = data; + + // First update the provider without the API key + const result = await onSubmitForm( + providerData, + isCreating, + selectedProvider?.id + ); + + if (result.success) { + // If we're editing and have a new API key, update it separately + if (!isCreating && selectedProvider?.id && apiKey) { + console.log("Updating API key separately"); + await updateApiKey(selectedProvider.id, apiKey); + } + + // Notify the parent component that we've submitted successfully + onSubmitProp(data); + setMode("view"); + } + } catch (error) { + console.error("Error updating provider:", error); + throw error; // Re-throw the error so it can be caught by the form's error handler } }, - [onSubmitForm, onSubmitProp, setMode, isCreating, selectedProvider?.id], + [onSubmitForm, onSubmitProp, setMode, isCreating, selectedProvider?.id, updateApiKey, isSubmitting], ); // Create a memoized version of onCancel that resets mode @@ -372,6 +414,3 @@ export function InferenceProviderProvider({ ); } - -// For backward compatibility -export const ProviderFormProvider = InferenceProviderProvider; diff --git a/graphcap_studio/src/features/inference/providers/index.ts b/graphcap_studio/src/features/inference/providers/index.ts index 351881e1..0a6ed1da 100644 --- a/graphcap_studio/src/features/inference/providers/index.ts +++ b/graphcap_studio/src/features/inference/providers/index.ts @@ -1,10 +1,10 @@ // SPDX-License-Identifier: Apache-2.0 -export { default as ProviderForm } from "./ProviderForm"; +export { default as ProviderForm } from "./ProviderConnection/ProviderForm"; export { ProvidersPanel } from "./ProvidersPanel"; export { default as ProvidersList } from "./ProvidersList"; -export { ModelSelectionSection } from "./ModelSelectionSection"; -export { FormFields } from "./FormFields"; -export { FormActions } from "./FormActions"; +export { ModelSelectionSection } from "./ProviderConnection/form/ModelSelectionSection"; +export { ProviderFormTabs } from "./ProviderConnection/ProviderFormTabs"; +export { ProviderFormActions } from "./ProviderConnection/ProviderFormActions"; export * from "../hooks"; export * from "./context"; diff --git a/graphcap_studio/src/features/server-connections/services/providers.ts b/graphcap_studio/src/features/server-connections/services/providers.ts index 6478b483..7db9cda1 100644 --- a/graphcap_studio/src/features/server-connections/services/providers.ts +++ b/graphcap_studio/src/features/server-connections/services/providers.ts @@ -118,14 +118,49 @@ export function useUpdateProvider() { return useMutation({ mutationFn: async ({ id, data }: { id: number; data: ProviderUpdate }) => { + // Filter out null values and skip apiKey property completely + const updateData = Object.entries(data).reduce((acc, [key, value]) => { + // Skip apiKey completely - it has its own endpoint + if (key === 'apiKey') return acc; + + // Only include defined values + if (value !== null && value !== undefined) { + acc[key] = value; + } + return acc; + }, {} as Record); + const client = createDataServiceClient(connections); const response = await client.providers[":id"].$put({ param: { id: id.toString() }, - json: data, + json: updateData, }); if (!response.ok) { - throw new Error(`Failed to update provider: ${response.status}`); + // Try to get detailed error information + try { + const errorData = await response.json(); + console.error("Provider update error:", errorData); + + // Check if we have a structured error response + if (errorData.status === 'error' || errorData.validationErrors) { + throw errorData; + } + + // Simple error with a message + if (errorData.message) { + throw new Error(errorData.message); + } + + // Fallback error + throw new Error(`Failed to update provider: ${response.status}`); + } catch (parseError) { + // If we can't parse the error as JSON, throw a general error + if (parseError instanceof Error && parseError.message !== 'Failed to update provider') { + throw parseError; + } + throw new Error(`Failed to update provider: ${response.status}`); + } } return response.json() as Promise; @@ -184,7 +219,30 @@ export function useUpdateProviderApiKey() { }); if (!response.ok) { - throw new Error(`Failed to update API key: ${response.status}`); + // Try to get detailed error information + try { + const errorData = await response.json(); + console.error("API key update error:", errorData); + + // Check if we have a structured error response + if (errorData.status === 'error' || errorData.validationErrors) { + throw errorData; + } + + // Simple error with a message + if (errorData.message) { + throw new Error(errorData.message); + } + + // Fallback error + throw new Error(`Failed to update API key: ${response.status}`); + } catch (parseError) { + // If we can't parse the error as JSON, throw a general error + if (parseError instanceof Error && parseError.message !== 'Failed to update API key') { + throw parseError; + } + throw new Error(`Failed to update API key: ${response.status}`); + } } return response.json() as Promise; diff --git a/servers/data_service/src/features/providers/controller.ts b/servers/data_service/src/features/providers/controller.ts index c3bf06c9..72e3c0a1 100644 --- a/servers/data_service/src/features/providers/controller.ts +++ b/servers/data_service/src/features/providers/controller.ts @@ -187,11 +187,102 @@ export const updateProvider = async (c: Context) => { status: "error", statusCode: 404, message: "Provider not found", - timestamp: new Date().toISOString(), - path: c.req.path + providerId: id }, 404); } + // Enhanced validation with detailed error messages + const validationErrors: Record = {}; + + // Name validation + if (data.name !== undefined) { + if (!data.name) { + validationErrors.name = ['Provider name cannot be empty']; + } else if (data.name.trim() === '') { + validationErrors.name = ['Provider name cannot be just whitespace']; + } else if (data.name.length < 3) { + validationErrors.name = ['Provider name must be at least 3 characters long']; + } + } + + // Kind validation + if (data.kind !== undefined) { + if (!data.kind) { + validationErrors.kind = ['Provider kind cannot be empty']; + } else if (data.kind.trim() === '') { + validationErrors.kind = ['Provider kind cannot be just whitespace']; + } else if (!['openai', 'anthropic', 'google', 'custom'].includes(data.kind.toLowerCase())) { + validationErrors.kind = ['Provider kind must be one of: openai, anthropic, google, custom']; + } + } + + // Base URL validation + if (data.baseUrl !== undefined) { + if (!data.baseUrl) { + validationErrors.baseUrl = ['Base URL cannot be empty']; + } else { + try { + const url = new URL(data.baseUrl); + if (!['http:', 'https:'].includes(url.protocol)) { + validationErrors.baseUrl = ['Base URL must use HTTP or HTTPS protocol']; + } + } catch (e) { + validationErrors.baseUrl = ['Base URL must be a valid URL']; + } + } + } + + // Environment validation + if (data.environment !== undefined) { + if (!data.environment) { + validationErrors.environment = ['Environment cannot be empty']; + } else if (!['production', 'development', 'staging', 'test'].includes(data.environment.toLowerCase())) { + validationErrors.environment = ['Environment must be one of: production, development, staging, test']; + } + } + + // Models validation if provided + if (data.models !== undefined) { + const modelErrors: string[] = []; + data.models.forEach((model, index) => { + if (!model.name) { + modelErrors.push(`Model at index ${index} must have a name`); + } + if (typeof model.isEnabled !== 'boolean') { + modelErrors.push(`Model ${model.name || `at index ${index}`} must have a boolean isEnabled field`); + } + }); + if (modelErrors.length > 0) { + validationErrors.models = modelErrors; + } + } + + // Rate limits validation if provided + if (data.rateLimits !== undefined) { + const rateLimitErrors: string[] = []; + if (typeof data.rateLimits.requestsPerMinute !== 'number' || data.rateLimits.requestsPerMinute < 0) { + rateLimitErrors.push('requestsPerMinute must be a non-negative number'); + } + if (typeof data.rateLimits.tokensPerMinute !== 'number' || data.rateLimits.tokensPerMinute < 0) { + rateLimitErrors.push('tokensPerMinute must be a non-negative number'); + } + if (rateLimitErrors.length > 0) { + validationErrors.rateLimits = rateLimitErrors; + } + } + + // If there are validation errors, return them + if (Object.keys(validationErrors).length > 0) { + logger.debug({ validationErrors }, "Validation errors in provider update"); + return c.json({ + status: "error", + statusCode: 400, + message: "Validation failed", + providerId: id, + validationErrors + }, 400); + } + // Extract models and rate limits if provided const { models, rateLimits, ...providerData } = data; @@ -271,16 +362,15 @@ export const updateProvider = async (c: Context) => { error, message: error instanceof Error ? error.message : "Unknown error", stack: error instanceof Error ? error.stack : undefined, + providerId: c.req.param('id') }, "Error updating provider"); - // Return detailed error response + // Return error response return c.json({ status: "error", statusCode: 500, message: error instanceof Error ? error.message : "Failed to update provider", - timestamp: new Date().toISOString(), - path: c.req.path, - details: error instanceof Error ? { name: error.name } : undefined + errorType: error instanceof Error ? error.name : 'Unknown' }, 500); } }; @@ -329,21 +419,6 @@ export const updateProviderApiKey = async (c: Context) => { const { apiKey } = c.req.valid("json") as ProviderApiKey; logger.debug({ id }, "Updating provider API key"); - // Validate API key - if (!apiKey || apiKey.trim() === '') { - logger.debug({ id }, "Empty API key provided"); - return c.json({ - status: "error", - statusCode: 400, - message: "API key cannot be empty", - timestamp: new Date().toISOString(), - path: c.req.path, - validationErrors: { - "apiKey": ["API key cannot be empty"] - } - }, 400); - } - // Check if provider exists const existingProvider = await db.query.providers.findFirst({ where: eq(providers.id, Number.parseInt(id)), @@ -360,10 +435,30 @@ export const updateProviderApiKey = async (c: Context) => { }, 404); } - // Encrypt the API key + // Validate API key + const validationErrors: Record = {}; + + if (!apiKey || apiKey.trim() === '') { + validationErrors.apiKey = ['API key cannot be empty']; + } + + // If there are validation errors, return them + if (Object.keys(validationErrors).length > 0) { + logger.debug({ validationErrors }, "API key validation errors"); + return c.json({ + status: "error", + statusCode: 400, + message: "Validation failed", + timestamp: new Date().toISOString(), + path: c.req.path, + validationErrors + }, 400); + } + + // Encrypt API key const encryptedApiKey = await encryptApiKey(apiKey); - // Update the provider's API key + // Update API key await db .update(providers) .set({ @@ -376,7 +471,6 @@ export const updateProviderApiKey = async (c: Context) => { return c.json({ success: true, message: "API key updated successfully", - timestamp: new Date().toISOString() }); } catch (error) { const providerId = c.req.param('id'); From e5f5b561229c4f10e4f4746b279e4c41923a4b67 Mon Sep 17 00:00:00 2001 From: jphillips Date: Wed, 26 Mar 2025 07:55:53 -0500 Subject: [PATCH 11/69] Update provider client env var to api key Signed-off-by: jphillips --- .../graphcap/providers/clients/base_client.py | 11 +---------- .../graphcap/providers/clients/gemini_client.py | 4 ++-- .../graphcap/providers/clients/ollama_client.py | 4 ++-- .../graphcap/providers/clients/openai_client.py | 4 ++-- .../graphcap/providers/clients/openrouter_client.py | 4 ++-- .../graphcap/providers/clients/vllm_client.py | 4 ++-- 6 files changed, 11 insertions(+), 20 deletions(-) diff --git a/servers/inference_server/graphcap/providers/clients/base_client.py b/servers/inference_server/graphcap/providers/clients/base_client.py index f2401492..a32ea8bb 100644 --- a/servers/inference_server/graphcap/providers/clients/base_client.py +++ b/servers/inference_server/graphcap/providers/clients/base_client.py @@ -38,15 +38,7 @@ class BaseClient(AsyncOpenAI, ABC): """Abstract base class for all provider clients""" - def __init__(self, name: str, kind: str, environment: str, env_var: str, base_url: str, default_model: str): - # Check for required environment variable - if env_var and env_var != "NONE": - api_key = os.getenv(env_var) - if api_key is None: - raise ValueError(f"Environment variable {env_var} is not set") - else: - api_key = "stub_key" - + def __init__(self, name: str, kind: str, environment: str, base_url: str, default_model: str, api_key: str): # Initialize OpenAI client super().__init__(api_key=api_key, base_url=base_url) @@ -54,7 +46,6 @@ def __init__(self, name: str, kind: str, environment: str, env_var: str, base_ur self.name = name self.kind = kind self.environment = environment - self.env_var = env_var self.base_url = base_url self.default_model = default_model diff --git a/servers/inference_server/graphcap/providers/clients/gemini_client.py b/servers/inference_server/graphcap/providers/clients/gemini_client.py index 4e232c56..670fd8f2 100644 --- a/servers/inference_server/graphcap/providers/clients/gemini_client.py +++ b/servers/inference_server/graphcap/providers/clients/gemini_client.py @@ -27,15 +27,15 @@ class GeminiClient(BaseClient): """Client for Google's Gemini API with OpenAI compatibility layer""" - def __init__(self, name: str, kind: str, environment: str, env_var: str, base_url: str, default_model: str): + def __init__(self, name: str, kind: str, environment: str, base_url: str, default_model: str, api_key: str): logger.info(f"GeminiClient initialized with base_url: {base_url}") super().__init__( name=name, kind=kind, environment=environment, - env_var=env_var, base_url=base_url.rstrip("/"), default_model=default_model, + api_key=api_key, ) def _format_vision_content(self, text: str, image_data: str) -> list[dict[str, Any]]: diff --git a/servers/inference_server/graphcap/providers/clients/ollama_client.py b/servers/inference_server/graphcap/providers/clients/ollama_client.py index fe42c180..56167622 100644 --- a/servers/inference_server/graphcap/providers/clients/ollama_client.py +++ b/servers/inference_server/graphcap/providers/clients/ollama_client.py @@ -26,7 +26,7 @@ class OllamaClient(BaseClient): """Client for Ollama API with OpenAI compatibility layer""" - def __init__(self, name: str, kind: str, environment: str, env_var: str, base_url: str, default_model: str): + def __init__(self, name: str, kind: str, environment: str, base_url: str, default_model: str, api_key: str = "stub_key"): logger.info("Initializing OllamaClient:") logger.info(f" - name: {name}") logger.info(f" - kind: {kind}") @@ -56,9 +56,9 @@ def __init__(self, name: str, kind: str, environment: str, env_var: str, base_ur name=name, kind=kind, environment=environment, - env_var=env_var, base_url=openai_base_url, default_model=default_model, + api_key=api_key, ) logger.debug(f"OllamaClient initialized with environment: {environment}, kind: {kind}") logger.debug(f"Using base URL {self._raw_base_url} for Ollama endpoints") diff --git a/servers/inference_server/graphcap/providers/clients/openai_client.py b/servers/inference_server/graphcap/providers/clients/openai_client.py index 46eebda0..c1da80d7 100644 --- a/servers/inference_server/graphcap/providers/clients/openai_client.py +++ b/servers/inference_server/graphcap/providers/clients/openai_client.py @@ -29,15 +29,15 @@ class OpenAIClient(BaseClient): """Client for OpenAI API""" - def __init__(self, name: str, kind: str, environment: str, env_var: str, base_url: str, default_model: str): + def __init__(self, name: str, kind: str, environment: str, base_url: str, default_model: str, api_key: str): logger.info(f"OpenAIClient initialized with base_url: {base_url}") super().__init__( name=name, kind=kind, environment=environment, - env_var=env_var, base_url=base_url.rstrip("/"), default_model=default_model, + api_key=api_key, ) def _format_vision_content(self, text: str, image_data: str) -> list[dict[str, Any]]: diff --git a/servers/inference_server/graphcap/providers/clients/openrouter_client.py b/servers/inference_server/graphcap/providers/clients/openrouter_client.py index 031a9e51..8681ebf5 100644 --- a/servers/inference_server/graphcap/providers/clients/openrouter_client.py +++ b/servers/inference_server/graphcap/providers/clients/openrouter_client.py @@ -27,15 +27,15 @@ class OpenRouterClient(BaseClient): """Client for OpenRouter API with OpenAI compatibility layer""" - def __init__(self, name: str, kind: str, environment: str, env_var: str, base_url: str, default_model: str): + def __init__(self, name: str, kind: str, environment: str, base_url: str, default_model: str): logger.info(f"OpenRouterClient initialized with base_url: {base_url}") super().__init__( name=name, kind=kind, environment=environment, - env_var=env_var, base_url=base_url.rstrip("/"), default_model=default_model, + api_key=api_key, ) def _format_vision_content(self, text: str, image_data: str) -> list[dict[str, Any]]: diff --git a/servers/inference_server/graphcap/providers/clients/vllm_client.py b/servers/inference_server/graphcap/providers/clients/vllm_client.py index 26de0d9f..63dcb9f8 100644 --- a/servers/inference_server/graphcap/providers/clients/vllm_client.py +++ b/servers/inference_server/graphcap/providers/clients/vllm_client.py @@ -28,7 +28,7 @@ class VLLMClient(BaseClient): """Client for VLLM API with OpenAI compatibility layer""" - def __init__(self, name: str, kind: str, environment: str, env_var: str, base_url: str, default_model: str): + def __init__(self, name: str, kind: str, environment: str, base_url: str, default_model: str, api_key: str = "stub_key"): # If base_url doesn't include /v1, append it if not base_url.endswith("/v1"): base_url = f"{base_url}/v1" @@ -38,9 +38,9 @@ def __init__(self, name: str, kind: str, environment: str, env_var: str, base_ur name=name, kind=kind, environment=environment, - env_var=env_var, base_url=base_url.rstrip("/"), default_model=default_model, + api_key=api_key, ) def _format_vision_content(self, text: str, image_data: str) -> list[dict[str, Any]]: From a82099b197586846c279584e164ccacb232220a3 Mon Sep 17 00:00:00 2001 From: jphillips Date: Wed, 26 Mar 2025 08:17:24 -0500 Subject: [PATCH 12/69] Add in connections tests to provider setup Signed-off-by: jphillips --- .../components/ConnectionSteps.tsx | 70 +++++ .../ProviderConnectionSuccessDialog.tsx | 261 +++++++++++++----- .../server/features/providers/service.py | 75 +++-- 3 files changed, 312 insertions(+), 94 deletions(-) create mode 100644 graphcap_studio/src/features/inference/providers/ProviderConnection/components/ConnectionSteps.tsx diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ConnectionSteps.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ConnectionSteps.tsx new file mode 100644 index 00000000..02b158f9 --- /dev/null +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ConnectionSteps.tsx @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: Apache-2.0 +import { Box, HStack, Icon, Text, VStack } from "@chakra-ui/react"; +import { LuCheck, LuCircleAlert, LuSkipForward } from "react-icons/lu"; + +/** + * Component that displays connection test steps and their results + */ +interface ConnectionStep { + step: string; + status: "success" | "failed" | "skipped" | "pending"; + timestamp: string; + error?: string; + message?: string; +} + +interface ConnectionStepsProps { + readonly steps: ConnectionStep[]; + readonly stepLabels?: Record; +} + +function StepIcon({ status }: { status: ConnectionStep["status"] }) { + switch (status) { + case "success": + return ; + case "skipped": + return ; + case "failed": + return ; + default: + return null; + } +} + +function ConnectionStepResult({ step, labels }: { step: ConnectionStep; labels?: Record }) { + const stepLabel = labels?.[step.step] || step.step; + + return ( + + + + + + {stepLabel} + {step.message && ( + {step.message} + )} + {step.error && ( + {step.error} + )} + + + ); +} + +export function ConnectionSteps({ steps, stepLabels = {} }: ConnectionStepsProps) { + return ( + + Connection Test Results: + {steps.map((step) => ( + + ))} + + ); +} + +export type { ConnectionStep }; \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionSuccessDialog.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionSuccessDialog.tsx index 4348628a..e1297b8a 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionSuccessDialog.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionSuccessDialog.tsx @@ -1,88 +1,199 @@ // SPDX-License-Identifier: Apache-2.0 -import { - Button, - Dialog, - Icon, - Portal, - Text, - VStack +import { + Button, + Dialog, + Icon, + Portal, + Separator, + Text, + VStack, } from "@chakra-ui/react"; import { useEffect, useRef } from "react"; -import { LuCheck } from "react-icons/lu"; +import { LuCheck, LuCircleAlert } from "react-icons/lu"; +import { type ConnectionStep, ConnectionSteps } from "./ConnectionSteps"; /** - * Dialog component that displays a success message when a provider - * connection test is successful + * Dialog component that displays connection test results */ interface ConnectionDetails { - models?: unknown[]; - [key: string]: unknown; + result: { + provider: string; + details: { + method?: string; + models_count?: number; + chat_completion_test?: "success"; + test_model?: string; + }; + diagnostics: { + connection_steps: ConnectionStep[]; + warnings: Array<{ + warning_type: string; + message: string; + }>; + }; + }; } type ProviderConnectionSuccessDialogProps = { - isOpen: boolean; - onClose: () => void; - providerName: string; - connectionDetails?: ConnectionDetails | null; + readonly isOpen: boolean; + readonly onClose: () => void; + readonly providerName: string; + readonly connectionDetails: ConnectionDetails; }; -export function ProviderConnectionSuccessDialog({ - isOpen, - onClose, - providerName, - connectionDetails +const STEP_LABELS: Record = { + initialize_client: "Initialize Client", + list_models: "List Available Models", + test_chat_completion: "Test Chat Completion", +}; + +export function ProviderConnectionSuccessDialog({ + isOpen, + onClose, + providerName, + connectionDetails, }: ProviderConnectionSuccessDialogProps) { - // Create a reference to the dialog content - const dialogContentRef = useRef(null); - - // Prevent clicks inside the dialog from triggering outside click handlers - useEffect(() => { - function handleDialogClick(e: MouseEvent) { - // Stop event propagation for all clicks inside the dialog - e.stopPropagation(); - } - - const dialogElement = dialogContentRef.current; - if (dialogElement) { - dialogElement.addEventListener("click", handleDialogClick); - - return () => { - dialogElement.removeEventListener("click", handleDialogClick); - }; - } - }, []); // No dependencies needed as we're just setting up the event listener - - return ( - !e.open && onClose()}> - - - - - - Connection Successful - - - - - - - - ); -} \ No newline at end of file + // Create a reference to the dialog content + const dialogContentRef = useRef(null); + + // Prevent clicks inside the dialog from triggering outside click handlers + useEffect(() => { + function handleDialogClick(e: MouseEvent) { + // Stop event propagation for all clicks inside the dialog + e.stopPropagation(); + } + + const dialogElement = dialogContentRef.current; + if (dialogElement) { + dialogElement.addEventListener("click", handleDialogClick); + + return () => { + dialogElement.removeEventListener("click", handleDialogClick); + }; + } + }, []); // No dependencies needed as we're just setting up the event listener + + const { result } = connectionDetails; + const steps = result.diagnostics.connection_steps; + const warnings = result.diagnostics.warnings; + const details = result.details; + + // Check if any required steps were skipped or failed + const hasSkippedSteps = steps.some((step) => step.status === "skipped"); + const hasFailedSteps = steps.some((step) => step.status === "failed"); + const allStepsSuccessful = steps.every((step) => step.status === "success"); + + // Determine the overall status + const getStatusInfo = () => { + if (hasFailedSteps) { + return { + title: "Connection Failed", + icon: LuCircleAlert, + color: "red.500", + message: "Connection test failed. Please check the details below.", + }; + } + if (hasSkippedSteps) { + return { + title: "Connection Partial", + icon: LuCircleAlert, + color: "yellow.500", + message: + "Connected with limited functionality. Some tests were skipped.", + }; + } + return { + title: "Connection Successful", + icon: LuCheck, + color: "green.500", + message: `Successfully connected to ${providerName}!`, + }; + }; + + const status = getStatusInfo(); + + return ( + !e.open && onClose()}> + + + + + + {status.title} + + + + + + + + ); +} diff --git a/servers/inference_server/server/server/features/providers/service.py b/servers/inference_server/server/server/features/providers/service.py index 21778a76..ef48fceb 100644 --- a/servers/inference_server/server/server/features/providers/service.py +++ b/servers/inference_server/server/server/features/providers/service.py @@ -200,16 +200,15 @@ async def test_provider_connection(provider_name: str, config: ProviderConfig) - # First check if we can get models (most providers support this) if isinstance(client, ModelProvider): try: - # Add diagnostic step + # Add diagnostic step for model list result["diagnostics"]["connection_steps"].append({ - "step": "verify_connection", + "step": "list_models", "status": "pending", "timestamp": str(datetime.datetime.now()) }) if hasattr(client, "get_available_models"): provider_models = await client.get_available_models() - result["connection_verified"] = True result["details"]["method"] = "get_available_models" # Add model information if available @@ -223,7 +222,6 @@ async def test_provider_connection(provider_name: str, config: ProviderConfig) - elif hasattr(client, "get_models"): provider_models = await client.get_models() - result["connection_verified"] = True result["details"]["method"] = "get_models" # Add model information if available @@ -234,23 +232,58 @@ async def test_provider_connection(provider_name: str, config: ProviderConfig) - models_data.append({"id": model_id}) result["details"]["available_models"] = models_data result["details"]["models_count"] = len(models_data) - - else: - # If we can't check models, just having created the client is enough - result["connection_verified"] = True - result["details"]["method"] = "client_init_only" # Update diagnostic step result["diagnostics"]["connection_steps"][-1]["status"] = "success" except Exception as e: - logger.error(f"Error testing connection to {provider_name}: {str(e)}") + logger.warning(f"Could not list models for {provider_name}: {str(e)}") + result["diagnostics"]["connection_steps"][-1]["status"] = "skipped" + result["diagnostics"]["connection_steps"][-1]["message"] = "Model listing not supported or failed" + + # Try a simple chat completion as a more thorough test + try: + # Add diagnostic step for chat completion + result["diagnostics"]["connection_steps"].append({ + "step": "test_chat_completion", + "status": "pending", + "timestamp": str(datetime.datetime.now()) + }) + + # Use the first available model or default model + test_model = None + if result.get("details", {}).get("available_models"): + test_model = result["details"]["available_models"][0]["id"] + elif config.default_model: + test_model = config.default_model + + if test_model: + # Simple test message + test_messages = [{"role": "user", "content": "Hello, this is a test message. Please respond with 'OK' if you can process this request."}] - # Update diagnostic step - result["diagnostics"]["connection_steps"][-1]["status"] = "failed" - result["diagnostics"]["connection_steps"][-1]["error"] = str(e) - result["diagnostics"]["connection_steps"][-1]["error_type"] = type(e).__name__ + completion = await client.chat.completions.create( + model=test_model, + messages=test_messages, + max_tokens=10, # Keep it minimal + temperature=0, # Deterministic + ) + + result["connection_verified"] = True + result["details"]["chat_completion_test"] = "success" + result["details"]["test_model"] = test_model + result["diagnostics"]["connection_steps"][-1]["status"] = "success" + else: + result["diagnostics"]["connection_steps"][-1]["status"] = "skipped" + result["diagnostics"]["connection_steps"][-1]["message"] = "No suitable model found for testing" + except Exception as e: + logger.error(f"Chat completion test failed for {provider_name}: {str(e)}") + result["diagnostics"]["connection_steps"][-1]["status"] = "failed" + result["diagnostics"]["connection_steps"][-1]["error"] = str(e) + result["diagnostics"]["connection_steps"][-1]["error_type"] = type(e).__name__ + + # Only mark connection as failed if we couldn't list models either + if not result.get("details", {}).get("available_models"): result["connection_verified"] = False result["details"]["error"] = str(e) result["details"]["error_type"] = type(e).__name__ @@ -264,11 +297,15 @@ async def test_provider_connection(provider_name: str, config: ProviderConfig) - result["details"]["suggestion"] = "Check if the base URL is correct for this provider" raise Exception(f"Error verifying connection: {str(e)}") - else: - # For providers that don't support models API - result["connection_verified"] = True - result["details"]["method"] = "client_init_only" - + else: + # If we could list models but chat completion failed, just warn + result["diagnostics"]["warnings"].append({ + "warning_type": "chat_completion_failed", + "message": f"Chat completion test failed but model listing succeeded. Provider may have limited functionality. Error: {str(e)}" + }) + result["connection_verified"] = True + result["details"]["method"] = "list_models_only" + return result except Exception as e: From aefbcea391c12e7d3876566c79e755369aa88ec7 Mon Sep 17 00:00:00 2001 From: jphillips Date: Wed, 26 Mar 2025 08:32:09 -0500 Subject: [PATCH 13/69] Update graphcap_server to inference_brdige Signed-off-by: jphillips --- .cursor/rules/task.mdc | 4 +- Taskfile.yml | 2 +- docker-compose.override.example.yml | 8 +- docker-compose.yml | 18 +- .../ProviderConnection/ProviderFormTabs.tsx | 4 +- .../ProviderConnectionSuccessDialog.tsx | 10 +- .../components/{form => }/ProviderSelect.tsx | 0 .../components/form/ConnectionSection.tsx | 234 +++++++++--------- .../components/form/EnvironmentSelect.tsx | 4 +- .../components/form/RateLimitsSection.tsx | 2 +- .../components/form/index.ts | 3 - .../inference/providers/ProvidersList.tsx | 2 +- .../inference/providers/ProvidersPanel.tsx | 2 +- .../src/features/inference/providers/index.ts | 1 - pyproject.toml | 2 +- .../README.md | 0 .../__init__.py | 0 .../graphcap/__init__.py | 0 .../graphcap/perspectives/__init__.py | 0 .../graphcap/perspectives/base.py | 0 .../graphcap/perspectives/base_caption.py | 0 .../graphcap/perspectives/constants.py | 0 .../graphcap/perspectives/loaders/__init__.py | 0 .../perspectives/loaders/directory.py | 0 .../perspectives/loaders/json_file.py | 0 .../graphcap/perspectives/loaders/modules.py | 0 .../graphcap/perspectives/loaders/settings.py | 0 .../graphcap/perspectives/models.py | 0 .../graphcap/perspectives/module.py | 0 .../perspectives/perspective_loader.py | 0 .../graphcap/perspectives/processor.py | 0 .../graphcap/perspectives/types.py | 0 .../graphcap/providers/README.md | 0 .../graphcap/providers/__init__.py | 0 .../graphcap/providers/clients/__init__.py | 0 .../graphcap/providers/clients/base_client.py | 0 .../providers/clients/gemini_client.py | 0 .../providers/clients/ollama_client.py | 0 .../providers/clients/openai_client.py | 0 .../providers/clients/openrouter_client.py | 0 .../graphcap/providers/clients/vllm_client.py | 0 .../graphcap/providers/factory.py | 0 .../graphcap/providers/types.py | 0 .../pipelines/.dep_hash | 0 .../pipelines/.dockerignore | 0 .../pipelines/Dockerfile.pipelines.dev | 0 .../pipelines/README.md | 0 .../pipelines/Taskfile.pipelines.yml | 0 .../pipelines/_scripts/pipeline_entrypoint.sh | 0 .../pipelines/dagster.example.yml | 0 .../pipelines/pipelines/__init__.py | 0 .../pipelines/pipelines/assets.py | 0 .../pipelines/pipelines/common/__init__.py | 0 .../pipelines/pipelines/common/constants.py | 0 .../pipelines/pipelines/common/io.py | 0 .../pipelines/pipelines/common/logging.py | 0 .../pipelines/pipelines/common/resources.py | 0 .../pipelines/pipelines/common/utils.py | 0 .../pipelines/pipelines/common/workspace.py | 0 .../pipelines/pipelines/definitions.py | 0 .../pipelines/huggingface/__init__.py | 0 .../pipelines/pipelines/huggingface/client.py | 0 .../pipelines/huggingface/dataset_export.py | 0 .../pipelines/huggingface/dataset_import.py | 0 .../pipelines/huggingface/dataset_manifest.py | 0 .../pipelines/huggingface/dataset_prep.py | 0 .../pipelines/huggingface/dataset_readme.py | 0 .../huggingface/perspective_export.py | 0 .../pipelines/pipelines/huggingface/types.py | 0 .../pipelines/pipelines/io/__init__.py | 0 .../pipelines/pipelines/io/image/__init__.py | 0 .../io/image/image_metadata/__init__.py | 0 .../image_metadata/common_formats/__init__.py | 0 .../common_formats/iptc_metadata.py | 0 .../common_formats/xmp_metadata.py | 0 .../io/image/image_metadata/extract_exif.py | 0 .../pipelines/io/image/load_images.py | 0 .../pipelines/pipelines/io/image/types.py | 0 .../pipelines/pipelines/jobs/__init__.py | 0 .../pipelines/jobs/dataset_import_job.py | 0 .../pipelines/jobs/image_metadata.py | 0 .../pipelines/pipelines/jobs/omi.py | 0 .../pipelines/perspectives/__init__.py | 0 .../pipelines/perspectives/assets.py | 0 .../pipelines/perspectives/jobs/__init__.py | 0 .../jobs/basic_perspective_pipeline.py | 0 .../pipelines/perspectives/jobs/config.py | 0 .../pipelines/pipelines/perspectives/types.py | 0 .../pipelines/pipelines/providers/__init__.py | 0 .../pipelines/pipelines/providers/assets.py | 0 .../pipelines/pipelines/providers/util.py | 0 .../pipelines/pipelines/start.py | 0 .../pipelines/pipelines_tests/__init__.py | 0 .../pipelines/pipelines_tests/test_assets.py | 0 .../pipelines/pyproject.toml | 0 .../pipelines/setup.cfg | 0 .../pipelines/setup.py | 0 .../pipelines/uv.lock | 0 .../pyproject.toml | 0 .../pytest.ini | 0 .../scripts/__init__.py | 0 .../scripts/__main__.py | 0 .../scripts/config_writer.py | 0 .../scripts/setup.py | 0 .../server/.dep_hash | 0 .../server/.dockerignore | 0 .../server/.env.local.template | 0 .../server/Dockerfile.server.dev | 0 .../server/README.md | 0 .../server/Taskfile.inference.yml | 12 +- .../server/__init__.py | 0 .../server/_scripts/endpoints-entrypoint.sh | 0 .../server/_scripts/gunicorn.conf.py | 0 .../server/pyproject.toml | 0 .../server/server/__init__.py | 0 .../server/server/config.py | 0 .../server/server/config/router.py | 0 .../server/server/db.py | 0 .../server/server/dependencies.py | 0 .../server/features/perspectives/__init__.py | 0 .../server/features/perspectives/models.py | 0 .../server/features/perspectives/router.py | 0 .../server/features/perspectives/service.py | 0 .../server/features/providers/__init__.py | 0 .../features/providers/error_handler.py | 0 .../server/features/providers/models.py | 0 .../server/features/providers/router.py | 0 .../server/features/providers/service.py | 0 .../server/features/repositories/types.py | 0 .../server/server/main.py | 0 .../server/server/models.py | 0 .../server/server/pipelines/__init__py | 0 .../server/server/pipelines/dagster_client.py | 0 .../server/server/routers.py | 0 .../server/server/utils/__init__.py | 0 .../server/server/utils/logger.py | 0 .../server/server/utils/middleware.py | 0 .../server/server/utils/resizing.py | 0 .../server/uv.lock | 0 .../tests/test_perspective_modules.py | 0 .../uv.lock | 0 141 files changed, 161 insertions(+), 147 deletions(-) rename graphcap_studio/src/features/inference/providers/ProviderConnection/components/{form => }/ProviderSelect.tsx (100%) rename servers/{inference_server => inference_bridge}/README.md (100%) rename servers/{inference_server => inference_bridge}/__init__.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/__init__.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/perspectives/__init__.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/perspectives/base.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/perspectives/base_caption.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/perspectives/constants.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/perspectives/loaders/__init__.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/perspectives/loaders/directory.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/perspectives/loaders/json_file.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/perspectives/loaders/modules.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/perspectives/loaders/settings.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/perspectives/models.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/perspectives/module.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/perspectives/perspective_loader.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/perspectives/processor.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/perspectives/types.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/providers/README.md (100%) rename servers/{inference_server => inference_bridge}/graphcap/providers/__init__.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/providers/clients/__init__.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/providers/clients/base_client.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/providers/clients/gemini_client.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/providers/clients/ollama_client.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/providers/clients/openai_client.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/providers/clients/openrouter_client.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/providers/clients/vllm_client.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/providers/factory.py (100%) rename servers/{inference_server => inference_bridge}/graphcap/providers/types.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/.dep_hash (100%) rename servers/{inference_server => inference_bridge}/pipelines/.dockerignore (100%) rename servers/{inference_server => inference_bridge}/pipelines/Dockerfile.pipelines.dev (100%) rename servers/{inference_server => inference_bridge}/pipelines/README.md (100%) rename servers/{inference_server => inference_bridge}/pipelines/Taskfile.pipelines.yml (100%) rename servers/{inference_server => inference_bridge}/pipelines/_scripts/pipeline_entrypoint.sh (100%) rename servers/{inference_server => inference_bridge}/pipelines/dagster.example.yml (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/__init__.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/assets.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/common/__init__.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/common/constants.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/common/io.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/common/logging.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/common/resources.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/common/utils.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/common/workspace.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/definitions.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/huggingface/__init__.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/huggingface/client.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/huggingface/dataset_export.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/huggingface/dataset_import.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/huggingface/dataset_manifest.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/huggingface/dataset_prep.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/huggingface/dataset_readme.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/huggingface/perspective_export.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/huggingface/types.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/io/__init__.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/io/image/__init__.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/io/image/image_metadata/__init__.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/io/image/image_metadata/common_formats/__init__.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/io/image/image_metadata/common_formats/iptc_metadata.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/io/image/image_metadata/common_formats/xmp_metadata.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/io/image/image_metadata/extract_exif.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/io/image/load_images.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/io/image/types.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/jobs/__init__.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/jobs/dataset_import_job.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/jobs/image_metadata.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/jobs/omi.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/perspectives/__init__.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/perspectives/assets.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/perspectives/jobs/__init__.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/perspectives/jobs/basic_perspective_pipeline.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/perspectives/jobs/config.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/perspectives/types.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/providers/__init__.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/providers/assets.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/providers/util.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines/start.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines_tests/__init__.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pipelines_tests/test_assets.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/pyproject.toml (100%) rename servers/{inference_server => inference_bridge}/pipelines/setup.cfg (100%) rename servers/{inference_server => inference_bridge}/pipelines/setup.py (100%) rename servers/{inference_server => inference_bridge}/pipelines/uv.lock (100%) rename servers/{inference_server => inference_bridge}/pyproject.toml (100%) rename servers/{inference_server => inference_bridge}/pytest.ini (100%) rename servers/{inference_server => inference_bridge}/scripts/__init__.py (100%) rename servers/{inference_server => inference_bridge}/scripts/__main__.py (100%) rename servers/{inference_server => inference_bridge}/scripts/config_writer.py (100%) rename servers/{inference_server => inference_bridge}/scripts/setup.py (100%) rename servers/{inference_server => inference_bridge}/server/.dep_hash (100%) rename servers/{inference_server => inference_bridge}/server/.dockerignore (100%) rename servers/{inference_server => inference_bridge}/server/.env.local.template (100%) rename servers/{inference_server => inference_bridge}/server/Dockerfile.server.dev (100%) rename servers/{inference_server => inference_bridge}/server/README.md (100%) rename servers/{inference_server => inference_bridge}/server/Taskfile.inference.yml (69%) rename servers/{inference_server => inference_bridge}/server/__init__.py (100%) rename servers/{inference_server => inference_bridge}/server/_scripts/endpoints-entrypoint.sh (100%) rename servers/{inference_server => inference_bridge}/server/_scripts/gunicorn.conf.py (100%) rename servers/{inference_server => inference_bridge}/server/pyproject.toml (100%) rename servers/{inference_server => inference_bridge}/server/server/__init__.py (100%) rename servers/{inference_server => inference_bridge}/server/server/config.py (100%) rename servers/{inference_server => inference_bridge}/server/server/config/router.py (100%) rename servers/{inference_server => inference_bridge}/server/server/db.py (100%) rename servers/{inference_server => inference_bridge}/server/server/dependencies.py (100%) rename servers/{inference_server => inference_bridge}/server/server/features/perspectives/__init__.py (100%) rename servers/{inference_server => inference_bridge}/server/server/features/perspectives/models.py (100%) rename servers/{inference_server => inference_bridge}/server/server/features/perspectives/router.py (100%) rename servers/{inference_server => inference_bridge}/server/server/features/perspectives/service.py (100%) rename servers/{inference_server => inference_bridge}/server/server/features/providers/__init__.py (100%) rename servers/{inference_server => inference_bridge}/server/server/features/providers/error_handler.py (100%) rename servers/{inference_server => inference_bridge}/server/server/features/providers/models.py (100%) rename servers/{inference_server => inference_bridge}/server/server/features/providers/router.py (100%) rename servers/{inference_server => inference_bridge}/server/server/features/providers/service.py (100%) rename servers/{inference_server => inference_bridge}/server/server/features/repositories/types.py (100%) rename servers/{inference_server => inference_bridge}/server/server/main.py (100%) rename servers/{inference_server => inference_bridge}/server/server/models.py (100%) rename servers/{inference_server => inference_bridge}/server/server/pipelines/__init__py (100%) rename servers/{inference_server => inference_bridge}/server/server/pipelines/dagster_client.py (100%) rename servers/{inference_server => inference_bridge}/server/server/routers.py (100%) rename servers/{inference_server => inference_bridge}/server/server/utils/__init__.py (100%) rename servers/{inference_server => inference_bridge}/server/server/utils/logger.py (100%) rename servers/{inference_server => inference_bridge}/server/server/utils/middleware.py (100%) rename servers/{inference_server => inference_bridge}/server/server/utils/resizing.py (100%) rename servers/{inference_server => inference_bridge}/server/uv.lock (100%) rename servers/{inference_server => inference_bridge}/tests/test_perspective_modules.py (100%) rename servers/{inference_server => inference_bridge}/uv.lock (100%) diff --git a/.cursor/rules/task.mdc b/.cursor/rules/task.mdc index 7f1b94d7..6b3bb799 100644 --- a/.cursor/rules/task.mdc +++ b/.cursor/rules/task.mdc @@ -12,5 +12,5 @@ Provider configuration still relies on some file based configuration. - Resolve issues cited in discord related to provider ux -[provider_manager.py](mdc:servers/inference_server/graphcap/providers/provider_manager.py) -[provider_config.py](mdc:servers/inference_server/graphcap/providers/provider_config.py) +[provider_manager.py](mdc:servers/inference_bridge/graphcap/providers/provider_manager.py) +[provider_config.py](mdc:servers/inference_bridge/graphcap/providers/provider_config.py) diff --git a/Taskfile.yml b/Taskfile.yml index ef95fa0b..cbd8b251 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -3,7 +3,7 @@ dotenv: ['.env', '{{.ENV}}/.env', '{{.HOME}}/.env'] includes: data: ./servers/data_service/Taskfile.data.yml - inference: ./servers/inference_server/server/Taskfile.inference.yml + inference: ./servers/inference_bridge/server/Taskfile.inference.yml studio: ./graphcap_studio/Taskfile.studio.yml tasks: diff --git a/docker-compose.override.example.yml b/docker-compose.override.example.yml index 2d1cec70..91a7d635 100644 --- a/docker-compose.override.example.yml +++ b/docker-compose.override.example.yml @@ -14,8 +14,8 @@ # SPDX-License-Identifier: Apache-2.0 name: graphcap # services: -# graphcap_server: -# container_name: graphcap_server +# inference_bridge: +# container_name: inference_bridge # build: # context: ./src # dockerfile: ./server/Dockerfile.server @@ -45,7 +45,7 @@ name: graphcap # ports: # - "35433:5432" # volumes: -# - graphcap_server_db:/var/lib/postgresql/data +# - inference_bridge_db:/var/lib/postgresql/data # healthcheck: # test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER}"] # interval: 5s @@ -181,7 +181,7 @@ name: graphcap # volumes: -# graphcap_server_db: +# inference_bridge_db: # networks: # graphcap: diff --git a/docker-compose.yml b/docker-compose.yml index d029f1a5..ddf589c1 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,16 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 name: graphcap services: - graphcap_server: - container_name: graphcap_server + inference_bridge: + container_name: inference_bridge build: - context: ./servers/inference_server + context: ./servers/inference_bridge dockerfile: ./server/Dockerfile.server.dev ports: - 32100:32100 volumes: - - ./servers/inference_server/graphcap:/app/graphcap - - ./servers/inference_server/server/server:/app/server/server + - ./servers/inference_bridge/graphcap:/app/graphcap + - ./servers/inference_bridge/server/server:/app/server/server - ./workspace:/workspace environment: - HOST_PLATFORM=${HOST_PLATFORM:-linux} @@ -33,7 +33,7 @@ services: ports: - "35433:5432" volumes: - - graphcap_server_db:/var/lib/postgresql/data + - inference_bridge_db:/var/lib/postgresql/data healthcheck: test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER}"] interval: 5s @@ -128,12 +128,12 @@ services: graphcap_pipelines: build: - context: ./servers/inference_server + context: ./servers/inference_bridge dockerfile: pipelines/Dockerfile.pipelines.dev container_name: graphcap_pipelines volumes: - ./workspace:/workspace - - ./servers/inference_server:/app + - ./servers/inference_bridge:/app environment: - DAGSTER_HOME=/workspace/.local/.dagster - DAGSTER_PORT=32300 @@ -182,7 +182,7 @@ services: - ./workspace/config/.env volumes: - graphcap_server_db: + inference_bridge_db: networks: graphcap: diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/ProviderFormTabs.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/ProviderFormTabs.tsx index d2f85146..131e4261 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/ProviderFormTabs.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/ProviderFormTabs.tsx @@ -1,8 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 import { Tabs } from "@chakra-ui/react"; import styles from "./FormFields.module.css"; -import { BasicInfoSection, ConnectionSection, RateLimitsSection } from "./form"; -import { ModelSelectionSection } from "./form/ModelSelectionSection"; +import { BasicInfoSection, ConnectionSection, RateLimitsSection } from "./components/form"; +import { ModelSelectionSection } from "./components/form/ModelSelectionSection"; /** * Component for rendering provider form fields in either view or edit mode diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionSuccessDialog.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionSuccessDialog.tsx index e1297b8a..d3596619 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionSuccessDialog.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionSuccessDialog.tsx @@ -38,7 +38,7 @@ type ProviderConnectionSuccessDialogProps = { readonly isOpen: boolean; readonly onClose: () => void; readonly providerName: string; - readonly connectionDetails: ConnectionDetails; + readonly connectionDetails: ConnectionDetails | null; }; const STEP_LABELS: Record = { @@ -53,6 +53,7 @@ export function ProviderConnectionSuccessDialog({ providerName, connectionDetails, }: ProviderConnectionSuccessDialogProps) { + // Create a reference to the dialog content const dialogContentRef = useRef(null); @@ -71,7 +72,12 @@ export function ProviderConnectionSuccessDialog({ dialogElement.removeEventListener("click", handleDialogClick); }; } - }, []); // No dependencies needed as we're just setting up the event listener + }, []); + + // Return early if connectionDetails is null + if (!connectionDetails) { + return null; + } const { result } = connectionDetails; const steps = result.diagnostics.connection_steps; diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ProviderSelect.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderSelect.tsx similarity index 100% rename from graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ProviderSelect.tsx rename to graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderSelect.tsx diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ConnectionSection.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ConnectionSection.tsx index 1c35e96a..fc0a7e7d 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ConnectionSection.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ConnectionSection.tsx @@ -1,7 +1,15 @@ import { Switch } from "@/components/ui/buttons/Switch"; import { useColorModeValue } from "@/components/ui/theme/color-mode"; -import { Box, Button, Field, Flex, Input, Text, VStack } from "@chakra-ui/react"; -import { Group, InputElement } from "@chakra-ui/react"; +import { + Box, + Button, + Field, + Group, + Input, + InputElement, + Text, + VStack, +} from "@chakra-ui/react"; // SPDX-License-Identifier: Apache-2.0 import { useState } from "react"; import { Controller } from "react-hook-form"; @@ -11,121 +19,125 @@ import { useInferenceProviderContext } from "../../../context"; * Component for displaying and editing provider connection settings */ export function ConnectionSection() { - const { control, errors, watch, isEditing, selectedProvider } = useInferenceProviderContext(); - const [showApiKey, setShowApiKey] = useState(false); - const labelColor = useColorModeValue("gray.600", "gray.300"); - const textColor = useColorModeValue("gray.700", "gray.200"); + const { control, errors, watch, isEditing, selectedProvider } = + useInferenceProviderContext(); + const [showApiKey, setShowApiKey] = useState(false); + const labelColor = useColorModeValue("gray.600", "gray.300"); + const textColor = useColorModeValue("gray.700", "gray.200"); - // Watch form values for read-only display - const baseUrl = watch("baseUrl"); - const isEnabled = watch("isEnabled"); + // Watch form values for read-only display + const baseUrl = watch("baseUrl"); + const isEnabled = watch("isEnabled"); - // Toggle API key visibility - const toggleShowApiKey = () => setShowApiKey(!showApiKey); + // Toggle API key visibility + const toggleShowApiKey = () => setShowApiKey(!showApiKey); - if (!isEditing) { - return ( - - - - Base URL - - {baseUrl} - + if (!isEditing) { + return ( + + + + Base URL + + {baseUrl} + - - - API Key - - - - - - - - + + + API Key + + + + + + + + - - - Status - - {isEnabled ? "Enabled" : "Disabled"} - - - ); - } + + + Status + + {isEnabled ? "Enabled" : "Disabled"} + + + ); + } - return ( - - ( - - Base URL - - {errors.baseUrl?.message} - - )} - /> + return ( + + ( + + Base URL + + {errors.baseUrl?.message} + + )} + /> - { - // Ensure we always have a defined string value - const value = field.value || ""; - return ( - - API Key - - field.onChange(e.target.value)} - /> - - - - - {errors.apiKey?.message || (value === "" && "API key is required")} - - ); - }} - /> + { + // Ensure we always have a defined string value + const value = field.value || ""; + return ( + + API Key + + field.onChange(e.target.value)} + /> + + + + + + {errors.apiKey?.message || + (value === "" && "API key is required")} + + + ); + }} + /> - ( - - - - Enabled - - - - )} - /> - - ); + ( + + + + Enabled + + + + )} + /> + + ); } diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/EnvironmentSelect.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/EnvironmentSelect.tsx index bae1a4ab..fec1c6b9 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/EnvironmentSelect.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/EnvironmentSelect.tsx @@ -8,8 +8,8 @@ import { useColorModeValue } from "@/components/ui/theme/color-mode"; import { Field, createListCollection } from "@chakra-ui/react"; // SPDX-License-Identifier: Apache-2.0 import { Controller } from "react-hook-form"; -import { PROVIDER_ENVIRONMENTS } from "../../../constants"; -import { useInferenceProviderContext } from "../../context"; +import { PROVIDER_ENVIRONMENTS } from "../../../../constants"; +import { useInferenceProviderContext } from "../../../context"; export function EnvironmentSelect() { const { control, errors } = useInferenceProviderContext(); diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/RateLimitsSection.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/RateLimitsSection.tsx index 3dde8c3d..981ce855 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/RateLimitsSection.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/RateLimitsSection.tsx @@ -11,7 +11,7 @@ import { // SPDX-License-Identifier: Apache-2.0 import type { ChangeEvent } from "react"; import { Controller } from "react-hook-form"; -import { useInferenceProviderContext } from "../../context"; +import { useInferenceProviderContext } from "../../../context"; /** * Component for displaying and editing provider rate limits diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/index.ts b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/index.ts index fc3f7fee..0bbcc269 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/index.ts +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/index.ts @@ -3,7 +3,4 @@ export * from "./BasicInfoSection"; export * from "./ConnectionSection"; export * from "./RateLimitsSection"; export * from "./EnvironmentSelect"; -export * from "./ProviderSelect"; -export * from "../../../../../components/ui/status/StatusMessage"; export * from "./ModelSelector"; -export * from "../../../../../components/ui/buttons/ActionButton"; diff --git a/graphcap_studio/src/features/inference/providers/ProvidersList.tsx b/graphcap_studio/src/features/inference/providers/ProvidersList.tsx index a4efaa96..3aa08713 100644 --- a/graphcap_studio/src/features/inference/providers/ProvidersList.tsx +++ b/graphcap_studio/src/features/inference/providers/ProvidersList.tsx @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -import { ProviderSelect } from "./ProviderConnection/form"; +import { ProviderSelect } from "./ProviderConnection/components/ProviderSelect"; import { useProviderFormContext } from "./context"; diff --git a/graphcap_studio/src/features/inference/providers/ProvidersPanel.tsx b/graphcap_studio/src/features/inference/providers/ProvidersPanel.tsx index 2a8a2b18..b6678318 100644 --- a/graphcap_studio/src/features/inference/providers/ProvidersPanel.tsx +++ b/graphcap_studio/src/features/inference/providers/ProvidersPanel.tsx @@ -5,7 +5,7 @@ import { useCallback, useMemo, useState } from "react"; import type { ProviderCreate, ProviderUpdate } from "../providers/types"; import { useProviders } from "../services/providers"; import ProviderForm from "./ProviderConnection/ProviderForm"; -import { ProviderSelect } from "./ProviderConnection/form"; +import { ProviderSelect } from "./ProviderConnection/components/ProviderSelect"; import { InferenceProviderProvider, useInferenceProviderContext, diff --git a/graphcap_studio/src/features/inference/providers/index.ts b/graphcap_studio/src/features/inference/providers/index.ts index 0a6ed1da..398098f7 100644 --- a/graphcap_studio/src/features/inference/providers/index.ts +++ b/graphcap_studio/src/features/inference/providers/index.ts @@ -3,7 +3,6 @@ export { default as ProviderForm } from "./ProviderConnection/ProviderForm"; export { ProvidersPanel } from "./ProvidersPanel"; export { default as ProvidersList } from "./ProvidersList"; -export { ModelSelectionSection } from "./ProviderConnection/form/ModelSelectionSection"; export { ProviderFormTabs } from "./ProviderConnection/ProviderFormTabs"; export { ProviderFormActions } from "./ProviderConnection/ProviderFormActions"; export * from "../hooks"; diff --git a/pyproject.toml b/pyproject.toml index 99491de5..2057d0dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,7 @@ exclude = [ ] # Source directories -src = ["./servers/inference_server"] +src = ["./servers/inference_bridge"] # Same as Black line-length = 120 diff --git a/servers/inference_server/README.md b/servers/inference_bridge/README.md similarity index 100% rename from servers/inference_server/README.md rename to servers/inference_bridge/README.md diff --git a/servers/inference_server/__init__.py b/servers/inference_bridge/__init__.py similarity index 100% rename from servers/inference_server/__init__.py rename to servers/inference_bridge/__init__.py diff --git a/servers/inference_server/graphcap/__init__.py b/servers/inference_bridge/graphcap/__init__.py similarity index 100% rename from servers/inference_server/graphcap/__init__.py rename to servers/inference_bridge/graphcap/__init__.py diff --git a/servers/inference_server/graphcap/perspectives/__init__.py b/servers/inference_bridge/graphcap/perspectives/__init__.py similarity index 100% rename from servers/inference_server/graphcap/perspectives/__init__.py rename to servers/inference_bridge/graphcap/perspectives/__init__.py diff --git a/servers/inference_server/graphcap/perspectives/base.py b/servers/inference_bridge/graphcap/perspectives/base.py similarity index 100% rename from servers/inference_server/graphcap/perspectives/base.py rename to servers/inference_bridge/graphcap/perspectives/base.py diff --git a/servers/inference_server/graphcap/perspectives/base_caption.py b/servers/inference_bridge/graphcap/perspectives/base_caption.py similarity index 100% rename from servers/inference_server/graphcap/perspectives/base_caption.py rename to servers/inference_bridge/graphcap/perspectives/base_caption.py diff --git a/servers/inference_server/graphcap/perspectives/constants.py b/servers/inference_bridge/graphcap/perspectives/constants.py similarity index 100% rename from servers/inference_server/graphcap/perspectives/constants.py rename to servers/inference_bridge/graphcap/perspectives/constants.py diff --git a/servers/inference_server/graphcap/perspectives/loaders/__init__.py b/servers/inference_bridge/graphcap/perspectives/loaders/__init__.py similarity index 100% rename from servers/inference_server/graphcap/perspectives/loaders/__init__.py rename to servers/inference_bridge/graphcap/perspectives/loaders/__init__.py diff --git a/servers/inference_server/graphcap/perspectives/loaders/directory.py b/servers/inference_bridge/graphcap/perspectives/loaders/directory.py similarity index 100% rename from servers/inference_server/graphcap/perspectives/loaders/directory.py rename to servers/inference_bridge/graphcap/perspectives/loaders/directory.py diff --git a/servers/inference_server/graphcap/perspectives/loaders/json_file.py b/servers/inference_bridge/graphcap/perspectives/loaders/json_file.py similarity index 100% rename from servers/inference_server/graphcap/perspectives/loaders/json_file.py rename to servers/inference_bridge/graphcap/perspectives/loaders/json_file.py diff --git a/servers/inference_server/graphcap/perspectives/loaders/modules.py b/servers/inference_bridge/graphcap/perspectives/loaders/modules.py similarity index 100% rename from servers/inference_server/graphcap/perspectives/loaders/modules.py rename to servers/inference_bridge/graphcap/perspectives/loaders/modules.py diff --git a/servers/inference_server/graphcap/perspectives/loaders/settings.py b/servers/inference_bridge/graphcap/perspectives/loaders/settings.py similarity index 100% rename from servers/inference_server/graphcap/perspectives/loaders/settings.py rename to servers/inference_bridge/graphcap/perspectives/loaders/settings.py diff --git a/servers/inference_server/graphcap/perspectives/models.py b/servers/inference_bridge/graphcap/perspectives/models.py similarity index 100% rename from servers/inference_server/graphcap/perspectives/models.py rename to servers/inference_bridge/graphcap/perspectives/models.py diff --git a/servers/inference_server/graphcap/perspectives/module.py b/servers/inference_bridge/graphcap/perspectives/module.py similarity index 100% rename from servers/inference_server/graphcap/perspectives/module.py rename to servers/inference_bridge/graphcap/perspectives/module.py diff --git a/servers/inference_server/graphcap/perspectives/perspective_loader.py b/servers/inference_bridge/graphcap/perspectives/perspective_loader.py similarity index 100% rename from servers/inference_server/graphcap/perspectives/perspective_loader.py rename to servers/inference_bridge/graphcap/perspectives/perspective_loader.py diff --git a/servers/inference_server/graphcap/perspectives/processor.py b/servers/inference_bridge/graphcap/perspectives/processor.py similarity index 100% rename from servers/inference_server/graphcap/perspectives/processor.py rename to servers/inference_bridge/graphcap/perspectives/processor.py diff --git a/servers/inference_server/graphcap/perspectives/types.py b/servers/inference_bridge/graphcap/perspectives/types.py similarity index 100% rename from servers/inference_server/graphcap/perspectives/types.py rename to servers/inference_bridge/graphcap/perspectives/types.py diff --git a/servers/inference_server/graphcap/providers/README.md b/servers/inference_bridge/graphcap/providers/README.md similarity index 100% rename from servers/inference_server/graphcap/providers/README.md rename to servers/inference_bridge/graphcap/providers/README.md diff --git a/servers/inference_server/graphcap/providers/__init__.py b/servers/inference_bridge/graphcap/providers/__init__.py similarity index 100% rename from servers/inference_server/graphcap/providers/__init__.py rename to servers/inference_bridge/graphcap/providers/__init__.py diff --git a/servers/inference_server/graphcap/providers/clients/__init__.py b/servers/inference_bridge/graphcap/providers/clients/__init__.py similarity index 100% rename from servers/inference_server/graphcap/providers/clients/__init__.py rename to servers/inference_bridge/graphcap/providers/clients/__init__.py diff --git a/servers/inference_server/graphcap/providers/clients/base_client.py b/servers/inference_bridge/graphcap/providers/clients/base_client.py similarity index 100% rename from servers/inference_server/graphcap/providers/clients/base_client.py rename to servers/inference_bridge/graphcap/providers/clients/base_client.py diff --git a/servers/inference_server/graphcap/providers/clients/gemini_client.py b/servers/inference_bridge/graphcap/providers/clients/gemini_client.py similarity index 100% rename from servers/inference_server/graphcap/providers/clients/gemini_client.py rename to servers/inference_bridge/graphcap/providers/clients/gemini_client.py diff --git a/servers/inference_server/graphcap/providers/clients/ollama_client.py b/servers/inference_bridge/graphcap/providers/clients/ollama_client.py similarity index 100% rename from servers/inference_server/graphcap/providers/clients/ollama_client.py rename to servers/inference_bridge/graphcap/providers/clients/ollama_client.py diff --git a/servers/inference_server/graphcap/providers/clients/openai_client.py b/servers/inference_bridge/graphcap/providers/clients/openai_client.py similarity index 100% rename from servers/inference_server/graphcap/providers/clients/openai_client.py rename to servers/inference_bridge/graphcap/providers/clients/openai_client.py diff --git a/servers/inference_server/graphcap/providers/clients/openrouter_client.py b/servers/inference_bridge/graphcap/providers/clients/openrouter_client.py similarity index 100% rename from servers/inference_server/graphcap/providers/clients/openrouter_client.py rename to servers/inference_bridge/graphcap/providers/clients/openrouter_client.py diff --git a/servers/inference_server/graphcap/providers/clients/vllm_client.py b/servers/inference_bridge/graphcap/providers/clients/vllm_client.py similarity index 100% rename from servers/inference_server/graphcap/providers/clients/vllm_client.py rename to servers/inference_bridge/graphcap/providers/clients/vllm_client.py diff --git a/servers/inference_server/graphcap/providers/factory.py b/servers/inference_bridge/graphcap/providers/factory.py similarity index 100% rename from servers/inference_server/graphcap/providers/factory.py rename to servers/inference_bridge/graphcap/providers/factory.py diff --git a/servers/inference_server/graphcap/providers/types.py b/servers/inference_bridge/graphcap/providers/types.py similarity index 100% rename from servers/inference_server/graphcap/providers/types.py rename to servers/inference_bridge/graphcap/providers/types.py diff --git a/servers/inference_server/pipelines/.dep_hash b/servers/inference_bridge/pipelines/.dep_hash similarity index 100% rename from servers/inference_server/pipelines/.dep_hash rename to servers/inference_bridge/pipelines/.dep_hash diff --git a/servers/inference_server/pipelines/.dockerignore b/servers/inference_bridge/pipelines/.dockerignore similarity index 100% rename from servers/inference_server/pipelines/.dockerignore rename to servers/inference_bridge/pipelines/.dockerignore diff --git a/servers/inference_server/pipelines/Dockerfile.pipelines.dev b/servers/inference_bridge/pipelines/Dockerfile.pipelines.dev similarity index 100% rename from servers/inference_server/pipelines/Dockerfile.pipelines.dev rename to servers/inference_bridge/pipelines/Dockerfile.pipelines.dev diff --git a/servers/inference_server/pipelines/README.md b/servers/inference_bridge/pipelines/README.md similarity index 100% rename from servers/inference_server/pipelines/README.md rename to servers/inference_bridge/pipelines/README.md diff --git a/servers/inference_server/pipelines/Taskfile.pipelines.yml b/servers/inference_bridge/pipelines/Taskfile.pipelines.yml similarity index 100% rename from servers/inference_server/pipelines/Taskfile.pipelines.yml rename to servers/inference_bridge/pipelines/Taskfile.pipelines.yml diff --git a/servers/inference_server/pipelines/_scripts/pipeline_entrypoint.sh b/servers/inference_bridge/pipelines/_scripts/pipeline_entrypoint.sh similarity index 100% rename from servers/inference_server/pipelines/_scripts/pipeline_entrypoint.sh rename to servers/inference_bridge/pipelines/_scripts/pipeline_entrypoint.sh diff --git a/servers/inference_server/pipelines/dagster.example.yml b/servers/inference_bridge/pipelines/dagster.example.yml similarity index 100% rename from servers/inference_server/pipelines/dagster.example.yml rename to servers/inference_bridge/pipelines/dagster.example.yml diff --git a/servers/inference_server/pipelines/pipelines/__init__.py b/servers/inference_bridge/pipelines/pipelines/__init__.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/__init__.py rename to servers/inference_bridge/pipelines/pipelines/__init__.py diff --git a/servers/inference_server/pipelines/pipelines/assets.py b/servers/inference_bridge/pipelines/pipelines/assets.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/assets.py rename to servers/inference_bridge/pipelines/pipelines/assets.py diff --git a/servers/inference_server/pipelines/pipelines/common/__init__.py b/servers/inference_bridge/pipelines/pipelines/common/__init__.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/common/__init__.py rename to servers/inference_bridge/pipelines/pipelines/common/__init__.py diff --git a/servers/inference_server/pipelines/pipelines/common/constants.py b/servers/inference_bridge/pipelines/pipelines/common/constants.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/common/constants.py rename to servers/inference_bridge/pipelines/pipelines/common/constants.py diff --git a/servers/inference_server/pipelines/pipelines/common/io.py b/servers/inference_bridge/pipelines/pipelines/common/io.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/common/io.py rename to servers/inference_bridge/pipelines/pipelines/common/io.py diff --git a/servers/inference_server/pipelines/pipelines/common/logging.py b/servers/inference_bridge/pipelines/pipelines/common/logging.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/common/logging.py rename to servers/inference_bridge/pipelines/pipelines/common/logging.py diff --git a/servers/inference_server/pipelines/pipelines/common/resources.py b/servers/inference_bridge/pipelines/pipelines/common/resources.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/common/resources.py rename to servers/inference_bridge/pipelines/pipelines/common/resources.py diff --git a/servers/inference_server/pipelines/pipelines/common/utils.py b/servers/inference_bridge/pipelines/pipelines/common/utils.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/common/utils.py rename to servers/inference_bridge/pipelines/pipelines/common/utils.py diff --git a/servers/inference_server/pipelines/pipelines/common/workspace.py b/servers/inference_bridge/pipelines/pipelines/common/workspace.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/common/workspace.py rename to servers/inference_bridge/pipelines/pipelines/common/workspace.py diff --git a/servers/inference_server/pipelines/pipelines/definitions.py b/servers/inference_bridge/pipelines/pipelines/definitions.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/definitions.py rename to servers/inference_bridge/pipelines/pipelines/definitions.py diff --git a/servers/inference_server/pipelines/pipelines/huggingface/__init__.py b/servers/inference_bridge/pipelines/pipelines/huggingface/__init__.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/huggingface/__init__.py rename to servers/inference_bridge/pipelines/pipelines/huggingface/__init__.py diff --git a/servers/inference_server/pipelines/pipelines/huggingface/client.py b/servers/inference_bridge/pipelines/pipelines/huggingface/client.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/huggingface/client.py rename to servers/inference_bridge/pipelines/pipelines/huggingface/client.py diff --git a/servers/inference_server/pipelines/pipelines/huggingface/dataset_export.py b/servers/inference_bridge/pipelines/pipelines/huggingface/dataset_export.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/huggingface/dataset_export.py rename to servers/inference_bridge/pipelines/pipelines/huggingface/dataset_export.py diff --git a/servers/inference_server/pipelines/pipelines/huggingface/dataset_import.py b/servers/inference_bridge/pipelines/pipelines/huggingface/dataset_import.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/huggingface/dataset_import.py rename to servers/inference_bridge/pipelines/pipelines/huggingface/dataset_import.py diff --git a/servers/inference_server/pipelines/pipelines/huggingface/dataset_manifest.py b/servers/inference_bridge/pipelines/pipelines/huggingface/dataset_manifest.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/huggingface/dataset_manifest.py rename to servers/inference_bridge/pipelines/pipelines/huggingface/dataset_manifest.py diff --git a/servers/inference_server/pipelines/pipelines/huggingface/dataset_prep.py b/servers/inference_bridge/pipelines/pipelines/huggingface/dataset_prep.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/huggingface/dataset_prep.py rename to servers/inference_bridge/pipelines/pipelines/huggingface/dataset_prep.py diff --git a/servers/inference_server/pipelines/pipelines/huggingface/dataset_readme.py b/servers/inference_bridge/pipelines/pipelines/huggingface/dataset_readme.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/huggingface/dataset_readme.py rename to servers/inference_bridge/pipelines/pipelines/huggingface/dataset_readme.py diff --git a/servers/inference_server/pipelines/pipelines/huggingface/perspective_export.py b/servers/inference_bridge/pipelines/pipelines/huggingface/perspective_export.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/huggingface/perspective_export.py rename to servers/inference_bridge/pipelines/pipelines/huggingface/perspective_export.py diff --git a/servers/inference_server/pipelines/pipelines/huggingface/types.py b/servers/inference_bridge/pipelines/pipelines/huggingface/types.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/huggingface/types.py rename to servers/inference_bridge/pipelines/pipelines/huggingface/types.py diff --git a/servers/inference_server/pipelines/pipelines/io/__init__.py b/servers/inference_bridge/pipelines/pipelines/io/__init__.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/io/__init__.py rename to servers/inference_bridge/pipelines/pipelines/io/__init__.py diff --git a/servers/inference_server/pipelines/pipelines/io/image/__init__.py b/servers/inference_bridge/pipelines/pipelines/io/image/__init__.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/io/image/__init__.py rename to servers/inference_bridge/pipelines/pipelines/io/image/__init__.py diff --git a/servers/inference_server/pipelines/pipelines/io/image/image_metadata/__init__.py b/servers/inference_bridge/pipelines/pipelines/io/image/image_metadata/__init__.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/io/image/image_metadata/__init__.py rename to servers/inference_bridge/pipelines/pipelines/io/image/image_metadata/__init__.py diff --git a/servers/inference_server/pipelines/pipelines/io/image/image_metadata/common_formats/__init__.py b/servers/inference_bridge/pipelines/pipelines/io/image/image_metadata/common_formats/__init__.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/io/image/image_metadata/common_formats/__init__.py rename to servers/inference_bridge/pipelines/pipelines/io/image/image_metadata/common_formats/__init__.py diff --git a/servers/inference_server/pipelines/pipelines/io/image/image_metadata/common_formats/iptc_metadata.py b/servers/inference_bridge/pipelines/pipelines/io/image/image_metadata/common_formats/iptc_metadata.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/io/image/image_metadata/common_formats/iptc_metadata.py rename to servers/inference_bridge/pipelines/pipelines/io/image/image_metadata/common_formats/iptc_metadata.py diff --git a/servers/inference_server/pipelines/pipelines/io/image/image_metadata/common_formats/xmp_metadata.py b/servers/inference_bridge/pipelines/pipelines/io/image/image_metadata/common_formats/xmp_metadata.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/io/image/image_metadata/common_formats/xmp_metadata.py rename to servers/inference_bridge/pipelines/pipelines/io/image/image_metadata/common_formats/xmp_metadata.py diff --git a/servers/inference_server/pipelines/pipelines/io/image/image_metadata/extract_exif.py b/servers/inference_bridge/pipelines/pipelines/io/image/image_metadata/extract_exif.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/io/image/image_metadata/extract_exif.py rename to servers/inference_bridge/pipelines/pipelines/io/image/image_metadata/extract_exif.py diff --git a/servers/inference_server/pipelines/pipelines/io/image/load_images.py b/servers/inference_bridge/pipelines/pipelines/io/image/load_images.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/io/image/load_images.py rename to servers/inference_bridge/pipelines/pipelines/io/image/load_images.py diff --git a/servers/inference_server/pipelines/pipelines/io/image/types.py b/servers/inference_bridge/pipelines/pipelines/io/image/types.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/io/image/types.py rename to servers/inference_bridge/pipelines/pipelines/io/image/types.py diff --git a/servers/inference_server/pipelines/pipelines/jobs/__init__.py b/servers/inference_bridge/pipelines/pipelines/jobs/__init__.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/jobs/__init__.py rename to servers/inference_bridge/pipelines/pipelines/jobs/__init__.py diff --git a/servers/inference_server/pipelines/pipelines/jobs/dataset_import_job.py b/servers/inference_bridge/pipelines/pipelines/jobs/dataset_import_job.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/jobs/dataset_import_job.py rename to servers/inference_bridge/pipelines/pipelines/jobs/dataset_import_job.py diff --git a/servers/inference_server/pipelines/pipelines/jobs/image_metadata.py b/servers/inference_bridge/pipelines/pipelines/jobs/image_metadata.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/jobs/image_metadata.py rename to servers/inference_bridge/pipelines/pipelines/jobs/image_metadata.py diff --git a/servers/inference_server/pipelines/pipelines/jobs/omi.py b/servers/inference_bridge/pipelines/pipelines/jobs/omi.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/jobs/omi.py rename to servers/inference_bridge/pipelines/pipelines/jobs/omi.py diff --git a/servers/inference_server/pipelines/pipelines/perspectives/__init__.py b/servers/inference_bridge/pipelines/pipelines/perspectives/__init__.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/perspectives/__init__.py rename to servers/inference_bridge/pipelines/pipelines/perspectives/__init__.py diff --git a/servers/inference_server/pipelines/pipelines/perspectives/assets.py b/servers/inference_bridge/pipelines/pipelines/perspectives/assets.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/perspectives/assets.py rename to servers/inference_bridge/pipelines/pipelines/perspectives/assets.py diff --git a/servers/inference_server/pipelines/pipelines/perspectives/jobs/__init__.py b/servers/inference_bridge/pipelines/pipelines/perspectives/jobs/__init__.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/perspectives/jobs/__init__.py rename to servers/inference_bridge/pipelines/pipelines/perspectives/jobs/__init__.py diff --git a/servers/inference_server/pipelines/pipelines/perspectives/jobs/basic_perspective_pipeline.py b/servers/inference_bridge/pipelines/pipelines/perspectives/jobs/basic_perspective_pipeline.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/perspectives/jobs/basic_perspective_pipeline.py rename to servers/inference_bridge/pipelines/pipelines/perspectives/jobs/basic_perspective_pipeline.py diff --git a/servers/inference_server/pipelines/pipelines/perspectives/jobs/config.py b/servers/inference_bridge/pipelines/pipelines/perspectives/jobs/config.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/perspectives/jobs/config.py rename to servers/inference_bridge/pipelines/pipelines/perspectives/jobs/config.py diff --git a/servers/inference_server/pipelines/pipelines/perspectives/types.py b/servers/inference_bridge/pipelines/pipelines/perspectives/types.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/perspectives/types.py rename to servers/inference_bridge/pipelines/pipelines/perspectives/types.py diff --git a/servers/inference_server/pipelines/pipelines/providers/__init__.py b/servers/inference_bridge/pipelines/pipelines/providers/__init__.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/providers/__init__.py rename to servers/inference_bridge/pipelines/pipelines/providers/__init__.py diff --git a/servers/inference_server/pipelines/pipelines/providers/assets.py b/servers/inference_bridge/pipelines/pipelines/providers/assets.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/providers/assets.py rename to servers/inference_bridge/pipelines/pipelines/providers/assets.py diff --git a/servers/inference_server/pipelines/pipelines/providers/util.py b/servers/inference_bridge/pipelines/pipelines/providers/util.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/providers/util.py rename to servers/inference_bridge/pipelines/pipelines/providers/util.py diff --git a/servers/inference_server/pipelines/pipelines/start.py b/servers/inference_bridge/pipelines/pipelines/start.py similarity index 100% rename from servers/inference_server/pipelines/pipelines/start.py rename to servers/inference_bridge/pipelines/pipelines/start.py diff --git a/servers/inference_server/pipelines/pipelines_tests/__init__.py b/servers/inference_bridge/pipelines/pipelines_tests/__init__.py similarity index 100% rename from servers/inference_server/pipelines/pipelines_tests/__init__.py rename to servers/inference_bridge/pipelines/pipelines_tests/__init__.py diff --git a/servers/inference_server/pipelines/pipelines_tests/test_assets.py b/servers/inference_bridge/pipelines/pipelines_tests/test_assets.py similarity index 100% rename from servers/inference_server/pipelines/pipelines_tests/test_assets.py rename to servers/inference_bridge/pipelines/pipelines_tests/test_assets.py diff --git a/servers/inference_server/pipelines/pyproject.toml b/servers/inference_bridge/pipelines/pyproject.toml similarity index 100% rename from servers/inference_server/pipelines/pyproject.toml rename to servers/inference_bridge/pipelines/pyproject.toml diff --git a/servers/inference_server/pipelines/setup.cfg b/servers/inference_bridge/pipelines/setup.cfg similarity index 100% rename from servers/inference_server/pipelines/setup.cfg rename to servers/inference_bridge/pipelines/setup.cfg diff --git a/servers/inference_server/pipelines/setup.py b/servers/inference_bridge/pipelines/setup.py similarity index 100% rename from servers/inference_server/pipelines/setup.py rename to servers/inference_bridge/pipelines/setup.py diff --git a/servers/inference_server/pipelines/uv.lock b/servers/inference_bridge/pipelines/uv.lock similarity index 100% rename from servers/inference_server/pipelines/uv.lock rename to servers/inference_bridge/pipelines/uv.lock diff --git a/servers/inference_server/pyproject.toml b/servers/inference_bridge/pyproject.toml similarity index 100% rename from servers/inference_server/pyproject.toml rename to servers/inference_bridge/pyproject.toml diff --git a/servers/inference_server/pytest.ini b/servers/inference_bridge/pytest.ini similarity index 100% rename from servers/inference_server/pytest.ini rename to servers/inference_bridge/pytest.ini diff --git a/servers/inference_server/scripts/__init__.py b/servers/inference_bridge/scripts/__init__.py similarity index 100% rename from servers/inference_server/scripts/__init__.py rename to servers/inference_bridge/scripts/__init__.py diff --git a/servers/inference_server/scripts/__main__.py b/servers/inference_bridge/scripts/__main__.py similarity index 100% rename from servers/inference_server/scripts/__main__.py rename to servers/inference_bridge/scripts/__main__.py diff --git a/servers/inference_server/scripts/config_writer.py b/servers/inference_bridge/scripts/config_writer.py similarity index 100% rename from servers/inference_server/scripts/config_writer.py rename to servers/inference_bridge/scripts/config_writer.py diff --git a/servers/inference_server/scripts/setup.py b/servers/inference_bridge/scripts/setup.py similarity index 100% rename from servers/inference_server/scripts/setup.py rename to servers/inference_bridge/scripts/setup.py diff --git a/servers/inference_server/server/.dep_hash b/servers/inference_bridge/server/.dep_hash similarity index 100% rename from servers/inference_server/server/.dep_hash rename to servers/inference_bridge/server/.dep_hash diff --git a/servers/inference_server/server/.dockerignore b/servers/inference_bridge/server/.dockerignore similarity index 100% rename from servers/inference_server/server/.dockerignore rename to servers/inference_bridge/server/.dockerignore diff --git a/servers/inference_server/server/.env.local.template b/servers/inference_bridge/server/.env.local.template similarity index 100% rename from servers/inference_server/server/.env.local.template rename to servers/inference_bridge/server/.env.local.template diff --git a/servers/inference_server/server/Dockerfile.server.dev b/servers/inference_bridge/server/Dockerfile.server.dev similarity index 100% rename from servers/inference_server/server/Dockerfile.server.dev rename to servers/inference_bridge/server/Dockerfile.server.dev diff --git a/servers/inference_server/server/README.md b/servers/inference_bridge/server/README.md similarity index 100% rename from servers/inference_server/server/README.md rename to servers/inference_bridge/server/README.md diff --git a/servers/inference_server/server/Taskfile.inference.yml b/servers/inference_bridge/server/Taskfile.inference.yml similarity index 69% rename from servers/inference_server/server/Taskfile.inference.yml rename to servers/inference_bridge/server/Taskfile.inference.yml index a2c64d9b..66ca0b6b 100644 --- a/servers/inference_server/server/Taskfile.inference.yml +++ b/servers/inference_bridge/server/Taskfile.inference.yml @@ -4,26 +4,26 @@ tasks: dev: desc: Start the inference server container with watch mode cmds: - - docker compose -f ./docker-compose.yml --env-file ./workspace/config/.env up graphcap_server --watch --build + - docker compose -f ./docker-compose.yml --env-file ./workspace/config/.env up inference_bridge --watch --build start: desc: Start the inference server container cmds: - - docker compose -f ./docker-compose.yml --env-file ./workspace/config/.env up -d graphcap_server + - docker compose -f ./docker-compose.yml --env-file ./workspace/config/.env up -d inference_bridge stop: desc: Stop the inference server container cmds: - - docker compose -f ./docker-compose.yml --env-file ./workspace/config/.env stop graphcap_server + - docker compose -f ./docker-compose.yml --env-file ./workspace/config/.env stop inference_bridge logs: desc: View logs for the inference server container cmds: - - docker compose -f ./docker-compose.yml logs -f graphcap_server + - docker compose -f ./docker-compose.yml logs -f inference_bridge rebuild: desc: Rebuild and restart the inference server container cmds: - - docker compose -f ./docker-compose.yml --env-file ./workspace/config/.env build graphcap_server - - docker compose -f ./docker-compose.yml --env-file ./workspace/config/.env up -d --force-recreate graphcap_server + - docker compose -f ./docker-compose.yml --env-file ./workspace/config/.env build inference_bridge + - docker compose -f ./docker-compose.yml --env-file ./workspace/config/.env up -d --force-recreate inference_bridge diff --git a/servers/inference_server/server/__init__.py b/servers/inference_bridge/server/__init__.py similarity index 100% rename from servers/inference_server/server/__init__.py rename to servers/inference_bridge/server/__init__.py diff --git a/servers/inference_server/server/_scripts/endpoints-entrypoint.sh b/servers/inference_bridge/server/_scripts/endpoints-entrypoint.sh similarity index 100% rename from servers/inference_server/server/_scripts/endpoints-entrypoint.sh rename to servers/inference_bridge/server/_scripts/endpoints-entrypoint.sh diff --git a/servers/inference_server/server/_scripts/gunicorn.conf.py b/servers/inference_bridge/server/_scripts/gunicorn.conf.py similarity index 100% rename from servers/inference_server/server/_scripts/gunicorn.conf.py rename to servers/inference_bridge/server/_scripts/gunicorn.conf.py diff --git a/servers/inference_server/server/pyproject.toml b/servers/inference_bridge/server/pyproject.toml similarity index 100% rename from servers/inference_server/server/pyproject.toml rename to servers/inference_bridge/server/pyproject.toml diff --git a/servers/inference_server/server/server/__init__.py b/servers/inference_bridge/server/server/__init__.py similarity index 100% rename from servers/inference_server/server/server/__init__.py rename to servers/inference_bridge/server/server/__init__.py diff --git a/servers/inference_server/server/server/config.py b/servers/inference_bridge/server/server/config.py similarity index 100% rename from servers/inference_server/server/server/config.py rename to servers/inference_bridge/server/server/config.py diff --git a/servers/inference_server/server/server/config/router.py b/servers/inference_bridge/server/server/config/router.py similarity index 100% rename from servers/inference_server/server/server/config/router.py rename to servers/inference_bridge/server/server/config/router.py diff --git a/servers/inference_server/server/server/db.py b/servers/inference_bridge/server/server/db.py similarity index 100% rename from servers/inference_server/server/server/db.py rename to servers/inference_bridge/server/server/db.py diff --git a/servers/inference_server/server/server/dependencies.py b/servers/inference_bridge/server/server/dependencies.py similarity index 100% rename from servers/inference_server/server/server/dependencies.py rename to servers/inference_bridge/server/server/dependencies.py diff --git a/servers/inference_server/server/server/features/perspectives/__init__.py b/servers/inference_bridge/server/server/features/perspectives/__init__.py similarity index 100% rename from servers/inference_server/server/server/features/perspectives/__init__.py rename to servers/inference_bridge/server/server/features/perspectives/__init__.py diff --git a/servers/inference_server/server/server/features/perspectives/models.py b/servers/inference_bridge/server/server/features/perspectives/models.py similarity index 100% rename from servers/inference_server/server/server/features/perspectives/models.py rename to servers/inference_bridge/server/server/features/perspectives/models.py diff --git a/servers/inference_server/server/server/features/perspectives/router.py b/servers/inference_bridge/server/server/features/perspectives/router.py similarity index 100% rename from servers/inference_server/server/server/features/perspectives/router.py rename to servers/inference_bridge/server/server/features/perspectives/router.py diff --git a/servers/inference_server/server/server/features/perspectives/service.py b/servers/inference_bridge/server/server/features/perspectives/service.py similarity index 100% rename from servers/inference_server/server/server/features/perspectives/service.py rename to servers/inference_bridge/server/server/features/perspectives/service.py diff --git a/servers/inference_server/server/server/features/providers/__init__.py b/servers/inference_bridge/server/server/features/providers/__init__.py similarity index 100% rename from servers/inference_server/server/server/features/providers/__init__.py rename to servers/inference_bridge/server/server/features/providers/__init__.py diff --git a/servers/inference_server/server/server/features/providers/error_handler.py b/servers/inference_bridge/server/server/features/providers/error_handler.py similarity index 100% rename from servers/inference_server/server/server/features/providers/error_handler.py rename to servers/inference_bridge/server/server/features/providers/error_handler.py diff --git a/servers/inference_server/server/server/features/providers/models.py b/servers/inference_bridge/server/server/features/providers/models.py similarity index 100% rename from servers/inference_server/server/server/features/providers/models.py rename to servers/inference_bridge/server/server/features/providers/models.py diff --git a/servers/inference_server/server/server/features/providers/router.py b/servers/inference_bridge/server/server/features/providers/router.py similarity index 100% rename from servers/inference_server/server/server/features/providers/router.py rename to servers/inference_bridge/server/server/features/providers/router.py diff --git a/servers/inference_server/server/server/features/providers/service.py b/servers/inference_bridge/server/server/features/providers/service.py similarity index 100% rename from servers/inference_server/server/server/features/providers/service.py rename to servers/inference_bridge/server/server/features/providers/service.py diff --git a/servers/inference_server/server/server/features/repositories/types.py b/servers/inference_bridge/server/server/features/repositories/types.py similarity index 100% rename from servers/inference_server/server/server/features/repositories/types.py rename to servers/inference_bridge/server/server/features/repositories/types.py diff --git a/servers/inference_server/server/server/main.py b/servers/inference_bridge/server/server/main.py similarity index 100% rename from servers/inference_server/server/server/main.py rename to servers/inference_bridge/server/server/main.py diff --git a/servers/inference_server/server/server/models.py b/servers/inference_bridge/server/server/models.py similarity index 100% rename from servers/inference_server/server/server/models.py rename to servers/inference_bridge/server/server/models.py diff --git a/servers/inference_server/server/server/pipelines/__init__py b/servers/inference_bridge/server/server/pipelines/__init__py similarity index 100% rename from servers/inference_server/server/server/pipelines/__init__py rename to servers/inference_bridge/server/server/pipelines/__init__py diff --git a/servers/inference_server/server/server/pipelines/dagster_client.py b/servers/inference_bridge/server/server/pipelines/dagster_client.py similarity index 100% rename from servers/inference_server/server/server/pipelines/dagster_client.py rename to servers/inference_bridge/server/server/pipelines/dagster_client.py diff --git a/servers/inference_server/server/server/routers.py b/servers/inference_bridge/server/server/routers.py similarity index 100% rename from servers/inference_server/server/server/routers.py rename to servers/inference_bridge/server/server/routers.py diff --git a/servers/inference_server/server/server/utils/__init__.py b/servers/inference_bridge/server/server/utils/__init__.py similarity index 100% rename from servers/inference_server/server/server/utils/__init__.py rename to servers/inference_bridge/server/server/utils/__init__.py diff --git a/servers/inference_server/server/server/utils/logger.py b/servers/inference_bridge/server/server/utils/logger.py similarity index 100% rename from servers/inference_server/server/server/utils/logger.py rename to servers/inference_bridge/server/server/utils/logger.py diff --git a/servers/inference_server/server/server/utils/middleware.py b/servers/inference_bridge/server/server/utils/middleware.py similarity index 100% rename from servers/inference_server/server/server/utils/middleware.py rename to servers/inference_bridge/server/server/utils/middleware.py diff --git a/servers/inference_server/server/server/utils/resizing.py b/servers/inference_bridge/server/server/utils/resizing.py similarity index 100% rename from servers/inference_server/server/server/utils/resizing.py rename to servers/inference_bridge/server/server/utils/resizing.py diff --git a/servers/inference_server/server/uv.lock b/servers/inference_bridge/server/uv.lock similarity index 100% rename from servers/inference_server/server/uv.lock rename to servers/inference_bridge/server/uv.lock diff --git a/servers/inference_server/tests/test_perspective_modules.py b/servers/inference_bridge/tests/test_perspective_modules.py similarity index 100% rename from servers/inference_server/tests/test_perspective_modules.py rename to servers/inference_bridge/tests/test_perspective_modules.py diff --git a/servers/inference_server/uv.lock b/servers/inference_bridge/uv.lock similarity index 100% rename from servers/inference_server/uv.lock rename to servers/inference_bridge/uv.lock From ae07a78ac5f03b655ee705416751df483fa53c22 Mon Sep 17 00:00:00 2001 From: jphillips Date: Wed, 26 Mar 2025 08:41:33 -0500 Subject: [PATCH 14/69] move db schema and provider seed to feature module Signed-off-by: jphillips --- servers/data_service/src/api/routes/index.ts | 2 +- servers/data_service/src/app.ts | 2 +- servers/data_service/src/db/schema/index.ts | 2 +- servers/data_service/src/db/seed/index.ts | 2 +- .../src/features/{providers => provider_config}/controller.ts | 0 .../providers.ts => features/provider_config/db_providers.ts} | 0 .../src/features/{providers => provider_config}/routes.ts | 0 .../src/features/{providers => provider_config}/schemas.ts | 0 .../provider_config/seed_providers.ts} | 4 ++-- 9 files changed, 6 insertions(+), 6 deletions(-) rename servers/data_service/src/features/{providers => provider_config}/controller.ts (100%) rename servers/data_service/src/{db/schema/providers.ts => features/provider_config/db_providers.ts} (100%) rename servers/data_service/src/features/{providers => provider_config}/routes.ts (100%) rename servers/data_service/src/features/{providers => provider_config}/schemas.ts (100%) rename servers/data_service/src/{db/seed/providers.ts => features/provider_config/seed_providers.ts} (97%) diff --git a/servers/data_service/src/api/routes/index.ts b/servers/data_service/src/api/routes/index.ts index 0a151bdf..fac8b6d3 100644 --- a/servers/data_service/src/api/routes/index.ts +++ b/servers/data_service/src/api/routes/index.ts @@ -5,7 +5,7 @@ * This file exports route definitions for client consumption. */ -import { providerRoutes } from '../../features/providers/routes'; +import { providerRoutes } from '../../features/provider_config/routes'; import { batchQueueRoutes } from './batch_queue'; export { providerRoutes, batchQueueRoutes }; \ No newline at end of file diff --git a/servers/data_service/src/app.ts b/servers/data_service/src/app.ts index ebff4278..bbca188c 100644 --- a/servers/data_service/src/app.ts +++ b/servers/data_service/src/app.ts @@ -16,7 +16,7 @@ import { z } from 'zod'; import { batchQueueRoutes } from './api/routes/batch_queue'; import { checkDatabaseConnection } from './db/init'; import { env } from './env'; -import { providerRoutes } from './features/providers/routes'; +import { providerRoutes } from './features/provider_config/routes'; import { errorHandlerMiddleware, notFoundHandler } from './utils/error-handler'; import { logger } from './utils/logger'; diff --git a/servers/data_service/src/db/schema/index.ts b/servers/data_service/src/db/schema/index.ts index 2003e966..7ca32784 100644 --- a/servers/data_service/src/db/schema/index.ts +++ b/servers/data_service/src/db/schema/index.ts @@ -5,5 +5,5 @@ * This file exports all database schema definitions for use with Drizzle ORM. */ -export * from './providers'; +export * from '../../features/provider_config/db_providers'; export * from './batch_queue'; \ No newline at end of file diff --git a/servers/data_service/src/db/seed/index.ts b/servers/data_service/src/db/seed/index.ts index d6cf6a9c..7fbe0124 100644 --- a/servers/data_service/src/db/seed/index.ts +++ b/servers/data_service/src/db/seed/index.ts @@ -6,8 +6,8 @@ * Add new seed operations here in the desired order. */ +import { seedProviders } from '../../features/provider_config/seed_providers'; import { logger } from '../../utils/logger'; -import { seedProviders } from './providers'; /** * Main seed function that orchestrates all seeding operations diff --git a/servers/data_service/src/features/providers/controller.ts b/servers/data_service/src/features/provider_config/controller.ts similarity index 100% rename from servers/data_service/src/features/providers/controller.ts rename to servers/data_service/src/features/provider_config/controller.ts diff --git a/servers/data_service/src/db/schema/providers.ts b/servers/data_service/src/features/provider_config/db_providers.ts similarity index 100% rename from servers/data_service/src/db/schema/providers.ts rename to servers/data_service/src/features/provider_config/db_providers.ts diff --git a/servers/data_service/src/features/providers/routes.ts b/servers/data_service/src/features/provider_config/routes.ts similarity index 100% rename from servers/data_service/src/features/providers/routes.ts rename to servers/data_service/src/features/provider_config/routes.ts diff --git a/servers/data_service/src/features/providers/schemas.ts b/servers/data_service/src/features/provider_config/schemas.ts similarity index 100% rename from servers/data_service/src/features/providers/schemas.ts rename to servers/data_service/src/features/provider_config/schemas.ts diff --git a/servers/data_service/src/db/seed/providers.ts b/servers/data_service/src/features/provider_config/seed_providers.ts similarity index 97% rename from servers/data_service/src/db/seed/providers.ts rename to servers/data_service/src/features/provider_config/seed_providers.ts index 2361d817..b3981faf 100644 --- a/servers/data_service/src/db/seed/providers.ts +++ b/servers/data_service/src/features/provider_config/seed_providers.ts @@ -5,9 +5,9 @@ * This script seeds the database with predefined provider configurations. */ -import { db } from '../index'; -import { providers, providerModels} from '../schema'; import { eq } from 'drizzle-orm'; +import { db } from '../../db/index'; +import { providerModels, providers } from '../../db/schema'; import { logger } from '../../utils/logger'; // Define interfaces for provider configurations From 3dd8a99390e2f7f97525cdf2831a22deba1b9412 Mon Sep 17 00:00:00 2001 From: jphillips Date: Wed, 26 Mar 2025 10:24:56 -0500 Subject: [PATCH 15/69] Split inference app and form context, clean up action flow Signed-off-by: jphillips --- .../inference/hooks/useProviderForm.ts | 127 -------- .../ProviderConnection/ProviderForm.tsx | 63 ---- .../ProviderFormActions.tsx | 84 ----- .../ProviderConnection/component.tsx | 31 ++ .../{ => components}/FormFields.module.css | 0 .../components/LoadingMessage.tsx | 45 +++ .../components/ProviderActions.tsx | 31 ++ .../components/ProviderConnectionActions.tsx | 8 +- .../ProviderConnectionErrorDialog.tsx | 7 +- .../ProviderConnectionSuccessDialog.tsx | 5 +- .../{ => components}/ProviderFormTabs.tsx | 4 +- .../components/ProviderFormView.tsx | 65 ++++ .../components/actions/CancelButton.tsx | 40 +++ .../components/actions/EditButton.tsx | 19 ++ .../components/actions/ProviderSaveDialog.tsx | 223 +++++++++++++ .../actions/TestConnectionButton.tsx | 22 ++ .../components/actions/index.ts | 5 + .../components/form/BasicInfoSection.tsx | 7 +- .../components/form/ConnectionSection.tsx | 7 +- .../components/form/ModelSelectionSection.tsx | 12 +- .../components/form/ProviderFormView.tsx | 138 -------- .../components/form/RateLimitsSection.tsx | 5 +- .../containers/ProviderFormContainer.tsx | 294 ++++++++++++++++++ .../hooks/useProviderConnection.ts | 2 +- .../providers/ProviderConnection/index.ts | 6 +- .../inference/providers/ProvidersList.tsx | 2 +- .../inference/providers/ProvidersPanel.tsx | 43 +-- .../context/InferenceProviderContext.tsx | 4 +- .../providers/context/ProviderFormContext.tsx | 79 +++++ .../src/features/inference/providers/index.ts | 3 - .../src/features/inference/providers/types.ts | 12 + .../features/inference/services/providers.ts | 37 --- .../src/features/provider_config/routes.ts | 36 +-- 33 files changed, 916 insertions(+), 550 deletions(-) delete mode 100644 graphcap_studio/src/features/inference/hooks/useProviderForm.ts delete mode 100644 graphcap_studio/src/features/inference/providers/ProviderConnection/ProviderForm.tsx delete mode 100644 graphcap_studio/src/features/inference/providers/ProviderConnection/ProviderFormActions.tsx create mode 100644 graphcap_studio/src/features/inference/providers/ProviderConnection/component.tsx rename graphcap_studio/src/features/inference/providers/ProviderConnection/{ => components}/FormFields.module.css (100%) create mode 100644 graphcap_studio/src/features/inference/providers/ProviderConnection/components/LoadingMessage.tsx create mode 100644 graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderActions.tsx rename graphcap_studio/src/features/inference/providers/ProviderConnection/{ => components}/ProviderFormTabs.tsx (93%) create mode 100644 graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderFormView.tsx create mode 100644 graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/CancelButton.tsx create mode 100644 graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/EditButton.tsx create mode 100644 graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/ProviderSaveDialog.tsx create mode 100644 graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/TestConnectionButton.tsx create mode 100644 graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/index.ts delete mode 100644 graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ProviderFormView.tsx create mode 100644 graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx create mode 100644 graphcap_studio/src/features/inference/providers/context/ProviderFormContext.tsx diff --git a/graphcap_studio/src/features/inference/hooks/useProviderForm.ts b/graphcap_studio/src/features/inference/hooks/useProviderForm.ts deleted file mode 100644 index 9c6a3970..00000000 --- a/graphcap_studio/src/features/inference/hooks/useProviderForm.ts +++ /dev/null @@ -1,127 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -import type { - ProviderCreate, - ProviderUpdate, -} from "@/features/inference/providers/types"; -import { - useCreateProvider, - useUpdateProvider, - useUpdateProviderApiKey, -} from "@/features/server-connections/services/providers"; - -import { useCallback } from "react"; -import { useForm } from "react-hook-form"; -import { DEFAULT_PROVIDER_FORM_DATA } from "../constants"; - -type FormData = ProviderCreate | ProviderUpdate; - -/** - * Custom hook for managing provider form state and operations - */ -export function useProviderForm(initialData: Partial = {}) { - // Initialize react-hook-form with validation - const { - control, - handleSubmit, - reset, - formState: { errors }, - watch, - } = useForm({ - defaultValues: { - ...DEFAULT_PROVIDER_FORM_DATA, - ...initialData, - // Ensure apiKey is always a string, never undefined - apiKey: initialData.apiKey || '', - }, - mode: "onBlur", - }); - - // Watch the provider name and other fields for use in UI - const providerName = watch("name"); - const fetchModels = watch("fetchModels"); - const defaultModel = watch("defaultModel"); - - // Mutations - const createProvider = useCreateProvider(); - const updateProvider = useUpdateProvider(); - const updateApiKeyMutation = useUpdateProviderApiKey(); - - // Update API key separately (needed because the server has a separate endpoint) - const updateApiKey = useCallback(async (providerId: number, apiKey: string) => { - if (!apiKey.trim()) { - console.warn("Attempted to update with empty API key, skipping"); - return { success: false, error: "API key cannot be empty" }; - } - - try { - await updateApiKeyMutation.mutateAsync({ id: providerId, apiKey }); - return { success: true }; - } catch (error) { - console.error("Error updating API key:", error); - return { - success: false, - error: error instanceof Error ? error.message : "Unknown error", - }; - } - }, [updateApiKeyMutation]); - - // Handle form submission - const onSubmit = useCallback( - async (data: FormData, isCreating: boolean, providerId?: number) => { - try { - // For create, we need all required fields - if (isCreating) { - // Ensure required fields are present - if (!data.name || !data.kind || !data.environment || !data.baseUrl) { - throw new Error("Missing required fields"); - } - - // For create, we need the API key in the initial request - await createProvider.mutateAsync(data as ProviderCreate); - } else if (providerId) { - // For update, we only need the fields that changed - // apiKey should be handled separately - const { apiKey, ...updateData } = data; - - await updateProvider.mutateAsync({ - id: providerId, - data: updateData as ProviderUpdate, - }); - - // If apiKey is provided and not empty, update it separately - if (apiKey && apiKey.trim() !== '') { - await updateApiKey(providerId, apiKey); - } - } - - reset(DEFAULT_PROVIDER_FORM_DATA); - return { success: true }; - } catch (error) { - console.error("Error submitting provider form:", error); - - // Propagate the error so it can be handled by the UI - throw error; - } - }, - [createProvider, updateProvider, updateApiKey, reset], - ); - - return { - // Form state - control, - handleSubmit, - errors, - watch, - providerName, - fetchModels, - defaultModel, - reset, - - // Form submission - onSubmit, - updateApiKey, - - // Loading state - isSubmitting: createProvider.isPending || updateProvider.isPending || updateApiKeyMutation.isPending, - }; -} diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/ProviderForm.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/ProviderForm.tsx deleted file mode 100644 index 8d444674..00000000 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/ProviderForm.tsx +++ /dev/null @@ -1,63 +0,0 @@ -import { Box, Flex } from "@chakra-ui/react"; -// SPDX-License-Identifier: Apache-2.0 -import { memo } from "react"; -import { useInferenceProviderContext } from "../context"; -import { ProviderFormActions } from "./ProviderFormActions"; -import { ProviderFormTabs } from "./ProviderFormTabs"; -import { ProviderConnectionActions } from "./components/ProviderConnectionActions"; -import { ProviderConnectionErrorDialog } from "./components/ProviderConnectionErrorDialog"; -import { ProviderConnectionSuccessDialog } from "./components/ProviderConnectionSuccessDialog"; -import { useProviderConnection } from "./hooks/useProviderConnection"; - -/** - * Container component for the provider form that handles business logic and state management - */ -function ProviderForm() { - const { selectedProvider, mode } = useInferenceProviderContext(); - const { - isTestingConnection, - connectionError, - connectionDetails, - dialogs, - handleTestConnection, - closeDialog - } = useProviderConnection(selectedProvider); - - const isEditing = mode === "edit"; - const isCreating = mode === "create"; - - return ( - - - - {/* Actions */} - - - {(isEditing || isCreating) && } - - - {/* Connection Error Dialog */} - closeDialog('error')} - error={connectionError} - providerName={selectedProvider?.name || "Provider"} - /> - - {/* Success Dialog */} - closeDialog('success')} - providerName={selectedProvider?.name || "Provider"} - connectionDetails={connectionDetails} - /> - - ); -} - -export default memo(ProviderForm); diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/ProviderFormActions.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/ProviderFormActions.tsx deleted file mode 100644 index 1cc108ba..00000000 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/ProviderFormActions.tsx +++ /dev/null @@ -1,84 +0,0 @@ -import { useColorMode } from "@/components/ui/theme/color-mode"; -import { Button, Flex, HStack } from "@chakra-ui/react"; -// SPDX-License-Identifier: Apache-2.0 -import { useInferenceProviderContext } from "../context"; - -/** - * Component for rendering form action buttons with Chakra UI styling - */ -export function ProviderFormActions() { - const { isSubmitting, isCreating, onCancel } = useInferenceProviderContext(); - - const { colorMode } = useColorMode(); - const isDark = colorMode === "dark"; - - // Determine the button text based on form state - let buttonText = "Save"; - if (isSubmitting) { - buttonText = "Saving..."; - } else if (isCreating) { - buttonText = "Create"; - } - - // Theme-based colors - const cancelBg = isDark ? "gray.700" : "gray.100"; - const cancelHoverBg = isDark ? "gray.600" : "gray.200"; - const cancelColor = isDark ? "gray.200" : "gray.800"; - - const primaryBg = "blue.500"; - const primaryHoverBg = "blue.600"; - - return ( - - - - - - - ); -} diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/component.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/component.tsx new file mode 100644 index 00000000..237e56bd --- /dev/null +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/component.tsx @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: Apache-2.0 +import type { ReactNode } from "react"; +import { useInferenceProviderContext } from "../context"; +import type { ProviderCreate, ProviderUpdate } from "../types"; +import { ProviderFormView } from "./components/ProviderFormView"; +import { ProviderFormContainer } from "./containers/ProviderFormContainer"; + +interface ProviderConnectionProps { + initialData?: Partial; + onSubmit?: (data: ProviderCreate | ProviderUpdate) => Promise; +} + +/** + * Main provider connection component that handles form state and submission + */ +export function ProviderConnection({ + initialData, + onSubmit, +}: ProviderConnectionProps) { + // Get onSubmit from context if not provided as prop + const context = useInferenceProviderContext(); + + // Use provided onSubmit if available, otherwise use the one from context + const handleSubmit = onSubmit || context.onSubmit; + + return ( + + + + ); +} diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/FormFields.module.css b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/FormFields.module.css similarity index 100% rename from graphcap_studio/src/features/inference/providers/ProviderConnection/FormFields.module.css rename to graphcap_studio/src/features/inference/providers/ProviderConnection/components/FormFields.module.css diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/LoadingMessage.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/LoadingMessage.tsx new file mode 100644 index 00000000..7b843cbb --- /dev/null +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/LoadingMessage.tsx @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: Apache-2.0 +import { Box, Spinner, Text } from "@chakra-ui/react"; + +interface LoadingMessageProps { + isSubmitting: boolean; + saveSuccess: boolean; +} + +export function LoadingMessage({ isSubmitting, saveSuccess }: LoadingMessageProps) { + if (!isSubmitting && !saveSuccess) return null; + + return ( + + {isSubmitting && ( + + + Saving changes... + + )} + + {!isSubmitting && saveSuccess && ( + + Provider saved successfully! + + )} + + ); +} \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderActions.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderActions.tsx new file mode 100644 index 00000000..82b64822 --- /dev/null +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderActions.tsx @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: Apache-2.0 +import { Flex, HStack } from "@chakra-ui/react"; +import { useProviderFormContext } from "../../context/ProviderFormContext"; +import { CancelButton, EditButton, SaveButton, TestConnectionButton } from "./actions"; + +/** + * Component for rendering provider form actions based on current mode + */ +export function ProviderActions() { + const { mode } = useProviderFormContext(); + const isEditing = mode === "edit"; + const isCreating = mode === "create"; + + if (isEditing || isCreating) { + return ( + + + + + + + ); + } + + return ( + + + + + ); +} \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionActions.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionActions.tsx index 05d31d7d..592613e4 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionActions.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionActions.tsx @@ -3,10 +3,10 @@ import { Button, Flex } from "@chakra-ui/react"; import { useInferenceProviderContext } from "../../context"; interface ProviderConnectionActionsProps { - isTestingConnection: boolean; - onTest: () => Promise; - disabled?: boolean; - showEditButton?: boolean; + readonly isTestingConnection: boolean; + readonly onTest: () => Promise; + readonly disabled?: boolean; + readonly showEditButton?: boolean; } /** diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionErrorDialog.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionErrorDialog.tsx index 4ec2cc59..249957f9 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionErrorDialog.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionErrorDialog.tsx @@ -12,11 +12,13 @@ import { } from "@chakra-ui/react"; import { useEffect, useRef } from "react"; import { LuTriangleAlert } from "react-icons/lu"; +import type { ErrorDetails as ContextErrorDetails } from "../../types"; type ErrorDetails = { message?: string; name?: string; - details?: string; + details?: string | Record; + code?: string; suggestions?: string[]; requestDetails?: { provider: string; @@ -27,7 +29,7 @@ type ErrorDetails = { type ProviderConnectionErrorDialogProps = { readonly isOpen: boolean; readonly onClose: () => void; - readonly error: ErrorDetails; + readonly error: ErrorDetails | ContextErrorDetails | null; readonly providerName: string; }; @@ -112,6 +114,7 @@ export function ProviderConnectionErrorDialog({ name: errorObj.name, message: errorObj.message, details: errorObj.details, + code: errorObj.code, suggestions: errorObj.suggestions }, null, 2)} diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionSuccessDialog.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionSuccessDialog.tsx index d3596619..f80d645b 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionSuccessDialog.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionSuccessDialog.tsx @@ -10,6 +10,7 @@ import { } from "@chakra-ui/react"; import { useEffect, useRef } from "react"; import { LuCheck, LuCircleAlert } from "react-icons/lu"; +import type { ConnectionDetails as ContextConnectionDetails } from "../../types"; import { type ConnectionStep, ConnectionSteps } from "./ConnectionSteps"; /** @@ -31,14 +32,14 @@ interface ConnectionDetails { message: string; }>; }; - }; + } | boolean; } type ProviderConnectionSuccessDialogProps = { readonly isOpen: boolean; readonly onClose: () => void; readonly providerName: string; - readonly connectionDetails: ConnectionDetails | null; + readonly connectionDetails: ConnectionDetails | ContextConnectionDetails | null; }; const STEP_LABELS: Record = { diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/ProviderFormTabs.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderFormTabs.tsx similarity index 93% rename from graphcap_studio/src/features/inference/providers/ProviderConnection/ProviderFormTabs.tsx rename to graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderFormTabs.tsx index 131e4261..d2f85146 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/ProviderFormTabs.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderFormTabs.tsx @@ -1,8 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 import { Tabs } from "@chakra-ui/react"; import styles from "./FormFields.module.css"; -import { BasicInfoSection, ConnectionSection, RateLimitsSection } from "./components/form"; -import { ModelSelectionSection } from "./components/form/ModelSelectionSection"; +import { BasicInfoSection, ConnectionSection, RateLimitsSection } from "./form"; +import { ModelSelectionSection } from "./form/ModelSelectionSection"; /** * Component for rendering provider form fields in either view or edit mode diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderFormView.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderFormView.tsx new file mode 100644 index 00000000..dbce865a --- /dev/null +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderFormView.tsx @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: Apache-2.0 +import { Box } from "@chakra-ui/react"; +import { useProviderFormContext } from "../../context/ProviderFormContext"; +import { LoadingMessage } from "./LoadingMessage"; +import { ProviderActions } from "./ProviderActions"; +import { ProviderConnectionErrorDialog } from "./ProviderConnectionErrorDialog"; +import { ProviderConnectionSuccessDialog } from "./ProviderConnectionSuccessDialog"; +import { ProviderFormTabs } from "./ProviderFormTabs"; + +/** + * Presentational component for the provider form + */ +export function ProviderFormView() { + const { + onSubmit, + handleSubmit, + isSubmitting, + saveSuccess, + dialogs, + closeDialog, + formError, + connectionError, + connectionDetails, + selectedProvider + } = useProviderFormContext(); + + return ( +
+ + + + + + + + {/* Form Error Dialog */} + closeDialog("formError")} + error={formError} + providerName={selectedProvider?.name || "Provider"} + /> + + {/* Connection Error Dialog */} + closeDialog("error")} + error={connectionError} + providerName={selectedProvider?.name || "Provider"} + /> + + {/* Success Dialog */} + closeDialog("success")} + providerName={selectedProvider?.name || "Provider"} + connectionDetails={connectionDetails} + /> + +
+ ); +} diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/CancelButton.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/CancelButton.tsx new file mode 100644 index 00000000..57860c6a --- /dev/null +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/CancelButton.tsx @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: Apache-2.0 +import { useColorMode } from "@/components/ui/theme/color-mode"; +import { Button } from "@chakra-ui/react"; +import { useProviderFormContext } from "../../../context/ProviderFormContext"; + +/** + * Button component for canceling provider form changes + */ +export function CancelButton() { + const { onCancel } = useProviderFormContext(); + const { colorMode } = useColorMode(); + const isDark = colorMode === "dark"; + + // Theme-based colors + const cancelBg = isDark ? "gray.700" : "gray.100"; + const cancelHoverBg = isDark ? "gray.600" : "gray.200"; + const cancelColor = isDark ? "gray.200" : "gray.800"; + + return ( + + ); +} \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/EditButton.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/EditButton.tsx new file mode 100644 index 00000000..c022263c --- /dev/null +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/EditButton.tsx @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: Apache-2.0 +import { Button } from "@chakra-ui/react"; +import { useProviderFormContext } from "../../../context/ProviderFormContext"; + +/** + * Button component for editing provider + */ +export function EditButton() { + const { setMode } = useProviderFormContext(); + + return ( + + ); +} \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/ProviderSaveDialog.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/ProviderSaveDialog.tsx new file mode 100644 index 00000000..6a9279a0 --- /dev/null +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/ProviderSaveDialog.tsx @@ -0,0 +1,223 @@ +import { + Box, + Button, + CloseButton, + Dialog, + Portal, + Spinner, + Text, + VStack, +} from "@chakra-ui/react"; +// SPDX-License-Identifier: Apache-2.0 +import { useEffect, useState } from "react"; +import { useCreateProvider, useUpdateProvider } from "../../../../services/providers"; +import { useProviderFormContext } from "../../../context/ProviderFormContext"; +import type { Provider, ProviderCreate, ProviderUpdate } from "../../../types"; + +// Define error type with message property +interface ErrorWithMessage { + message: string; + [key: string]: unknown; +} + +/** + * Unified component that combines the save button and save dialog functionality + */ +export function SaveButton() { + const { + isSubmitting: isContextSubmitting, + isCreating, + handleSubmit, + selectedProvider, + saveError: contextSaveError + } = useProviderFormContext(); + + // API mutation hooks + const createProvider = useCreateProvider(); + const updateProvider = useUpdateProvider(); + + // Local state for dialog visibility and save state + const [isDialogOpen, setIsDialogOpen] = useState(false); + const [isSaving, setIsSaving] = useState(false); + const [saveComplete, setSaveComplete] = useState(false); + const [savingProvider, setSavingProvider] = useState(null); + const [saveError, setSaveError] = useState(contextSaveError); + + // Determine if form is submitting + const isSubmitting = isContextSubmitting || isSaving; + + // Determine the button text based on form state + const buttonText = isSubmitting + ? "Saving..." + : isCreating + ? "Create" + : "Save"; + + // Function to close the dialog + const closeDialog = () => { + setIsDialogOpen(false); + setSaveComplete(false); + setSavingProvider(null); + setSaveError(undefined); + }; + + // Custom submit handler that shows the dialog and directly calls API + const handleFormSubmit = async (e: React.FormEvent) => { + try { + setIsSaving(true); + setIsDialogOpen(true); + setSaveError(undefined); + + // Use the form's handleSubmit to get validated data + await handleSubmit(async (data: ProviderCreate | ProviderUpdate) => { + try { + console.log("Provider form submitted:", data); + + // Determine if we're creating or updating based on presence of id + let result: Provider; + if ('id' in data && data.id) { + // Update existing provider + const id = data.id as number; + console.log(`Updating provider with id ${id}`); + result = await updateProvider.mutateAsync({ id, data }); + console.log("Provider updated successfully:", result); + } else { + // Create new provider + console.log("Creating new provider"); + result = await createProvider.mutateAsync(data as ProviderCreate); + console.log("Provider created successfully:", result); + } + + // Success - store the provider details and mark as complete + setSavingProvider(result); + setSaveComplete(true); + } catch (error) { + console.error("Error saving provider:", error); + if (error instanceof Error) { + setSaveError(error.message); + } else if (typeof error === 'object' && error !== null && 'message' in error) { + const errorWithMsg = error as ErrorWithMessage; + setSaveError(errorWithMsg.message); + } else { + setSaveError("An unknown error occurred"); + } + } + })(e); + } catch (error) { + console.error("Form submission error:", error); + if (error instanceof Error) { + setSaveError(error.message); + } else { + setSaveError("Form validation failed"); + } + } finally { + setIsSaving(false); + } + }; + + // Get the current provider to display + const displayProvider = savingProvider || selectedProvider; + + return ( + <> + + + {/* Provider Save Dialog */} + !isSaving && setIsDialogOpen(e.open)} + > + + + + + + + {saveError ? "Error Saving Provider" : isSaving ? "Saving Provider..." : saveComplete ? "Provider Saved" : "Processing..."} + + + + + + + + {saveError ? ( + + {saveError || "An unknown error occurred"} + + ) : isSaving ? ( + + + Saving provider configuration to server... + Please wait while we process your request + + ) : saveComplete && displayProvider ? ( + + + + Name: {displayProvider.name} + Kind: {displayProvider.kind} + Environment: {displayProvider.environment} + Base URL: {displayProvider.baseUrl} + {displayProvider.fetchModels && ( + Default Model: {displayProvider.defaultModel || "Not set"} + )} + + + + ) : ( + + Initializing save process... + + )} + + + + + + + + + + + ); +} \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/TestConnectionButton.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/TestConnectionButton.tsx new file mode 100644 index 00000000..7d07a319 --- /dev/null +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/TestConnectionButton.tsx @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: Apache-2.0 +import { Button } from "@chakra-ui/react"; +import { useProviderFormContext } from "../../../context/ProviderFormContext"; + +/** + * Button component for testing provider connection + */ +export function TestConnectionButton() { + const { isTestingConnection, handleTestConnection, selectedProvider } = useProviderFormContext(); + + return ( + + ); +} \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/index.ts b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/index.ts new file mode 100644 index 00000000..db0723b8 --- /dev/null +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/index.ts @@ -0,0 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +export * from './CancelButton'; +export * from './EditButton'; +export * from './ProviderSaveDialog'; +export * from './TestConnectionButton'; \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/BasicInfoSection.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/BasicInfoSection.tsx index 7af6526d..5d0e952b 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/BasicInfoSection.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/BasicInfoSection.tsx @@ -1,5 +1,5 @@ import { useColorModeValue } from "@/components/ui/theme/color-mode"; -import { useInferenceProviderContext } from "@/features/inference/providers/context"; +// SPDX-License-Identifier: Apache-2.0 import { Box, Field, @@ -9,15 +9,16 @@ import { Text, VStack, } from "@chakra-ui/react"; -// SPDX-License-Identifier: Apache-2.0 import { Controller } from "react-hook-form"; +import { useProviderFormContext } from "../../../context/ProviderFormContext"; import { EnvironmentSelect } from "./EnvironmentSelect"; /** * Component for displaying and editing basic provider information */ export function BasicInfoSection() { - const { control, errors, watch, isEditing } = useInferenceProviderContext(); + const { control, errors, watch, mode } = useProviderFormContext(); + const isEditing = mode === "edit" || mode === "create"; const labelColor = useColorModeValue("gray.600", "gray.300"); const textColor = useColorModeValue("gray.700", "gray.200"); diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ConnectionSection.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ConnectionSection.tsx index fc0a7e7d..40cee055 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ConnectionSection.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ConnectionSection.tsx @@ -13,14 +13,15 @@ import { // SPDX-License-Identifier: Apache-2.0 import { useState } from "react"; import { Controller } from "react-hook-form"; -import { useInferenceProviderContext } from "../../../context"; +import { useProviderFormContext } from "../../../context/ProviderFormContext"; /** * Component for displaying and editing provider connection settings */ export function ConnectionSection() { - const { control, errors, watch, isEditing, selectedProvider } = - useInferenceProviderContext(); + const { control, errors, watch, mode, selectedProvider } = + useProviderFormContext(); + const isEditing = mode === "edit" || mode === "create"; const [showApiKey, setShowApiKey] = useState(false); const labelColor = useColorModeValue("gray.600", "gray.300"); const textColor = useColorModeValue("gray.700", "gray.200"); diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelectionSection.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelectionSection.tsx index 9cc1d18b..4a59a1d9 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelectionSection.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelectionSection.tsx @@ -2,7 +2,7 @@ import { ActionButton } from "@/components/ui/buttons/ActionButton"; import { StatusMessage } from "@/components/ui/status/StatusMessage"; // SPDX-License-Identifier: Apache-2.0 import { Box } from "@chakra-ui/react"; -import { useInferenceProviderContext } from "../../../context"; +import { useProviderFormContext } from "../../../context/ProviderFormContext"; import { ModelSelector } from "./ModelSelector"; // Define the model type @@ -17,16 +17,18 @@ export interface ProviderModel { */ export function ModelSelectionSection() { const { - providerName, - selectedModelId, - setSelectedModelId, + selectedProvider, providerModelsData, isLoadingModels, isModelsError, modelsError, + selectedModelId, + setSelectedModelId, handleModelSelect, isSubmitting, - } = useInferenceProviderContext(); + } = useProviderFormContext(); + + const providerName = selectedProvider?.name; // Handle different states if (!providerName) { diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ProviderFormView.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ProviderFormView.tsx deleted file mode 100644 index 48e54f6f..00000000 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ProviderFormView.tsx +++ /dev/null @@ -1,138 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -import { Box, Button, Flex, Spinner, Text } from "@chakra-ui/react"; -import type { Provider, ProviderCreate, ProviderUpdate } from "../../../types"; -import { ProviderFormActions } from "../../ProviderFormActions"; -import { ProviderFormTabs } from "../../ProviderFormTabs"; -import { ProviderConnectionErrorDialog } from "../ProviderConnectionErrorDialog"; -import { ProviderConnectionSuccessDialog } from "../ProviderConnectionSuccessDialog"; - -interface ProviderFormViewProps { - mode: 'view' | 'edit' | 'create'; - isSubmitting: boolean; - saveSuccess: boolean; - isTestingConnection: boolean; - selectedProvider: Provider | null; - formError: unknown; - connectionError: Record | string | null; - connectionDetails: Record | null; - dialogs: { - error: boolean; - success: boolean; - formError: boolean; - }; - onSubmit: (data: ProviderCreate | ProviderUpdate) => Promise; - handleSubmit: (handler: (data: ProviderCreate | ProviderUpdate) => Promise) => (e: React.FormEvent) => void; - handleTestConnection: () => Promise; - setMode: (mode: 'view' | 'edit' | 'create') => void; - closeDialog: (dialog: 'error' | 'success' | 'formError') => void; -} - -/** - * Presentational component for the provider form - */ -export function ProviderFormView({ - mode, - isSubmitting, - saveSuccess, - isTestingConnection, - selectedProvider, - formError, - connectionError, - connectionDetails, - dialogs, - onSubmit, - handleSubmit, - handleTestConnection, - setMode, - closeDialog, -}: ProviderFormViewProps) { - const isEditing = mode === "edit"; - const isCreating = mode === "create"; - - return ( -
- - - - {/* Loading/Success Message */} - - {isSubmitting && ( - - - Saving changes... - - )} - - {!isSubmitting && saveSuccess && ( - - Provider saved successfully! - - )} - - - {/* Actions */} - - {isEditing || isCreating ? ( - - ) : ( - <> - - - - )} - - - {/* Form Error Dialog */} - closeDialog("formError")} - error={formError} - providerName={selectedProvider?.name || "Provider"} - /> - - {/* Connection Error Dialog */} - closeDialog("error")} - error={connectionError} - providerName={selectedProvider?.name || "Provider"} - /> - - {/* Success Dialog */} - closeDialog("success")} - providerName={selectedProvider?.name || "Provider"} - connectionDetails={connectionDetails} - /> - -
- ); -} diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/RateLimitsSection.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/RateLimitsSection.tsx index 981ce855..377f00d3 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/RateLimitsSection.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/RateLimitsSection.tsx @@ -11,13 +11,14 @@ import { // SPDX-License-Identifier: Apache-2.0 import type { ChangeEvent } from "react"; import { Controller } from "react-hook-form"; -import { useInferenceProviderContext } from "../../../context"; +import { useProviderFormContext } from "../../../context/ProviderFormContext"; /** * Component for displaying and editing provider rate limits */ export function RateLimitsSection() { - const { control, errors, watch, isEditing } = useInferenceProviderContext(); + const { control, errors, watch, mode } = useProviderFormContext(); + const isEditing = mode === "edit" || mode === "create"; const labelColor = useColorModeValue("gray.600", "gray.300"); const textColor = useColorModeValue("gray.700", "gray.200"); diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx new file mode 100644 index 00000000..a765b2b7 --- /dev/null +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx @@ -0,0 +1,294 @@ +// SPDX-License-Identifier: Apache-2.0 +import { useState } from "react"; +import type { ReactNode } from "react"; +import { useModelSelection } from "../../../hooks/useModelSelection"; +import { useProviderForm } from "../../../hooks/useProviderForm"; +import { useProviders } from "../../../services/providers"; +import { useTestProviderConnection } from "../../../services/providers"; +import { ProviderFormProvider } from "../../context/ProviderFormContext"; +import type { ConnectionDetails, ErrorDetails, Provider, ProviderCreate, ProviderModelsResponse, ProviderUpdate } from "../../types"; +import { toServerConfig } from "../../types"; + +// Extended Error interface with cause property +interface ErrorWithCause extends Error { + cause?: unknown; +} + +// Model data type definition matching the context type exactly +interface ProviderModelData { + models: Array<{ id: string; name: string; is_default?: boolean }>; +} + +interface ProviderFormContainerProps { + children: ReactNode; + initialData?: Partial; + onSubmit: (data: ProviderCreate | ProviderUpdate) => Promise; +} + +/** + * Container component that provides the ProviderFormContext + */ +export function ProviderFormContainer({ + children, + initialData, + onSubmit: onSubmitProp, +}: ProviderFormContainerProps) { + // Form state from useProviderForm + const { + control, + handleSubmit: hookHandleSubmit, + errors, + watch, + providerName, + fetchModels, + defaultModel, + reset, + dialogs: formDialogs, + closeDialog: closeFormDialog, + savedProvider, + saveError, + onSubmit: hookOnSubmit, + updateApiKey, + isSubmitting, + } = useProviderForm(initialData); + + // Local state for the provider form + const [mode, setMode] = useState<"view" | "edit" | "create">("view"); + const [saveSuccess, setSaveSuccess] = useState(false); + const [isTestingConnection, setIsTestingConnection] = useState(false); + const [selectedProvider, setSelectedProvider] = useState( + initialData as Provider || null + ); + const [formError, setFormError] = useState(null); + const [connectionError, setConnectionError] = useState(null); + const [connectionDetails, setConnectionDetails] = useState(null); + const [dialogs, setDialogs] = useState({ + error: false, + success: false, + formError: false, + save: formDialogs.save || false, + }); + + // Fetch providers + const { data: providers = [] } = useProviders(); + + // Model selection - handle special case to fix type issues + // Extract values first but handle type casting + const modelSelectionHook = selectedProvider ? useModelSelection(selectedProvider) : null; + + // Now create properly typed values + const selectedModelId = modelSelectionHook ? modelSelectionHook.selectedModelId : null; + + // Explicitly type the model data + const providerModelsData: ProviderModelData | null = modelSelectionHook?.providerModelsData + ? { models: modelSelectionHook.providerModelsData.models || [] } + : null; + + // Type-safe setter function + const setSelectedModelId = (id: string | null) => { + if (modelSelectionHook?.setSelectedModelId && typeof id === 'string') { + modelSelectionHook.setSelectedModelId(id); + } + }; + + // Other model selection properties + const isLoadingModels = !!modelSelectionHook?.isLoadingModels; + const isModelsError = !!modelSelectionHook?.isModelsError; + const modelsError = modelSelectionHook?.modelsError || null; + const handleModelSelect = modelSelectionHook?.handleModelSelect || (() => {}); + + // API connection test hook + const testConnection = useTestProviderConnection(); + + // Function to close any dialog + const closeDialog = (dialog: "error" | "success" | "formError" | "save") => { + if (dialog === "save") { + closeFormDialog("save"); // Handle this through the form hook + } else { + setDialogs(prev => ({ ...prev, [dialog]: false })); + } + }; + + + // Handle form submission + const onSubmit = async (data: ProviderCreate | ProviderUpdate) => { + try { + setFormError(null); + setSaveSuccess(false); + await onSubmitProp(data); + setSaveSuccess(true); + + // Show the save dialog + setDialogs(prev => ({ ...prev, save: true })); + + // Reset success message after 3 seconds + setTimeout(() => { + setSaveSuccess(false); + }, 3000); + } catch (error) { + console.error("Provider form submission error:", error); + + // Convert error to ErrorDetails format + let errorObj: ErrorDetails; + if (error instanceof Error) { + errorObj = { + message: error.message, + code: error.name, + details: { + error: error.toString() + } + }; + + // Try to extract cause if it exists + const errorWithCause = error as ErrorWithCause; + if ('cause' in error && errorWithCause.cause !== undefined) { + errorObj.details = { + ...errorObj.details, + cause: errorWithCause.cause + }; + } + } else if (typeof error === 'object' && error !== null) { + errorObj = error as ErrorDetails; + } else { + errorObj = { + message: String(error), + details: { error } + }; + } + + setFormError(errorObj); + setDialogs(prev => ({ ...prev, formError: true })); + } + }; + + // Form submission handler + const handleSubmit = (handler: (data: ProviderCreate | ProviderUpdate) => Promise) => { + return hookHandleSubmit(async (data) => { + try { + await handler(data); + } catch (error) { + console.error("Form submission error:", error); + } + }); + }; + + // Test connection handler + const handleTestConnection = async () => { + if (!selectedProvider) return; + + // Validate API key is present + if (!selectedProvider.apiKey) { + setConnectionError({ + message: "API key is required", + name: "ValidationError", + details: { + reason: "Please provide an API key in the provider configuration." + }, + suggestions: [ + "Edit the provider to add an API key", + "API keys should be non-empty strings", + ], + } as ErrorDetails); + setDialogs(prev => ({ ...prev, error: true })); + return; + } + + setIsTestingConnection(true); + setConnectionError(null); + + try { + const config = toServerConfig(selectedProvider); + const result = await testConnection.mutateAsync({ + providerName: selectedProvider.name, + config, + }); + + setConnectionDetails(result); + setDialogs(prev => ({ ...prev, success: true })); + } catch (error) { + console.error("Connection test failed:", error); + + // Convert error to ErrorDetails format + let errorObj: ErrorDetails; + if (error instanceof Error) { + errorObj = { + message: error.message, + code: error.name, + details: { + error: error.toString() + } + }; + + // Try to extract cause if it exists + const errorWithCause = error as ErrorWithCause; + if ('cause' in error && errorWithCause.cause !== undefined) { + errorObj.details = { + ...errorObj.details, + cause: errorWithCause.cause + }; + } + } else if (typeof error === 'object' && error !== null) { + errorObj = error as ErrorDetails; + } else { + errorObj = { + message: String(error), + details: { error } + }; + } + + setConnectionError(errorObj); + setDialogs(prev => ({ ...prev, error: true })); + } finally { + setIsTestingConnection(false); + } + }; + + return ( + { + reset(); + setMode("view"); + }, + handleSubmit, + handleTestConnection, + setMode, + closeDialog, + }} + > + {children} + + ); +} diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/hooks/useProviderConnection.ts b/graphcap_studio/src/features/inference/providers/ProviderConnection/hooks/useProviderConnection.ts index 138bbceb..afe84a05 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/hooks/useProviderConnection.ts +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/hooks/useProviderConnection.ts @@ -2,7 +2,7 @@ import { useState } from "react"; import { useTestProviderConnection } from "../../../services/providers"; import { useInferenceProviderContext } from "../../context"; -import { type Provider, type ProviderCreate, toServerConfig } from "../../types"; +import { type Provider, toServerConfig } from "../../types"; interface UseProviderConnectionResult { isTestingConnection: boolean; diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/index.ts b/graphcap_studio/src/features/inference/providers/ProviderConnection/index.ts index eb9a0f8a..989bcb24 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/index.ts +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/index.ts @@ -1,4 +1,4 @@ // SPDX-License-Identifier: Apache-2.0 -export { default as ProviderForm } from './ProviderForm'; -export { useProviderForm } from './context/useProviderForm'; -export { ProviderFormView } from './components/form/ProviderFormView'; \ No newline at end of file +export { useProviderForm } from '../../hooks/useProviderForm'; +export { ProviderFormView } from './components/ProviderFormView'; +export { ProviderConnection } from './component'; \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/providers/ProvidersList.tsx b/graphcap_studio/src/features/inference/providers/ProvidersList.tsx index 3aa08713..e5b9a0f9 100644 --- a/graphcap_studio/src/features/inference/providers/ProvidersList.tsx +++ b/graphcap_studio/src/features/inference/providers/ProvidersList.tsx @@ -1,6 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 import { ProviderSelect } from "./ProviderConnection/components/ProviderSelect"; -import { useProviderFormContext } from "./context"; +import { useProviderFormContext } from "./context/ProviderFormContext"; /** diff --git a/graphcap_studio/src/features/inference/providers/ProvidersPanel.tsx b/graphcap_studio/src/features/inference/providers/ProvidersPanel.tsx index b6678318..481dd0d0 100644 --- a/graphcap_studio/src/features/inference/providers/ProvidersPanel.tsx +++ b/graphcap_studio/src/features/inference/providers/ProvidersPanel.tsx @@ -1,10 +1,10 @@ import { useColorMode } from "@/components/ui/theme/color-mode"; import { Box, Button, Center, Flex, Text, VStack } from "@chakra-ui/react"; // SPDX-License-Identifier: Apache-2.0 -import { useCallback, useMemo, useState } from "react"; -import type { ProviderCreate, ProviderUpdate } from "../providers/types"; +import { useMemo } from "react"; +import type { Provider } from "../providers/types"; import { useProviders } from "../services/providers"; -import ProviderForm from "./ProviderConnection/ProviderForm"; +import { ProviderConnection } from "./ProviderConnection"; import { ProviderSelect } from "./ProviderConnection/components/ProviderSelect"; import { InferenceProviderProvider, @@ -15,9 +15,7 @@ import { * Panel content that requires context */ function PanelContent() { - const { setMode, providers } = - useInferenceProviderContext(); - + const { setMode, providers, selectedProvider } = useInferenceProviderContext(); const { colorMode } = useColorMode(); const textColor = colorMode === "light" ? "gray.600" : "gray.300"; const borderColor = colorMode === "light" ? "gray.200" : "gray.700"; @@ -30,7 +28,7 @@ function PanelContent() { + +
+ + {/* Provider Form Tabs */} { + setMode("create"); + }; + + return ( + + ); +} \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/ProviderSaveDialog.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/ProviderSaveDialog.tsx index 6a9279a0..b5f01f31 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/ProviderSaveDialog.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/ProviderSaveDialog.tsx @@ -9,7 +9,7 @@ import { VStack, } from "@chakra-ui/react"; // SPDX-License-Identifier: Apache-2.0 -import { useEffect, useState } from "react"; +import { useState } from "react"; import { useCreateProvider, useUpdateProvider } from "../../../../services/providers"; import { useProviderFormContext } from "../../../context/ProviderFormContext"; import type { Provider, ProviderCreate, ProviderUpdate } from "../../../types"; @@ -26,16 +26,13 @@ interface ErrorWithMessage { export function SaveButton() { const { isSubmitting: isContextSubmitting, - isCreating, + isCreating, + mode, handleSubmit, selectedProvider, saveError: contextSaveError } = useProviderFormContext(); - // API mutation hooks - const createProvider = useCreateProvider(); - const updateProvider = useUpdateProvider(); - // Local state for dialog visibility and save state const [isDialogOpen, setIsDialogOpen] = useState(false); const [isSaving, setIsSaving] = useState(false); @@ -43,15 +40,20 @@ export function SaveButton() { const [savingProvider, setSavingProvider] = useState(null); const [saveError, setSaveError] = useState(contextSaveError); + // Get provider service functions + const { mutateAsync: createProviderAsync, isPending: isCreatingProvider } = useCreateProvider(); + const { mutateAsync: updateProviderAsync, isPending: isUpdatingProvider } = useUpdateProvider(); + // Determine if form is submitting - const isSubmitting = isContextSubmitting || isSaving; + const isSubmitting = isContextSubmitting || isSaving || isCreatingProvider || isUpdatingProvider; // Determine the button text based on form state - const buttonText = isSubmitting - ? "Saving..." - : isCreating - ? "Create" - : "Save"; + let buttonText = "Save"; + if (isSubmitting) { + buttonText = "Saving..."; + } else if (isCreating) { + buttonText = "Create"; + } // Function to close the dialog const closeDialog = () => { @@ -61,48 +63,119 @@ export function SaveButton() { setSaveError(undefined); }; - // Custom submit handler that shows the dialog and directly calls API + // Helper function to normalize rate limits + const normalizeRateLimits = (data: ProviderCreate | ProviderUpdate) => { + if (!data.rateLimits) return; + + if (Array.isArray(data.rateLimits) || typeof data.rateLimits !== 'object') { + const currentRateLimits = data.rateLimits as unknown; + const requestsPerMinute = typeof currentRateLimits === 'object' && currentRateLimits !== null + ? (currentRateLimits as Record).requestsPerMinute as number ?? 0 + : 0; + const tokensPerMinute = typeof currentRateLimits === 'object' && currentRateLimits !== null + ? (currentRateLimits as Record).tokensPerMinute as number ?? 0 + : 0; + + data.rateLimits = { requestsPerMinute, tokensPerMinute }; + } + }; + + // Handle form submission errors + const handleSaveError = (error: unknown) => { + console.error("Error saving provider:", error); + if (error instanceof Error) { + setSaveError(error.message); + } else if (typeof error === 'object' && error !== null && 'message' in error) { + const errorWithMsg = error as ErrorWithMessage; + setSaveError(errorWithMsg.message); + } else { + setSaveError("An unknown error occurred"); + } + }; + + // Save the provider using the appropriate service function + const saveProvider = async (data: ProviderCreate | ProviderUpdate): Promise => { + // Make sure API key is included in update requests + if (!data.apiKey && selectedProvider?.apiKey && mode === 'edit') { + console.log("Including existing API key in update"); + data.apiKey = selectedProvider.apiKey; + } + + // Log the full data we're sending (redact the actual API key) + console.log("Sending to server:", { + ...data, + apiKey: data.apiKey ? "[PRESENT]" : "[MISSING]", + mode + }); + + try { + // Edit mode with selected provider + if (mode === "edit" && selectedProvider?.id) { + console.log(`Updating provider with id ${selectedProvider.id}`); + const result = await updateProviderAsync({ + id: selectedProvider.id, + data: data as ProviderUpdate + }); + console.log("Provider updated successfully:", result); + return result; + } + + // Create mode + if (mode === "create") { + console.log("Creating new provider"); + const result = await createProviderAsync(data as ProviderCreate); + console.log("Provider created successfully:", result); + return result; + } + + // Fallback path - has ID in data + if ('id' in data && data.id) { + const id = data.id as number; + console.log(`Updating provider with id ${id}`); + const result = await updateProviderAsync({ id, data }); + console.log("Provider updated successfully:", result); + return result; + } + + // Default create path + console.log("Creating new provider (fallback path)"); + const result = await createProviderAsync(data as ProviderCreate); + console.log("Provider created successfully:", result); + return result; + } catch (error) { + handleSaveError(error); + throw error; + } + }; + + // Custom submit handler that shows the dialog and processes the form const handleFormSubmit = async (e: React.FormEvent) => { try { setIsSaving(true); setIsDialogOpen(true); setSaveError(undefined); - // Use the form's handleSubmit to get validated data - await handleSubmit(async (data: ProviderCreate | ProviderUpdate) => { + // Process the form submission through the form's handleSubmit + const formHandler = handleSubmit(async (data: ProviderCreate | ProviderUpdate) => { try { console.log("Provider form submitted:", data); - // Determine if we're creating or updating based on presence of id - let result: Provider; - if ('id' in data && data.id) { - // Update existing provider - const id = data.id as number; - console.log(`Updating provider with id ${id}`); - result = await updateProvider.mutateAsync({ id, data }); - console.log("Provider updated successfully:", result); - } else { - // Create new provider - console.log("Creating new provider"); - result = await createProvider.mutateAsync(data as ProviderCreate); - console.log("Provider created successfully:", result); - } + // Normalize rateLimits - ensure it's an object, not an array + normalizeRateLimits(data); + + // Save the provider + const result = await saveProvider(data); // Success - store the provider details and mark as complete setSavingProvider(result); setSaveComplete(true); } catch (error) { - console.error("Error saving provider:", error); - if (error instanceof Error) { - setSaveError(error.message); - } else if (typeof error === 'object' && error !== null && 'message' in error) { - const errorWithMsg = error as ErrorWithMessage; - setSaveError(errorWithMsg.message); - } else { - setSaveError("An unknown error occurred"); - } + handleSaveError(error); } - })(e); + }); + + // Execute the form handler + formHandler(e); } catch (error) { console.error("Form submission error:", error); if (error instanceof Error) { @@ -118,6 +191,61 @@ export function SaveButton() { // Get the current provider to display const displayProvider = savingProvider || selectedProvider; + // Determine dialog title + let dialogTitle = "Processing..."; + if (saveError) { + dialogTitle = "Error Saving Provider"; + } else if (isSaving) { + dialogTitle = "Saving Provider..."; + } else if (saveComplete) { + dialogTitle = "Provider Saved"; + } + + // Render dialog body content based on state + const renderDialogBody = () => { + if (saveError) { + return ( + + {saveError || "An unknown error occurred"} + + ); + } + + if (isSaving) { + return ( + + + Saving provider configuration to server... + Please wait while we process your request + + ); + } + + if (saveComplete && displayProvider) { + return ( + + + + Name: {displayProvider.name} + Kind: {displayProvider.kind} + Environment: {displayProvider.environment} + Base URL: {displayProvider.baseUrl} + {displayProvider.fetchModels && ( + Default Model: {displayProvider.defaultModel ?? "Not set"} + )} + + + + ); + } + + return ( + + Initializing save process... + + ); + }; + return ( <> -
- ); - } - return ( - {/* Header */} - - {/* Provider Selection Dropdown */} - - - - - - {/* Content */} - + ); @@ -70,8 +29,8 @@ function PanelContent() { /** * Providers Panel Component * - * This component displays a list of providers and allows viewing and editing - * provider configurations. + * This component displays provider configurations in a panel. + * It acts as a container for the provider connection form. */ export function ProvidersPanel() { const { @@ -112,9 +71,7 @@ export function ProvidersPanel() { providers={providersData} selectedProvider={initialSelectedProvider} isCreating={false} - onSubmit={() => {}} onCancel={() => {}} - isSubmitting={false} > diff --git a/graphcap_studio/src/features/inference/providers/context/ProviderFormContext.tsx b/graphcap_studio/src/features/inference/providers/context/ProviderFormContext.tsx index 753cbe48..5e814dc5 100644 --- a/graphcap_studio/src/features/inference/providers/context/ProviderFormContext.tsx +++ b/graphcap_studio/src/features/inference/providers/context/ProviderFormContext.tsx @@ -10,6 +10,7 @@ interface ProviderFormContextType { saveSuccess: boolean; isTestingConnection: boolean; selectedProvider: Provider | null; + setSelectedProvider: (provider: Provider | null) => void; formError: ErrorDetails | null; connectionError: ErrorDetails | null; connectionDetails: ConnectionDetails | null; diff --git a/graphcap_studio/src/features/inference/providers/index.ts b/graphcap_studio/src/features/inference/providers/index.ts index 188f1e7f..d20e2a33 100644 --- a/graphcap_studio/src/features/inference/providers/index.ts +++ b/graphcap_studio/src/features/inference/providers/index.ts @@ -1,6 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 -export { ProvidersPanel } from "./ProvidersPanel"; -export { default as ProvidersList } from "./ProvidersList"; -export * from "../hooks"; -export * from "./context"; +export * from './ProviderConnection'; +export * from './ProvidersPanel'; +export * from './context'; +export * from './types'; diff --git a/graphcap_studio/src/features/inference/services/providers.ts b/graphcap_studio/src/features/inference/services/providers.ts index 9c1434ab..f59176c1 100644 --- a/graphcap_studio/src/features/inference/services/providers.ts +++ b/graphcap_studio/src/features/inference/services/providers.ts @@ -126,7 +126,30 @@ export function useCreateProvider() { }); if (!response.ok) { - throw new Error(`Failed to create provider: ${response.status}`); + // Try to get detailed error information + try { + const errorData = await response.json(); + console.error("Provider creation error:", errorData); + + // Check if we have a structured error response + if (errorData.status === 'error' || errorData.validationErrors) { + throw errorData; + } + + // Simple error with a message + if (errorData.message) { + throw new Error(errorData.message); + } + + // Fallback error + throw new Error(`Failed to create provider: ${response.status}`); + } catch (parseError) { + // If we can't parse the error as JSON, throw a general error + if (parseError instanceof Error && parseError.message !== 'Failed to create provider') { + throw parseError; + } + throw new Error(`Failed to create provider: ${response.status}`); + } } return response.json() as Promise; @@ -147,6 +170,7 @@ export function useUpdateProvider() { return useMutation({ mutationFn: async ({ id, data }: { id: number; data: ProviderUpdate }) => { + console.log("Updating provider with data:", data); const client = createDataServiceClient(connections); const response = await client.providers[":id"].$put({ param: { id: id.toString() }, diff --git a/graphcap_studio/src/features/server-connections/services/apiClients.ts b/graphcap_studio/src/features/server-connections/services/apiClients.ts index a77f51a1..ce83c056 100644 --- a/graphcap_studio/src/features/server-connections/services/apiClients.ts +++ b/graphcap_studio/src/features/server-connections/services/apiClients.ts @@ -23,12 +23,6 @@ export interface DataServiceClient { json: unknown; }) => Promise; $delete: (options: { param: { id: string } }) => Promise; - "api-key": { - $put: (options: { - param: { id: string }; - json: unknown; - }) => Promise; - }; }; }; } diff --git a/graphcap_studio/src/features/server-connections/services/providers.ts b/graphcap_studio/src/features/server-connections/services/providers.ts index 7db9cda1..f5a98b6d 100644 --- a/graphcap_studio/src/features/server-connections/services/providers.ts +++ b/graphcap_studio/src/features/server-connections/services/providers.ts @@ -97,7 +97,30 @@ export function useCreateProvider() { }); if (!response.ok) { - throw new Error(`Failed to create provider: ${response.status}`); + // Try to get detailed error information + try { + const errorData = await response.json(); + console.error("Provider creation error:", errorData); + + // Check if we have a structured error response + if (errorData.status === 'error' || errorData.validationErrors) { + throw errorData; + } + + // Simple error with a message + if (errorData.message) { + throw new Error(errorData.message); + } + + // Fallback error + throw new Error(`Failed to create provider: ${response.status}`); + } catch (parseError) { + // If we can't parse the error as JSON, throw a general error + if (parseError instanceof Error && parseError.message !== 'Failed to create provider') { + throw parseError; + } + throw new Error(`Failed to create provider: ${response.status}`); + } } return response.json() as Promise; @@ -118,18 +141,15 @@ export function useUpdateProvider() { return useMutation({ mutationFn: async ({ id, data }: { id: number; data: ProviderUpdate }) => { - // Filter out null values and skip apiKey property completely + console.log("Updating provider with data:", data); const updateData = Object.entries(data).reduce((acc, [key, value]) => { - // Skip apiKey completely - it has its own endpoint - if (key === 'apiKey') return acc; - // Only include defined values if (value !== null && value !== undefined) { acc[key] = value; } return acc; }, {} as Record); - + console.log("updateData", updateData); const client = createDataServiceClient(connections); const response = await client.providers[":id"].$put({ param: { id: id.toString() }, diff --git a/servers/data_service/src/features/provider_config/controller.ts b/servers/data_service/src/features/provider_config/controller.ts index d211219a..ad161a79 100644 --- a/servers/data_service/src/features/provider_config/controller.ts +++ b/servers/data_service/src/features/provider_config/controller.ts @@ -9,7 +9,7 @@ import { eq } from "drizzle-orm"; import type { Context } from "hono"; import { db } from "../../db"; import { providerModels, providerRateLimits, providers } from "../../db/schema"; -import { encryptApiKey } from "../../utils/encryption"; +import { decryptApiKey, encryptApiKey } from "../../utils/encryption"; import { logger } from "../../utils/logger"; import type { ProviderApiKey, @@ -37,6 +37,23 @@ export const getProviders = async (c: Context) => { }, }); + // Decrypt API keys before returning to client + for (const provider of allProviders) { + if (provider.apiKey) { + logger.debug({ providerId: provider.id }, "Decrypting API key for provider"); + provider.apiKey = await decryptApiKey(provider.apiKey); + + // Log whether API key is present after decryption (without showing the actual key) + logger.debug({ + providerId: provider.id, + apiKeyPresent: provider.apiKey ? true : false, + apiKeyLength: provider.apiKey ? provider.apiKey.length : 0 + }, "Provider API key decryption result"); + } else { + logger.debug({ providerId: provider.id }, "No API key to decrypt for provider"); + } + } + logger.info( { count: allProviders.length }, "Providers fetched successfully", @@ -90,12 +107,32 @@ export const getProvider = async (c: Context) => { return c.json({ error: "Provider not found" }, 404); } + // Decrypt API key before returning to client + if (provider.apiKey) { + logger.debug({ + providerId: id, + encryptedKeyLength: provider.apiKey.length + }, "Decrypting API key for provider"); + + provider.apiKey = await decryptApiKey(provider.apiKey); + + // Log the result of decryption (without showing the actual key) + logger.debug({ + providerId: id, + apiKeyPresent: provider.apiKey ? true : false, + apiKeyLength: provider.apiKey ? provider.apiKey.length : 0 + }, "Provider API key decryption result"); + } else { + logger.debug({ providerId: id }, "No API key to decrypt for provider"); + } + logger.info({ providerId: id }, "Provider fetched successfully"); return c.json(provider); - } else { - logger.warn({ providerId: id }, "Provider not found"); - return c.json({ error: "Provider not found" }, 404); } + + // If ID mismatch, return not found (removed else clause) + logger.warn({ providerId: id }, "Provider not found"); + return c.json({ error: "Provider not found" }, 404); } catch (error) { logger.error({ error, providerId: id }, "Error fetching provider"); return c.json({ error: "Failed to fetch provider" }, 500); @@ -289,11 +326,21 @@ export const updateProvider = async (c: Context) => { try { // @ts-ignore - Hono OpenAPI validation types are not properly recognized const data = c.req.valid("json") as ProviderUpdate; - logger.debug({ id, data }, "Updating provider"); + logger.debug({ + id, + data: { + ...data, + apiKey: data.apiKey !== undefined ? '[PRESENT]' : '[MISSING]' + } + }, "Updating provider"); // Check if provider exists const existingProvider = await db.query.providers.findFirst({ where: eq(providers.id, Number.parseInt(id)), + with: { + models: true, + rateLimits: true, + }, }); if (!existingProvider) { @@ -325,13 +372,118 @@ export const updateProvider = async (c: Context) => { // Extract models and rate limits if provided const { models, rateLimits, ...providerData } = data; + // LOG API KEY STATUS FOR DEBUGGING + logger.debug({ + providerId: id, + original_apiKey_present: existingProvider.apiKey !== null, + update_apiKey_present: 'apiKey' in providerData, + update_apiKey_value_present: providerData.apiKey !== undefined && providerData.apiKey !== null + }, "API key update status"); + + // Log what fields are being updated + const updatedFields: Record = {}; + + // Compare each field being updated with existing values + for (const [key, value] of Object.entries(providerData)) { + const existingValue = (existingProvider as Record)[key]; + // Only log if the value is actually changing + if (existingValue !== value && value !== undefined) { + // Special handling for API key to avoid logging actual values + if (key === 'apiKey') { + updatedFields[key] = { + from: existingValue ? '[ENCRYPTED]' : '[EMPTY]', + to: value ? '[NEW_VALUE]' : '[EMPTY]' + }; + logger.info( + { providerId: id }, + `Updating API key from ${existingValue ? 'existing value' : 'empty'} to ${value ? 'new value' : 'empty'}` + ); + } else { + updatedFields[key] = { from: existingValue, to: value }; + } + } + } + + // Log all field changes + if (Object.keys(updatedFields).length > 0) { + logger.info({ + providerId: id, + provider: existingProvider.name, + updatedFields + }, "Provider fields being updated"); + } + + // Log model changes if applicable + if (models && models.length > 0) { + // Need to query specifically for models since they might not be included in existingProvider + const existingModelsQuery = await db.query.providerModels.findMany({ + where: eq(providerModels.providerId, Number.parseInt(id)) + }); + + logger.info({ + providerId: id, + provider: existingProvider.name, + existingModelsCount: existingModelsQuery.length, + newModelsCount: models.length, + modelNames: models.map(m => m.name) + }, "Updating provider models"); + } + + // Log rate limit changes if applicable + if (rateLimits) { + // Query for existing rate limits + const existingRateLimitsQuery = await db.query.providerRateLimits.findFirst({ + where: eq(providerRateLimits.providerId, Number.parseInt(id)) + }); + + logger.info({ + providerId: id, + provider: existingProvider.name, + existingRateLimits: existingRateLimitsQuery + ? { + requestsPerMinute: existingRateLimitsQuery.requestsPerMinute, + tokensPerMinute: existingRateLimitsQuery.tokensPerMinute + } + : { requestsPerMinute: null, tokensPerMinute: null }, + newRateLimits: rateLimits + }, "Updating provider rate limits"); + } + // Start a transaction const result = await db.transaction(async (tx) => { - // Update provider + // Get the current provider from the database to ensure we have the latest data + const currentProvider = await tx.query.providers.findFirst({ + where: eq(providers.id, Number.parseInt(id)), + }); + + if (!currentProvider) { + throw new Error(`Provider not found with id ${id}`); + } + + // CRITICAL FIX: Handle API key specially to avoid losing it + let apiKeyToUse = currentProvider.apiKey; // Default to keeping existing key + + // Only update the API key if it's explicitly included in the update data + if ('apiKey' in providerData && providerData.apiKey !== undefined) { + if (providerData.apiKey) { + logger.debug({ providerId: id }, "Encrypting new API key for provider update"); + apiKeyToUse = await encryptApiKey(providerData.apiKey as string); + logger.info({ providerId: id }, "API key encrypted for update"); + } else { + // If apiKey is explicitly set to an empty value, only then clear it + logger.debug({ providerId: id }, "API key explicitly cleared in update"); + apiKeyToUse = null; + } + } else { + logger.debug({ providerId: id }, "Keeping existing API key - no change requested"); + } + + // Update provider with the appropriate API key await tx .update(providers) .set({ ...providerData, + apiKey: apiKeyToUse, // Use the properly determined API key updatedAt: new Date(), }) .where(eq(providers.id, Number.parseInt(id))); diff --git a/servers/data_service/src/features/provider_config/schemas.ts b/servers/data_service/src/features/provider_config/schemas.ts index 5b0dbcf6..00c82e42 100644 --- a/servers/data_service/src/features/provider_config/schemas.ts +++ b/servers/data_service/src/features/provider_config/schemas.ts @@ -64,6 +64,7 @@ export const providerUpdateSchema = z.object({ kind: z.string().min(1, 'Kind is required').optional(), environment: z.enum(['cloud', 'local']).optional(), baseUrl: z.string().url('Must be a valid URL').optional(), + apiKey: z.string().optional(), isEnabled: z.boolean().optional(), models: z.array( z.object({ @@ -78,17 +79,8 @@ export const providerUpdateSchema = z.object({ }).optional(), }); -// Schema for updating a provider's API key -export const providerApiKeySchema = z.object({ - apiKey: z.string() - .min(1, { message: 'API key is required and cannot be empty' }) - .refine(val => val.trim() !== '', { - message: 'API key cannot be just whitespace' - }), -}); // Export types export type Provider = z.infer; export type ProviderCreate = z.infer; -export type ProviderUpdate = z.infer; -export type ProviderApiKey = z.infer; \ No newline at end of file +export type ProviderUpdate = z.infer; \ No newline at end of file From 2e4f5b9e64dcccda380eb19330a4e45cca4f133c Mon Sep 17 00:00:00 2001 From: jphillips Date: Wed, 26 Mar 2025 13:26:20 -0500 Subject: [PATCH 20/69] Provider refactor, cleanup 1 Signed-off-by: jphillips --- .../components/ProviderActions.tsx | 1 + .../components/form/EnvironmentSelect.tsx | 13 +- .../inference/providers/ProvidersList.tsx | 2 - .../inference/providers/ProvidersPanel.tsx | 2 - .../components/ConnectionActionButton.tsx | 2 +- .../components/ConnectionStatusIndicator.tsx | 2 +- .../components/ServerConnectionsPanel.tsx | 2 +- .../server-connections/services/providers.ts | 53 ---- .../data_service/src/utils/pino-middleware.ts | 227 +++++++++++------- 9 files changed, 151 insertions(+), 153 deletions(-) diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderActions.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderActions.tsx index 82b64822..d44f8482 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderActions.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderActions.tsx @@ -15,6 +15,7 @@ export function ProviderActions() { return ( + diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/EnvironmentSelect.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/EnvironmentSelect.tsx index 419ba708..9c5b8355 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/EnvironmentSelect.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/EnvironmentSelect.tsx @@ -3,6 +3,7 @@ import { SelectItem, SelectRoot, SelectTrigger, + SelectValueText, } from "@/components/ui/select"; import { useColorModeValue } from "@/components/ui/theme/color-mode"; import { Field, createListCollection } from "@chakra-ui/react"; @@ -34,10 +35,18 @@ export function EnvironmentSelect() { field.onChange(value)} + onValueChange={(details) => { + if (details.value && details.value.length > 0) { + field.onChange(details.value[0]); + } else { + field.onChange(""); + } + }} collection={collection} > - {field.value || "Select environment"} + + + {environmentItems.map((item) => ( diff --git a/graphcap_studio/src/features/inference/providers/ProvidersList.tsx b/graphcap_studio/src/features/inference/providers/ProvidersList.tsx index fd9161f3..6bf48ddd 100644 --- a/graphcap_studio/src/features/inference/providers/ProvidersList.tsx +++ b/graphcap_studio/src/features/inference/providers/ProvidersList.tsx @@ -2,7 +2,6 @@ import { useProviders } from "../services/providers"; import { ProviderFormSelect } from "./ProviderConnection/components/form/ProviderFormSelect"; import { ProviderFormContainer } from "./ProviderConnection/containers/ProviderFormContainer"; -import { useInferenceProviderContext } from "./context/InferenceProviderContext"; import type { ProviderCreate, ProviderUpdate } from "./types"; /** @@ -11,7 +10,6 @@ import type { ProviderCreate, ProviderUpdate } from "./types"; */ export default function ProvidersList() { const { data: providers = [], isLoading } = useProviders(); - const { setSelectedProvider } = useInferenceProviderContext(); // No need to check context.providers as we fetch directly here if (providers.length === 0 && !isLoading) { diff --git a/graphcap_studio/src/features/inference/providers/ProvidersPanel.tsx b/graphcap_studio/src/features/inference/providers/ProvidersPanel.tsx index 0944a975..eba6e435 100644 --- a/graphcap_studio/src/features/inference/providers/ProvidersPanel.tsx +++ b/graphcap_studio/src/features/inference/providers/ProvidersPanel.tsx @@ -2,7 +2,6 @@ import { useColorMode } from "@/components/ui/theme/color-mode"; import { Box, Center, Flex, Text } from "@chakra-ui/react"; // SPDX-License-Identifier: Apache-2.0 import { useMemo } from "react"; -import type { Provider } from "../providers/types"; import { useProviders } from "../services/providers"; import { ProviderConnection } from "./ProviderConnection"; import { @@ -14,7 +13,6 @@ import { */ function PanelContent() { const { colorMode } = useColorMode(); - const borderColor = colorMode === "light" ? "gray.200" : "gray.700"; return ( diff --git a/graphcap_studio/src/features/server-connections/components/ConnectionActionButton.tsx b/graphcap_studio/src/features/server-connections/components/ConnectionActionButton.tsx index 9db05a2f..6e55fe11 100644 --- a/graphcap_studio/src/features/server-connections/components/ConnectionActionButton.tsx +++ b/graphcap_studio/src/features/server-connections/components/ConnectionActionButton.tsx @@ -1,5 +1,5 @@ import { CONNECTION_STATUS } from "@/features/server-connections/constants"; -import { ConnectionActionButtonProps } from "@/features/server-connections/types"; +import type { ConnectionActionButtonProps } from "@/features/server-connections/types"; import { Button } from "@chakra-ui/react"; // SPDX-License-Identifier: Apache-2.0 import { memo } from "react"; diff --git a/graphcap_studio/src/features/server-connections/components/ConnectionStatusIndicator.tsx b/graphcap_studio/src/features/server-connections/components/ConnectionStatusIndicator.tsx index b746d85b..a91b5680 100644 --- a/graphcap_studio/src/features/server-connections/components/ConnectionStatusIndicator.tsx +++ b/graphcap_studio/src/features/server-connections/components/ConnectionStatusIndicator.tsx @@ -1,6 +1,6 @@ import { Status } from "@/components/ui/status"; import { CONNECTION_STATUS } from "@/features/server-connections/constants"; -import { ConnectionStatusIndicatorProps } from "@/features/server-connections/types"; +import type { ConnectionStatusIndicatorProps } from "@/features/server-connections/types"; // SPDX-License-Identifier: Apache-2.0 import { memo } from "react"; diff --git a/graphcap_studio/src/features/server-connections/components/ServerConnectionsPanel.tsx b/graphcap_studio/src/features/server-connections/components/ServerConnectionsPanel.tsx index 87d14ddb..9dd57e36 100644 --- a/graphcap_studio/src/features/server-connections/components/ServerConnectionsPanel.tsx +++ b/graphcap_studio/src/features/server-connections/components/ServerConnectionsPanel.tsx @@ -1,7 +1,7 @@ import { useColorModeValue } from "@/components/ui/theme/color-mode"; import { useServerConnectionsContext } from "@/context/ServerConnectionsContext"; import { CONNECTION_STATUS } from "@/features/server-connections/constants"; -import { ServerConnectionsPanelProps } from "@/features/server-connections/types"; +import type { ServerConnectionsPanelProps } from "@/features/server-connections/types"; import { Box, Button, Flex, Heading, Spinner, Stack } from "@chakra-ui/react"; // SPDX-License-Identifier: Apache-2.0 import { memo, useMemo } from "react"; diff --git a/graphcap_studio/src/features/server-connections/services/providers.ts b/graphcap_studio/src/features/server-connections/services/providers.ts index f5a98b6d..05187118 100644 --- a/graphcap_studio/src/features/server-connections/services/providers.ts +++ b/graphcap_studio/src/features/server-connections/services/providers.ts @@ -9,11 +9,9 @@ import { useServerConnectionsContext } from "@/context/ServerConnectionsContext"; import type { Provider, - ProviderApiKey, ProviderCreate, ProviderModelsResponse, ProviderUpdate, - ServerProviderConfig, SuccessResponse, } from "@/features/inference/providers/types"; import { toServerConfig } from "@/features/inference/providers/types"; @@ -223,57 +221,6 @@ export function useDeleteProvider() { }); } -/** - * Hook to update a provider's API key - */ -export function useUpdateProviderApiKey() { - const queryClient = useQueryClient(); - const { connections } = useServerConnectionsContext(); - - return useMutation({ - mutationFn: async ({ id, apiKey }: { id: number; apiKey: string }) => { - const client = createDataServiceClient(connections); - const response = await client.providers[":id"]["api-key"].$put({ - param: { id: id.toString() }, - json: { apiKey } as ProviderApiKey, - }); - - if (!response.ok) { - // Try to get detailed error information - try { - const errorData = await response.json(); - console.error("API key update error:", errorData); - - // Check if we have a structured error response - if (errorData.status === 'error' || errorData.validationErrors) { - throw errorData; - } - - // Simple error with a message - if (errorData.message) { - throw new Error(errorData.message); - } - - // Fallback error - throw new Error(`Failed to update API key: ${response.status}`); - } catch (parseError) { - // If we can't parse the error as JSON, throw a general error - if (parseError instanceof Error && parseError.message !== 'Failed to update API key') { - throw parseError; - } - throw new Error(`Failed to update API key: ${response.status}`); - } - } - - return response.json() as Promise; - }, - onSuccess: (_, { id }) => { - // Invalidate specific provider query - queryClient.invalidateQueries({ queryKey: queryKeys.provider(id) }); - }, - }); -} - /** * Hook to get available models for a provider */ diff --git a/servers/data_service/src/utils/pino-middleware.ts b/servers/data_service/src/utils/pino-middleware.ts index cfd242c4..0f615716 100644 --- a/servers/data_service/src/utils/pino-middleware.ts +++ b/servers/data_service/src/utils/pino-middleware.ts @@ -30,6 +30,134 @@ export const createPinoLoggerMiddleware = () => { }); }; +/** + * Get safe query parameters with error handling + */ +const getSafeQueryParams = (c: Context): Record => { + try { + return c.req.query(); + } catch (e) { + logger.debug({ error: e }, "Failed to get query parameters"); + return {}; // Fallback to empty object if query() throws an error + } +}; + +/** + * Extract and parse request body based on content type + */ +const parseRequestBody = async (clonedReq: Request, contentType: string): Promise => { + if (contentType.includes("application/json")) { + try { + return await clonedReq.json(); + } catch (jsonError) { + logger.debug({ error: jsonError }, "Failed to parse JSON body"); + return "[Unparseable JSON]"; + } + } + + if (contentType.includes("multipart/form-data")) { + return "[Multipart form data]"; + } + + if (contentType.includes("application/x-www-form-urlencoded")) { + try { + return Object.fromEntries(await clonedReq.formData()); + } catch (formError) { + logger.debug({ error: formError }, "Failed to parse form data"); + return "[Unparseable form data]"; + } + } + + // Default to text handling + try { + const textBody = await clonedReq.text(); + return textBody.length > 1000 ? `${textBody.substring(0, 1000)}... [truncated]` : textBody; + } catch (textError) { + logger.debug({ error: textError }, "Failed to get text body"); + return "[Unreadable text body]"; + } +}; + +/** + * Extract request body with proper error handling + */ +const getRequestBody = async (c: Context, method: string): Promise<[unknown, boolean]> => { + // Skip for GET and HEAD requests + if (method === "GET" || method === "HEAD") { + return [null, true]; + } + + try { + // Check if the request can be cloned + if (c.req.raw.clone && typeof c.req.raw.clone === 'function') { + const clonedReq = c.req.raw.clone(); + const contentType = c.req.header("content-type") || ""; + const body = await parseRequestBody(clonedReq, contentType); + return [body, true]; + } + + // If request cloning is not supported + logger.debug("Request body logging skipped - Request.clone() not supported"); + return ["[Body logging disabled - clone not supported]", false]; + } catch (e) { + logger.debug({ error: e }, "Error while attempting to read request body"); + return ["[Error reading request body]", true]; + } +}; + +/** + * Log request information + */ +const logRequest = ( + method: string, + url: string, + path: string, + queryParams: Record, + headers: Record, + body: unknown, + bodyReadable: boolean +) => { + logger.info( + { + type: "request", + method, + url, + path, + query: queryParams, + headers, + body, + bodyReadable, + }, + "API Request", + ); +}; + +/** + * Log response information + */ +const logResponse = ( + method: string, + url: string, + path: string, + status: number | undefined, + headers: Headers | undefined, + responseTime: number +) => { + logger.info( + { + type: "response", + method, + url, + path, + status, + headers, + responseTime, + body: "[Response body not captured]", + }, + "API Response", + ); +}; + /** * Detailed logging middleware * @@ -38,107 +166,24 @@ export const createPinoLoggerMiddleware = () => { */ export const createDetailedLoggingMiddleware = () => { return async (c: Context, next: () => Promise) => { - // Log request details before processing + // Extract basic request information const { method } = c.req; const url = c.req.url; const path = c.req.path; - // Safely get query parameters - let queryParams: Record; - try { - queryParams = c.req.query(); - } catch (e) { - queryParams = {}; // Fallback to empty object if query() throws an error - logger.debug({ error: e }, "Failed to get query parameters"); - } - - // Try to get the request body if not a GET or HEAD request - let reqBody: unknown = null; - let bodyReadable = true; - if (method !== "GET" && method !== "HEAD") { - try { - // Check if the request can be cloned (only works in certain environments) - if (c.req.raw.clone && typeof c.req.raw.clone === 'function') { - // Clone the request to read the body without consuming it - const clonedReq = c.req.raw.clone(); - const contentType = c.req.header("content-type") || ""; - - if (contentType.includes("application/json")) { - try { - reqBody = await clonedReq.json(); - } catch (jsonError) { - logger.debug({ error: jsonError }, "Failed to parse JSON body"); - reqBody = "[Unparseable JSON]"; - } - } else if (contentType.includes("multipart/form-data")) { - reqBody = "[Multipart form data]"; - } else if (contentType.includes("application/x-www-form-urlencoded")) { - try { - reqBody = Object.fromEntries(await clonedReq.formData()); - } catch (formError) { - logger.debug({ error: formError }, "Failed to parse form data"); - reqBody = "[Unparseable form data]"; - } - } else { - try { - const textBody = await clonedReq.text(); - reqBody = textBody.length > 1000 ? `${textBody.substring(0, 1000)}... [truncated]` : textBody; - } catch (textError) { - logger.debug({ error: textError }, "Failed to get text body"); - reqBody = "[Unreadable text body]"; - } - } - } else { - // If request cloning is not supported, don't attempt to read the body - bodyReadable = false; - reqBody = "[Body logging disabled - clone not supported]"; - logger.debug("Request body logging skipped - Request.clone() not supported"); - } - } catch (e) { - logger.debug({ error: e }, "Error while attempting to read request body"); - reqBody = "[Error reading request body]"; - } - } - + // Get request components + const queryParams = getSafeQueryParams(c); + const [reqBody, bodyReadable] = await getRequestBody(c, method); + // Log the request - logger.info( - { - type: "request", - method, - url, - path, - query: queryParams, - headers: c.req.header(), - body: reqBody, - bodyReadable, - }, - "API Request", - ); + logRequest(method, url, path, queryParams, c.req.header(), reqBody, bodyReadable); - // Process the request + // Process the request and measure response time const startTime = Date.now(); await next(); const responseTime = Date.now() - startTime; // Log response details - // We don't attempt to capture the response body to avoid interference - // Response body capture requires special handling at the route level - const resStatus = c.res?.status; - const resHeaders = c.res?.headers; - - // Log the response - logger.info( - { - type: "response", - method, - url, - path, - status: resStatus, - headers: resHeaders, - responseTime, - body: "[Response body not captured]", - }, - "API Response", - ); + logResponse(method, url, path, c.res?.status, c.res?.headers, responseTime); }; }; From be3d69ff3cb46afc1522555038cd49ba9b0745bc Mon Sep 17 00:00:00 2001 From: jphillips Date: Wed, 26 Mar 2025 13:41:48 -0500 Subject: [PATCH 21/69] Fix models call in provider form Signed-off-by: jphillips --- .../src/features/inference/services/providers.ts | 3 ++- .../features/server-connections/services/apiClients.ts | 3 --- .../features/server-connections/services/providers.ts | 10 ++++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/graphcap_studio/src/features/inference/services/providers.ts b/graphcap_studio/src/features/inference/services/providers.ts index f59176c1..33db33e4 100644 --- a/graphcap_studio/src/features/inference/services/providers.ts +++ b/graphcap_studio/src/features/inference/services/providers.ts @@ -298,7 +298,8 @@ export function useProviderModels(provider: Provider) { const client = createInferenceBridgeClient(connections); const serverConfig = toServerConfig(provider); - const response = await client.models.$post({ + const response = await client.providers[":provider_name"]["models"].$post({ + param: { provider_name: provider.name }, json: serverConfig, }); diff --git a/graphcap_studio/src/features/server-connections/services/apiClients.ts b/graphcap_studio/src/features/server-connections/services/apiClients.ts index ce83c056..3a857981 100644 --- a/graphcap_studio/src/features/server-connections/services/apiClients.ts +++ b/graphcap_studio/src/features/server-connections/services/apiClients.ts @@ -31,9 +31,6 @@ export interface DataServiceClient { * Interface for the Inference Bridge client */ export interface InferenceBridgeClient { - models: { - $post: (options: { json: unknown }) => Promise; - }; providers: { ":provider_name": { "test-connection": { diff --git a/graphcap_studio/src/features/server-connections/services/providers.ts b/graphcap_studio/src/features/server-connections/services/providers.ts index 05187118..cf4dccf8 100644 --- a/graphcap_studio/src/features/server-connections/services/providers.ts +++ b/graphcap_studio/src/features/server-connections/services/providers.ts @@ -16,6 +16,7 @@ import type { } from "@/features/inference/providers/types"; import { toServerConfig } from "@/features/inference/providers/types"; import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import { SERVER_IDS } from "../constants"; import { createDataServiceClient, createInferenceBridgeClient } from "./apiClients"; // Query keys for TanStack Query @@ -31,7 +32,7 @@ export const queryKeys = { export function useProviders() { const { connections } = useServerConnectionsContext(); const dataServiceConnection = connections.find( - (conn) => conn.id === "data-service", + (conn) => conn.id === SERVER_IDS.DATA_SERVICE, ); const isConnected = dataServiceConnection?.status === "connected"; @@ -58,7 +59,7 @@ export function useProviders() { export function useProvider(id: number) { const { connections } = useServerConnectionsContext(); const dataServiceConnection = connections.find( - (conn) => conn.id === "data-service", + (conn) => conn.id === SERVER_IDS.DATA_SERVICE, ); const isConnected = dataServiceConnection?.status === "connected"; @@ -227,7 +228,7 @@ export function useDeleteProvider() { export function useProviderModels(provider: Provider) { const { connections } = useServerConnectionsContext(); const inferenceBridgeConnection = connections.find( - (conn) => conn.id === "inference-bridge", + (conn) => conn.id === SERVER_IDS.INFERENCE_BRIDGE, ); const isConnected = inferenceBridgeConnection?.status === "connected"; @@ -237,7 +238,8 @@ export function useProviderModels(provider: Provider) { const client = createInferenceBridgeClient(connections); const serverConfig = toServerConfig(provider); - const response = await client.models.$post({ + const response = await client.providers[":provider_name"].models.$post({ + param: { provider_name: provider.name }, json: serverConfig, }); From 984fe706d04526cb8a2bf043c7bb3b8989c42a5b Mon Sep 17 00:00:00 2001 From: jphillips Date: Wed, 26 Mar 2025 15:09:50 -0500 Subject: [PATCH 22/69] Provider->Bridge Flow init Signed-off-by: jphillips --- .../PerspectiveActions/PerspectivesFooter.tsx | 10 +- .../context/PerspectivesDataContext.tsx | 19 +- .../hooks/useGeneratePerspectiveCaption.ts | 18 +- .../hooks/useImagePerspectives.ts | 22 +- .../server-connections/services/apiClients.ts | 132 ++--------- .../services/dataServiceClient.ts | 54 +++++ .../server-connections/services/index.ts | 5 +- .../services/inferenceBridgeClient.ts | 110 +++++++++ .../services/serverConnections.ts | 107 ++++++++- .../server/features/perspectives/models.py | 1 + .../server/features/perspectives/router.py | 30 +++ .../server/features/perspectives/service.py | 48 ++-- .../server/features/providers/service.py | 213 ++++++------------ 13 files changed, 443 insertions(+), 326 deletions(-) create mode 100644 graphcap_studio/src/features/server-connections/services/dataServiceClient.ts create mode 100644 graphcap_studio/src/features/server-connections/services/inferenceBridgeClient.ts diff --git a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveActions/PerspectivesFooter.tsx b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveActions/PerspectivesFooter.tsx index 8435416d..b13c976f 100644 --- a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveActions/PerspectivesFooter.tsx +++ b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveActions/PerspectivesFooter.tsx @@ -150,10 +150,17 @@ export function PerspectivesFooter() { try { console.log("Calling generatePerspective..."); + // Find the provider object from the available providers + const providerObject = availableProviders.find(p => p.name === selectedProvider); + + if (!providerObject) { + throw new Error(`Provider "${selectedProvider}" not found in available providers`); + } + await generatePerspective( activeSchemaName!, currentImage!.path, - selectedProvider, + providerObject, // Pass the full provider object effectiveOptions, ); showMessage( @@ -172,6 +179,7 @@ export function PerspectivesFooter() { }, [ activeSchemaName, selectedProvider, + availableProviders, // Add availableProviders to the dependencies generatePerspective, captionOptions, showMessage, diff --git a/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx b/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx index d23a8de1..c1bebc91 100644 --- a/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx +++ b/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx @@ -123,7 +123,7 @@ interface PerspectivesDataContextType { generatePerspective: ( schemaName: string, imagePath: string, - provider_name?: string, + provider?: Provider, options?: CaptionOptions, ) => Promise; @@ -369,7 +369,7 @@ export function PerspectivesDataProvider({ async ( schemaName: string, imagePath: string, - provider_name?: string, + provider?: Provider, options?: CaptionOptions, ) => { if (!isServerConnected) { @@ -384,8 +384,12 @@ export function PerspectivesDataProvider({ // Add to generating list setGeneratingPerspectives((prev) => [...prev, schemaName]); - // Use provided provider or selected provider - const effectiveProvider = provider_name ?? selectedProvider; + // Use provided provider or get the selected provider by name from available providers + let effectiveProvider = provider; + if (!effectiveProvider && selectedProvider) { + // Find the provider object by name + effectiveProvider = availableProviders.find(p => p.name === selectedProvider); + } if (!effectiveProvider) { throw new Error("No provider selected for caption generation"); @@ -403,7 +407,7 @@ export function PerspectivesDataProvider({ const result = await generateCaptionMutation.mutateAsync({ perspective: schemaName, imagePath, - provider_name: effectiveProvider, + provider: effectiveProvider, options: options ?? captionOptions, }); @@ -438,7 +442,7 @@ export function PerspectivesDataProvider({ ); return "MISSING_MODEL"; })(), - provider: effectiveProvider, + provider: effectiveProvider.name, content: result.result || {}, options: options || captionOptions, }; @@ -456,7 +460,7 @@ export function PerspectivesDataProvider({ }, metadata: { captioned_at: new Date().toISOString(), - provider: effectiveProvider, + provider: effectiveProvider.name, model: result.metadata?.model ?? "unknown", }, }; @@ -480,6 +484,7 @@ export function PerspectivesDataProvider({ isServerConnected, currentImage, selectedProvider, + availableProviders, captionOptions, generateCaptionMutation, ], diff --git a/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts b/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts index 263ae356..b045b10a 100644 --- a/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts +++ b/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts @@ -6,6 +6,10 @@ */ import { useServerConnectionsContext } from "@/context"; +import { + type Provider, + toServerConfig, +} from "@/features/inference/providers/types"; import { SERVER_IDS } from "@/features/server-connections/constants"; import { createInferenceBridgeClient } from "@/features/server-connections/services/apiClients"; import type { ServerConnection } from "@/features/server-connections/types"; @@ -29,11 +33,11 @@ export function useGeneratePerspectiveCaption() { { perspective: string; imagePath: string; - provider_name: string; + provider: Provider; // Use Provider type from types.ts options?: CaptionOptions; } >({ - mutationFn: async ({ perspective, imagePath, provider_name, options }) => { + mutationFn: async ({ perspective, imagePath, provider, options }) => { const graphcapServerConnection = connections.find( (conn: ServerConnection) => conn.id === SERVER_IDS.INFERENCE_BRIDGE, ); @@ -49,10 +53,13 @@ export function useGeneratePerspectiveCaption() { // Use the inference bridge client instead of direct fetch const client = createInferenceBridgeClient(connections); - + // Normalize the image path to ensure it starts with /workspace const normalizedImagePath = ensureWorkspacePath(imagePath); + // Convert provider to server config + const providerConfig = toServerConfig(provider); + console.log( `Generating caption for image: ${normalizedImagePath} using perspective: ${perspective}`, ); @@ -61,7 +68,8 @@ export function useGeneratePerspectiveCaption() { const requestBody = { perspective, image_path: normalizedImagePath, - provider: provider_name, + provider: provider.name, + provider_config: providerConfig, // Include the full provider configuration max_tokens: options.max_tokens, temperature: options.temperature, top_p: options.top_p, @@ -75,7 +83,7 @@ export function useGeneratePerspectiveCaption() { console.log("Sending caption generation request using API client", { perspective, image_path: normalizedImagePath, - provider: provider_name, + provider: provider.name, options: { max_tokens: requestBody.max_tokens, temperature: requestBody.temperature, diff --git a/graphcap_studio/src/features/perspectives/hooks/useImagePerspectives.ts b/graphcap_studio/src/features/perspectives/hooks/useImagePerspectives.ts index 67bcc797..7c51a921 100644 --- a/graphcap_studio/src/features/perspectives/hooks/useImagePerspectives.ts +++ b/graphcap_studio/src/features/perspectives/hooks/useImagePerspectives.ts @@ -6,6 +6,7 @@ */ import { useServerConnectionsContext } from "@/context"; +import type { Provider } from "@/features/inference/providers/types"; import { SERVER_IDS } from "@/features/server-connections/constants"; import { useProviders } from "@/features/server-connections/services/providers"; import type { Image } from "@/services/images"; @@ -104,12 +105,13 @@ export function useImagePerspectives( } // Find the provider by ID if provided - let provider_name: string | undefined; + let providerObject: Provider | undefined; if (providerId && providersData) { - const provider = providersData.find((p) => p.id === providerId); - if (provider) { - provider_name = provider.name; - console.debug(`Using provider: ${provider_name} (ID: ${providerId})`); + providerObject = providersData.find((p) => p.id === providerId); + if (providerObject) { + console.debug( + `Using provider: ${providerObject.name} (ID: ${providerId})`, + ); } else { console.warn(`Provider with ID ${providerId} not found`); setError(`Provider with ID ${providerId} not found`); @@ -123,7 +125,7 @@ export function useImagePerspectives( console.log(`Generating perspective: ${perspective}`, { imagePath: image.path, - provider_name, + provider: providerObject.name, options, }); @@ -137,7 +139,7 @@ export function useImagePerspectives( const result = await generateCaption.mutateAsync({ imagePath: image.path, perspective, - provider_name, + provider: providerObject, options, }); @@ -153,7 +155,7 @@ export function useImagePerspectives( config_name: perspective, version: "1.0", model: "api-generated", - provider: provider_name, + provider: providerObject.name, content: result.content || result.result || {}, options: options, }; @@ -170,7 +172,7 @@ export function useImagePerspectives( }, metadata: { captioned_at: new Date().toISOString(), - provider: provider_name, + provider: providerObject.name, model: "api-generated", }, }; @@ -187,7 +189,7 @@ export function useImagePerspectives( metadata: { ...prevCaptions.metadata, captioned_at: new Date().toISOString(), - provider: provider_name, + provider: providerObject.name, model: "api-generated", }, }; diff --git a/graphcap_studio/src/features/server-connections/services/apiClients.ts b/graphcap_studio/src/features/server-connections/services/apiClients.ts index 3a857981..df344072 100644 --- a/graphcap_studio/src/features/server-connections/services/apiClients.ts +++ b/graphcap_studio/src/features/server-connections/services/apiClients.ts @@ -2,119 +2,23 @@ /** * API Clients Service * - * This module provides centralized client functions for interacting with various server APIs. + * This module re-exports client functions for interacting with various server APIs. */ -import { hc } from "hono/client"; -import { DEFAULT_URLS, SERVER_IDS } from "../constants"; -import type { ServerConnection } from "../types"; - -/** - * Interface for the Data Service client - */ -export interface DataServiceClient { - providers: { - $get: () => Promise; - $post: (options: { json: unknown }) => Promise; - ":id": { - $get: (options: { param: { id: string } }) => Promise; - $put: (options: { - param: { id: string }; - json: unknown; - }) => Promise; - $delete: (options: { param: { id: string } }) => Promise; - }; - }; -} - -/** - * Interface for the Inference Bridge client - */ -export interface InferenceBridgeClient { - providers: { - ":provider_name": { - "test-connection": { - $post: (options: { - param: { provider_name: string }; - json: unknown; - }) => Promise; - }; - "models": { - $post: (options: { - param: { provider_name: string }; - json: unknown; - }) => Promise; - }; - }; - }; - perspectives: { - list: { - $get: () => Promise; - }; - modules: { - $get: () => Promise; - ":moduleName": { - $get: (options: { param: { moduleName: string } }) => Promise; - }; - }; - "caption-from-path": { - $post: (options: { json: unknown }) => Promise; - }; - ":name": { - $post: (options: { - param: { name: string }; - json: unknown; - formData?: FormData; - }) => Promise; - }; - }; -} - -/** - * Get the Data Service URL from server connections - */ -export function getDataServiceUrl(connections: ServerConnection[]): string { - const dataServiceConnection = connections.find( - (conn) => conn.id === SERVER_IDS.DATA_SERVICE, - ); - - return ( - dataServiceConnection?.url || - import.meta.env.VITE_DATA_SERVICE_URL || - DEFAULT_URLS[SERVER_IDS.DATA_SERVICE] - ); -} - -/** - * Create a Hono client for the Data Service - */ -export function createDataServiceClient(connections: ServerConnection[]): DataServiceClient { - const baseUrl = getDataServiceUrl(connections); - return hc(`${baseUrl}/api/v1`) as unknown as DataServiceClient; -} - -/** - * Get the Inference Bridge URL from server connections - */ -export function getInferenceBridgeUrl(connections: ServerConnection[]): string { - const inferenceBridgeConnection = connections.find( - (conn) => conn.id === SERVER_IDS.INFERENCE_BRIDGE, - ); - - return ( - inferenceBridgeConnection?.url || - import.meta.env.VITE_INFERENCE_BRIDGE_URL || - DEFAULT_URLS[SERVER_IDS.INFERENCE_BRIDGE] - ); -} - -/** - * Create a Hono client for the Inference Bridge - * Automatically appends /api/v1 to the base URL - */ -export function createInferenceBridgeClient(connections: ServerConnection[]): InferenceBridgeClient { - const baseUrl = getInferenceBridgeUrl(connections); - // Ensure the URL doesn't already have /api/v1 - const apiUrl = baseUrl.endsWith('/api/v1') ? baseUrl : `${baseUrl}/api/v1`; - return hc(apiUrl) as unknown as InferenceBridgeClient; -} \ No newline at end of file +// Re-export everything from dataServiceClient.ts +export { + type DataServiceClient, + getDataServiceUrl, + createDataServiceClient, +} from "./dataServiceClient"; + +// Re-export everything from inferenceBridgeClient.ts +export { + type InferenceBridgeClient, + type ProviderClient, + type PerspectivesClient, + getInferenceBridgeUrl, + createInferenceBridgeClient, + createProviderClient, + createPerspectivesClient, +} from "./inferenceBridgeClient"; \ No newline at end of file diff --git a/graphcap_studio/src/features/server-connections/services/dataServiceClient.ts b/graphcap_studio/src/features/server-connections/services/dataServiceClient.ts new file mode 100644 index 00000000..f201a5b3 --- /dev/null +++ b/graphcap_studio/src/features/server-connections/services/dataServiceClient.ts @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Data Service API Client + * + * This module provides client functions for interacting with the Data Service API. + */ + +import { hc } from "hono/client"; +import { DEFAULT_URLS, SERVER_IDS } from "../constants"; +import type { ServerConnection } from "../types"; + +/** + * Interface for the Data Service client + */ +export interface DataServiceClient { + providers: { + $get: () => Promise; + $post: (options: { json: unknown }) => Promise; + ":id": { + $get: (options: { param: { id: string } }) => Promise; + $put: (options: { + param: { id: string }; + json: unknown; + }) => Promise; + $delete: (options: { param: { id: string } }) => Promise; + }; + }; + health: { + $get: () => Promise; + }; +} + +/** + * Get the Data Service URL from server connections + */ +export function getDataServiceUrl(connections: ServerConnection[]): string { + const dataServiceConnection = connections.find( + (conn) => conn.id === SERVER_IDS.DATA_SERVICE, + ); + + return ( + dataServiceConnection?.url || + import.meta.env.VITE_DATA_SERVICE_URL || + DEFAULT_URLS[SERVER_IDS.DATA_SERVICE] + ); +} + +/** + * Create a Hono client for the Data Service + */ +export function createDataServiceClient(connections: ServerConnection[]): DataServiceClient { + const baseUrl = getDataServiceUrl(connections); + return hc(`${baseUrl}/api/v1`) as unknown as DataServiceClient; +} \ No newline at end of file diff --git a/graphcap_studio/src/features/server-connections/services/index.ts b/graphcap_studio/src/features/server-connections/services/index.ts index 9ecd8cec..b4f8b761 100644 --- a/graphcap_studio/src/features/server-connections/services/index.ts +++ b/graphcap_studio/src/features/server-connections/services/index.ts @@ -10,6 +10,8 @@ export { export type { DataServiceClient, InferenceBridgeClient, + ProviderClient, + PerspectivesClient, } from "./apiClients"; export { @@ -17,6 +19,8 @@ export { createDataServiceClient, getInferenceBridgeUrl, createInferenceBridgeClient, + createProviderClient, + createPerspectivesClient, } from "./apiClients"; // Provider services @@ -27,6 +31,5 @@ export { useCreateProvider, useUpdateProvider, useDeleteProvider, - useUpdateProviderApiKey, useProviderModels, } from "./providers"; diff --git a/graphcap_studio/src/features/server-connections/services/inferenceBridgeClient.ts b/graphcap_studio/src/features/server-connections/services/inferenceBridgeClient.ts new file mode 100644 index 00000000..6433fc5b --- /dev/null +++ b/graphcap_studio/src/features/server-connections/services/inferenceBridgeClient.ts @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Inference Bridge API Client + * + * This module provides client functions for interacting with the Inference Bridge API. + */ + +import { hc } from "hono/client"; +import { DEFAULT_URLS, SERVER_IDS } from "../constants"; +import type { ServerConnection } from "../types"; + +/** + * Interface for the Inference Bridge Provider operations + */ +export interface ProviderClient { + ":provider_name": { + "test-connection": { + $post: (options: { + param: { provider_name: string }; + json: unknown; + }) => Promise; + }; + "models": { + $post: (options: { + param: { provider_name: string }; + json: unknown; + }) => Promise; + }; + }; +} + +/** + * Interface for the Inference Bridge Perspectives operations + */ +export interface PerspectivesClient { + list: { + $get: () => Promise; + }; + modules: { + $get: () => Promise; + ":moduleName": { + $get: (options: { param: { moduleName: string } }) => Promise; + }; + }; + "caption-from-path": { + $post: (options: { json: unknown }) => Promise; + }; + ":name": { + $post: (options: { + param: { name: string }; + json: unknown; + formData?: FormData; + }) => Promise; + }; +} + +/** + * Interface for the Inference Bridge client - combines provider and perspectives APIs + */ +export interface InferenceBridgeClient { + providers: ProviderClient; + perspectives: PerspectivesClient; + health: { + $get: () => Promise; + }; +} + +/** + * Get the Inference Bridge URL from server connections + */ +export function getInferenceBridgeUrl(connections: ServerConnection[]): string { + const inferenceBridgeConnection = connections.find( + (conn) => conn.id === SERVER_IDS.INFERENCE_BRIDGE, + ); + + return ( + inferenceBridgeConnection?.url || + import.meta.env.VITE_INFERENCE_BRIDGE_URL || + DEFAULT_URLS[SERVER_IDS.INFERENCE_BRIDGE] + ); +} + +/** + * Create a Hono client for the Inference Bridge + * Automatically appends /api/v1 to the base URL + */ +export function createInferenceBridgeClient(connections: ServerConnection[]): InferenceBridgeClient { + const baseUrl = getInferenceBridgeUrl(connections); + // Ensure the URL doesn't already have /api/v1 + const apiUrl = baseUrl.endsWith('/api/v1') ? baseUrl : `${baseUrl}/api/v1`; + return hc(apiUrl) as unknown as InferenceBridgeClient; +} + +/** + * Create a client for provider operations only + */ +export function createProviderClient(connections: ServerConnection[]): ProviderClient { + const baseUrl = getInferenceBridgeUrl(connections); + const apiUrl = baseUrl.endsWith('/api/v1') ? baseUrl : `${baseUrl}/api/v1`; + return (hc(apiUrl) as unknown as InferenceBridgeClient).providers; +} + +/** + * Create a client for perspectives operations only + */ +export function createPerspectivesClient(connections: ServerConnection[]): PerspectivesClient { + const baseUrl = getInferenceBridgeUrl(connections); + const apiUrl = baseUrl.endsWith('/api/v1') ? baseUrl : `${baseUrl}/api/v1`; + return (hc(apiUrl) as unknown as InferenceBridgeClient).perspectives; +} \ No newline at end of file diff --git a/graphcap_studio/src/features/server-connections/services/serverConnections.ts b/graphcap_studio/src/features/server-connections/services/serverConnections.ts index deb38cbf..cbc51a61 100644 --- a/graphcap_studio/src/features/server-connections/services/serverConnections.ts +++ b/graphcap_studio/src/features/server-connections/services/serverConnections.ts @@ -6,7 +6,9 @@ * such as the Media Server and Inference Bridge. */ -import { SERVER_IDS } from "../constants"; +import { CONNECTION_STATUS, SERVER_IDS } from "../constants"; +import type { ServerConnection } from "../types"; +import { createDataServiceClient, createInferenceBridgeClient } from "./apiClients"; /** * Interface for health check response @@ -69,12 +71,37 @@ export async function checkMediaServerHealth(url: string): Promise { */ export async function checkInferenceBridgeHealth(url: string): Promise { try { - // Normalize URL by removing trailing slash if present - const normalizedUrl = url.endsWith("/") ? url.slice(0, -1) : url; + // Create mock connection array with the URL + const mockConnection: ServerConnection[] = [ + { + id: SERVER_IDS.INFERENCE_BRIDGE, + name: "Inference Bridge", + status: CONNECTION_STATUS.DISCONNECTED, + url, + }, + ]; + + // Create client with the URL + const client = createInferenceBridgeClient(mockConnection); + + // First try the /api/v1/health endpoint using the client + try { + const response = await client.health.$get(); + + if (response.ok) { + const data = (await response.json()) as HealthCheckResponse; + // Check if the response contains a valid status + return data.status === "ok" || data.status === "healthy"; + } + } catch (apiError) { + console.warn("Error checking Inference Bridge at /api/v1/health, trying direct health endpoint next:", apiError); + } - // First try the /api/v1/health endpoint + // Try direct /api/v1/health endpoint try { - const response = await fetch(`${normalizedUrl}/api/v1/health`, { + // Normalize URL by removing trailing slash if present + const normalizedUrl = url.endsWith("/") ? url.slice(0, -1) : url; + const apiResponse = await fetch(`${normalizedUrl}/api/v1/health`, { method: "GET", headers: { Accept: "application/json", @@ -83,16 +110,18 @@ export async function checkInferenceBridgeHealth(url: string): Promise signal: AbortSignal.timeout(3000), }); - if (response.ok) { - const data = (await response.json()) as HealthCheckResponse; + if (apiResponse.ok) { + const data = (await apiResponse.json()) as HealthCheckResponse; // Check if the response contains a valid status return data.status === "ok" || data.status === "healthy"; } - } catch (apiError) { - console.warn("Error checking Inference Bridge at /api/v1/health, trying /health next:", apiError); + } catch (directApiError) { + console.warn("Error checking Inference Bridge at direct /api/v1/health, trying /health next:", directApiError); } - // Fallback to the /health endpoint + // Fallback to the legacy /health endpoint with direct fetch as last resort + // Normalize URL by removing trailing slash if present + const normalizedUrl = url.endsWith("/") ? url.slice(0, -1) : url; const fallbackResponse = await fetch(`${normalizedUrl}/health`, { method: "GET", headers: { @@ -103,7 +132,7 @@ export async function checkInferenceBridgeHealth(url: string): Promise }); if (!fallbackResponse.ok) { - console.error(`Both health check endpoints failed. Last status: ${fallbackResponse.status}`); + console.error(`All health check endpoints failed. Last status: ${fallbackResponse.status}`); return false; } @@ -123,7 +152,61 @@ export async function checkInferenceBridgeHealth(url: string): Promise * @returns A promise that resolves to a boolean indicating if the server is healthy */ export async function checkDataServiceHealth(url: string): Promise { - return checkServerHealth(url); + try { + // Create mock connection array with the URL + const mockConnection: ServerConnection[] = [ + { + id: SERVER_IDS.DATA_SERVICE, + name: "Data Service", + status: CONNECTION_STATUS.DISCONNECTED, + url, + }, + ]; + + // Create client with the URL + const client = createDataServiceClient(mockConnection); + + // Try the /api/v1/health endpoint using the client + try { + const response = await client.health.$get(); + + if (response.ok) { + const data = (await response.json()) as HealthCheckResponse; + // Check if the response contains a valid status + return data.status === "ok" || data.status === "healthy"; + } + } catch (apiError) { + console.warn("Error checking Data Service at /api/v1/health, trying direct endpoint next:", apiError); + } + + // Try direct /api/v1/health endpoint + try { + // Normalize URL by removing trailing slash if present + const normalizedUrl = url.endsWith("/") ? url.slice(0, -1) : url; + const apiResponse = await fetch(`${normalizedUrl}/api/v1/health`, { + method: "GET", + headers: { + Accept: "application/json", + }, + // Set a timeout to prevent long-hanging requests + signal: AbortSignal.timeout(3000), + }); + + if (apiResponse.ok) { + const data = (await apiResponse.json()) as HealthCheckResponse; + // Check if the response contains a valid status + return data.status === "ok" || data.status === "healthy"; + } + } catch (directApiError) { + console.warn("Error checking Data Service at direct /api/v1/health, trying /health next:", directApiError); + } + + // Fallback to the direct /health endpoint check as last resort + return checkServerHealth(url); + } catch (error) { + console.error("Error checking Data Service health:", error); + return false; + } } /** diff --git a/servers/inference_bridge/server/server/features/perspectives/models.py b/servers/inference_bridge/server/server/features/perspectives/models.py index f975e49c..03a34469 100644 --- a/servers/inference_bridge/server/server/features/perspectives/models.py +++ b/servers/inference_bridge/server/server/features/perspectives/models.py @@ -190,6 +190,7 @@ class CaptionPathRequest(BaseModel): perspective: str = Field(..., description=DESC_PERSPECTIVE_NAME) image_path: str = Field(..., description="Path to the image file in the workspace") provider: str = Field("gemini", description="Name of the provider to use") + provider_config: dict = Field(..., description="Provider configuration") max_tokens: Optional[int] = Field(4096, description=DESC_MAX_TOKENS) temperature: Optional[float] = Field(0.8, description=DESC_TEMPERATURE) top_p: Optional[float] = Field(0.9, description=DESC_TOP_P) diff --git a/servers/inference_bridge/server/server/features/perspectives/router.py b/servers/inference_bridge/server/server/features/perspectives/router.py index 3671273e..280d6aa1 100644 --- a/servers/inference_bridge/server/server/features/perspectives/router.py +++ b/servers/inference_bridge/server/server/features/perspectives/router.py @@ -59,6 +59,7 @@ async def create_caption( file: UploadFile = File(..., description="Image file to upload"), perspective: str = Form(..., description="Name of the perspective to use"), provider: str = Form("gemini", description="Name of the provider to use"), + provider_config: Optional[str] = Form(None, description="Provider configuration as JSON string"), max_tokens: Optional[int] = Form(4096, description="Maximum number of tokens"), temperature: Optional[float] = Form(0.8, description="Temperature for generation"), top_p: Optional[float] = Form(0.9, description="Top-p sampling parameter"), @@ -77,6 +78,7 @@ async def create_caption( file: Image file to upload (required) perspective: Name of the perspective to use (required) provider: Name of the provider to use (optional, default: "default") + provider_config: Provider configuration as JSON string (optional) max_tokens: Maximum number of tokens (optional, default: 4096) temperature: Temperature for generation (optional, default: 0.8) top_p: Top-p sampling parameter (optional, default: 0.9) @@ -95,6 +97,16 @@ async def create_caption( # Parse context from JSON string if provided parsed_context = _parse_context(context) + # Parse provider_config from JSON string if provided + parsed_provider_config = None + if provider_config: + try: + parsed_provider_config = json.loads(provider_config) + logger.info(f"Parsed provider configuration for {provider}") + except json.JSONDecodeError as e: + logger.error(f"Invalid provider configuration JSON: {e}") + raise HTTPException(status_code=400, detail=f"Invalid provider configuration JSON: {str(e)}") + # Process the uploaded file image_path = await save_uploaded_file(file) @@ -136,6 +148,14 @@ async def create_caption( # Add cleanup task background_tasks.add_task(lambda: os.unlink(image_path) if os.path.exists(image_path) else None) + # Validate provider configuration + if not parsed_provider_config: + logger.error(f"No provider configuration provided for {provider}") + raise HTTPException( + status_code=400, + detail=f"Provider configuration not provided for '{provider}'. Please include provider_config in the request." + ) + # Generate the caption caption_data = await generate_caption( perspective_name=perspective, @@ -147,6 +167,7 @@ async def create_caption( context=parsed_context, global_context=global_context, provider_name=provider, + provider_config=parsed_provider_config, ) # Log the caption data for debugging @@ -204,6 +225,14 @@ async def create_caption_from_path( # Process context context = _process_context(request.context) + # Validate that provider_config is present + if not hasattr(request, 'provider_config') or not request.provider_config: + logger.error(f"No provider configuration provided for {request.provider}") + raise HTTPException( + status_code=400, + detail=f"Provider configuration not provided for '{request.provider}'. Please include provider_config in the request." + ) + # Generate the caption caption_data = await generate_caption( perspective_name=request.perspective, @@ -215,6 +244,7 @@ async def create_caption_from_path( context=context, global_context=request.global_context, provider_name=request.provider, + provider_config=request.provider_config, ) # Clean up temporary file if we created one diff --git a/servers/inference_bridge/server/server/features/perspectives/service.py b/servers/inference_bridge/server/server/features/perspectives/service.py index 07931bd8..19a3155e 100644 --- a/servers/inference_bridge/server/server/features/perspectives/service.py +++ b/servers/inference_bridge/server/server/features/perspectives/service.py @@ -22,7 +22,7 @@ from graphcap.providers.clients.base_client import BaseClient from loguru import logger -from ..providers.service import get_provider_manager +from ..providers.service import create_provider_client_from_config from .models import ModuleInfo, PerspectiveInfo, PerspectiveSchema, SchemaField, TableColumn @@ -270,6 +270,7 @@ async def generate_caption( context: Optional[List[str]] = None, global_context: Optional[str] = None, provider_name: str = "gemini", + provider_config: Optional[dict] = None, ) -> Dict: """ Generate a caption for an image using a perspective. @@ -284,6 +285,7 @@ async def generate_caption( context: Additional context for the caption global_context: Global context for the caption provider_name: Name of the provider to use (default: "gemini") + provider_config: Full provider configuration if available Returns: Caption data @@ -295,35 +297,23 @@ async def generate_caption( # Get the perspective perspective = get_perspective(perspective_name) - # Get the provider client from the provider manager - provider_manager = get_provider_manager() - - # Debug: Log available providers - available_providers = provider_manager.available_providers() - logger.debug(f"Available providers: {available_providers}") - - # Debug: Try to resolve host.docker.internal - try: - host_ip = socket.gethostbyname("host.docker.internal") - logger.debug(f"host.docker.internal resolves to: {host_ip}") - except socket.gaierror as e: - logger.warning(f"Could not resolve host.docker.internal: {e}") - - try: - provider: BaseClient = provider_manager.get_client(provider_name) - # Debug: Log provider details - logger.debug("Provider details:") - logger.debug(f" - Name: {provider_name}") - logger.debug(f" - Kind: {provider.kind}") - logger.debug(f" - Environment: {provider.environment}") - logger.debug(f" - Base URL: {provider.base_url}") - logger.debug(f" - Default Model: {provider.default_model}") - except ValueError as e: - logger.error(f"Provider '{provider_name}' not found: {str(e)}") + # Create a provider client using the config if provided + if provider_config: + from ..providers.models import ProviderConfig + from ..providers.service import create_provider_client_from_config + + # Convert dict to ProviderConfig + config = ProviderConfig(**provider_config) + provider = create_provider_client_from_config(config) + logger.info(f"Created provider client from provided config for {provider_name}") + else: + # Legacy path - will likely fail as no provider manager exists + logger.error(f"No provider configuration provided for {provider_name}. Caption generation will likely fail.") + logger.error("Provider configuration must be provided in the request.") raise HTTPException( - status_code=404, - detail=f"""Provider '{provider_name}' not found. - Available providers: {', '.join(provider_manager.available_providers())}""", + status_code=400, + detail=f"""Provider configuration not provided for '{provider_name}'. + Provider configuration must be included in the request.""", ) # Create a temporary output directory diff --git a/servers/inference_bridge/server/server/features/providers/service.py b/servers/inference_bridge/server/server/features/providers/service.py index ef48fceb..14aaac6d 100644 --- a/servers/inference_bridge/server/server/features/providers/service.py +++ b/servers/inference_bridge/server/server/features/providers/service.py @@ -65,7 +65,7 @@ async def get_provider_models(provider_name: str, config: ProviderConfig) -> Lis models = [] # Try to fetch models if configured - if config.fetch_models and isinstance(client, ModelProvider): + if config.fetch_models: try: logger.info(f"Fetching models from provider {provider_name}") if hasattr(client, "get_available_models"): @@ -92,39 +92,30 @@ async def get_provider_models(provider_name: str, config: ProviderConfig) -> Lis return models -def get_provider_manager(): +def create_provider_client_from_config(config: ProviderConfig) -> BaseClient: """ - Get a compatible provider manager using the factory pattern. - This function creates a wrapper around the provider factory that maintains - backward compatibility with code expecting a provider manager. + Create a provider client from a configuration. - Returns: - An object with the provider manager interface - """ - factory = get_provider_factory() - - # Create a wrapper object that delegates to the factory - class ProviderManagerWrapper: - def __init__(self, factory): - self.factory = factory - self._client_cache: Dict[str, BaseClient] = {} + Args: + config: Provider configuration - def get_client(self, name: str) -> BaseClient: - """Get a client for the specified provider""" - if name in self._client_cache: - return self._client_cache[name] - - # In a real implementation, this would look up the config for the name - # For now, we're just passing through to create_provider_client - # which will fail if the provider doesn't exist - return create_provider_client(name=name, kind="", environment="", base_url="", api_key="") + Returns: + Provider client - def available_providers(self) -> List[str]: - """Return a list of available provider names""" - # This is a stub - in a real implementation we would return actual providers - return ["gemini"] - - return ProviderManagerWrapper(factory) + Raises: + ValueError: If client creation fails + """ + logger.info(f"Creating provider client from config for {config.name}") + return create_provider_client( + name=config.name, + kind=config.kind, + environment=config.environment, + base_url=config.base_url, + api_key=config.api_key, + default_model=config.default_model, + rate_limits=config.rate_limits, + use_cache=True, + ) async def test_provider_connection(provider_name: str, config: ProviderConfig) -> Dict[str, Any]: @@ -198,138 +189,66 @@ async def test_provider_connection(provider_name: str, config: ProviderConfig) - # Try to test the connection with a lightweight operation # First check if we can get models (most providers support this) - if isinstance(client, ModelProvider): - try: - # Add diagnostic step for model list - result["diagnostics"]["connection_steps"].append({ - "step": "list_models", - "status": "pending", - "timestamp": str(datetime.datetime.now()) - }) - - if hasattr(client, "get_available_models"): - provider_models = await client.get_available_models() - result["details"]["method"] = "get_available_models" - - # Add model information if available - if hasattr(provider_models, "data"): - models_data = [] - for model in provider_models.data: - model_id = _extract_model_id(model) - models_data.append({"id": model_id}) - result["details"]["available_models"] = models_data - result["details"]["models_count"] = len(models_data) - - elif hasattr(client, "get_models"): - provider_models = await client.get_models() - result["details"]["method"] = "get_models" - - # Add model information if available - if hasattr(provider_models, "models"): - models_data = [] - for model in provider_models.models: - model_id = _extract_model_id(model) - models_data.append({"id": model_id}) - result["details"]["available_models"] = models_data - result["details"]["models_count"] = len(models_data) - - # Update diagnostic step - result["diagnostics"]["connection_steps"][-1]["status"] = "success" - - except Exception as e: - logger.warning(f"Could not list models for {provider_name}: {str(e)}") - result["diagnostics"]["connection_steps"][-1]["status"] = "skipped" - result["diagnostics"]["connection_steps"][-1]["message"] = "Model listing not supported or failed" - - # Try a simple chat completion as a more thorough test try: - # Add diagnostic step for chat completion + # Add diagnostic step for model list result["diagnostics"]["connection_steps"].append({ - "step": "test_chat_completion", + "step": "list_models", "status": "pending", "timestamp": str(datetime.datetime.now()) }) - # Use the first available model or default model - test_model = None - if result.get("details", {}).get("available_models"): - test_model = result["details"]["available_models"][0]["id"] - elif config.default_model: - test_model = config.default_model - - if test_model: - # Simple test message - test_messages = [{"role": "user", "content": "Hello, this is a test message. Please respond with 'OK' if you can process this request."}] - - completion = await client.chat.completions.create( - model=test_model, - messages=test_messages, - max_tokens=10, # Keep it minimal - temperature=0, # Deterministic - ) + if hasattr(client, "get_available_models"): + provider_models = await client.get_available_models() + result["details"]["method"] = "get_available_models" - result["connection_verified"] = True - result["details"]["chat_completion_test"] = "success" - result["details"]["test_model"] = test_model - result["diagnostics"]["connection_steps"][-1]["status"] = "success" - else: - result["diagnostics"]["connection_steps"][-1]["status"] = "skipped" - result["diagnostics"]["connection_steps"][-1]["message"] = "No suitable model found for testing" + # Add model information if available + if hasattr(provider_models, "data"): + models_data = [] + for model in provider_models.data: + model_id = _extract_model_id(model) + models_data.append({"id": model_id}) + result["details"]["available_models"] = models_data + result["details"]["models_count"] = len(models_data) + + elif hasattr(client, "get_models"): + provider_models = await client.get_models() + result["details"]["method"] = "get_models" + # Add model information if available + if hasattr(provider_models, "models"): + models_data = [] + for model in provider_models.models: + model_id = _extract_model_id(model) + models_data.append({"id": model_id}) + result["details"]["available_models"] = models_data + result["details"]["models_count"] = len(models_data) + + # Update diagnostic step + result["diagnostics"]["connection_steps"][-1]["status"] = "success" + except Exception as e: - logger.error(f"Chat completion test failed for {provider_name}: {str(e)}") - result["diagnostics"]["connection_steps"][-1]["status"] = "failed" - result["diagnostics"]["connection_steps"][-1]["error"] = str(e) - result["diagnostics"]["connection_steps"][-1]["error_type"] = type(e).__name__ + logger.warning(f"Could not list models for {provider_name}: {str(e)}") + result["diagnostics"]["connection_steps"][-1]["status"] = "skipped" + result["diagnostics"]["connection_steps"][-1]["message"] = "Model listing not supported or failed" - # Only mark connection as failed if we couldn't list models either - if not result.get("details", {}).get("available_models"): - result["connection_verified"] = False - result["details"]["error"] = str(e) - result["details"]["error_type"] = type(e).__name__ - - # Add specific suggestions based on error type - if "authentication" in str(e).lower() or "unauthorized" in str(e).lower() or "auth" in str(e).lower(): - result["details"]["suggestion"] = "Check if the API key is valid and has necessary permissions" - elif "timeout" in str(e).lower(): - result["details"]["suggestion"] = "Connection timed out. Check network connectivity or server status" - elif "url" in str(e).lower() or "endpoint" in str(e).lower(): - result["details"]["suggestion"] = "Check if the base URL is correct for this provider" - - raise Exception(f"Error verifying connection: {str(e)}") - else: - # If we could list models but chat completion failed, just warn - result["diagnostics"]["warnings"].append({ - "warning_type": "chat_completion_failed", - "message": f"Chat completion test failed but model listing succeeded. Provider may have limited functionality. Error: {str(e)}" - }) - result["connection_verified"] = True - result["details"]["method"] = "list_models_only" + # Connection test successful + result["connected"] = True + result["success"] = True + result["message"] = f"Successfully connected to {provider_name}" return result - except Exception as e: - logger.error(f"Error initializing provider client for {provider_name}: {str(e)}") + logger.error(f"Error testing connection to {provider_name}: {str(e)}") - # Update diagnostic information - if "connection_steps" in result["diagnostics"] and result["diagnostics"]["connection_steps"]: + # Update the last diagnostic step if it's pending + if result["diagnostics"]["connection_steps"] and result["diagnostics"]["connection_steps"][-1]["status"] == "pending": result["diagnostics"]["connection_steps"][-1]["status"] = "failed" result["diagnostics"]["connection_steps"][-1]["error"] = str(e) - result["diagnostics"]["connection_steps"][-1]["error_type"] = type(e).__name__ - result["client_initialized"] = False - result["connection_verified"] = False - result["details"]["error"] = str(e) - result["details"]["error_type"] = type(e).__name__ + # Add overall failure information + result["connected"] = False + result["success"] = False + result["message"] = f"Failed to connect to {provider_name}: {str(e)}" + result["error"] = str(e) - # Add specific suggestions based on error type - if any(keyword in str(e).lower() for keyword in ["api key", "authentication", "auth", "credential"]): - result["details"]["suggestion"] = "Check if the API key is valid" - elif any(keyword in str(e).lower() for keyword in ["url", "endpoint", "address"]): - result["details"]["suggestion"] = "Check if the base URL is correct" - elif any(keyword in str(e).lower() for keyword in ["timeout", "connect"]): - result["details"]["suggestion"] = "Network connectivity issue. Check your internet connection" - elif "provider" in str(e).lower(): - result["details"]["suggestion"] = "Verify that the provider type is supported and correctly configured" - - raise Exception(f"Failed to initialize provider client: {str(e)}", result) + return result From 86412b389f542d0a704a7e1f6ffd79535e3af6e3 Mon Sep 17 00:00:00 2001 From: jphillips Date: Wed, 26 Mar 2025 15:55:02 -0500 Subject: [PATCH 23/69] Remove default model concept from inference bridge Signed-off-by: jphillips --- .../graphcap/providers/clients/base_client.py | 4 +--- .../providers/clients/gemini_client.py | 3 +-- .../providers/clients/ollama_client.py | 4 +--- .../providers/clients/openai_client.py | 3 +-- .../providers/clients/openrouter_client.py | 11 +++++++--- .../graphcap/providers/clients/vllm_client.py | 5 ++--- .../graphcap/providers/factory.py | 8 ------- .../graphcap/providers/types.py | 1 - .../server/features/perspectives/models.py | 12 ++++++++-- .../server/features/perspectives/router.py | 22 +++++++++---------- .../server/features/providers/models.py | 2 -- .../server/features/providers/service.py | 14 +++++------- 12 files changed, 39 insertions(+), 50 deletions(-) diff --git a/servers/inference_bridge/graphcap/providers/clients/base_client.py b/servers/inference_bridge/graphcap/providers/clients/base_client.py index a32ea8bb..5e372f3a 100644 --- a/servers/inference_bridge/graphcap/providers/clients/base_client.py +++ b/servers/inference_bridge/graphcap/providers/clients/base_client.py @@ -19,7 +19,6 @@ environment (str): Deployment environment env_var (str): Environment variable for API key base_url (str): Base API URL - default_model (str): Default model identifier """ import asyncio @@ -38,7 +37,7 @@ class BaseClient(AsyncOpenAI, ABC): """Abstract base class for all provider clients""" - def __init__(self, name: str, kind: str, environment: str, base_url: str, default_model: str, api_key: str): + def __init__(self, name: str, kind: str, environment: str, base_url: str, api_key: str): # Initialize OpenAI client super().__init__(api_key=api_key, base_url=base_url) @@ -47,7 +46,6 @@ def __init__(self, name: str, kind: str, environment: str, base_url: str, defaul self.kind = kind self.environment = environment self.base_url = base_url - self.default_model = default_model # Rate limiting state self._request_times: list[float] = [] diff --git a/servers/inference_bridge/graphcap/providers/clients/gemini_client.py b/servers/inference_bridge/graphcap/providers/clients/gemini_client.py index 670fd8f2..e7ee0b43 100644 --- a/servers/inference_bridge/graphcap/providers/clients/gemini_client.py +++ b/servers/inference_bridge/graphcap/providers/clients/gemini_client.py @@ -27,14 +27,13 @@ class GeminiClient(BaseClient): """Client for Google's Gemini API with OpenAI compatibility layer""" - def __init__(self, name: str, kind: str, environment: str, base_url: str, default_model: str, api_key: str): + def __init__(self, name: str, kind: str, environment: str, base_url: str, api_key: str): logger.info(f"GeminiClient initialized with base_url: {base_url}") super().__init__( name=name, kind=kind, environment=environment, base_url=base_url.rstrip("/"), - default_model=default_model, api_key=api_key, ) diff --git a/servers/inference_bridge/graphcap/providers/clients/ollama_client.py b/servers/inference_bridge/graphcap/providers/clients/ollama_client.py index 56167622..432894cf 100644 --- a/servers/inference_bridge/graphcap/providers/clients/ollama_client.py +++ b/servers/inference_bridge/graphcap/providers/clients/ollama_client.py @@ -26,13 +26,12 @@ class OllamaClient(BaseClient): """Client for Ollama API with OpenAI compatibility layer""" - def __init__(self, name: str, kind: str, environment: str, base_url: str, default_model: str, api_key: str = "stub_key"): + def __init__(self, name: str, kind: str, environment: str, base_url: str, api_key: str = "stub_key"): logger.info("Initializing OllamaClient:") logger.info(f" - name: {name}") logger.info(f" - kind: {kind}") logger.info(f" - environment: {environment}") logger.info(f" - base_url: {base_url}") - logger.info(f" - default_model: {default_model}") # Store the raw base URL for Ollama-specific endpoints base_url = base_url.rstrip("/") @@ -57,7 +56,6 @@ def __init__(self, name: str, kind: str, environment: str, base_url: str, defaul kind=kind, environment=environment, base_url=openai_base_url, - default_model=default_model, api_key=api_key, ) logger.debug(f"OllamaClient initialized with environment: {environment}, kind: {kind}") diff --git a/servers/inference_bridge/graphcap/providers/clients/openai_client.py b/servers/inference_bridge/graphcap/providers/clients/openai_client.py index c1da80d7..5459b02f 100644 --- a/servers/inference_bridge/graphcap/providers/clients/openai_client.py +++ b/servers/inference_bridge/graphcap/providers/clients/openai_client.py @@ -29,14 +29,13 @@ class OpenAIClient(BaseClient): """Client for OpenAI API""" - def __init__(self, name: str, kind: str, environment: str, base_url: str, default_model: str, api_key: str): + def __init__(self, name: str, kind: str, environment: str, base_url: str, api_key: str): logger.info(f"OpenAIClient initialized with base_url: {base_url}") super().__init__( name=name, kind=kind, environment=environment, base_url=base_url.rstrip("/"), - default_model=default_model, api_key=api_key, ) diff --git a/servers/inference_bridge/graphcap/providers/clients/openrouter_client.py b/servers/inference_bridge/graphcap/providers/clients/openrouter_client.py index 8681ebf5..b6c220f3 100644 --- a/servers/inference_bridge/graphcap/providers/clients/openrouter_client.py +++ b/servers/inference_bridge/graphcap/providers/clients/openrouter_client.py @@ -25,16 +25,21 @@ class OpenRouterClient(BaseClient): - """Client for OpenRouter API with OpenAI compatibility layer""" + """Client for OpenRouter API""" - def __init__(self, name: str, kind: str, environment: str, base_url: str, default_model: str): + def __init__(self, name: str, kind: str, environment: str, base_url: str, api_key: str): logger.info(f"OpenRouterClient initialized with base_url: {base_url}") + + # Base URL handling for OpenRouter + if not base_url.endswith("/v1"): + base_url = f"{base_url}/v1" + logger.info(f"Added /v1 to base URL: {base_url}") + super().__init__( name=name, kind=kind, environment=environment, base_url=base_url.rstrip("/"), - default_model=default_model, api_key=api_key, ) diff --git a/servers/inference_bridge/graphcap/providers/clients/vllm_client.py b/servers/inference_bridge/graphcap/providers/clients/vllm_client.py index 63dcb9f8..6789783d 100644 --- a/servers/inference_bridge/graphcap/providers/clients/vllm_client.py +++ b/servers/inference_bridge/graphcap/providers/clients/vllm_client.py @@ -26,9 +26,9 @@ class VLLMClient(BaseClient): - """Client for VLLM API with OpenAI compatibility layer""" + """Client for vLLM API""" - def __init__(self, name: str, kind: str, environment: str, base_url: str, default_model: str, api_key: str = "stub_key"): + def __init__(self, name: str, kind: str, environment: str, base_url: str, api_key: str = "stub_key"): # If base_url doesn't include /v1, append it if not base_url.endswith("/v1"): base_url = f"{base_url}/v1" @@ -39,7 +39,6 @@ def __init__(self, name: str, kind: str, environment: str, base_url: str, defaul kind=kind, environment=environment, base_url=base_url.rstrip("/"), - default_model=default_model, api_key=api_key, ) diff --git a/servers/inference_bridge/graphcap/providers/factory.py b/servers/inference_bridge/graphcap/providers/factory.py index 805f26ff..afd7f56e 100644 --- a/servers/inference_bridge/graphcap/providers/factory.py +++ b/servers/inference_bridge/graphcap/providers/factory.py @@ -33,7 +33,6 @@ def create_client( environment: str, base_url: str, api_key: str, - default_model: Optional[str] = None, rate_limits: Optional[dict] = None, use_cache: bool = True, ) -> BaseClient: @@ -45,7 +44,6 @@ def create_client( environment: Provider environment (cloud, local) base_url: Base URL for the provider API api_key: API key for the provider - default_model: Default model for the provider rate_limits: Rate limiting configuration use_cache: Whether to cache and reuse client instances (default: True) @@ -66,7 +64,6 @@ def create_client( logger.info(f" - kind: {kind}") logger.info(f" - environment: {environment}") logger.info(f" - base_url: {base_url}") - logger.info(f" - default_model: {default_model}") try: client = get_client( @@ -75,7 +72,6 @@ def create_client( environment=environment, api_key=api_key, base_url=base_url, - default_model=default_model, ) # Set rate limits if configured @@ -98,7 +94,6 @@ def create_client( logger.error(f" - kind: {kind}") logger.error(f" - environment: {environment}") logger.error(f" - base_url: {base_url}") - logger.error(f" - default_model: {default_model}") raise ValueError(f"Failed to create client for {name}: {str(e)}") def clear_cache(self) -> None: @@ -131,7 +126,6 @@ def create_provider_client( environment: str, base_url: str, api_key: str, - default_model: Optional[str] = None, rate_limits: Optional[dict] = None, use_cache: bool = True, ) -> BaseClient: @@ -143,7 +137,6 @@ def create_provider_client( environment: Provider environment (cloud, local) base_url: Base URL for the provider API api_key: API key for the provider - default_model: Default model for the provider rate_limits: Rate limiting configuration use_cache: Whether to cache and reuse client instances (default: True) @@ -160,7 +153,6 @@ def create_provider_client( environment=environment, base_url=base_url, api_key=api_key, - default_model=default_model, rate_limits=rate_limits, use_cache=use_cache, ) diff --git a/servers/inference_bridge/graphcap/providers/types.py b/servers/inference_bridge/graphcap/providers/types.py index 4ef7922d..d7dce188 100644 --- a/servers/inference_bridge/graphcap/providers/types.py +++ b/servers/inference_bridge/graphcap/providers/types.py @@ -22,6 +22,5 @@ class ProviderConfig: env_var: str base_url: str models: list[str] - default_model: str fetch_models: bool = False rate_limits: Optional[RateLimits] = None diff --git a/servers/inference_bridge/server/server/features/perspectives/models.py b/servers/inference_bridge/server/server/features/perspectives/models.py index 03a34469..ed0196c4 100644 --- a/servers/inference_bridge/server/server/features/perspectives/models.py +++ b/servers/inference_bridge/server/server/features/perspectives/models.py @@ -138,7 +138,7 @@ class CaptionResponse(BaseModel): """Response model for a generated caption.""" perspective: str = Field(..., description="Name of the perspective used") - provider: str = Field("gemini", description="Name of the provider used") + provider: str = Field(..., description="Name of the provider used") result: dict = Field(..., description="Structured caption result") raw_text: Optional[str] = Field(None, description="Raw text response from the model") @@ -189,7 +189,7 @@ class CaptionPathRequest(BaseModel): perspective: str = Field(..., description=DESC_PERSPECTIVE_NAME) image_path: str = Field(..., description="Path to the image file in the workspace") - provider: str = Field("gemini", description="Name of the provider to use") + provider: str = Field(..., description="Name of the provider to use") provider_config: dict = Field(..., description="Provider configuration") max_tokens: Optional[int] = Field(4096, description=DESC_MAX_TOKENS) temperature: Optional[float] = Field(0.8, description=DESC_TEMPERATURE) @@ -205,6 +205,14 @@ class Config: "perspective": "custom_caption", "image_path": "/workspace/datasets/example.jpg", "provider": "gemini", + "provider_config": { + "name": "gemini", + "kind": "gemini", + "environment": "cloud", + "api_key": "your_api_key_here", + "base_url": "https://generativelanguage.googleapis.com/v1beta", + "models": ["gemini-pro-vision"] + }, "max_tokens": 4096, "temperature": 0.8, "resize_resolution": "HD_720P", diff --git a/servers/inference_bridge/server/server/features/perspectives/router.py b/servers/inference_bridge/server/server/features/perspectives/router.py index 280d6aa1..053bd282 100644 --- a/servers/inference_bridge/server/server/features/perspectives/router.py +++ b/servers/inference_bridge/server/server/features/perspectives/router.py @@ -58,8 +58,8 @@ async def create_caption( background_tasks: BackgroundTasks, file: UploadFile = File(..., description="Image file to upload"), perspective: str = Form(..., description="Name of the perspective to use"), - provider: str = Form("gemini", description="Name of the provider to use"), - provider_config: Optional[str] = Form(None, description="Provider configuration as JSON string"), + provider: str = Form(..., description="Name of the provider to use"), + provider_config: str = Form(..., description="Provider configuration as JSON string"), max_tokens: Optional[int] = Form(4096, description="Maximum number of tokens"), temperature: Optional[float] = Form(0.8, description="Temperature for generation"), top_p: Optional[float] = Form(0.9, description="Top-p sampling parameter"), @@ -77,8 +77,8 @@ async def create_caption( background_tasks: Background tasks for cleanup file: Image file to upload (required) perspective: Name of the perspective to use (required) - provider: Name of the provider to use (optional, default: "default") - provider_config: Provider configuration as JSON string (optional) + provider: Name of the provider to use (required) + provider_config: Provider configuration as JSON string (required) max_tokens: Maximum number of tokens (optional, default: 4096) temperature: Temperature for generation (optional, default: 0.8) top_p: Top-p sampling parameter (optional, default: 0.9) @@ -98,14 +98,12 @@ async def create_caption( parsed_context = _parse_context(context) # Parse provider_config from JSON string if provided - parsed_provider_config = None - if provider_config: - try: - parsed_provider_config = json.loads(provider_config) - logger.info(f"Parsed provider configuration for {provider}") - except json.JSONDecodeError as e: - logger.error(f"Invalid provider configuration JSON: {e}") - raise HTTPException(status_code=400, detail=f"Invalid provider configuration JSON: {str(e)}") + try: + parsed_provider_config = json.loads(provider_config) + logger.info(f"Parsed provider configuration for {provider}") + except json.JSONDecodeError as e: + logger.error(f"Invalid provider configuration JSON: {e}") + raise HTTPException(status_code=400, detail=f"Invalid provider configuration JSON: {str(e)}") # Process the uploaded file image_path = await save_uploaded_file(file) diff --git a/servers/inference_bridge/server/server/features/providers/models.py b/servers/inference_bridge/server/server/features/providers/models.py index 177b2025..928bcd82 100644 --- a/servers/inference_bridge/server/server/features/providers/models.py +++ b/servers/inference_bridge/server/server/features/providers/models.py @@ -15,7 +15,6 @@ class ProviderInfo(BaseModel): name: str = Field(..., description="Unique identifier for the provider") kind: str = Field(..., description="Type of provider (e.g., 'openai', 'anthropic', 'gemini')") - default_model: str = Field("", description="Default model used by the provider") class ProviderListResponse(BaseModel): @@ -47,7 +46,6 @@ class ProviderConfig(BaseModel): environment: str = Field(..., description="Provider environment (cloud, local)") base_url: str = Field(..., description="Base URL for the provider API") api_key: str = Field(..., description="API key for the provider") - default_model: Optional[str] = Field(None, description="Default model for the provider") models: List[str] = Field(default_factory=list, description="List of available model IDs") fetch_models: bool = Field(default=True, description="Whether to fetch models from the provider API") rate_limits: Optional[dict] = Field(None, description="Rate limiting configuration") diff --git a/servers/inference_bridge/server/server/features/providers/service.py b/servers/inference_bridge/server/server/features/providers/service.py index 14aaac6d..94db9d03 100644 --- a/servers/inference_bridge/server/server/features/providers/service.py +++ b/servers/inference_bridge/server/server/features/providers/service.py @@ -31,12 +31,12 @@ def _extract_model_id(model: Any) -> str: return str(model) -def _create_model_info(model_id: str, default_model: str) -> ModelInfo: +def _create_model_info(model_id: str) -> ModelInfo: """Create a ModelInfo instance""" return ModelInfo( id=model_id, name=model_id, - is_default=model_id == default_model + is_default=False ) @@ -57,7 +57,6 @@ async def get_provider_models(provider_name: str, config: ProviderConfig) -> Lis environment=config.environment, base_url=config.base_url, api_key=config.api_key, - default_model=config.default_model, rate_limits=config.rate_limits, use_cache=True, # Cache clients for better performance ) @@ -73,20 +72,20 @@ async def get_provider_models(provider_name: str, config: ProviderConfig) -> Lis if hasattr(provider_models, "data"): for model in provider_models.data: model_id = _extract_model_id(model) - models.append(_create_model_info(model_id, config.default_model or "")) + models.append(_create_model_info(model_id)) elif hasattr(client, "get_models"): provider_models = await client.get_models() if hasattr(provider_models, "models"): for model in provider_models.models: model_id = _extract_model_id(model) - models.append(_create_model_info(model_id, config.default_model or "")) + models.append(_create_model_info(model_id)) except Exception as e: logger.error(f"Error fetching models from provider {provider_name}: {str(e)}") logger.info(f"Falling back to configured models for provider {provider_name}") # Fall back to configured models if none fetched if not models: - models = [_create_model_info(model_id, config.default_model or "") for model_id in config.models] + models = [_create_model_info(model_id) for model_id in config.models] logger.info(f"Using {len(models)} configured models for provider {provider_name}") return models @@ -112,7 +111,6 @@ def create_provider_client_from_config(config: ProviderConfig) -> BaseClient: environment=config.environment, base_url=config.base_url, api_key=config.api_key, - default_model=config.default_model, rate_limits=config.rate_limits, use_cache=True, ) @@ -141,7 +139,6 @@ async def test_provider_connection(provider_name: str, config: ProviderConfig) - "environment": config.environment, "base_url_valid": bool(config.base_url), "api_key_provided": bool(config.api_key), - "default_model": config.default_model, "models_count": len(config.models), }, "connection_steps": [], @@ -164,7 +161,6 @@ async def test_provider_connection(provider_name: str, config: ProviderConfig) - environment=config.environment, base_url=config.base_url, api_key=config.api_key, - default_model=config.default_model, rate_limits=config.rate_limits, use_cache=False, # Don't cache test clients ) From 5d850e2f1f14b33ecd431a175fcfc6375f1da40e Mon Sep 17 00:00:00 2001 From: jphillips Date: Thu, 27 Mar 2025 06:22:15 -0500 Subject: [PATCH 24/69] Add generation options to Action Drawers Signed-off-by: jphillips --- .../src/app/layout/RootLeftActionPanel.tsx | 8 + .../src/components/icons/index.tsx | 24 +++ .../src/context/AppContextProvider.tsx | 10 +- .../components/GenerationOptionsButton.tsx | 16 +- .../components/GenerationOptionsDialog.tsx | 165 ++++++++++++++++ .../components/GenerationOptionsPanel.tsx | 52 +++++ .../components/GenerationOptionsPopover.tsx | 128 ------------ .../components/fields/GlobalContextField.tsx | 11 +- .../components/fields/ModelSelectorField.tsx | 183 ++++++++++++++++++ .../components/fields/index.ts | 1 + .../generation-options/components/index.ts | 5 +- .../context/GenerationOptionsContext.tsx | 48 +++-- .../inference/generation-options/schema.ts | 6 + .../inference/hooks/useModelSelection.ts | 29 +-- .../hooks/useProviderModelSelection.ts | 10 +- .../context/InferenceProviderContext.tsx | 35 ++-- .../server-connections/services/providers.ts | 10 +- 17 files changed, 540 insertions(+), 201 deletions(-) create mode 100644 graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsDialog.tsx create mode 100644 graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsPanel.tsx delete mode 100644 graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsPopover.tsx create mode 100644 graphcap_studio/src/features/inference/generation-options/components/fields/ModelSelectorField.tsx diff --git a/graphcap_studio/src/app/layout/RootLeftActionPanel.tsx b/graphcap_studio/src/app/layout/RootLeftActionPanel.tsx index bc6948c4..5d396b69 100644 --- a/graphcap_studio/src/app/layout/RootLeftActionPanel.tsx +++ b/graphcap_studio/src/app/layout/RootLeftActionPanel.tsx @@ -8,6 +8,7 @@ import { DatasetIcon, FlagIcon, + GenerationOptionsIcon, PerspectiveLayersIcon, ProviderIcon, SettingsIcon, @@ -15,6 +16,7 @@ import { import { SettingsPanel } from "@/features/app-settings"; import { FeatureFlagsPanel } from "@/features/app-settings/feature-flags"; import { DatasetPanel } from "@/features/datasets"; +import { GenerationOptionsPanel } from "@/features/inference/generation-options/components/GenerationOptionsPanel"; import { ProvidersPanel } from "@/features/inference/providers"; import { PerspectiveManagementPanel } from "@/features/perspectives/components/PerspectiveManagement/PerspectiveManagementPanel"; import { ActionPanel } from "./ActionPanel"; @@ -35,6 +37,12 @@ export function RootLeftActionPanel() { icon: , content: , }, + { + id: "generation-options", + title: "Generation Options", + icon: , + content: , + }, { id: "datasets", title: "Datasets", diff --git a/graphcap_studio/src/components/icons/index.tsx b/graphcap_studio/src/components/icons/index.tsx index 9a37668f..ed9b1ac0 100644 --- a/graphcap_studio/src/components/icons/index.tsx +++ b/graphcap_studio/src/components/icons/index.tsx @@ -187,5 +187,29 @@ export function DatasetIcon({ className = "" }: Readonly) { ); } +/** + * Generation Options icon + */ +export function GenerationOptionsIcon({ className = "" }: Readonly) { + return ( + + Generation Options + + + ); +} + // Export for PerspectiveLayersIcon export { PerspectiveLayersIcon } from "./PerspectiveLayersIcon"; diff --git a/graphcap_studio/src/context/AppContextProvider.tsx b/graphcap_studio/src/context/AppContextProvider.tsx index 771b790e..35f80f43 100644 --- a/graphcap_studio/src/context/AppContextProvider.tsx +++ b/graphcap_studio/src/context/AppContextProvider.tsx @@ -1,7 +1,9 @@ import { DatasetInitializer } from "@/features/datasets"; +import { GenerationOptionsProvider } from "@/features/inference/generation-options"; +import { InferenceProviderProvider } from "@/features/inference/providers/context"; import { PerspectivesProvider } from "@/features/perspectives/context"; // SPDX-License-Identifier: Apache-2.0 -import { ReactNode } from "react"; +import type { ReactNode } from "react"; import { ServerConnectionsProvider } from "."; import { FeatureFlagProvider } from "../features/app-settings/feature-flags/FeatureFlagProvider"; @@ -28,7 +30,11 @@ export function AppContextProvider({ children }: AppContextProviderProps) { - {children} + + + {children} + + diff --git a/graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsButton.tsx b/graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsButton.tsx index 44e25613..1e875ed5 100644 --- a/graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsButton.tsx +++ b/graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsButton.tsx @@ -2,13 +2,13 @@ /** * Generation Options Button * - * This component provides a button that triggers the generation options popover. + * This component provides a button that triggers the generation options dialog. */ import { Button } from "@/components/ui"; -import React from "react"; +import type React from "react"; import { useGenerationOptions } from "../context"; -import { GenerationOptionsPopover } from "./GenerationOptionsPopover"; +import { GenerationOptionsDialog } from "./GenerationOptionsDialog"; interface GenerationOptionsButtonProps { readonly label?: React.ReactNode; @@ -17,25 +17,25 @@ interface GenerationOptionsButtonProps { } /** - * Button component for triggering generation options popover + * Button component for triggering generation options dialog */ export function GenerationOptionsButton({ label = "Options", size = "sm", variant = "outline", }: GenerationOptionsButtonProps) { - const { togglePopover, isGenerating } = useGenerationOptions(); + const { toggleDialog, isGenerating } = useGenerationOptions(); return ( - + - + ); } diff --git a/graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsDialog.tsx b/graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsDialog.tsx new file mode 100644 index 00000000..874944dc --- /dev/null +++ b/graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsDialog.tsx @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Generation Options Dialog + * + * This component displays a dialog with generation options form. + */ + +import { Button } from "@/components/ui"; +import { useColorModeValue } from "@/components/ui/theme/color-mode"; +import { Box, CloseButton, Dialog, Fieldset, Flex, HStack, Portal } from "@chakra-ui/react"; +import type React from "react"; +import { useGenerationOptions } from "../context"; +import { + GlobalContextField, + MaxTokensField, + ModelSelectorField, + RepetitionPenaltyField, + ResizeResolutionField, + TemperatureField, + TopPField, +} from "./fields"; + +interface GenerationOptionsDialogProps { + readonly children: React.ReactNode; +} + +/** + * Dialog component for generation options + */ +export function GenerationOptionsDialog({ + children, +}: GenerationOptionsDialogProps) { + const { isDialogOpen, closeDialog, resetOptions, isGenerating } = + useGenerationOptions(); + + // Colors for theming + const bgColor = useColorModeValue("white", "gray.700"); + const borderColor = useColorModeValue("gray.200", "gray.600"); + const headerColor = useColorModeValue("gray.800", "white"); + + return ( + (e.open ? null : closeDialog())} + size="lg" + > + {children} + + + + + + + Generation Options + + + + + + + + + {/* First column: Model selection */} + + + Provider & Model + + + + + + + {/* Second column: Generation Parameters */} + + + Generation Parameters + + + + + + + + + + + + + + + + + + + + {/* Third column: Image Processing */} + + + Image Processing + + + + + + + + {/* Global Context Field (spans full width on second row) */} + + Context Settings + + + + + + + + + + + + + + + + + + + + + + ); +} \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsPanel.tsx b/graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsPanel.tsx new file mode 100644 index 00000000..9e721514 --- /dev/null +++ b/graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsPanel.tsx @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Generation Options Panel + * + * This component displays generation options in the left action drawer. + */ + +import { Box, Button, VStack } from "@chakra-ui/react"; +import { useGenerationOptions } from "../context"; +import { + GlobalContextField, + MaxTokensField, + ModelSelectorField, + RepetitionPenaltyField, + ResizeResolutionField, + TemperatureField, + TopPField, +} from "./fields"; + +/** + * Panel component for generation options in the left action drawer + */ +export function GenerationOptionsPanel() { + const { resetOptions, isGenerating } = useGenerationOptions(); + + return ( + + + + + + + + + + + + + + + + + ); +} \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsPopover.tsx b/graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsPopover.tsx deleted file mode 100644 index 7fc12e30..00000000 --- a/graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsPopover.tsx +++ /dev/null @@ -1,128 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Generation Options Popover - * - * This component displays a popover with generation options form. - */ - -import { Button } from "@/components/ui"; -import { useColorModeValue } from "@/components/ui/theme/color-mode"; -import { Box, Flex, HStack, Popover, Portal } from "@chakra-ui/react"; -import React from "react"; -import { useGenerationOptions } from "../context"; -import { - GlobalContextField, - MaxTokensField, - RepetitionPenaltyField, - ResizeResolutionField, - TemperatureField, - TopPField, -} from "./fields"; - -interface GenerationOptionsPopoverProps { - readonly children: React.ReactNode; -} - -/** - * Popover component for generation options - */ -export function GenerationOptionsPopover({ - children, -}: GenerationOptionsPopoverProps) { - const { isPopoverOpen, closePopover, resetOptions, isGenerating } = - useGenerationOptions(); - - // Colors for theming - const bgColor = useColorModeValue("white", "gray.700"); - const borderColor = useColorModeValue("gray.200", "gray.600"); - const headerColor = useColorModeValue("gray.800", "white"); - - return ( - (e.open ? null : closePopover())} - > - {children} - - - - - - Generation Options - - - ✕ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - ); -} diff --git a/graphcap_studio/src/features/inference/generation-options/components/fields/GlobalContextField.tsx b/graphcap_studio/src/features/inference/generation-options/components/fields/GlobalContextField.tsx index 70f84901..92b0a8f0 100644 --- a/graphcap_studio/src/features/inference/generation-options/components/fields/GlobalContextField.tsx +++ b/graphcap_studio/src/features/inference/generation-options/components/fields/GlobalContextField.tsx @@ -7,7 +7,7 @@ import { useColorModeValue } from "@/components/ui/theme/color-mode"; import { Box, Textarea } from "@chakra-ui/react"; -import { ChangeEvent, useCallback, useEffect, useState } from "react"; +import { type ChangeEvent, useCallback, useEffect, useState } from "react"; import { useGenerationOptions } from "../../context"; /** @@ -32,7 +32,9 @@ export function GlobalContextField() { debounce((value: string) => { updateOption("global_context", value); }, 500), - [updateOption], + // updateOption is from context and won't change during component's lifecycle + // eslint-disable-next-line react-hooks/exhaustive-deps + [], ); const handleChange = (e: ChangeEvent) => { @@ -63,13 +65,14 @@ export function GlobalContextField() { } // Debounce utility function -function debounce any>( +function debounce) => ReturnType>( func: T, wait: number, ): (...args: Parameters) => void { let timeout: NodeJS.Timeout | null = null; - return function (...args: Parameters) { + // Using arrow function to avoid "this function expression can be turned into an arrow function" lint error + return (...args: Parameters) => { if (timeout) clearTimeout(timeout); timeout = setTimeout(() => func(...args), wait); }; diff --git a/graphcap_studio/src/features/inference/generation-options/components/fields/ModelSelectorField.tsx b/graphcap_studio/src/features/inference/generation-options/components/fields/ModelSelectorField.tsx new file mode 100644 index 00000000..ba5e7e08 --- /dev/null +++ b/graphcap_studio/src/features/inference/generation-options/components/fields/ModelSelectorField.tsx @@ -0,0 +1,183 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Model Selector Field Component + * + * This component provides controls for selecting a provider and model. + */ + +import { Field } from "@/components/ui/field"; +import { useColorModeValue } from "@/components/ui/theme/color-mode"; +import { useProviderModelSelection } from "@/features/inference/hooks"; +import { useInferenceProviderContext } from "@/features/inference/providers/context/InferenceProviderContext"; +import { Box } from "@chakra-ui/react"; +import { Portal, Select, createListCollection } from "@chakra-ui/react"; +import { useCallback, useEffect, useMemo } from "react"; +import { useGenerationOptions } from "../../context"; + +/** + * Field component for selecting model and provider + */ +export function ModelSelectorField() { + const { options, updateOption, isGenerating } = useGenerationOptions(); + const { providers: contextProviders } = useInferenceProviderContext(); + + // Extract provider from context + const currentProvider = useMemo(() => { + if (!options.provider_id) return null; + const providerId = Number.parseInt(options.provider_id, 10); + return contextProviders.find(p => p.id === providerId) || null; + }, [contextProviders, options.provider_id]); + + // Use the hook to get providers and models + const { + providers, + models, + isLoading, + } = useProviderModelSelection(currentProvider); + + // Color values f or theming + const labelColor = useColorModeValue("gray.700", "gray.300"); + const helperTextColor = useColorModeValue("gray.500", "gray.400"); + + // Initialize provider if needed + useEffect(() => { + if (providers.length > 0 && !options.provider_id) { + const provider = providers[0]; + updateOption("provider_id", provider.id.toString()); + } + }, [providers, options.provider_id, updateOption]); + + // Update model when provider changes or when models are loaded + useEffect(() => { + if (models.length > 0 && !options.model_id) { + updateOption("model_id", models[0]?.name || ""); + } + }, [models, options.model_id, updateOption]); + + // Create collections for selects - always include at least one item + const providerCollection = useMemo(() => { + const items = providers.length > 0 + ? providers.map((provider) => ({ + label: provider.name, + value: provider.id.toString(), + disabled: false, + })) + : [{ label: "No providers available", value: "none", disabled: false }]; + + return createListCollection({ items }); + }, [providers]); + + const modelCollection = useMemo(() => { + const items = models.length > 0 + ? models.map((model) => ({ + label: model.name, + value: model.id, + disabled: false, + })) + : [{ label: "No models available", value: "none", disabled: false }]; + + return createListCollection({ items }); + }, [models]); + + // Handle provider change + const handleProviderChange = useCallback((newValue: string[]) => { + if (newValue.length > 0 && newValue[0] !== "none") { + const providerId = newValue[0]; + updateOption("provider_id", providerId); + updateOption("model_id", ""); + } + }, [updateOption]); + + // Handle model change + const handleModelChange = useCallback((newValue: string[]) => { + if (newValue.length > 0 && newValue[0] !== "none") { + updateOption("model_id", newValue[0]); + } + }, [updateOption]); + + console.log('Providers:', providers); + console.log('IsLoading:', isLoading); + console.log('CurrentProvider:', currentProvider); + + return ( + + + Provider & Model + + + + + handleProviderChange(e.value)} + disabled={false} // Never disable this + size="sm" + > + + + + + + + + + + + + + + {providerCollection.items.map((provider) => ( + + {provider.label} + + + ))} + + + + + + + + + + handleModelChange(e.value)} + disabled={false} // Never disable this + size="sm" + > + + + + + + + + + + + + + + {modelCollection.items.map((model) => ( + + {model.label} + + + ))} + + + + + + + + + Click on the dropdowns to select provider and model + + + ); +} diff --git a/graphcap_studio/src/features/inference/generation-options/components/fields/index.ts b/graphcap_studio/src/features/inference/generation-options/components/fields/index.ts index 2abc98cd..2d5d59ca 100644 --- a/graphcap_studio/src/features/inference/generation-options/components/fields/index.ts +++ b/graphcap_studio/src/features/inference/generation-options/components/fields/index.ts @@ -12,3 +12,4 @@ export * from "./TopPField"; export * from "./RepetitionPenaltyField"; export * from "./GlobalContextField"; export * from "./ResizeResolutionField"; +export * from "./ModelSelectorField"; diff --git a/graphcap_studio/src/features/inference/generation-options/components/index.ts b/graphcap_studio/src/features/inference/generation-options/components/index.ts index 4cae6d1c..7e99765c 100644 --- a/graphcap_studio/src/features/inference/generation-options/components/index.ts +++ b/graphcap_studio/src/features/inference/generation-options/components/index.ts @@ -5,6 +5,9 @@ * This module exports all components for generation options. */ -export * from "./GenerationOptionsPopover"; +export * from "./fields"; +export * from "./GenerationOptionForm"; export * from "./GenerationOptionsButton"; +export * from "./GenerationOptionsDialog"; +export * from "./GenerationOptionsPanel"; export * from "./ProviderSelector"; diff --git a/graphcap_studio/src/features/inference/generation-options/context/GenerationOptionsContext.tsx b/graphcap_studio/src/features/inference/generation-options/context/GenerationOptionsContext.tsx index 6311bdf3..d0745c16 100644 --- a/graphcap_studio/src/features/inference/generation-options/context/GenerationOptionsContext.tsx +++ b/graphcap_studio/src/features/inference/generation-options/context/GenerationOptionsContext.tsx @@ -5,18 +5,19 @@ * This module provides a context for managing generation options state. */ -import React, { +import type React from "react"; +import { createContext, - useContext, - useState, useCallback, - useMemo, + useContext, useEffect, + useMemo, + useState, } from "react"; import { usePersistGenerationOptions } from "../persist-generation-options"; import { DEFAULT_OPTIONS, - GenerationOptions, + type GenerationOptions, GenerationOptionsSchema, } from "../schema"; @@ -24,7 +25,7 @@ import { interface GenerationOptionsContextValue { // State options: GenerationOptions; - isPopoverOpen: boolean; + isDialogOpen: boolean; isGenerating: boolean; // Actions @@ -34,9 +35,9 @@ interface GenerationOptionsContextValue { ) => void; resetOptions: () => void; setOptions: (options: Partial) => void; - openPopover: () => void; - closePopover: () => void; - togglePopover: () => void; + openDialog: () => void; + closeDialog: () => void; + toggleDialog: () => void; setIsGenerating: (isGenerating: boolean) => void; } @@ -78,7 +79,7 @@ export function GenerationOptionsProvider({ // State const [options, setOptions] = useState(defaultOptions); - const [isPopoverOpen, setIsPopoverOpen] = useState(false); + const [isDialogOpen, setIsDialogOpen] = useState(false); const [isGenerating, setIsGenerating] = useState(initialGenerating); // Save options to localStorage when they change @@ -131,41 +132,38 @@ export function GenerationOptionsProvider({ [onOptionsChange], ); - // Popover controls - const openPopover = useCallback(() => setIsPopoverOpen(true), []); - const closePopover = useCallback(() => setIsPopoverOpen(false), []); - const togglePopover = useCallback( - () => setIsPopoverOpen((prev) => !prev), - [], - ); + // Dialog controls + const openDialog = useCallback(() => setIsDialogOpen(true), []); + const closeDialog = useCallback(() => setIsDialogOpen(false), []); + const toggleDialog = useCallback(() => setIsDialogOpen((prev) => !prev), []); // Context value const value = useMemo( () => ({ // State options, - isPopoverOpen, + isDialogOpen, isGenerating, // Actions updateOption, resetOptions, setOptions: mergeOptions, - openPopover, - closePopover, - togglePopover, + openDialog, + closeDialog, + toggleDialog, setIsGenerating, }), [ options, - isPopoverOpen, + isDialogOpen, isGenerating, updateOption, resetOptions, mergeOptions, - openPopover, - closePopover, - togglePopover, + openDialog, + closeDialog, + toggleDialog, ], ); diff --git a/graphcap_studio/src/features/inference/generation-options/schema.ts b/graphcap_studio/src/features/inference/generation-options/schema.ts index 593045c7..1036e5e1 100644 --- a/graphcap_studio/src/features/inference/generation-options/schema.ts +++ b/graphcap_studio/src/features/inference/generation-options/schema.ts @@ -35,6 +35,8 @@ export const DEFAULT_OPTIONS = { repetition_penalty: 1.1, resize_resolution: "NONE", // Default to no resize global_context: "You are a visual captioning perspective.", + provider_id: "", // Default to empty (will be populated later) + model_id: "", // Default to empty (will be populated later) } as const; // Schema for generation options @@ -67,6 +69,10 @@ export const GenerationOptionsSchema = z.object({ resize_resolution: z.string().default(DEFAULT_OPTIONS.resize_resolution), global_context: z.string().default(DEFAULT_OPTIONS.global_context), + + provider_id: z.string().default(DEFAULT_OPTIONS.provider_id), + + model_id: z.string().default(DEFAULT_OPTIONS.model_id), }); // Type for generation options diff --git a/graphcap_studio/src/features/inference/hooks/useModelSelection.ts b/graphcap_studio/src/features/inference/hooks/useModelSelection.ts index df1ab19c..2a96ff54 100644 --- a/graphcap_studio/src/features/inference/hooks/useModelSelection.ts +++ b/graphcap_studio/src/features/inference/hooks/useModelSelection.ts @@ -6,18 +6,25 @@ import type { Provider } from "../providers/types"; /** * Custom hook for managing model selection * - * @param provider - Provider to fetch models for + * @param provider - Provider to fetch models for, can be null or undefined * @param onModelSelect - Callback function when a model is selected * @returns Model selection state and handlers */ export function useModelSelection( - provider: Provider, + provider: Provider | null | undefined, onModelSelect?: (providerName: string, modelId: string) => void, ) { // State for model selection - const [selectedModelId, setSelectedModelId] = useState( - provider.defaultModel || "" - ); + const [selectedModelId, setSelectedModelId] = useState(""); + + // Update selected model ID when provider changes + useEffect(() => { + if (provider?.defaultModel) { + setSelectedModelId(provider.defaultModel); + } else { + setSelectedModelId(""); + } + }, [provider]); // Get models for the current provider const { @@ -27,24 +34,22 @@ export function useModelSelection( error: modelsError, } = useProviderModels(provider); - // Update selected model when models are loaded or default model changes + // Update selected model when models are loaded useEffect(() => { - if (provider.defaultModel) { - setSelectedModelId(provider.defaultModel); - } else if (providerModelsData?.models && providerModelsData.models.length > 0) { + if (!selectedModelId && providerModelsData?.models && providerModelsData.models.length > 0) { const defaultModel = providerModelsData.models.find( (model) => model.is_default ); setSelectedModelId(defaultModel?.id ?? providerModelsData.models[0].id); } - }, [providerModelsData, provider.defaultModel]); + }, [providerModelsData, selectedModelId]); // Handle model selection const handleModelSelect = useCallback(() => { - if (onModelSelect && provider.name && selectedModelId) { + if (onModelSelect && provider?.name && selectedModelId) { onModelSelect(provider.name, selectedModelId); } - }, [onModelSelect, provider.name, selectedModelId]); + }, [onModelSelect, provider, selectedModelId]); return { selectedModelId, diff --git a/graphcap_studio/src/features/inference/hooks/useProviderModelSelection.ts b/graphcap_studio/src/features/inference/hooks/useProviderModelSelection.ts index 5e69e108..d305245b 100644 --- a/graphcap_studio/src/features/inference/hooks/useProviderModelSelection.ts +++ b/graphcap_studio/src/features/inference/hooks/useProviderModelSelection.ts @@ -6,7 +6,7 @@ import type { Provider } from "../providers/types"; /** * Custom hook to handle provider and model selection logic */ -export function useProviderModelSelection(provider: Provider) { +export function useProviderModelSelection(provider: Provider | null | undefined) { // Fetch providers from API const { data: providers = [], @@ -31,16 +31,16 @@ export function useProviderModelSelection(provider: Provider) { const providersWithNoModels = useMemo(() => { const noModelsSet = new Set(); - if (providerModelsData?.models?.length === 0 && provider.fetchModels) { + if (providerModelsData?.models?.length === 0 && provider?.fetchModels && provider?.name) { noModelsSet.add(provider.name); } return noModelsSet; - }, [provider.name, provider.fetchModels, providerModelsData]); + }, [provider?.name, provider?.fetchModels, providerModelsData]); // Get default model if available const defaultModel = useMemo(() => { - if (provider.defaultModel) { + if (provider?.defaultModel) { return { id: provider.defaultModel, name: provider.defaultModel, @@ -54,7 +54,7 @@ export function useProviderModelSelection(provider: Provider) { ); } return null; - }, [provider.defaultModel, providerModelsData]); + }, [provider?.defaultModel, providerModelsData]); return { providers: availableProviders, diff --git a/graphcap_studio/src/features/inference/providers/context/InferenceProviderContext.tsx b/graphcap_studio/src/features/inference/providers/context/InferenceProviderContext.tsx index f0047499..f823ed8c 100644 --- a/graphcap_studio/src/features/inference/providers/context/InferenceProviderContext.tsx +++ b/graphcap_studio/src/features/inference/providers/context/InferenceProviderContext.tsx @@ -155,8 +155,8 @@ const loadProviderFromStorage = (): Provider | null => { */ type InferenceProviderProviderProps = { readonly children: ReactNode; - readonly isCreating: boolean; - readonly onCancel: () => void; + readonly isCreating?: boolean; + readonly onCancel?: () => void; readonly onModelSelect?: (providerName: string, modelId: string) => void; readonly selectedProvider?: Provider | null; readonly providers?: Provider[]; @@ -177,8 +177,8 @@ type InferenceProviderProviderProps = { */ export function InferenceProviderProvider({ children, - isCreating, - onCancel, + isCreating = false, + onCancel = () => {}, onModelSelect, selectedProvider: selectedProviderProp, providers: providersProp = [], @@ -195,20 +195,27 @@ export function InferenceProviderProvider({ // Update selected provider when prop changes useEffect(() => { - if (selectedProviderProp) { + if (selectedProviderProp && JSON.stringify(selectedProviderProp) !== JSON.stringify(selectedProvider)) { setSelectedProvider(selectedProviderProp); } - }, [selectedProviderProp]); + }, [selectedProviderProp, selectedProvider]); // Save selected provider to localStorage when it changes useEffect(() => { - saveProviderToStorage(selectedProvider); + if (selectedProvider) { + saveProviderToStorage(selectedProvider); + } }, [selectedProvider]); - // Update providers when prop changes + // Update providers when prop changes - only if we have providers and they're different useEffect(() => { - setProviders(providersProp); - }, [providersProp]); + const hasProviders = Array.isArray(providersProp) && providersProp.length > 0; + const providersChanged = JSON.stringify(providersProp) !== JSON.stringify(providers); + + if (hasProviders && providersChanged) { + setProviders(providersProp); + } + }, [providersProp, providers]); // Use the model selection hook with selectedProvider const { @@ -219,12 +226,14 @@ export function InferenceProviderProvider({ isModelsError, modelsError, handleModelSelect: handleModelSelectBase, - } = useModelSelection(selectedProvider as Provider, onModelSelect); + } = useModelSelection(selectedProvider, onModelSelect); // Create a memoized version of handleModelSelect const handleModelSelect = useCallback(() => { - handleModelSelectBase(); - }, [handleModelSelectBase]); + if (selectedProvider) { + handleModelSelectBase(); + } + }, [handleModelSelectBase, selectedProvider]); // Create a memoized version of onCancel that resets mode const onCancelHandler = useCallback(() => { diff --git a/graphcap_studio/src/features/server-connections/services/providers.ts b/graphcap_studio/src/features/server-connections/services/providers.ts index cf4dccf8..d39c4f1b 100644 --- a/graphcap_studio/src/features/server-connections/services/providers.ts +++ b/graphcap_studio/src/features/server-connections/services/providers.ts @@ -225,7 +225,7 @@ export function useDeleteProvider() { /** * Hook to get available models for a provider */ -export function useProviderModels(provider: Provider) { +export function useProviderModels(provider: Provider | null | undefined) { const { connections } = useServerConnectionsContext(); const inferenceBridgeConnection = connections.find( (conn) => conn.id === SERVER_IDS.INFERENCE_BRIDGE, @@ -233,8 +233,12 @@ export function useProviderModels(provider: Provider) { const isConnected = inferenceBridgeConnection?.status === "connected"; return useQuery({ - queryKey: queryKeys.providerModels(provider.name), + queryKey: queryKeys.providerModels(provider?.name ?? 'unknown'), queryFn: async () => { + if (!provider) { + throw new Error("Provider is null or undefined"); + } + const client = createInferenceBridgeClient(connections); const serverConfig = toServerConfig(provider); @@ -249,7 +253,7 @@ export function useProviderModels(provider: Provider) { return response.json() as Promise; }, - enabled: isConnected && !!provider && provider.fetchModels, + enabled: isConnected && !!provider && !!provider.fetchModels, staleTime: 1000 * 60 * 10, // 10 minutes }); } \ No newline at end of file From a228a89845145c00bbdd21b1c05377cb5accac29 Mon Sep 17 00:00:00 2001 From: jphillips Date: Thu, 27 Mar 2025 06:33:13 -0500 Subject: [PATCH 25/69] Style tweaks Signed-off-by: jphillips --- graphcap_studio/src/app/layout/RootLeftActionPanel.tsx | 2 +- graphcap_studio/src/app/theme/global-css.ts | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/graphcap_studio/src/app/layout/RootLeftActionPanel.tsx b/graphcap_studio/src/app/layout/RootLeftActionPanel.tsx index 5d396b69..79b4de30 100644 --- a/graphcap_studio/src/app/layout/RootLeftActionPanel.tsx +++ b/graphcap_studio/src/app/layout/RootLeftActionPanel.tsx @@ -29,7 +29,7 @@ export function RootLeftActionPanel() { Date: Sat, 29 Mar 2025 05:57:09 -0500 Subject: [PATCH 26/69] Require model on generation request, clean up default model reqs Signed-off-by: jphillips --- .../common_inference/ModelSelector.tsx | 113 +++++++++++++++++ .../common_inference/ProviderSelector.tsx | 117 ++++++++++++++++++ graphcap_studio/src/components/ui/index.ts | 1 + .../components/ProviderSelector.tsx | 69 +++-------- .../components/form/ModelSelector.tsx | 42 ++----- .../components/form/ProviderFormSelect.tsx | 50 ++------ .../hooks/useGeneratePerspectiveCaption.ts | 12 ++ .../perspectives/types/perspectivesTypes.ts | 2 + graphcap_studio/src/utils/toast.ts | 38 ++++-- .../graphcap/perspectives/base_caption.py | 13 +- .../providers/clients/ollama_client.py | 2 - .../pipelines/pipelines/providers/assets.py | 2 +- .../pipelines/pipelines/providers/util.py | 5 +- .../server/features/perspectives/models.py | 2 + .../server/features/perspectives/router.py | 47 ++++--- .../server/features/perspectives/service.py | 20 +-- .../features/providers/error_handler.py | 7 +- 17 files changed, 367 insertions(+), 175 deletions(-) create mode 100644 graphcap_studio/src/components/common_inference/ModelSelector.tsx create mode 100644 graphcap_studio/src/components/common_inference/ProviderSelector.tsx diff --git a/graphcap_studio/src/components/common_inference/ModelSelector.tsx b/graphcap_studio/src/components/common_inference/ModelSelector.tsx new file mode 100644 index 00000000..e8c5ad82 --- /dev/null +++ b/graphcap_studio/src/components/common_inference/ModelSelector.tsx @@ -0,0 +1,113 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Model Selector Component + * + * A reusable component for selecting a model from a list of options. + */ + +import { Box, createListCollection } from "@chakra-ui/react"; +import { Field } from "../ui/field"; +import { + SelectContent, + SelectItem, + SelectRoot, + SelectTrigger, + SelectValueText, +} from "../ui/select"; + +export interface ModelOption { + label: string; + value: string; +} + +export interface ModelSelectorProps { + readonly options: ModelOption[]; + readonly value: string | null | undefined; + readonly onChange: (value: string) => void; + readonly isDisabled?: boolean; + readonly maxWidth?: string | number; + readonly minWidth?: string | number; + readonly width?: string | number; + readonly size?: "xs" | "sm" | "md" | "lg"; + readonly placeholder?: string; + readonly className?: string; + readonly showLabel?: boolean; + readonly label?: string; + readonly helperText?: string; +} + +/** + * A reusable component for selecting a model from a list of options. + */ +export function ModelSelector({ + options, + value, + onChange, + isDisabled = false, + maxWidth = undefined, + minWidth = undefined, + width = undefined, + size = "sm", + placeholder = "Select model", + className, + showLabel = false, + label = "Model", + helperText, +}: ModelSelectorProps) { + // Create collection for SelectRoot + const modelCollection = createListCollection({ + items: options, + }); + + // Convert value to string array format required by SelectRoot + const selectValue = value ? [value] : []; + + const handleValueChange = (details: { value: string[] }) => { + if (details.value && details.value.length > 0) { + onChange(details.value[0]); + } else { + onChange(""); + } + }; + + const boxProps = { + ...(maxWidth ? { maxWidth } : {}), + ...(minWidth ? { minWidth } : {}), + ...(width ? { width } : {}), + className, + }; + + const selector = ( + + + + + + {options.map((option) => ( + + {option.label} + + ))} + + + ); + + return ( + + {showLabel ? ( + + {selector} + + ) : ( + selector + )} + + ); +} \ No newline at end of file diff --git a/graphcap_studio/src/components/common_inference/ProviderSelector.tsx b/graphcap_studio/src/components/common_inference/ProviderSelector.tsx new file mode 100644 index 00000000..0daa99b2 --- /dev/null +++ b/graphcap_studio/src/components/common_inference/ProviderSelector.tsx @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Provider Selector Component + * + * A reusable component for selecting a provider from a list of options. + */ + +import { Box, createListCollection } from "@chakra-ui/react"; +import { Field } from "../ui/field"; +import { + SelectContent, + SelectItem, + SelectRoot, + SelectTrigger, + SelectValueText, +} from "../ui/select"; + +export interface ProviderOption { + label: string; + value: string; + id?: number; +} + +export interface ProviderSelectorProps { + readonly options: ProviderOption[]; + readonly value: string | number | null | undefined; + readonly onChange: (value: string) => void; + readonly isDisabled?: boolean; + readonly maxWidth?: string | number; + readonly minWidth?: string | number; + readonly width?: string | number; + readonly size?: "xs" | "sm" | "md" | "lg"; + readonly placeholder?: string; + readonly className?: string; + readonly showLabel?: boolean; + readonly label?: string; + readonly helperText?: string; +} + +/** + * A reusable component for selecting a provider from a list of options. + */ +export function ProviderSelector({ + options, + value, + onChange, + isDisabled = false, + maxWidth = undefined, + minWidth = undefined, + width = undefined, + size = "sm", + placeholder = "Select provider", + className, + showLabel = false, + label = "Provider", + helperText, +}: ProviderSelectorProps) { + // Create collection for SelectRoot + const providerCollection = createListCollection({ + items: options, + }); + + // Convert value to string array format required by SelectRoot + // Handle both string and number values + const stringValue = value !== null && value !== undefined ? String(value) : null; + const selectValue = stringValue ? [stringValue] : []; + + const handleValueChange = (details: { value: string[] }) => { + if (details.value && details.value.length > 0) { + onChange(details.value[0]); + } else { + onChange(""); + } + }; + + const boxProps = { + ...(maxWidth ? { maxWidth } : {}), + ...(minWidth ? { minWidth } : {}), + ...(width ? { width } : {}), + className, + }; + + const selector = ( + + + + + + {options.map((option) => ( + + {option.label} + + ))} + + + ); + + return ( + + {showLabel ? ( + + {selector} + + ) : ( + selector + )} + + ); +} \ No newline at end of file diff --git a/graphcap_studio/src/components/ui/index.ts b/graphcap_studio/src/components/ui/index.ts index 670c9ad6..6aa6afb8 100644 --- a/graphcap_studio/src/components/ui/index.ts +++ b/graphcap_studio/src/components/ui/index.ts @@ -3,3 +3,4 @@ export * from "./ImageCounter"; export * from "./buttons"; export * from "./status"; +export * from "../common_inference/ModelSelector"; diff --git a/graphcap_studio/src/features/inference/generation-options/components/ProviderSelector.tsx b/graphcap_studio/src/features/inference/generation-options/components/ProviderSelector.tsx index b40fcf65..4d3603c4 100644 --- a/graphcap_studio/src/features/inference/generation-options/components/ProviderSelector.tsx +++ b/graphcap_studio/src/features/inference/generation-options/components/ProviderSelector.tsx @@ -5,14 +5,7 @@ * A component for selecting an inference provider. */ -import { - SelectContent, - SelectItem, - SelectRoot, - SelectTrigger, - SelectValueText, -} from "@/components/ui/select"; -import { Box, createListCollection } from "@chakra-ui/react"; +import { ProviderSelector as CommonProviderSelector, type ProviderOption } from "@/components/common_inference/ProviderSelector"; export interface Provider { id: number; @@ -44,55 +37,25 @@ export function ProviderSelector({ placeholder = "Select provider", className, }: ProviderSelectorProps) { - // Convert providers to the format expected by SelectRoot - const providerItems = providers.map((provider) => ({ + // Convert providers to ProviderOption format + const providerOptions: ProviderOption[] = providers.map((provider) => ({ label: provider.name, value: provider.name, + id: provider.id, })); - const providerCollection = createListCollection({ - items: providerItems, - }); - - // Convert selectedProvider to string array format required by SelectRoot - const value = selectedProvider ? [selectedProvider] : []; - - const handleValueChange = (details: any) => { - if (details.value && details.value.length > 0) { - onChange(details.value[0]); - } else { - onChange(""); - } - }; - - const boxProps = { - ...(maxWidth ? { maxWidth } : {}), - ...(minWidth ? { minWidth } : {}), - ...(width ? { width } : {}), - className, - }; - return ( - - - - - - - {providerItems.map((item) => ( - - {item.label} - - ))} - - - + ); } diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelector.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelector.tsx index e535ee60..b3e5bc27 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelector.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelector.tsx @@ -1,16 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 -import { Field } from "@/components/ui/field"; -import { - SelectContent, - SelectItem, - SelectRoot, - SelectTrigger, - SelectValueText, -} from "@/components/ui/select"; +import { ModelSelector as GenericModelSelector, ModelOption } from "@/components/common_inference/ModelSelector"; import { useColorMode } from "@/components/ui/theme/color-mode"; -import { Box, Heading, Text, createListCollection } from "@chakra-ui/react"; +import { Box, Heading, Text } from "@chakra-ui/react"; -// Define the model item type for the select component export interface ModelItem { label: string; value: string; @@ -38,13 +30,6 @@ export function ModelSelector({ const headingColor = isDark ? "gray.100" : "gray.700"; const labelColor = isDark ? "gray.300" : "gray.600"; - const modelCollection = createListCollection({ - items: modelItems, - }); - - // Convert selectedModelId to string array format - const value = selectedModelId ? [selectedModelId] : []; - return ( Select a model to use with this provider - - setSelectedModelId(details.value[0])} - > - - - - - {modelItems.map((item: ModelItem) => ( - - {item.label} - - ))} - - - + ); } diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ProviderFormSelect.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ProviderFormSelect.tsx index 8dc5faa5..c4b93e6f 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ProviderFormSelect.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ProviderFormSelect.tsx @@ -1,12 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -import { - SelectContent, - SelectItem, - SelectRoot, - SelectTrigger, - SelectValueText, -} from "@/components/ui/select"; -import { createListCollection } from "@chakra-ui/react"; +import { ProviderSelector, type ProviderOption } from "@/components/common_inference/ProviderSelector"; import { useProviderFormContext } from "../../../context/ProviderFormContext"; type ProviderFormSelectProps = { @@ -27,23 +20,17 @@ export function ProviderFormSelect({ const selectedProviderId = selectedProvider?.id ?? null; - // Convert providers to the format expected by SelectRoot - const providerItems = providers.map((provider) => ({ + // Convert providers to the format expected by ProviderSelector + const providerOptions: ProviderOption[] = providers.map((provider) => ({ label: provider.name, value: String(provider.id), + id: provider.id, })); - const providerCollection = createListCollection({ - items: providerItems, - }); - - // Convert selectedProviderId to string array format - const value = selectedProviderId ? [String(selectedProviderId)] : []; - - const handleProviderChange = (details: { value: string[] }) => { - if (!details.value.length) return; + const handleProviderChange = (value: string) => { + if (!value) return; - const id = Number(details.value[0]); + const id = Number(value); const provider = providers.find((p) => p.id === id); if (provider) { // Call the context's setSelectedProvider function @@ -53,23 +40,12 @@ export function ProviderFormSelect({ }; return ( - - - - - - {providerItems.map((item) => ( - - {item.label} - - ))} - - + /> ); } \ No newline at end of file diff --git a/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts b/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts index b045b10a..86c10c62 100644 --- a/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts +++ b/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts @@ -13,6 +13,7 @@ import { import { SERVER_IDS } from "@/features/server-connections/constants"; import { createInferenceBridgeClient } from "@/features/server-connections/services/apiClients"; import type { ServerConnection } from "@/features/server-connections/types"; +import { toast } from "@/utils/toast"; import { useMutation, useQueryClient } from "@tanstack/react-query"; import { perspectivesQueryKeys } from "../services/constants"; import { ensureWorkspacePath, handleApiError } from "../services/utils"; @@ -51,6 +52,11 @@ export function useGeneratePerspectiveCaption() { throw new Error("Caption generation options are required"); } + // Check if a model is specified in the options + if (!options.model) { + throw new Error("A model must be specified in the options"); + } + // Use the inference bridge client instead of direct fetch const client = createInferenceBridgeClient(connections); @@ -69,6 +75,7 @@ export function useGeneratePerspectiveCaption() { perspective, image_path: normalizedImagePath, provider: provider.name, + model: options.model, // Use the model from options provider_config: providerConfig, // Include the full provider configuration max_tokens: options.max_tokens, temperature: options.temperature, @@ -84,6 +91,7 @@ export function useGeneratePerspectiveCaption() { perspective, image_path: normalizedImagePath, provider: provider.name, + model: options.model, // Log the model from options options: { max_tokens: requestBody.max_tokens, temperature: requestBody.temperature, @@ -136,6 +144,10 @@ export function useGeneratePerspectiveCaption() { }, onError: (error) => { console.error("Caption generation failed", error); + toast.error({ + title: "Caption generation failed", + description: error.message, + }); }, }); } diff --git a/graphcap_studio/src/features/perspectives/types/perspectivesTypes.ts b/graphcap_studio/src/features/perspectives/types/perspectivesTypes.ts index 2a1804a2..8df840f8 100644 --- a/graphcap_studio/src/features/perspectives/types/perspectivesTypes.ts +++ b/graphcap_studio/src/features/perspectives/types/perspectivesTypes.ts @@ -101,6 +101,7 @@ export const CaptionRequestSchema = z.object({ provider: z.string().optional(), // For backward compatibility options: z .object({ + model: z.string(), // Required model name to use for processing max_tokens: z.number().optional(), temperature: z.number().optional(), top_p: z.number().optional(), @@ -190,6 +191,7 @@ export type ServerConnection = { * Specifies options for generating captions. */ export type CaptionOptions = { + model: string; // Required model name to use for processing max_tokens?: number; temperature?: number; top_p?: number; diff --git a/graphcap_studio/src/utils/toast.ts b/graphcap_studio/src/utils/toast.ts index ef8b5bbe..ed918de2 100644 --- a/graphcap_studio/src/utils/toast.ts +++ b/graphcap_studio/src/utils/toast.ts @@ -105,8 +105,8 @@ export const toast = { /** * Show a success toast */ - success: ({ title, description, duration = 3000 }: { title: string; description?: string; duration?: number }) => { - toaster.create({ + success: ({ title, description, duration = 1000 }: { title: string; description?: string; duration?: number }) => { + return toaster.create({ title, description, duration, @@ -117,8 +117,8 @@ export const toast = { /** * Show an error toast */ - error: ({ title, description, duration = 5000 }: { title: string; description?: string; duration?: number }) => { - toaster.create({ + error: ({ title, description, duration = 2000 }: { title: string; description?: string; duration?: number }) => { + return toaster.create({ title, description, duration, @@ -129,8 +129,8 @@ export const toast = { /** * Show an info toast */ - info: ({ title, description, duration = 3000 }: { title: string; description?: string; duration?: number }) => { - toaster.create({ + info: ({ title, description, duration = 2000 }: { title: string; description?: string; duration?: number }) => { + return toaster.create({ title, description, duration, @@ -141,12 +141,34 @@ export const toast = { /** * Show a warning toast */ - warning: ({ title, description, duration = 4000 }: { title: string; description?: string; duration?: number }) => { - toaster.create({ + warning: ({ title, description, duration = 2000 }: { title: string; description?: string; duration?: number }) => { + return toaster.create({ title, description, duration, type: "warning", }); + }, + + /** + * Dismiss a toast by its ID + * If no ID is provided, all toasts will be dismissed + */ + dismiss: (id?: string) => { + toaster.dismiss(id); + }, + + /** + * Pause a toast by its ID to prevent it from timing out + */ + pause: (id: string) => { + toaster.pause(id); + }, + + /** + * Resume a paused toast, re-enabling the timeout with the remaining duration + */ + resume: (id: string) => { + toaster.resume(id); } }; diff --git a/servers/inference_bridge/graphcap/perspectives/base_caption.py b/servers/inference_bridge/graphcap/perspectives/base_caption.py index 737e09a3..2d1643c7 100644 --- a/servers/inference_bridge/graphcap/perspectives/base_caption.py +++ b/servers/inference_bridge/graphcap/perspectives/base_caption.py @@ -112,6 +112,7 @@ async def process_single( self, provider: BaseClient, image_path: Path, + model: str, max_tokens: Optional[int] = 4096, temperature: Optional[float] = 0.8, top_p: Optional[float] = 0.9, @@ -125,6 +126,7 @@ async def process_single( Args: provider: Vision AI provider client instance image_path: Path to the image file + model: Model name to use for processing max_tokens: Maximum tokens for model response temperature: Sampling temperature top_p: Nucleus sampling parameter @@ -151,7 +153,7 @@ async def process_single( prompt=prompt, image=image_path, schema=self.vision_config.schema, - model=provider.default_model, + model=model, max_tokens=max_tokens, temperature=temperature, top_p=top_p, @@ -186,6 +188,7 @@ async def process_batch( self, provider: BaseClient, image_paths: List[Path], + model: str, max_tokens: Optional[int] = 4096, temperature: Optional[float] = 0.8, top_p: Optional[float] = 0.9, @@ -205,6 +208,7 @@ async def process_batch( Args: provider: Vision AI provider client instance image_paths: List of paths to image files + model: Model name to use for processing max_tokens: Maximum tokens for model response temperature: Sampling temperature top_p: Nucleus sampling parameter @@ -247,7 +251,7 @@ async def process_batch( job_info_data = { "started_at": timestamp, "provider": provider.name, - "model": provider.default_model, + "model": model, "config_name": self.vision_config.config_name, "version": self.vision_config.version, "total_images": len(image_paths), @@ -308,6 +312,7 @@ async def process_with_semaphore(path: Path) -> Dict[str, Any]: result = await self.process_single( provider=provider, image_path=path, + model=model, max_tokens=max_tokens, temperature=temperature, top_p=top_p, @@ -324,7 +329,7 @@ async def process_with_semaphore(path: Path) -> Dict[str, Any]: "filename": f"./{path.name}", "config_name": self.vision_config.config_name, "version": self.vision_config.version, - "model": provider.default_model, + "model": model, "provider": provider.name, "parsed": result, } @@ -354,7 +359,7 @@ async def process_with_semaphore(path: Path) -> Dict[str, Any]: "filename": f"./{path.name}", "config_name": self.vision_config.config_name, "version": self.vision_config.version, - "model": provider.default_model, + "model": model, "provider": provider.name, "parsed": {"error": str(e)}, } diff --git a/servers/inference_bridge/graphcap/providers/clients/ollama_client.py b/servers/inference_bridge/graphcap/providers/clients/ollama_client.py index 432894cf..3a76c3ff 100644 --- a/servers/inference_bridge/graphcap/providers/clients/ollama_client.py +++ b/servers/inference_bridge/graphcap/providers/clients/ollama_client.py @@ -75,7 +75,6 @@ async def get_models(self): try: logger.info("Fetching models from Ollama:") logger.info(f" - URL: {self._raw_base_url}/models") - logger.info(f" - Default model: {self.default_model}") async with httpx.AsyncClient() as client: response = await client.get(f"{self._raw_base_url}/models") @@ -88,7 +87,6 @@ async def get_models(self): logger.error("Connection error while fetching models from Ollama:") logger.error(f" - Error: {str(e)}") logger.error(f" - URL: {self._raw_base_url}/models") - logger.error(f" - Default model: {self.default_model}") raise except Exception as e: logger.error(f"Failed to get models from Ollama: {str(e)}") diff --git a/servers/inference_bridge/pipelines/pipelines/providers/assets.py b/servers/inference_bridge/pipelines/pipelines/providers/assets.py index 4876f39c..748e1ebb 100644 --- a/servers/inference_bridge/pipelines/pipelines/providers/assets.py +++ b/servers/inference_bridge/pipelines/pipelines/providers/assets.py @@ -2,6 +2,7 @@ """Assets for loading provider configurations.""" import dagster as dg + from graphcap.providers.types import ProviderConfig from ..common.resources import ProviderConfigFile @@ -23,7 +24,6 @@ def provider_list( env_var="GOOGLE_API_KEY", base_url="https://generativelanguage.googleapis.com/v1beta", models=["gemini-2.0-flash-exp"], - default_model="gemini-2.0-flash-exp", fetch_models=False, ) diff --git a/servers/inference_bridge/pipelines/pipelines/providers/util.py b/servers/inference_bridge/pipelines/pipelines/providers/util.py index 2086cfd4..278d2534 100644 --- a/servers/inference_bridge/pipelines/pipelines/providers/util.py +++ b/servers/inference_bridge/pipelines/pipelines/providers/util.py @@ -1,5 +1,4 @@ from graphcap.providers.factory import create_provider_client -from ..perspectives.jobs.config import PerspectivePipelineConfig def get_provider(config_path: str, default_provider: str): @@ -18,9 +17,9 @@ def get_provider(config_path: str, default_provider: str): client = create_provider_client( name=default_provider, kind="gemini", - environment="cloud", + environment="cloud", base_url="https://generativelanguage.googleapis.com/v1beta", api_key="", # API key will be retrieved from environment variable - default_model="gemini-2.0-flash-exp", + models=["gemini-2.0-flash-exp"], # Specify models explicitly ) return client diff --git a/servers/inference_bridge/server/server/features/perspectives/models.py b/servers/inference_bridge/server/server/features/perspectives/models.py index ed0196c4..30a62a69 100644 --- a/servers/inference_bridge/server/server/features/perspectives/models.py +++ b/servers/inference_bridge/server/server/features/perspectives/models.py @@ -191,6 +191,7 @@ class CaptionPathRequest(BaseModel): image_path: str = Field(..., description="Path to the image file in the workspace") provider: str = Field(..., description="Name of the provider to use") provider_config: dict = Field(..., description="Provider configuration") + model: str = Field(..., description="Model name to use for processing") max_tokens: Optional[int] = Field(4096, description=DESC_MAX_TOKENS) temperature: Optional[float] = Field(0.8, description=DESC_TEMPERATURE) top_p: Optional[float] = Field(0.9, description=DESC_TOP_P) @@ -205,6 +206,7 @@ class Config: "perspective": "custom_caption", "image_path": "/workspace/datasets/example.jpg", "provider": "gemini", + "model": "gemini-pro-vision", "provider_config": { "name": "gemini", "kind": "gemini", diff --git a/servers/inference_bridge/server/server/features/perspectives/router.py b/servers/inference_bridge/server/server/features/perspectives/router.py index 053bd282..f177ab30 100644 --- a/servers/inference_bridge/server/server/features/perspectives/router.py +++ b/servers/inference_bridge/server/server/features/perspectives/router.py @@ -19,24 +19,17 @@ from pathlib import Path from typing import List, Optional -from fastapi import APIRouter, BackgroundTasks, File, Form, HTTPException, UploadFile, status +from fastapi import (APIRouter, BackgroundTasks, File, Form, HTTPException, + UploadFile, status) from loguru import logger -from ...utils.resizing import ResolutionPreset, log_resize_options, resize_image -from .models import ( - CaptionPathRequest, - CaptionResponse, - ModuleListResponse, - ModulePerspectivesResponse, - PerspectiveListResponse, -) -from .service import ( - generate_caption, - get_available_modules, - get_available_perspectives, - get_perspectives_by_module, - save_uploaded_file, -) +from ...utils.resizing import (ResolutionPreset, log_resize_options, + resize_image) +from .models import (CaptionPathRequest, CaptionResponse, ModuleListResponse, + ModulePerspectivesResponse, PerspectiveListResponse) +from .service import (generate_caption, get_available_modules, + get_available_perspectives, get_perspectives_by_module, + save_uploaded_file) router = APIRouter(prefix="/perspectives", tags=["perspectives"]) @@ -60,6 +53,7 @@ async def create_caption( perspective: str = Form(..., description="Name of the perspective to use"), provider: str = Form(..., description="Name of the provider to use"), provider_config: str = Form(..., description="Provider configuration as JSON string"), + model: str = Form(..., description="Model name to use for processing"), max_tokens: Optional[int] = Form(4096, description="Maximum number of tokens"), temperature: Optional[float] = Form(0.8, description="Temperature for generation"), top_p: Optional[float] = Form(0.9, description="Top-p sampling parameter"), @@ -79,6 +73,7 @@ async def create_caption( perspective: Name of the perspective to use (required) provider: Name of the provider to use (required) provider_config: Provider configuration as JSON string (required) + model: Model name to use for processing (required) max_tokens: Maximum number of tokens (optional, default: 4096) temperature: Temperature for generation (optional, default: 0.8) top_p: Top-p sampling parameter (optional, default: 0.9) @@ -150,14 +145,23 @@ async def create_caption( if not parsed_provider_config: logger.error(f"No provider configuration provided for {provider}") raise HTTPException( - status_code=400, + status_code=400, detail=f"Provider configuration not provided for '{provider}'. Please include provider_config in the request." ) + # Validate model is provided + if not model: + logger.error(f"No model specified for {provider}") + raise HTTPException( + status_code=400, + detail=f"Model name not provided for '{provider}'. Please include model in the request." + ) + # Generate the caption caption_data = await generate_caption( perspective_name=perspective, image_path=image_path, + model=model, max_tokens=max_tokens, temperature=temperature, top_p=top_p, @@ -231,10 +235,19 @@ async def create_caption_from_path( detail=f"Provider configuration not provided for '{request.provider}'. Please include provider_config in the request." ) + # Validate that model is provided + if not hasattr(request, 'model') or not request.model: + logger.error(f"No model specified for {request.provider}") + raise HTTPException( + status_code=400, + detail=f"Model name not provided for '{request.provider}'. Please include model in the request." + ) + # Generate the caption caption_data = await generate_caption( perspective_name=request.perspective, image_path=image_path, + model=request.model, max_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, diff --git a/servers/inference_bridge/server/server/features/perspectives/service.py b/servers/inference_bridge/server/server/features/perspectives/service.py index 19a3155e..9af3cab9 100644 --- a/servers/inference_bridge/server/server/features/perspectives/service.py +++ b/servers/inference_bridge/server/server/features/perspectives/service.py @@ -7,7 +7,6 @@ import base64 import os -import socket import tempfile from collections import defaultdict from pathlib import Path @@ -15,15 +14,12 @@ import aiohttp from fastapi import HTTPException, UploadFile -from graphcap.perspectives import ( - get_perspective, - get_perspective_list, -) -from graphcap.providers.clients.base_client import BaseClient from loguru import logger -from ..providers.service import create_provider_client_from_config -from .models import ModuleInfo, PerspectiveInfo, PerspectiveSchema, SchemaField, TableColumn +from graphcap.perspectives import get_perspective, get_perspective_list + +from .models import (ModuleInfo, PerspectiveInfo, PerspectiveSchema, + SchemaField, TableColumn) async def download_image(url: str) -> Path: @@ -108,7 +104,7 @@ def load_perspective_schema(perspective_name: str) -> Optional[PerspectiveSchema try: # Import perspective function from graphcap.perspectives import get_perspective - + # Get the perspective processor perspective = get_perspective(perspective_name) if perspective and hasattr(perspective, 'config'): @@ -263,6 +259,7 @@ def get_perspectives_by_module(module_name: str) -> List[PerspectiveInfo]: async def generate_caption( perspective_name: str, image_path: Path, + model: str, max_tokens: Optional[int] = 4096, temperature: Optional[float] = 0.8, top_p: Optional[float] = 0.9, @@ -278,6 +275,7 @@ async def generate_caption( Args: perspective_name: Name of the perspective to use image_path: Path to the image file + model: Model name to use for processing max_tokens: Maximum number of tokens in the response temperature: Temperature for generation top_p: Top-p sampling parameter @@ -301,7 +299,7 @@ async def generate_caption( if provider_config: from ..providers.models import ProviderConfig from ..providers.service import create_provider_client_from_config - + # Convert dict to ProviderConfig config = ProviderConfig(**provider_config) provider = create_provider_client_from_config(config) @@ -332,6 +330,7 @@ async def generate_caption( caption_data_list = await perspective.process_batch( provider=provider, image_paths=[image_path], + model=model, output_dir=output_dir, max_tokens=max_tokens, temperature=temperature, @@ -353,6 +352,7 @@ async def generate_caption( caption_data = await perspective.process_single( provider=provider, image_path=image_path, + model=model, max_tokens=max_tokens, temperature=temperature, top_p=top_p, diff --git a/servers/inference_bridge/server/server/features/providers/error_handler.py b/servers/inference_bridge/server/server/features/providers/error_handler.py index caaf9265..ae943988 100644 --- a/servers/inference_bridge/server/server/features/providers/error_handler.py +++ b/servers/inference_bridge/server/server/features/providers/error_handler.py @@ -6,14 +6,12 @@ """ import datetime -import traceback -from typing import Any, Dict, List, Set, Union +from typing import Any, Dict, Set from fastapi.responses import JSONResponse from pydantic import ValidationError from .models import ProviderConfig -from ...utils.logger import logger def format_provider_validation_error(e: ValidationError, provider_name: str) -> JSONResponse: @@ -150,7 +148,6 @@ def format_provider_connection_error(e: Exception, provider_name: str, config: P "kind": config.kind, "environment": config.environment, "base_url": config.base_url, - "default_model": config.default_model, "models": config.models, "fetch_models": config.fetch_models, } @@ -187,4 +184,4 @@ def format_provider_connection_error(e: Exception, provider_name: str, config: P return JSONResponse( status_code=400, content=error_response - ) \ No newline at end of file + ) From 996bf8b52f7d20e6598ac695648ca2087bb163fc Mon Sep 17 00:00:00 2001 From: jphillips Date: Sat, 29 Mar 2025 06:58:57 -0500 Subject: [PATCH 27/69] Move GenOptions to root types Signed-off-by: jphillips --- .../components/fields/OptionField.tsx | 8 +++---- .../fields/ResizeResolutionField.tsx | 2 +- .../context/GenerationOptionsContext.tsx | 10 ++++---- .../inference/generation-options/index.ts | 2 +- .../persist-generation-options.ts | 5 +++- .../containers/ProviderFormContainer.tsx | 6 ++--- .../PerspectiveActions/PerspectivesFooter.tsx | 24 +++++++++---------- .../generation-option-types.ts} | 0 graphcap_studio/src/types/index.ts | 0 9 files changed, 29 insertions(+), 28 deletions(-) rename graphcap_studio/src/{features/inference/generation-options/schema.ts => types/generation-option-types.ts} (100%) create mode 100644 graphcap_studio/src/types/index.ts diff --git a/graphcap_studio/src/features/inference/generation-options/components/fields/OptionField.tsx b/graphcap_studio/src/features/inference/generation-options/components/fields/OptionField.tsx index bf82efbf..7e3fb06e 100644 --- a/graphcap_studio/src/features/inference/generation-options/components/fields/OptionField.tsx +++ b/graphcap_studio/src/features/inference/generation-options/components/fields/OptionField.tsx @@ -7,9 +7,9 @@ import { Slider } from "@/components/ui/slider"; import { useColorModeValue } from "@/components/ui/theme/color-mode"; +import { OPTION_CONFIGS } from "@/types/generation-option-types"; import { Box, HStack, Input } from "@chakra-ui/react"; -import { ChangeEvent } from "react"; -import { OPTION_CONFIGS } from "../../schema"; +import type { ChangeEvent } from "react"; export type OptionFieldKey = keyof typeof OPTION_CONFIGS; @@ -46,8 +46,8 @@ export function OptionField({ // Handle direct input changes const handleInputChange = (e: ChangeEvent) => { - const valueAsNumber = parseFloat(e.target.value); - if (isNaN(valueAsNumber)) return; + const valueAsNumber = Number.parseFloat(e.target.value); + if (Number.isNaN(valueAsNumber)) return; // Ensure value is within bounds const boundedValue = Math.max( diff --git a/graphcap_studio/src/features/inference/generation-options/components/fields/ResizeResolutionField.tsx b/graphcap_studio/src/features/inference/generation-options/components/fields/ResizeResolutionField.tsx index f70ce65a..4a1a7d7e 100644 --- a/graphcap_studio/src/features/inference/generation-options/components/fields/ResizeResolutionField.tsx +++ b/graphcap_studio/src/features/inference/generation-options/components/fields/ResizeResolutionField.tsx @@ -6,9 +6,9 @@ */ import { useColorModeValue } from "@/components/ui/theme/color-mode"; +import { RESOLUTION_PRESETS } from "@/types/generation-option-types"; import { Box, HStack } from "@chakra-ui/react"; import { useGenerationOptions } from "../../context"; -import { RESOLUTION_PRESETS } from "../../schema"; /** * Field component for adjusting image resize resolution diff --git a/graphcap_studio/src/features/inference/generation-options/context/GenerationOptionsContext.tsx b/graphcap_studio/src/features/inference/generation-options/context/GenerationOptionsContext.tsx index d0745c16..30d3400c 100644 --- a/graphcap_studio/src/features/inference/generation-options/context/GenerationOptionsContext.tsx +++ b/graphcap_studio/src/features/inference/generation-options/context/GenerationOptionsContext.tsx @@ -5,6 +5,11 @@ * This module provides a context for managing generation options state. */ +import { + DEFAULT_OPTIONS, + type GenerationOptions, + GenerationOptionsSchema, +} from "@/types/generation-option-types"; import type React from "react"; import { createContext, @@ -15,11 +20,6 @@ import { useState, } from "react"; import { usePersistGenerationOptions } from "../persist-generation-options"; -import { - DEFAULT_OPTIONS, - type GenerationOptions, - GenerationOptionsSchema, -} from "../schema"; // Define the context interface interface GenerationOptionsContextValue { diff --git a/graphcap_studio/src/features/inference/generation-options/index.ts b/graphcap_studio/src/features/inference/generation-options/index.ts index e6b81267..aa028147 100644 --- a/graphcap_studio/src/features/inference/generation-options/index.ts +++ b/graphcap_studio/src/features/inference/generation-options/index.ts @@ -7,5 +7,5 @@ export * from "./components"; export * from "./context"; -export * from "./schema"; export * from "./persist-generation-options"; + diff --git a/graphcap_studio/src/features/inference/generation-options/persist-generation-options.ts b/graphcap_studio/src/features/inference/generation-options/persist-generation-options.ts index 10954d49..e1e66895 100644 --- a/graphcap_studio/src/features/inference/generation-options/persist-generation-options.ts +++ b/graphcap_studio/src/features/inference/generation-options/persist-generation-options.ts @@ -5,7 +5,10 @@ * This module provides utilities for persisting generation options to localStorage. */ -import { GenerationOptions, GenerationOptionsSchema } from "./schema"; +import { + type GenerationOptions, + GenerationOptionsSchema, +} from "@/types/generation-option-types"; /** * Storage key for saving generation options in localStorage diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx index fd8339d1..004a9d30 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx @@ -1,6 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 -import { useCallback, useState } from "react"; import type { ReactNode } from "react"; +import { useCallback, useState } from "react"; import { useForm } from "react-hook-form"; import { useCreateProvider, useProviders, useTestProviderConnection, useUpdateProvider } from "../../../services/providers"; import { useInferenceProviderContext } from "../../context/InferenceProviderContext"; @@ -25,7 +25,6 @@ interface ProviderFormContainerProps { export function ProviderFormContainer({ children, initialData, - onSubmit: onSubmitProp, }: ProviderFormContainerProps) { // Get model selection and provider state from the InferenceProviderContext const { @@ -140,8 +139,7 @@ export function ProviderFormContainer({ // Create new provider await createProvider.mutateAsync(data as ProviderCreate); } else { - // This is the custom submit handler from parent (if any) - await onSubmitProp(data); + } setSaveSuccess(true); diff --git a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveActions/PerspectivesFooter.tsx b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveActions/PerspectivesFooter.tsx index b13c976f..a3b3bafc 100644 --- a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveActions/PerspectivesFooter.tsx +++ b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveActions/PerspectivesFooter.tsx @@ -7,23 +7,23 @@ import { useColorModeValue } from "@/components/ui/theme/color-mode"; import { - GenerationOptionsButton, - GenerationOptionsProvider, - ProviderSelector, + GenerationOptionsButton, + GenerationOptionsProvider, + ProviderSelector, } from "@/features/inference/generation-options"; -import { DEFAULT_OPTIONS } from "@/features/inference/generation-options/schema"; import { - usePerspectiveUI, - usePerspectivesData, + usePerspectiveUI, + usePerspectivesData, } from "@/features/perspectives/context"; import type { CaptionOptions } from "@/features/perspectives/types"; +import { DEFAULT_OPTIONS } from "@/types/generation-option-types"; import { - Box, - Button, - Flex, - HStack, - Icon, - useBreakpointValue, + Box, + Button, + Flex, + HStack, + Icon, + useBreakpointValue, } from "@chakra-ui/react"; import { useCallback, useEffect } from "react"; import { LuRefreshCw, LuSettings } from "react-icons/lu"; diff --git a/graphcap_studio/src/features/inference/generation-options/schema.ts b/graphcap_studio/src/types/generation-option-types.ts similarity index 100% rename from graphcap_studio/src/features/inference/generation-options/schema.ts rename to graphcap_studio/src/types/generation-option-types.ts diff --git a/graphcap_studio/src/types/index.ts b/graphcap_studio/src/types/index.ts new file mode 100644 index 00000000..e69de29b From e77c5559d97f409bafaeb7632ec7f2e2c0d73e8c Mon Sep 17 00:00:00 2001 From: jphillips Date: Sat, 29 Mar 2025 07:04:50 -0500 Subject: [PATCH 28/69] Move Server Connection and Perspective types to type root Signed-off-by: jphillips --- graphcap_studio/src/context/ServerConnectionsContext.tsx | 4 ++-- .../perspectives/hooks/useGeneratePerspectiveCaption.ts | 2 +- .../src/features/perspectives/hooks/usePerspectives.ts | 7 +++---- graphcap_studio/src/features/perspectives/types/index.ts | 5 +++-- .../components/ConnectionActionButton.tsx | 2 +- .../server-connections/components/ConnectionCard.tsx | 2 +- .../components/ConnectionStatusIndicator.tsx | 2 +- .../server-connections/components/ConnectionUrlInput.tsx | 4 ++-- .../components/ServerConnectionsPanel.tsx | 2 +- graphcap_studio/src/features/server-connections/index.ts | 3 ++- .../server-connections/services/dataServiceClient.ts | 2 +- .../server-connections/services/inferenceBridgeClient.ts | 2 +- .../server-connections/services/serverConnections.ts | 2 +- .../features/server-connections/useServerConnections.ts | 2 +- .../perspectives => }/types/perspectiveModuleTypes.ts | 4 ++-- .../{features/perspectives => }/types/perspectivesTypes.ts | 0 .../types.ts => types/server-connection-types.ts} | 0 17 files changed, 23 insertions(+), 22 deletions(-) rename graphcap_studio/src/{features/perspectives => }/types/perspectiveModuleTypes.ts (92%) rename graphcap_studio/src/{features/perspectives => }/types/perspectivesTypes.ts (100%) rename graphcap_studio/src/{features/server-connections/types.ts => types/server-connection-types.ts} (100%) diff --git a/graphcap_studio/src/context/ServerConnectionsContext.tsx b/graphcap_studio/src/context/ServerConnectionsContext.tsx index 047bbe75..02c0a1b9 100644 --- a/graphcap_studio/src/context/ServerConnectionsContext.tsx +++ b/graphcap_studio/src/context/ServerConnectionsContext.tsx @@ -1,7 +1,7 @@ import { useServerConnections } from "@/features/server-connections"; -import { ServerConnection } from "@/features/server-connections/types"; +import type { ServerConnection } from "@/types/server-connection-types"; // SPDX-License-Identifier: Apache-2.0 -import { ReactNode, createContext, useContext } from "react"; +import { type ReactNode, createContext, useContext } from "react"; /** * Interface for the ServerConnectionsContext value diff --git a/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts b/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts index 86c10c62..75d51732 100644 --- a/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts +++ b/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts @@ -12,7 +12,7 @@ import { } from "@/features/inference/providers/types"; import { SERVER_IDS } from "@/features/server-connections/constants"; import { createInferenceBridgeClient } from "@/features/server-connections/services/apiClients"; -import type { ServerConnection } from "@/features/server-connections/types"; +import type { ServerConnection } from "@/types/server-connection-types"; import { toast } from "@/utils/toast"; import { useMutation, useQueryClient } from "@tanstack/react-query"; import { perspectivesQueryKeys } from "../services/constants"; diff --git a/graphcap_studio/src/features/perspectives/hooks/usePerspectives.ts b/graphcap_studio/src/features/perspectives/hooks/usePerspectives.ts index 492749f9..f9b9a69e 100644 --- a/graphcap_studio/src/features/perspectives/hooks/usePerspectives.ts +++ b/graphcap_studio/src/features/perspectives/hooks/usePerspectives.ts @@ -9,15 +9,14 @@ import { useServerConnectionsContext } from "@/context"; import { SERVER_IDS } from "@/features/server-connections/constants"; import { createInferenceBridgeClient } from "@/features/server-connections/services/apiClients"; -import type { ServerConnection } from "@/features/server-connections/types"; +import type { ServerConnection } from "@/types/server-connection-types"; import { useQuery } from "@tanstack/react-query"; import { useEffect } from "react"; import { - API_ENDPOINTS, CACHE_TIMES, - perspectivesQueryKeys, + perspectivesQueryKeys } from "../services/constants"; -import { getGraphCapServerUrl, handleApiError } from "../services/utils"; +import { handleApiError } from "../services/utils"; import type { Perspective, PerspectiveListResponse } from "../types"; /** diff --git a/graphcap_studio/src/features/perspectives/types/index.ts b/graphcap_studio/src/features/perspectives/types/index.ts index 6b6024f2..ab5f5687 100644 --- a/graphcap_studio/src/features/perspectives/types/index.ts +++ b/graphcap_studio/src/features/perspectives/types/index.ts @@ -6,5 +6,6 @@ * Type definitions are consolidated in their respective files. */ -export * from "./perspectivesTypes"; -export * from "./perspectiveModuleTypes"; +export * from "@/types/perspectiveModuleTypes"; +export * from "@/types/perspectivesTypes"; + diff --git a/graphcap_studio/src/features/server-connections/components/ConnectionActionButton.tsx b/graphcap_studio/src/features/server-connections/components/ConnectionActionButton.tsx index 6e55fe11..48f908ad 100644 --- a/graphcap_studio/src/features/server-connections/components/ConnectionActionButton.tsx +++ b/graphcap_studio/src/features/server-connections/components/ConnectionActionButton.tsx @@ -1,5 +1,5 @@ import { CONNECTION_STATUS } from "@/features/server-connections/constants"; -import type { ConnectionActionButtonProps } from "@/features/server-connections/types"; +import type { ConnectionActionButtonProps } from "@/types/server-connection-types"; import { Button } from "@chakra-ui/react"; // SPDX-License-Identifier: Apache-2.0 import { memo } from "react"; diff --git a/graphcap_studio/src/features/server-connections/components/ConnectionCard.tsx b/graphcap_studio/src/features/server-connections/components/ConnectionCard.tsx index 2603dd55..1e46f435 100644 --- a/graphcap_studio/src/features/server-connections/components/ConnectionCard.tsx +++ b/graphcap_studio/src/features/server-connections/components/ConnectionCard.tsx @@ -1,5 +1,5 @@ import { useColorModeValue } from "@/components/ui/theme/color-mode"; -import { ConnectionCardProps } from "@/features/server-connections/types"; +import type { ConnectionCardProps } from "@/types/server-connection-types"; import { Box, Flex, Heading, Stack } from "@chakra-ui/react"; // SPDX-License-Identifier: Apache-2.0 import { memo } from "react"; diff --git a/graphcap_studio/src/features/server-connections/components/ConnectionStatusIndicator.tsx b/graphcap_studio/src/features/server-connections/components/ConnectionStatusIndicator.tsx index a91b5680..df872334 100644 --- a/graphcap_studio/src/features/server-connections/components/ConnectionStatusIndicator.tsx +++ b/graphcap_studio/src/features/server-connections/components/ConnectionStatusIndicator.tsx @@ -1,6 +1,6 @@ import { Status } from "@/components/ui/status"; import { CONNECTION_STATUS } from "@/features/server-connections/constants"; -import type { ConnectionStatusIndicatorProps } from "@/features/server-connections/types"; +import type { ConnectionStatusIndicatorProps } from "@/types/server-connection-types"; // SPDX-License-Identifier: Apache-2.0 import { memo } from "react"; diff --git a/graphcap_studio/src/features/server-connections/components/ConnectionUrlInput.tsx b/graphcap_studio/src/features/server-connections/components/ConnectionUrlInput.tsx index 855c773d..996ba3ad 100644 --- a/graphcap_studio/src/features/server-connections/components/ConnectionUrlInput.tsx +++ b/graphcap_studio/src/features/server-connections/components/ConnectionUrlInput.tsx @@ -1,8 +1,8 @@ import { useColorModeValue } from "@/components/ui/theme/color-mode"; -import { ConnectionUrlInputProps } from "@/features/server-connections/types"; +import type { ConnectionUrlInputProps } from "@/types/server-connection-types"; import { Input } from "@chakra-ui/react"; // SPDX-License-Identifier: Apache-2.0 -import { ChangeEvent, memo } from "react"; +import { type ChangeEvent, memo } from "react"; /** * ConnectionUrlInput component diff --git a/graphcap_studio/src/features/server-connections/components/ServerConnectionsPanel.tsx b/graphcap_studio/src/features/server-connections/components/ServerConnectionsPanel.tsx index 9dd57e36..7c118b94 100644 --- a/graphcap_studio/src/features/server-connections/components/ServerConnectionsPanel.tsx +++ b/graphcap_studio/src/features/server-connections/components/ServerConnectionsPanel.tsx @@ -1,7 +1,7 @@ import { useColorModeValue } from "@/components/ui/theme/color-mode"; import { useServerConnectionsContext } from "@/context/ServerConnectionsContext"; import { CONNECTION_STATUS } from "@/features/server-connections/constants"; -import type { ServerConnectionsPanelProps } from "@/features/server-connections/types"; +import type { ServerConnectionsPanelProps } from "@/types/server-connection-types"; import { Box, Button, Flex, Heading, Spinner, Stack } from "@chakra-ui/react"; // SPDX-License-Identifier: Apache-2.0 import { memo, useMemo } from "react"; diff --git a/graphcap_studio/src/features/server-connections/index.ts b/graphcap_studio/src/features/server-connections/index.ts index 51a3eb3a..afe983f5 100644 --- a/graphcap_studio/src/features/server-connections/index.ts +++ b/graphcap_studio/src/features/server-connections/index.ts @@ -1,4 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -export * from "./types"; +export * from "@/types/server-connection-types"; export * from "./constants"; export * from "./useServerConnections"; + diff --git a/graphcap_studio/src/features/server-connections/services/dataServiceClient.ts b/graphcap_studio/src/features/server-connections/services/dataServiceClient.ts index f201a5b3..99927776 100644 --- a/graphcap_studio/src/features/server-connections/services/dataServiceClient.ts +++ b/graphcap_studio/src/features/server-connections/services/dataServiceClient.ts @@ -5,9 +5,9 @@ * This module provides client functions for interacting with the Data Service API. */ +import type { ServerConnection } from "@/types/server-connection-types"; import { hc } from "hono/client"; import { DEFAULT_URLS, SERVER_IDS } from "../constants"; -import type { ServerConnection } from "../types"; /** * Interface for the Data Service client diff --git a/graphcap_studio/src/features/server-connections/services/inferenceBridgeClient.ts b/graphcap_studio/src/features/server-connections/services/inferenceBridgeClient.ts index 6433fc5b..f7f114dc 100644 --- a/graphcap_studio/src/features/server-connections/services/inferenceBridgeClient.ts +++ b/graphcap_studio/src/features/server-connections/services/inferenceBridgeClient.ts @@ -5,9 +5,9 @@ * This module provides client functions for interacting with the Inference Bridge API. */ +import type { ServerConnection } from "@/types/server-connection-types"; import { hc } from "hono/client"; import { DEFAULT_URLS, SERVER_IDS } from "../constants"; -import type { ServerConnection } from "../types"; /** * Interface for the Inference Bridge Provider operations diff --git a/graphcap_studio/src/features/server-connections/services/serverConnections.ts b/graphcap_studio/src/features/server-connections/services/serverConnections.ts index cbc51a61..d1b8503b 100644 --- a/graphcap_studio/src/features/server-connections/services/serverConnections.ts +++ b/graphcap_studio/src/features/server-connections/services/serverConnections.ts @@ -6,8 +6,8 @@ * such as the Media Server and Inference Bridge. */ +import type { ServerConnection } from "@/types/server-connection-types"; import { CONNECTION_STATUS, SERVER_IDS } from "../constants"; -import type { ServerConnection } from "../types"; import { createDataServiceClient, createInferenceBridgeClient } from "./apiClients"; /** diff --git a/graphcap_studio/src/features/server-connections/useServerConnections.ts b/graphcap_studio/src/features/server-connections/useServerConnections.ts index 564a69f4..ecd98566 100644 --- a/graphcap_studio/src/features/server-connections/useServerConnections.ts +++ b/graphcap_studio/src/features/server-connections/useServerConnections.ts @@ -1,3 +1,4 @@ +import type { ServerConnection } from "@/types/server-connection-types"; // SPDX-License-Identifier: Apache-2.0 import { useCallback, useEffect, useRef, useState } from "react"; import { @@ -7,7 +8,6 @@ import { SERVER_NAMES, } from "./constants"; import { checkServerHealthById } from "./services/serverConnections"; -import type { ServerConnection } from "./types"; // Local storage keys const STORAGE_KEY = "inference-bridge-connections"; diff --git a/graphcap_studio/src/features/perspectives/types/perspectiveModuleTypes.ts b/graphcap_studio/src/types/perspectiveModuleTypes.ts similarity index 92% rename from graphcap_studio/src/features/perspectives/types/perspectiveModuleTypes.ts rename to graphcap_studio/src/types/perspectiveModuleTypes.ts index 3084b7ab..a4f186ca 100644 --- a/graphcap_studio/src/features/perspectives/types/perspectiveModuleTypes.ts +++ b/graphcap_studio/src/types/perspectiveModuleTypes.ts @@ -5,9 +5,9 @@ * This module defines types related to perspective modules and management. */ +import type { Perspective } from "@/types/perspectivesTypes"; +import { PerspectiveSchema } from "@/types/perspectivesTypes"; import { z } from "zod"; -import { PerspectiveSchema } from "./perspectivesTypes"; -import type { Perspective } from "./perspectivesTypes"; /** * Schema for module information diff --git a/graphcap_studio/src/features/perspectives/types/perspectivesTypes.ts b/graphcap_studio/src/types/perspectivesTypes.ts similarity index 100% rename from graphcap_studio/src/features/perspectives/types/perspectivesTypes.ts rename to graphcap_studio/src/types/perspectivesTypes.ts diff --git a/graphcap_studio/src/features/server-connections/types.ts b/graphcap_studio/src/types/server-connection-types.ts similarity index 100% rename from graphcap_studio/src/features/server-connections/types.ts rename to graphcap_studio/src/types/server-connection-types.ts From 2a1564095d6b71606fe1ec7961074462b615a3fe Mon Sep 17 00:00:00 2001 From: jphillips Date: Sat, 29 Mar 2025 07:10:48 -0500 Subject: [PATCH 29/69] Move Provider config types to common types Signed-off-by: jphillips --- .../inference/hooks/useModelSelection.ts | 2 +- .../hooks/useProviderModelSelection.ts | 2 +- .../ProviderConnection/component.tsx | 4 --- .../ProviderConnectionErrorDialog.tsx | 2 +- .../ProviderConnectionSuccessDialog.tsx | 2 +- .../components/actions/ProviderSaveDialog.tsx | 2 +- .../containers/ProviderFormContainer.tsx | 4 +-- .../context/useProviderForm.ts | 4 +-- .../hooks/useProviderConnection.ts | 9 +------ .../inference/providers/ProvidersList.tsx | 2 +- .../context/InferenceProviderContext.tsx | 2 +- .../providers/context/ProviderFormContext.tsx | 2 +- .../src/features/inference/providers/index.ts | 5 ++-- .../features/inference/services/providers.ts | 26 +++---------------- .../hooks/useGeneratePerspectiveCaption.ts | 6 ++--- .../hooks/useImagePerspectives.ts | 2 +- .../server-connections/services/providers.ts | 4 +-- .../provider-config-types.ts} | 0 18 files changed, 25 insertions(+), 55 deletions(-) rename graphcap_studio/src/{features/inference/providers/types.ts => types/provider-config-types.ts} (100%) diff --git a/graphcap_studio/src/features/inference/hooks/useModelSelection.ts b/graphcap_studio/src/features/inference/hooks/useModelSelection.ts index 2a96ff54..c71c86a7 100644 --- a/graphcap_studio/src/features/inference/hooks/useModelSelection.ts +++ b/graphcap_studio/src/features/inference/hooks/useModelSelection.ts @@ -1,7 +1,7 @@ import { useProviderModels } from "@/features/server-connections/services/providers"; +import type { Provider } from "@/types/provider-config-types"; // SPDX-License-Identifier: Apache-2.0 import { useCallback, useEffect, useState } from "react"; -import type { Provider } from "../providers/types"; /** * Custom hook for managing model selection diff --git a/graphcap_studio/src/features/inference/hooks/useProviderModelSelection.ts b/graphcap_studio/src/features/inference/hooks/useProviderModelSelection.ts index d305245b..89e5d931 100644 --- a/graphcap_studio/src/features/inference/hooks/useProviderModelSelection.ts +++ b/graphcap_studio/src/features/inference/hooks/useProviderModelSelection.ts @@ -1,7 +1,7 @@ import { useProviderModels, useProviders } from "@/features/server-connections/services/providers"; +import type { Provider } from "@/types/provider-config-types"; // SPDX-License-Identifier: Apache-2.0 import { useMemo } from "react"; -import type { Provider } from "../providers/types"; /** * Custom hook to handle provider and model selection logic diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/component.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/component.tsx index 3abe5707..8dcc4669 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/component.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/component.tsx @@ -1,8 +1,4 @@ // SPDX-License-Identifier: Apache-2.0 -import type { ReactNode } from "react"; -import { useState } from "react"; -import { useInferenceProviderContext } from "../context"; -import type { ProviderCreate, ProviderUpdate } from "../types"; import { ProviderFormView } from "./components/ProviderFormView"; import { ProviderFormContainer } from "./containers/ProviderFormContainer"; diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionErrorDialog.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionErrorDialog.tsx index 249957f9..486b770e 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionErrorDialog.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionErrorDialog.tsx @@ -1,3 +1,4 @@ +import type { ErrorDetails as ContextErrorDetails } from "@/types/provider-config-types"; // SPDX-License-Identifier: Apache-2.0 import { Box, @@ -12,7 +13,6 @@ import { } from "@chakra-ui/react"; import { useEffect, useRef } from "react"; import { LuTriangleAlert } from "react-icons/lu"; -import type { ErrorDetails as ContextErrorDetails } from "../../types"; type ErrorDetails = { message?: string; diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionSuccessDialog.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionSuccessDialog.tsx index f80d645b..4c8b1390 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionSuccessDialog.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionSuccessDialog.tsx @@ -1,3 +1,4 @@ +import type { ConnectionDetails as ContextConnectionDetails } from "@/types/provider-config-types"; // SPDX-License-Identifier: Apache-2.0 import { Button, @@ -10,7 +11,6 @@ import { } from "@chakra-ui/react"; import { useEffect, useRef } from "react"; import { LuCheck, LuCircleAlert } from "react-icons/lu"; -import type { ConnectionDetails as ContextConnectionDetails } from "../../types"; import { type ConnectionStep, ConnectionSteps } from "./ConnectionSteps"; /** diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/ProviderSaveDialog.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/ProviderSaveDialog.tsx index b5f01f31..ee182871 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/ProviderSaveDialog.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/ProviderSaveDialog.tsx @@ -1,3 +1,4 @@ +import type { Provider, ProviderCreate, ProviderUpdate } from "@/types/provider-config-types"; import { Box, Button, @@ -12,7 +13,6 @@ import { import { useState } from "react"; import { useCreateProvider, useUpdateProvider } from "../../../../services/providers"; import { useProviderFormContext } from "../../../context/ProviderFormContext"; -import type { Provider, ProviderCreate, ProviderUpdate } from "../../../types"; // Define error type with message property interface ErrorWithMessage { diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx index 004a9d30..db00d006 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx @@ -1,3 +1,5 @@ +import type { ConnectionDetails, ErrorDetails, Provider, ProviderCreate, ProviderUpdate } from "@/types/provider-config-types"; +import { toServerConfig } from "@/types/provider-config-types"; // SPDX-License-Identifier: Apache-2.0 import type { ReactNode } from "react"; import { useCallback, useState } from "react"; @@ -5,8 +7,6 @@ import { useForm } from "react-hook-form"; import { useCreateProvider, useProviders, useTestProviderConnection, useUpdateProvider } from "../../../services/providers"; import { useInferenceProviderContext } from "../../context/InferenceProviderContext"; import { ProviderFormProvider } from "../../context/ProviderFormContext"; -import type { ConnectionDetails, ErrorDetails, Provider, ProviderCreate, ProviderUpdate } from "../../types"; -import { toServerConfig } from "../../types"; // Extended Error interface with cause property interface ErrorWithCause extends Error { diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/context/useProviderForm.ts b/graphcap_studio/src/features/inference/providers/ProviderConnection/context/useProviderForm.ts index 127ba4f6..46e14394 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/context/useProviderForm.ts +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/context/useProviderForm.ts @@ -1,9 +1,9 @@ +import { type Provider, type ProviderCreate, type ProviderUpdate, toServerConfig } from "@/types/provider-config-types"; // SPDX-License-Identifier: Apache-2.0 import { useState } from "react"; import { type Control, type FieldErrors, type UseFormHandleSubmit, type UseFormReset, type UseFormWatch, useForm } from "react-hook-form"; import { useTestProviderConnection } from "../../../services/providers"; import { useInferenceProviderContext } from "../../context/InferenceProviderContext"; -import { type Provider, type ProviderCreate, type ProviderUpdate, toServerConfig } from "../../types"; interface UseProviderFormResult { mode: 'view' | 'edit' | 'create'; @@ -185,7 +185,7 @@ export function useProviderForm(initialData: Partial Promise.resolve(), // Placeholder to maintain compatibility + onSubmit: () => Promise.resolve(), // Placeholder to maintain compatibility handleTestConnection, setMode, closeDialog, diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/hooks/useProviderConnection.ts b/graphcap_studio/src/features/inference/providers/ProviderConnection/hooks/useProviderConnection.ts index afe84a05..38290f91 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/hooks/useProviderConnection.ts +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/hooks/useProviderConnection.ts @@ -1,8 +1,8 @@ +import { type Provider, toServerConfig } from "@/types/provider-config-types"; // SPDX-License-Identifier: Apache-2.0 import { useState } from "react"; import { useTestProviderConnection } from "../../../services/providers"; import { useInferenceProviderContext } from "../../context"; -import { type Provider, toServerConfig } from "../../types"; interface UseProviderConnectionResult { isTestingConnection: boolean; @@ -75,13 +75,6 @@ export function useProviderConnection(selectedProvider: Provider | null): UsePro try { const config = toServerConfig(currentFormValues); - const requestDetails = { - provider: currentFormValues.name, - config: { - ...config, - api_key: config.api_key ? '[REDACTED]' : undefined - } - }; const result = await testConnection.mutateAsync({ providerName: currentFormValues.name, diff --git a/graphcap_studio/src/features/inference/providers/ProvidersList.tsx b/graphcap_studio/src/features/inference/providers/ProvidersList.tsx index 6bf48ddd..e79434df 100644 --- a/graphcap_studio/src/features/inference/providers/ProvidersList.tsx +++ b/graphcap_studio/src/features/inference/providers/ProvidersList.tsx @@ -1,8 +1,8 @@ +import type { ProviderCreate, ProviderUpdate } from "@/types/provider-config-types"; // SPDX-License-Identifier: Apache-2.0 import { useProviders } from "../services/providers"; import { ProviderFormSelect } from "./ProviderConnection/components/form/ProviderFormSelect"; import { ProviderFormContainer } from "./ProviderConnection/containers/ProviderFormContainer"; -import type { ProviderCreate, ProviderUpdate } from "./types"; /** * Component for displaying a list of providers as a dropdown diff --git a/graphcap_studio/src/features/inference/providers/context/InferenceProviderContext.tsx b/graphcap_studio/src/features/inference/providers/context/InferenceProviderContext.tsx index f823ed8c..9a4132ed 100644 --- a/graphcap_studio/src/features/inference/providers/context/InferenceProviderContext.tsx +++ b/graphcap_studio/src/features/inference/providers/context/InferenceProviderContext.tsx @@ -1,3 +1,4 @@ +import type { Provider } from "@/types/provider-config-types"; // SPDX-License-Identifier: Apache-2.0 /** * Inference Provider Context @@ -21,7 +22,6 @@ import { useState, } from "react"; import { useModelSelection } from "../../hooks"; -import type { Provider } from "../types"; // Local storage key for selected provider const SELECTED_PROVIDER_STORAGE_KEY = "graphcap-selected-provider"; diff --git a/graphcap_studio/src/features/inference/providers/context/ProviderFormContext.tsx b/graphcap_studio/src/features/inference/providers/context/ProviderFormContext.tsx index 5e814dc5..5b4e3eaf 100644 --- a/graphcap_studio/src/features/inference/providers/context/ProviderFormContext.tsx +++ b/graphcap_studio/src/features/inference/providers/context/ProviderFormContext.tsx @@ -1,7 +1,7 @@ +import type { ConnectionDetails, ErrorDetails, Provider, ProviderCreate, ProviderUpdate } from "@/types/provider-config-types"; // SPDX-License-Identifier: Apache-2.0 import { type ReactNode, createContext, useContext } from "react"; import type { Control, FieldErrors, UseFormWatch } from "react-hook-form"; -import type { ConnectionDetails, ErrorDetails, Provider, ProviderCreate, ProviderUpdate } from "../types"; interface ProviderFormContextType { mode: "view" | "edit" | "create"; diff --git a/graphcap_studio/src/features/inference/providers/index.ts b/graphcap_studio/src/features/inference/providers/index.ts index d20e2a33..386cead9 100644 --- a/graphcap_studio/src/features/inference/providers/index.ts +++ b/graphcap_studio/src/features/inference/providers/index.ts @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 +export * from '@/types/provider-config-types'; +export * from './context'; export * from './ProviderConnection'; export * from './ProvidersPanel'; -export * from './context'; -export * from './types'; + diff --git a/graphcap_studio/src/features/inference/services/providers.ts b/graphcap_studio/src/features/inference/services/providers.ts index 33db33e4..0b0bcc9f 100644 --- a/graphcap_studio/src/features/inference/services/providers.ts +++ b/graphcap_studio/src/features/inference/services/providers.ts @@ -9,7 +9,6 @@ import { useServerConnectionsContext } from "@/context/ServerConnectionsContext"; import { SERVER_IDS } from "@/features/server-connections/constants"; import { createDataServiceClient, createInferenceBridgeClient } from "@/features/server-connections/services/apiClients"; -import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; import type { Provider, ProviderCreate, @@ -17,8 +16,9 @@ import type { ProviderUpdate, ServerProviderConfig, SuccessResponse, -} from "../providers/types"; -import { toServerConfig } from "../providers/types"; +} from "@/types/provider-config-types"; +import { toServerConfig } from "@/types/provider-config-types"; +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; // Query keys for TanStack Query export const queryKeys = { @@ -27,27 +27,7 @@ export const queryKeys = { providerModels: (providerName: string) => ["providers", "models", providerName] as const, }; -interface ServerConnection { - id: string; - url: string; - status: string; -} -// Define a more specific type for the client -interface DataServiceClient { - providers: { - $get: () => Promise; - $post: (options: { json: ProviderCreate }) => Promise; - ":id": { - $get: (options: { param: { id: string } }) => Promise; - $put: (options: { - param: { id: string }; - json: ProviderUpdate; - }) => Promise; - $delete: (options: { param: { id: string } }) => Promise; - }; - }; -} /** * Extended Error interface with cause property diff --git a/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts b/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts index 75d51732..2ce6670b 100644 --- a/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts +++ b/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts @@ -6,12 +6,12 @@ */ import { useServerConnectionsContext } from "@/context"; +import { SERVER_IDS } from "@/features/server-connections/constants"; +import { createInferenceBridgeClient } from "@/features/server-connections/services/apiClients"; import { type Provider, toServerConfig, -} from "@/features/inference/providers/types"; -import { SERVER_IDS } from "@/features/server-connections/constants"; -import { createInferenceBridgeClient } from "@/features/server-connections/services/apiClients"; +} from "@/types/provider-config-types"; import type { ServerConnection } from "@/types/server-connection-types"; import { toast } from "@/utils/toast"; import { useMutation, useQueryClient } from "@tanstack/react-query"; diff --git a/graphcap_studio/src/features/perspectives/hooks/useImagePerspectives.ts b/graphcap_studio/src/features/perspectives/hooks/useImagePerspectives.ts index 7c51a921..e11b9c19 100644 --- a/graphcap_studio/src/features/perspectives/hooks/useImagePerspectives.ts +++ b/graphcap_studio/src/features/perspectives/hooks/useImagePerspectives.ts @@ -6,10 +6,10 @@ */ import { useServerConnectionsContext } from "@/context"; -import type { Provider } from "@/features/inference/providers/types"; import { SERVER_IDS } from "@/features/server-connections/constants"; import { useProviders } from "@/features/server-connections/services/providers"; import type { Image } from "@/services/images"; +import type { Provider } from "@/types/provider-config-types"; import { useCallback, useEffect, useState } from "react"; import type { diff --git a/graphcap_studio/src/features/server-connections/services/providers.ts b/graphcap_studio/src/features/server-connections/services/providers.ts index d39c4f1b..e2530755 100644 --- a/graphcap_studio/src/features/server-connections/services/providers.ts +++ b/graphcap_studio/src/features/server-connections/services/providers.ts @@ -13,8 +13,8 @@ import type { ProviderModelsResponse, ProviderUpdate, SuccessResponse, -} from "@/features/inference/providers/types"; -import { toServerConfig } from "@/features/inference/providers/types"; +} from "@/types/provider-config-types"; +import { toServerConfig } from "@/types/provider-config-types"; import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; import { SERVER_IDS } from "../constants"; import { createDataServiceClient, createInferenceBridgeClient } from "./apiClients"; diff --git a/graphcap_studio/src/features/inference/providers/types.ts b/graphcap_studio/src/types/provider-config-types.ts similarity index 100% rename from graphcap_studio/src/features/inference/providers/types.ts rename to graphcap_studio/src/types/provider-config-types.ts From bc2479a6a68d07ab855ef6965582319bf7996db3 Mon Sep 17 00:00:00 2001 From: jphillips Date: Sat, 29 Mar 2025 07:23:06 -0500 Subject: [PATCH 30/69] Update imports, Add explicit type imports Signed-off-by: jphillips --- .../hooks/useDatasetNavigation.ts | 2 +- .../components/dataset-tree/hooks/useTree.ts | 2 +- .../dataset-tree/hooks/useTreeActions.ts | 2 +- .../PerspectiveCard/PerspectiveCardTabbed.tsx | 2 +- .../PerspectiveCard/PerspectiveDebug.tsx | 2 +- .../PerspectiveCard/SchemaView.tsx | 7 +++---- .../context/PerspectiveUIContext.tsx | 11 +++++----- .../context/PerspectivesDataContext.tsx | 20 +++++++++---------- .../context/PerspectivesProvider.tsx | 6 +++--- .../hooks/useGeneratePerspectiveCaption.ts | 2 +- .../hooks/useImagePerspectives.ts | 2 +- .../hooks/usePerspectiveModules.ts | 9 ++------- .../perspectives/hooks/usePerspectives.ts | 2 +- .../src/features/perspectives/services/api.ts | 19 +++++++++--------- .../src/features/perspectives/types/index.ts | 4 ++-- .../persist-perspective-caption.test.ts | 2 +- .../utils/persist-perspective-caption.ts | 2 +- graphcap_studio/src/types/index.ts | 6 ++++++ ...leTypes.ts => perspective-module-types.ts} | 4 ++-- ...spectivesTypes.ts => perspective-types.ts} | 0 20 files changed, 53 insertions(+), 53 deletions(-) rename graphcap_studio/src/types/{perspectiveModuleTypes.ts => perspective-module-types.ts} (92%) rename graphcap_studio/src/types/{perspectivesTypes.ts => perspective-types.ts} (100%) diff --git a/graphcap_studio/src/features/datasets/components/dataset-tree/hooks/useDatasetNavigation.ts b/graphcap_studio/src/features/datasets/components/dataset-tree/hooks/useDatasetNavigation.ts index 548a83f9..809b40a0 100644 --- a/graphcap_studio/src/features/datasets/components/dataset-tree/hooks/useDatasetNavigation.ts +++ b/graphcap_studio/src/features/datasets/components/dataset-tree/hooks/useDatasetNavigation.ts @@ -3,7 +3,7 @@ import { useNavigate } from "@tanstack/react-router"; // SPDX-License-Identifier: Apache-2.0 import { useCallback } from "react"; -import { TreeItemData } from "../types"; +import type { TreeItemData } from "../types"; /** * Custom hook for dataset navigation. diff --git a/graphcap_studio/src/features/datasets/components/dataset-tree/hooks/useTree.ts b/graphcap_studio/src/features/datasets/components/dataset-tree/hooks/useTree.ts index abe47671..4ab5286a 100644 --- a/graphcap_studio/src/features/datasets/components/dataset-tree/hooks/useTree.ts +++ b/graphcap_studio/src/features/datasets/components/dataset-tree/hooks/useTree.ts @@ -3,7 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 import { useCallback } from "react"; import { useTreeContext } from "../TreeContext"; -import { TreeItemData } from "../types"; +import type { TreeItemData } from "../types"; /** * Custom hook for tree-specific logic. diff --git a/graphcap_studio/src/features/datasets/components/dataset-tree/hooks/useTreeActions.ts b/graphcap_studio/src/features/datasets/components/dataset-tree/hooks/useTreeActions.ts index 7c4ac604..129ea5e7 100644 --- a/graphcap_studio/src/features/datasets/components/dataset-tree/hooks/useTreeActions.ts +++ b/graphcap_studio/src/features/datasets/components/dataset-tree/hooks/useTreeActions.ts @@ -3,7 +3,7 @@ import { useNavigate } from "@tanstack/react-router"; // SPDX-License-Identifier: Apache-2.0 import { useCallback } from "react"; -import { TreeContextMenuAction, TreeItemData } from "../types"; +import type { TreeContextMenuAction, TreeItemData } from "../types"; /** * Custom hook for common tree actions like deletion, navigation, etc. diff --git a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/PerspectiveCardTabbed.tsx b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/PerspectiveCardTabbed.tsx index c6453bdc..e1e23185 100644 --- a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/PerspectiveCardTabbed.tsx +++ b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/PerspectiveCardTabbed.tsx @@ -1,5 +1,6 @@ import { ClipboardButton } from "@/components/ui/buttons"; import { useColorModeValue } from "@/components/ui/theme/color-mode"; +import type { PerspectiveSchema } from "@/types/perspective-types"; // SPDX-License-Identifier: Apache-2.0 /** * PerspectiveCardTabbed Component @@ -8,7 +9,6 @@ import { useColorModeValue } from "@/components/ui/theme/color-mode"; * This component uses Chakra UI tabs for the tabbed interface. */ import { Box, Card, Stack, Tabs, Text } from "@chakra-ui/react"; -import type { PerspectiveSchema } from "../../../types"; import { PerspectiveDebug } from "./PerspectiveDebug"; import { SchemaView } from "./SchemaView"; import { CaptionRenderer } from "./schema-fields"; diff --git a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/PerspectiveDebug.tsx b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/PerspectiveDebug.tsx index 15f54073..bc772628 100644 --- a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/PerspectiveDebug.tsx +++ b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/PerspectiveDebug.tsx @@ -1,4 +1,5 @@ import { ClipboardButton } from "@/components/ui/buttons"; +import type { PerspectiveData, PerspectiveSchema } from "@/types"; import { Box, Stack } from "@chakra-ui/react"; // SPDX-License-Identifier: Apache-2.0 /** @@ -8,7 +9,6 @@ import { Box, Stack } from "@chakra-ui/react"; * including its data, options, and metadata. */ import { useEffect } from "react"; -import type { PerspectiveData, PerspectiveSchema } from "../../../types"; import { DataStatistics, MetadataSection, diff --git a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/SchemaView.tsx b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/SchemaView.tsx index 34b72a5b..b0cbb982 100644 --- a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/SchemaView.tsx +++ b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/SchemaView.tsx @@ -5,13 +5,12 @@ * This component displays the schema information for a perspective. */ -import React from "react"; -import type { PerspectiveSchema } from "../../../types"; +import type { PerspectiveSchema } from "@/types"; import { SchemaFieldFactory } from "./schema-fields"; interface SchemaViewProps { - schema: PerspectiveSchema; - className?: string; + readonly schema: PerspectiveSchema; + readonly className?: string; } export function SchemaView({ schema, className = "" }: SchemaViewProps) { diff --git a/graphcap_studio/src/features/perspectives/context/PerspectiveUIContext.tsx b/graphcap_studio/src/features/perspectives/context/PerspectiveUIContext.tsx index 58f7e9f3..9d2d5e17 100644 --- a/graphcap_studio/src/features/perspectives/context/PerspectiveUIContext.tsx +++ b/graphcap_studio/src/features/perspectives/context/PerspectiveUIContext.tsx @@ -6,17 +6,18 @@ * It follows the Context API best practices and focuses exclusively on UI concerns. */ -import React, { +import type { PerspectiveSchema } from "@/types"; +import type React from "react"; +import { + type ReactNode, createContext, + useCallback, useContext, - ReactNode, useEffect, useMemo, - useState, useRef, - useCallback, + useState, } from "react"; -import { PerspectiveSchema } from "../types"; import { getSelectedPerspective, saveSelectedPerspective, diff --git a/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx b/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx index c1bebc91..2d79fbe7 100644 --- a/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx +++ b/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx @@ -13,24 +13,24 @@ import { useServerConnectionsContext } from "@/context"; import { SERVER_IDS } from "@/features/server-connections/constants"; import { useProviders } from "@/features/server-connections/services/providers"; import type { Image } from "@/services/images"; +import type { + CaptionOptions, + Perspective, + PerspectiveData, + PerspectiveSchema, + Provider, +} from "@/types"; import React, { createContext, - useContext, - type ReactNode, - useState, useCallback, + useContext, useEffect, useMemo, + useState, + type ReactNode, } from "react"; import { useGeneratePerspectiveCaption } from "../hooks/useGeneratePerspectiveCaption"; import { usePerspectives } from "../hooks/usePerspectives"; -import type { - CaptionOptions, - Perspective, - PerspectiveData, - PerspectiveSchema, - Provider, -} from "../types"; import { getAllPerspectiveCaptions, loadHiddenPerspectives, diff --git a/graphcap_studio/src/features/perspectives/context/PerspectivesProvider.tsx b/graphcap_studio/src/features/perspectives/context/PerspectivesProvider.tsx index a22e1d1e..d5909b1f 100644 --- a/graphcap_studio/src/features/perspectives/context/PerspectivesProvider.tsx +++ b/graphcap_studio/src/features/perspectives/context/PerspectivesProvider.tsx @@ -6,9 +6,9 @@ * to simplify usage in component trees. */ -import { Image } from "@/services/images"; -import { ReactNode } from "react"; -import { Provider } from "../types"; +import type { Image } from "@/services/images"; +import type { Provider } from "@/types/provider-config-types"; +import type { ReactNode } from "react"; import { PerspectiveUIProvider } from "./PerspectiveUIContext"; import { PerspectivesDataProvider } from "./PerspectivesDataContext"; diff --git a/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts b/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts index 2ce6670b..7be27499 100644 --- a/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts +++ b/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts @@ -8,6 +8,7 @@ import { useServerConnectionsContext } from "@/context"; import { SERVER_IDS } from "@/features/server-connections/constants"; import { createInferenceBridgeClient } from "@/features/server-connections/services/apiClients"; +import type { CaptionOptions, CaptionResponse } from "@/types"; import { type Provider, toServerConfig, @@ -17,7 +18,6 @@ import { toast } from "@/utils/toast"; import { useMutation, useQueryClient } from "@tanstack/react-query"; import { perspectivesQueryKeys } from "../services/constants"; import { ensureWorkspacePath, handleApiError } from "../services/utils"; -import type { CaptionOptions, CaptionResponse } from "../types"; /** * Hook to generate a perspective caption for an image diff --git a/graphcap_studio/src/features/perspectives/hooks/useImagePerspectives.ts b/graphcap_studio/src/features/perspectives/hooks/useImagePerspectives.ts index e11b9c19..c9c3dcbe 100644 --- a/graphcap_studio/src/features/perspectives/hooks/useImagePerspectives.ts +++ b/graphcap_studio/src/features/perspectives/hooks/useImagePerspectives.ts @@ -18,7 +18,7 @@ import type { ImagePerspectivesResult, PerspectiveData, PerspectiveType, -} from "../types"; +} from "@/types/perspective-types"; import { useGeneratePerspectiveCaption } from "./useGeneratePerspectiveCaption"; import { usePerspectives } from "./usePerspectives"; diff --git a/graphcap_studio/src/features/perspectives/hooks/usePerspectiveModules.ts b/graphcap_studio/src/features/perspectives/hooks/usePerspectiveModules.ts index 99580924..3ea188b5 100644 --- a/graphcap_studio/src/features/perspectives/hooks/usePerspectiveModules.ts +++ b/graphcap_studio/src/features/perspectives/hooks/usePerspectiveModules.ts @@ -9,15 +9,10 @@ import { useServerConnectionsContext } from "@/context"; import { SERVER_IDS } from "@/features/server-connections/constants"; import { createInferenceBridgeClient } from "@/features/server-connections/services/apiClients"; +import type { ModuleInfo, ModuleListResponse, Perspective, PerspectiveModule } from "@/types"; import { useQuery } from "@tanstack/react-query"; import { useEffect, useMemo } from "react"; -import { - API_ENDPOINTS, - CACHE_TIMES, - perspectivesQueryKeys, -} from "../services/constants"; -import { getGraphCapServerUrl, handleApiError } from "../services/utils"; -import type { ModuleInfo, ModuleListResponse, Perspective, PerspectiveModule } from "../types"; +import { handleApiError } from "../services/utils"; import { PerspectiveError } from "./usePerspectives"; type ModuleQueryResult = { diff --git a/graphcap_studio/src/features/perspectives/hooks/usePerspectives.ts b/graphcap_studio/src/features/perspectives/hooks/usePerspectives.ts index f9b9a69e..a149afcc 100644 --- a/graphcap_studio/src/features/perspectives/hooks/usePerspectives.ts +++ b/graphcap_studio/src/features/perspectives/hooks/usePerspectives.ts @@ -9,6 +9,7 @@ import { useServerConnectionsContext } from "@/context"; import { SERVER_IDS } from "@/features/server-connections/constants"; import { createInferenceBridgeClient } from "@/features/server-connections/services/apiClients"; +import type { Perspective, PerspectiveListResponse } from "@/types"; import type { ServerConnection } from "@/types/server-connection-types"; import { useQuery } from "@tanstack/react-query"; import { useEffect } from "react"; @@ -17,7 +18,6 @@ import { perspectivesQueryKeys } from "../services/constants"; import { handleApiError } from "../services/utils"; -import type { Perspective, PerspectiveListResponse } from "../types"; /** * Custom error class for perspective fetching errors diff --git a/graphcap_studio/src/features/perspectives/services/api.ts b/graphcap_studio/src/features/perspectives/services/api.ts index 64b0c13c..0877284d 100644 --- a/graphcap_studio/src/features/perspectives/services/api.ts +++ b/graphcap_studio/src/features/perspectives/services/api.ts @@ -6,22 +6,21 @@ */ import { createInferenceBridgeClient } from "@/features/server-connections/services/apiClients"; -import { API_ENDPOINTS } from "../constants/index"; -import { - CaptionRequestSchema, - CaptionResponseSchema, - ModuleListResponseSchema, - ModulePerspectivesResponseSchema, - PerspectiveListResponseSchema, -} from "../types"; import type { CaptionRequest, CaptionResponse, ModuleListResponse, ModulePerspectivesResponse, Perspective, -} from "../types"; -import { ensureWorkspacePath, getGraphCapServerUrl, getInferenceBridgeApiUrl, handleApiError } from "./utils"; +} from "@/types"; +import { + CaptionRequestSchema, + CaptionResponseSchema, + ModuleListResponseSchema, + ModulePerspectivesResponseSchema, + PerspectiveListResponseSchema, +} from "@/types"; +import { ensureWorkspacePath, handleApiError } from "./utils"; /** * Get server connections from local storage diff --git a/graphcap_studio/src/features/perspectives/types/index.ts b/graphcap_studio/src/features/perspectives/types/index.ts index ab5f5687..5a5d7e94 100644 --- a/graphcap_studio/src/features/perspectives/types/index.ts +++ b/graphcap_studio/src/features/perspectives/types/index.ts @@ -6,6 +6,6 @@ * Type definitions are consolidated in their respective files. */ -export * from "@/types/perspectiveModuleTypes"; -export * from "@/types/perspectivesTypes"; +export * from "@/types/perspective-module-types"; +export * from "@/types/perspective-types"; diff --git a/graphcap_studio/src/features/perspectives/utils/__tests__/persist-perspective-caption.test.ts b/graphcap_studio/src/features/perspectives/utils/__tests__/persist-perspective-caption.test.ts index 34bf9d88..638bf6f3 100644 --- a/graphcap_studio/src/features/perspectives/utils/__tests__/persist-perspective-caption.test.ts +++ b/graphcap_studio/src/features/perspectives/utils/__tests__/persist-perspective-caption.test.ts @@ -3,8 +3,8 @@ * Unit tests for perspective caption persistence utilities */ +import type { PerspectiveData } from "@/types/perspective-types"; import { afterAll, beforeEach, describe, expect, it } from "vitest"; -import { PerspectiveData } from "../../types"; import { clearAllPerspectiveCaptions, deletePerspectiveCaption, diff --git a/graphcap_studio/src/features/perspectives/utils/persist-perspective-caption.ts b/graphcap_studio/src/features/perspectives/utils/persist-perspective-caption.ts index 9adfe5b4..d0d47863 100644 --- a/graphcap_studio/src/features/perspectives/utils/persist-perspective-caption.ts +++ b/graphcap_studio/src/features/perspectives/utils/persist-perspective-caption.ts @@ -7,7 +7,7 @@ * and perspective name. */ -import { PerspectiveData } from "../types"; +import type { PerspectiveData } from "@/types/perspective-types"; /** * Storage key prefix for saving perspective captions in localStorage diff --git a/graphcap_studio/src/types/index.ts b/graphcap_studio/src/types/index.ts index e69de29b..653e5f72 100644 --- a/graphcap_studio/src/types/index.ts +++ b/graphcap_studio/src/types/index.ts @@ -0,0 +1,6 @@ +export * from "./generation-option-types"; +export * from "./perspective-module-types"; +export * from "./perspective-types"; +export * from "./provider-config-types"; +export * from "./server-connection-types"; + diff --git a/graphcap_studio/src/types/perspectiveModuleTypes.ts b/graphcap_studio/src/types/perspective-module-types.ts similarity index 92% rename from graphcap_studio/src/types/perspectiveModuleTypes.ts rename to graphcap_studio/src/types/perspective-module-types.ts index a4f186ca..0d265d1c 100644 --- a/graphcap_studio/src/types/perspectiveModuleTypes.ts +++ b/graphcap_studio/src/types/perspective-module-types.ts @@ -5,8 +5,8 @@ * This module defines types related to perspective modules and management. */ -import type { Perspective } from "@/types/perspectivesTypes"; -import { PerspectiveSchema } from "@/types/perspectivesTypes"; +import type { Perspective } from "@/types/perspective-types"; +import { PerspectiveSchema } from "@/types/perspective-types"; import { z } from "zod"; /** diff --git a/graphcap_studio/src/types/perspectivesTypes.ts b/graphcap_studio/src/types/perspective-types.ts similarity index 100% rename from graphcap_studio/src/types/perspectivesTypes.ts rename to graphcap_studio/src/types/perspective-types.ts From 672a19afd00434500ebcb01a83149b3dfd335146 Mon Sep 17 00:00:00 2001 From: jphillips Date: Sat, 29 Mar 2025 08:44:42 -0500 Subject: [PATCH 31/69] Consolidate caption options to use central generation options Signed-off-by: jphillips --- .../components/fields/ModelSelectorField.tsx | 116 +++--- .../context/GenerationOptionsContext.tsx | 152 ++++++-- .../persist-generation-options.ts | 16 +- .../hooks/useProviderModelOptions.ts | 68 ++++ .../PerspectiveActions/PerspectivesFooter.tsx | 179 ++++------ .../context/PerspectivesDataContext.tsx | 256 ++++++------- .../hooks/useGeneratePerspectiveCaption.ts | 39 +- .../perspectives/utils/api-adapters.ts | 106 ++++++ .../services/inferenceBridgeClient.ts | 7 + .../services/providerAdapters.ts | 167 +++++++++ .../server-connections/services/providers.ts | 285 ++++++++++----- .../src/types/generation-option-types.ts | 48 ++- .../src/types/provider-config-types.ts | 335 +++++++++++------- 13 files changed, 1157 insertions(+), 617 deletions(-) create mode 100644 graphcap_studio/src/features/inference/hooks/useProviderModelOptions.ts create mode 100644 graphcap_studio/src/features/perspectives/utils/api-adapters.ts create mode 100644 graphcap_studio/src/features/server-connections/services/providerAdapters.ts diff --git a/graphcap_studio/src/features/inference/generation-options/components/fields/ModelSelectorField.tsx b/graphcap_studio/src/features/inference/generation-options/components/fields/ModelSelectorField.tsx index ba5e7e08..a48f5a0e 100644 --- a/graphcap_studio/src/features/inference/generation-options/components/fields/ModelSelectorField.tsx +++ b/graphcap_studio/src/features/inference/generation-options/components/fields/ModelSelectorField.tsx @@ -7,97 +7,67 @@ import { Field } from "@/components/ui/field"; import { useColorModeValue } from "@/components/ui/theme/color-mode"; -import { useProviderModelSelection } from "@/features/inference/hooks"; -import { useInferenceProviderContext } from "@/features/inference/providers/context/InferenceProviderContext"; -import { Box } from "@chakra-ui/react"; -import { Portal, Select, createListCollection } from "@chakra-ui/react"; -import { useCallback, useEffect, useMemo } from "react"; +import { Box, Portal, Select, createListCollection } from "@chakra-ui/react"; import { useGenerationOptions } from "../../context"; /** * Field component for selecting model and provider */ export function ModelSelectorField() { - const { options, updateOption, isGenerating } = useGenerationOptions(); - const { providers: contextProviders } = useInferenceProviderContext(); - - // Extract provider from context - const currentProvider = useMemo(() => { - if (!options.provider_id) return null; - const providerId = Number.parseInt(options.provider_id, 10); - return contextProviders.find(p => p.id === providerId) || null; - }, [contextProviders, options.provider_id]); - - // Use the hook to get providers and models - const { - providers, + const { + options, + providers, models, - isLoading, - } = useProviderModelSelection(currentProvider); + actions + } = useGenerationOptions(); - // Color values f or theming + // Color values for theming const labelColor = useColorModeValue("gray.700", "gray.300"); const helperTextColor = useColorModeValue("gray.500", "gray.400"); - // Initialize provider if needed - useEffect(() => { - if (providers.length > 0 && !options.provider_id) { - const provider = providers[0]; - updateOption("provider_id", provider.id.toString()); - } - }, [providers, options.provider_id, updateOption]); - - // Update model when provider changes or when models are loaded - useEffect(() => { - if (models.length > 0 && !options.model_id) { - updateOption("model_id", models[0]?.name || ""); - } - }, [models, options.model_id, updateOption]); - // Create collections for selects - always include at least one item - const providerCollection = useMemo(() => { - const items = providers.length > 0 - ? providers.map((provider) => ({ + const providerCollection = createListCollection({ + items: providers.items.length > 0 + ? providers.items.map((provider) => ({ label: provider.name, - value: provider.id.toString(), + value: provider.id, disabled: false, })) - : [{ label: "No providers available", value: "none", disabled: false }]; + : [{ label: "No providers available", value: "none", disabled: false }] + }); - return createListCollection({ items }); - }, [providers]); - - const modelCollection = useMemo(() => { - const items = models.length > 0 - ? models.map((model) => ({ + const modelCollection = createListCollection({ + items: models.items.length > 0 + ? models.items.map((model) => ({ label: model.name, value: model.id, disabled: false, })) - : [{ label: "No models available", value: "none", disabled: false }]; - - return createListCollection({ items }); - }, [models]); + : [{ label: "No models available", value: "none", disabled: false }] + }); // Handle provider change - const handleProviderChange = useCallback((newValue: string[]) => { - if (newValue.length > 0 && newValue[0] !== "none") { - const providerId = newValue[0]; - updateOption("provider_id", providerId); - updateOption("model_id", ""); + const handleProviderChange = (details: { value: string[] }) => { + if (details.value.length > 0 && details.value[0] !== "none") { + actions.selectProvider(details.value[0]); } - }, [updateOption]); + }; // Handle model change - const handleModelChange = useCallback((newValue: string[]) => { - if (newValue.length > 0 && newValue[0] !== "none") { - updateOption("model_id", newValue[0]); + const handleModelChange = (details: { value: string[] }) => { + if (details.value.length > 0 && details.value[0] !== "none") { + actions.selectModel(details.value[0]); } - }, [updateOption]); + }; - console.log('Providers:', providers); - console.log('IsLoading:', isLoading); - console.log('CurrentProvider:', currentProvider); + // Check if any providers are available + const hasProviders = providers.items.length > 0; + + // Check if any models are available for the selected provider + + // Loading state + const isProvidersLoading = providers.isLoading; + const isModelsLoading = models.isLoading; return ( @@ -105,19 +75,19 @@ export function ModelSelectorField() { Provider & Model - + handleProviderChange(e.value)} - disabled={false} // Never disable this + onValueChange={handleProviderChange} + disabled={isProvidersLoading} size="sm" > - + @@ -140,19 +110,19 @@ export function ModelSelectorField() { - + handleModelChange(e.value)} - disabled={false} // Never disable this + onValueChange={handleModelChange} + disabled={isModelsLoading || !hasProviders} size="sm" > - + @@ -176,7 +146,7 @@ export function ModelSelectorField() { - Click on the dropdowns to select provider and model + Select provider and model for generation ); diff --git a/graphcap_studio/src/features/inference/generation-options/context/GenerationOptionsContext.tsx b/graphcap_studio/src/features/inference/generation-options/context/GenerationOptionsContext.tsx index 30d3400c..b85233a9 100644 --- a/graphcap_studio/src/features/inference/generation-options/context/GenerationOptionsContext.tsx +++ b/graphcap_studio/src/features/inference/generation-options/context/GenerationOptionsContext.tsx @@ -2,7 +2,8 @@ /** * Generation Options Context * - * This module provides a context for managing generation options state. + * This module provides a context for managing generation options state, + * including provider and model selection. */ import { @@ -10,6 +11,7 @@ import { type GenerationOptions, GenerationOptionsSchema, } from "@/types/generation-option-types"; +import type { Provider, ProviderModelInfo } from "@/types/provider-config-types"; import type React from "react"; import { createContext, @@ -19,29 +21,47 @@ import { useMemo, useState, } from "react"; +import { useProviderModelOptions } from "../../hooks/useProviderModelOptions"; import { usePersistGenerationOptions } from "../persist-generation-options"; // Define the context interface interface GenerationOptionsContextValue { - // State + // State groups options: GenerationOptions; - isDialogOpen: boolean; - isGenerating: boolean; - - // Actions - updateOption: ( - key: K, - value: GenerationOptions[K], - ) => void; - resetOptions: () => void; - setOptions: (options: Partial) => void; - openDialog: () => void; - closeDialog: () => void; - toggleDialog: () => void; - setIsGenerating: (isGenerating: boolean) => void; + providers: { + items: Provider[]; + selected: Provider | null; + isLoading: boolean; + error: unknown; + }; + models: { + items: ProviderModelInfo[]; + defaultModel: ProviderModelInfo | null; + isLoading: boolean; + error: unknown; + }; + uiState: { + isDialogOpen: boolean; + isGenerating: boolean; + }; + + // Action groups + actions: { + updateOption: (key: K, value: GenerationOptions[K]) => void; + resetOptions: () => void; + setOptions: (options: Partial) => void; + selectProvider: (providerId: string) => void; + selectModel: (modelId: string) => void; + }; + uiActions: { + openDialog: () => void; + closeDialog: () => void; + toggleDialog: () => void; + setIsGenerating: (isGenerating: boolean) => void; + }; } -// Create the context with a default value +// Create the context with undefined default const GenerationOptionsContext = createContext< GenerationOptionsContextValue | undefined >(undefined); @@ -82,6 +102,18 @@ export function GenerationOptionsProvider({ const [isDialogOpen, setIsDialogOpen] = useState(false); const [isGenerating, setIsGenerating] = useState(initialGenerating); + // Provider and model data + const { + providers, + selectedProvider, + isLoadingProviders, + providersError, + models, + defaultModel, + isLoadingModels, + modelsError + } = useProviderModelOptions(options.provider_id); + // Save options to localStorage when they change useEffect(() => { saveOptions(options); @@ -92,6 +124,24 @@ export function GenerationOptionsProvider({ setIsGenerating(initialGenerating); }, [initialGenerating]); + // Initialize provider if available and not already set + useEffect(() => { + if (providers.length > 0 && !options.provider_id) { + const firstProvider = providers[0]; + updateOption("provider_id", firstProvider.id); + } + }, [providers, options.provider_id]); + + // Initialize model if available and not already set + useEffect(() => { + // If we have a provider but no model, and models are available + if (options.provider_id && !options.model_id && models.length > 0) { + // Try to use default model first, otherwise use first available model + const modelToUse = defaultModel || models[0]; + updateOption("model_id", modelToUse.id); + } + }, [options.provider_id, options.model_id, models, defaultModel]); + // Update a single option const updateOption = useCallback( ( @@ -132,38 +182,80 @@ export function GenerationOptionsProvider({ [onOptionsChange], ); + // Provider selection + const selectProvider = useCallback((providerId: string) => { + updateOption("provider_id", providerId); + // Clear model when provider changes + updateOption("model_id", ""); + }, [updateOption]); + + // Model selection + const selectModel = useCallback((modelId: string) => { + updateOption("model_id", modelId); + }, [updateOption]); + // Dialog controls const openDialog = useCallback(() => setIsDialogOpen(true), []); const closeDialog = useCallback(() => setIsDialogOpen(false), []); const toggleDialog = useCallback(() => setIsDialogOpen((prev) => !prev), []); - // Context value + // Context value using grouped structure const value = useMemo( () => ({ - // State + // State groups options, - isDialogOpen, - isGenerating, + providers: { + items: providers, + selected: selectedProvider, + isLoading: isLoadingProviders, + error: providersError + }, + models: { + items: models, + defaultModel, + isLoading: isLoadingModels, + error: modelsError + }, + uiState: { + isDialogOpen, + isGenerating + }, - // Actions - updateOption, - resetOptions, - setOptions: mergeOptions, - openDialog, - closeDialog, - toggleDialog, - setIsGenerating, + // Action groups + actions: { + updateOption, + resetOptions, + setOptions: mergeOptions, + selectProvider, + selectModel + }, + uiActions: { + openDialog, + closeDialog, + toggleDialog, + setIsGenerating + } }), [ options, + providers, + selectedProvider, + isLoadingProviders, + providersError, + models, + defaultModel, + isLoadingModels, + modelsError, isDialogOpen, isGenerating, updateOption, resetOptions, mergeOptions, + selectProvider, + selectModel, openDialog, closeDialog, - toggleDialog, + toggleDialog ], ); diff --git a/graphcap_studio/src/features/inference/generation-options/persist-generation-options.ts b/graphcap_studio/src/features/inference/generation-options/persist-generation-options.ts index e1e66895..7421fa62 100644 --- a/graphcap_studio/src/features/inference/generation-options/persist-generation-options.ts +++ b/graphcap_studio/src/features/inference/generation-options/persist-generation-options.ts @@ -20,7 +20,15 @@ const STORAGE_KEY = "graphcap:generation-options"; */ export function saveGenerationOptions(options: GenerationOptions): void { try { - const serialized = JSON.stringify(options); + // Create a copy to ensure we don't modify the original + const optionsToSave = { ...options }; + + // Ensure provider_id is stored as a string + if (optionsToSave.provider_id !== undefined) { + optionsToSave.provider_id = String(optionsToSave.provider_id); + } + + const serialized = JSON.stringify(optionsToSave); localStorage.setItem(STORAGE_KEY, serialized); } catch (error) { console.error("Failed to save generation options to localStorage:", error); @@ -38,6 +46,12 @@ export function loadGenerationOptions(): GenerationOptions | null { if (!serialized) return null; const parsed = JSON.parse(serialized); + + // If provider_id exists and is a number, convert it to string + if (parsed.provider_id !== undefined && typeof parsed.provider_id === 'number') { + parsed.provider_id = parsed.provider_id.toString(); + } + // Validate the loaded data against the schema return GenerationOptionsSchema.parse(parsed); } catch (error) { diff --git a/graphcap_studio/src/features/inference/hooks/useProviderModelOptions.ts b/graphcap_studio/src/features/inference/hooks/useProviderModelOptions.ts new file mode 100644 index 00000000..3f16b3db --- /dev/null +++ b/graphcap_studio/src/features/inference/hooks/useProviderModelOptions.ts @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Provider Model Options Hook + * + * This hook provides access to providers and models data with support for selection. + * It consolidates provider and model data loading in a single hook. + */ + +import { useProviderModels, useProviders } from "@/features/server-connections/services/providers"; +import type { Provider, ProviderModelInfo } from "@/types/provider-config-types"; +import { useMemo } from "react"; + +/** + * Hook for accessing provider and model selection options + * + * @param providerId - The selected provider ID + * @returns Provider and model data with loading states + */ +export function useProviderModelOptions(providerId?: string) { + // Fetch all providers + const { + data: providers = [], + isLoading: isLoadingProviders, + error: providersError + } = useProviders(); + + // Find the selected provider object + const selectedProvider = useMemo(() => { + if (!providerId) return null; + return providers.find((p: Provider) => p.id === providerId) || null; + }, [providers, providerId]); + + // Fetch models for the selected provider + const { + data: modelData, + isLoading: isLoadingModels, + error: modelsError + } = useProviderModels(selectedProvider?.name || ""); + + // Process models data + const models = useMemo(() => { + if (!modelData?.models) return []; + return modelData.models; + }, [modelData]); + + // Check for default model + const defaultModel = useMemo(() => { + return models.find(model => model.is_default === true) || null; + }, [models]); + + return { + // Providers data + providers, + selectedProvider, + isLoadingProviders, + providersError, + + // Models data + models, + defaultModel, + isLoadingModels, + modelsError, + + // Helper for status checking + isLoading: isLoadingProviders || isLoadingModels, + hasError: !!providersError || !!modelsError + }; +} \ No newline at end of file diff --git a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveActions/PerspectivesFooter.tsx b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveActions/PerspectivesFooter.tsx index a3b3bafc..14ce9e7c 100644 --- a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveActions/PerspectivesFooter.tsx +++ b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveActions/PerspectivesFooter.tsx @@ -7,26 +7,19 @@ import { useColorModeValue } from "@/components/ui/theme/color-mode"; import { - GenerationOptionsButton, - GenerationOptionsProvider, - ProviderSelector, -} from "@/features/inference/generation-options"; -import { - usePerspectiveUI, - usePerspectivesData, + usePerspectiveUI, + usePerspectivesData, } from "@/features/perspectives/context"; -import type { CaptionOptions } from "@/features/perspectives/types"; -import { DEFAULT_OPTIONS } from "@/types/generation-option-types"; import { - Box, - Button, - Flex, - HStack, - Icon, - useBreakpointValue, + Box, + Button, + Flex, + Icon, + Text, + chakra, } from "@chakra-ui/react"; import { useCallback, useEffect } from "react"; -import { LuRefreshCw, LuSettings } from "react-icons/lu"; +import { LuRefreshCw } from "react-icons/lu"; /** * Helper function to determine button title text @@ -64,10 +57,8 @@ export function PerspectivesFooter() { generatePerspective, isGenerating, currentImage, - captionOptions, - setCaptionOptions, + generationOptions, selectedProvider, - handleProviderChange, } = usePerspectivesData(); // Use UI context @@ -83,13 +74,26 @@ export function PerspectivesFooter() { const bgColor = useColorModeValue("white", "gray.800"); const borderColor = useColorModeValue("gray.200", "gray.700"); - - // Use responsive selector width based on screen size - const selectorWidth = useBreakpointValue({ - base: "100%", - sm: "12rem", - md: "16rem", - }); + const infoTextColor = useColorModeValue("gray.600", "gray.400"); + + // Get provider and model names + // Log information for debugging + console.log("GenerationOptions:", generationOptions); + console.log("Available providers:", availableProviders); + + // Try to find provider by id first, if that fails look for a name match + const providerObj = availableProviders.find(p => + // Try matching by ID + p.id.toString() === generationOptions.provider_id || + // Or by name + (selectedProvider && p.name === selectedProvider) + ); + + const providerName = providerObj?.name || + // If we have a provider_id but couldn't find a match, show that + (generationOptions.provider_id ? `ID: ${generationOptions.provider_id}` : "None"); + + const modelId = generationOptions.model_id || "None"; // Fetch providers on component mount useEffect(() => { @@ -135,14 +139,7 @@ export function PerspectivesFooter() { console.log("Generate button clicked"); console.log("Active schema:", activeSchemaName); console.log("Selected provider:", selectedProvider); - - // Ensure we have valid options by applying defaults if needed - const effectiveOptions = - Object.keys(captionOptions).length === 0 - ? DEFAULT_OPTIONS - : captionOptions; - - console.log("Using caption options:", effectiveOptions); + console.log("Using generation options:", generationOptions); if (!validateGeneration()) { return; @@ -158,11 +155,12 @@ export function PerspectivesFooter() { } await generatePerspective( - activeSchemaName!, - currentImage!.path, - providerObject, // Pass the full provider object - effectiveOptions, + activeSchemaName as string, + currentImage?.path as string, + providerObject, + generationOptions ); + showMessage( "Generation started", `Generating ${activeSchemaName} perspective`, @@ -179,9 +177,9 @@ export function PerspectivesFooter() { }, [ activeSchemaName, selectedProvider, - availableProviders, // Add availableProviders to the dependencies + availableProviders, generatePerspective, - captionOptions, + generationOptions, showMessage, currentImage, validateGeneration, @@ -202,27 +200,6 @@ export function PerspectivesFooter() { isGenerated, ); - // Handle options change - const handleOptionsChange = useCallback( - (newOptions: CaptionOptions) => { - setCaptionOptions(newOptions); - }, - [setCaptionOptions], - ); - - // Create a handler for the new ProviderSelector component - const handleProviderSelection = useCallback( - (provider: string) => { - // Create a synthetic event to pass to the original handler - const syntheticEvent = { - target: { value: provider }, - } as React.ChangeEvent; - - handleProviderChange(syntheticEvent); - }, - [handleProviderChange], - ); - return ( - {/* Provider Selection */} - {availableProviders.length > 0 ? ( - - ) : ( - - )} - - - {/* Options Button with Popover */} - - - - Options - - } - size="sm" - variant="ghost" + {/* Provider and Model Info */} + + Using: {providerName} / {modelId} + + + {/* Generate/Regenerate Button */} + - + )} + {isGenerated && !isGenerating && } + {isGenerated ? "Regenerate" : "Generate"} + ); diff --git a/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx b/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx index 2d79fbe7..ac279287 100644 --- a/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx +++ b/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx @@ -10,22 +10,22 @@ */ import { useServerConnectionsContext } from "@/context"; +import { useGenerationOptions } from "@/features/inference/generation-options/context"; import { SERVER_IDS } from "@/features/server-connections/constants"; import { useProviders } from "@/features/server-connections/services/providers"; import type { Image } from "@/services/images"; import type { - CaptionOptions, Perspective, PerspectiveData, - PerspectiveSchema, - Provider, + PerspectiveSchema } from "@/types"; +import type { GenerationOptions } from "@/types/generation-option-types"; +import type { Provider } from "@/types/provider-config-types"; import React, { createContext, useCallback, useContext, useEffect, - useMemo, useState, type ReactNode, } from "react"; @@ -106,14 +106,13 @@ interface PerspectivesDataContextType { refetchPerspectives: () => Promise; // Captions data - captions: Record; + captions: Record; generatedPerspectives: string[]; isGenerating: boolean; isServerConnected: boolean; - // Caption options - captionOptions: CaptionOptions; - setCaptionOptions: (options: CaptionOptions) => void; + // Generation options from the GenerationOptions context + generationOptions: GenerationOptions; // Current image currentImage: Image | null; @@ -124,8 +123,8 @@ interface PerspectivesDataContextType { schemaName: string, imagePath: string, provider?: Provider, - options?: CaptionOptions, - ) => Promise; + options?: GenerationOptions, + ) => Promise; // Status helpers isPerspectiveGenerated: (schemaName: string) => boolean; @@ -162,7 +161,6 @@ interface PerspectivesDataProviderProps { readonly image: Image | null; readonly initialProvider?: string; readonly initialProviders?: Provider[]; - readonly initialCaptionOptions?: CaptionOptions; } /** @@ -174,7 +172,6 @@ export function PerspectivesDataProvider({ image: initialImage, initialProvider, initialProviders = [], - initialCaptionOptions = {}, }: PerspectivesDataProviderProps) { // Server connection state const { connections } = useServerConnectionsContext(); @@ -183,6 +180,9 @@ export function PerspectivesDataProvider({ ); const isServerConnected = graphcapServerConnection?.status === "connected"; + // Get generation options from context + const generationOptions = useGenerationOptions(); + // Current image state const [currentImage, setCurrentImage] = useState(initialImage); @@ -194,13 +194,8 @@ export function PerspectivesDataProvider({ useState(initialProviders); const [isGeneratingAll, setIsGeneratingAll] = useState(false); - // Caption options state - const [captionOptions, setCaptionOptions] = useState( - initialCaptionOptions, - ); - // Captions state - const [captions, setCaptions] = useState({}); + const [captions, setCaptions] = useState>({}); // Generation state const [generatingPerspectives, setGeneratingPerspectives] = useState< @@ -289,10 +284,9 @@ export function PerspectivesDataProvider({ if (prev.includes(perspectiveName)) { // If already hidden, make it visible (remove from hidden list) return prev.filter((name) => name !== perspectiveName); - } else { - // If visible, hide it (add to hidden list) - return [...prev, perspectiveName]; } + // If visible, hide it (add to hidden list) + return [...prev, perspectiveName]; }); }, []); @@ -360,8 +354,9 @@ export function PerspectivesDataProvider({ // Get generated perspectives based on captions const generatedPerspectives = React.useMemo(() => { - if (!captions.perspectives) return []; - return Object.keys(captions.perspectives); + const perspectives = captions.perspectives as Record | undefined; + if (!perspectives) return []; + return Object.keys(perspectives); }, [captions]); // Generate a perspective caption and save to localStorage @@ -370,7 +365,7 @@ export function PerspectivesDataProvider({ schemaName: string, imagePath: string, provider?: Provider, - options?: CaptionOptions, + options?: GenerationOptions, ) => { if (!isServerConnected) { throw new Error( @@ -395,11 +390,14 @@ export function PerspectivesDataProvider({ throw new Error("No provider selected for caption generation"); } + // Get current options from GenerationOptions context if not provided + const effectiveOptions = options || generationOptions.options; + // Log the options to ensure they're being passed correctly console.debug(`Generating perspective "${schemaName}" with options:`, { providedOptions: options, - contextOptions: captionOptions, - finalOptions: options ?? captionOptions ?? {}, + contextOptions: generationOptions.options, + finalOptions: effectiveOptions, provider: effectiveProvider, }); @@ -408,7 +406,7 @@ export function PerspectivesDataProvider({ perspective: schemaName, imagePath, provider: effectiveProvider, - options: options ?? captionOptions, + options: effectiveOptions, }); // Validate required data @@ -418,49 +416,40 @@ export function PerspectivesDataProvider({ ); } - if (!effectiveProvider) { - console.error( - `ERROR: Missing provider information for perspective ${schemaName}`, - ); - } - - if (!options && !captionOptions) { - console.error( - `ERROR: Missing generation options for perspective ${schemaName}`, - ); - } - - // Format the data as PerspectiveData object - no defaults! - const perspectiveData: PerspectiveData = { + // Format the data as PerspectiveData object + const perspectiveData = { config_name: schemaName, version: "1.0", - model: - result.metadata?.model ?? - (() => { - console.error( - `CRITICAL ERROR: Missing model information in API response for perspective ${schemaName}`, - ); - return "MISSING_MODEL"; - })(), + model: result.metadata?.model ?? "MISSING_MODEL", provider: effectiveProvider.name, content: result.result || {}, - options: options || captionOptions, + options: { + model: effectiveOptions.model_id, // Map to expected model property + max_tokens: effectiveOptions.max_tokens, + temperature: effectiveOptions.temperature, + top_p: effectiveOptions.top_p, + repetition_penalty: effectiveOptions.repetition_penalty, + global_context: effectiveOptions.global_context, + context: effectiveOptions.context, + resize_resolution: effectiveOptions.resize_resolution + } }; // Save the perspective directly to localStorage savePerspectiveCaption(imagePath, schemaName, perspectiveData); // Update captions state with this new perspective data - setCaptions((prev: Record) => { + setCaptions((prev) => { + const prevPerspectives = (prev.perspectives || {}) as Record; const newCaptions = { ...prev, perspectives: { - ...prev.perspectives, + ...prevPerspectives, [schemaName]: perspectiveData, }, metadata: { captioned_at: new Date().toISOString(), - provider: effectiveProvider.name, + provider: effectiveProvider?.name || "", model: result.metadata?.model ?? "unknown", }, }; @@ -485,7 +474,7 @@ export function PerspectivesDataProvider({ currentImage, selectedProvider, availableProviders, - captionOptions, + generationOptions.options, generateCaptionMutation, ], ); @@ -493,7 +482,8 @@ export function PerspectivesDataProvider({ // Helper to check if a perspective is generated const isPerspectiveGenerated = useCallback( (schemaName: string) => { - return !!captions.perspectives?.[schemaName]; + const perspectives = captions.perspectives as Record | undefined; + return !!perspectives?.[schemaName]; }, [captions], ); @@ -510,116 +500,72 @@ export function PerspectivesDataProvider({ const getPerspectiveData = useCallback( (schemaName: string) => { // Try to get data from our in-memory state - const perspectiveData = captions.perspectives?.[schemaName]; + const perspectives = captions.perspectives as Record | undefined; + const perspectiveData = perspectives?.[schemaName]; console.debug("getPerspectiveData for", schemaName, perspectiveData); // Always return the complete perspective data object // to preserve options and metadata - return perspectiveData; + return perspectiveData as Record | null; }, [captions], ); - // Get the server URL from connections - const graphcapServerUrl = useMemo(() => { - const serverConn = connections.find( - (conn) => conn.id === SERVER_IDS.INFERENCE_BRIDGE, - ); - return serverConn?.url || ""; - }, [connections]); - // Create consolidated context value - const value: PerspectivesDataContextType = useMemo( - () => ({ - // Provider state - selectedProvider, - availableProviders, - isGeneratingAll, - - // Provider actions - setSelectedProvider, - setAvailableProviders, - setIsGeneratingAll, - handleProviderChange, - - // Data fetching - providers - fetchProviders, - isLoadingProviders, - providerError, - - // Perspectives data - perspectives: perspectivesData || [], - schemas, - isLoadingPerspectives, - perspectivesError, - refetchPerspectives, - - // Captions data - captions, - generatedPerspectives, - isGenerating: generatingPerspectives.length > 0, - isServerConnected, - - // Caption options - captionOptions, - setCaptionOptions, - - // Current image - currentImage, - setCurrentImage, - - // Generation operations - generatePerspective, - - // Status helpers - isPerspectiveGenerated, - isPerspectiveGenerating, - - // Data helpers - getPerspectiveData, - - // Perspective visibility - hiddenPerspectives, - togglePerspectiveVisibility, - isPerspectiveVisible, - setAllPerspectivesVisible, - }), - [ - selectedProvider, - availableProviders, - isGeneratingAll, - setSelectedProvider, - setAvailableProviders, - setIsGeneratingAll, - handleProviderChange, - fetchProviders, - isLoadingProviders, - providerError, - perspectivesData, - schemas, - isLoadingPerspectives, - perspectivesError, - refetchPerspectives, - captions, - generatedPerspectives, - generatingPerspectives, - isServerConnected, - captionOptions, - setCaptionOptions, - currentImage, - setCurrentImage, - generatePerspective, - isPerspectiveGenerated, - isPerspectiveGenerating, - getPerspectiveData, - hiddenPerspectives, - togglePerspectiveVisibility, - isPerspectiveVisible, - setAllPerspectivesVisible, - graphcapServerUrl, - ], - ); + const value: PerspectivesDataContextType = { + // Provider state + selectedProvider, + availableProviders, + isGeneratingAll, + + // Provider actions + setSelectedProvider, + setAvailableProviders, + setIsGeneratingAll, + handleProviderChange, + + // Data fetching - providers + fetchProviders, + isLoadingProviders, + providerError, + + // Perspectives data + perspectives: perspectivesData || [], + schemas, + isLoadingPerspectives, + perspectivesError, + refetchPerspectives, + + // Captions data + captions, + generatedPerspectives, + isGenerating: generatingPerspectives.length > 0, + isServerConnected, + + // Generation options from context + generationOptions: generationOptions.options, + + // Current image + currentImage, + setCurrentImage, + + // Generation operations + generatePerspective, + + // Status helpers + isPerspectiveGenerated, + isPerspectiveGenerating, + + // Data helpers + getPerspectiveData, + + // Perspective visibility + hiddenPerspectives, + togglePerspectiveVisibility, + isPerspectiveVisible, + setAllPerspectivesVisible, + }; return ( diff --git a/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts b/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts index 7be27499..8d83dd69 100644 --- a/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts +++ b/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts @@ -8,7 +8,11 @@ import { useServerConnectionsContext } from "@/context"; import { SERVER_IDS } from "@/features/server-connections/constants"; import { createInferenceBridgeClient } from "@/features/server-connections/services/apiClients"; -import type { CaptionOptions, CaptionResponse } from "@/types"; +import type { CaptionResponse } from "@/types"; +import { + type GenerationOptions, + formatApiOptions +} from "@/types/generation-option-types"; import { type Provider, toServerConfig, @@ -34,8 +38,8 @@ export function useGeneratePerspectiveCaption() { { perspective: string; imagePath: string; - provider: Provider; // Use Provider type from types.ts - options?: CaptionOptions; + provider: Provider; + options: GenerationOptions; } >({ mutationFn: async ({ perspective, imagePath, provider, options }) => { @@ -53,7 +57,7 @@ export function useGeneratePerspectiveCaption() { } // Check if a model is specified in the options - if (!options.model) { + if (!options.model_id) { throw new Error("A model must be specified in the options"); } @@ -70,38 +74,25 @@ export function useGeneratePerspectiveCaption() { `Generating caption for image: ${normalizedImagePath} using perspective: ${perspective}`, ); + // Format options for API request + const apiOptions = formatApiOptions(options); + // Prepare the request body according to the server's expected format const requestBody = { perspective, image_path: normalizedImagePath, provider: provider.name, - model: options.model, // Use the model from options + model: options.model_id, // Use model_id from GenerationOptions provider_config: providerConfig, // Include the full provider configuration - max_tokens: options.max_tokens, - temperature: options.temperature, - top_p: options.top_p, - repetition_penalty: options.repetition_penalty, - context: options.context || [], - global_context: options.global_context ?? "", - resize: options.resize ?? false, - resize_resolution: options.resize_resolution ?? "HD_720P", + ...apiOptions, // Spread the formatted API options }; console.log("Sending caption generation request using API client", { perspective, image_path: normalizedImagePath, provider: provider.name, - model: options.model, // Log the model from options - options: { - max_tokens: requestBody.max_tokens, - temperature: requestBody.temperature, - top_p: requestBody.top_p, - repetition_penalty: requestBody.repetition_penalty, - context: requestBody.context, - global_context: requestBody.global_context, - resize: requestBody.resize, - resize_resolution: requestBody.resize_resolution, - }, + model: options.model_id, // Log the model_id from options + options: apiOptions, }); const response = await client.perspectives["caption-from-path"].$post({ diff --git a/graphcap_studio/src/features/perspectives/utils/api-adapters.ts b/graphcap_studio/src/features/perspectives/utils/api-adapters.ts new file mode 100644 index 00000000..9cfa576f --- /dev/null +++ b/graphcap_studio/src/features/perspectives/utils/api-adapters.ts @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Perspectives API Adapters + * + * This module provides adapter functions for converting between API and application + * types for the perspectives feature, including caption request formatting. + */ + +import type { GenerationOptions } from "@/types/generation-option-types"; +import { formatApiOptions } from "@/types/generation-option-types"; +import type { Provider } from "@/types/provider-config-types"; +import { denormalizeProviderId } from "@/types/provider-config-types"; + +// Legacy caption options interface for migration purposes +interface LegacyCaptionOptions { + model: string; + max_tokens?: number; + temperature?: number; + top_p?: number; + repetition_penalty?: number; + global_context?: string; + context?: string[]; + resize?: boolean; + resize_resolution?: string; +} + +/** + * Format a caption generation request + * Converts from application types to the API request format + */ +export function formatCaptionRequest( + imagePath: string, + perspective: string, + provider: Provider, + options: GenerationOptions +): { + image_path: string; + perspective: string; + provider_id: number; + options: Record; +} { + return { + image_path: imagePath, + perspective, + provider_id: denormalizeProviderId(provider.id), + options: formatApiOptions(options) + }; +} + +/** + * Convert from CaptionOptions format to GenerationOptions format + * Used during the migration from CaptionOptions to GenerationOptions + */ +export function legacyCaptionToGenerationOptions( + captionOptions: LegacyCaptionOptions, + providerId: string +): GenerationOptions { + return { + model_id: captionOptions.model, + max_tokens: captionOptions.max_tokens ?? 4096, + temperature: captionOptions.temperature ?? 0.7, + top_p: captionOptions.top_p ?? 0.95, + repetition_penalty: captionOptions.repetition_penalty ?? 1.1, + global_context: captionOptions.global_context ?? "You are a visual captioning perspective.", + context: captionOptions.context ?? [], + resize_resolution: captionOptions.resize_resolution ?? "NONE", + provider_id: providerId + }; +} + +// Interface for perspective data structure +interface PerspectiveDataWithOptions { + model: string; + provider: string; + options?: LegacyCaptionOptions; +} + +/** + * Convert from PerspectiveData to GenerationOptions format + * Used for loading saved perspective settings + */ +export function perspectiveDataToGenerationOptions( + perspectiveData: PerspectiveDataWithOptions, + providerIdMap: Record +): GenerationOptions { + // If we have structured options, use those + if (perspectiveData.options) { + return legacyCaptionToGenerationOptions( + perspectiveData.options, + providerIdMap[perspectiveData.provider] || "" + ); + } + + // Otherwise create minimal options + return { + model_id: perspectiveData.model, + provider_id: providerIdMap[perspectiveData.provider] || "", + max_tokens: 4096, + temperature: 0.7, + top_p: 0.95, + repetition_penalty: 1.1, + global_context: "You are a visual captioning perspective.", + context: [], + resize_resolution: "NONE" + }; +} \ No newline at end of file diff --git a/graphcap_studio/src/features/server-connections/services/inferenceBridgeClient.ts b/graphcap_studio/src/features/server-connections/services/inferenceBridgeClient.ts index f7f114dc..054ca024 100644 --- a/graphcap_studio/src/features/server-connections/services/inferenceBridgeClient.ts +++ b/graphcap_studio/src/features/server-connections/services/inferenceBridgeClient.ts @@ -27,6 +27,13 @@ export interface ProviderClient { }) => Promise; }; }; + ":provider": { + "models": { + $get: (options: { + param: { provider: string }; + }) => Promise; + }; + }; } /** diff --git a/graphcap_studio/src/features/server-connections/services/providerAdapters.ts b/graphcap_studio/src/features/server-connections/services/providerAdapters.ts new file mode 100644 index 00000000..f8cea16c --- /dev/null +++ b/graphcap_studio/src/features/server-connections/services/providerAdapters.ts @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Provider API Adapters + * + * This module provides adapter functions for converting between API and application + * provider types, handling the conversion between numeric and string IDs. + */ + +import type { Provider, ProviderModel, ProviderModelInfo } from "@/types/provider-config-types"; +import { normalizeProviderId } from "@/types/provider-config-types"; + +// Type for raw API provider data +interface ApiProvider { + id: number; + name: string; + kind: string; + environment: "cloud" | "local"; + baseUrl: string; + apiKey?: string; + isEnabled: boolean; + defaultModel?: string; + fetchModels: boolean; + createdAt: string | Date; + updatedAt: string | Date; + models?: ApiProviderModel[]; + rateLimits?: ApiRateLimits; +} + +// Type for raw API provider model data +interface ApiProviderModel { + id: number; + providerId: number; + name: string; + isEnabled: boolean; + createdAt: string | Date; + updatedAt: string | Date; +} + +// Type for raw API rate limits data +interface ApiRateLimits { + id: number; + providerId: number; + requestsPerMinute?: number; + tokensPerMinute?: number; + createdAt: string | Date; + updatedAt: string | Date; +} + +// Type for raw API model info +interface ApiModelInfo { + id: string; + name: string; + is_default?: boolean; +} + +/** + * Convert API provider to application Provider type + * This handles ID conversion from number to string + */ +export function fromApiProvider(apiProvider: ApiProvider): Provider { + return { + id: normalizeProviderId(apiProvider.id), + name: apiProvider.name, + kind: apiProvider.kind, + environment: apiProvider.environment, + baseUrl: apiProvider.baseUrl, + apiKey: apiProvider.apiKey, + isEnabled: apiProvider.isEnabled, + defaultModel: apiProvider.defaultModel, + fetchModels: apiProvider.fetchModels, + createdAt: apiProvider.createdAt, + updatedAt: apiProvider.updatedAt, + + // Convert nested models + models: apiProvider.models?.map((model: ApiProviderModel) => ({ + id: normalizeProviderId(model.id), + providerId: normalizeProviderId(model.providerId), + name: model.name, + isEnabled: model.isEnabled, + createdAt: model.createdAt, + updatedAt: model.updatedAt, + })), + + // Convert nested rate limits + rateLimits: apiProvider.rateLimits + ? { + id: normalizeProviderId(apiProvider.rateLimits.id), + providerId: normalizeProviderId(apiProvider.rateLimits.providerId), + requestsPerMinute: apiProvider.rateLimits.requestsPerMinute, + tokensPerMinute: apiProvider.rateLimits.tokensPerMinute, + createdAt: apiProvider.rateLimits.createdAt, + updatedAt: apiProvider.rateLimits.updatedAt, + } + : undefined, + }; +} + +/** + * Convert application Provider to API provider + * This handles ID conversion from string to number + */ +export function toApiProvider(provider: Provider): ApiProvider { + return { + id: Number.parseInt(provider.id, 10), + name: provider.name, + kind: provider.kind, + environment: provider.environment, + baseUrl: provider.baseUrl, + apiKey: provider.apiKey, + isEnabled: provider.isEnabled, + defaultModel: provider.defaultModel, + fetchModels: provider.fetchModels, + createdAt: provider.createdAt, + updatedAt: provider.updatedAt, + + // Convert models back to numeric IDs + models: provider.models?.map((model) => ({ + id: Number.parseInt(model.id, 10), + providerId: Number.parseInt(model.providerId, 10), + name: model.name, + isEnabled: model.isEnabled, + createdAt: model.createdAt, + updatedAt: model.updatedAt, + })), + + // Convert rate limits back to numeric IDs + rateLimits: provider.rateLimits + ? { + id: Number.parseInt(provider.rateLimits.id, 10), + providerId: Number.parseInt(provider.rateLimits.providerId, 10), + requestsPerMinute: provider.rateLimits.requestsPerMinute, + tokensPerMinute: provider.rateLimits.tokensPerMinute, + createdAt: provider.rateLimits.createdAt, + updatedAt: provider.rateLimits.updatedAt, + } + : undefined, + }; +} + +/** + * Convert API model info to application ProviderModelInfo + */ +export function fromApiModelInfo(apiModel: ApiModelInfo): ProviderModelInfo { + return { + id: apiModel.id, + name: apiModel.name, + is_default: apiModel.is_default, + }; +} + +/** + * Create a provider model with defaults + */ +export function createProviderModel( + providerId: string, + name: string, + id?: string, +): ProviderModel { + return { + id: id || crypto.randomUUID(), // Generate UUID if no ID provided + providerId, + name, + isEnabled: true, + createdAt: new Date().toISOString(), + updatedAt: new Date().toISOString(), + }; +} diff --git a/graphcap_studio/src/features/server-connections/services/providers.ts b/graphcap_studio/src/features/server-connections/services/providers.ts index e2530755..2fe3eec2 100644 --- a/graphcap_studio/src/features/server-connections/services/providers.ts +++ b/graphcap_studio/src/features/server-connections/services/providers.ts @@ -18,11 +18,12 @@ import { toServerConfig } from "@/types/provider-config-types"; import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; import { SERVER_IDS } from "../constants"; import { createDataServiceClient, createInferenceBridgeClient } from "./apiClients"; +import { fromApiProvider, toApiProvider } from "./providerAdapters"; // Query keys for TanStack Query export const queryKeys = { providers: ["providers"] as const, - provider: (id: number) => ["providers", id] as const, + provider: (id: string) => ["providers", id] as const, providerModels: (providerName: string) => ["providers", "models", providerName] as const, }; @@ -39,14 +40,23 @@ export function useProviders() { return useQuery({ queryKey: queryKeys.providers, queryFn: async () => { + console.log("📡 Fetching all providers"); + const client = createDataServiceClient(connections); const response = await client.providers.$get(); if (!response.ok) { - throw new Error(`Failed to fetch providers: ${response.status}`); + const errorMsg = `Failed to fetch providers: ${response.status}`; + console.error(`❌ ${errorMsg}`); + throw new Error(errorMsg); } - return response.json() as Promise; + // Convert API response to application types + const apiProviders = await response.json(); + const providers = apiProviders.map(fromApiProvider); + + console.log(`✅ Fetched ${providers.length} providers:`, providers); + return providers; }, enabled: isConnected, staleTime: 1000 * 60 * 5, // 5 minutes @@ -56,7 +66,7 @@ export function useProviders() { /** * Hook to get a provider by ID */ -export function useProvider(id: number) { +export function useProvider(id: string) { const { connections } = useServerConnectionsContext(); const dataServiceConnection = connections.find( (conn) => conn.id === SERVER_IDS.DATA_SERVICE, @@ -66,16 +76,25 @@ export function useProvider(id: number) { return useQuery({ queryKey: queryKeys.provider(id), queryFn: async () => { + console.log(`📡 Fetching provider with ID: ${id}`); + const client = createDataServiceClient(connections); const response = await client.providers[":id"].$get({ - param: { id: id.toString() }, + param: { id }, }); if (!response.ok) { - throw new Error(`Failed to fetch provider: ${response.status}`); + const errorMsg = `Failed to fetch provider: ${response.status}`; + console.error(`❌ ${errorMsg}`); + throw new Error(errorMsg); } - return response.json() as Promise; + // Convert API response to application types + const apiProvider = await response.json(); + const provider = fromApiProvider(apiProvider); + + console.log("✅ Fetched provider:", provider); + return provider; }, enabled: isConnected && !!id, }); @@ -85,49 +104,42 @@ export function useProvider(id: number) { * Hook to create a provider */ export function useCreateProvider() { - const queryClient = useQueryClient(); const { connections } = useServerConnectionsContext(); + const queryClient = useQueryClient(); return useMutation({ - mutationFn: async (provider: ProviderCreate) => { + mutationFn: async (data: ProviderCreate) => { + console.log("📡 Creating provider:", data); + const client = createDataServiceClient(connections); + // Convert application data to API format + const apiData = toApiProvider(data as Provider); + console.log("📤 API request data:", apiData); + const response = await client.providers.$post({ - json: provider, + json: apiData, }); if (!response.ok) { - // Try to get detailed error information - try { - const errorData = await response.json(); - console.error("Provider creation error:", errorData); - - // Check if we have a structured error response - if (errorData.status === 'error' || errorData.validationErrors) { - throw errorData; - } - - // Simple error with a message - if (errorData.message) { - throw new Error(errorData.message); - } - - // Fallback error - throw new Error(`Failed to create provider: ${response.status}`); - } catch (parseError) { - // If we can't parse the error as JSON, throw a general error - if (parseError instanceof Error && parseError.message !== 'Failed to create provider') { - throw parseError; - } - throw new Error(`Failed to create provider: ${response.status}`); - } + const errorMsg = `Failed to create provider: ${response.status}`; + console.error(`❌ ${errorMsg}`); + throw new Error(errorMsg); } - return response.json() as Promise; + // Convert API response to application types + const apiProvider = await response.json(); + const provider = fromApiProvider(apiProvider); + + console.log("✅ Provider created:", provider); + return provider; }, onSuccess: () => { - // Invalidate providers query to refetch the list + console.log("🔄 Invalidating providers cache after create"); queryClient.invalidateQueries({ queryKey: queryKeys.providers }); }, + onError: (error: Error) => { + console.error("❌ Error in useCreateProvider:", error); + }, }); } @@ -135,60 +147,46 @@ export function useCreateProvider() { * Hook to update a provider */ export function useUpdateProvider() { - const queryClient = useQueryClient(); const { connections } = useServerConnectionsContext(); + const queryClient = useQueryClient(); return useMutation({ - mutationFn: async ({ id, data }: { id: number; data: ProviderUpdate }) => { - console.log("Updating provider with data:", data); - const updateData = Object.entries(data).reduce((acc, [key, value]) => { - // Only include defined values - if (value !== null && value !== undefined) { - acc[key] = value; - } - return acc; - }, {} as Record); - console.log("updateData", updateData); + mutationFn: async ({ id, data }: { id: string; data: ProviderUpdate }) => { + console.log(`📡 Updating provider with ID ${id}:`, data); + const client = createDataServiceClient(connections); + // Convert application data to API format + const apiData = toApiProvider(data as Provider); + // Create a new object without the ID + const { id: _, ...apiDataWithoutId } = apiData; + + console.log("📤 API request data:", apiDataWithoutId); + const response = await client.providers[":id"].$put({ - param: { id: id.toString() }, - json: updateData, + param: { id }, + json: apiDataWithoutId, }); if (!response.ok) { - // Try to get detailed error information - try { - const errorData = await response.json(); - console.error("Provider update error:", errorData); - - // Check if we have a structured error response - if (errorData.status === 'error' || errorData.validationErrors) { - throw errorData; - } - - // Simple error with a message - if (errorData.message) { - throw new Error(errorData.message); - } - - // Fallback error - throw new Error(`Failed to update provider: ${response.status}`); - } catch (parseError) { - // If we can't parse the error as JSON, throw a general error - if (parseError instanceof Error && parseError.message !== 'Failed to update provider') { - throw parseError; - } - throw new Error(`Failed to update provider: ${response.status}`); - } + const errorMsg = `Failed to update provider: ${response.status}`; + console.error(`❌ ${errorMsg}`); + throw new Error(errorMsg); } - return response.json() as Promise; + // Convert API response to application types + const apiProvider = await response.json(); + const provider = fromApiProvider(apiProvider); + + console.log("✅ Provider updated:", provider); + return provider; }, - onSuccess: (data) => { - // Invalidate specific provider query - queryClient.invalidateQueries({ queryKey: queryKeys.provider(data.id) }); - // Invalidate providers list + onSuccess: (_data, variables) => { + console.log(`🔄 Invalidating providers cache after update for ID ${variables.id}`); queryClient.invalidateQueries({ queryKey: queryKeys.providers }); + queryClient.invalidateQueries({ queryKey: queryKeys.provider(variables.id) }); + }, + onError: (error: Error, variables) => { + console.error(`❌ Error in useUpdateProvider for ID ${variables.id}:`, error); }, }); } @@ -197,50 +195,135 @@ export function useUpdateProvider() { * Hook to delete a provider */ export function useDeleteProvider() { - const queryClient = useQueryClient(); const { connections } = useServerConnectionsContext(); + const queryClient = useQueryClient(); return useMutation({ - mutationFn: async (id: number) => { + mutationFn: async (id: string) => { + console.log(`📡 Deleting provider with ID: ${id}`); + const client = createDataServiceClient(connections); + const response = await client.providers[":id"].$delete({ - param: { id: id.toString() }, + param: { id }, }); if (!response.ok) { - throw new Error(`Failed to delete provider: ${response.status}`); + const errorMsg = `Failed to delete provider: ${response.status}`; + console.error(`❌ ${errorMsg}`); + throw new Error(errorMsg); } - return response.json() as Promise; + const result = await response.json() as SuccessResponse; + console.log("✅ Provider deleted:", result); + return result; }, - onSuccess: (_, id) => { - // Invalidate specific provider query - queryClient.invalidateQueries({ queryKey: queryKeys.provider(id) }); - // Invalidate providers list + onSuccess: (_data, id) => { + console.log(`🔄 Invalidating providers cache after delete for ID ${id}`); queryClient.invalidateQueries({ queryKey: queryKeys.providers }); + queryClient.invalidateQueries({ queryKey: queryKeys.provider(id) }); + }, + onError: (error: Error, id) => { + console.error(`❌ Error in useDeleteProvider for ID ${id}:`, error); }, }); } /** - * Hook to get available models for a provider + * Hook to get provider models */ -export function useProviderModels(provider: Provider | null | undefined) { +export function useProviderModels(providerName: string | Provider) { const { connections } = useServerConnectionsContext(); - const inferenceBridgeConnection = connections.find( + const graphcapServerConnection = connections.find( (conn) => conn.id === SERVER_IDS.INFERENCE_BRIDGE, ); - const isConnected = inferenceBridgeConnection?.status === "connected"; + const isConnected = graphcapServerConnection?.status === "connected"; + + // Extract the provider name and data if an object was passed + const isProviderObject = typeof providerName === 'object' && providerName !== null; + const name = isProviderObject ? providerName.name : providerName; + const provider = isProviderObject ? providerName : null; return useQuery({ - queryKey: queryKeys.providerModels(provider?.name ?? 'unknown'), + queryKey: queryKeys.providerModels(name), queryFn: async () => { - if (!provider) { - throw new Error("Provider is null or undefined"); + console.log(`📡 Fetching models for provider: ${name}`); + + try { + const client = createInferenceBridgeClient(connections); + + // If we have the full provider object, use it to create a server config + // Otherwise, use a minimal configuration + const config = provider + ? toServerConfig(provider) + : { + name, + kind: "unknown", + environment: "cloud" as const, + base_url: "", + api_key: "", + models: [], + fetch_models: true + }; + + console.log("📤 API request data:", config); + + // Use the POST endpoint with provider_name param and config body + const response = await client.providers[":provider_name"].models.$post({ + param: { provider_name: name }, + json: config, + }); + + if (!response.ok) { + throw new Error(`Failed to fetch provider models from API: ${response.status}`); + } + + const models = await response.json() as ProviderModelsResponse; + console.log(`✅ Fetched ${models.models.length} models from API for provider ${name}:`, models); + return models; + } catch (error) { + // If we have a provider object with models, use them as fallback + if (provider?.models && provider.models.length > 0) { + const modelCount = provider.models.length; + console.log(`⚠️ API request failed. Using ${modelCount} saved models from provider`); + + // Convert the provider models to the expected ProviderModelsResponse format + const fallbackModels: ProviderModelsResponse = { + provider: name, + models: provider.models.map(model => ({ + id: model.id ? (typeof model.id === 'string' ? model.id : String(model.id)) : String(model.name), + name: model.name, + is_default: model.name === provider.defaultModel + })) + }; + + return fallbackModels; + } + + // If no fallback is available, re-throw the error + console.error("❌ Failed to fetch models and no fallback available:", error); + throw error; } + }, + enabled: isConnected && !!name, + }); +} + +/** + * Hook to test a provider connection + */ +export function useTestProviderConnection() { + const { connections } = useServerConnectionsContext(); + + return useMutation({ + mutationFn: async (provider: Provider) => { + console.log(`📡 Testing connection for provider: ${provider.name}`, provider); const client = createInferenceBridgeClient(connections); + + // Convert to server config format const serverConfig = toServerConfig(provider); + console.log("📤 API request data:", serverConfig); const response = await client.providers[":provider_name"].models.$post({ param: { provider_name: provider.name }, @@ -248,12 +331,18 @@ export function useProviderModels(provider: Provider | null | undefined) { }); if (!response.ok) { - throw new Error(`Failed to fetch models: ${response.status}`); + const errorData = await response.json(); + const errorMsg = errorData.message || `Failed to test provider connection: ${response.status}`; + console.error(`❌ ${errorMsg}`, errorData); + throw new Error(errorMsg); } - return response.json() as Promise; + const result = await response.json(); + console.log("✅ Provider connection test successful:", result); + return result; + }, + onError: (error: Error, provider) => { + console.error(`❌ Error in useTestProviderConnection for provider ${provider.name}:`, error); }, - enabled: isConnected && !!provider && !!provider.fetchModels, - staleTime: 1000 * 60 * 10, // 10 minutes }); } \ No newline at end of file diff --git a/graphcap_studio/src/types/generation-option-types.ts b/graphcap_studio/src/types/generation-option-types.ts index 1036e5e1..ec4b06a2 100644 --- a/graphcap_studio/src/types/generation-option-types.ts +++ b/graphcap_studio/src/types/generation-option-types.ts @@ -2,7 +2,8 @@ /** * Generation Options Schema * - * This module defines the validation schema for caption generation options. + * This module defines the validation schema for generation options, + * replacing the legacy CaptionOptions with a consolidated schema. */ import { z } from "zod"; @@ -27,7 +28,7 @@ export const RESOLUTION_PRESETS = { UHD_8K: { label: "8K UHD", value: "UHD_8K" }, } as const; -// Default options for caption generation +// Default options for generation export const DEFAULT_OPTIONS = { temperature: 0.7, max_tokens: 4096, @@ -35,6 +36,7 @@ export const DEFAULT_OPTIONS = { repetition_penalty: 1.1, resize_resolution: "NONE", // Default to no resize global_context: "You are a visual captioning perspective.", + context: [] as string[], // Default to empty context array provider_id: "", // Default to empty (will be populated later) model_id: "", // Default to empty (will be populated later) } as const; @@ -70,6 +72,10 @@ export const GenerationOptionsSchema = z.object({ global_context: z.string().default(DEFAULT_OPTIONS.global_context), + // Added context array (was in CaptionOptions) + context: z.array(z.string()).default([]), + + // Provider and model selection provider_id: z.string().default(DEFAULT_OPTIONS.provider_id), model_id: z.string().default(DEFAULT_OPTIONS.model_id), @@ -77,3 +83,41 @@ export const GenerationOptionsSchema = z.object({ // Type for generation options export type GenerationOptions = z.infer; + +/** + * Format generation options for API requests + * This transforms the frontend GenerationOptions to the format expected by the API + */ +export function formatApiOptions(options: GenerationOptions): Record { + return { + model: options.model_id, // API expects 'model' instead of model_id + temperature: options.temperature, + max_tokens: options.max_tokens, + top_p: options.top_p, + repetition_penalty: options.repetition_penalty, + global_context: options.global_context, + context: options.context, + resize_resolution: options.resize_resolution, + }; +} + +/** + * Format a complete caption request + */ +export function formatCaptionRequest( + imagePath: string, + perspective: string, + options: GenerationOptions +): { + image_path: string; + perspective: string; + provider_id: string; + options: Record; +} { + return { + image_path: imagePath, + perspective: perspective, + provider_id: options.provider_id, + options: formatApiOptions(options), + }; +} diff --git a/graphcap_studio/src/types/provider-config-types.ts b/graphcap_studio/src/types/provider-config-types.ts index b78e58ee..b2774c68 100644 --- a/graphcap_studio/src/types/provider-config-types.ts +++ b/graphcap_studio/src/types/provider-config-types.ts @@ -2,146 +2,251 @@ /** * Provider Types * - * Type definitions for provider-related data. + * Type definitions for provider-related data with Zod validation. */ +import { z } from "zod"; + +// ============================================================================ +// SECTION A - ZOD SCHEMAS +// ============================================================================ + /** - * Server-side provider configuration - * This is the configuration that gets sent to the inference server + * Base provider schema */ -export interface ServerProviderConfig { - name: string; - kind: string; - environment: "cloud" | "local"; - base_url: string; - api_key: string; // Required for server requests - default_model?: string; - models: string[]; - fetch_models: boolean; - rate_limits?: { - requests_per_minute?: number; - tokens_per_minute?: number; - }; -} +export const BaseProviderSchema = z.object({ + id: z.string(), + name: z.string().min(1, "Name is required"), + isEnabled: z.boolean().default(true), +}); + +/** + * Provider model schema + */ +export const ProviderModelSchema = z.object({ + id: z.string(), + providerId: z.string(), + name: z.string().min(1, "Model name is required"), + isEnabled: z.boolean().default(true), + createdAt: z.string().or(z.date()), + updatedAt: z.string().or(z.date()), +}); + +/** + * Rate limits schema + */ +export const RateLimitsSchema = z.object({ + id: z.string(), + providerId: z.string(), + requestsPerMinute: z.number().optional(), + tokensPerMinute: z.number().optional(), + createdAt: z.string().or(z.date()), + updatedAt: z.string().or(z.date()), +}); + +/** + * Complete provider schema + */ +export const ProviderSchema = BaseProviderSchema.extend({ + kind: z.string().min(1, "Kind is required"), + environment: z.enum(["cloud", "local"]), + baseUrl: z.string().url("Must be a valid URL"), + apiKey: z.string().optional(), + defaultModel: z.string().optional(), + fetchModels: z.boolean().default(true), + createdAt: z.string().or(z.date()), + updatedAt: z.string().or(z.date()), + models: z.array(ProviderModelSchema).optional(), + rateLimits: RateLimitsSchema.optional(), +}); + +// Provider creation schema +export const ProviderCreateSchema = z.object({ + name: z.string().min(1, "Name is required"), + kind: z.string().min(1, "Kind is required"), + environment: z.enum(["cloud", "local"]), + baseUrl: z.string().url("Must be a valid URL"), + apiKey: z.string().optional(), + isEnabled: z.boolean().default(true), + defaultModel: z.string().optional(), + fetchModels: z.boolean().default(true), + models: z + .array( + z.object({ + name: z.string().min(1, "Model name is required"), + isEnabled: z.boolean().default(true), + }), + ) + .optional(), + rateLimits: z + .object({ + requestsPerMinute: z.number().optional(), + tokensPerMinute: z.number().optional(), + }) + .optional(), +}); + +// Provider update schema +export const ProviderUpdateSchema = z.object({ + name: z.string().min(1, "Name is required").optional(), + kind: z.string().min(1, "Kind is required").optional(), + environment: z.enum(["cloud", "local"]).optional(), + baseUrl: z.string().url("Must be a valid URL").optional(), + apiKey: z.string().optional(), + isEnabled: z.boolean().optional(), + defaultModel: z.string().optional(), + fetchModels: z.boolean().optional(), + models: z + .array( + z.object({ + id: z.string().optional(), + name: z.string().min(1, "Model name is required"), + isEnabled: z.boolean().default(true), + }), + ) + .optional(), + rateLimits: z + .object({ + requestsPerMinute: z.number().optional(), + tokensPerMinute: z.number().optional(), + }) + .optional(), +}); + +// Provider model info schema +export const ProviderModelInfoSchema = z.object({ + id: z.string(), + name: z.string(), + is_default: z.boolean().optional(), +}); + +// Provider models response schema +export const ProviderModelsResponseSchema = z.object({ + provider: z.string(), + models: z.array(ProviderModelInfoSchema), +}); + +// Success response schema +export const SuccessResponseSchema = z.object({ + success: z.boolean(), + message: z.string(), +}); + +// Error details schema +export const ErrorDetailsSchema = z.object({ + message: z.string(), + code: z.string().optional(), + details: z.record(z.unknown()).optional(), +}); + +// Connection details schema +export const ConnectionDetailsSchema = z.object({ + result: z.boolean(), + details: z.record(z.unknown()).optional(), + message: z.string().optional(), +}); + +// Server provider config schema +export const ServerProviderConfigSchema = z.object({ + name: z.string(), + kind: z.string(), + environment: z.enum(["cloud", "local"]), + base_url: z.string(), + api_key: z.string(), + default_model: z.string().optional(), + models: z.array(z.string()), + fetch_models: z.boolean(), + rate_limits: z + .object({ + requests_per_minute: z.number().optional(), + tokens_per_minute: z.number().optional(), + }) + .optional(), +}); + +// ============================================================================ +// SECTION B - INFERRED TYPES +// ============================================================================ + +/** + * Base provider interface for selection + */ +export type BaseProvider = z.infer; /** * Provider model */ -export interface ProviderModel { - id: number; - providerId: number; - name: string; - isEnabled: boolean; - createdAt: string | Date; - updatedAt: string | Date; -} +export type ProviderModel = z.infer; /** * Rate limits configuration */ -export interface RateLimits { - id: number; - providerId: number; - requestsPerMinute?: number; - tokensPerMinute?: number; - createdAt: string | Date; - updatedAt: string | Date; -} +export type RateLimits = z.infer; /** * Provider configuration stored in data service */ -export interface Provider { - id: number; - name: string; - kind: string; - environment: "cloud" | "local"; - baseUrl: string; - apiKey: string; // Changed from optional to required - isEnabled: boolean; - defaultModel?: string; - fetchModels: boolean; - createdAt: string | Date; - updatedAt: string | Date; - models?: ProviderModel[]; - rateLimits?: RateLimits; -} +export type Provider = z.infer; /** * Provider creation payload */ -export interface ProviderCreate { - name: string; - kind: string; - environment: "cloud" | "local"; - baseUrl: string; - apiKey: string; // Changed from optional to required - isEnabled?: boolean; - defaultModel?: string; - fetchModels?: boolean; - models?: Array<{ - name: string; - isEnabled?: boolean; - }>; - rateLimits?: { - requestsPerMinute?: number; - tokensPerMinute?: number; - }; -} +export type ProviderCreate = z.infer; /** * Provider update payload */ -export interface ProviderUpdate { - name?: string; - kind?: string; - environment?: "cloud" | "local"; - baseUrl?: string; - apiKey?: string; - isEnabled?: boolean; - defaultModel?: string; - fetchModels?: boolean; - models?: Array<{ - id?: number; - name: string; - isEnabled?: boolean; - }>; - rateLimits?: { - requestsPerMinute?: number; - tokensPerMinute?: number; - }; -} +export type ProviderUpdate = z.infer; /** - * Provider API key update payload + * Provider model info from GraphCap server */ -export interface ProviderApiKey { - apiKey: string; -} +export type ProviderModelInfo = z.infer; + +/** + * Provider models response from GraphCap server + */ +export type ProviderModelsResponse = z.infer< + typeof ProviderModelsResponseSchema +>; /** * Success response */ -export interface SuccessResponse { - success: boolean; - message: string; -} +export type SuccessResponse = z.infer; /** - * Provider model info from GraphCap server + * Error details + */ +export type ErrorDetails = z.infer; + +/** + * Connection details */ -export interface ProviderModelInfo { - id: string; - name: string; - is_default: boolean; +export type ConnectionDetails = z.infer; + +/** + * Server-side provider configuration + * This is the configuration that gets sent to the inference server + */ +export type ServerProviderConfig = z.infer; + +// ============================================================================ +// SECTION C - UTILITY FUNCTIONS +// ============================================================================ + +/** + * Convert string ID to number for API calls + */ +export function denormalizeProviderId(id: string): number { + return Number.parseInt(id, 10); } /** - * Provider models response from GraphCap server + * Convert number ID to string for frontend use */ -export interface ProviderModelsResponse { - provider: string; - models: ProviderModelInfo[]; +export function normalizeProviderId(id: number | string): string { + return id.toString(); } /** @@ -153,25 +258,15 @@ export function toServerConfig(provider: Provider): ServerProviderConfig { kind: provider.kind, environment: provider.environment, base_url: provider.baseUrl, - api_key: provider.apiKey, + api_key: provider.apiKey || "", default_model: provider.defaultModel, - models: provider.models?.map(m => m.name) || [], + models: provider.models?.map((m) => m.name) || [], fetch_models: provider.fetchModels, - rate_limits: provider.rateLimits ? { - requests_per_minute: provider.rateLimits.requestsPerMinute, - tokens_per_minute: provider.rateLimits.tokensPerMinute - } : undefined + rate_limits: provider.rateLimits + ? { + requests_per_minute: provider.rateLimits.requestsPerMinute, + tokens_per_minute: provider.rateLimits.tokensPerMinute, + } + : undefined, }; } - -export interface ErrorDetails { - message: string; - code?: string; - details?: Record; -} - -export interface ConnectionDetails { - result: boolean; - details?: Record; - message?: string; -} From 1200ae12c74b6fc2695d989f3962578cc9b3f3d7 Mon Sep 17 00:00:00 2001 From: jphillips Date: Sat, 29 Mar 2025 10:59:12 -0500 Subject: [PATCH 32/69] Fix persist and save of new gen option. Clean up old ui components Signed-off-by: jphillips --- .../inference/generation-options/README.md | 163 ----------------- .../components/GenerationOptionsButton.tsx | 41 ----- .../components/GenerationOptionsDialog.tsx | 165 ------------------ .../components/GenerationOptionsPanel.tsx | 4 +- .../components/fields/GlobalContextField.tsx | 4 +- .../components/fields/MaxTokensField.tsx | 9 +- .../components/fields/ModelSelectorField.tsx | 16 +- .../fields/RepetitionPenaltyField.tsx | 4 +- .../fields/ResizeResolutionField.tsx | 4 +- .../components/fields/TemperatureField.tsx | 4 +- .../components/fields/TopPField.tsx | 6 +- .../generation-options/components/index.ts | 4 +- .../src/features/inference/hooks/index.ts | 4 +- .../hooks/useProviderModelSelection.ts | 76 -------- 14 files changed, 37 insertions(+), 467 deletions(-) delete mode 100644 graphcap_studio/src/features/inference/generation-options/README.md delete mode 100644 graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsButton.tsx delete mode 100644 graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsDialog.tsx delete mode 100644 graphcap_studio/src/features/inference/hooks/useProviderModelSelection.ts diff --git a/graphcap_studio/src/features/inference/generation-options/README.md b/graphcap_studio/src/features/inference/generation-options/README.md deleted file mode 100644 index f805d793..00000000 --- a/graphcap_studio/src/features/inference/generation-options/README.md +++ /dev/null @@ -1,163 +0,0 @@ -# Generation Options Module - -This module provides components and context for managing model generation options such as temperature, max tokens, top_p, and repetition penalty. - -## Features - -- React Context API for state management -- Zod schema for validation -- Chakra UI Popover for form display -- Individual field components for easy reuse - -## Usage - -### Basic Setup - -Wrap your component tree with the provider: - -```tsx -import { GenerationOptionsProvider } from '@/features/inference/generation-options'; - -function App() { - return ( - - - - ); -} -``` - -### Using the Button - -The simplest way to add generation options to your UI: - -```tsx -import { GenerationOptionsButton } from '@/features/inference/generation-options'; - -function YourComponent() { - return ( -
- {/* Other UI elements */} - -
- ); -} -``` - -### Custom Trigger - -You can use your own trigger element instead of the default button: - -```tsx -import { GenerationOptionsPopover, useGenerationOptions } from '@/features/inference/generation-options'; -import { IconButton } from '@/components/ui'; - -function YourComponent() { - const { togglePopover } = useGenerationOptions(); - - return ( -
- {/* Other UI elements */} - - - -
- ); -} -``` - -### Accessing Options State - -You can access the current options state anywhere in your component tree: - -```tsx -import { useGenerationOptions } from '@/features/inference/generation-options'; - -function YourComponent() { - const { options, updateOption, resetOptions } = useGenerationOptions(); - - // Example: Get the current temperature value - console.log('Current temperature:', options.temperature); - - // Example: Update an option - const handleTemperatureChange = (newValue) => { - updateOption('temperature', newValue); - }; - - return ( -
- {/* Your UI using options */} -
- ); -} -``` - -### Getting Notified of Changes - -You can pass an `onOptionsChange` callback to the provider: - -```tsx -import { GenerationOptionsProvider } from '@/features/inference/generation-options'; -import type { GenerationOptions } from '@/features/inference/generation-options'; - -function App() { - const handleOptionsChange = (options: GenerationOptions) => { - console.log('Options changed:', options); - // Do something with the updated options - }; - - return ( - - - - ); -} -``` - -## Components - -- `GenerationOptionsProvider`: Context provider -- `GenerationOptionsButton`: Button that opens the options popover -- `GenerationOptionsPopover`: Popover container for custom triggers -- Field Components: - - `TemperatureField`: Controls the temperature option - - `MaxTokensField`: Controls the max_tokens option - - `TopPField`: Controls the top_p option - - `RepetitionPenaltyField`: Controls the repetition_penalty option - -## API - -### GenerationOptionsProvider Props - -| Prop | Type | Description | -|------|------|-------------| -| `children` | `React.ReactNode` | Child components | -| `initialOptions` | `Partial` | Initial values for options | -| `onOptionsChange` | `(options: GenerationOptions) => void` | Callback when options change | - -### GenerationOptionsButton Props - -| Prop | Type | Default | Description | -|------|------|---------|-------------| -| `label` | `string` | `'Options'` | Button text | -| `size` | `'xs' \| 'sm' \| 'md' \| 'lg'` | `'sm'` | Button size | -| `variant` | `'solid' \| 'outline' \| 'ghost'` | `'outline'` | Button variant | - -### useGenerationOptions Hook - -The hook returns an object with the following properties: - -- `options`: Current options state -- `isPopoverOpen`: Whether the popover is open -- `isGenerating`: Whether generation is in progress -- `updateOption`: Function to update a single option -- `resetOptions`: Function to reset options to defaults -- `setOptions`: Function to update multiple options -- `openPopover`: Function to open the popover -- `closePopover`: Function to close the popover -- `togglePopover`: Function to toggle the popover -- `setIsGenerating`: Function to update the isGenerating state \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsButton.tsx b/graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsButton.tsx deleted file mode 100644 index 1e875ed5..00000000 --- a/graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsButton.tsx +++ /dev/null @@ -1,41 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Generation Options Button - * - * This component provides a button that triggers the generation options dialog. - */ - -import { Button } from "@/components/ui"; -import type React from "react"; -import { useGenerationOptions } from "../context"; -import { GenerationOptionsDialog } from "./GenerationOptionsDialog"; - -interface GenerationOptionsButtonProps { - readonly label?: React.ReactNode; - readonly size?: "xs" | "sm" | "md" | "lg"; - readonly variant?: "solid" | "outline" | "ghost"; -} - -/** - * Button component for triggering generation options dialog - */ -export function GenerationOptionsButton({ - label = "Options", - size = "sm", - variant = "outline", -}: GenerationOptionsButtonProps) { - const { toggleDialog, isGenerating } = useGenerationOptions(); - - return ( - - - - ); -} diff --git a/graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsDialog.tsx b/graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsDialog.tsx deleted file mode 100644 index 874944dc..00000000 --- a/graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsDialog.tsx +++ /dev/null @@ -1,165 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Generation Options Dialog - * - * This component displays a dialog with generation options form. - */ - -import { Button } from "@/components/ui"; -import { useColorModeValue } from "@/components/ui/theme/color-mode"; -import { Box, CloseButton, Dialog, Fieldset, Flex, HStack, Portal } from "@chakra-ui/react"; -import type React from "react"; -import { useGenerationOptions } from "../context"; -import { - GlobalContextField, - MaxTokensField, - ModelSelectorField, - RepetitionPenaltyField, - ResizeResolutionField, - TemperatureField, - TopPField, -} from "./fields"; - -interface GenerationOptionsDialogProps { - readonly children: React.ReactNode; -} - -/** - * Dialog component for generation options - */ -export function GenerationOptionsDialog({ - children, -}: GenerationOptionsDialogProps) { - const { isDialogOpen, closeDialog, resetOptions, isGenerating } = - useGenerationOptions(); - - // Colors for theming - const bgColor = useColorModeValue("white", "gray.700"); - const borderColor = useColorModeValue("gray.200", "gray.600"); - const headerColor = useColorModeValue("gray.800", "white"); - - return ( - (e.open ? null : closeDialog())} - size="lg" - > - {children} - - - - - - - Generation Options - - - - - - - - - {/* First column: Model selection */} - - - Provider & Model - - - - - - - {/* Second column: Generation Parameters */} - - - Generation Parameters - - - - - - - - - - - - - - - - - - - - {/* Third column: Image Processing */} - - - Image Processing - - - - - - - - {/* Global Context Field (spans full width on second row) */} - - Context Settings - - - - - - - - - - - - - - - - - - - - - - ); -} \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsPanel.tsx b/graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsPanel.tsx index 9e721514..a01b22e9 100644 --- a/graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsPanel.tsx +++ b/graphcap_studio/src/features/inference/generation-options/components/GenerationOptionsPanel.tsx @@ -21,7 +21,9 @@ import { * Panel component for generation options in the left action drawer */ export function GenerationOptionsPanel() { - const { resetOptions, isGenerating } = useGenerationOptions(); + const { actions, uiState } = useGenerationOptions(); + const { resetOptions } = actions; + const { isGenerating } = uiState; return ( diff --git a/graphcap_studio/src/features/inference/generation-options/components/fields/GlobalContextField.tsx b/graphcap_studio/src/features/inference/generation-options/components/fields/GlobalContextField.tsx index 92b0a8f0..3ead3ae8 100644 --- a/graphcap_studio/src/features/inference/generation-options/components/fields/GlobalContextField.tsx +++ b/graphcap_studio/src/features/inference/generation-options/components/fields/GlobalContextField.tsx @@ -14,7 +14,9 @@ import { useGenerationOptions } from "../../context"; * Global context control field component */ export function GlobalContextField() { - const { options, updateOption, isGenerating } = useGenerationOptions(); + const { options, actions, uiState } = useGenerationOptions(); + const { updateOption } = actions; + const { isGenerating } = uiState; const [localValue, setLocalValue] = useState(options.global_context); // Color values for theming diff --git a/graphcap_studio/src/features/inference/generation-options/components/fields/MaxTokensField.tsx b/graphcap_studio/src/features/inference/generation-options/components/fields/MaxTokensField.tsx index 1e3e8d42..f242561d 100644 --- a/graphcap_studio/src/features/inference/generation-options/components/fields/MaxTokensField.tsx +++ b/graphcap_studio/src/features/inference/generation-options/components/fields/MaxTokensField.tsx @@ -2,7 +2,7 @@ /** * Max Tokens Field Component * - * This component renders the max tokens option field. + * This component renders the max_tokens option field. */ import { useGenerationOptions } from "../../context"; @@ -12,11 +12,12 @@ import { OptionField } from "./OptionField"; * Max tokens control field component */ export function MaxTokensField() { - const { options, updateOption, isGenerating } = useGenerationOptions(); + const { options, actions, uiState } = useGenerationOptions(); + const { updateOption } = actions; + const { isGenerating } = uiState; const handleChange = (value: number) => { - // Ensure we're using an integer - updateOption("max_tokens", Math.round(value)); + updateOption("max_tokens", value); }; return ( diff --git a/graphcap_studio/src/features/inference/generation-options/components/fields/ModelSelectorField.tsx b/graphcap_studio/src/features/inference/generation-options/components/fields/ModelSelectorField.tsx index a48f5a0e..9d20d04d 100644 --- a/graphcap_studio/src/features/inference/generation-options/components/fields/ModelSelectorField.tsx +++ b/graphcap_studio/src/features/inference/generation-options/components/fields/ModelSelectorField.tsx @@ -18,8 +18,12 @@ export function ModelSelectorField() { options, providers, models, - actions + actions, + uiState } = useGenerationOptions(); + + const { selectProvider, selectModel } = actions; + const { isGenerating } = uiState; // Color values for theming const labelColor = useColorModeValue("gray.700", "gray.300"); @@ -49,14 +53,14 @@ export function ModelSelectorField() { // Handle provider change const handleProviderChange = (details: { value: string[] }) => { if (details.value.length > 0 && details.value[0] !== "none") { - actions.selectProvider(details.value[0]); + selectProvider(details.value[0]); } }; // Handle model change const handleModelChange = (details: { value: string[] }) => { if (details.value.length > 0 && details.value[0] !== "none") { - actions.selectModel(details.value[0]); + selectModel(details.value[0]); } }; @@ -68,7 +72,7 @@ export function ModelSelectorField() { // Loading state const isProvidersLoading = providers.isLoading; const isModelsLoading = models.isLoading; - + return ( @@ -81,7 +85,7 @@ export function ModelSelectorField() { collection={providerCollection} value={options.provider_id ? [options.provider_id] : []} onValueChange={handleProviderChange} - disabled={isProvidersLoading} + disabled={isProvidersLoading || isGenerating} size="sm" > @@ -116,7 +120,7 @@ export function ModelSelectorField() { collection={modelCollection} value={options.model_id ? [options.model_id] : []} onValueChange={handleModelChange} - disabled={isModelsLoading || !hasProviders} + disabled={isModelsLoading || !hasProviders || isGenerating} size="sm" > diff --git a/graphcap_studio/src/features/inference/generation-options/components/fields/RepetitionPenaltyField.tsx b/graphcap_studio/src/features/inference/generation-options/components/fields/RepetitionPenaltyField.tsx index ad0410cf..fd770f9b 100644 --- a/graphcap_studio/src/features/inference/generation-options/components/fields/RepetitionPenaltyField.tsx +++ b/graphcap_studio/src/features/inference/generation-options/components/fields/RepetitionPenaltyField.tsx @@ -12,7 +12,9 @@ import { OptionField } from "./OptionField"; * Repetition penalty control field component */ export function RepetitionPenaltyField() { - const { options, updateOption, isGenerating } = useGenerationOptions(); + const { options, actions, uiState } = useGenerationOptions(); + const { updateOption } = actions; + const { isGenerating } = uiState; const handleChange = (value: number) => { updateOption("repetition_penalty", value); diff --git a/graphcap_studio/src/features/inference/generation-options/components/fields/ResizeResolutionField.tsx b/graphcap_studio/src/features/inference/generation-options/components/fields/ResizeResolutionField.tsx index 4a1a7d7e..e7b370e3 100644 --- a/graphcap_studio/src/features/inference/generation-options/components/fields/ResizeResolutionField.tsx +++ b/graphcap_studio/src/features/inference/generation-options/components/fields/ResizeResolutionField.tsx @@ -14,7 +14,9 @@ import { useGenerationOptions } from "../../context"; * Field component for adjusting image resize resolution */ export function ResizeResolutionField() { - const { options, updateOption, isGenerating } = useGenerationOptions(); + const { options, actions, uiState } = useGenerationOptions(); + const { updateOption } = actions; + const { isGenerating } = uiState; // Color values for theming const labelColor = useColorModeValue("gray.700", "gray.300"); diff --git a/graphcap_studio/src/features/inference/generation-options/components/fields/TemperatureField.tsx b/graphcap_studio/src/features/inference/generation-options/components/fields/TemperatureField.tsx index 7c56a5b0..c6536029 100644 --- a/graphcap_studio/src/features/inference/generation-options/components/fields/TemperatureField.tsx +++ b/graphcap_studio/src/features/inference/generation-options/components/fields/TemperatureField.tsx @@ -12,7 +12,9 @@ import { OptionField } from "./OptionField"; * Temperature control field component */ export function TemperatureField() { - const { options, updateOption, isGenerating } = useGenerationOptions(); + const { options, actions, uiState } = useGenerationOptions(); + const { updateOption } = actions; + const { isGenerating } = uiState; const handleChange = (value: number) => { updateOption("temperature", value); diff --git a/graphcap_studio/src/features/inference/generation-options/components/fields/TopPField.tsx b/graphcap_studio/src/features/inference/generation-options/components/fields/TopPField.tsx index 1500fc41..84ce0a81 100644 --- a/graphcap_studio/src/features/inference/generation-options/components/fields/TopPField.tsx +++ b/graphcap_studio/src/features/inference/generation-options/components/fields/TopPField.tsx @@ -9,10 +9,12 @@ import { useGenerationOptions } from "../../context"; import { OptionField } from "./OptionField"; /** - * Top P control field component + * Top-P (nucleus sampling) control field component */ export function TopPField() { - const { options, updateOption, isGenerating } = useGenerationOptions(); + const { options, actions, uiState } = useGenerationOptions(); + const { updateOption } = actions; + const { isGenerating } = uiState; const handleChange = (value: number) => { updateOption("top_p", value); diff --git a/graphcap_studio/src/features/inference/generation-options/components/index.ts b/graphcap_studio/src/features/inference/generation-options/components/index.ts index 7e99765c..d8371081 100644 --- a/graphcap_studio/src/features/inference/generation-options/components/index.ts +++ b/graphcap_studio/src/features/inference/generation-options/components/index.ts @@ -6,8 +6,6 @@ */ export * from "./fields"; -export * from "./GenerationOptionForm"; -export * from "./GenerationOptionsButton"; -export * from "./GenerationOptionsDialog"; export * from "./GenerationOptionsPanel"; export * from "./ProviderSelector"; + diff --git a/graphcap_studio/src/features/inference/hooks/index.ts b/graphcap_studio/src/features/inference/hooks/index.ts index 88b23e05..83285847 100644 --- a/graphcap_studio/src/features/inference/hooks/index.ts +++ b/graphcap_studio/src/features/inference/hooks/index.ts @@ -1,4 +1,4 @@ // SPDX-License-Identifier: Apache-2.0 -export * from "./useModelSelection"; export * from "./useDatabaseHealth"; -export * from "./useProviderModelSelection"; +export * from "./useModelSelection"; + diff --git a/graphcap_studio/src/features/inference/hooks/useProviderModelSelection.ts b/graphcap_studio/src/features/inference/hooks/useProviderModelSelection.ts deleted file mode 100644 index 89e5d931..00000000 --- a/graphcap_studio/src/features/inference/hooks/useProviderModelSelection.ts +++ /dev/null @@ -1,76 +0,0 @@ -import { useProviderModels, useProviders } from "@/features/server-connections/services/providers"; -import type { Provider } from "@/types/provider-config-types"; -// SPDX-License-Identifier: Apache-2.0 -import { useMemo } from "react"; - -/** - * Custom hook to handle provider and model selection logic - */ -export function useProviderModelSelection(provider: Provider | null | undefined) { - // Fetch providers from API - const { - data: providers = [], - isLoading: isLoadingProviders, - isError: isProvidersError, - } = useProviders(); - - // Fetch models for the selected provider - const { - data: providerModelsData, - isLoading: isLoadingModels, - isError: isModelsError, - error: modelsError, - } = useProviderModels(provider); - - // Memoize the available providers - const availableProviders = useMemo(() => { - return providers.filter((provider) => provider.isEnabled); - }, [providers]); - - // Determine providers with no models - const providersWithNoModels = useMemo(() => { - const noModelsSet = new Set(); - - if (providerModelsData?.models?.length === 0 && provider?.fetchModels && provider?.name) { - noModelsSet.add(provider.name); - } - - return noModelsSet; - }, [provider?.name, provider?.fetchModels, providerModelsData]); - - // Get default model if available - const defaultModel = useMemo(() => { - if (provider?.defaultModel) { - return { - id: provider.defaultModel, - name: provider.defaultModel, - is_default: true, - }; - } - if (providerModelsData?.models && providerModelsData.models.length > 0) { - return ( - providerModelsData.models.find((model) => model.is_default) || - providerModelsData.models[0] - ); - } - return null; - }, [provider?.defaultModel, providerModelsData]); - - return { - providers: availableProviders, - models: providerModelsData?.models || [], - defaultModel, - providersWithNoModels, - isLoading: { - providers: isLoadingProviders, - models: isLoadingModels, - }, - isError: { - providers: isProvidersError, - models: isModelsError, - }, - error: { - models: modelsError, - }, - }; -} From 37f4fe4e53d2bcc019cffc39b901edfde4cf1832 Mon Sep 17 00:00:00 2001 From: jphillips Date: Sat, 29 Mar 2025 11:33:12 -0500 Subject: [PATCH 33/69] Fix issue with loading providers with no rate limits Signed-off-by: jphillips --- .../features/server-connections/services/providerAdapters.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphcap_studio/src/features/server-connections/services/providerAdapters.ts b/graphcap_studio/src/features/server-connections/services/providerAdapters.ts index f8cea16c..4ad90b7b 100644 --- a/graphcap_studio/src/features/server-connections/services/providerAdapters.ts +++ b/graphcap_studio/src/features/server-connections/services/providerAdapters.ts @@ -81,8 +81,8 @@ export function fromApiProvider(apiProvider: ApiProvider): Provider { updatedAt: model.updatedAt, })), - // Convert nested rate limits - rateLimits: apiProvider.rateLimits + // Convert nested rate limits if defined + rateLimits: apiProvider.rateLimits?.id && apiProvider.rateLimits?.providerId ? { id: normalizeProviderId(apiProvider.rateLimits.id), providerId: normalizeProviderId(apiProvider.rateLimits.providerId), From 51e90e17c596504a8c1e686d7dfbbf0f07267550 Mon Sep 17 00:00:00 2001 From: jphillips Date: Sat, 29 Mar 2025 12:52:01 -0500 Subject: [PATCH 34/69] First split between user defined and api models in provider config Signed-off-by: jphillips --- .../common_inference/ModelSelector.tsx | 2 + .../common_inference/ProviderSelector.tsx | 4 +- .../components/ProviderFormView.tsx | 39 +- .../components/actions/CancelButton.tsx | 4 +- .../actions/ProviderModelActions.tsx | 89 +++ .../components/actions/ProviderSaveDialog.tsx | 533 ++++++++---------- .../actions/TestConnectionButton.tsx | 8 +- .../components/form/ModelSelectionSection.tsx | 109 +++- .../components/form/ModelSelector.tsx | 18 +- .../components/form/ProviderFormSelect.tsx | 36 +- .../components/form/RateLimitsSection.tsx | 81 ++- .../components/form/index.ts | 4 +- .../containers/ProviderFormContainer.tsx | 384 +++++-------- .../providers/context/ProviderFormContext.tsx | 74 +-- .../features/inference/services/providers.ts | 90 +-- .../features/provider_config/controller.ts | 73 ++- 16 files changed, 775 insertions(+), 773 deletions(-) create mode 100644 graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/ProviderModelActions.tsx diff --git a/graphcap_studio/src/components/common_inference/ModelSelector.tsx b/graphcap_studio/src/components/common_inference/ModelSelector.tsx index e8c5ad82..abbd5584 100644 --- a/graphcap_studio/src/components/common_inference/ModelSelector.tsx +++ b/graphcap_studio/src/components/common_inference/ModelSelector.tsx @@ -18,6 +18,8 @@ import { export interface ModelOption { label: string; value: string; + id: string; + is_default?: boolean; } export interface ModelSelectorProps { diff --git a/graphcap_studio/src/components/common_inference/ProviderSelector.tsx b/graphcap_studio/src/components/common_inference/ProviderSelector.tsx index 0daa99b2..1135463e 100644 --- a/graphcap_studio/src/components/common_inference/ProviderSelector.tsx +++ b/graphcap_studio/src/components/common_inference/ProviderSelector.tsx @@ -18,12 +18,12 @@ import { export interface ProviderOption { label: string; value: string; - id?: number; + id: string; } export interface ProviderSelectorProps { readonly options: ProviderOption[]; - readonly value: string | number | null | undefined; + readonly value: string | null | undefined; readonly onChange: (value: string) => void; readonly isDisabled?: boolean; readonly maxWidth?: string | number; diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderFormView.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderFormView.tsx index 7a2b4217..64d6aa64 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderFormView.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderFormView.tsx @@ -13,26 +13,21 @@ import { ProviderFormSelect } from "./form/ProviderFormSelect"; */ export function ProviderFormView() { const { - onSubmit, handleSubmit, isSubmitting, - saveSuccess, - dialogs, + dialog, closeDialog, - formError, - connectionError, + error, connectionDetails, - selectedProvider, - setMode, - mode - } = useProviderFormContext(); + provider, + setMode } = useProviderFormContext(); const handleAddProvider = () => { setMode("create"); }; return ( - + {/* Provider Selection Section */} @@ -56,32 +51,32 @@ export function ProviderFormView() { {/* Form Error Dialog */} closeDialog("formError")} - error={formError} - providerName={selectedProvider?.name || "Provider"} + isOpen={dialog === "formError"} + onClose={() => closeDialog()} + error={error} + providerName={provider?.name || "Provider"} /> {/* Connection Error Dialog */} closeDialog("error")} - error={connectionError} - providerName={selectedProvider?.name || "Provider"} + isOpen={dialog === "error"} + onClose={() => closeDialog()} + error={error} + providerName={provider?.name || "Provider"} /> {/* Success Dialog */} closeDialog("success")} - providerName={selectedProvider?.name || "Provider"} + isOpen={dialog === "success"} + onClose={() => closeDialog()} + providerName={provider?.name || "Provider"} connectionDetails={connectionDetails} /> diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/CancelButton.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/CancelButton.tsx index 57860c6a..0c996b5c 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/CancelButton.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/CancelButton.tsx @@ -7,7 +7,7 @@ import { useProviderFormContext } from "../../../context/ProviderFormContext"; * Button component for canceling provider form changes */ export function CancelButton() { - const { onCancel } = useProviderFormContext(); + const { cancelEdit } = useProviderFormContext(); const { colorMode } = useColorMode(); const isDark = colorMode === "dark"; @@ -19,7 +19,7 @@ export function CancelButton() { return ( +
+ + {/* Show current models */} + {currentModels.length > 0 && ( + + Current Models + {currentModels.map((model, index) => ( + + {model.name} + + + ))} + + )} +
+ ); +} \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/ProviderSaveDialog.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/ProviderSaveDialog.tsx index ee182871..d3b17de3 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/ProviderSaveDialog.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/ProviderSaveDialog.tsx @@ -1,4 +1,6 @@ -import type { Provider, ProviderCreate, ProviderUpdate } from "@/types/provider-config-types"; +import type { + Provider +} from "@/types/provider-config-types"; import { Box, Button, @@ -11,311 +13,238 @@ import { } from "@chakra-ui/react"; // SPDX-License-Identifier: Apache-2.0 import { useState } from "react"; -import { useCreateProvider, useUpdateProvider } from "../../../../services/providers"; +import { + useCreateProvider, + useUpdateProvider, +} from "../../../../services/providers"; import { useProviderFormContext } from "../../../context/ProviderFormContext"; -// Define error type with message property -interface ErrorWithMessage { - message: string; - [key: string]: unknown; -} /** * Unified component that combines the save button and save dialog functionality */ export function SaveButton() { - const { - isSubmitting: isContextSubmitting, - isCreating, - mode, - handleSubmit, - selectedProvider, - saveError: contextSaveError - } = useProviderFormContext(); - - // Local state for dialog visibility and save state - const [isDialogOpen, setIsDialogOpen] = useState(false); - const [isSaving, setIsSaving] = useState(false); - const [saveComplete, setSaveComplete] = useState(false); - const [savingProvider, setSavingProvider] = useState(null); - const [saveError, setSaveError] = useState(contextSaveError); - - // Get provider service functions - const { mutateAsync: createProviderAsync, isPending: isCreatingProvider } = useCreateProvider(); - const { mutateAsync: updateProviderAsync, isPending: isUpdatingProvider } = useUpdateProvider(); - - // Determine if form is submitting - const isSubmitting = isContextSubmitting || isSaving || isCreatingProvider || isUpdatingProvider; - - // Determine the button text based on form state - let buttonText = "Save"; - if (isSubmitting) { - buttonText = "Saving..."; - } else if (isCreating) { - buttonText = "Create"; - } - - // Function to close the dialog - const closeDialog = () => { - setIsDialogOpen(false); - setSaveComplete(false); - setSavingProvider(null); - setSaveError(undefined); - }; - - // Helper function to normalize rate limits - const normalizeRateLimits = (data: ProviderCreate | ProviderUpdate) => { - if (!data.rateLimits) return; - - if (Array.isArray(data.rateLimits) || typeof data.rateLimits !== 'object') { - const currentRateLimits = data.rateLimits as unknown; - const requestsPerMinute = typeof currentRateLimits === 'object' && currentRateLimits !== null - ? (currentRateLimits as Record).requestsPerMinute as number ?? 0 - : 0; - const tokensPerMinute = typeof currentRateLimits === 'object' && currentRateLimits !== null - ? (currentRateLimits as Record).tokensPerMinute as number ?? 0 - : 0; - - data.rateLimits = { requestsPerMinute, tokensPerMinute }; - } - }; - - // Handle form submission errors - const handleSaveError = (error: unknown) => { - console.error("Error saving provider:", error); - if (error instanceof Error) { - setSaveError(error.message); - } else if (typeof error === 'object' && error !== null && 'message' in error) { - const errorWithMsg = error as ErrorWithMessage; - setSaveError(errorWithMsg.message); - } else { - setSaveError("An unknown error occurred"); - } - }; - - // Save the provider using the appropriate service function - const saveProvider = async (data: ProviderCreate | ProviderUpdate): Promise => { - // Make sure API key is included in update requests - if (!data.apiKey && selectedProvider?.apiKey && mode === 'edit') { - console.log("Including existing API key in update"); - data.apiKey = selectedProvider.apiKey; - } - - // Log the full data we're sending (redact the actual API key) - console.log("Sending to server:", { - ...data, - apiKey: data.apiKey ? "[PRESENT]" : "[MISSING]", - mode - }); - - try { - // Edit mode with selected provider - if (mode === "edit" && selectedProvider?.id) { - console.log(`Updating provider with id ${selectedProvider.id}`); - const result = await updateProviderAsync({ - id: selectedProvider.id, - data: data as ProviderUpdate - }); - console.log("Provider updated successfully:", result); - return result; - } - - // Create mode - if (mode === "create") { - console.log("Creating new provider"); - const result = await createProviderAsync(data as ProviderCreate); - console.log("Provider created successfully:", result); - return result; - } - - // Fallback path - has ID in data - if ('id' in data && data.id) { - const id = data.id as number; - console.log(`Updating provider with id ${id}`); - const result = await updateProviderAsync({ id, data }); - console.log("Provider updated successfully:", result); - return result; - } - - // Default create path - console.log("Creating new provider (fallback path)"); - const result = await createProviderAsync(data as ProviderCreate); - console.log("Provider created successfully:", result); - return result; - } catch (error) { - handleSaveError(error); - throw error; - } - }; - - // Custom submit handler that shows the dialog and processes the form - const handleFormSubmit = async (e: React.FormEvent) => { - try { - setIsSaving(true); - setIsDialogOpen(true); - setSaveError(undefined); - - // Process the form submission through the form's handleSubmit - const formHandler = handleSubmit(async (data: ProviderCreate | ProviderUpdate) => { - try { - console.log("Provider form submitted:", data); - - // Normalize rateLimits - ensure it's an object, not an array - normalizeRateLimits(data); - - // Save the provider - const result = await saveProvider(data); - - // Success - store the provider details and mark as complete - setSavingProvider(result); - setSaveComplete(true); - } catch (error) { - handleSaveError(error); - } - }); - - // Execute the form handler - formHandler(e); - } catch (error) { - console.error("Form submission error:", error); - if (error instanceof Error) { - setSaveError(error.message); - } else { - setSaveError("Form validation failed"); - } - } finally { - setIsSaving(false); - } - }; - - // Get the current provider to display - const displayProvider = savingProvider || selectedProvider; - - // Determine dialog title - let dialogTitle = "Processing..."; - if (saveError) { - dialogTitle = "Error Saving Provider"; - } else if (isSaving) { - dialogTitle = "Saving Provider..."; - } else if (saveComplete) { - dialogTitle = "Provider Saved"; - } - - // Render dialog body content based on state - const renderDialogBody = () => { - if (saveError) { - return ( - - {saveError || "An unknown error occurred"} - - ); - } - - if (isSaving) { - return ( - - - Saving provider configuration to server... - Please wait while we process your request - - ); - } - - if (saveComplete && displayProvider) { - return ( - - - - Name: {displayProvider.name} - Kind: {displayProvider.kind} - Environment: {displayProvider.environment} - Base URL: {displayProvider.baseUrl} - {displayProvider.fetchModels && ( - Default Model: {displayProvider.defaultModel ?? "Not set"} - )} - - - - ); - } - - return ( - - Initializing save process... - - ); - }; - - return ( - <> - - - {/* Provider Save Dialog */} - !isSaving && setIsDialogOpen(e.open)} - > - - - - - - {dialogTitle} - - - - - - - {renderDialogBody()} - - - - - - - - - - - ); -} \ No newline at end of file + const { + isSubmitting: isContextSubmitting, + mode, + provider: selectedProvider, + error: contextSaveError, + handleSubmit, + } = useProviderFormContext(); + + // Local state for dialog visibility and save state + const [isDialogOpen, setIsDialogOpen] = useState(false); + const [isSaving, setIsSaving] = useState(false); + const [saveComplete, setSaveComplete] = useState(false); + const [savingProvider, setSavingProvider] = useState(null); + const [saveError, setSaveError] = useState( + contextSaveError?.message, + ); + + // Get provider service functions + const { isPending: isCreatingProvider } = + useCreateProvider(); + const { mutateAsync: updateProviderAsync, isPending: isUpdatingProvider } = + useUpdateProvider(); + + // Determine if form is submitting + const isSubmitting = + isContextSubmitting || isSaving || isCreatingProvider || isUpdatingProvider; + + // Determine the button text based on form state + let buttonText = "Save"; + if (isSubmitting) { + buttonText = "Saving..."; + } else if (mode === "create") { + buttonText = "Create"; + } + + // Function to close the dialog + const closeDialog = () => { + setIsDialogOpen(false); + setSaveComplete(false); + setSavingProvider(null); + setSaveError(undefined); + }; + + // Handle form submission errors + + // Save the provider using the appropriate service function + + // Custom submit handler that shows the dialog and processes the form + const handleFormSubmit = async (e: React.FormEvent) => { + try { + setIsSaving(true); + setIsDialogOpen(true); + setSaveError(undefined); + + // Try to save the provider using the context handleSubmit + await handleSubmit(e as React.BaseSyntheticEvent); + + // If we get here without errors, it means the form was submitted successfully + setSaveComplete(true); + + } catch (error) { + console.error("Form submission error:", error); + if (error instanceof Error) { + setSaveError(error.message); + } else { + setSaveError("Form validation failed"); + } + } finally { + setIsSaving(false); + } + }; + + // Get the current provider to display + const displayProvider = savingProvider || selectedProvider; + + // Determine dialog title + let dialogTitle = "Processing..."; + if (saveError) { + dialogTitle = "Error Saving Provider"; + } else if (isSaving) { + dialogTitle = "Saving Provider..."; + } else if (saveComplete) { + dialogTitle = "Provider Saved"; + } + + // Render dialog body content based on state + const renderDialogBody = () => { + if (saveError) { + return ( + + {saveError || "An unknown error occurred"} + + ); + } + + if (isSaving) { + return ( + + + Saving provider configuration to server... + + Please wait while we process your request + + + ); + } + + if (saveComplete && displayProvider) { + return ( + + + + + Name: {displayProvider.name} + + + Kind: {displayProvider.kind} + + + Environment: {displayProvider.environment} + + + Base URL: {displayProvider.baseUrl} + + {displayProvider.fetchModels && ( + + Default Model:{" "} + {displayProvider.defaultModel ?? "Not set"} + + )} + + + + ); + } + + return ( + + Initializing save process... + + ); + }; + + return ( + <> + + + {/* Provider Save Dialog */} + !isSaving && setIsDialogOpen(e.open)} + > + + + + + + {dialogTitle} + + + + + + {renderDialogBody()} + + + + + + + + + + ); +} diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/TestConnectionButton.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/TestConnectionButton.tsx index 7d07a319..6b71a71d 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/TestConnectionButton.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/TestConnectionButton.tsx @@ -6,17 +6,17 @@ import { useProviderFormContext } from "../../../context/ProviderFormContext"; * Button component for testing provider connection */ export function TestConnectionButton() { - const { isTestingConnection, handleTestConnection, selectedProvider } = useProviderFormContext(); + const { isSubmitting, testConnection, provider } = useProviderFormContext(); return ( ); } \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelectionSection.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelectionSection.tsx index 4a59a1d9..bc9c2df8 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelectionSection.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelectionSection.tsx @@ -1,8 +1,9 @@ import { ActionButton } from "@/components/ui/buttons/ActionButton"; import { StatusMessage } from "@/components/ui/status/StatusMessage"; // SPDX-License-Identifier: Apache-2.0 -import { Box } from "@chakra-ui/react"; +import { Box, Heading, VStack } from "@chakra-ui/react"; import { useProviderFormContext } from "../../../context/ProviderFormContext"; +import { ProviderModelActions } from "../actions/ProviderModelActions"; import { ModelSelector } from "./ModelSelector"; // Define the model type @@ -17,18 +18,53 @@ export interface ProviderModel { */ export function ModelSelectionSection() { const { - selectedProvider, - providerModelsData, + provider, + providerModels, isLoadingModels, - isModelsError, - modelsError, selectedModelId, setSelectedModelId, - handleModelSelect, isSubmitting, + mode, + watch, } = useProviderFormContext(); - const providerName = selectedProvider?.name; + const providerName = provider?.name; + const isEditMode = mode === "edit" || mode === "create"; + const customModels = watch("models") || []; + const fetchModels = provider?.fetchModels || watch("fetchModels"); + + // Prepare an array with all models to display + const allModels = []; + + // Always add custom/user-defined models + if (customModels && customModels.length > 0) { + // Map custom models to the format expected by the model selector + for (const model of customModels) { + allModels.push({ + // Generate a stable ID for custom models + id: typeof model.id === 'string' ? model.id : `custom-${model.name}`, + name: model.name, + is_default: provider?.defaultModel === model.name, + isCustom: true + }); + } + } + + // Add API-fetched models if fetchModels is true + if (fetchModels && providerModels && providerModels.length > 0) { + // Map API models to the format expected by the model selector + for (const model of providerModels) { + // Only add if not already included in custom models + if (!customModels.some(m => m.name === model.name)) { + allModels.push({ + id: model.id, + name: model.name, + is_default: model.is_default, + isApiModel: true + }); + } + } + } // Handle different states if (!providerName) { @@ -40,36 +76,57 @@ export function ModelSelectionSection() { ); } - if (isLoadingModels) { - return ; - } - - if (isModelsError) { + // When in edit mode, show model management section + if (isEditMode) { return ( - + + + Model Configuration + + + + {fetchModels && isLoadingModels && ( + + )} + + {allModels.length > 0 && ( + + Default Model Selection + ({ + label: `${model.name}${model.is_default ? " (Default)" : ""}${model.isCustom ? " (Custom)" : ""}${model.isApiModel ? " (API)" : ""}`, + value: model.id, + id: model.id, + }))} + selectedModelId={selectedModelId} + setSelectedModelId={setSelectedModelId} + /> + + )} + ); } - if (!providerModelsData?.models || providerModelsData.models.length === 0) { + // View mode + if (fetchModels && isLoadingModels) { + return ; + } + + if (allModels.length === 0) { return ( ); } - // Convert models to the format expected by SelectRoot - const modelItems = providerModelsData.models.map((model: ProviderModel) => ({ - label: `${model.name}${model.is_default ? " (Default)" : ""}`, + // Convert all models to the format expected by SelectRoot + const modelItems = allModels.map(model => ({ + label: `${model.name}${model.is_default ? " (Default)" : ""}${model.isCustom ? " (Custom)" : ""}${model.isApiModel ? " (API)" : ""}`, value: model.id, + id: model.id, })); return ( @@ -81,7 +138,7 @@ export function ModelSelectionSection() { /> console.log("Selected model:", selectedModelId)} disabled={!selectedModelId} isLoading={isSubmitting} /> diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelector.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelector.tsx index b3e5bc27..3864ee95 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelector.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelector.tsx @@ -1,17 +1,17 @@ // SPDX-License-Identifier: Apache-2.0 -import { ModelSelector as GenericModelSelector, ModelOption } from "@/components/common_inference/ModelSelector"; +import { + ModelSelector as GenericModelSelector, + type ModelOption, +} from "@/components/common_inference/ModelSelector"; import { useColorMode } from "@/components/ui/theme/color-mode"; import { Box, Heading, Text } from "@chakra-ui/react"; -export interface ModelItem { - label: string; - value: string; -} +export type ModelItem = ModelOption; export interface ModelSelectorProps { modelItems: ModelItem[]; selectedModelId: string | null; - setSelectedModelId: (id: string) => void; + setSelectedModelId: (id: string | null) => void; } /** @@ -30,6 +30,10 @@ export function ModelSelector({ const headingColor = isDark ? "gray.100" : "gray.700"; const labelColor = isDark ? "gray.300" : "gray.600"; + const handleModelChange = (value: string) => { + setSelectedModelId(value || null); + }; + return ( ); diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ProviderFormSelect.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ProviderFormSelect.tsx index c4b93e6f..fb47a2af 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ProviderFormSelect.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ProviderFormSelect.tsx @@ -1,5 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 -import { ProviderSelector, type ProviderOption } from "@/components/common_inference/ProviderSelector"; +import { type ProviderOption, ProviderSelector } from "@/components/common_inference/ProviderSelector"; +import type { Provider } from "@/types/provider-config-types"; +import { useProviders } from "../../../../services/providers"; import { useProviderFormContext } from "../../../context/ProviderFormContext"; type ProviderFormSelectProps = { @@ -15,37 +17,37 @@ export function ProviderFormSelect({ className, "aria-label": ariaLabel = "Select Provider", }: ProviderFormSelectProps) { - // Get providers and selection functions from the form context - const { providers, selectedProvider, setSelectedProvider } = useProviderFormContext(); - - const selectedProviderId = selectedProvider?.id ?? null; - + // Get provider data from context + const { provider, setProvider } = useProviderFormContext(); + + // Fetch providers directly + const { data: providers = [] } = useProviders(); + // Convert providers to the format expected by ProviderSelector - const providerOptions: ProviderOption[] = providers.map((provider) => ({ - label: provider.name, - value: String(provider.id), - id: provider.id, + const providerOptions: ProviderOption[] = providers.map((p: Provider) => ({ + label: p.name, + value: String(p.id), + id: String(p.id), })); const handleProviderChange = (value: string) => { if (!value) return; - const id = Number(value); - const provider = providers.find((p) => p.id === id); - if (provider) { - // Call the context's setSelectedProvider function - // This will update the global context and reset the form - setSelectedProvider(provider); + // Find the selected provider from the providers list + const selectedProvider = providers.find((p: Provider) => String(p.id) === value); + if (selectedProvider) { + setProvider(selectedProvider); } }; return ( ); } \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/RateLimitsSection.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/RateLimitsSection.tsx index 3fe9fcc7..82db53ca 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/RateLimitsSection.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/RateLimitsSection.tsx @@ -10,7 +10,7 @@ import { } from "@chakra-ui/react"; // SPDX-License-Identifier: Apache-2.0 import type { ChangeEvent } from "react"; -import { Controller, useController } from "react-hook-form"; +import { Controller } from "react-hook-form"; import { useProviderFormContext } from "../../../context/ProviderFormContext"; /** @@ -23,17 +23,8 @@ export function RateLimitsSection() { const textColor = useColorModeValue("gray.700", "gray.200"); // Watch form values for read-only display - const rateLimits = watch("rateLimits"); - - // Ensure rateLimits object exists in the form - useController({ - name: "rateLimits", - control, - defaultValue: { - requestsPerMinute: 0, - tokensPerMinute: 0, - }, - }); + const formValues = watch(); + const rateLimits = formValues.rateLimits || { requestsPerMinute: 0, tokensPerMinute: 0 }; if (!isEditing) { return ( @@ -48,14 +39,14 @@ export function RateLimitsSection() { Requests per minute - {rateLimits?.requestsPerMinute ?? 0} + {rateLimits.requestsPerMinute ?? 0} Tokens per minute - {rateLimits?.tokensPerMinute ?? 0} + {rateLimits.tokensPerMinute ?? 0} @@ -69,63 +60,65 @@ export function RateLimitsSection() { Rate Limits - - - ( + + {/* Use a single Controller for the entire rateLimits object + This ensures we always have an object structure */} + ( + + Requests per minute ) => - onChange(Number.parseInt(e.target.value) || 0) - } + value={field.value?.requestsPerMinute ?? 0} + onChange={(e: ChangeEvent) => { + const value = Number.parseInt(e.target.value) || 0; + field.onChange({ + ...field.value, + requestsPerMinute: value + }); + }} min={0} /> {errors.rateLimits?.requestsPerMinute?.message} - )} - /> - + - - ( + Tokens per minute ) => - onChange(Number.parseInt(e.target.value) || 0) - } + value={field.value?.tokensPerMinute ?? 0} + onChange={(e: ChangeEvent) => { + const value = Number.parseInt(e.target.value) || 0; + field.onChange({ + ...field.value, + tokensPerMinute: value + }); + }} min={0} /> {errors.rateLimits?.tokensPerMinute?.message} - )} - /> - - + + + )} + /> ); diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/index.ts b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/index.ts index bddf3d50..1c698340 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/index.ts +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/index.ts @@ -1,7 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 export * from "./BasicInfoSection"; export * from "./ConnectionSection"; -export * from "./RateLimitsSection"; export * from "./EnvironmentSelect"; +export * from "./ModelSelectionSection"; export * from "./ModelSelector"; export * from "./ProviderFormSelect"; +export * from "./RateLimitsSection"; + diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx index db00d006..b4fa86c1 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx @@ -1,5 +1,5 @@ import type { ConnectionDetails, ErrorDetails, Provider, ProviderCreate, ProviderUpdate } from "@/types/provider-config-types"; -import { toServerConfig } from "@/types/provider-config-types"; +import { denormalizeProviderId, toServerConfig } from "@/types/provider-config-types"; // SPDX-License-Identifier: Apache-2.0 import type { ReactNode } from "react"; import { useCallback, useState } from "react"; @@ -8,336 +8,222 @@ import { useCreateProvider, useProviders, useTestProviderConnection, useUpdatePr import { useInferenceProviderContext } from "../../context/InferenceProviderContext"; import { ProviderFormProvider } from "../../context/ProviderFormContext"; -// Extended Error interface with cause property -interface ErrorWithCause extends Error { - cause?: unknown; -} +// Simplified dialog state type +type DialogType = null | "error" | "success" | "formError" | "save"; interface ProviderFormContainerProps { children: ReactNode; initialData?: Partial; - onSubmit: (data: ProviderCreate | ProviderUpdate) => Promise; } -/** - * Container component that provides the ProviderFormContext - */ export function ProviderFormContainer({ children, initialData, }: ProviderFormContainerProps) { - // Get model selection and provider state from the InferenceProviderContext + // Get required context from parent const { - mode, - setMode, - selectedProvider, - setSelectedProvider, - providers: contextProviders, + mode: contextMode, + setMode: setContextMode, + selectedProvider: contextSelectedProvider, selectedModelId, - setSelectedModelId, + setSelectedModelId, providerModelsData, isLoadingModels, - isModelsError, - modelsError, - handleModelSelect, onCancel: onContextCancel, } = useInferenceProviderContext(); - // Setup react-hook-form for provider form + // State for the provider form + const [mode, setMode] = useState(contextMode); + const [provider, setProvider] = useState(contextSelectedProvider); + const [isSubmitting, setIsSubmitting] = useState(false); + const [dialog, setDialog] = useState(null); + const [error, setError] = useState(null); + const [connectionDetails, setConnectionDetails] = useState(null); + + // Form setup const { control, - handleSubmit: hookHandleSubmit, + handleSubmit: formHandleSubmit, formState: { errors }, watch, reset } = useForm({ - defaultValues: initialData, + defaultValues: initialData || provider || {}, }); - // Local state for the provider form - const [isSubmitting, setIsSubmitting] = useState(false); - const [saveSuccess, setSaveSuccess] = useState(false); - const [isTestingConnection, setIsTestingConnection] = useState(false); - const [formError, setFormError] = useState(null); - const [connectionError, setConnectionError] = useState(null); - const [connectionDetails, setConnectionDetails] = useState(null); - const [dialogs, setDialogs] = useState({ - error: false, - success: false, - formError: false, - save: false, - }); - - // Fetch providers - const { data: providers = [] } = useProviders(); - - // API connection test hook + // API hooks + useProviders(); const testConnection = useTestProviderConnection(); - // Add hooks for creating and updating providers const createProvider = useCreateProvider(); const updateProvider = useUpdateProvider(); - const handleSelectProvider = (provider: Provider | null) => { - setSelectedProvider(provider); + // Use either API providers or context providers + + // Handle provider selection + const handleProviderSelect = useCallback((newProvider: Provider | null) => { + setProvider(newProvider); - // If a provider is selected, reset the form with its data - if (provider) { - // Debug log to check if API key is present - console.log("Provider selected for edit:", { - ...provider, - apiKey: provider.apiKey ? "[PRESENT]" : "[MISSING]" - }); - - // Reset the form with the selected provider's data - const providerData: ProviderUpdate = { - name: provider.name, - kind: provider.kind, - environment: provider.environment, - baseUrl: provider.baseUrl, - apiKey: provider.apiKey || "", // Ensure apiKey is included and not null/undefined - isEnabled: provider.isEnabled, - defaultModel: provider.defaultModel, - fetchModels: provider.fetchModels, - models: provider.models, - rateLimits: provider.rateLimits || { requestsPerMinute: 0, tokensPerMinute: 0 } - }; - - // Log the data being used to reset the form - console.log("Resetting form with:", { - ...providerData, - apiKey: providerData.apiKey ? "[PRESENT]" : "[MISSING]" + if (newProvider) { + // Reset form with provider data + reset({ + name: newProvider.name, + kind: newProvider.kind, + environment: newProvider.environment, + baseUrl: newProvider.baseUrl, + apiKey: newProvider.apiKey || "", + isEnabled: newProvider.isEnabled, + defaultModel: newProvider.defaultModel, + fetchModels: newProvider.fetchModels, + models: newProvider.models, + rateLimits: newProvider.rateLimits || { requestsPerMinute: 0, tokensPerMinute: 0 } }); - - reset(providerData); } - }; + }, [reset]); - // Function to close any dialog - const closeDialog = (dialog: "error" | "success" | "formError" | "save") => { - setDialogs(prev => ({ ...prev, [dialog]: false })); - }; + // Dialog handlers + const openDialog = useCallback((type: DialogType, newError?: ErrorDetails) => { + setDialog(type); + if (newError) setError(newError); + }, []); - // Function to open the save dialog - const openSaveDialog = useCallback(() => { - setDialogs(prev => ({ ...prev, save: true })); + const closeDialog = useCallback(() => { + setDialog(null); }, []); - // Handle form submission - const onSubmit = async (data: ProviderCreate | ProviderUpdate) => { + // Form submission + const handleSubmit = async (e?: React.BaseSyntheticEvent) => { + // If an event was provided, prevent default + if (e) { + e.preventDefault(); + } + try { setIsSubmitting(true); - setFormError(null); - setSaveSuccess(false); + setError(null); - if (mode === "edit" && selectedProvider?.id) { - // Update existing provider + // Use formHandleSubmit to get data from the form + const formData = await new Promise((resolve) => { + formHandleSubmit((data) => { + resolve(data); + })(e); + }); + + if (mode === "edit" && provider?.id) { await updateProvider.mutateAsync({ - id: selectedProvider.id, - data: data as ProviderUpdate + id: denormalizeProviderId(provider.id), + data: formData as ProviderUpdate }); } else if (mode === "create") { - // Create new provider - await createProvider.mutateAsync(data as ProviderCreate); - } else { - + await createProvider.mutateAsync(formData as ProviderCreate); } - setSaveSuccess(true); - openSaveDialog(); - - // Switch back to view mode after successful save + openDialog("success"); setMode("view"); + setContextMode("view"); - // Reset success message after 3 seconds - setTimeout(() => { - setSaveSuccess(false); - }, 3000); - } catch (error) { - console.error("Provider form submission error:", error); + } catch (err) { + console.error("Provider form submission error:", err); + const errorDetails: ErrorDetails = err instanceof Error + ? { message: err.message, code: err.name, details: { error: err.toString() } } + : { message: String(err), details: { error } }; - // Convert error to ErrorDetails format - let errorObj: ErrorDetails; - if (error instanceof Error) { - errorObj = { - message: error.message, - code: error.name, - details: { - error: error.toString() - } - }; - - // Try to extract cause if it exists - const errorWithCause = error as ErrorWithCause; - if ('cause' in error && errorWithCause.cause !== undefined) { - errorObj.details = { - ...errorObj.details, - cause: errorWithCause.cause - }; - } - } else if (typeof error === 'object' && error !== null) { - errorObj = error as ErrorDetails; - } else { - errorObj = { - message: String(error), - details: { error } - }; - } + setError(errorDetails); + openDialog("formError"); - setFormError(errorObj); - setDialogs(prev => ({ ...prev, formError: true })); + // Re-throw the error so the caller knows something went wrong + throw err; } finally { setIsSubmitting(false); } }; - // Handle cancel - use context's cancel handler - const onCancel = useCallback(() => { - onContextCancel(); - }, [onContextCancel]); - - // Handle test connection - const handleTestConnection = async () => { - if (!selectedProvider) return; - - // Validate API key is present - if (!selectedProvider.apiKey) { - setConnectionError({ + // Connection test + const testProviderConnection = async () => { + if (!provider) return; + + if (!provider.apiKey) { + setError({ message: "API key is required", code: "ValidationError", - details: { - title: "Connection failed", - timestamp: new Date().toISOString(), - message: "API key is required", - name: "ValidationError", - details: "Please provide an API key in the provider configuration.", - suggestions: [ - "Edit the provider to add an API key", - "API keys should be non-empty strings", - ], - } + details: { message: "API key is required" } }); - setDialogs(prev => ({ ...prev, error: true })); + openDialog("error"); return; } - setIsTestingConnection(true); - setConnectionError(null); - try { - const config = toServerConfig(selectedProvider); + setIsSubmitting(true); + setError(null); + + const config = toServerConfig(provider); const result = await testConnection.mutateAsync({ - providerName: selectedProvider.name, + providerName: provider.name, config, }); setConnectionDetails(result); - setDialogs(prev => ({ ...prev, success: true })); - } catch (error) { - console.error("Connection test failed:", error); - - let errorObj: ErrorDetails = { - message: "Connection failed", - code: "ConnectionError", - details: { - title: "Connection failed", - timestamp: new Date().toISOString(), - } - }; - - if (error instanceof Error) { - errorObj.message = error.message; - errorObj.code = error.name; - - if (error.message?.includes("[object Object]")) { - errorObj.message = "Invalid provider configuration"; - errorObj.details = { - ...errorObj.details, - details: "The server rejected the request due to invalid parameters.", - suggestions: [ - "Check API key and endpoint URL", - "Verify the provider is correctly configured", - "Check server logs for more details", - ] - }; - } - - const errorWithCause = error as ErrorWithCause; - if ('cause' in error && errorWithCause.cause !== undefined) { - errorObj.details = { - ...errorObj.details, - cause: errorWithCause.cause - }; - } - } else if (typeof error === "object" && error !== null) { - errorObj = { - ...errorObj, - ...(error as ErrorDetails), - }; - } else { - errorObj.message = String(error); - } - - setConnectionError(errorObj); - setDialogs(prev => ({ ...prev, error: true })); + openDialog("success"); + } catch (err) { + console.error("Connection test failed:", err); + + const errorDetails: ErrorDetails = err instanceof Error + ? { message: err.message, code: err.name, details: { error: err.toString() } } + : { message: String(err), details: { error } }; + + setError(errorDetails); + openDialog("error"); } finally { - setIsTestingConnection(false); + setIsSubmitting(false); } }; - // Form submission handler - const handleSubmit = (handler: (data: ProviderCreate | ProviderUpdate) => Promise) => { - return hookHandleSubmit(async (data) => { - try { - await handler(data); - } catch (error) { - console.error("Form submission error:", error); - } - }); - }; + // Update mode in both local and context state + const handleSetMode = useCallback((newMode: "view" | "edit" | "create") => { + setMode(newMode); + setContextMode(newMode); + }, [setContextMode]); + + // Handle model selection with proper type handling + const handleSetSelectedModelId = useCallback((id: string | null) => { + if (id !== null) { + setSelectedModelId(id); + } + }, [setSelectedModelId]); return ( 0 ? providers : contextProviders, - // Form related properties + // Form state control, errors, watch, - // Model selection properties - providerModelsData, + // UI state + isSubmitting, + dialog, + error, + connectionDetails, + + // Selected model state + selectedModelId, + providerModels: providerModelsData?.models || null, isLoadingModels, - isModelsError, - modelsError, - selectedModelId: selectedModelId || null, - setSelectedModelId: (id: string | null) => { - if (id !== null) { - setSelectedModelId(id); - } - }, - handleModelSelect, - onSubmit, - onCancel, - handleSubmit, - handleTestConnection, - setMode, + // Actions + setProvider: handleProviderSelect, + setMode: handleSetMode, + setSelectedModelId: handleSetSelectedModelId, + openDialog, closeDialog, - openSaveDialog, + + // Form actions + handleSubmit, + cancelEdit: onContextCancel, + testConnection: testProviderConnection, }} > {children} diff --git a/graphcap_studio/src/features/inference/providers/context/ProviderFormContext.tsx b/graphcap_studio/src/features/inference/providers/context/ProviderFormContext.tsx index 5b4e3eaf..d1ff7614 100644 --- a/graphcap_studio/src/features/inference/providers/context/ProviderFormContext.tsx +++ b/graphcap_studio/src/features/inference/providers/context/ProviderFormContext.tsx @@ -1,64 +1,51 @@ -import type { ConnectionDetails, ErrorDetails, Provider, ProviderCreate, ProviderUpdate } from "@/types/provider-config-types"; // SPDX-License-Identifier: Apache-2.0 +import type { ConnectionDetails, ErrorDetails, Provider, ProviderCreate, ProviderUpdate } from "@/types/provider-config-types"; import { type ReactNode, createContext, useContext } from "react"; import type { Control, FieldErrors, UseFormWatch } from "react-hook-form"; +// Simplified dialog state type +type DialogType = null | "error" | "success" | "formError" | "save"; + interface ProviderFormContextType { + // Core state + provider: Provider | null; mode: "view" | "edit" | "create"; - isSubmitting: boolean; - isCreating: boolean; - saveSuccess: boolean; - isTestingConnection: boolean; - selectedProvider: Provider | null; - setSelectedProvider: (provider: Provider | null) => void; - formError: ErrorDetails | null; - connectionError: ErrorDetails | null; - connectionDetails: ConnectionDetails | null; - dialogs: { - error: boolean; - success: boolean; - formError: boolean; - save: boolean; - }; - saveError?: string; - savedProvider: Provider | null; - providers: Provider[]; - // Form handling properties + // Form state control: Control; errors: FieldErrors; watch: UseFormWatch; - // Model selection properties - providerModelsData: { models: Array<{ id: string; name: string; is_default?: boolean }> } | null; - isLoadingModels: boolean; - isModelsError: boolean; - modelsError: Error | null; + // UI state + isSubmitting: boolean; + dialog: DialogType; + error: ErrorDetails | null; + connectionDetails: ConnectionDetails | null; + + // Selected model state selectedModelId: string | null; - setSelectedModelId: (id: string | null) => void; - handleModelSelect: () => void; + providerModels: Array<{ id: string; name: string; is_default?: boolean }> | null; + isLoadingModels: boolean; - onSubmit: (data: ProviderCreate | ProviderUpdate) => Promise; - onCancel: () => void; - handleSubmit: ( - handler: (data: ProviderCreate | ProviderUpdate) => Promise, - ) => (e: React.FormEvent) => void; - handleTestConnection: () => Promise; + // Actions + setProvider: (provider: Provider | null) => void; setMode: (mode: "view" | "edit" | "create") => void; - closeDialog: (dialog: "error" | "success" | "formError" | "save") => void; - openSaveDialog: () => void; + setSelectedModelId: (id: string | null) => void; + openDialog: (type: DialogType, error?: ErrorDetails) => void; + closeDialog: () => void; + + // Form actions + handleSubmit: (e?: React.BaseSyntheticEvent) => Promise; + cancelEdit: () => void; + testConnection: () => Promise; } -const ProviderFormContext = createContext( - undefined, -); +const ProviderFormContext = createContext(undefined); export function useProviderFormContext() { const context = useContext(ProviderFormContext); if (context === undefined) { - throw new Error( - "useProviderFormContext must be used within a ProviderFormProvider", - ); + throw new Error("useProviderFormContext must be used within a ProviderFormProvider"); } return context; } @@ -68,10 +55,7 @@ interface ProviderFormProviderProps { value: ProviderFormContextType; } -export function ProviderFormProvider({ - children, - value, -}: ProviderFormProviderProps) { +export function ProviderFormProvider({ children, value }: ProviderFormProviderProps) { return ( {children} diff --git a/graphcap_studio/src/features/inference/services/providers.ts b/graphcap_studio/src/features/inference/services/providers.ts index 0b0bcc9f..95acfc2a 100644 --- a/graphcap_studio/src/features/inference/services/providers.ts +++ b/graphcap_studio/src/features/inference/services/providers.ts @@ -8,7 +8,10 @@ import { useServerConnectionsContext } from "@/context/ServerConnectionsContext"; import { SERVER_IDS } from "@/features/server-connections/constants"; -import { createDataServiceClient, createInferenceBridgeClient } from "@/features/server-connections/services/apiClients"; +import { + createDataServiceClient, + createInferenceBridgeClient, +} from "@/features/server-connections/services/apiClients"; import type { Provider, ProviderCreate, @@ -24,11 +27,10 @@ import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; export const queryKeys = { providers: ["providers"] as const, provider: (id: number) => ["providers", id] as const, - providerModels: (providerName: string) => ["providers", "models", providerName] as const, + providerModels: (providerName: string) => + ["providers", "models", providerName] as const, }; - - /** * Extended Error interface with cause property */ @@ -110,22 +112,25 @@ export function useCreateProvider() { try { const errorData = await response.json(); console.error("Provider creation error:", errorData); - + // Check if we have a structured error response - if (errorData.status === 'error' || errorData.validationErrors) { + if (errorData.status === "error" || errorData.validationErrors) { throw errorData; } - + // Simple error with a message if (errorData.message) { throw new Error(errorData.message); } - + // Fallback error throw new Error(`Failed to create provider: ${response.status}`); } catch (parseError) { // If we can't parse the error as JSON, throw a general error - if (parseError instanceof Error && parseError.message !== 'Failed to create provider') { + if ( + parseError instanceof Error && + parseError.message !== "Failed to create provider" + ) { throw parseError; } throw new Error(`Failed to create provider: ${response.status}`); @@ -151,10 +156,19 @@ export function useUpdateProvider() { return useMutation({ mutationFn: async ({ id, data }: { id: number; data: ProviderUpdate }) => { console.log("Updating provider with data:", data); + + // For update operations, we only need to send simple model objects + // Backend will handle ID generation and association + const apiData = { ...data }; + + // Keep models simple, backend will handle IDs + // No ID conversion needed, only sending name and isEnabled + // Backend will handle the rest + const client = createDataServiceClient(connections); const response = await client.providers[":id"].$put({ param: { id: id.toString() }, - json: data, + json: apiData, }); if (!response.ok) { @@ -166,8 +180,11 @@ export function useUpdateProvider() { return response.json() as Promise; }, onSuccess: (data) => { + // Convert string ID to number for query invalidation + const numericId = typeof data.id === 'string' ? Number.parseInt(data.id, 10) : data.id; + // Invalidate specific provider query - queryClient.invalidateQueries({ queryKey: queryKeys.provider(data.id) }); + queryClient.invalidateQueries({ queryKey: queryKeys.provider(numericId) }); // Invalidate providers list queryClient.invalidateQueries({ queryKey: queryKeys.providers }); }, @@ -208,51 +225,58 @@ export function useDeleteProvider() { */ export function useTestProviderConnection() { const { connections } = useServerConnectionsContext(); - + return useMutation({ - mutationFn: async ({ providerName, config }: { providerName: string; config: ServerProviderConfig }) => { + mutationFn: async ({ + providerName, + config, + }: { providerName: string; config: ServerProviderConfig }) => { const client = createInferenceBridgeClient(connections); - + // Add console logging to debug - console.log('Testing connection with config:', JSON.stringify(config)); - + console.log("Testing connection with config:", JSON.stringify(config)); + // Make sure api_key is properly set and not null or undefined if (!config.api_key) { throw new Error("API key is required for testing provider connection"); } - - const response = await client.providers[":provider_name"]["test-connection"].$post({ + + const response = await client.providers[":provider_name"][ + "test-connection" + ].$post({ param: { provider_name: providerName }, json: config, }); if (!response.ok) { const errorData = await response.json(); - console.error('Error response:', errorData); - + console.error("Error response:", errorData); + // Check if this is our enhanced error format - if (errorData.status === 'error' && errorData.details) { + if (errorData.status === "error" && errorData.details) { // Use the structured error data with cause property - const error = new Error(errorData.message || 'Connection test failed') as ErrorWithCause; + const error = new Error( + errorData.message || "Connection test failed", + ) as ErrorWithCause; error.cause = errorData; throw error; } - + // Handle different error formats if (errorData.detail) { throw new Error(errorData.detail); } - + if (errorData.message) { throw new Error(errorData.message); } - - if (typeof errorData === 'object') { + + if (typeof errorData === "object") { // For raw objects, don't wrap in Error, just throw the object directly // This prevents "[object Object]" in the error message throw { ...errorData }; } - + // Fallback to simple error throw new Error(`Connection test failed: ${response.status}`); } @@ -277,11 +301,13 @@ export function useProviderModels(provider: Provider) { queryFn: async () => { const client = createInferenceBridgeClient(connections); const serverConfig = toServerConfig(provider); - - const response = await client.providers[":provider_name"]["models"].$post({ - param: { provider_name: provider.name }, - json: serverConfig, - }); + + const response = await client.providers[":provider_name"]["models"].$post( + { + param: { provider_name: provider.name }, + json: serverConfig, + }, + ); if (!response.ok) { throw new Error( diff --git a/servers/data_service/src/features/provider_config/controller.ts b/servers/data_service/src/features/provider_config/controller.ts index ad161a79..24ad1ddc 100644 --- a/servers/data_service/src/features/provider_config/controller.ts +++ b/servers/data_service/src/features/provider_config/controller.ts @@ -415,18 +415,36 @@ export const updateProvider = async (c: Context) => { // Log model changes if applicable if (models && models.length > 0) { - // Need to query specifically for models since they might not be included in existingProvider - const existingModelsQuery = await db.query.providerModels.findMany({ - where: eq(providerModels.providerId, Number.parseInt(id)) - }); - - logger.info({ - providerId: id, - provider: existingProvider.name, - existingModelsCount: existingModelsQuery.length, - newModelsCount: models.length, - modelNames: models.map(m => m.name) - }, "Updating provider models"); + // First, delete existing models + await db + .delete(providerModels) + .where(eq(providerModels.providerId, Number.parseInt(id))); + + // Then insert new models - handle both full model objects and simple name+isEnabled objects + await db.insert(providerModels).values( + models.map((model) => { + // If model already has numeric ID, use it + // Otherwise, generate a new one (auto-increment by database) + const modelData = { + providerId: Number.parseInt(id), + name: model.name, + isEnabled: model.isEnabled ?? true, + createdAt: new Date(), + updatedAt: new Date(), + }; + + // Only include ID if it exists and is a number + if (model.id !== undefined && typeof model.id === 'number') { + return { + ...modelData, + id: model.id + }; + } + + // Let database auto-generate ID + return modelData; + }), + ); } // Log rate limit changes if applicable @@ -495,15 +513,30 @@ export const updateProvider = async (c: Context) => { .delete(providerModels) .where(eq(providerModels.providerId, Number.parseInt(id))); - // Then insert new models + // Then insert new models - handle both full model objects and simple name+isEnabled objects await tx.insert(providerModels).values( - models.map((model) => ({ - providerId: Number.parseInt(id), - name: model.name, - isEnabled: model.isEnabled, - createdAt: new Date(), - updatedAt: new Date(), - })), + models.map((model) => { + // If model already has numeric ID, use it + // Otherwise, generate a new one (auto-increment by database) + const modelData = { + providerId: Number.parseInt(id), + name: model.name, + isEnabled: model.isEnabled ?? true, + createdAt: new Date(), + updatedAt: new Date(), + }; + + // Only include ID if it exists and is a number + if (model.id !== undefined && typeof model.id === 'number') { + return { + ...modelData, + id: model.id + }; + } + + // Let database auto-generate ID + return modelData; + }), ); } From 18f98e5a87b5293b4a865d363acc5d9a47739279 Mon Sep 17 00:00:00 2001 From: jphillips Date: Sat, 29 Mar 2025 13:30:12 -0500 Subject: [PATCH 35/69] Remove fetchModels Signed-off-by: jphillips --- .../src/features/inference/constants.ts | 3 +- .../components/actions/ProviderSaveDialog.tsx | 12 +- .../components/form/ModelSelectionSection.tsx | 11 +- .../containers/ProviderFormContainer.tsx | 1 - .../features/inference/services/providers.ts | 40 +------ .../services/providerAdapters.ts | 3 - .../server-connections/services/providers.ts | 106 ++++++------------ .../src/types/provider-config-types.ts | 5 - .../features/provider_config/controller.ts | 6 +- .../src/features/provider_config/schemas.ts | 2 +- 10 files changed, 52 insertions(+), 137 deletions(-) diff --git a/graphcap_studio/src/features/inference/constants.ts b/graphcap_studio/src/features/inference/constants.ts index 35737246..67dc9469 100644 --- a/graphcap_studio/src/features/inference/constants.ts +++ b/graphcap_studio/src/features/inference/constants.ts @@ -9,9 +9,8 @@ export const DEFAULT_PROVIDER_FORM_DATA = { environment: "cloud" as const, baseUrl: "", apiKey: "", - isEnabled: true, + isEnabled: false, defaultModel: "", - fetchModels: true, models: [], rateLimits: { requestsPerMinute: 0, diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/ProviderSaveDialog.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/ProviderSaveDialog.tsx index d3b17de3..b9cfd4d6 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/ProviderSaveDialog.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/ProviderSaveDialog.tsx @@ -44,7 +44,7 @@ export function SaveButton() { // Get provider service functions const { isPending: isCreatingProvider } = useCreateProvider(); - const { mutateAsync: updateProviderAsync, isPending: isUpdatingProvider } = + const { isPending: isUpdatingProvider } = useUpdateProvider(); // Determine if form is submitting @@ -158,12 +158,10 @@ export function SaveButton() { Base URL: {displayProvider.baseUrl} - {displayProvider.fetchModels && ( - - Default Model:{" "} - {displayProvider.defaultModel ?? "Not set"} - - )} + + Default Model:{" "} + {displayProvider.defaultModel ?? "Not set"} + diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelectionSection.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelectionSection.tsx index bc9c2df8..7bc877b5 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelectionSection.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelectionSection.tsx @@ -31,7 +31,6 @@ export function ModelSelectionSection() { const providerName = provider?.name; const isEditMode = mode === "edit" || mode === "create"; const customModels = watch("models") || []; - const fetchModels = provider?.fetchModels || watch("fetchModels"); // Prepare an array with all models to display const allModels = []; @@ -50,8 +49,8 @@ export function ModelSelectionSection() { } } - // Add API-fetched models if fetchModels is true - if (fetchModels && providerModels && providerModels.length > 0) { + // Add API-fetched models + if (providerModels && providerModels.length > 0) { // Map API models to the format expected by the model selector for (const model of providerModels) { // Only add if not already included in custom models @@ -85,7 +84,7 @@ export function ModelSelectionSection() { - {fetchModels && isLoadingModels && ( + {isLoadingModels && ( )} @@ -108,7 +107,7 @@ export function ModelSelectionSection() { } // View mode - if (fetchModels && isLoadingModels) { + if (isLoadingModels) { return ; } @@ -117,7 +116,7 @@ export function ModelSelectionSection() { ); } diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx index b4fa86c1..0b92e9d5 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx @@ -73,7 +73,6 @@ export function ProviderFormContainer({ apiKey: newProvider.apiKey || "", isEnabled: newProvider.isEnabled, defaultModel: newProvider.defaultModel, - fetchModels: newProvider.fetchModels, models: newProvider.models, rateLimits: newProvider.rateLimits || { requestsPerMinute: 0, tokensPerMinute: 0 } }); diff --git a/graphcap_studio/src/features/inference/services/providers.ts b/graphcap_studio/src/features/inference/services/providers.ts index 95acfc2a..6cbf2c9e 100644 --- a/graphcap_studio/src/features/inference/services/providers.ts +++ b/graphcap_studio/src/features/inference/services/providers.ts @@ -15,12 +15,10 @@ import { import type { Provider, ProviderCreate, - ProviderModelsResponse, ProviderUpdate, ServerProviderConfig, - SuccessResponse, + SuccessResponse } from "@/types/provider-config-types"; -import { toServerConfig } from "@/types/provider-config-types"; import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; // Query keys for TanStack Query @@ -285,39 +283,3 @@ export function useTestProviderConnection() { }, }); } - -/** - * Hook to fetch provider models - */ -export function useProviderModels(provider: Provider) { - const { connections } = useServerConnectionsContext(); - const inferenceBridgeConnection = connections.find( - (conn) => conn.id === SERVER_IDS.INFERENCE_BRIDGE, - ); - const isConnected = inferenceBridgeConnection?.status === "connected"; - - return useQuery({ - queryKey: queryKeys.providerModels(provider.name), - queryFn: async () => { - const client = createInferenceBridgeClient(connections); - const serverConfig = toServerConfig(provider); - - const response = await client.providers[":provider_name"]["models"].$post( - { - param: { provider_name: provider.name }, - json: serverConfig, - }, - ); - - if (!response.ok) { - throw new Error( - `Failed to fetch models for ${provider.name}: ${response.status}`, - ); - } - - return response.json() as Promise; - }, - enabled: isConnected && provider.fetchModels, - staleTime: 1000 * 60 * 5, // 5 minutes - }); -} diff --git a/graphcap_studio/src/features/server-connections/services/providerAdapters.ts b/graphcap_studio/src/features/server-connections/services/providerAdapters.ts index 4ad90b7b..ff11973e 100644 --- a/graphcap_studio/src/features/server-connections/services/providerAdapters.ts +++ b/graphcap_studio/src/features/server-connections/services/providerAdapters.ts @@ -19,7 +19,6 @@ interface ApiProvider { apiKey?: string; isEnabled: boolean; defaultModel?: string; - fetchModels: boolean; createdAt: string | Date; updatedAt: string | Date; models?: ApiProviderModel[]; @@ -67,7 +66,6 @@ export function fromApiProvider(apiProvider: ApiProvider): Provider { apiKey: apiProvider.apiKey, isEnabled: apiProvider.isEnabled, defaultModel: apiProvider.defaultModel, - fetchModels: apiProvider.fetchModels, createdAt: apiProvider.createdAt, updatedAt: apiProvider.updatedAt, @@ -109,7 +107,6 @@ export function toApiProvider(provider: Provider): ApiProvider { apiKey: provider.apiKey, isEnabled: provider.isEnabled, defaultModel: provider.defaultModel, - fetchModels: provider.fetchModels, createdAt: provider.createdAt, updatedAt: provider.updatedAt, diff --git a/graphcap_studio/src/features/server-connections/services/providers.ts b/graphcap_studio/src/features/server-connections/services/providers.ts index 2fe3eec2..f3f4b4fd 100644 --- a/graphcap_studio/src/features/server-connections/services/providers.ts +++ b/graphcap_studio/src/features/server-connections/services/providers.ts @@ -12,9 +12,9 @@ import type { ProviderCreate, ProviderModelsResponse, ProviderUpdate, + ServerProviderConfig, SuccessResponse, } from "@/types/provider-config-types"; -import { toServerConfig } from "@/types/provider-config-types"; import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; import { SERVER_IDS } from "../constants"; import { createDataServiceClient, createInferenceBridgeClient } from "./apiClients"; @@ -233,11 +233,7 @@ export function useDeleteProvider() { * Hook to get provider models */ export function useProviderModels(providerName: string | Provider) { - const { connections } = useServerConnectionsContext(); - const graphcapServerConnection = connections.find( - (conn) => conn.id === SERVER_IDS.INFERENCE_BRIDGE, - ); - const isConnected = graphcapServerConnection?.status === "connected"; + useServerConnectionsContext(); // Extract the provider name and data if an object was passed const isProviderObject = typeof providerName === 'object' && providerName !== null; @@ -247,65 +243,39 @@ export function useProviderModels(providerName: string | Provider) { return useQuery({ queryKey: queryKeys.providerModels(name), queryFn: async () => { - console.log(`📡 Fetching models for provider: ${name}`); + console.log(`📡 Processing models for provider: ${name}`); - try { - const client = createInferenceBridgeClient(connections); - - // If we have the full provider object, use it to create a server config - // Otherwise, use a minimal configuration - const config = provider - ? toServerConfig(provider) - : { - name, - kind: "unknown", - environment: "cloud" as const, - base_url: "", - api_key: "", - models: [], - fetch_models: true - }; - - console.log("📤 API request data:", config); + // Debug provider object to better understand structure + if (provider) { + console.log('Provider object passed:', JSON.stringify(provider, null, 2)); + } + + // If we have a provider object with models, use those directly + if (provider && Array.isArray(provider.models) && provider.models.length > 0) { + const modelCount = provider.models.length; + console.log(`📝 Using ${modelCount} models from provider object`, provider.models); - // Use the POST endpoint with provider_name param and config body - const response = await client.providers[":provider_name"].models.$post({ - param: { provider_name: name }, - json: config, - }); - - if (!response.ok) { - throw new Error(`Failed to fetch provider models from API: ${response.status}`); - } - - const models = await response.json() as ProviderModelsResponse; - console.log(`✅ Fetched ${models.models.length} models from API for provider ${name}:`, models); - return models; - } catch (error) { - // If we have a provider object with models, use them as fallback - if (provider?.models && provider.models.length > 0) { - const modelCount = provider.models.length; - console.log(`⚠️ API request failed. Using ${modelCount} saved models from provider`); - - // Convert the provider models to the expected ProviderModelsResponse format - const fallbackModels: ProviderModelsResponse = { - provider: name, - models: provider.models.map(model => ({ - id: model.id ? (typeof model.id === 'string' ? model.id : String(model.id)) : String(model.name), - name: model.name, - is_default: model.name === provider.defaultModel - })) - }; - - return fallbackModels; - } + // Convert the provider models to the expected ProviderModelsResponse format + const configuredModels: ProviderModelsResponse = { + provider: name, + models: provider.models.map(model => ({ + id: model.id ? (typeof model.id === 'string' ? model.id : String(model.id)) : String(model.name), + name: model.name, + is_default: model.name === provider.defaultModel + })) + }; - // If no fallback is available, re-throw the error - console.error("❌ Failed to fetch models and no fallback available:", error); - throw error; + return configuredModels; } + + // If no provider object or no models, return empty array + console.log(`📝 No models available for provider: ${name}`, provider ? 'Has provider object but no models array or empty array' : 'No provider object'); + return { + provider: name, + models: [] + }; }, - enabled: isConnected && !!name, + enabled: !!name, }); } @@ -316,18 +286,16 @@ export function useTestProviderConnection() { const { connections } = useServerConnectionsContext(); return useMutation({ - mutationFn: async (provider: Provider) => { - console.log(`📡 Testing connection for provider: ${provider.name}`, provider); + mutationFn: async ({ providerName, config }: { providerName: string, config: ServerProviderConfig }) => { + console.log(`📡 Testing connection for provider: ${providerName}`, config); const client = createInferenceBridgeClient(connections); - // Convert to server config format - const serverConfig = toServerConfig(provider); - console.log("📤 API request data:", serverConfig); + console.log("📤 API request data:", config); const response = await client.providers[":provider_name"].models.$post({ - param: { provider_name: provider.name }, - json: serverConfig, + param: { provider_name: providerName }, + json: config, }); if (!response.ok) { @@ -341,8 +309,8 @@ export function useTestProviderConnection() { console.log("✅ Provider connection test successful:", result); return result; }, - onError: (error: Error, provider) => { - console.error(`❌ Error in useTestProviderConnection for provider ${provider.name}:`, error); + onError: (error: Error, variables) => { + console.error(`❌ Error in useTestProviderConnection for provider ${variables.providerName}:`, error); }, }); } \ No newline at end of file diff --git a/graphcap_studio/src/types/provider-config-types.ts b/graphcap_studio/src/types/provider-config-types.ts index b2774c68..08f33090 100644 --- a/graphcap_studio/src/types/provider-config-types.ts +++ b/graphcap_studio/src/types/provider-config-types.ts @@ -53,7 +53,6 @@ export const ProviderSchema = BaseProviderSchema.extend({ baseUrl: z.string().url("Must be a valid URL"), apiKey: z.string().optional(), defaultModel: z.string().optional(), - fetchModels: z.boolean().default(true), createdAt: z.string().or(z.date()), updatedAt: z.string().or(z.date()), models: z.array(ProviderModelSchema).optional(), @@ -69,7 +68,6 @@ export const ProviderCreateSchema = z.object({ apiKey: z.string().optional(), isEnabled: z.boolean().default(true), defaultModel: z.string().optional(), - fetchModels: z.boolean().default(true), models: z .array( z.object({ @@ -95,7 +93,6 @@ export const ProviderUpdateSchema = z.object({ apiKey: z.string().optional(), isEnabled: z.boolean().optional(), defaultModel: z.string().optional(), - fetchModels: z.boolean().optional(), models: z .array( z.object({ @@ -155,7 +152,6 @@ export const ServerProviderConfigSchema = z.object({ api_key: z.string(), default_model: z.string().optional(), models: z.array(z.string()), - fetch_models: z.boolean(), rate_limits: z .object({ requests_per_minute: z.number().optional(), @@ -261,7 +257,6 @@ export function toServerConfig(provider: Provider): ServerProviderConfig { api_key: provider.apiKey || "", default_model: provider.defaultModel, models: provider.models?.map((m) => m.name) || [], - fetch_models: provider.fetchModels, rate_limits: provider.rateLimits ? { requests_per_minute: provider.rateLimits.requestsPerMinute, diff --git a/servers/data_service/src/features/provider_config/controller.ts b/servers/data_service/src/features/provider_config/controller.ts index 24ad1ddc..02a3b0a1 100644 --- a/servers/data_service/src/features/provider_config/controller.ts +++ b/servers/data_service/src/features/provider_config/controller.ts @@ -423,8 +423,7 @@ export const updateProvider = async (c: Context) => { // Then insert new models - handle both full model objects and simple name+isEnabled objects await db.insert(providerModels).values( models.map((model) => { - // If model already has numeric ID, use it - // Otherwise, generate a new one (auto-increment by database) + // Create base model data object const modelData = { providerId: Number.parseInt(id), name: model.name, @@ -516,8 +515,7 @@ export const updateProvider = async (c: Context) => { // Then insert new models - handle both full model objects and simple name+isEnabled objects await tx.insert(providerModels).values( models.map((model) => { - // If model already has numeric ID, use it - // Otherwise, generate a new one (auto-increment by database) + // Create base model data object const modelData = { providerId: Number.parseInt(id), name: model.name, diff --git a/servers/data_service/src/features/provider_config/schemas.ts b/servers/data_service/src/features/provider_config/schemas.ts index 00c82e42..c832715c 100644 --- a/servers/data_service/src/features/provider_config/schemas.ts +++ b/servers/data_service/src/features/provider_config/schemas.ts @@ -68,7 +68,7 @@ export const providerUpdateSchema = z.object({ isEnabled: z.boolean().optional(), models: z.array( z.object({ - id: z.number().optional(), + id: z.number().or(z.string()).optional(), name: z.string().min(1, 'Model name is required'), isEnabled: z.boolean().default(true), }) From 0bfb84ec24a888216907da2302528c5ea671c21e Mon Sep 17 00:00:00 2001 From: jphillips Date: Sat, 29 Mar 2025 14:35:56 -0500 Subject: [PATCH 36/69] Provider model setup Signed-off-by: jphillips --- .../src/features/inference/constants.ts | 4 - .../context/GenerationOptionsContext.tsx | 11 +- .../persist-generation-options.ts | 11 ++ .../inference/hooks/useModelSelection.ts | 54 ++++---- .../hooks/useProviderModelOptions.ts | 29 ++-- .../components/ProviderFormTabs.tsx | 6 +- .../components/form/RateLimitsSection.tsx | 125 ------------------ .../components/form/index.ts | 5 +- .../containers/ProviderFormContainer.tsx | 3 +- .../PerspectiveActions/PerspectivesFooter.tsx | 67 ++++++---- .../features/perspectives/constants/index.ts | 3 +- .../context/PerspectivesDataContext.tsx | 2 +- .../server-connections/services/index.ts | 25 +--- .../services/providerAdapters.ts | 35 ----- .../server-connections/services/providers.ts | 51 +------ .../src/types/provider-config-types.ts | 42 ------ .../providers/clients/gemini_client.py | 3 - 17 files changed, 113 insertions(+), 363 deletions(-) delete mode 100644 graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/RateLimitsSection.tsx diff --git a/graphcap_studio/src/features/inference/constants.ts b/graphcap_studio/src/features/inference/constants.ts index 67dc9469..cce1fc1c 100644 --- a/graphcap_studio/src/features/inference/constants.ts +++ b/graphcap_studio/src/features/inference/constants.ts @@ -12,10 +12,6 @@ export const DEFAULT_PROVIDER_FORM_DATA = { isEnabled: false, defaultModel: "", models: [], - rateLimits: { - requestsPerMinute: 0, - tokensPerMinute: 0, - }, }; /** diff --git a/graphcap_studio/src/features/inference/generation-options/context/GenerationOptionsContext.tsx b/graphcap_studio/src/features/inference/generation-options/context/GenerationOptionsContext.tsx index b85233a9..5fae359f 100644 --- a/graphcap_studio/src/features/inference/generation-options/context/GenerationOptionsContext.tsx +++ b/graphcap_studio/src/features/inference/generation-options/context/GenerationOptionsContext.tsx @@ -138,7 +138,7 @@ export function GenerationOptionsProvider({ if (options.provider_id && !options.model_id && models.length > 0) { // Try to use default model first, otherwise use first available model const modelToUse = defaultModel || models[0]; - updateOption("model_id", modelToUse.id); + updateOption("model_id", modelToUse.name); } }, [options.provider_id, options.model_id, models, defaultModel]); @@ -191,8 +191,13 @@ export function GenerationOptionsProvider({ // Model selection const selectModel = useCallback((modelId: string) => { - updateOption("model_id", modelId); - }, [updateOption]); + // Find the model by ID to get its name + const model = models.find((m: ProviderModelInfo) => m.id === modelId); + if (!model) { + throw new Error(`Model with ID ${modelId} not found`); + } + updateOption("model_id", model.name); + }, [updateOption, models]); // Dialog controls const openDialog = useCallback(() => setIsDialogOpen(true), []); diff --git a/graphcap_studio/src/features/inference/generation-options/persist-generation-options.ts b/graphcap_studio/src/features/inference/generation-options/persist-generation-options.ts index 7421fa62..c4825009 100644 --- a/graphcap_studio/src/features/inference/generation-options/persist-generation-options.ts +++ b/graphcap_studio/src/features/inference/generation-options/persist-generation-options.ts @@ -28,6 +28,11 @@ export function saveGenerationOptions(options: GenerationOptions): void { optionsToSave.provider_id = String(optionsToSave.provider_id); } + // Verify we have a model name, not an ID + if (optionsToSave.model_id && /^\d+$/.test(optionsToSave.model_id)) { + throw new Error('model_id must be a model name, not a numeric ID'); + } + const serialized = JSON.stringify(optionsToSave); localStorage.setItem(STORAGE_KEY, serialized); } catch (error) { @@ -52,6 +57,12 @@ export function loadGenerationOptions(): GenerationOptions | null { parsed.provider_id = parsed.provider_id.toString(); } + // Check if model_id appears to be a numeric ID and not a name + if (parsed.model_id && /^\d+$/.test(parsed.model_id)) { + console.error('Invalid model_id format: Must be a model name, not a numeric ID'); + throw new Error('model_id must be a model name, not a numeric ID'); + } + // Validate the loaded data against the schema return GenerationOptionsSchema.parse(parsed); } catch (error) { diff --git a/graphcap_studio/src/features/inference/hooks/useModelSelection.ts b/graphcap_studio/src/features/inference/hooks/useModelSelection.ts index c71c86a7..037daf4c 100644 --- a/graphcap_studio/src/features/inference/hooks/useModelSelection.ts +++ b/graphcap_studio/src/features/inference/hooks/useModelSelection.ts @@ -1,12 +1,11 @@ -import { useProviderModels } from "@/features/server-connections/services/providers"; -import type { Provider } from "@/types/provider-config-types"; +import type { Provider, ProviderModelInfo } from "@/types/provider-config-types"; // SPDX-License-Identifier: Apache-2.0 -import { useCallback, useEffect, useState } from "react"; +import { useCallback, useEffect, useMemo, useState } from "react"; /** * Custom hook for managing model selection * - * @param provider - Provider to fetch models for, can be null or undefined + * @param provider - Provider to use models from, can be null or undefined * @param onModelSelect - Callback function when a model is selected * @returns Model selection state and handlers */ @@ -17,33 +16,37 @@ export function useModelSelection( // State for model selection const [selectedModelId, setSelectedModelId] = useState(""); + // Process provider models + const models = useMemo(() => { + if (!provider?.models?.length) return []; + + // Map provider models to ProviderModelInfo format + return provider.models.map(model => ({ + id: model.id, + name: model.name, + is_default: model.name === provider.defaultModel + })); + }, [provider]); + // Update selected model ID when provider changes useEffect(() => { if (provider?.defaultModel) { - setSelectedModelId(provider.defaultModel); + // Try to find the model with the default name + const defaultModel = provider.models?.find(m => m.name === provider.defaultModel); + if (defaultModel) { + setSelectedModelId(defaultModel.id); + return; + } + } + + // If no default or default not found, use first model or reset + if (provider?.models?.length) { + setSelectedModelId(provider.models[0].id); } else { setSelectedModelId(""); } }, [provider]); - // Get models for the current provider - const { - data: providerModelsData, - isLoading: isLoadingModels, - isError: isModelsError, - error: modelsError, - } = useProviderModels(provider); - - // Update selected model when models are loaded - useEffect(() => { - if (!selectedModelId && providerModelsData?.models && providerModelsData.models.length > 0) { - const defaultModel = providerModelsData.models.find( - (model) => model.is_default - ); - setSelectedModelId(defaultModel?.id ?? providerModelsData.models[0].id); - } - }, [providerModelsData, selectedModelId]); - // Handle model selection const handleModelSelect = useCallback(() => { if (onModelSelect && provider?.name && selectedModelId) { @@ -54,10 +57,7 @@ export function useModelSelection( return { selectedModelId, setSelectedModelId, - providerModelsData, - isLoadingModels, - isModelsError, - modelsError, + models, handleModelSelect, }; } diff --git a/graphcap_studio/src/features/inference/hooks/useProviderModelOptions.ts b/graphcap_studio/src/features/inference/hooks/useProviderModelOptions.ts index 3f16b3db..d58bcfe1 100644 --- a/graphcap_studio/src/features/inference/hooks/useProviderModelOptions.ts +++ b/graphcap_studio/src/features/inference/hooks/useProviderModelOptions.ts @@ -6,7 +6,7 @@ * It consolidates provider and model data loading in a single hook. */ -import { useProviderModels, useProviders } from "@/features/server-connections/services/providers"; +import { useProviders } from "@/features/server-connections/services/providers"; import type { Provider, ProviderModelInfo } from "@/types/provider-config-types"; import { useMemo } from "react"; @@ -30,18 +30,17 @@ export function useProviderModelOptions(providerId?: string) { return providers.find((p: Provider) => p.id === providerId) || null; }, [providers, providerId]); - // Fetch models for the selected provider - const { - data: modelData, - isLoading: isLoadingModels, - error: modelsError - } = useProviderModels(selectedProvider?.name || ""); - - // Process models data + // Process models data directly from the provider const models = useMemo(() => { - if (!modelData?.models) return []; - return modelData.models; - }, [modelData]); + if (!selectedProvider?.models?.length) return []; + + // Map provider models to ProviderModelInfo format + return selectedProvider.models.map((model: { id: string; name: string }) => ({ + id: model.id, + name: model.name, + is_default: model.name === selectedProvider.defaultModel + })); + }, [selectedProvider]); // Check for default model const defaultModel = useMemo(() => { @@ -58,11 +57,9 @@ export function useProviderModelOptions(providerId?: string) { // Models data models, defaultModel, - isLoadingModels, - modelsError, // Helper for status checking - isLoading: isLoadingProviders || isLoadingModels, - hasError: !!providersError || !!modelsError + isLoading: isLoadingProviders, + hasError: !!providersError }; } \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderFormTabs.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderFormTabs.tsx index d2f85146..c6d154e5 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderFormTabs.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderFormTabs.tsx @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 import { Tabs } from "@chakra-ui/react"; import styles from "./FormFields.module.css"; -import { BasicInfoSection, ConnectionSection, RateLimitsSection } from "./form"; +import { BasicInfoSection, ConnectionSection } from "./form"; import { ModelSelectionSection } from "./form/ModelSelectionSection"; /** @@ -29,7 +29,6 @@ export function ProviderFormTabs() { > Basic Info Connection - Rate Limits Model @@ -44,9 +43,6 @@ export function ProviderFormTabs() { - - - ); diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/RateLimitsSection.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/RateLimitsSection.tsx deleted file mode 100644 index 82db53ca..00000000 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/RateLimitsSection.tsx +++ /dev/null @@ -1,125 +0,0 @@ -import { useColorModeValue } from "@/components/ui/theme/color-mode"; -import { - Box, - Field, - Grid, - GridItem, - Input, - Text, - VStack, -} from "@chakra-ui/react"; -// SPDX-License-Identifier: Apache-2.0 -import type { ChangeEvent } from "react"; -import { Controller } from "react-hook-form"; -import { useProviderFormContext } from "../../../context/ProviderFormContext"; - -/** - * Component for displaying and editing provider rate limits - */ -export function RateLimitsSection() { - const { control, errors, watch, mode } = useProviderFormContext(); - const isEditing = mode === "edit" || mode === "create"; - const labelColor = useColorModeValue("gray.600", "gray.300"); - const textColor = useColorModeValue("gray.700", "gray.200"); - - // Watch form values for read-only display - const formValues = watch(); - const rateLimits = formValues.rateLimits || { requestsPerMinute: 0, tokensPerMinute: 0 }; - - if (!isEditing) { - return ( - - - - Rate Limits - - - - - Requests per minute - - - {rateLimits.requestsPerMinute ?? 0} - - - - - Tokens per minute - - {rateLimits.tokensPerMinute ?? 0} - - - - - ); - } - - return ( - - - - Rate Limits - - - {/* Use a single Controller for the entire rateLimits object - This ensures we always have an object structure */} - ( - - - - - Requests per minute - - ) => { - const value = Number.parseInt(e.target.value) || 0; - field.onChange({ - ...field.value, - requestsPerMinute: value - }); - }} - min={0} - /> - - {errors.rateLimits?.requestsPerMinute?.message} - - - - - - - - Tokens per minute - - ) => { - const value = Number.parseInt(e.target.value) || 0; - field.onChange({ - ...field.value, - tokensPerMinute: value - }); - }} - min={0} - /> - - {errors.rateLimits?.tokensPerMinute?.message} - - - - - )} - /> - - - ); -} diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/index.ts b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/index.ts index 1c698340..0eec7288 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/index.ts +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/index.ts @@ -1,9 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 +/** + * Form component exports + */ + export * from "./BasicInfoSection"; export * from "./ConnectionSection"; export * from "./EnvironmentSelect"; export * from "./ModelSelectionSection"; export * from "./ModelSelector"; export * from "./ProviderFormSelect"; -export * from "./RateLimitsSection"; diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx index 0b92e9d5..0ea32893 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx @@ -73,8 +73,7 @@ export function ProviderFormContainer({ apiKey: newProvider.apiKey || "", isEnabled: newProvider.isEnabled, defaultModel: newProvider.defaultModel, - models: newProvider.models, - rateLimits: newProvider.rateLimits || { requestsPerMinute: 0, tokensPerMinute: 0 } + models: newProvider.models }); } }, [reset]); diff --git a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveActions/PerspectivesFooter.tsx b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveActions/PerspectivesFooter.tsx index 14ce9e7c..2ef6d50f 100644 --- a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveActions/PerspectivesFooter.tsx +++ b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveActions/PerspectivesFooter.tsx @@ -18,7 +18,7 @@ import { Text, chakra, } from "@chakra-ui/react"; -import { useCallback, useEffect } from "react"; +import { useCallback, useEffect, useMemo } from "react"; import { LuRefreshCw } from "react-icons/lu"; /** @@ -58,7 +58,6 @@ export function PerspectivesFooter() { isGenerating, currentImage, generationOptions, - selectedProvider, } = usePerspectivesData(); // Use UI context @@ -81,19 +80,35 @@ export function PerspectivesFooter() { console.log("GenerationOptions:", generationOptions); console.log("Available providers:", availableProviders); - // Try to find provider by id first, if that fails look for a name match - const providerObj = availableProviders.find(p => - // Try matching by ID - p.id.toString() === generationOptions.provider_id || - // Or by name - (selectedProvider && p.name === selectedProvider) - ); - - const providerName = providerObj?.name || - // If we have a provider_id but couldn't find a match, show that - (generationOptions.provider_id ? `ID: ${generationOptions.provider_id}` : "None"); + // Get provider information safely (without throwing during render) + const providerInfo = useMemo(() => { + // Don't attempt to find providers if the list is empty or loading + if (!availableProviders.length) { + return { providerName: "Loading...", modelName: "Loading..." }; + } + + // Try to find provider by ID + const providerObj = availableProviders.find(p => + p.id.toString() === generationOptions.provider_id + ); + + // If provider not found, return a placeholder but don't throw + if (!providerObj) { + console.warn(`Provider with ID ${generationOptions.provider_id} not found yet`); + return { + providerName: `ID: ${generationOptions.provider_id}`, + modelName: generationOptions.model_id || "None" + }; + } + + // Provider found, return proper info + return { + providerName: providerObj.name, + modelName: generationOptions.model_id || "None" + }; + }, [availableProviders, generationOptions.provider_id, generationOptions.model_id]); - const modelId = generationOptions.model_id || "None"; + const { providerName, modelName } = providerInfo; // Fetch providers on component mount useEffect(() => { @@ -113,7 +128,7 @@ export function PerspectivesFooter() { return false; } - if (!selectedProvider) { + if (!generationOptions.provider_id) { showMessage( "No provider selected", "Please select an inference provider", @@ -132,14 +147,13 @@ export function PerspectivesFooter() { } return true; - }, [activeSchemaName, selectedProvider, currentImage, showMessage]); + }, [activeSchemaName, generationOptions.provider_id, currentImage, showMessage]); // Handle generate button click const handleGenerate = useCallback(async () => { console.log("Generate button clicked"); console.log("Active schema:", activeSchemaName); - console.log("Selected provider:", selectedProvider); - console.log("Using generation options:", generationOptions); + console.log("Generation options:", generationOptions); if (!validateGeneration()) { return; @@ -147,11 +161,11 @@ export function PerspectivesFooter() { try { console.log("Calling generatePerspective..."); - // Find the provider object from the available providers - const providerObject = availableProviders.find(p => p.name === selectedProvider); + // Find the provider object from the available providers using the provider_id + const providerObject = availableProviders.find(p => p.id.toString() === generationOptions.provider_id); if (!providerObject) { - throw new Error(`Provider "${selectedProvider}" not found in available providers`); + throw new Error(`Provider with ID "${generationOptions.provider_id}" not found in available providers`); } await generatePerspective( @@ -176,7 +190,6 @@ export function PerspectivesFooter() { } }, [ activeSchemaName, - selectedProvider, availableProviders, generatePerspective, generationOptions, @@ -188,13 +201,13 @@ export function PerspectivesFooter() { // Combine loading states const isProcessing = isLoading || isGenerating; - // Check if button should be disabled + // Check if button should be disabled - use generationOptions.provider_id instead of selectedProvider const isGenerateDisabled = - isProcessing || !activeSchemaName || !selectedProvider; + isProcessing || !activeSchemaName || !generationOptions.provider_id; - // Get title for the generate button + // Get title for the generate button - also use generationOptions.provider_id const buttonTitle = getButtonTitle( - selectedProvider, + generationOptions.provider_id ? providerName : undefined, activeSchemaName, isProcessing, isGenerated, @@ -220,7 +233,7 @@ export function PerspectivesFooter() { color={infoTextColor} title="Current provider and model from global settings" > - Using: {providerName} / {modelId} + Using: {providerName} / {modelName} {/* Generate/Regenerate Button */} diff --git a/graphcap_studio/src/features/perspectives/constants/index.ts b/graphcap_studio/src/features/perspectives/constants/index.ts index ef436ffe..ab668128 100644 --- a/graphcap_studio/src/features/perspectives/constants/index.ts +++ b/graphcap_studio/src/features/perspectives/constants/index.ts @@ -72,6 +72,5 @@ export const CACHE_TIMES = { // Default values export const DEFAULTS = { SERVER_URL: "http://localhost:32100", - PROVIDER: "gemini", - DEFAULT_FILENAME: "image.jpg", + }; diff --git a/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx b/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx index ac279287..d40e5767 100644 --- a/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx +++ b/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx @@ -424,7 +424,7 @@ export function PerspectivesDataProvider({ provider: effectiveProvider.name, content: result.result || {}, options: { - model: effectiveOptions.model_id, // Map to expected model property + model: effectiveOptions.model_id, // model_id now contains the name max_tokens: effectiveOptions.max_tokens, temperature: effectiveOptions.temperature, top_p: effectiveOptions.top_p, diff --git a/graphcap_studio/src/features/server-connections/services/index.ts b/graphcap_studio/src/features/server-connections/services/index.ts index b4f8b761..2321ada0 100644 --- a/graphcap_studio/src/features/server-connections/services/index.ts +++ b/graphcap_studio/src/features/server-connections/services/index.ts @@ -1,35 +1,20 @@ // Server health checks export { - checkServerHealth, - checkMediaServerHealth, - checkInferenceBridgeHealth, - checkServerHealthById, + checkInferenceBridgeHealth, checkMediaServerHealth, checkServerHealth, checkServerHealthById } from "./serverConnections"; // API clients export type { DataServiceClient, - InferenceBridgeClient, - ProviderClient, - PerspectivesClient, + InferenceBridgeClient, PerspectivesClient, ProviderClient } from "./apiClients"; export { - getDataServiceUrl, - createDataServiceClient, - getInferenceBridgeUrl, - createInferenceBridgeClient, - createProviderClient, - createPerspectivesClient, + createDataServiceClient, createInferenceBridgeClient, createPerspectivesClient, createProviderClient, getDataServiceUrl, getInferenceBridgeUrl } from "./apiClients"; // Provider services export { - queryKeys as providerQueryKeys, - useProviders, - useProvider, - useCreateProvider, - useUpdateProvider, - useDeleteProvider, - useProviderModels, + queryKeys as providerQueryKeys, useCreateProvider, useDeleteProvider, useProvider, useProviders, useUpdateProvider } from "./providers"; + diff --git a/graphcap_studio/src/features/server-connections/services/providerAdapters.ts b/graphcap_studio/src/features/server-connections/services/providerAdapters.ts index ff11973e..2f0af535 100644 --- a/graphcap_studio/src/features/server-connections/services/providerAdapters.ts +++ b/graphcap_studio/src/features/server-connections/services/providerAdapters.ts @@ -22,7 +22,6 @@ interface ApiProvider { createdAt: string | Date; updatedAt: string | Date; models?: ApiProviderModel[]; - rateLimits?: ApiRateLimits; } // Type for raw API provider model data @@ -35,16 +34,6 @@ interface ApiProviderModel { updatedAt: string | Date; } -// Type for raw API rate limits data -interface ApiRateLimits { - id: number; - providerId: number; - requestsPerMinute?: number; - tokensPerMinute?: number; - createdAt: string | Date; - updatedAt: string | Date; -} - // Type for raw API model info interface ApiModelInfo { id: string; @@ -78,18 +67,6 @@ export function fromApiProvider(apiProvider: ApiProvider): Provider { createdAt: model.createdAt, updatedAt: model.updatedAt, })), - - // Convert nested rate limits if defined - rateLimits: apiProvider.rateLimits?.id && apiProvider.rateLimits?.providerId - ? { - id: normalizeProviderId(apiProvider.rateLimits.id), - providerId: normalizeProviderId(apiProvider.rateLimits.providerId), - requestsPerMinute: apiProvider.rateLimits.requestsPerMinute, - tokensPerMinute: apiProvider.rateLimits.tokensPerMinute, - createdAt: apiProvider.rateLimits.createdAt, - updatedAt: apiProvider.rateLimits.updatedAt, - } - : undefined, }; } @@ -119,18 +96,6 @@ export function toApiProvider(provider: Provider): ApiProvider { createdAt: model.createdAt, updatedAt: model.updatedAt, })), - - // Convert rate limits back to numeric IDs - rateLimits: provider.rateLimits - ? { - id: Number.parseInt(provider.rateLimits.id, 10), - providerId: Number.parseInt(provider.rateLimits.providerId, 10), - requestsPerMinute: provider.rateLimits.requestsPerMinute, - tokensPerMinute: provider.rateLimits.tokensPerMinute, - createdAt: provider.rateLimits.createdAt, - updatedAt: provider.rateLimits.updatedAt, - } - : undefined, }; } diff --git a/graphcap_studio/src/features/server-connections/services/providers.ts b/graphcap_studio/src/features/server-connections/services/providers.ts index f3f4b4fd..25882225 100644 --- a/graphcap_studio/src/features/server-connections/services/providers.ts +++ b/graphcap_studio/src/features/server-connections/services/providers.ts @@ -10,10 +10,9 @@ import { useServerConnectionsContext } from "@/context/ServerConnectionsContext" import type { Provider, ProviderCreate, - ProviderModelsResponse, ProviderUpdate, ServerProviderConfig, - SuccessResponse, + SuccessResponse } from "@/types/provider-config-types"; import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; import { SERVER_IDS } from "../constants"; @@ -229,55 +228,7 @@ export function useDeleteProvider() { }); } -/** - * Hook to get provider models - */ -export function useProviderModels(providerName: string | Provider) { - useServerConnectionsContext(); - - // Extract the provider name and data if an object was passed - const isProviderObject = typeof providerName === 'object' && providerName !== null; - const name = isProviderObject ? providerName.name : providerName; - const provider = isProviderObject ? providerName : null; - return useQuery({ - queryKey: queryKeys.providerModels(name), - queryFn: async () => { - console.log(`📡 Processing models for provider: ${name}`); - - // Debug provider object to better understand structure - if (provider) { - console.log('Provider object passed:', JSON.stringify(provider, null, 2)); - } - - // If we have a provider object with models, use those directly - if (provider && Array.isArray(provider.models) && provider.models.length > 0) { - const modelCount = provider.models.length; - console.log(`📝 Using ${modelCount} models from provider object`, provider.models); - - // Convert the provider models to the expected ProviderModelsResponse format - const configuredModels: ProviderModelsResponse = { - provider: name, - models: provider.models.map(model => ({ - id: model.id ? (typeof model.id === 'string' ? model.id : String(model.id)) : String(model.name), - name: model.name, - is_default: model.name === provider.defaultModel - })) - }; - - return configuredModels; - } - - // If no provider object or no models, return empty array - console.log(`📝 No models available for provider: ${name}`, provider ? 'Has provider object but no models array or empty array' : 'No provider object'); - return { - provider: name, - models: [] - }; - }, - enabled: !!name, - }); -} /** * Hook to test a provider connection diff --git a/graphcap_studio/src/types/provider-config-types.ts b/graphcap_studio/src/types/provider-config-types.ts index 08f33090..f73d4718 100644 --- a/graphcap_studio/src/types/provider-config-types.ts +++ b/graphcap_studio/src/types/provider-config-types.ts @@ -32,18 +32,6 @@ export const ProviderModelSchema = z.object({ updatedAt: z.string().or(z.date()), }); -/** - * Rate limits schema - */ -export const RateLimitsSchema = z.object({ - id: z.string(), - providerId: z.string(), - requestsPerMinute: z.number().optional(), - tokensPerMinute: z.number().optional(), - createdAt: z.string().or(z.date()), - updatedAt: z.string().or(z.date()), -}); - /** * Complete provider schema */ @@ -56,7 +44,6 @@ export const ProviderSchema = BaseProviderSchema.extend({ createdAt: z.string().or(z.date()), updatedAt: z.string().or(z.date()), models: z.array(ProviderModelSchema).optional(), - rateLimits: RateLimitsSchema.optional(), }); // Provider creation schema @@ -76,12 +63,6 @@ export const ProviderCreateSchema = z.object({ }), ) .optional(), - rateLimits: z - .object({ - requestsPerMinute: z.number().optional(), - tokensPerMinute: z.number().optional(), - }) - .optional(), }); // Provider update schema @@ -102,12 +83,6 @@ export const ProviderUpdateSchema = z.object({ }), ) .optional(), - rateLimits: z - .object({ - requestsPerMinute: z.number().optional(), - tokensPerMinute: z.number().optional(), - }) - .optional(), }); // Provider model info schema @@ -152,12 +127,6 @@ export const ServerProviderConfigSchema = z.object({ api_key: z.string(), default_model: z.string().optional(), models: z.array(z.string()), - rate_limits: z - .object({ - requests_per_minute: z.number().optional(), - tokens_per_minute: z.number().optional(), - }) - .optional(), }); // ============================================================================ @@ -174,11 +143,6 @@ export type BaseProvider = z.infer; */ export type ProviderModel = z.infer; -/** - * Rate limits configuration - */ -export type RateLimits = z.infer; - /** * Provider configuration stored in data service */ @@ -257,11 +221,5 @@ export function toServerConfig(provider: Provider): ServerProviderConfig { api_key: provider.apiKey || "", default_model: provider.defaultModel, models: provider.models?.map((m) => m.name) || [], - rate_limits: provider.rateLimits - ? { - requests_per_minute: provider.rateLimits.requestsPerMinute, - tokens_per_minute: provider.rateLimits.tokensPerMinute, - } - : undefined, }; } diff --git a/servers/inference_bridge/graphcap/providers/clients/gemini_client.py b/servers/inference_bridge/graphcap/providers/clients/gemini_client.py index e7ee0b43..16aa6dc3 100644 --- a/servers/inference_bridge/graphcap/providers/clients/gemini_client.py +++ b/servers/inference_bridge/graphcap/providers/clients/gemini_client.py @@ -15,7 +15,6 @@ GeminiClient: Gemini API client implementation """ -import time from typing import Any from loguru import logger @@ -40,8 +39,6 @@ def __init__(self, name: str, kind: str, environment: str, base_url: str, api_ke def _format_vision_content(self, text: str, image_data: str) -> list[dict[str, Any]]: """Format vision content for Gemini API""" # TODO: Add feature flag to handle gemini free tier rate limits instead of this hack - logger.info("Sleeping for 3 seconds to avoid rate limits") - time.sleep(3) return [ {"type": "text", "text": text}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}}, From cbd92b31a2aa3374d45017d77b0177d706f85f57 Mon Sep 17 00:00:00 2001 From: jphillips Date: Sat, 29 Mar 2025 14:46:37 -0500 Subject: [PATCH 37/69] Fix cache key for perspectives panel Signed-off-by: jphillips --- .../PerspectiveFilterPanel.tsx | 103 ------------------ .../PerspectiveModuleFilter.tsx | 4 +- .../components/PerspectiveManagement/index.ts | 12 +- .../features/perspectives/components/index.ts | 8 +- .../hooks/usePerspectiveModules.ts | 1 + 5 files changed, 13 insertions(+), 115 deletions(-) delete mode 100644 graphcap_studio/src/features/perspectives/components/PerspectiveManagement/PerspectiveFilterPanel.tsx diff --git a/graphcap_studio/src/features/perspectives/components/PerspectiveManagement/PerspectiveFilterPanel.tsx b/graphcap_studio/src/features/perspectives/components/PerspectiveManagement/PerspectiveFilterPanel.tsx deleted file mode 100644 index 177d6f2c..00000000 --- a/graphcap_studio/src/features/perspectives/components/PerspectiveManagement/PerspectiveFilterPanel.tsx +++ /dev/null @@ -1,103 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * Perspective Filter Panel - * - * This component provides a UI for toggling the visibility of different perspectives. - */ - -import { Checkbox } from "@/components/ui/checkbox"; -import { - Box, - Button, - Flex, - HStack, - Heading, - Text, - VStack, -} from "@chakra-ui/react"; -import { useMemo } from "react"; -import { usePerspectivesData } from "../../context/PerspectivesDataContext"; - -/** - * Component for filtering which perspectives are visible in the UI - */ -export function PerspectiveFilterPanel() { - const { - perspectives, - hiddenPerspectives, - togglePerspectiveVisibility, - isPerspectiveVisible, - setAllPerspectivesVisible, - } = usePerspectivesData(); - - // Count how many perspectives are visible/hidden - const counts = useMemo(() => { - const totalCount = perspectives.length; - const hiddenCount = hiddenPerspectives.length; - const visibleCount = totalCount - hiddenCount; - - return { totalCount, hiddenCount, visibleCount }; - }, [perspectives, hiddenPerspectives]); - - return ( - - - - Perspective Visibility - - {counts.visibleCount} of {counts.totalCount} visible - - - - - - - {perspectives.map((perspective) => ( - togglePerspectiveVisibility(perspective.name)} - colorScheme="blue" - size="sm" - > - - {perspective.display_name || perspective.name} - - - ))} - - - - - - - - - - - ); -} diff --git a/graphcap_studio/src/features/perspectives/components/PerspectiveManagement/PerspectiveModuleFilter.tsx b/graphcap_studio/src/features/perspectives/components/PerspectiveManagement/PerspectiveModuleFilter.tsx index 808298cc..762df226 100644 --- a/graphcap_studio/src/features/perspectives/components/PerspectiveManagement/PerspectiveModuleFilter.tsx +++ b/graphcap_studio/src/features/perspectives/components/PerspectiveManagement/PerspectiveModuleFilter.tsx @@ -170,7 +170,7 @@ export function PerspectiveModuleFilter({ alignItems="center" justifyContent="center" color={buttonColor} - width="20px" + width="60px" height="20px" borderWidth="1px" borderColor="currentColor" @@ -178,7 +178,7 @@ export function PerspectiveModuleFilter({ ml={1} _hover={{ bg: hoverBgColor }} > - → + View → diff --git a/graphcap_studio/src/features/perspectives/components/PerspectiveManagement/index.ts b/graphcap_studio/src/features/perspectives/components/PerspectiveManagement/index.ts index 2efe2459..e6565208 100644 --- a/graphcap_studio/src/features/perspectives/components/PerspectiveManagement/index.ts +++ b/graphcap_studio/src/features/perspectives/components/PerspectiveManagement/index.ts @@ -1,10 +1,10 @@ -export { ModuleList } from './PerspectiveModules/ModuleList'; -export { ModuleInfo } from './PerspectiveModules/ModuleInfo'; -export { NotFound } from './NotFound'; export { ErrorDisplay } from './ErrorDisplay'; export { LoadingDisplay } from './LoadingDisplay'; +export { NotFound } from './NotFound'; export { PerspectiveEditor } from './PerspectiveEditor/PerspectiveEditor'; -export { SchemaValidationError } from './SchemaValidationError'; -export { PerspectiveModuleFilter } from './PerspectiveModuleFilter'; export { PerspectiveManagementPanel } from './PerspectiveManagementPanel'; -export { PerspectiveFilterPanel } from './PerspectiveFilterPanel'; +export { PerspectiveModuleFilter } from './PerspectiveModuleFilter'; +export { ModuleInfo } from './PerspectiveModules/ModuleInfo'; +export { ModuleList } from './PerspectiveModules/ModuleList'; +export { SchemaValidationError } from './SchemaValidationError'; + diff --git a/graphcap_studio/src/features/perspectives/components/index.ts b/graphcap_studio/src/features/perspectives/components/index.ts index 07d89562..da4e9e35 100644 --- a/graphcap_studio/src/features/perspectives/components/index.ts +++ b/graphcap_studio/src/features/perspectives/components/index.ts @@ -6,11 +6,11 @@ */ export * from "./PerspectiveCaption/EmptyPerspectives"; -export * from "./PerspectivesErrorState"; -export * from "./PerspectiveManagement/PerspectiveFilterPanel"; export * from "./PerspectiveCaption/ErrorMessage"; +export * from "./PerspectiveCaption/PerspectiveActions"; +export { MetadataDisplay } from "./PerspectiveCaption/PerspectiveCard/MetadataDisplay"; export { PerspectiveHeader } from "./PerspectiveCaption/PerspectiveNavigation/PerspectiveHeader"; export { PerspectivesPager } from "./PerspectiveCaption/PerspectiveNavigation/PerspectivesPager"; -export { MetadataDisplay } from "./PerspectiveCaption/PerspectiveCard/MetadataDisplay"; -export * from "./PerspectiveCaption/PerspectiveActions"; export * from "./PerspectiveManagement/PerspectiveManagementPanel"; +export * from "./PerspectivesErrorState"; + diff --git a/graphcap_studio/src/features/perspectives/hooks/usePerspectiveModules.ts b/graphcap_studio/src/features/perspectives/hooks/usePerspectiveModules.ts index 3ea188b5..01355eeb 100644 --- a/graphcap_studio/src/features/perspectives/hooks/usePerspectiveModules.ts +++ b/graphcap_studio/src/features/perspectives/hooks/usePerspectiveModules.ts @@ -12,6 +12,7 @@ import { createInferenceBridgeClient } from "@/features/server-connections/servi import type { ModuleInfo, ModuleListResponse, Perspective, PerspectiveModule } from "@/types"; import { useQuery } from "@tanstack/react-query"; import { useEffect, useMemo } from "react"; +import { CACHE_TIMES, perspectivesQueryKeys } from "../services/constants"; import { handleApiError } from "../services/utils"; import { PerspectiveError } from "./usePerspectives"; From bfd9b872385b97a9b729dd9ce09cbf4ed5294705 Mon Sep 17 00:00:00 2001 From: jphillips Date: Sat, 29 Mar 2025 15:34:19 -0500 Subject: [PATCH 38/69] Split up large fn, move id to model name Signed-off-by: jphillips --- .cursor/rules/task.mdc | 11 +- .../components/form/ModelSelectionSection.tsx | 5 +- graphcap_studio/src/utils/error-handler.ts | 68 +++--- .../features/providers/error_handler.py | 195 +++++++++++------- 4 files changed, 173 insertions(+), 106 deletions(-) diff --git a/.cursor/rules/task.mdc b/.cursor/rules/task.mdc index 6b3bb799..e009d5c8 100644 --- a/.cursor/rules/task.mdc +++ b/.cursor/rules/task.mdc @@ -4,13 +4,4 @@ globs: alwaysApply: true --- # Task -Provider configuration still relies on some file based configuration. - -### Describe the solution you'd like -- Utilize the UI / DB for provider configuration -- Remove all old provider handling code -- Resolve issues cited in discord related to provider ux - - -[provider_manager.py](mdc:servers/inference_bridge/graphcap/providers/provider_manager.py) -[provider_config.py](mdc:servers/inference_bridge/graphcap/providers/provider_config.py) +Place your task for the agent here. \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelectionSection.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelectionSection.tsx index 7bc877b5..a61502e7 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelectionSection.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelectionSection.tsx @@ -39,9 +39,10 @@ export function ModelSelectionSection() { if (customModels && customModels.length > 0) { // Map custom models to the format expected by the model selector for (const model of customModels) { + // Generate a stable ID for custom models allModels.push({ // Generate a stable ID for custom models - id: typeof model.id === 'string' ? model.id : `custom-${model.name}`, + id: `${model.name}`, name: model.name, is_default: provider?.defaultModel === model.name, isCustom: true @@ -78,7 +79,7 @@ export function ModelSelectionSection() { // When in edit mode, show model management section if (isEditMode) { return ( - + Model Configuration diff --git a/graphcap_studio/src/utils/error-handler.ts b/graphcap_studio/src/utils/error-handler.ts index 3ce3bd6f..9f475a78 100644 --- a/graphcap_studio/src/utils/error-handler.ts +++ b/graphcap_studio/src/utils/error-handler.ts @@ -18,6 +18,47 @@ interface ServerErrorResponse { validationErrors?: Record; } +/** + * Extracts a message from a validation error object + */ +function extractValidationErrorMessage(validationErrors: Record): string | null { + const validationMessages: string[] = []; + + for (const [field, errors] of Object.entries(validationErrors)) { + for (const errorMsg of errors) { + validationMessages.push(`${field}: ${errorMsg}`); + } + } + + if (validationMessages.length > 0) { + return `Validation errors:\n${validationMessages.join('\n')}`; + } + + return null; +} + +/** + * Extracts message from a server error object + */ +function extractServerErrorMessage(serverError: ServerErrorResponse): string | null { + // If there's a message, use it + if (serverError.message) { + return serverError.message; + } + + // If there are validation errors, format them + if (serverError.validationErrors) { + return extractValidationErrorMessage(serverError.validationErrors); + } + + // If there's an error property with a message (common in Axios errors) + if ('error' in serverError && typeof serverError.error === 'string') { + return serverError.error; + } + + return null; +} + /** * Formats a server error response into a human-readable message */ @@ -30,30 +71,9 @@ export function formatServerError(error: unknown): string { // Try to handle server error response if (error && typeof error === 'object') { const serverError = error as ServerErrorResponse; - - // If there's a message, use it - if (serverError.message) { - return serverError.message; - } - - // If there are validation errors, format them - if (serverError.validationErrors) { - const validationMessages: string[] = []; - - for (const [field, errors] of Object.entries(serverError.validationErrors)) { - for (const errorMsg of errors) { - validationMessages.push(`${field}: ${errorMsg}`); - } - } - - if (validationMessages.length > 0) { - return `Validation errors:\n${validationMessages.join('\n')}`; - } - } - - // If there's an error property with a message (common in Axios errors) - if ('error' in serverError && typeof serverError.error === 'string') { - return serverError.error; + const message = extractServerErrorMessage(serverError); + if (message) { + return message; } } diff --git a/servers/inference_bridge/server/server/features/providers/error_handler.py b/servers/inference_bridge/server/server/features/providers/error_handler.py index ae943988..0b466658 100644 --- a/servers/inference_bridge/server/server/features/providers/error_handler.py +++ b/servers/inference_bridge/server/server/features/providers/error_handler.py @@ -14,31 +14,26 @@ from .models import ProviderConfig -def format_provider_validation_error(e: ValidationError, provider_name: str) -> JSONResponse: - """ - Format a provider validation error into a standardized response. - - Args: - e: The validation error - provider_name: Name of the provider - - Returns: - A JSONResponse with detailed error information - """ - errors = e.errors() - invalid_params = {} - - # Extract field names for the error message +def _extract_invalid_fields(errors) -> Set[str]: + """Extract the set of invalid field names from validation errors.""" invalid_fields: Set[str] = set() for error in errors: - # Get field location loc = error.get("loc", []) if len(loc) > 1: field_name = loc[1] if isinstance(loc[1], str) else str(loc[1]) invalid_fields.add(field_name) - # Format specific error details + return invalid_fields + + +def _build_invalid_params(errors) -> Dict[str, Dict]: + """Build dictionary mapping fields to their error details.""" + invalid_params = {} + + for error in errors: + # Get field location + loc = error.get("loc", []) field = ".".join(str(loc) for loc in error.get("loc", [])) if error.get("loc") else "" message = error.get("msg", "Validation error") error_type = error.get("type", "unknown_error") @@ -57,18 +52,24 @@ def format_provider_validation_error(e: ValidationError, provider_name: str) -> if context: invalid_params[field]["context"] = context - - # Generate appropriate overall message + + return invalid_params + + +def _generate_error_message(invalid_fields: Set[str]) -> str: + """Generate an appropriate error message based on invalid fields.""" if len(invalid_fields) == 1: field = next(iter(invalid_fields)) - message = f"Invalid provider configuration: '{field}' parameter is invalid" + return f"Invalid provider configuration: '{field}' parameter is invalid" elif len(invalid_fields) > 1: field_list = "', '".join(sorted(invalid_fields)) - message = f"Invalid provider configuration: Parameters '{field_list}' are invalid" + return f"Invalid provider configuration: Parameters '{field_list}' are invalid" else: - message = "Invalid provider configuration" - - # Build provider-specific suggestions + return "Invalid provider configuration" + + +def _generate_suggestions(errors) -> list: + """Generate helpful suggestions based on validation errors.""" suggestions = ["Check API key and endpoint URL", "Verify the provider is correctly configured"] for error in errors: @@ -98,6 +99,28 @@ def format_provider_validation_error(e: ValidationError, provider_name: str) -> suggestions.append("Valid environment values are typically 'cloud' or 'local'") suggestions.append("Check server logs for more details") + return list(dict.fromkeys(suggestions)) # Remove duplicates while preserving order + + +def format_provider_validation_error(e: ValidationError) -> JSONResponse: + """ + Format a provider validation error into a standardized response. + + Args: + e: The validation error + + Returns: + A JSONResponse with detailed error information + """ + errors = e.errors() + + # Extract field names and build error details + invalid_fields = _extract_invalid_fields(errors) + invalid_params = _build_invalid_params(errors) + + # Generate appropriate message and suggestions + message = _generate_error_message(invalid_fields) + suggestions = _generate_suggestions(errors) # Build the response error_response = { @@ -107,7 +130,7 @@ def format_provider_validation_error(e: ValidationError, provider_name: str) -> "name": "Error", "details": "The server rejected the request due to invalid provider parameters.", "invalid_parameters": invalid_params, - "suggestions": list(dict.fromkeys(suggestions)) + "suggestions": suggestions } return JSONResponse( @@ -116,6 +139,58 @@ def format_provider_validation_error(e: ValidationError, provider_name: str) -> ) +def _create_safe_config(config: ProviderConfig) -> Dict[str, Any]: + """Create a copy of the config without sensitive information.""" + return { + "kind": config.kind, + "environment": config.environment, + "base_url": config.base_url, + "models": config.models, + "fetch_models": config.fetch_models, + } + + +def _determine_error_code(error_message: str) -> str: + """Determine the error code based on the error message.""" + error_message = error_message.lower() + + if "authentication failed" in error_message or "unauthorized" in error_message: + return "AUTH_ERROR" + elif "not found" in error_message or "404" in error_message: + return "ENDPOINT_NOT_FOUND" + elif "timeout" in error_message: + return "TIMEOUT" + elif "connection" in error_message: + return "CONNECTION_ERROR" + elif "rate limit" in error_message or "too many requests" in error_message: + return "RATE_LIMIT" + elif "quota" in error_message or "exceeded" in error_message: + return "QUOTA_EXCEEDED" + else: + return "UNKNOWN_ERROR" + + +def _generate_connection_suggestions(error_code: str) -> list: + """Generate suggestions based on the error code.""" + suggestions = ["Check API key and endpoint URL", "Verify the provider is correctly configured"] + + if error_code == "AUTH_ERROR": + suggestions.append("Check if the API key is valid and has the necessary permissions") + elif error_code == "ENDPOINT_NOT_FOUND": + suggestions.append("Verify the base URL is correct for this provider") + elif error_code == "TIMEOUT": + suggestions.append("The server took too long to respond. Check network connectivity or try again later") + elif error_code == "CONNECTION_ERROR": + suggestions.append("Failed to establish connection to the provider. Check network connectivity") + elif error_code == "RATE_LIMIT": + suggestions.append("You've exceeded the provider's rate limits. Try again later") + elif error_code == "QUOTA_EXCEEDED": + suggestions.append("You've exceeded your provider quota. Check your usage dashboard") + + suggestions.append("Check server logs for more details") + return suggestions + + def format_provider_connection_error(e: Exception, provider_name: str, config: ProviderConfig) -> JSONResponse: """ Format a provider connection error into a standardized response. @@ -128,58 +203,38 @@ def format_provider_connection_error(e: Exception, provider_name: str, config: P Returns: A JSONResponse with detailed error information """ - # Create a detailed error response - error_response: Dict[str, Any] = { + # Get the error message as string + error_message = str(e) + + # Determine error code + error_code = _determine_error_code(error_message) + + # Create safe configuration without sensitive data + safe_config = _create_safe_config(config) + + # Generate provider details + provider_details = { + "provider": provider_name, + "error_type": type(e).__name__, + "error_code": error_code, + "config": safe_config + } + + # Generate suggestions + suggestions = _generate_connection_suggestions(error_code) + + # Create the error response + error_response = { "title": "Connection failed", "timestamp": datetime.datetime.now().isoformat(), "status": "error", - "message": str(e), + "message": error_message, "name": "Error", "details": "Failed to connect to the provider service.", - "provider_details": { - "provider": provider_name, - "error_type": type(e).__name__, - } - } - - # Add any configuration info that might be helpful for debugging - # but exclude sensitive data like API keys - safe_config = { - "kind": config.kind, - "environment": config.environment, - "base_url": config.base_url, - "models": config.models, - "fetch_models": config.fetch_models, + "provider_details": provider_details, + "suggestions": suggestions } - error_response["provider_details"]["config"] = safe_config - # Create specific suggestions for common issues - suggestions = ["Check API key and endpoint URL", "Verify the provider is correctly configured"] - - if "authentication failed" in str(e).lower() or "unauthorized" in str(e).lower(): - error_response["provider_details"]["error_code"] = "AUTH_ERROR" - suggestions.append("Check if the API key is valid and has the necessary permissions") - elif "not found" in str(e).lower() or "404" in str(e).lower(): - error_response["provider_details"]["error_code"] = "ENDPOINT_NOT_FOUND" - suggestions.append("Verify the base URL is correct for this provider") - elif "timeout" in str(e).lower(): - error_response["provider_details"]["error_code"] = "TIMEOUT" - suggestions.append("The server took too long to respond. Check network connectivity or try again later") - elif "connection" in str(e).lower(): - error_response["provider_details"]["error_code"] = "CONNECTION_ERROR" - suggestions.append("Failed to establish connection to the provider. Check network connectivity") - elif "rate limit" in str(e).lower() or "too many requests" in str(e).lower(): - error_response["provider_details"]["error_code"] = "RATE_LIMIT" - suggestions.append("You've exceeded the provider's rate limits. Try again later") - elif "quota" in str(e).lower() or "exceeded" in str(e).lower(): - error_response["provider_details"]["error_code"] = "QUOTA_EXCEEDED" - suggestions.append("You've exceeded your provider quota. Check your usage dashboard") - else: - error_response["provider_details"]["error_code"] = "UNKNOWN_ERROR" - - suggestions.append("Check server logs for more details") - error_response["suggestions"] = suggestions - # Return a structured error response with HTTP 400 status return JSONResponse( status_code=400, From 57333acf0aa0013dae1e5a1a45d8034d9f305309 Mon Sep 17 00:00:00 2001 From: jphillips Date: Sat, 29 Mar 2025 16:18:48 -0500 Subject: [PATCH 39/69] Extract api key logic from controller Signed-off-by: jphillips --- .../provider_config/api-key-manager.ts | 71 +++++++++++ .../features/provider_config/controller.ts | 111 +----------------- .../src/features/provider_config/schemas.ts | 1 - 3 files changed, 76 insertions(+), 107 deletions(-) create mode 100644 servers/data_service/src/features/provider_config/api-key-manager.ts diff --git a/servers/data_service/src/features/provider_config/api-key-manager.ts b/servers/data_service/src/features/provider_config/api-key-manager.ts new file mode 100644 index 00000000..af233074 --- /dev/null +++ b/servers/data_service/src/features/provider_config/api-key-manager.ts @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * API Key Manager + * + * This module handles API key encryption, decryption, and management operations. + */ + +import { decryptApiKey, encryptApiKey } from "../../utils/encryption"; +import { logger } from "../../utils/logger"; +import type { Provider } from "./schemas"; + +// Simple type for objects with an optional API key property +export type WithApiKey = { + apiKey: string | null | undefined; +}; + +/** + * Processes an API key for update operations + * Determines if we should use existing key, encrypt a new key, or clear the key + */ +export const processApiKeyForUpdate = async ( + currentProvider: WithApiKey, + newApiKeyValue: string | undefined | null +): Promise => { + // CASE 1: Keep existing key if no new value provided + if (newApiKeyValue === undefined) { + logger.debug("Keeping existing API key - no change requested"); + return currentProvider.apiKey ?? null; + } + + // CASE 2: Explicitly clear the key + if (newApiKeyValue === null || newApiKeyValue === "") { + logger.debug("API key explicitly cleared in update"); + return null; + } + + // CASE 3: Encrypt new key value + logger.debug("Encrypting new API key for provider update"); + const encryptedKey = await encryptApiKey(newApiKeyValue); + logger.debug("API key encrypted for update"); + return encryptedKey; +}; + +/** + * Safely decrypts a provider's API key for client response + */ +export const decryptProviderApiKey = async ( + provider: Provider +): Promise => { + const providerCopy = { ...provider }; + + if (providerCopy.apiKey) { + logger.debug({ + providerId: provider.id, + encryptedKeyLength: providerCopy.apiKey.length + }, "Decrypting API key for provider"); + + providerCopy.apiKey = await decryptApiKey(providerCopy.apiKey); + + // Log the result of decryption (without showing the actual key) + logger.debug({ + providerId: provider.id, + apiKeyPresent: Boolean(providerCopy.apiKey), + apiKeyLength: providerCopy.apiKey ? providerCopy.apiKey.length : 0 + }, "Provider API key decryption result"); + } else { + logger.debug({ providerId: provider.id }, "No API key to decrypt for provider"); + } + + return providerCopy; +}; \ No newline at end of file diff --git a/servers/data_service/src/features/provider_config/controller.ts b/servers/data_service/src/features/provider_config/controller.ts index 02a3b0a1..67c0aa2e 100644 --- a/servers/data_service/src/features/provider_config/controller.ts +++ b/servers/data_service/src/features/provider_config/controller.ts @@ -10,9 +10,8 @@ import type { Context } from "hono"; import { db } from "../../db"; import { providerModels, providerRateLimits, providers } from "../../db/schema"; import { decryptApiKey, encryptApiKey } from "../../utils/encryption"; -import { logger } from "../../utils/logger"; +import { processApiKeyForUpdate } from "./api-key-manager"; import type { - ProviderApiKey, ProviderCreate, ProviderUpdate, } from "./schemas"; @@ -46,7 +45,7 @@ export const getProviders = async (c: Context) => { // Log whether API key is present after decryption (without showing the actual key) logger.debug({ providerId: provider.id, - apiKeyPresent: provider.apiKey ? true : false, + apiKeyPresent: Boolean(provider.apiKey), apiKeyLength: provider.apiKey ? provider.apiKey.length : 0 }, "Provider API key decryption result"); } else { @@ -119,7 +118,7 @@ export const getProvider = async (c: Context) => { // Log the result of decryption (without showing the actual key) logger.debug({ providerId: id, - apiKeyPresent: provider.apiKey ? true : false, + apiKeyPresent: Boolean(provider.apiKey), apiKeyLength: provider.apiKey ? provider.apiKey.length : 0 }, "Provider API key decryption result"); } else { @@ -477,23 +476,8 @@ export const updateProvider = async (c: Context) => { throw new Error(`Provider not found with id ${id}`); } - // CRITICAL FIX: Handle API key specially to avoid losing it - let apiKeyToUse = currentProvider.apiKey; // Default to keeping existing key - - // Only update the API key if it's explicitly included in the update data - if ('apiKey' in providerData && providerData.apiKey !== undefined) { - if (providerData.apiKey) { - logger.debug({ providerId: id }, "Encrypting new API key for provider update"); - apiKeyToUse = await encryptApiKey(providerData.apiKey as string); - logger.info({ providerId: id }, "API key encrypted for update"); - } else { - // If apiKey is explicitly set to an empty value, only then clear it - logger.debug({ providerId: id }, "API key explicitly cleared in update"); - apiKeyToUse = null; - } - } else { - logger.debug({ providerId: id }, "Keeping existing API key - no change requested"); - } + // Use the API key manager to handle API key updates + const apiKeyToUse = await processApiKeyForUpdate(currentProvider, providerData.apiKey); // Update provider with the appropriate API key await tx @@ -636,88 +620,3 @@ export const deleteProvider = async (c: Context) => { return c.json({ error: "Failed to delete provider" }, 500); } }; - -/** - * Update a provider's API key - */ -export const updateProviderApiKey = async (c: Context) => { - try { - // @ts-ignore - Hono OpenAPI validation types are not properly recognized - const { id } = c.req.valid("param") as ValidatedParams; - // @ts-ignore - Hono OpenAPI validation types are not properly recognized - const { apiKey } = c.req.valid("json") as ProviderApiKey; - logger.debug({ id }, "Updating provider API key"); - - // Check if provider exists - const existingProvider = await db.query.providers.findFirst({ - where: eq(providers.id, Number.parseInt(id)), - }); - - if (!existingProvider) { - logger.debug({ id }, "Provider not found for API key update"); - return c.json({ - status: "error", - statusCode: 404, - message: "Provider not found", - timestamp: new Date().toISOString(), - path: c.req.path - }, 404); - } - - // Validate API key - const validationErrors: Record = {}; - - if (!apiKey || apiKey.trim() === '') { - validationErrors.apiKey = ['API key cannot be empty']; - } - - // If there are validation errors, return them - if (Object.keys(validationErrors).length > 0) { - logger.debug({ validationErrors }, "API key validation errors"); - return c.json({ - status: "error", - statusCode: 400, - message: "Validation failed", - timestamp: new Date().toISOString(), - path: c.req.path, - validationErrors - }, 400); - } - - // Encrypt API key - const encryptedApiKey = await encryptApiKey(apiKey); - - // Update API key - await db - .update(providers) - .set({ - apiKey: encryptedApiKey, - updatedAt: new Date(), - }) - .where(eq(providers.id, Number.parseInt(id))); - - logger.debug({ id }, "Provider API key updated successfully"); - return c.json({ - success: true, - message: "API key updated successfully", - }); - } catch (error) { - const providerId = c.req.param('id'); - logger.error({ - error, - message: error instanceof Error ? error.message : "Unknown error", - stack: error instanceof Error ? error.stack : undefined, - providerId - }, "Error updating provider API key"); - - // Return detailed error response - return c.json({ - status: "error", - statusCode: 500, - message: error instanceof Error ? error.message : "Failed to update API key", - timestamp: new Date().toISOString(), - path: c.req.path, - details: error instanceof Error ? { name: error.name } : undefined - }, 500); - } -}; diff --git a/servers/data_service/src/features/provider_config/schemas.ts b/servers/data_service/src/features/provider_config/schemas.ts index c832715c..0435e744 100644 --- a/servers/data_service/src/features/provider_config/schemas.ts +++ b/servers/data_service/src/features/provider_config/schemas.ts @@ -79,7 +79,6 @@ export const providerUpdateSchema = z.object({ }).optional(), }); - // Export types export type Provider = z.infer; export type ProviderCreate = z.infer; From c3982a8815ca0d9c86bc23a1fb15012b2931a400 Mon Sep 17 00:00:00 2001 From: jphillips Date: Sat, 29 Mar 2025 17:08:58 -0500 Subject: [PATCH 40/69] Default resize image to skip Signed-off-by: jphillips --- .../server/features/perspectives/router.py | 48 ++++++++++--------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/servers/inference_bridge/server/server/features/perspectives/router.py b/servers/inference_bridge/server/server/features/perspectives/router.py index f177ab30..84fd1ebb 100644 --- a/servers/inference_bridge/server/server/features/perspectives/router.py +++ b/servers/inference_bridge/server/server/features/perspectives/router.py @@ -114,25 +114,29 @@ async def create_caption( try: resolution = ResolutionPreset[resize_resolution] except (KeyError, ValueError): - logger.warning(f"Invalid resolution: {resize_resolution}. Using HD_720P.") - resolution = ResolutionPreset.HD_720P - - # Create temporary file for resized image - suffix = os.path.splitext(image_path)[1] - fd, resized_path = tempfile.mkstemp(suffix=suffix) - os.close(fd) - - # Resize the image - logger.info(f"Resizing image to {resolution.name} ({resolution.value})") - resized_img = resize_image(image_path, resolution) - resized_img.save(resized_path) - - # Add cleanup task for original image - background_tasks.add_task(lambda: os.unlink(image_path) if os.path.exists(image_path) else None) - - # Use the resized image - image_path = Path(resized_path) - logger.info(f"Image resized successfully to {resolution.name}") + logger.warning(f"Invalid resolution: {resize_resolution}. Skipping resize.") + # Skip resizing and continue with original image + continue_resize = False + else: + continue_resize = True + + if continue_resize: + # Create temporary file for resized image + suffix = os.path.splitext(image_path)[1] + fd, resized_path = tempfile.mkstemp(suffix=suffix) + os.close(fd) + + # Resize the image + logger.info(f"Resizing image to {resolution.name} ({resolution.value})") + resized_img = resize_image(image_path, resolution) + resized_img.save(resized_path) + + # Add cleanup task for original image + background_tasks.add_task(lambda: os.unlink(image_path) if os.path.exists(image_path) else None) + + # Use the resized image + image_path = Path(resized_path) + logger.info(f"Image resized successfully to {resolution.name}") except Exception as e: logger.error(f"Error resizing image: {str(e)}") logger.warning("Using original image instead") @@ -294,9 +298,9 @@ async def _resize_image_if_needed(image_path: Path, resize_resolution: Optional[ try: resolution = ResolutionPreset[resize_resolution] except (KeyError, ValueError): - logger.warning(f"Invalid resolution: {resize_resolution}. Using HD_720P.") - resolution = ResolutionPreset.HD_720P - + logger.warning(f"Invalid resolution: {resize_resolution}. Skipping resize.") + return image_path, temp_path + # Create temporary file for resized image suffix = os.path.splitext(str(image_path))[1] fd, resized_path = tempfile.mkstemp(suffix=suffix) From fa95f8c0b470841cbcc8c6be8ea6b56a6f376efd Mon Sep 17 00:00:00 2001 From: jphillips Date: Sat, 29 Mar 2025 18:49:07 -0500 Subject: [PATCH 41/69] Resolve SQ issue batch Signed-off-by: jphillips --- .../features/inference/services/providers.ts | 20 +++---------------- .../pipelines/pipelines/providers/assets.py | 7 ++++--- .../server/features/providers/router.py | 7 ++++--- .../inference_bridge/server/server/main.py | 2 +- 4 files changed, 12 insertions(+), 24 deletions(-) diff --git a/graphcap_studio/src/features/inference/services/providers.ts b/graphcap_studio/src/features/inference/services/providers.ts index 6cbf2c9e..9d5ce231 100644 --- a/graphcap_studio/src/features/inference/services/providers.ts +++ b/graphcap_studio/src/features/inference/services/providers.ts @@ -155,14 +155,8 @@ export function useUpdateProvider() { mutationFn: async ({ id, data }: { id: number; data: ProviderUpdate }) => { console.log("Updating provider with data:", data); - // For update operations, we only need to send simple model objects - // Backend will handle ID generation and association const apiData = { ...data }; - // Keep models simple, backend will handle IDs - // No ID conversion needed, only sending name and isEnabled - // Backend will handle the rest - const client = createDataServiceClient(connections); const response = await client.providers[":id"].$put({ param: { id: id.toString() }, @@ -231,14 +225,8 @@ export function useTestProviderConnection() { }: { providerName: string; config: ServerProviderConfig }) => { const client = createInferenceBridgeClient(connections); - // Add console logging to debug console.log("Testing connection with config:", JSON.stringify(config)); - // Make sure api_key is properly set and not null or undefined - if (!config.api_key) { - throw new Error("API key is required for testing provider connection"); - } - const response = await client.providers[":provider_name"][ "test-connection" ].$post({ @@ -250,9 +238,7 @@ export function useTestProviderConnection() { const errorData = await response.json(); console.error("Error response:", errorData); - // Check if this is our enhanced error format if (errorData.status === "error" && errorData.details) { - // Use the structured error data with cause property const error = new Error( errorData.message || "Connection test failed", ) as ErrorWithCause; @@ -270,9 +256,9 @@ export function useTestProviderConnection() { } if (typeof errorData === "object") { - // For raw objects, don't wrap in Error, just throw the object directly - // This prevents "[object Object]" in the error message - throw { ...errorData }; + const error = new Error("Connection test failed") as ErrorWithCause; + error.cause = errorData; + throw error; } // Fallback to simple error diff --git a/servers/inference_bridge/pipelines/pipelines/providers/assets.py b/servers/inference_bridge/pipelines/pipelines/providers/assets.py index 748e1ebb..f3a606c0 100644 --- a/servers/inference_bridge/pipelines/pipelines/providers/assets.py +++ b/servers/inference_bridge/pipelines/pipelines/providers/assets.py @@ -7,6 +7,7 @@ from ..common.resources import ProviderConfigFile +deprecation_msg = "Provider configuration is now managed by the data service" @dg.asset(compute_kind="python", group_name="providers") def provider_list( @@ -15,7 +16,7 @@ def provider_list( """Loads the list of providers (now from data service API).""" # TODO: Call data service API to get providers instead of loading from file # For now, return an empty dictionary to avoid errors - context.log.info("Provider configuration is now managed by the data service") + context.log.info(deprecation_msg) # Sample provider for testing gemini_config = ProviderConfig( @@ -33,7 +34,7 @@ def provider_list( { "num_providers": len(providers), "providers": "gemini: gemini-2.0-flash-exp", - "note": "Provider configuration is now managed by the data service" + "note": deprecation_msg } ) return providers @@ -48,7 +49,7 @@ def default_provider(context: dg.AssetExecutionContext, provider_config_file: Pr context.add_output_metadata( { "selected_provider": selected_provider_name, - "note": "Provider configuration is now managed by the data service" + "note": deprecation_msg } ) return selected_provider_name diff --git a/servers/inference_bridge/server/server/features/providers/router.py b/servers/inference_bridge/server/server/features/providers/router.py index 2a4c2c5a..d267feb5 100644 --- a/servers/inference_bridge/server/server/features/providers/router.py +++ b/servers/inference_bridge/server/server/features/providers/router.py @@ -17,7 +17,8 @@ from pydantic import ValidationError from ...utils.logger import logger -from .error_handler import format_provider_connection_error, format_provider_validation_error +from .error_handler import (format_provider_connection_error, + format_provider_validation_error) from .models import ProviderConfig, ProviderModelsResponse from .service import get_provider_models, test_provider_connection @@ -43,7 +44,7 @@ async def list_provider_models(provider_name: str, config: ProviderConfig) -> Un models = await get_provider_models(provider_name, config) return ProviderModelsResponse(provider=provider_name, models=models) except ValidationError as e: - return format_provider_validation_error(e, provider_name) + return format_provider_validation_error(e) except Exception as e: logger.error(f"Error getting models for {provider_name}: {str(e)}") logger.error(traceback.format_exc()) @@ -69,7 +70,7 @@ async def test_connection(provider_name: str, config: ProviderConfig): result = await test_provider_connection(provider_name, config) return {"status": "success", "message": "Connection successful", "result": result} except ValidationError as e: - return format_provider_validation_error(e, provider_name) + return format_provider_validation_error(e) except Exception as e: logger.error(f"Error testing connection to {provider_name}: {str(e)}") logger.error(traceback.format_exc()) diff --git a/servers/inference_bridge/server/server/main.py b/servers/inference_bridge/server/server/main.py index a5c86905..a7b41725 100644 --- a/servers/inference_bridge/server/server/main.py +++ b/servers/inference_bridge/server/server/main.py @@ -19,7 +19,7 @@ from .utils.middleware import setup_middlewares -class GracefulExit(SystemExit): +class GracefulExit(Exception): """Custom exception for graceful shutdown.""" pass From ff270cc320438c87c5947ef19c23c2b853e6978f Mon Sep 17 00:00:00 2001 From: jphillips Date: Sat, 29 Mar 2025 19:01:30 -0500 Subject: [PATCH 42/69] Remove file upload caption route Signed-off-by: jphillips --- .../features/perspectives/constants/index.ts | 1 - .../server/features/perspectives/models.py | 40 ---- .../server/features/perspectives/router.py | 175 +----------------- .../server/features/perspectives/service.py | 34 +--- 4 files changed, 3 insertions(+), 247 deletions(-) diff --git a/graphcap_studio/src/features/perspectives/constants/index.ts b/graphcap_studio/src/features/perspectives/constants/index.ts index ab668128..35a8f247 100644 --- a/graphcap_studio/src/features/perspectives/constants/index.ts +++ b/graphcap_studio/src/features/perspectives/constants/index.ts @@ -54,7 +54,6 @@ export const perspectivesQueryKeys = { // Constants for API endpoints export const API_ENDPOINTS = { LIST_PERSPECTIVES: "/perspectives/list", - GENERATE_CAPTION: "/perspectives/caption", VIEW_IMAGE: "/images/view", REST_LIST_PERSPECTIVES: "/perspectives/list", REST_GENERATE_CAPTION: "/perspectives/caption-from-path", diff --git a/servers/inference_bridge/server/server/features/perspectives/models.py b/servers/inference_bridge/server/server/features/perspectives/models.py index 30a62a69..777a4350 100644 --- a/servers/inference_bridge/server/server/features/perspectives/models.py +++ b/servers/inference_bridge/server/server/features/perspectives/models.py @@ -7,7 +7,6 @@ from typing import Any, Dict, List, Optional, Union -from fastapi import File, Form, UploadFile from pydantic import BaseModel, Field # Field description constants @@ -143,45 +142,6 @@ class CaptionResponse(BaseModel): raw_text: Optional[str] = Field(None, description="Raw text response from the model") -# Form data model for multipart/form-data requests with file uploads -class CaptionFormRequest: - """Form request model for generating a caption with a perspective using file upload.""" - - def __init__( - self, - perspective: str = Form(..., description=DESC_PERSPECTIVE_NAME), - file: Optional[UploadFile] = File(None, description="Image file to caption"), - url: Optional[str] = Form(None, description="URL of the image to caption"), - base64: Optional[str] = Form(None, description="Base64-encoded image data"), - max_tokens: Optional[int] = Form(4096, description=DESC_MAX_TOKENS), - temperature: Optional[float] = Form(0.8, description=DESC_TEMPERATURE), - top_p: Optional[float] = Form(0.9, description=DESC_TOP_P), - repetition_penalty: Optional[float] = Form(1.15, description=DESC_REPETITION_PENALTY), - global_context: Optional[str] = Form(None, description=DESC_GLOBAL_CONTEXT), - context: Optional[str] = Form(None, description="Additional context for the caption (JSON array string)"), - resize_resolution: Optional[str] = Form(None, description=DESC_RESIZE_RESOLUTION), - ): - self.perspective = perspective - self.file = file - self.url = url - self.base64 = base64 - self.max_tokens = max_tokens - self.temperature = temperature - self.top_p = top_p - self.repetition_penalty = repetition_penalty - self.global_context = global_context - self.resize_resolution = resize_resolution - - # Parse context from JSON string if provided - self.context = None - if context: - import json - - try: - self.context = json.loads(context) - except json.JSONDecodeError: - # If not valid JSON array, treat as a single context item - self.context = [context] class CaptionPathRequest(BaseModel): diff --git a/servers/inference_bridge/server/server/features/perspectives/router.py b/servers/inference_bridge/server/server/features/perspectives/router.py index 84fd1ebb..58d7df3d 100644 --- a/servers/inference_bridge/server/server/features/perspectives/router.py +++ b/servers/inference_bridge/server/server/features/perspectives/router.py @@ -19,8 +19,7 @@ from pathlib import Path from typing import List, Optional -from fastapi import (APIRouter, BackgroundTasks, File, Form, HTTPException, - UploadFile, status) +from fastapi import APIRouter, HTTPException, status from loguru import logger from ...utils.resizing import (ResolutionPreset, log_resize_options, @@ -28,8 +27,7 @@ from .models import (CaptionPathRequest, CaptionResponse, ModuleListResponse, ModulePerspectivesResponse, PerspectiveListResponse) from .service import (generate_caption, get_available_modules, - get_available_perspectives, get_perspectives_by_module, - save_uploaded_file) + get_available_perspectives, get_perspectives_by_module) router = APIRouter(prefix="/perspectives", tags=["perspectives"]) @@ -46,161 +44,6 @@ async def list_perspectives() -> PerspectiveListResponse: return PerspectiveListResponse(perspectives=perspectives) -@router.post("/caption", response_model=CaptionResponse, status_code=status.HTTP_200_OK) -async def create_caption( - background_tasks: BackgroundTasks, - file: UploadFile = File(..., description="Image file to upload"), - perspective: str = Form(..., description="Name of the perspective to use"), - provider: str = Form(..., description="Name of the provider to use"), - provider_config: str = Form(..., description="Provider configuration as JSON string"), - model: str = Form(..., description="Model name to use for processing"), - max_tokens: Optional[int] = Form(4096, description="Maximum number of tokens"), - temperature: Optional[float] = Form(0.8, description="Temperature for generation"), - top_p: Optional[float] = Form(0.9, description="Top-p sampling parameter"), - repetition_penalty: Optional[float] = Form(1.15, description="Repetition penalty"), - global_context: Optional[str] = Form(None, description="Global context for the caption"), - context: Optional[str] = Form(None, description="Additional context for the caption as JSON array string"), - resize_resolution: Optional[str] = Form(None, description="Resolution to resize to (None to disable resizing)"), -) -> CaptionResponse: - """ - Generate a caption for an image using a perspective. - - This endpoint supports file uploads only. - - Args: - background_tasks: Background tasks for cleanup - file: Image file to upload (required) - perspective: Name of the perspective to use (required) - provider: Name of the provider to use (required) - provider_config: Provider configuration as JSON string (required) - model: Model name to use for processing (required) - max_tokens: Maximum number of tokens (optional, default: 4096) - temperature: Temperature for generation (optional, default: 0.8) - top_p: Top-p sampling parameter (optional, default: 0.9) - repetition_penalty: Repetition penalty (optional, default: 1.15) - context: JSON array string of context items (optional) - global_context: Global context string (optional) - resize_resolution: Resolution to resize to (optional, default: None - no resizing) - - Returns: - Generated caption with structured result and optional raw text - - Raises: - HTTPException: If the request is invalid or processing fails - """ - try: - # Parse context from JSON string if provided - parsed_context = _parse_context(context) - - # Parse provider_config from JSON string if provided - try: - parsed_provider_config = json.loads(provider_config) - logger.info(f"Parsed provider configuration for {provider}") - except json.JSONDecodeError as e: - logger.error(f"Invalid provider configuration JSON: {e}") - raise HTTPException(status_code=400, detail=f"Invalid provider configuration JSON: {str(e)}") - - # Process the uploaded file - image_path = await save_uploaded_file(file) - - # Log resize options - options = {"resize_resolution": resize_resolution} - log_resize_options(options) - - # Resize the image if resize_resolution is provided - if resize_resolution: - try: - # Get the resolution enum value - try: - resolution = ResolutionPreset[resize_resolution] - except (KeyError, ValueError): - logger.warning(f"Invalid resolution: {resize_resolution}. Skipping resize.") - # Skip resizing and continue with original image - continue_resize = False - else: - continue_resize = True - - if continue_resize: - # Create temporary file for resized image - suffix = os.path.splitext(image_path)[1] - fd, resized_path = tempfile.mkstemp(suffix=suffix) - os.close(fd) - - # Resize the image - logger.info(f"Resizing image to {resolution.name} ({resolution.value})") - resized_img = resize_image(image_path, resolution) - resized_img.save(resized_path) - - # Add cleanup task for original image - background_tasks.add_task(lambda: os.unlink(image_path) if os.path.exists(image_path) else None) - - # Use the resized image - image_path = Path(resized_path) - logger.info(f"Image resized successfully to {resolution.name}") - except Exception as e: - logger.error(f"Error resizing image: {str(e)}") - logger.warning("Using original image instead") - # Continue with original image if resizing fails - - # Add cleanup task - background_tasks.add_task(lambda: os.unlink(image_path) if os.path.exists(image_path) else None) - - # Validate provider configuration - if not parsed_provider_config: - logger.error(f"No provider configuration provided for {provider}") - raise HTTPException( - status_code=400, - detail=f"Provider configuration not provided for '{provider}'. Please include provider_config in the request." - ) - - # Validate model is provided - if not model: - logger.error(f"No model specified for {provider}") - raise HTTPException( - status_code=400, - detail=f"Model name not provided for '{provider}'. Please include model in the request." - ) - - # Generate the caption - caption_data = await generate_caption( - perspective_name=perspective, - image_path=image_path, - model=model, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - repetition_penalty=repetition_penalty, - context=parsed_context, - global_context=global_context, - provider_name=provider, - provider_config=parsed_provider_config, - ) - - # Log the caption data for debugging - logger.debug(f"Caption data: {caption_data}") - - # Extract the parsed result and raw text - parsed_result = caption_data.get("parsed", {}) - raw_text = caption_data.get("raw_text") - - # If parsed result is empty but raw_text exists, try to create a basic result - if not parsed_result and raw_text: - logger.warning("Parsed result is empty but raw_text exists. Creating basic result.") - parsed_result = {"text": raw_text} - - # Return the response - return CaptionResponse( - perspective=perspective, - provider=provider, - result=parsed_result, - raw_text=raw_text, - ) - except Exception as e: - logger.error(f"Error creating caption: {str(e)}") - if isinstance(e, HTTPException): - raise - raise HTTPException(status_code=500, detail=f"Error creating caption: {str(e)}") - @router.post("/caption-from-path", response_model=CaptionResponse, status_code=status.HTTP_200_OK) async def create_caption_from_path( @@ -375,20 +218,6 @@ def _prepare_caption_response(caption_data: dict, perspective: str, provider: st ) -def _parse_context(context_str) -> Optional[List[str]]: - """Parse context from a JSON string.""" - if not context_str or not isinstance(context_str, str): - return None - - try: - context = json.loads(context_str) - if isinstance(context, list): - return context - return [context_str] - except json.JSONDecodeError: - return [context_str] - - @router.get("/modules", response_model=ModuleListResponse) diff --git a/servers/inference_bridge/server/server/features/perspectives/service.py b/servers/inference_bridge/server/server/features/perspectives/service.py index 9af3cab9..64f5788d 100644 --- a/servers/inference_bridge/server/server/features/perspectives/service.py +++ b/servers/inference_bridge/server/server/features/perspectives/service.py @@ -13,7 +13,7 @@ from typing import Dict, List, Optional import aiohttp -from fastapi import HTTPException, UploadFile +from fastapi import HTTPException from loguru import logger from graphcap.perspectives import get_perspective, get_perspective_list @@ -373,35 +373,3 @@ async def generate_caption( raise HTTPException(status_code=500, detail=f"Error generating caption: {str(e)}") -async def save_uploaded_file(file: UploadFile) -> Path: - """ - Save an uploaded file to a temporary location. - - Args: - file: Uploaded file object - - Returns: - Path to the saved file - - Raises: - HTTPException: If the file cannot be saved - """ - try: - # Create a temporary file with appropriate extension - suffix = os.path.splitext(file.filename)[1] if file.filename else ".jpg" - fd, temp_path = tempfile.mkstemp(suffix=suffix) - os.close(fd) - temp_file = Path(temp_path) - - # Save the uploaded file - content = await file.read() - with open(temp_file, "wb") as f: - f.write(content) - - # Reset file pointer for potential future reads - await file.seek(0) - - return temp_file - except Exception as e: - logger.error(f"Error saving uploaded file: {str(e)}") - raise HTTPException(status_code=400, detail=f"Error saving uploaded file: {str(e)}") From c689b9b74540782417e210945cb102a6ad7c4d9e Mon Sep 17 00:00:00 2001 From: jphillips Date: Sat, 29 Mar 2025 19:14:03 -0500 Subject: [PATCH 43/69] Deprecate old batch processing method in gcap core Signed-off-by: jphillips --- .../graphcap/perspectives/base_caption.py | 221 +----------------- .../pipelines/perspectives/assets.py | 142 ++++++++++- .../server/features/perspectives/service.py | 57 ++--- 3 files changed, 162 insertions(+), 258 deletions(-) diff --git a/servers/inference_bridge/graphcap/perspectives/base_caption.py b/servers/inference_bridge/graphcap/perspectives/base_caption.py index 2d1643c7..15a5d761 100644 --- a/servers/inference_bridge/graphcap/perspectives/base_caption.py +++ b/servers/inference_bridge/graphcap/perspectives/base_caption.py @@ -5,20 +5,15 @@ Provides base classes and shared functionality for different caption types. """ -import asyncio import json -import shutil from abc import ABC, abstractmethod -from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional from loguru import logger from pydantic import BaseModel from rich.console import Console from rich.table import Table -from tenacity import retry, stop_after_attempt, wait_exponential -from tqdm.asyncio import tqdm_asyncio from ..providers.clients.base_client import BaseClient from .types import StructuredVisionConfig @@ -184,218 +179,8 @@ async def process_single( except Exception as e: raise Exception(f"Error processing {image_path}: {str(e)}") - async def process_batch( - self, - provider: BaseClient, - image_paths: List[Path], - model: str, - max_tokens: Optional[int] = 4096, - temperature: Optional[float] = 0.8, - top_p: Optional[float] = 0.9, - max_concurrent: Optional[int] = 1, - repetition_penalty: Optional[float] = 1.15, - output_dir: Optional[Path] = None, - store_logs: bool = False, - formats: Optional[List[str]] = None, - copy_images: bool = False, - global_context: str | None = None, - contexts: dict[str, list[str]] | None = None, - name: str | None = None, - ) -> List[Dict[str, Any]]: - """ - Process multiple images and return their captions. - - Args: - provider: Vision AI provider client instance - image_paths: List of paths to image files - model: Model name to use for processing - max_tokens: Maximum tokens for model response - temperature: Sampling temperature - top_p: Nucleus sampling parameter - max_concurrent: Maximum number of concurrent API requests - output_dir: Directory to store incremental results and job info - store_logs: Whether to store logs in the output directory - formats: List of additional formats to write caption data - copy_images: Whether to copy images to the output directory - contexts: Additional context for the vision model based on image paths - Returns: - List[Dict[str, Any]]: List of caption results with metadata - """ - # Create job directory with timestamp if output_dir provided - job_dir = None - job_output = None - job_info = None - log_file = None - - if output_dir: - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - job_dir = output_dir / f"batch_{name or timestamp}" - job_dir.mkdir(parents=True, exist_ok=True) - - # Create output file and job info - job_output = job_dir / "captions.jsonl" - job_info = job_dir / "job_info.json" - - # Configure logging if requested - if store_logs: - log_file = job_dir / "process.log" - # Add file logger while keeping console output - logger.add( - log_file, - format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}", - level="INFO", - rotation="100 MB", - ) - - # Write initial job info - job_info_data = { - "started_at": timestamp, - "provider": provider.name, - "model": model, - "config_name": self.vision_config.config_name, - "version": self.vision_config.version, - "total_images": len(image_paths), - "sampling": { - "original_count": getattr(image_paths, "original_count", len(image_paths)), - "sample_size": getattr(image_paths, "sample_size", len(image_paths)), - "sample_method": getattr(image_paths, "sample_method", "all"), - }, - "params": { - "max_tokens": max_tokens, - "temperature": temperature, - "top_p": top_p, - "max_concurrent": max_concurrent, - "repetition_penalty": repetition_penalty, - }, - "log_file": str(log_file.relative_to(job_dir)) if log_file else None, - "formats": formats or [], - "copy_images": copy_images, - "global_context": global_context, - } - job_info.write_text(json.dumps(job_info_data, indent=2)) - - # Copy images if requested - if copy_images: - images_dir = job_dir / "images" - images_dir.mkdir(exist_ok=True) - for path in image_paths: - try: - shutil.copy2(path, images_dir / path.name) - except Exception as e: - logger.error(f"Failed to copy image {path}: {e}") - - logger.info(f"Processing {len(image_paths)} images with {provider.name} provider") - logger.info(f"Using max concurrency of {max_concurrent} requests") - if job_dir: - logger.info(f"Writing results to {job_dir}") - if log_file: - logger.info(f"Logging to {log_file}") - - semaphore = asyncio.Semaphore(max_concurrent) - active_requests = 0 - processed_count = 0 - failed_count = 0 - - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - reraise=True, - ) - async def process_with_semaphore(path: Path) -> Dict[str, Any]: - nonlocal active_requests, processed_count, failed_count - - async with semaphore: - try: - active_requests += 1 - logger.info(f"Starting request for {path.name} (Active requests: {active_requests})") - - result = await self.process_single( - provider=provider, - image_path=path, - model=model, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - repetition_penalty=repetition_penalty, - context=contexts.get(path.name) if contexts else None, - global_context=global_context, - ) - - active_requests -= 1 - processed_count += 1 - logger.info(f"Completed request for {path.name} (Active requests: {active_requests})") - - caption_data = { - "filename": f"./{path.name}", - "config_name": self.vision_config.config_name, - "version": self.vision_config.version, - "model": model, - "provider": provider.name, - "parsed": result, - } - - # Write result incrementally if output file exists - if job_output: - with job_output.open("a") as f: - f.write(json.dumps(caption_data) + "\n") - - # Update job info - job_info_data["processed_count"] = processed_count - job_info_data["failed_count"] = failed_count - job_info_data["completed_at"] = datetime.now().strftime("%Y%m%d_%H%M%S") - job_info.write_text(json.dumps(job_info_data, indent=2)) - - # Create and display Rich table - console.print(f"\n[bold cyan]Processed {path.name}:[/bold cyan]") - table = self.create_rich_table(caption_data) - console.print(table) - - return caption_data - except Exception as e: - active_requests -= 1 - failed_count += 1 - logger.error(f"Failed request for {path.name} (Active requests: {active_requests})") - error_data = { - "filename": f"./{path.name}", - "config_name": self.vision_config.config_name, - "version": self.vision_config.version, - "model": model, - "provider": provider.name, - "parsed": {"error": str(e)}, - } - - # Write error result if output file exists - if job_output: - with job_output.open("a") as f: - f.write(json.dumps(error_data) + "\n") - - # Update job info - job_info_data["processed_count"] = processed_count - job_info_data["failed_count"] = failed_count - job_info_data["completed_at"] = datetime.now().strftime("%Y%m%d_%H%M%S") - job_info.write_text(json.dumps(job_info_data, indent=2)) - - console.print(f"\n[bold red]Failed to process {path.name}:[/bold red] {str(e)}") - return error_data - - results = await tqdm_asyncio.gather( - *[process_with_semaphore(path) for path in image_paths], - desc=f"Processing images with {provider.name}", - ) - - # Log summary with Rich - success_count = sum(1 for r in results if "error" not in r["parsed"]) - summary_table = Table(title="Processing Summary", show_header=False) - summary_table.add_column("Metric", style="cyan") - summary_table.add_column("Value", style="green") - summary_table.add_row("Total Images", str(len(results))) - summary_table.add_row("Successful", str(success_count)) - summary_table.add_row("Failed", str(len(results) - success_count)) - - console.print("\n") - console.print(summary_table) - - return results + # Note: process_batch has been removed as batch processing is being migrated to Kafka. + # Batch processing functionality should now be implemented in Kafka-based pipeline components. @abstractmethod def to_table(self, caption_data: Dict[str, Any]) -> Dict[str, Any]: diff --git a/servers/inference_bridge/pipelines/pipelines/perspectives/assets.py b/servers/inference_bridge/pipelines/pipelines/perspectives/assets.py index ef1f4a2e..1b4fef19 100644 --- a/servers/inference_bridge/pipelines/pipelines/perspectives/assets.py +++ b/servers/inference_bridge/pipelines/pipelines/perspectives/assets.py @@ -1,18 +1,143 @@ # SPDX-License-Identifier: Apache-2.0 """Assets and ops for basic text captioning.""" +import asyncio +import json from datetime import datetime from pathlib import Path from typing import Any, Dict, List import dagster as dg import pandas as pd +from loguru import logger +from tqdm.asyncio import tqdm_asyncio + from graphcap.perspectives import get_perspective, get_synthesizer from ..common.logging import write_caption_results from ..perspectives.jobs.config import PerspectivePipelineConfig from ..providers.util import get_provider +# File constants +JOB_INFO_FILENAME = "job_info.json" +CAPTIONS_FILENAME = "captions.jsonl" + +# Temporary batch processing function to replace BaseCaptionProcessor.process_batch +# This will be replaced with Kafka-based processing in the future +async def process_images_in_batch( + processor, + provider, + image_paths, + model="gemini-2.0-flash-exp", + max_tokens=4096, + temperature=0.8, + top_p=0.9, + repetition_penalty=1.15, + max_concurrent=3, + output_dir=None, + global_context=None, + contexts=None, + name=None, +): + """ + Temporary batch processing function to replace BaseCaptionProcessor.process_batch. + Will be replaced with Kafka-based processing in the future. + + Processes multiple images by calling process_single for each in parallel. + """ + logger.info(f"[DEPRECATED] Processing {len(image_paths)} images with {provider.name}") + logger.info(f"Using max concurrency of {max_concurrent} requests") + + # Create job directory for output if requested + job_dir = None + if output_dir: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + job_dir = output_dir / f"batch_{name or timestamp}" + job_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Writing results to {job_dir}") + + # Create captions.jsonl file and job_info.json with basic info + with open(job_dir / JOB_INFO_FILENAME, "w") as f: + job_info = { + "started_at": timestamp, + "provider": provider.name, + "model": model, + "config_name": getattr(processor, "config_name", name), + "version": getattr(processor, "version", "1.0"), + "total_images": len(image_paths), + "global_context": global_context, + "note": "This is a temporary implementation of batch processing until Kafka-based processing is implemented." + } + json.dump(job_info, f, indent=2) + + # Process images in parallel with limited concurrency + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_image(path): + async with semaphore: + try: + result = await processor.process_single( + provider=provider, + image_path=path, + model=model, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + repetition_penalty=repetition_penalty, + context=contexts.get(path.name) if contexts else None, + global_context=global_context, + ) + + caption_data = { + "filename": f"./{path.name}", + "config_name": getattr(processor, "config_name", name), + "version": getattr(processor, "version", "1.0"), + "model": model, + "provider": provider.name, + "parsed": result, + } + + # Write to captions.jsonl if job_dir exists + if job_dir: + with open(job_dir / CAPTIONS_FILENAME, "a") as f: + f.write(json.dumps(caption_data) + "\n") + + return caption_data + except Exception as e: + logger.error(f"Error processing {path}: {e}") + error_data = { + "filename": f"./{path.name}", + "config_name": getattr(processor, "config_name", name), + "version": getattr(processor, "version", "1.0"), + "model": model, + "provider": provider.name, + "parsed": {"error": str(e)}, + } + + # Write error to captions.jsonl if job_dir exists + if job_dir: + with open(job_dir / CAPTIONS_FILENAME, "a") as f: + f.write(json.dumps(error_data) + "\n") + + return error_data + + tasks = [process_image(path) for path in image_paths] + results = await tqdm_asyncio.gather(*tasks, desc=f"Processing images with {provider.name}") + + # Update job_info.json with completion info + if job_dir: + with open(job_dir / JOB_INFO_FILENAME, "r") as f: + job_info = json.load(f) + + job_info["completed_at"] = datetime.now().strftime("%Y%m%d_%H%M%S") + job_info["success_count"] = sum(1 for r in results if "error" not in r["parsed"]) + job_info["failed_count"] = sum(1 for r in results if "error" in r["parsed"]) + + with open(job_dir / JOB_INFO_FILENAME, "w") as f: + json.dump(job_info, f, indent=2) + + return results + @dg.asset( group_name="perspectives", @@ -45,11 +170,13 @@ async def perspective_caption( processor = get_perspective(perspective) try: - # Process images in batch + # Process images using the temporary batch processing function image_paths = [Path(image) for image in perspective_image_list] - caption_data_list = await processor.process_batch( + caption_data_list = await process_images_in_batch( + processor, client, image_paths, + model=getattr(provider_config, "model", "gemini-2.0-flash-exp"), output_dir=Path(io_config.run_dir), global_context=perspective_config.global_context, name=perspective, @@ -128,8 +255,15 @@ async def synthesizer_caption( image_dir = Path(io_config.output_dir) / "images" paths = [image_dir / path for path in caption_contexts.keys()] - results = await synthesizer.process_batch( - client, paths, output_dir=Path(io_config.run_dir), contexts=caption_contexts, name="synthesized_caption" + + # Use the temporary batch processing function + results = await process_images_in_batch( + synthesizer, + client, + paths, + output_dir=Path(io_config.run_dir), + contexts=caption_contexts, + name="synthesized_caption" ) # Format the results to match the perspective_caption output diff --git a/servers/inference_bridge/server/server/features/perspectives/service.py b/servers/inference_bridge/server/server/features/perspectives/service.py index 64f5788d..217486df 100644 --- a/servers/inference_bridge/server/server/features/perspectives/service.py +++ b/servers/inference_bridge/server/server/features/perspectives/service.py @@ -323,43 +323,28 @@ async def generate_caption( f"Generating caption for {image_path} using {perspective_name} perspective and {provider_name} provider" ) - # Check if the perspective has process_batch method - if hasattr(perspective, "process_batch"): - logger.info(f"Using process_batch method for {perspective_name}") - # Use process_batch with a single image to match the pipeline implementation - caption_data_list = await perspective.process_batch( - provider=provider, - image_paths=[image_path], - model=model, - output_dir=output_dir, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - repetition_penalty=repetition_penalty, - global_context=global_context, - name=perspective_name, - ) + # Use process_single directly as process_batch has been deprecated + logger.info(f"Using process_single for {perspective_name}") + caption_data = await perspective.process_single( + provider=provider, + image_path=image_path, + model=model, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + repetition_penalty=repetition_penalty, + context=context, + global_context=global_context, + ) - # Get the first (and only) result - if not caption_data_list or len(caption_data_list) == 0: - logger.error(f"No caption data returned for {image_path}") - raise HTTPException(status_code=500, detail="No caption data returned") - - caption_data = caption_data_list[0] - else: - # Fallback to process_single if process_batch is not available - logger.info(f"Falling back to process_single method for {perspective_name}") - caption_data = await perspective.process_single( - provider=provider, - image_path=image_path, - model=model, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - repetition_penalty=repetition_penalty, - context=context, - global_context=global_context, - ) + caption_data = { + "filename": f"./{image_path.name}", + "config_name": getattr(perspective, 'config_name', perspective_name), + "version": getattr(perspective, 'version', '1.0'), + "model": model, + "provider": provider.name, + "parsed": caption_data, + } # Log the result logger.info(f"Caption generated successfully: {caption_data.keys() if caption_data else 'None'}") From 521c195c1e0afd2e8975965bab8899099a45c032 Mon Sep 17 00:00:00 2001 From: jphillips Date: Sat, 29 Mar 2025 19:19:07 -0500 Subject: [PATCH 44/69] Split import function Signed-off-by: jphillips --- .../pipelines/huggingface/dataset_import.py | 204 +++++++++++------- 1 file changed, 132 insertions(+), 72 deletions(-) diff --git a/servers/inference_bridge/pipelines/pipelines/huggingface/dataset_import.py b/servers/inference_bridge/pipelines/pipelines/huggingface/dataset_import.py index d3b979ea..cba5f393 100644 --- a/servers/inference_bridge/pipelines/pipelines/huggingface/dataset_import.py +++ b/servers/inference_bridge/pipelines/pipelines/huggingface/dataset_import.py @@ -27,7 +27,8 @@ from huggingface_hub import hf_hub_download from tqdm import tqdm -from .types import DatasetImportConfig, DatasetParquetUrlDownloadConfig, DatasetParseConfig +from .types import (DatasetImportConfig, DatasetParquetUrlDownloadConfig, + DatasetParseConfig) def _clone_with_git_lfs( @@ -242,93 +243,152 @@ def dataset_download_urls( dataset_download: Path to the downloaded dataset config: Configuration for URL downloading """ + input_dir, output_dir = _setup_directories(context, dataset_download, config) + parquet_files = _find_parquet_files(context, input_dir) + + successful_downloads = 0 + failed_downloads = 0 + total_urls = 0 + + for parquet_file in parquet_files: + df = _load_parquet_file(context, parquet_file, config) + download_results = _process_dataframe(context, df, output_dir, config) + + successful_downloads += download_results["successful"] + failed_downloads += download_results["failed"] + total_urls += download_results["total"] + + # Log summary + context.log.info( + f"Download complete. Successful: {successful_downloads}, Failed: {failed_downloads}, Total: {total_urls}" + ) + + +def _setup_directories( + context: dg.AssetExecutionContext, dataset_download: str, config: DatasetParquetUrlDownloadConfig +) -> tuple[Path, Path]: + """Setup input and output directories for URL downloads.""" input_dir = Path(dataset_download) / config.parquet_dir if not input_dir.exists(): raise ValueError(f"Parquet directory not found at {input_dir}") + + output_dir = Path(config.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + return input_dir, output_dir - # Find all parquet files + +def _find_parquet_files(context: dg.AssetExecutionContext, input_dir: Path) -> list[Path]: + """Find parquet files in the input directory.""" parquet_files = list(input_dir.glob("*.parquet")) if not parquet_files: raise ValueError(f"No parquet files found in {input_dir}") - + context.log.info(f"Found {len(parquet_files)} parquet files") + return parquet_files + + +def _load_parquet_file( + context: dg.AssetExecutionContext, parquet_file: Path, config: DatasetParquetUrlDownloadConfig +) -> pd.DataFrame: + """Load and validate a parquet file.""" + context.log.info(f"Processing {parquet_file}") + df = pd.read_parquet(parquet_file) + + context.log.info(f"Loaded parquet file with {len(df)} rows") + context.log.info(f"Columns: {df.columns.tolist()}") + + if config.url_column not in df.columns: + raise ValueError( + f"URL column '{config.url_column}' not found in {parquet_file}. " + f"Available columns: {df.columns.tolist()}" + ) + + return df - # Create output directory - output_dir = Path(config.output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - # Track progress +def _process_dataframe( + context: dg.AssetExecutionContext, + df: pd.DataFrame, + output_dir: Path, + config: DatasetParquetUrlDownloadConfig, +) -> dict: + """Process a dataframe and download URLs.""" successful_downloads = 0 failed_downloads = 0 total_urls = 0 - - # Process each parquet file - for parquet_file in parquet_files: - context.log.info(f"Processing {parquet_file}") - df = pd.read_parquet(parquet_file) - - context.log.info(f"Loaded parquet file with {len(df)} rows") - context.log.info(f"Columns: {df.columns.tolist()}") - - if config.url_column not in df.columns: - raise ValueError( - f"URL column '{config.url_column}' not found in {parquet_file}. " - f"Available columns: {df.columns.tolist()}" + + with ThreadPoolExecutor(max_workers=config.max_workers) as executor: + future_to_url = {} + + for idx, row in df.iterrows(): + if idx % 1000 == 0: + context.log.info(f"Processing row {idx}") + + download_batch = _prepare_download_batch(context, row, output_dir, config) + if not download_batch: + continue + + for url, output_path in download_batch: + if len(future_to_url) >= config.max_workers * 2: + # Process current batch before adding more + results = _process_completed_downloads( + future_to_url, context, successful_downloads, failed_downloads + ) + successful_downloads, failed_downloads = results + future_to_url = {} + + # Submit download task + future = executor.submit(_download_url, url, output_path, context) + future_to_url[future] = (url, output_path) + total_urls += 1 + + # Process remaining downloads + if future_to_url: + results = _process_completed_downloads( + future_to_url, context, successful_downloads, failed_downloads ) + successful_downloads, failed_downloads = results + + return { + "successful": successful_downloads, + "failed": failed_downloads, + "total": total_urls + } - # Process each row - with ThreadPoolExecutor(max_workers=config.max_workers) as executor: - future_to_url = {} - - for idx, row in df.iterrows(): - if idx % 1000 == 0: - context.log.info(f"Processing row {idx}") - - # Extract URLs using helper function - urls = _extract_urls(row[config.url_column]) - if not urls: - context.log.debug(f"Skipping row {idx}: no valid URLs") - continue - - # Use row ID as base filename - base_filename = row["id"] - if not base_filename: - context.log.warning(f"Skipping row {idx}: no ID") - continue - - for i, url in enumerate(urls): - # Generate unique filename for each URL - filename = f"{base_filename}_{i}.{config.default_extension}" - output_path = output_dir / filename - - # Skip if file exists and no overwrite - if output_path.exists() and not config.overwrite_existing: - context.log.debug(f"Skipping existing file: {output_path}") - continue - - # Limit batch size for rate limiting - if len(future_to_url) >= config.max_workers * 2: - # Wait for some downloads to complete before adding more - successful_downloads, failed_downloads = _process_completed_downloads( - future_to_url, context, successful_downloads, failed_downloads - ) - future_to_url = {} - - # Submit download task - future = executor.submit(_download_url, url, output_path, context) - future_to_url[future] = (url, output_path) - total_urls += 1 - - # Process remaining downloads - if future_to_url: - successful_downloads, failed_downloads = _process_completed_downloads( - future_to_url, context, successful_downloads, failed_downloads - ) - # Log summary - context.log.info( - f"Download complete. Successful: {successful_downloads}, Failed: {failed_downloads}, Total: {total_urls}" - ) +def _prepare_download_batch( + context: dg.AssetExecutionContext, + row: pd.Series, + output_dir: Path, + config: DatasetParquetUrlDownloadConfig, +) -> list[tuple[str, Path]]: + """Prepare a batch of URLs to download for a single row.""" + # Extract URLs + urls = _extract_urls(row[config.url_column]) + if not urls: + return [] + + # Use row ID as base filename + base_filename = row.get("id") + if not base_filename: + context.log.warning("Skipping row: no ID") + return [] + + download_batch = [] + for i, url in enumerate(urls): + # Generate unique filename for each URL + filename = f"{base_filename}_{i}.{config.default_extension}" + output_path = output_dir / filename + + # Skip if file exists and no overwrite + if output_path.exists() and not config.overwrite_existing: + context.log.debug(f"Skipping existing file: {output_path}") + continue + + download_batch.append((url, output_path)) + + return download_batch def _process_completed_downloads( From 5d86abba151528c3e04682101213d1d229652b0f Mon Sep 17 00:00:00 2001 From: jphillips Date: Sat, 29 Mar 2025 19:35:55 -0500 Subject: [PATCH 45/69] Remove model list endpoint from inference bridge server and client Signed-off-by: jphillips --- .../services/inferenceBridgeClient.ts | 7 --- .../server/features/providers/router.py | 34 +---------- .../server/features/providers/service.py | 58 ++----------------- 3 files changed, 6 insertions(+), 93 deletions(-) diff --git a/graphcap_studio/src/features/server-connections/services/inferenceBridgeClient.ts b/graphcap_studio/src/features/server-connections/services/inferenceBridgeClient.ts index 054ca024..f7f114dc 100644 --- a/graphcap_studio/src/features/server-connections/services/inferenceBridgeClient.ts +++ b/graphcap_studio/src/features/server-connections/services/inferenceBridgeClient.ts @@ -27,13 +27,6 @@ export interface ProviderClient { }) => Promise; }; }; - ":provider": { - "models": { - $get: (options: { - param: { provider: string }; - }) => Promise; - }; - }; } /** diff --git a/servers/inference_bridge/server/server/features/providers/router.py b/servers/inference_bridge/server/server/features/providers/router.py index d267feb5..05aac4e2 100644 --- a/servers/inference_bridge/server/server/features/providers/router.py +++ b/servers/inference_bridge/server/server/features/providers/router.py @@ -5,52 +5,22 @@ Defines API routes for working with AI providers. This module provides the following endpoints: -- POST /providers/{provider_name}/models - List available models for a provider using provided configuration - POST /providers/{provider_name}/test-connection - Test connection to a provider using provided configuration """ import traceback -from typing import Union from fastapi import APIRouter -from fastapi.responses import JSONResponse from pydantic import ValidationError from ...utils.logger import logger from .error_handler import (format_provider_connection_error, format_provider_validation_error) -from .models import ProviderConfig, ProviderModelsResponse -from .service import get_provider_models, test_provider_connection +from .models import ProviderConfig +from .service import test_provider_connection router = APIRouter(prefix="/providers", tags=["providers"]) - -@router.post("/{provider_name}/models", response_model=ProviderModelsResponse) -async def list_provider_models(provider_name: str, config: ProviderConfig) -> Union[ProviderModelsResponse, JSONResponse]: - """ - List available models for a specific provider using provided configuration. - - Args: - provider_name: Name of the provider to get models for - config: Provider configuration for this request - - Returns: - List of available models for the provider or an error response - - Raises: - HTTPException: If there is an error getting models - """ - try: - models = await get_provider_models(provider_name, config) - return ProviderModelsResponse(provider=provider_name, models=models) - except ValidationError as e: - return format_provider_validation_error(e) - except Exception as e: - logger.error(f"Error getting models for {provider_name}: {str(e)}") - logger.error(traceback.format_exc()) - return format_provider_connection_error(e, provider_name, config) - - @router.post("/{provider_name}/test-connection") async def test_connection(provider_name: str, config: ProviderConfig): """ diff --git a/servers/inference_bridge/server/server/features/providers/service.py b/servers/inference_bridge/server/server/features/providers/service.py index 94db9d03..5c3c204b 100644 --- a/servers/inference_bridge/server/server/features/providers/service.py +++ b/servers/inference_bridge/server/server/features/providers/service.py @@ -5,13 +5,14 @@ Provides services for working with AI providers. """ -from typing import Any, Dict, List, Protocol, runtime_checkable import datetime +from typing import Any, Dict, Protocol, runtime_checkable -from graphcap.providers.clients.base_client import BaseClient -from graphcap.providers.factory import create_provider_client, get_provider_factory from loguru import logger +from graphcap.providers.clients.base_client import BaseClient +from graphcap.providers.factory import create_provider_client + from .models import ModelInfo, ProviderConfig @@ -40,57 +41,6 @@ def _create_model_info(model_id: str) -> ModelInfo: ) -async def get_provider_models(provider_name: str, config: ProviderConfig) -> List[ModelInfo]: - """ - Get a list of available models for a specific provider. - - Args: - provider_name: Name of the provider to get models for - config: Provider configuration for this request - Returns: - List of model information - """ - # Initialize client with provided configuration - client = create_provider_client( - name=provider_name, - kind=config.kind, - environment=config.environment, - base_url=config.base_url, - api_key=config.api_key, - rate_limits=config.rate_limits, - use_cache=True, # Cache clients for better performance - ) - - models = [] - - # Try to fetch models if configured - if config.fetch_models: - try: - logger.info(f"Fetching models from provider {provider_name}") - if hasattr(client, "get_available_models"): - provider_models = await client.get_available_models() - if hasattr(provider_models, "data"): - for model in provider_models.data: - model_id = _extract_model_id(model) - models.append(_create_model_info(model_id)) - elif hasattr(client, "get_models"): - provider_models = await client.get_models() - if hasattr(provider_models, "models"): - for model in provider_models.models: - model_id = _extract_model_id(model) - models.append(_create_model_info(model_id)) - except Exception as e: - logger.error(f"Error fetching models from provider {provider_name}: {str(e)}") - logger.info(f"Falling back to configured models for provider {provider_name}") - - # Fall back to configured models if none fetched - if not models: - models = [_create_model_info(model_id) for model_id in config.models] - logger.info(f"Using {len(models)} configured models for provider {provider_name}") - - return models - - def create_provider_client_from_config(config: ProviderConfig) -> BaseClient: """ Create a provider client from a configuration. From f4b1d21e4b11f2a31f19b1b2c1325286ec1e48f8 Mon Sep 17 00:00:00 2001 From: jphillips Date: Sat, 29 Mar 2025 19:40:51 -0500 Subject: [PATCH 46/69] Split debug logic for provider issues Signed-off-by: jphillips --- .../features/providers/error_handler.py | 66 ++++++++++++------- 1 file changed, 42 insertions(+), 24 deletions(-) diff --git a/servers/inference_bridge/server/server/features/providers/error_handler.py b/servers/inference_bridge/server/server/features/providers/error_handler.py index 0b466658..eaf2282b 100644 --- a/servers/inference_bridge/server/server/features/providers/error_handler.py +++ b/servers/inference_bridge/server/server/features/providers/error_handler.py @@ -68,35 +68,53 @@ def _generate_error_message(invalid_fields: Set[str]) -> str: return "Invalid provider configuration" +def _get_field_from_error(error: dict) -> str: + """Extract the field name from the error location.""" + return ".".join(str(loc) for loc in error.get("loc", [])[1:]) if error.get("loc") else "" + + +def _add_error_type_suggestion(error: dict, field: str, suggestions: list) -> None: + """Add suggestion based on error type.""" + error_type = error.get("type", "") + + if error_type == "missing": + suggestions.append(f"Add the missing required parameter: '{field}'") + elif error_type == "string_type": + suggestions.append(f"Ensure '{field}' is a valid string") + elif error_type == "url_parsing": + suggestions.append(f"Use a valid URL format for '{field}'") + elif error_type and "enum" in error_type: + _add_enum_suggestion(error, field, suggestions) + + +def _add_enum_suggestion(error: dict, field: str, suggestions: list) -> None: + """Add suggestion for enum validation errors.""" + valid_values = error.get("ctx", {}).get("expected", []) + if valid_values: + values_str = ", ".join([f"'{v}'" for v in valid_values]) + suggestions.append(f"Choose a valid option for '{field}': {values_str}") + else: + suggestions.append(f"Choose a valid option for '{field}'") + + +def _add_field_specific_suggestion(field: str, suggestions: list) -> None: + """Add suggestion based on specific field name.""" + if field == "api_key": + suggestions.append("Check the API key is correct for this provider") + elif field == "base_url": + suggestions.append("Verify the base URL format matches the provider's API documentation") + elif field == "environment": + suggestions.append("Valid environment values are typically 'cloud' or 'local'") + + def _generate_suggestions(errors) -> list: """Generate helpful suggestions based on validation errors.""" suggestions = ["Check API key and endpoint URL", "Verify the provider is correctly configured"] for error in errors: - error_type = error.get("type", "") - field = ".".join(str(loc) for loc in error.get("loc", [])[1:]) if error.get("loc") else "" - - if error_type == "missing": - suggestions.append(f"Add the missing required parameter: '{field}'") - elif error_type == "string_type": - suggestions.append(f"Ensure '{field}' is a valid string") - elif error_type == "url_parsing": - suggestions.append(f"Use a valid URL format for '{field}'") - elif error_type and "enum" in error_type: - valid_values = error.get("ctx", {}).get("expected", []) - if valid_values: - values_str = ", ".join([f"'{v}'" for v in valid_values]) - suggestions.append(f"Choose a valid option for '{field}': {values_str}") - else: - suggestions.append(f"Choose a valid option for '{field}'") - - # Add provider-specific field suggestions - if field == "api_key": - suggestions.append("Check the API key is correct for this provider") - elif field == "base_url": - suggestions.append("Verify the base URL format matches the provider's API documentation") - elif field == "environment": - suggestions.append("Valid environment values are typically 'cloud' or 'local'") + field = _get_field_from_error(error) + _add_error_type_suggestion(error, field, suggestions) + _add_field_specific_suggestion(field, suggestions) suggestions.append("Check server logs for more details") return list(dict.fromkeys(suggestions)) # Remove duplicates while preserving order From c667123715cc05298da9a000e553972537bbdf38 Mon Sep 17 00:00:00 2001 From: jphillips Date: Sat, 29 Mar 2025 20:12:07 -0500 Subject: [PATCH 47/69] split provider test conn, process single caption Signed-off-by: jphillips --- .../graphcap/perspectives/base_caption.py | 128 +++++++++----- .../graphcap/perspectives/types.py | 2 +- .../server/features/providers/service.py | 159 ++++++++++-------- 3 files changed, 178 insertions(+), 111 deletions(-) diff --git a/servers/inference_bridge/graphcap/perspectives/base_caption.py b/servers/inference_bridge/graphcap/perspectives/base_caption.py index 15a5d761..a2009a5f 100644 --- a/servers/inference_bridge/graphcap/perspectives/base_caption.py +++ b/servers/inference_bridge/graphcap/perspectives/base_caption.py @@ -8,7 +8,7 @@ import json from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, cast from loguru import logger from pydantic import BaseModel @@ -90,6 +90,71 @@ def _sanitize_json_string(self, text: str) -> str: return result + def _build_prompt_with_context( + self, context: list[str] | None = None, global_context: str | None = None + ) -> str: + """ + Build the prompt with optional context. + + Args: + context: List of context strings + global_context: Global context string + + Returns: + The complete prompt with context if provided + """ + if not context and not global_context: + return self.vision_config.prompt + + context_block = " Consider the following context when generating the caption:\n" + + if global_context: + context_block += f"\n{global_context}\n\n" + + if context: + for entry in context: + context_block += f"\n{entry}\n\n" + + context_block += "\n" + return f"{context_block}{self.vision_config.prompt}" + + def _parse_completion_result(self, completion: Any) -> Dict[str, Any]: + """ + Parse the completion result into a standardized format. + + Args: + completion: The completion response from the vision model + + Returns: + Parsed result as a dictionary + + Raises: + json.JSONDecodeError: If JSON parsing fails + """ + # Handle BaseModel responses through duck typing + if hasattr(completion, 'choices') and hasattr(completion.choices[0], 'message'): + result = completion.choices[0].message.parsed + if hasattr(result, 'model_dump'): + return result.model_dump() + return cast(Dict[str, Any], result) + + result = completion.choices[0].message.parsed + + # Handle string responses + if isinstance(result, str): + sanitized = self._sanitize_json_string(result) + return json.loads(sanitized) + + # Handle nested structure responses + if isinstance(result, dict): + if "choices" in result: + return cast(Dict[str, Any], result["choices"][0]["message"]["parsed"]["parsed"]) + + if "message" in result: + return cast(Dict[str, Any], result["message"]["parsed"]) + + return cast(Dict[str, Any], result) + @abstractmethod def create_rich_table(self, caption_data: Dict[str, Any]) -> Table: """ @@ -114,7 +179,7 @@ async def process_single( repetition_penalty: Optional[float] = 1.15, context: list[str] | None = None, global_context: str | None = None, - ) -> dict: + ) -> Dict[str, Any]: """ Process a single image and return caption data. @@ -125,6 +190,9 @@ async def process_single( max_tokens: Maximum tokens for model response temperature: Sampling temperature top_p: Nucleus sampling parameter + repetition_penalty: Repetition penalty parameter + context: List of context strings + global_context: Global context string Returns: dict: Structured caption data according to schema @@ -132,50 +200,34 @@ async def process_single( Raises: Exception: If image processing fails """ - if context or global_context: - context_block = " Consider the following context when generating the caption:\n" - if global_context: - context_block += f"\n{global_context}\n\n" - if context: - for entry in context: - context_block += f"\n{entry}\n\n" - context_block += "\n" - prompt = f"{context_block}{self.vision_config.prompt}" - else: - prompt = self.vision_config.prompt try: + # Build prompt with context if provided + prompt = self._build_prompt_with_context(context, global_context) + + # Handle optional parameters with defaults + tokens = 4096 if max_tokens is None else max_tokens + temp = 0.8 if temperature is None else temperature + nucleus = 0.9 if top_p is None else top_p + rep_penalty = 1.15 if repetition_penalty is None else repetition_penalty + + # Process image with vision model completion = await provider.vision( prompt=prompt, image=image_path, schema=self.vision_config.schema, model=model, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - repetition_penalty=repetition_penalty, + max_tokens=tokens, + temperature=temp, + top_p=nucleus, + repetition_penalty=rep_penalty, ) - # Handle response parsing with sanitization - if isinstance(completion, BaseModel): - result = completion.choices[0].message.parsed - if isinstance(result, BaseModel): - result = result.model_dump() - else: - result = completion.choices[0].message.parsed - # Handle string responses that need parsing - if isinstance(result, str): - sanitized = self._sanitize_json_string(result) - try: - result = json.loads(sanitized) - except json.JSONDecodeError as e: - logger.error(f"Failed to parse sanitized JSON: {e}") - raise - elif "choices" in result: - result = result["choices"][0]["message"]["parsed"]["parsed"] - elif "message" in result: - result = result["message"]["parsed"] - - return result + # Parse the completion result + return self._parse_completion_result(completion) + + except json.JSONDecodeError as e: + logger.error(f"Failed to parse JSON response: {e}") + raise Exception(f"Error parsing response for {image_path}: {str(e)}") except Exception as e: raise Exception(f"Error processing {image_path}: {str(e)}") diff --git a/servers/inference_bridge/graphcap/perspectives/types.py b/servers/inference_bridge/graphcap/perspectives/types.py index 9bc18e70..d41af3f4 100644 --- a/servers/inference_bridge/graphcap/perspectives/types.py +++ b/servers/inference_bridge/graphcap/perspectives/types.py @@ -26,4 +26,4 @@ class StructuredVisionConfig: config_name: str version: str prompt: str - schema: BaseModel + schema: type[BaseModel] diff --git a/servers/inference_bridge/server/server/features/providers/service.py b/servers/inference_bridge/server/server/features/providers/service.py index 5c3c204b..d75c8394 100644 --- a/servers/inference_bridge/server/server/features/providers/service.py +++ b/servers/inference_bridge/server/server/features/providers/service.py @@ -6,7 +6,7 @@ """ import datetime -from typing import Any, Dict, Protocol, runtime_checkable +from typing import Any, Dict, Protocol, cast, runtime_checkable from loguru import logger @@ -66,21 +66,9 @@ def create_provider_client_from_config(config: ProviderConfig) -> BaseClient: ) -async def test_provider_connection(provider_name: str, config: ProviderConfig) -> Dict[str, Any]: - """ - Test connection to a provider by initializing the client and performing a simple operation. - - Args: - provider_name: Name of the provider to test - config: Provider configuration for this request - - Returns: - Dictionary containing test results and additional information - - Raises: - Exception: If the connection test fails - """ - result = { +def _create_initial_result(provider_name: str, config: ProviderConfig) -> Dict[str, Any]: + """Create initial result structure for connection test""" + return { "provider": provider_name, "details": {}, "diagnostics": { @@ -95,6 +83,84 @@ async def test_provider_connection(provider_name: str, config: ProviderConfig) - "warnings": [] } } + + +def _check_configuration_warnings(result: Dict[str, Any], config: ProviderConfig) -> None: + """Check for configuration warnings and add them to the result""" + # Check if an empty API key was provided + if not config.api_key: + result["diagnostics"]["warnings"].append({ + "warning_type": "empty_api_key", + "message": "An empty API key was provided. This might not work with most providers." + }) + + # Check if the base URL seems valid + if not config.base_url.startswith(("http://", "https://")): + result["diagnostics"]["warnings"].append({ + "warning_type": "invalid_base_url", + "message": "The base URL doesn't start with http:// or https://" + }) + + +async def _try_list_models(client: ModelProvider, result: Dict[str, Any]) -> None: + """Attempt to list models from the provider""" + # Add diagnostic step for model list + result["diagnostics"]["connection_steps"].append({ + "step": "list_models", + "status": "pending", + "timestamp": str(datetime.datetime.now()) + }) + + try: + if hasattr(client, "get_available_models"): + provider_models = await client.get_available_models() + result["details"]["method"] = "get_available_models" + + if hasattr(provider_models, "data"): + _extract_models_data(result, provider_models.data) + + elif hasattr(client, "get_models"): + provider_models = await client.get_models() + result["details"]["method"] = "get_models" + + if hasattr(provider_models, "models"): + _extract_models_data(result, provider_models.models) + + # Update diagnostic step + result["diagnostics"]["connection_steps"][-1]["status"] = "success" + + except Exception as e: + logger.warning(f"Could not list models: {str(e)}") + result["diagnostics"]["connection_steps"][-1]["status"] = "skipped" + result["diagnostics"]["connection_steps"][-1]["message"] = "Model listing not supported or failed" + + +def _extract_models_data(result: Dict[str, Any], models_list: Any) -> None: + """Extract model data from provider response""" + models_data = [] + for model in models_list: + model_id = _extract_model_id(model) + models_data.append({"id": model_id}) + + result["details"]["available_models"] = models_data + result["details"]["models_count"] = len(models_data) + + +async def test_provider_connection(provider_name: str, config: ProviderConfig) -> Dict[str, Any]: + """ + Test connection to a provider by initializing the client and performing a simple operation. + + Args: + provider_name: Name of the provider to test + config: Provider configuration for this request + + Returns: + Dictionary containing test results and additional information + + Raises: + Exception: If the connection test fails + """ + result = _create_initial_result(provider_name, config) try: # Add diagnostic step @@ -119,70 +185,19 @@ async def test_provider_connection(provider_name: str, config: ProviderConfig) - result["diagnostics"]["connection_steps"][-1]["status"] = "success" result["client_initialized"] = True - # Check if an empty API key was provided - if not config.api_key: - result["diagnostics"]["warnings"].append({ - "warning_type": "empty_api_key", - "message": "An empty API key was provided. This might not work with most providers." - }) + # Check for configuration warnings + _check_configuration_warnings(result, config) - # Check if the base URL seems valid - if not config.base_url.startswith(("http://", "https://")): - result["diagnostics"]["warnings"].append({ - "warning_type": "invalid_base_url", - "message": "The base URL doesn't start with http:// or https://" - }) + # Try to list models + await _try_list_models(cast(ModelProvider, client), result) - # Try to test the connection with a lightweight operation - # First check if we can get models (most providers support this) - try: - # Add diagnostic step for model list - result["diagnostics"]["connection_steps"].append({ - "step": "list_models", - "status": "pending", - "timestamp": str(datetime.datetime.now()) - }) - - if hasattr(client, "get_available_models"): - provider_models = await client.get_available_models() - result["details"]["method"] = "get_available_models" - - # Add model information if available - if hasattr(provider_models, "data"): - models_data = [] - for model in provider_models.data: - model_id = _extract_model_id(model) - models_data.append({"id": model_id}) - result["details"]["available_models"] = models_data - result["details"]["models_count"] = len(models_data) - - elif hasattr(client, "get_models"): - provider_models = await client.get_models() - result["details"]["method"] = "get_models" - - # Add model information if available - if hasattr(provider_models, "models"): - models_data = [] - for model in provider_models.models: - model_id = _extract_model_id(model) - models_data.append({"id": model_id}) - result["details"]["available_models"] = models_data - result["details"]["models_count"] = len(models_data) - - # Update diagnostic step - result["diagnostics"]["connection_steps"][-1]["status"] = "success" - - except Exception as e: - logger.warning(f"Could not list models for {provider_name}: {str(e)}") - result["diagnostics"]["connection_steps"][-1]["status"] = "skipped" - result["diagnostics"]["connection_steps"][-1]["message"] = "Model listing not supported or failed" - # Connection test successful result["connected"] = True result["success"] = True result["message"] = f"Successfully connected to {provider_name}" return result + except Exception as e: logger.error(f"Error testing connection to {provider_name}: {str(e)}") From 4762eea7df4a4ecc62e76b16c79b6588d845cb75 Mon Sep 17 00:00:00 2001 From: jphillips Date: Sat, 29 Mar 2025 20:23:25 -0500 Subject: [PATCH 48/69] Split provider logic in data service Signed-off-by: jphillips --- .../features/provider_config/controller.ts | 949 ++++++++++-------- .../data_service/src/utils/error-handler.ts | 4 +- .../data_service/src/utils/pino-middleware.ts | 2 +- 3 files changed, 557 insertions(+), 398 deletions(-) diff --git a/servers/data_service/src/features/provider_config/controller.ts b/servers/data_service/src/features/provider_config/controller.ts index 67c0aa2e..18bfb56a 100644 --- a/servers/data_service/src/features/provider_config/controller.ts +++ b/servers/data_service/src/features/provider_config/controller.ts @@ -7,14 +7,12 @@ import { eq } from "drizzle-orm"; import type { Context } from "hono"; +import type { Logger } from "pino"; import { db } from "../../db"; import { providerModels, providerRateLimits, providers } from "../../db/schema"; import { decryptApiKey, encryptApiKey } from "../../utils/encryption"; import { processApiKeyForUpdate } from "./api-key-manager"; -import type { - ProviderCreate, - ProviderUpdate, -} from "./schemas"; +import type { ProviderCreate, ProviderUpdate } from "./schemas"; // Type for the validated parameters type ValidatedParams = { @@ -39,17 +37,26 @@ export const getProviders = async (c: Context) => { // Decrypt API keys before returning to client for (const provider of allProviders) { if (provider.apiKey) { - logger.debug({ providerId: provider.id }, "Decrypting API key for provider"); + logger.debug( + { providerId: provider.id }, + "Decrypting API key for provider", + ); provider.apiKey = await decryptApiKey(provider.apiKey); - + // Log whether API key is present after decryption (without showing the actual key) - logger.debug({ - providerId: provider.id, - apiKeyPresent: Boolean(provider.apiKey), - apiKeyLength: provider.apiKey ? provider.apiKey.length : 0 - }, "Provider API key decryption result"); + logger.debug( + { + providerId: provider.id, + apiKeyPresent: Boolean(provider.apiKey), + apiKeyLength: provider.apiKey ? provider.apiKey.length : 0, + }, + "Provider API key decryption result", + ); } else { - logger.debug({ providerId: provider.id }, "No API key to decrypt for provider"); + logger.debug( + { providerId: provider.id }, + "No API key to decrypt for provider", + ); } } @@ -86,7 +93,7 @@ export const getProviders = async (c: Context) => { */ export const getProvider = async (c: Context) => { const { logger } = c.var; - const id = c.req.param('id'); + const id = c.req.param("id"); logger.debug({ id }, "Fetching provider by ID"); try { @@ -108,19 +115,25 @@ export const getProvider = async (c: Context) => { // Decrypt API key before returning to client if (provider.apiKey) { - logger.debug({ - providerId: id, - encryptedKeyLength: provider.apiKey.length - }, "Decrypting API key for provider"); - + logger.debug( + { + providerId: id, + encryptedKeyLength: provider.apiKey.length, + }, + "Decrypting API key for provider", + ); + provider.apiKey = await decryptApiKey(provider.apiKey); - + // Log the result of decryption (without showing the actual key) - logger.debug({ - providerId: id, - apiKeyPresent: Boolean(provider.apiKey), - apiKeyLength: provider.apiKey ? provider.apiKey.length : 0 - }, "Provider API key decryption result"); + logger.debug( + { + providerId: id, + apiKeyPresent: Boolean(provider.apiKey), + apiKeyLength: provider.apiKey ? provider.apiKey.length : 0, + }, + "Provider API key decryption result", + ); } else { logger.debug({ providerId: id }, "No API key to decrypt for provider"); } @@ -128,7 +141,7 @@ export const getProvider = async (c: Context) => { logger.info({ providerId: id }, "Provider fetched successfully"); return c.json(provider); } - + // If ID mismatch, return not found (removed else clause) logger.warn({ providerId: id }, "Provider not found"); return c.json({ error: "Provider not found" }, 404); @@ -138,63 +151,214 @@ export const getProvider = async (c: Context) => { } }; +/** + * Validates provider data during creation + */ +const validateProviderCreate = ( + data: ProviderCreate, +): Record => { + const validationErrors: Record = {}; + + // Name validation + if (!data.name) { + validationErrors.name = ["Provider name is required"]; + } else if (data.name.trim() === "") { + validationErrors.name = ["Provider name cannot be just whitespace"]; + } else if (data.name.length < 3) { + validationErrors.name = [ + "Provider name must be at least 3 characters long", + ]; + } + + // Kind validation + if (!data.kind) { + validationErrors.kind = ["Provider kind is required"]; + } else if (data.kind.trim() === "") { + validationErrors.kind = ["Provider kind cannot be just whitespace"]; + } + + // Base URL validation + if (!data.baseUrl) { + validationErrors.baseUrl = ["Base URL is required"]; + } else { + try { + new URL(data.baseUrl); + } catch (e) { + validationErrors.baseUrl = ["Base URL must be a valid URL"]; + } + } + + // Environment validation + if (!data.environment) { + validationErrors.environment = ["Environment is required"]; + } else if (!["cloud", "local"].includes(data.environment)) { + validationErrors.environment = [ + 'Environment must be either "cloud" or "local"', + ]; + } + + return validationErrors; +}; + +/** + * Handles specific error cases for provider creation + */ +const handleProviderCreateError = (c: Context, error: unknown) => { + const logger = c.var.logger; + + logger.error( + { + error, + message: error instanceof Error ? error.message : "Unknown error", + stack: error instanceof Error ? error.stack : undefined, + }, + "Error creating provider", + ); + + // Check for specific error types to provide better error messages + if (error instanceof Error) { + // Handle database unique constraint violation + if ( + error.message.includes("duplicate key value violates unique constraint") + ) { + return c.json( + { + status: "error", + statusCode: 400, + message: "A provider with that name already exists", + details: { + type: "UniqueConstraintViolation", + }, + }, + 400, + ); + } + + // Handle other database errors + if (error.message.includes("database") || error.message.includes("query")) { + return c.json( + { + status: "error", + statusCode: 500, + message: "Database error occurred while creating provider", + details: { + type: "DatabaseError", + }, + }, + 500, + ); + } + + // Handle validation errors from Zod or other validators + if (error.message.includes("validation")) { + return c.json( + { + status: "error", + statusCode: 400, + message: "Validation error", + details: { + message: error.message, + }, + }, + 400, + ); + } + } + + // Generic error fallback + return c.json( + { + status: "error", + statusCode: 500, + message: "Failed to create provider", + details: error instanceof Error ? { message: error.message } : undefined, + }, + 500, + ); +}; + +/** + * Creates provider data in the database + */ +const saveProviderToDatabase = async ( + tx: typeof db, + providerData: Omit, + models?: ProviderCreate["models"], + rateLimits?: ProviderCreate["rateLimits"], +) => { + // Insert provider + const [provider] = await tx + .insert(providers) + .values({ + ...providerData, + createdAt: new Date(), + updatedAt: new Date(), + }) + .returning(); + + // Insert models if provided + if (models && models.length > 0) { + await tx.insert(providerModels).values( + models.map((model) => ({ + providerId: provider.id, + name: model.name, + isEnabled: model.isEnabled ?? true, + createdAt: new Date(), + updatedAt: new Date(), + })), + ); + } + + // Insert rate limits if provided + if (rateLimits) { + await tx.insert(providerRateLimits).values({ + providerId: provider.id, + requestsPerMinute: rateLimits.requestsPerMinute, + tokensPerMinute: rateLimits.tokensPerMinute, + createdAt: new Date(), + updatedAt: new Date(), + }); + } + + // Return the created provider with relations + return await tx.query.providers.findFirst({ + where: eq(providers.id, provider.id), + with: { + models: true, + rateLimits: true, + }, + }); +}; + /** * Create a new provider */ export const createProvider = async (c: Context) => { const { logger } = c.var; - + try { // @ts-ignore - Hono OpenAPI validation types are not properly recognized const data = c.req.valid("json") as ProviderCreate; logger.debug({ data }, "Creating new provider"); - // Enhanced validation with detailed error messages - const validationErrors: Record = {}; - - // Name validation - if (!data.name) { - validationErrors.name = ['Provider name is required']; - } else if (data.name.trim() === '') { - validationErrors.name = ['Provider name cannot be just whitespace']; - } else if (data.name.length < 3) { - validationErrors.name = ['Provider name must be at least 3 characters long']; - } - - // Kind validation - if (!data.kind) { - validationErrors.kind = ['Provider kind is required']; - } else if (data.kind.trim() === '') { - validationErrors.kind = ['Provider kind cannot be just whitespace']; - } - - // Base URL validation - if (!data.baseUrl) { - validationErrors.baseUrl = ['Base URL is required']; - } else { - try { - new URL(data.baseUrl); - } catch (e) { - validationErrors.baseUrl = ['Base URL must be a valid URL']; - } - } - - // Environment validation - if (!data.environment) { - validationErrors.environment = ['Environment is required']; - } else if (!['cloud', 'local'].includes(data.environment)) { - validationErrors.environment = ['Environment must be either "cloud" or "local"']; - } - + // Validate the provider data + const validationErrors = validateProviderCreate(data); + // If there are validation errors, return them if (Object.keys(validationErrors).length > 0) { - logger.debug({ validationErrors }, "Validation errors in provider creation"); - return c.json({ - status: "error", - statusCode: 400, - message: "Validation failed", - validationErrors - }, 400); + logger.debug( + { validationErrors }, + "Validation errors in provider creation", + ); + return c.json( + { + status: "error", + statusCode: 400, + message: "Validation failed", + validationErrors, + }, + 400, + ); } // Extract models and rate limits if provided @@ -207,112 +371,245 @@ export const createProvider = async (c: Context) => { // Start a transaction const result = await db.transaction(async (tx) => { - // Insert provider - const [provider] = await tx - .insert(providers) - .values({ - ...providerData, - createdAt: new Date(), - updatedAt: new Date(), - }) - .returning(); - - // Insert models if provided - if (models && models.length > 0) { - await tx.insert(providerModels).values( - models.map((model) => ({ - providerId: provider.id, - name: model.name, - isEnabled: model.isEnabled, - createdAt: new Date(), - updatedAt: new Date(), - })), + return saveProviderToDatabase(tx, providerData, models, rateLimits); + }); + + logger.info( + { + provider: { + id: result?.id, + name: result?.name, + kind: result?.kind, + }, + }, + "Provider created successfully", + ); + return c.json(result, 201); + } catch (error) { + return handleProviderCreateError(c, error); + } +}; + +/** + * Validates provider data during update + */ +const validateProviderUpdate = ( +): Record => { + const validationErrors: Record = {}; + + // Add validation logic here if needed + + return validationErrors; +}; + +/** + * Logs field changes between existing provider and update data + */ +const logFieldChanges = ( + logger: Logger, + id: string, + existingProvider: Record, + providerData: Partial +): Record => { + const updatedFields: Record = {}; + + // Compare each field being updated with existing values + for (const [key, value] of Object.entries(providerData)) { + const existingValue = (existingProvider as Record)[key]; + // Only log if the value is actually changing + if (existingValue !== value && value !== undefined) { + // Special handling for API key to avoid logging actual values + if (key === "apiKey") { + updatedFields[key] = { + from: existingValue ? "[ENCRYPTED]" : "[EMPTY]", + to: value ? "[NEW_VALUE]" : "[EMPTY]", + }; + logger.info( + { providerId: id }, + `Updating API key from ${existingValue ? "existing value" : "empty"} to ${value ? "new value" : "empty"}`, ); + } else { + updatedFields[key] = { from: existingValue, to: value }; } + } + } - // Insert rate limits if provided - if (rateLimits) { - await tx.insert(providerRateLimits).values({ - providerId: provider.id, - requestsPerMinute: rateLimits.requestsPerMinute, - tokensPerMinute: rateLimits.tokensPerMinute, - createdAt: new Date(), - updatedAt: new Date(), - }); + // Log all field changes + if (Object.keys(updatedFields).length > 0) { + logger.info( + { + providerId: id, + provider: existingProvider.name, + updatedFields, + }, + "Provider fields being updated", + ); + } + + return updatedFields; +}; + +/** + * Processes model updates + */ +const processModelUpdates = async ( + tx: typeof db, + id: string, + models: ProviderUpdate["models"], +) => { + if (!models || models.length === 0) return; + + // First, delete existing models + await tx + .delete(providerModels) + .where(eq(providerModels.providerId, Number.parseInt(id))); + + // Then insert new models + await tx.insert(providerModels).values( + models.map((model) => { + // Create base model data object + const modelData = { + providerId: Number.parseInt(id), + name: model.name, + isEnabled: model.isEnabled ?? true, + createdAt: new Date(), + updatedAt: new Date(), + }; + + // Only include ID if it exists and is a number + if (model.id !== undefined && typeof model.id === "number") { + return { + ...modelData, + id: model.id, + }; } - // Return the created provider with relations - return await tx.query.providers.findFirst({ - where: eq(providers.id, provider.id), - with: { - models: true, - rateLimits: true, - }, - }); + // Let database auto-generate ID + return modelData; + }), + ); +}; + +/** + * Processes rate limit updates + */ +const processRateLimitUpdates = async ( + tx: typeof db, + id: string, + rateLimits: ProviderUpdate["rateLimits"], +) => { + if (!rateLimits) return; + + // Check if rate limits exist + const existingRateLimits = await tx.query.providerRateLimits.findFirst({ + where: eq(providerRateLimits.providerId, Number.parseInt(id)), + }); + + if (existingRateLimits) { + // Update existing rate limits + await tx + .update(providerRateLimits) + .set({ + requestsPerMinute: rateLimits.requestsPerMinute, + tokensPerMinute: rateLimits.tokensPerMinute, + updatedAt: new Date(), + }) + .where(eq(providerRateLimits.providerId, Number.parseInt(id))); + } else { + // Insert new rate limits + await tx.insert(providerRateLimits).values({ + providerId: Number.parseInt(id), + requestsPerMinute: rateLimits.requestsPerMinute, + tokensPerMinute: rateLimits.tokensPerMinute, + createdAt: new Date(), + updatedAt: new Date(), }); + } +}; - logger.info({ - provider: { - id: result?.id, - name: result?.name, - kind: result?.kind - } - }, "Provider created successfully"); - return c.json(result, 201); - } catch (error) { - logger.error({ +/** + * Handles database updates for a provider + */ +const updateProviderInDatabase = async ( + tx: typeof db, + id: string, + providerData: Partial, + models?: ProviderUpdate["models"], + rateLimits?: ProviderUpdate["rateLimits"], +) => { + // Get the current provider from the database to ensure we have the latest data + const currentProvider = await tx.query.providers.findFirst({ + where: eq(providers.id, Number.parseInt(id)), + }); + + if (!currentProvider) { + throw new Error(`Provider not found with id ${id}`); + } + + // Use the API key manager to handle API key updates + const apiKeyToUse = await processApiKeyForUpdate( + currentProvider, + providerData.apiKey, + ); + + // Update provider with the appropriate API key + await tx + .update(providers) + .set({ + ...providerData, + apiKey: apiKeyToUse, // Use the properly determined API key + updatedAt: new Date(), + }) + .where(eq(providers.id, Number.parseInt(id))); + + // Update models if provided + if (models && models.length > 0) { + await processModelUpdates(tx, id, models); + } + + // Update rate limits if provided + if (rateLimits) { + await processRateLimitUpdates(tx, id, rateLimits); + } + + // Return the updated provider with relations + return await tx.query.providers.findFirst({ + where: eq(providers.id, Number.parseInt(id)), + with: { + models: true, + rateLimits: true, + }, + }); +}; + +/** + * Handles specific error cases for provider updates + */ +const handleProviderUpdateError = (c: Context, error: unknown) => { + const logger = c.var.logger; + const id = c.req.param("id"); + + logger.error( + { error, message: error instanceof Error ? error.message : "Unknown error", - stack: error instanceof Error ? error.stack : undefined - }, "Error creating provider"); - - // Check for specific error types to provide better error messages - if (error instanceof Error) { - // Handle database unique constraint violation - if (error.message.includes('duplicate key value violates unique constraint')) { - return c.json({ - status: "error", - statusCode: 400, - message: "A provider with that name already exists", - details: { - type: "UniqueConstraintViolation" - } - }, 400); - } - - // Handle other database errors - if (error.message.includes('database') || error.message.includes('query')) { - return c.json({ - status: "error", - statusCode: 500, - message: "Database error occurred while creating provider", - details: { - type: "DatabaseError" - } - }, 500); - } - - // Handle validation errors from Zod or other validators - if (error.message.includes('validation')) { - return c.json({ - status: "error", - statusCode: 400, - message: "Validation error", - details: { - message: error.message - } - }, 400); - } - } - - // Generic error fallback - return c.json({ + stack: error instanceof Error ? error.stack : undefined, + providerId: id, + }, + "Error updating provider", + ); + + // Return error response + return c.json( + { status: "error", statusCode: 500, - message: "Failed to create provider", - details: error instanceof Error ? { message: error.message } : undefined - }, 500); - } + message: + error instanceof Error ? error.message : "Failed to update provider", + errorType: error instanceof Error ? error.name : "Unknown", + }, + 500, + ); }; /** @@ -320,18 +617,21 @@ export const createProvider = async (c: Context) => { */ export const updateProvider = async (c: Context) => { const { logger } = c.var; - const id = c.req.param('id'); - + const id = c.req.param("id"); + try { // @ts-ignore - Hono OpenAPI validation types are not properly recognized const data = c.req.valid("json") as ProviderUpdate; - logger.debug({ - id, - data: { - ...data, - apiKey: data.apiKey !== undefined ? '[PRESENT]' : '[MISSING]' - } - }, "Updating provider"); + logger.debug( + { + id, + data: { + ...data, + apiKey: data.apiKey !== undefined ? "[PRESENT]" : "[MISSING]", + }, + }, + "Updating provider", + ); // Check if provider exists const existingProvider = await db.query.providers.findFirst({ @@ -344,246 +644,105 @@ export const updateProvider = async (c: Context) => { if (!existingProvider) { logger.debug({ id }, "Provider not found for update"); - return c.json({ - status: "error", - statusCode: 404, - message: "Provider not found", - providerId: id - }, 404); + return c.json( + { + status: "error", + statusCode: 404, + message: "Provider not found", + providerId: id, + }, + 404, + ); } - // Enhanced validation with detailed error messages - const validationErrors: Record = {}; - - + // Validate update data + const validationErrors = validateProviderUpdate(data); + // If there are validation errors, return them if (Object.keys(validationErrors).length > 0) { - logger.debug({ validationErrors }, "Validation errors in provider update"); - return c.json({ - status: "error", - statusCode: 400, - message: "Validation failed", - providerId: id, - validationErrors - }, 400); + logger.debug( + { validationErrors }, + "Validation errors in provider update", + ); + return c.json( + { + status: "error", + statusCode: 400, + message: "Validation failed", + providerId: id, + validationErrors, + }, + 400, + ); } // Extract models and rate limits if provided const { models, rateLimits, ...providerData } = data; - // LOG API KEY STATUS FOR DEBUGGING - logger.debug({ - providerId: id, - original_apiKey_present: existingProvider.apiKey !== null, - update_apiKey_present: 'apiKey' in providerData, - update_apiKey_value_present: providerData.apiKey !== undefined && providerData.apiKey !== null - }, "API key update status"); - - // Log what fields are being updated - const updatedFields: Record = {}; - - // Compare each field being updated with existing values - for (const [key, value] of Object.entries(providerData)) { - const existingValue = (existingProvider as Record)[key]; - // Only log if the value is actually changing - if (existingValue !== value && value !== undefined) { - // Special handling for API key to avoid logging actual values - if (key === 'apiKey') { - updatedFields[key] = { - from: existingValue ? '[ENCRYPTED]' : '[EMPTY]', - to: value ? '[NEW_VALUE]' : '[EMPTY]' - }; - logger.info( - { providerId: id }, - `Updating API key from ${existingValue ? 'existing value' : 'empty'} to ${value ? 'new value' : 'empty'}` - ); - } else { - updatedFields[key] = { from: existingValue, to: value }; - } - } - } - - // Log all field changes - if (Object.keys(updatedFields).length > 0) { - logger.info({ - providerId: id, - provider: existingProvider.name, - updatedFields - }, "Provider fields being updated"); - } - - // Log model changes if applicable + // Log API key status for debugging + logger.debug( + { + providerId: id, + original_apiKey_present: existingProvider.apiKey !== null, + update_apiKey_present: "apiKey" in providerData, + update_apiKey_value_present: + providerData.apiKey !== undefined && providerData.apiKey !== null, + }, + "API key update status", + ); + + // Log field changes + logFieldChanges(logger, id, existingProvider, providerData); + + // Log model and rate limit changes if (models && models.length > 0) { - // First, delete existing models - await db - .delete(providerModels) - .where(eq(providerModels.providerId, Number.parseInt(id))); - - // Then insert new models - handle both full model objects and simple name+isEnabled objects - await db.insert(providerModels).values( - models.map((model) => { - // Create base model data object - const modelData = { - providerId: Number.parseInt(id), - name: model.name, - isEnabled: model.isEnabled ?? true, - createdAt: new Date(), - updatedAt: new Date(), - }; - - // Only include ID if it exists and is a number - if (model.id !== undefined && typeof model.id === 'number') { - return { - ...modelData, - id: model.id - }; - } - - // Let database auto-generate ID - return modelData; - }), + logger.info( + { providerId: id, modelCount: models.length }, + "Updating provider models", ); } - - // Log rate limit changes if applicable + if (rateLimits) { // Query for existing rate limits - const existingRateLimitsQuery = await db.query.providerRateLimits.findFirst({ - where: eq(providerRateLimits.providerId, Number.parseInt(id)) - }); - - logger.info({ - providerId: id, - provider: existingProvider.name, - existingRateLimits: existingRateLimitsQuery - ? { - requestsPerMinute: existingRateLimitsQuery.requestsPerMinute, - tokensPerMinute: existingRateLimitsQuery.tokensPerMinute - } - : { requestsPerMinute: null, tokensPerMinute: null }, - newRateLimits: rateLimits - }, "Updating provider rate limits"); - } - - // Start a transaction - const result = await db.transaction(async (tx) => { - // Get the current provider from the database to ensure we have the latest data - const currentProvider = await tx.query.providers.findFirst({ - where: eq(providers.id, Number.parseInt(id)), - }); - - if (!currentProvider) { - throw new Error(`Provider not found with id ${id}`); - } - - // Use the API key manager to handle API key updates - const apiKeyToUse = await processApiKeyForUpdate(currentProvider, providerData.apiKey); - - // Update provider with the appropriate API key - await tx - .update(providers) - .set({ - ...providerData, - apiKey: apiKeyToUse, // Use the properly determined API key - updatedAt: new Date(), - }) - .where(eq(providers.id, Number.parseInt(id))); - - // Update models if provided - if (models && models.length > 0) { - // First, delete existing models - await tx - .delete(providerModels) - .where(eq(providerModels.providerId, Number.parseInt(id))); - - // Then insert new models - handle both full model objects and simple name+isEnabled objects - await tx.insert(providerModels).values( - models.map((model) => { - // Create base model data object - const modelData = { - providerId: Number.parseInt(id), - name: model.name, - isEnabled: model.isEnabled ?? true, - createdAt: new Date(), - updatedAt: new Date(), - }; - - // Only include ID if it exists and is a number - if (model.id !== undefined && typeof model.id === 'number') { - return { - ...modelData, - id: model.id - }; - } - - // Let database auto-generate ID - return modelData; - }), - ); - } - - // Update rate limits if provided - if (rateLimits) { - // Check if rate limits exist - const existingRateLimits = await tx.query.providerRateLimits.findFirst({ + const existingRateLimitsQuery = + await db.query.providerRateLimits.findFirst({ where: eq(providerRateLimits.providerId, Number.parseInt(id)), }); - if (existingRateLimits) { - // Update existing rate limits - await tx - .update(providerRateLimits) - .set({ - requestsPerMinute: rateLimits.requestsPerMinute, - tokensPerMinute: rateLimits.tokensPerMinute, - updatedAt: new Date(), - }) - .where(eq(providerRateLimits.providerId, Number.parseInt(id))); - } else { - // Insert new rate limits - await tx.insert(providerRateLimits).values({ - providerId: Number.parseInt(id), - requestsPerMinute: rateLimits.requestsPerMinute, - tokensPerMinute: rateLimits.tokensPerMinute, - createdAt: new Date(), - updatedAt: new Date(), - }); - } - } - - // Return the updated provider with relations - return await tx.query.providers.findFirst({ - where: eq(providers.id, Number.parseInt(id)), - with: { - models: true, - rateLimits: true, + logger.info( + { + providerId: id, + provider: existingProvider.name, + existingRateLimits: existingRateLimitsQuery + ? { + requestsPerMinute: existingRateLimitsQuery.requestsPerMinute, + tokensPerMinute: existingRateLimitsQuery.tokensPerMinute, + } + : { requestsPerMinute: null, tokensPerMinute: null }, + newRateLimits: rateLimits, }, - }); + "Updating provider rate limits", + ); + } + + // Update the provider in the database + const result = await db.transaction(async (tx) => { + return updateProviderInDatabase(tx, id, providerData, models, rateLimits); }); - logger.info({ - providerId: id, - provider: { - name: result?.name, - kind: result?.kind - } - }, "Provider updated successfully"); + logger.info( + { + providerId: id, + provider: { + name: result?.name, + kind: result?.kind, + }, + }, + "Provider updated successfully", + ); return c.json(result); } catch (error) { - logger.error({ - error, - message: error instanceof Error ? error.message : "Unknown error", - stack: error instanceof Error ? error.stack : undefined, - providerId: c.req.param('id') - }, "Error updating provider"); - - // Return error response - return c.json({ - status: "error", - statusCode: 500, - message: error instanceof Error ? error.message : "Failed to update provider", - errorType: error instanceof Error ? error.name : 'Unknown' - }, 500); + return handleProviderUpdateError(c, error); } }; @@ -592,10 +751,10 @@ export const updateProvider = async (c: Context) => { */ export const deleteProvider = async (c: Context) => { const { logger } = c.var; - const id = c.req.param('id'); - + const id = c.req.param("id"); + logger.debug({ id }, "Deleting provider"); - + try { // Check if provider exists const existingProvider = await db.query.providers.findFirst({ diff --git a/servers/data_service/src/utils/error-handler.ts b/servers/data_service/src/utils/error-handler.ts index 726e63a3..42f60692 100644 --- a/servers/data_service/src/utils/error-handler.ts +++ b/servers/data_service/src/utils/error-handler.ts @@ -44,13 +44,13 @@ export function createErrorResponse( export function handleValidationError(error: ZodError, c: Context): Response { const validationErrors: Record = {}; - error.errors.forEach((err) => { + for (const err of error.errors) { const path = err.path.join("."); if (!validationErrors[path]) { validationErrors[path] = []; } validationErrors[path].push(err.message); - }); + } const response = createErrorResponse( "Validation error", diff --git a/servers/data_service/src/utils/pino-middleware.ts b/servers/data_service/src/utils/pino-middleware.ts index 0f615716..d47b6965 100644 --- a/servers/data_service/src/utils/pino-middleware.ts +++ b/servers/data_service/src/utils/pino-middleware.ts @@ -91,7 +91,7 @@ const getRequestBody = async (c: Context, method: string): Promise<[unknown, boo // Check if the request can be cloned if (c.req.raw.clone && typeof c.req.raw.clone === 'function') { const clonedReq = c.req.raw.clone(); - const contentType = c.req.header("content-type") || ""; + const contentType = c.req.header("content-type") ?? ""; const body = await parseRequestBody(clonedReq, contentType); return [body, true]; } From 228dffa2f90751a1b7f44768931d7eba3387e747 Mon Sep 17 00:00:00 2001 From: jphillips Date: Sat, 29 Mar 2025 20:32:55 -0500 Subject: [PATCH 49/69] Decompose provider config crud Signed-off-by: jphillips --- .../features/provider_config/controller.ts | 282 ++++++++++++------ 1 file changed, 199 insertions(+), 83 deletions(-) diff --git a/servers/data_service/src/features/provider_config/controller.ts b/servers/data_service/src/features/provider_config/controller.ts index 18bfb56a..ac07f1a5 100644 --- a/servers/data_service/src/features/provider_config/controller.ts +++ b/servers/data_service/src/features/provider_config/controller.ts @@ -12,7 +12,9 @@ import { db } from "../../db"; import { providerModels, providerRateLimits, providers } from "../../db/schema"; import { decryptApiKey, encryptApiKey } from "../../utils/encryption"; import { processApiKeyForUpdate } from "./api-key-manager"; -import type { ProviderCreate, ProviderUpdate } from "./schemas"; +import type { Provider, ProviderCreate, ProviderUpdate } from "./schemas"; + + // Type for the validated parameters type ValidatedParams = { @@ -402,6 +404,62 @@ const validateProviderUpdate = ( return validationErrors; }; +/** + * Checks if a value has changed + */ +const hasValueChanged = (existingValue: unknown, newValue: unknown): boolean => { + return existingValue !== newValue && newValue !== undefined; +}; + +/** + * Creates a log entry for API key changes + */ +const createApiKeyLogEntry = ( + existingValue: unknown, + value: unknown +): { from: unknown; to: unknown } => { + return { + from: existingValue ? "[ENCRYPTED]" : "[EMPTY]", + to: value ? "[NEW_VALUE]" : "[EMPTY]", + }; +}; + +/** + * Logs an API key change + */ +const logApiKeyChange = ( + logger: Logger, + id: string, + existingValue: unknown, + value: unknown +): void => { + logger.info( + { providerId: id }, + `Updating API key from ${existingValue ? "existing value" : "empty"} to ${value ? "new value" : "empty"}`, + ); +}; + +/** + * Logs all field changes + */ +const logAllFieldChanges = ( + logger: Logger, + id: string, + existingProvider: Record, + updatedFields: Record +): void => { + if (Object.keys(updatedFields).length > 0) { + logger.info( + { + providerId: id, + provider: existingProvider.name, + updatedFields, + }, + "Provider fields being updated", + ); + } +}; + /** * Logs field changes between existing provider and update data */ @@ -416,18 +474,13 @@ const logFieldChanges = ( // Compare each field being updated with existing values for (const [key, value] of Object.entries(providerData)) { const existingValue = (existingProvider as Record)[key]; - // Only log if the value is actually changing - if (existingValue !== value && value !== undefined) { + + // Only process if the value is actually changing + if (hasValueChanged(existingValue, value)) { // Special handling for API key to avoid logging actual values if (key === "apiKey") { - updatedFields[key] = { - from: existingValue ? "[ENCRYPTED]" : "[EMPTY]", - to: value ? "[NEW_VALUE]" : "[EMPTY]", - }; - logger.info( - { providerId: id }, - `Updating API key from ${existingValue ? "existing value" : "empty"} to ${value ? "new value" : "empty"}`, - ); + updatedFields[key] = createApiKeyLogEntry(existingValue, value); + logApiKeyChange(logger, id, existingValue, value); } else { updatedFields[key] = { from: existingValue, to: value }; } @@ -435,16 +488,7 @@ const logFieldChanges = ( } // Log all field changes - if (Object.keys(updatedFields).length > 0) { - logger.info( - { - providerId: id, - provider: existingProvider.name, - updatedFields, - }, - "Provider fields being updated", - ); - } + logAllFieldChanges(logger, id, existingProvider, updatedFields); return updatedFields; }; @@ -536,7 +580,7 @@ const updateProviderInDatabase = async ( providerData: Partial, models?: ProviderUpdate["models"], rateLimits?: ProviderUpdate["rateLimits"], -) => { +): Promise => { // Get the current provider from the database to ensure we have the latest data const currentProvider = await tx.query.providers.findFirst({ where: eq(providers.id, Number.parseInt(id)), @@ -573,13 +617,16 @@ const updateProviderInDatabase = async ( } // Return the updated provider with relations - return await tx.query.providers.findFirst({ + const result = await tx.query.providers.findFirst({ where: eq(providers.id, Number.parseInt(id)), with: { models: true, rateLimits: true, }, }); + + // Cast to ensure type safety + return result as Provider | null; }; /** @@ -612,6 +659,122 @@ const handleProviderUpdateError = (c: Context, error: unknown) => { ); }; +/** + * Checks if a provider exists + */ +const checkProviderExists = async (id: string): Promise => { + const provider = await db.query.providers.findFirst({ + where: eq(providers.id, Number.parseInt(id)), + }); + return !!provider; +}; + +/** + * Fetches existing provider with models and rate limits + */ +const fetchExistingProvider = async (id: string): Promise => { + const provider = await db.query.providers.findFirst({ + where: eq(providers.id, Number.parseInt(id)), + with: { + models: true, + rateLimits: true, + }, + }); + + // Cast to ensure type safety + return provider as Provider | null; +}; + +/** + * Logs API key status for debugging + */ +const logApiKeyStatus = ( + logger: Logger, + id: string, + existingProvider: Provider, + providerData: Partial +): void => { + logger.debug( + { + providerId: id, + original_apiKey_present: existingProvider.apiKey !== null, + update_apiKey_present: "apiKey" in providerData, + update_apiKey_value_present: + providerData.apiKey !== undefined && providerData.apiKey !== null, + }, + "API key update status", + ); +}; + +/** + * Logs model updates + */ +const logModelUpdates = ( + logger: Logger, + id: string, + models: ProviderUpdate["models"] +): void => { + if (models && models.length > 0) { + logger.info( + { providerId: id, modelCount: models.length }, + "Updating provider models", + ); + } +}; + +/** + * Fetches and logs rate limit information + */ +const fetchAndLogRateLimits = async ( + logger: Logger, + id: string, + existingProvider: Provider, + rateLimits: ProviderUpdate["rateLimits"] +): Promise => { + if (!rateLimits) return; + + // Query for existing rate limits + const existingRateLimitsQuery = + await db.query.providerRateLimits.findFirst({ + where: eq(providerRateLimits.providerId, Number.parseInt(id)), + }); + + logger.info( + { + providerId: id, + provider: existingProvider.name, + existingRateLimits: existingRateLimitsQuery + ? { + requestsPerMinute: existingRateLimitsQuery.requestsPerMinute, + tokensPerMinute: existingRateLimitsQuery.tokensPerMinute, + } + : { requestsPerMinute: null, tokensPerMinute: null }, + newRateLimits: rateLimits, + }, + "Updating provider rate limits", + ); +}; + +/** + * Log the result of a successful provider update + */ +const logSuccessfulUpdate = ( + logger: Logger, + id: string, + result: Provider | null +): void => { + logger.info( + { + providerId: id, + provider: { + name: result?.name, + kind: result?.kind, + }, + }, + "Provider updated successfully", + ); +}; + /** * Update an existing provider */ @@ -634,13 +797,7 @@ export const updateProvider = async (c: Context) => { ); // Check if provider exists - const existingProvider = await db.query.providers.findFirst({ - where: eq(providers.id, Number.parseInt(id)), - with: { - models: true, - rateLimits: true, - }, - }); + const existingProvider = await fetchExistingProvider(id); if (!existingProvider) { logger.debug({ id }, "Provider not found for update"); @@ -680,66 +837,27 @@ export const updateProvider = async (c: Context) => { const { models, rateLimits, ...providerData } = data; // Log API key status for debugging - logger.debug( - { - providerId: id, - original_apiKey_present: existingProvider.apiKey !== null, - update_apiKey_present: "apiKey" in providerData, - update_apiKey_value_present: - providerData.apiKey !== undefined && providerData.apiKey !== null, - }, - "API key update status", - ); + logApiKeyStatus(logger, id, existingProvider, providerData); // Log field changes logFieldChanges(logger, id, existingProvider, providerData); // Log model and rate limit changes - if (models && models.length > 0) { - logger.info( - { providerId: id, modelCount: models.length }, - "Updating provider models", - ); - } - - if (rateLimits) { - // Query for existing rate limits - const existingRateLimitsQuery = - await db.query.providerRateLimits.findFirst({ - where: eq(providerRateLimits.providerId, Number.parseInt(id)), - }); - - logger.info( - { - providerId: id, - provider: existingProvider.name, - existingRateLimits: existingRateLimitsQuery - ? { - requestsPerMinute: existingRateLimitsQuery.requestsPerMinute, - tokensPerMinute: existingRateLimitsQuery.tokensPerMinute, - } - : { requestsPerMinute: null, tokensPerMinute: null }, - newRateLimits: rateLimits, - }, - "Updating provider rate limits", - ); - } + logModelUpdates(logger, id, models); + + // Log rate limit changes if provided + await fetchAndLogRateLimits(logger, id, existingProvider, rateLimits); // Update the provider in the database const result = await db.transaction(async (tx) => { return updateProviderInDatabase(tx, id, providerData, models, rateLimits); }); - logger.info( - { - providerId: id, - provider: { - name: result?.name, - kind: result?.kind, - }, - }, - "Provider updated successfully", - ); + // Log successful update (only if result is not null) + if (result) { + logSuccessfulUpdate(logger, id, result); + } + return c.json(result); } catch (error) { return handleProviderUpdateError(c, error); @@ -757,11 +875,9 @@ export const deleteProvider = async (c: Context) => { try { // Check if provider exists - const existingProvider = await db.query.providers.findFirst({ - where: eq(providers.id, Number.parseInt(id)), - }); + const providerExists = await checkProviderExists(id); - if (!existingProvider) { + if (!providerExists) { logger.debug({ id }, "Provider not found for deletion"); return c.json({ error: "Provider not found" }, 404); } From 730f7ced1f3e8c55dd34e7ce251b02cdf5710354 Mon Sep 17 00:00:00 2001 From: jphillips Date: Sat, 29 Mar 2025 21:04:38 -0500 Subject: [PATCH 50/69] Performance and Static Analysis tweaks Signed-off-by: jphillips --- .../ProviderConnectionSuccessDialog.tsx | 29 ++++++++------ .../actions/ProviderModelActions.tsx | 10 +++-- .../inference/providers/ProvidersPanel.tsx | 1 - .../context/PerspectivesDataContext.tsx | 38 ++++++++++++++++--- .../graphcap/perspectives/__init__.py | 7 +--- .../graphcap/providers/factory.py | 4 +- .../pipelines/huggingface/dataset_export.py | 6 +-- 7 files changed, 63 insertions(+), 32 deletions(-) diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionSuccessDialog.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionSuccessDialog.tsx index 4c8b1390..e3977c51 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionSuccessDialog.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionSuccessDialog.tsx @@ -81,14 +81,19 @@ export function ProviderConnectionSuccessDialog({ } const { result } = connectionDetails; + + if (typeof result === 'boolean') { + return null; + } + const steps = result.diagnostics.connection_steps; const warnings = result.diagnostics.warnings; const details = result.details; // Check if any required steps were skipped or failed - const hasSkippedSteps = steps.some((step) => step.status === "skipped"); - const hasFailedSteps = steps.some((step) => step.status === "failed"); - const allStepsSuccessful = steps.every((step) => step.status === "success"); + const hasSkippedSteps = steps.some((step: ConnectionStep) => step.status === "skipped"); + const hasFailedSteps = steps.some((step: ConnectionStep) => step.status === "failed"); + const allStepsSuccessful = steps.every((step: ConnectionStep) => step.status === "success"); // Determine the overall status const getStatusInfo = () => { @@ -118,6 +123,14 @@ export function ProviderConnectionSuccessDialog({ }; const status = getStatusInfo(); + + const getButtonColorScheme = () => { + if (allStepsSuccessful) return "green"; + if (hasFailedSteps) return "red"; + return "yellow"; + }; + + const buttonColorScheme = getButtonColorScheme(); return ( !e.open && onClose()}> @@ -156,7 +169,7 @@ export function ProviderConnectionSuccessDialog({ Warnings: - {warnings.map((warning) => ( + {warnings.map((warning: { warning_type: string; message: string }) => ( + @@ -61,7 +52,7 @@ export function ProviderFormView() { isOpen={dialog === "formError"} onClose={() => closeDialog()} error={error} - providerName={provider?.name || "Provider"} + providerName={provider?.name ?? "Provider"} /> {/* Connection Error Dialog */} @@ -69,14 +60,14 @@ export function ProviderFormView() { isOpen={dialog === "error"} onClose={() => closeDialog()} error={error} - providerName={provider?.name || "Provider"} + providerName={provider?.name ?? "Provider"} /> {/* Success Dialog */} closeDialog()} - providerName={provider?.name || "Provider"} + providerName={provider?.name ?? "Provider"} connectionDetails={connectionDetails} /> diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/AddProviderButton.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/AddProviderButton.tsx index 181a0c20..8f144935 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/AddProviderButton.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/AddProviderButton.tsx @@ -6,9 +6,11 @@ import { useProviderFormContext } from "../../../context/ProviderFormContext"; * Button to add a new provider */ export function AddProviderButton() { - const { setMode } = useProviderFormContext(); + const { setMode, setProvider } = useProviderFormContext(); const handleAddProvider = () => { + // Clear the form and current provider when entering create mode + setProvider(null); setMode("create"); }; diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ConnectionSection.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ConnectionSection.tsx index 40cee055..336ceacd 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ConnectionSection.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ConnectionSection.tsx @@ -91,7 +91,7 @@ export function ConnectionSection() { control={control} render={({ field }) => { // Ensure we always have a defined string value - const value = field.value || ""; + const value = field.value ?? ""; return ( API Key diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx index 0ea32893..da5cae90 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx @@ -12,8 +12,8 @@ import { ProviderFormProvider } from "../../context/ProviderFormContext"; type DialogType = null | "error" | "success" | "formError" | "save"; interface ProviderFormContainerProps { - children: ReactNode; - initialData?: Partial; + readonly children: ReactNode; + readonly initialData?: Partial; } export function ProviderFormContainer({ @@ -70,7 +70,7 @@ export function ProviderFormContainer({ kind: newProvider.kind, environment: newProvider.environment, baseUrl: newProvider.baseUrl, - apiKey: newProvider.apiKey || "", + apiKey: newProvider.apiKey ?? "", isEnabled: newProvider.isEnabled, defaultModel: newProvider.defaultModel, models: newProvider.models @@ -179,7 +179,21 @@ export function ProviderFormContainer({ const handleSetMode = useCallback((newMode: "view" | "edit" | "create") => { setMode(newMode); setContextMode(newMode); - }, [setContextMode]); + + // When switching to create mode, reset the form to empty values + if (newMode === "create") { + reset({ + name: "", + kind: "", + environment: "cloud", + baseUrl: "", + apiKey: "", + isEnabled: true, + defaultModel: "", + models: [] + }); + } + }, [setContextMode, reset]); // Handle model selection with proper type handling const handleSetSelectedModelId = useCallback((id: string | null) => { diff --git a/graphcap_studio/src/features/inference/providers/context/ProviderFormContext.tsx b/graphcap_studio/src/features/inference/providers/context/ProviderFormContext.tsx index d1ff7614..f1430ef5 100644 --- a/graphcap_studio/src/features/inference/providers/context/ProviderFormContext.tsx +++ b/graphcap_studio/src/features/inference/providers/context/ProviderFormContext.tsx @@ -51,8 +51,8 @@ export function useProviderFormContext() { } interface ProviderFormProviderProps { - children: ReactNode; - value: ProviderFormContextType; + readonly children: ReactNode; + readonly value: ProviderFormContextType; } export function ProviderFormProvider({ children, value }: ProviderFormProviderProps) { diff --git a/graphcap_studio/src/types/provider-config-types.ts b/graphcap_studio/src/types/provider-config-types.ts index f73d4718..ce0020b1 100644 --- a/graphcap_studio/src/types/provider-config-types.ts +++ b/graphcap_studio/src/types/provider-config-types.ts @@ -218,7 +218,7 @@ export function toServerConfig(provider: Provider): ServerProviderConfig { kind: provider.kind, environment: provider.environment, base_url: provider.baseUrl, - api_key: provider.apiKey || "", + api_key: provider.apiKey ?? "", default_model: provider.defaultModel, models: provider.models?.map((m) => m.name) || [], }; From ae61df0cd50aaa0693e8fb1f0b478aa05cd8e239 Mon Sep 17 00:00:00 2001 From: jphillips Date: Sun, 30 Mar 2025 04:31:39 -0500 Subject: [PATCH 57/69] Convert provider kind to dropdown Signed-off-by: jphillips --- .../src/features/inference/constants.ts | 4 +- .../components/form/BasicInfoSection.tsx | 43 ++++++++++++++++++- .../containers/ProviderFormContainer.tsx | 2 +- 3 files changed, 45 insertions(+), 4 deletions(-) diff --git a/graphcap_studio/src/features/inference/constants.ts b/graphcap_studio/src/features/inference/constants.ts index cce1fc1c..0d923b2c 100644 --- a/graphcap_studio/src/features/inference/constants.ts +++ b/graphcap_studio/src/features/inference/constants.ts @@ -5,7 +5,7 @@ */ export const DEFAULT_PROVIDER_FORM_DATA = { name: "", - kind: "", + kind: "openai" as const, environment: "cloud" as const, baseUrl: "", apiKey: "", @@ -24,7 +24,7 @@ export const PROVIDER_ENVIRONMENTS = ["cloud", "local"] as const; */ export const PROVIDER_KINDS = [ "openai", - "google", + "gemini", "ollama", "vllm", ] as const; diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/BasicInfoSection.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/BasicInfoSection.tsx index 5d0e952b..b15da553 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/BasicInfoSection.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/BasicInfoSection.tsx @@ -1,3 +1,10 @@ +import { + SelectContent, + SelectItem, + SelectRoot, + SelectTrigger, + SelectValueText, +} from "@/components/ui/select"; import { useColorModeValue } from "@/components/ui/theme/color-mode"; // SPDX-License-Identifier: Apache-2.0 import { @@ -8,8 +15,10 @@ import { Input, Text, VStack, + createListCollection, } from "@chakra-ui/react"; import { Controller } from "react-hook-form"; +import { PROVIDER_KINDS } from "../../../../constants"; import { useProviderFormContext } from "../../../context/ProviderFormContext"; import { EnvironmentSelect } from "./EnvironmentSelect"; @@ -27,6 +36,16 @@ export function BasicInfoSection() { const kind = watch("kind"); const environment = watch("environment"); + // Create collection for provider kinds + const kindItems = PROVIDER_KINDS.map((kind) => ({ + label: kind.charAt(0) + kind.slice(1), + value: kind, + })); + + const kindCollection = createListCollection({ + items: kindItems, + }); + if (!isEditing) { return ( @@ -77,7 +96,29 @@ export function BasicInfoSection() { render={({ field }) => ( Kind - + { + if (details.value && details.value.length > 0) { + field.onChange(details.value[0]); + } else { + field.onChange(""); + } + }} + collection={kindCollection} + > + + + + + {kindItems.map((item) => ( + + {item.label} + + ))} + + {errors.kind?.message} )} diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx index da5cae90..09a76dc5 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/containers/ProviderFormContainer.tsx @@ -184,7 +184,7 @@ export function ProviderFormContainer({ if (newMode === "create") { reset({ name: "", - kind: "", + kind: "openai", // Default to the first provider kind environment: "cloud", baseUrl: "", apiKey: "", From a5323e5a9f44e5b22993b03a20e4d9d4ad78c265 Mon Sep 17 00:00:00 2001 From: jphillips Date: Sun, 30 Mar 2025 04:39:32 -0500 Subject: [PATCH 58/69] Fix api show in view mode, cleanup Signed-off-by: jphillips --- .../components/ConnectionSteps.tsx | 15 ++++++++------- .../components/form/ConnectionSection.tsx | 12 ++++++++++-- .../components/form/ModelSelector.tsx | 6 +++--- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ConnectionSteps.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ConnectionSteps.tsx index 02b158f9..fc37d7e5 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ConnectionSteps.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ConnectionSteps.tsx @@ -6,11 +6,11 @@ import { LuCheck, LuCircleAlert, LuSkipForward } from "react-icons/lu"; * Component that displays connection test steps and their results */ interface ConnectionStep { - step: string; - status: "success" | "failed" | "skipped" | "pending"; - timestamp: string; - error?: string; - message?: string; + readonly step: string; + readonly status: "success" | "failed" | "skipped" | "pending"; + readonly timestamp: string; + readonly error?: string; + readonly message?: string; } interface ConnectionStepsProps { @@ -32,7 +32,7 @@ function StepIcon({ status }: { status: ConnectionStep["status"] }) { } function ConnectionStepResult({ step, labels }: { step: ConnectionStep; labels?: Record }) { - const stepLabel = labels?.[step.step] || step.step; + const stepLabel = labels?.[step.step] ?? step.step; return ( @@ -67,4 +67,5 @@ export function ConnectionSteps({ steps, stepLabels = {} }: ConnectionStepsProps ); } -export type { ConnectionStep }; \ No newline at end of file +export type { ConnectionStep }; + diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ConnectionSection.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ConnectionSection.tsx index 336ceacd..710b05c9 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ConnectionSection.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ConnectionSection.tsx @@ -19,7 +19,7 @@ import { useProviderFormContext } from "../../../context/ProviderFormContext"; * Component for displaying and editing provider connection settings */ export function ConnectionSection() { - const { control, errors, watch, mode, selectedProvider } = + const { control, errors, watch, mode, provider } = useProviderFormContext(); const isEditing = mode === "edit" || mode === "create"; const [showApiKey, setShowApiKey] = useState(false); @@ -33,6 +33,14 @@ export function ConnectionSection() { // Toggle API key visibility const toggleShowApiKey = () => setShowApiKey(!showApiKey); + // Get API key display value + const getApiKeyDisplayValue = () => { + if (showApiKey) { + return provider?.apiKey; + } + return provider?.apiKey ? "••••••••••••••••" : "Not set"; + }; + if (!isEditing) { return ( @@ -50,7 +58,7 @@ export function ConnectionSection() { diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelector.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelector.tsx index 3864ee95..49529913 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelector.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ModelSelector.tsx @@ -9,9 +9,9 @@ import { Box, Heading, Text } from "@chakra-ui/react"; export type ModelItem = ModelOption; export interface ModelSelectorProps { - modelItems: ModelItem[]; - selectedModelId: string | null; - setSelectedModelId: (id: string | null) => void; + readonly modelItems: ModelItem[]; + readonly selectedModelId: string | null; + readonly setSelectedModelId: (id: string | null) => void; } /** From 8ea9ba4cb693c9a5986ed293a6138e3c4de60ccc Mon Sep 17 00:00:00 2001 From: jphillips Date: Sun, 30 Mar 2025 04:46:05 -0500 Subject: [PATCH 59/69] Test Connection cleanup. Fix color Signed-off-by: jphillips --- .../components/LoadingMessage.tsx | 4 +- .../components/ProviderConnectionActions.tsx | 41 ------------------- .../actions/TestConnectionButton.tsx | 3 +- 3 files changed, 4 insertions(+), 44 deletions(-) delete mode 100644 graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionActions.tsx diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/LoadingMessage.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/LoadingMessage.tsx index 7b843cbb..15c3f5b0 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/LoadingMessage.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/LoadingMessage.tsx @@ -2,8 +2,8 @@ import { Box, Spinner, Text } from "@chakra-ui/react"; interface LoadingMessageProps { - isSubmitting: boolean; - saveSuccess: boolean; + readonly isSubmitting: boolean; + readonly saveSuccess: boolean; } export function LoadingMessage({ isSubmitting, saveSuccess }: LoadingMessageProps) { diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionActions.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionActions.tsx deleted file mode 100644 index 592613e4..00000000 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/ProviderConnectionActions.tsx +++ /dev/null @@ -1,41 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -import { Button, Flex } from "@chakra-ui/react"; -import { useInferenceProviderContext } from "../../context"; - -interface ProviderConnectionActionsProps { - readonly isTestingConnection: boolean; - readonly onTest: () => Promise; - readonly disabled?: boolean; - readonly showEditButton?: boolean; -} - -/** - * Component for rendering provider connection test actions - */ -export function ProviderConnectionActions({ - isTestingConnection, - onTest, - disabled, - showEditButton = true -}: ProviderConnectionActionsProps) { - const { setMode } = useInferenceProviderContext(); - - return ( - - - {showEditButton && ( - - )} - - ); -} \ No newline at end of file diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/TestConnectionButton.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/TestConnectionButton.tsx index 6b71a71d..698c8b48 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/TestConnectionButton.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/TestConnectionButton.tsx @@ -10,7 +10,8 @@ export function TestConnectionButton() { return ( + + !deleteProvider.isPending && onClose()} + > + + + + + + Remove Provider + + + + + + + + Are you sure you want to remove the provider "{provider.name}"? + This action cannot be undone. + + + + + + + + + + + + + ); +} \ No newline at end of file diff --git a/graphcap_studio/src/features/server-connections/services/dataServiceClient.ts b/graphcap_studio/src/features/server-connections/services/dataServiceClient.ts index 99927776..4ef3dc68 100644 --- a/graphcap_studio/src/features/server-connections/services/dataServiceClient.ts +++ b/graphcap_studio/src/features/server-connections/services/dataServiceClient.ts @@ -39,8 +39,8 @@ export function getDataServiceUrl(connections: ServerConnection[]): string { ); return ( - dataServiceConnection?.url || - import.meta.env.VITE_DATA_SERVICE_URL || + dataServiceConnection?.url ?? + import.meta.env.VITE_DATA_SERVICE_URL ?? DEFAULT_URLS[SERVER_IDS.DATA_SERVICE] ); } From 115a8740044b15c4f9303622aba0d49fdb5f9733 Mon Sep 17 00:00:00 2001 From: jphillips Date: Sun, 30 Mar 2025 06:48:42 -0500 Subject: [PATCH 67/69] Convert to use provider and model name throughout gen options Signed-off-by: jphillips --- .../components/fields/ModelSelectorField.tsx | 8 ++- .../context/GenerationOptionsContext.tsx | 45 ++++++++-------- .../persist-generation-options.ts | 26 +--------- .../hooks/useProviderModelOptions.ts | 12 +++-- .../actions/RemoveProviderButton.tsx | 8 +-- .../components/form/ProviderFormSelect.tsx | 8 +-- .../containers/ProviderFormContainer.tsx | 19 +++++-- .../PerspectiveActions/PerspectivesFooter.tsx | 51 +++++-------------- .../context/PerspectivesDataContext.tsx | 2 +- .../hooks/useGeneratePerspectiveCaption.ts | 6 +-- .../perspectives/utils/api-adapters.ts | 14 ++--- .../src/types/generation-option-types.ts | 23 ++++++--- 12 files changed, 97 insertions(+), 125 deletions(-) diff --git a/graphcap_studio/src/features/inference/generation-options/components/fields/ModelSelectorField.tsx b/graphcap_studio/src/features/inference/generation-options/components/fields/ModelSelectorField.tsx index 36d48028..6b3a53b6 100644 --- a/graphcap_studio/src/features/inference/generation-options/components/fields/ModelSelectorField.tsx +++ b/graphcap_studio/src/features/inference/generation-options/components/fields/ModelSelectorField.tsx @@ -34,7 +34,7 @@ export function ModelSelectorField() { items: providers.items.length > 0 ? providers.items.map((provider) => ({ label: provider.name, - value: provider.id, + value: provider.name, disabled: false, })) : [{ label: "No providers available", value: "none", disabled: false }] @@ -68,8 +68,6 @@ export function ModelSelectorField() { // Check if any providers are available const hasProviders = providers.items.length > 0; - // Check if any models are available for the selected provider - // Loading state const isProvidersLoading = providers.isLoading; const isModelsLoading = models.isLoading; @@ -84,7 +82,7 @@ export function ModelSelectorField() { (key: K, value: GenerationOptions[K]) => void; resetOptions: () => void; setOptions: (options: Partial) => void; - selectProvider: (providerId: string) => void; + selectProvider: (providerName: string) => void; selectModel: (modelName: string) => void; }; uiActions: { @@ -112,7 +112,7 @@ export function GenerationOptionsProvider({ defaultModel, isLoading, hasError - } = useProviderModelOptions(options.provider_id); + } = useProviderModelOptions(options.provider_name); // Save options to localStorage when they change useEffect(() => { @@ -126,21 +126,21 @@ export function GenerationOptionsProvider({ // Initialize provider if available and not already set useEffect(() => { - if (providers.length > 0 && !options.provider_id) { + if (providers.length > 0 && !options.provider_name && !isLoadingProviders) { const firstProvider = providers[0]; - updateOption("provider_id", firstProvider.id); + updateOption("provider_name", firstProvider.name); } - }, [providers, options.provider_id]); + }, [providers, options.provider_name, isLoadingProviders]); + - // Initialize model if available and not already set useEffect(() => { - // If we have a provider but no model, and models are available - if (options.provider_id && !options.model_id && models.length > 0) { + // Only set model if we have a provider and no model is selected yet + if (options.provider_name && !options.model_name && models.length > 0) { // Try to use default model first, otherwise use first available model const modelToUse = defaultModel || models[0]; - updateOption("model_id", modelToUse.name); + updateOption("model_name", modelToUse.name); } - }, [options.provider_id, options.model_id, models, defaultModel]); + }, [options.provider_name, options.model_name, models, defaultModel]); // Update a single option const updateOption = useCallback( @@ -183,15 +183,16 @@ export function GenerationOptionsProvider({ ); // Provider selection - const selectProvider = useCallback((providerId: string) => { - updateOption("provider_id", providerId); - // Clear model when provider changes - updateOption("model_id", ""); - }, [updateOption]); + const selectProvider = useCallback((providerName: string) => { + if (providerName !== options.provider_name) { + updateOption("provider_name", providerName); + updateOption("model_name", ""); + } + }, [updateOption, options.provider_name]); // Model selection const selectModel = useCallback((modelName: string) => { - updateOption("model_id", modelName); + updateOption("model_name", modelName); }, [updateOption]); // Dialog controls @@ -255,8 +256,8 @@ export function GenerationOptionsProvider({ selectModel, openDialog, closeDialog, - toggleDialog - ], + toggleDialog, + ] ); return ( @@ -267,17 +268,17 @@ export function GenerationOptionsProvider({ } /** - * Hook to access the generation options context + * Hook to use generation options context + * + * Must be used within a GenerationOptionsProvider */ export function useGenerationOptions() { const context = useContext(GenerationOptionsContext); - - if (context === undefined) { + if (!context) { throw new Error( "useGenerationOptions must be used within a GenerationOptionsProvider", ); } - return context; } diff --git a/graphcap_studio/src/features/inference/generation-options/persist-generation-options.ts b/graphcap_studio/src/features/inference/generation-options/persist-generation-options.ts index c4825009..a2e2cdae 100644 --- a/graphcap_studio/src/features/inference/generation-options/persist-generation-options.ts +++ b/graphcap_studio/src/features/inference/generation-options/persist-generation-options.ts @@ -20,20 +20,7 @@ const STORAGE_KEY = "graphcap:generation-options"; */ export function saveGenerationOptions(options: GenerationOptions): void { try { - // Create a copy to ensure we don't modify the original - const optionsToSave = { ...options }; - - // Ensure provider_id is stored as a string - if (optionsToSave.provider_id !== undefined) { - optionsToSave.provider_id = String(optionsToSave.provider_id); - } - - // Verify we have a model name, not an ID - if (optionsToSave.model_id && /^\d+$/.test(optionsToSave.model_id)) { - throw new Error('model_id must be a model name, not a numeric ID'); - } - - const serialized = JSON.stringify(optionsToSave); + const serialized = JSON.stringify(options); localStorage.setItem(STORAGE_KEY, serialized); } catch (error) { console.error("Failed to save generation options to localStorage:", error); @@ -52,17 +39,6 @@ export function loadGenerationOptions(): GenerationOptions | null { const parsed = JSON.parse(serialized); - // If provider_id exists and is a number, convert it to string - if (parsed.provider_id !== undefined && typeof parsed.provider_id === 'number') { - parsed.provider_id = parsed.provider_id.toString(); - } - - // Check if model_id appears to be a numeric ID and not a name - if (parsed.model_id && /^\d+$/.test(parsed.model_id)) { - console.error('Invalid model_id format: Must be a model name, not a numeric ID'); - throw new Error('model_id must be a model name, not a numeric ID'); - } - // Validate the loaded data against the schema return GenerationOptionsSchema.parse(parsed); } catch (error) { diff --git a/graphcap_studio/src/features/inference/hooks/useProviderModelOptions.ts b/graphcap_studio/src/features/inference/hooks/useProviderModelOptions.ts index d58bcfe1..1275116c 100644 --- a/graphcap_studio/src/features/inference/hooks/useProviderModelOptions.ts +++ b/graphcap_studio/src/features/inference/hooks/useProviderModelOptions.ts @@ -13,10 +13,10 @@ import { useMemo } from "react"; /** * Hook for accessing provider and model selection options * - * @param providerId - The selected provider ID + * @param providerName - The selected provider name * @returns Provider and model data with loading states */ -export function useProviderModelOptions(providerId?: string) { +export function useProviderModelOptions(providerName?: string) { // Fetch all providers const { data: providers = [], @@ -26,9 +26,11 @@ export function useProviderModelOptions(providerId?: string) { // Find the selected provider object const selectedProvider = useMemo(() => { - if (!providerId) return null; - return providers.find((p: Provider) => p.id === providerId) || null; - }, [providers, providerId]); + if (!providerName || !providers.length) return null; + + // Find provider by name + return providers.find((p: Provider) => p.name === providerName) || null; + }, [providers, providerName]); // Process models data directly from the provider const models = useMemo(() => { diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/RemoveProviderButton.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/RemoveProviderButton.tsx index 2cad50a3..145bfbae 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/RemoveProviderButton.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/actions/RemoveProviderButton.tsx @@ -21,8 +21,8 @@ export function RemoveProviderButton() { const cancelRef = useRef(null); const deleteProvider = useDeleteProvider(); - // Only show the button if we have a provider - if (!provider || !provider.id) { + // Only show the button if we don't have a provider or the provider doesn't have an ID + if (!provider?.id) { return null; } @@ -35,11 +35,11 @@ export function RemoveProviderButton() { await deleteProvider.mutateAsync(providerId); - // Reset the provider selection setProvider(null); - // Close the dialog onClose(); + + console.log(`Provider "${provider.name}" successfully removed`); } catch (error) { console.error("Failed to delete provider:", error); } diff --git a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ProviderFormSelect.tsx b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ProviderFormSelect.tsx index fb47a2af..00136cc7 100644 --- a/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ProviderFormSelect.tsx +++ b/graphcap_studio/src/features/inference/providers/ProviderConnection/components/form/ProviderFormSelect.tsx @@ -26,15 +26,15 @@ export function ProviderFormSelect({ // Convert providers to the format expected by ProviderSelector const providerOptions: ProviderOption[] = providers.map((p: Provider) => ({ label: p.name, - value: String(p.id), + value: p.name, id: String(p.id), })); const handleProviderChange = (value: string) => { if (!value) return; - // Find the selected provider from the providers list - const selectedProvider = providers.find((p: Provider) => String(p.id) === value); + // Find the selected provider from the providers list by name + const selectedProvider = providers.find((p: Provider) => p.name === value); if (selectedProvider) { setProvider(selectedProvider); } @@ -43,7 +43,7 @@ export function ProviderFormSelect({ return ( (null); const [connectionDetails, setConnectionDetails] = useState(null); + // Get query client for cache invalidation + const queryClient = useQueryClient(); + // Form setup const { control, @@ -55,8 +59,6 @@ export function ProviderFormContainer({ const createProvider = useCreateProvider(); const updateProvider = useUpdateProvider(); - // Use either API providers or context providers - // Handle provider selection const handleProviderSelect = useCallback((newProvider: Provider | null) => { setProvider(newProvider); @@ -107,15 +109,24 @@ export function ProviderFormContainer({ })(e); }); + let savedProvider: Provider | null = null; + if (mode === "edit" && provider?.id) { - await updateProvider.mutateAsync({ + savedProvider = await updateProvider.mutateAsync({ id: denormalizeProviderId(provider.id), data: formData as ProviderUpdate }); } else if (mode === "create") { - await createProvider.mutateAsync(formData as ProviderCreate); + // Create the provider + savedProvider = await createProvider.mutateAsync(formData as ProviderCreate); } + if (savedProvider) { + setProvider(savedProvider); + } + + await queryClient.invalidateQueries({ queryKey: ['providers'] }); + openDialog("success"); setMode("view"); setContextMode("view"); diff --git a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveActions/PerspectivesFooter.tsx b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveActions/PerspectivesFooter.tsx index 2ef6d50f..b756cb30 100644 --- a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveActions/PerspectivesFooter.tsx +++ b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveActions/PerspectivesFooter.tsx @@ -75,40 +75,17 @@ export function PerspectivesFooter() { const borderColor = useColorModeValue("gray.200", "gray.700"); const infoTextColor = useColorModeValue("gray.600", "gray.400"); - // Get provider and model names // Log information for debugging console.log("GenerationOptions:", generationOptions); console.log("Available providers:", availableProviders); - // Get provider information safely (without throwing during render) - const providerInfo = useMemo(() => { - // Don't attempt to find providers if the list is empty or loading - if (!availableProviders.length) { - return { providerName: "Loading...", modelName: "Loading..." }; - } - - // Try to find provider by ID - const providerObj = availableProviders.find(p => - p.id.toString() === generationOptions.provider_id - ); - - // If provider not found, return a placeholder but don't throw - if (!providerObj) { - console.warn(`Provider with ID ${generationOptions.provider_id} not found yet`); - return { - providerName: `ID: ${generationOptions.provider_id}`, - modelName: generationOptions.model_id || "None" - }; - } - - // Provider found, return proper info + // Get provider information safely + const { providerName, modelName } = useMemo(() => { return { - providerName: providerObj.name, - modelName: generationOptions.model_id || "None" + providerName: generationOptions.provider_name || "Select Provider", + modelName: generationOptions.model_name || "Select Model" }; - }, [availableProviders, generationOptions.provider_id, generationOptions.model_id]); - - const { providerName, modelName } = providerInfo; + }, [generationOptions.provider_name, generationOptions.model_name]); // Fetch providers on component mount useEffect(() => { @@ -128,7 +105,7 @@ export function PerspectivesFooter() { return false; } - if (!generationOptions.provider_id) { + if (!generationOptions.provider_name) { showMessage( "No provider selected", "Please select an inference provider", @@ -147,7 +124,7 @@ export function PerspectivesFooter() { } return true; - }, [activeSchemaName, generationOptions.provider_id, currentImage, showMessage]); + }, [activeSchemaName, generationOptions.provider_name, currentImage, showMessage]); // Handle generate button click const handleGenerate = useCallback(async () => { @@ -161,11 +138,11 @@ export function PerspectivesFooter() { try { console.log("Calling generatePerspective..."); - // Find the provider object from the available providers using the provider_id - const providerObject = availableProviders.find(p => p.id.toString() === generationOptions.provider_id); + // Find the provider object from the available providers using provider_name + const providerObject = availableProviders.find(p => p.name === generationOptions.provider_name); if (!providerObject) { - throw new Error(`Provider with ID "${generationOptions.provider_id}" not found in available providers`); + throw new Error(`Provider "${generationOptions.provider_name}" not found in available providers`); } await generatePerspective( @@ -201,13 +178,13 @@ export function PerspectivesFooter() { // Combine loading states const isProcessing = isLoading || isGenerating; - // Check if button should be disabled - use generationOptions.provider_id instead of selectedProvider + // Check if button should be disabled const isGenerateDisabled = - isProcessing || !activeSchemaName || !generationOptions.provider_id; + isProcessing || !activeSchemaName || !generationOptions.provider_name; - // Get title for the generate button - also use generationOptions.provider_id + // Get title for the generate button const buttonTitle = getButtonTitle( - generationOptions.provider_id ? providerName : undefined, + generationOptions.provider_name, activeSchemaName, isProcessing, isGenerated, diff --git a/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx b/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx index 33ca4d08..dac415e2 100644 --- a/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx +++ b/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx @@ -425,7 +425,7 @@ export function PerspectivesDataProvider({ provider: effectiveProvider.name, content: result.result || {}, options: { - model: effectiveOptions.model_id, // model_id now contains the name + model: effectiveOptions.model_name, max_tokens: effectiveOptions.max_tokens, temperature: effectiveOptions.temperature, top_p: effectiveOptions.top_p, diff --git a/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts b/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts index 8d83dd69..3bebd87e 100644 --- a/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts +++ b/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts @@ -57,7 +57,7 @@ export function useGeneratePerspectiveCaption() { } // Check if a model is specified in the options - if (!options.model_id) { + if (!options.model_name) { throw new Error("A model must be specified in the options"); } @@ -82,7 +82,7 @@ export function useGeneratePerspectiveCaption() { perspective, image_path: normalizedImagePath, provider: provider.name, - model: options.model_id, // Use model_id from GenerationOptions + model: options.model_name, // Use model_name from GenerationOptions provider_config: providerConfig, // Include the full provider configuration ...apiOptions, // Spread the formatted API options }; @@ -91,7 +91,7 @@ export function useGeneratePerspectiveCaption() { perspective, image_path: normalizedImagePath, provider: provider.name, - model: options.model_id, // Log the model_id from options + model: options.model_name, // Log the model_name from options options: apiOptions, }); diff --git a/graphcap_studio/src/features/perspectives/utils/api-adapters.ts b/graphcap_studio/src/features/perspectives/utils/api-adapters.ts index 9cfa576f..152b9cf2 100644 --- a/graphcap_studio/src/features/perspectives/utils/api-adapters.ts +++ b/graphcap_studio/src/features/perspectives/utils/api-adapters.ts @@ -53,10 +53,10 @@ export function formatCaptionRequest( */ export function legacyCaptionToGenerationOptions( captionOptions: LegacyCaptionOptions, - providerId: string + providerName: string ): GenerationOptions { return { - model_id: captionOptions.model, + model_name: captionOptions.model, max_tokens: captionOptions.max_tokens ?? 4096, temperature: captionOptions.temperature ?? 0.7, top_p: captionOptions.top_p ?? 0.95, @@ -64,7 +64,7 @@ export function legacyCaptionToGenerationOptions( global_context: captionOptions.global_context ?? "You are a visual captioning perspective.", context: captionOptions.context ?? [], resize_resolution: captionOptions.resize_resolution ?? "NONE", - provider_id: providerId + provider_name: providerName }; } @@ -81,20 +81,20 @@ interface PerspectiveDataWithOptions { */ export function perspectiveDataToGenerationOptions( perspectiveData: PerspectiveDataWithOptions, - providerIdMap: Record + providerNameMap: Record ): GenerationOptions { // If we have structured options, use those if (perspectiveData.options) { return legacyCaptionToGenerationOptions( perspectiveData.options, - providerIdMap[perspectiveData.provider] || "" + providerNameMap[perspectiveData.provider] || perspectiveData.provider ); } // Otherwise create minimal options return { - model_id: perspectiveData.model, - provider_id: providerIdMap[perspectiveData.provider] || "", + model_name: perspectiveData.model, + provider_name: providerNameMap[perspectiveData.provider] || perspectiveData.provider, max_tokens: 4096, temperature: 0.7, top_p: 0.95, diff --git a/graphcap_studio/src/types/generation-option-types.ts b/graphcap_studio/src/types/generation-option-types.ts index ec4b06a2..a348a593 100644 --- a/graphcap_studio/src/types/generation-option-types.ts +++ b/graphcap_studio/src/types/generation-option-types.ts @@ -37,8 +37,8 @@ export const DEFAULT_OPTIONS = { resize_resolution: "NONE", // Default to no resize global_context: "You are a visual captioning perspective.", context: [] as string[], // Default to empty context array - provider_id: "", // Default to empty (will be populated later) - model_id: "", // Default to empty (will be populated later) + provider_name: "", // Default to empty (will be populated later) + model_name: "", // Default to empty (will be populated later) } as const; // Schema for generation options @@ -75,10 +75,10 @@ export const GenerationOptionsSchema = z.object({ // Added context array (was in CaptionOptions) context: z.array(z.string()).default([]), - // Provider and model selection - provider_id: z.string().default(DEFAULT_OPTIONS.provider_id), + // Provider and model selection (using names instead of IDs) + provider_name: z.string().default(DEFAULT_OPTIONS.provider_name), - model_id: z.string().default(DEFAULT_OPTIONS.model_id), + model_name: z.string().default(DEFAULT_OPTIONS.model_name), }); // Type for generation options @@ -90,7 +90,7 @@ export type GenerationOptions = z.infer; */ export function formatApiOptions(options: GenerationOptions): Record { return { - model: options.model_id, // API expects 'model' instead of model_id + model: options.model_name, // API expects 'model' instead of model_name temperature: options.temperature, max_tokens: options.max_tokens, top_p: options.top_p, @@ -103,11 +103,18 @@ export function formatApiOptions(options: GenerationOptions): Record Date: Sun, 30 Mar 2025 07:16:00 -0500 Subject: [PATCH 68/69] Fix perspective type issues, metadata for generated caps Signed-off-by: jphillips --- .../src/features/perspectives/README.md | 1 - .../PerspectiveCard/PerspectiveCardTabbed.tsx | 5 +- .../PerspectiveCard/PerspectiveDebug.tsx | 4 +- .../debug-fields/MetadataSection.tsx | 7 +- .../context/PerspectivesDataContext.tsx | 11 +- .../src/features/perspectives/hooks/index.ts | 5 +- .../hooks/useImagePerspectives.ts | 334 ------------------ .../perspectives/hooks/usePerspectiveUI.ts | 5 +- .../features/perspectives/services/index.ts | 5 +- .../src/types/perspective-types.ts | 27 +- 10 files changed, 27 insertions(+), 377 deletions(-) delete mode 100644 graphcap_studio/src/features/perspectives/hooks/useImagePerspectives.ts diff --git a/graphcap_studio/src/features/perspectives/README.md b/graphcap_studio/src/features/perspectives/README.md index 0c30ae8b..d5f4313b 100644 --- a/graphcap_studio/src/features/perspectives/README.md +++ b/graphcap_studio/src/features/perspectives/README.md @@ -165,4 +165,3 @@ Custom hooks are provided for working with perspectives: - **usePerspectives** - Fetches available perspectives from the server - **useGeneratePerspectiveCaption** - Generates captions for images using perspectives -- **useImagePerspectives** - Manages perspective data for a specific image \ No newline at end of file diff --git a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/PerspectiveCardTabbed.tsx b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/PerspectiveCardTabbed.tsx index e1e23185..e4d6eac5 100644 --- a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/PerspectiveCardTabbed.tsx +++ b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/PerspectiveCardTabbed.tsx @@ -175,8 +175,9 @@ export function PerspectiveCardTabbed({ {/* Metadata - e.g., timestamps or version info */} - {data?.metadata?.timestamp && - new Date(data.metadata.timestamp).toLocaleString()} + {data?.metadata?.generatedAt || data?.metadata?.timestamp ? + new Date(data?.metadata?.generatedAt || data?.metadata?.timestamp || '').toLocaleString() : + ''} diff --git a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/PerspectiveDebug.tsx b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/PerspectiveDebug.tsx index bc772628..6e7ad947 100644 --- a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/PerspectiveDebug.tsx +++ b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/PerspectiveDebug.tsx @@ -49,9 +49,7 @@ function processDebugInfo( model: perspectiveData?.model, version: perspectiveData?.version, config_name: perspectiveData?.config_name ?? schema.name, - generatedAt: data.metadata?.timestamp - ? new Date(data.metadata.timestamp).toISOString() - : null, + generatedAt: data.metadata?.generatedAt ?? null, }, // Generation options - directly from the PerspectiveData interface options: perspectiveData?.options || null, diff --git a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/debug-fields/MetadataSection.tsx b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/debug-fields/MetadataSection.tsx index 143ddaee..13663e43 100644 --- a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/debug-fields/MetadataSection.tsx +++ b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/debug-fields/MetadataSection.tsx @@ -77,6 +77,11 @@ function MetadataItem({ labelColor, valueColor, }: MetadataItemProps) { + // Format date if this is the Generated timestamp field + const formattedValue = label === "Generated:" && value + ? new Date(value).toLocaleString() + : value; + return ( @@ -84,7 +89,7 @@ function MetadataItem({ {value ? ( - {value} + {formattedValue} ) : ( diff --git a/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx b/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx index dac415e2..5b48c819 100644 --- a/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx +++ b/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx @@ -421,7 +421,7 @@ export function PerspectivesDataProvider({ const perspectiveData = { config_name: schemaName, version: "1.0", - model: result.metadata?.model ?? "MISSING_MODEL", + model: result.metadata?.model ?? effectiveOptions.model_name ?? "MISSING_MODEL", provider: effectiveProvider.name, content: result.result || {}, options: { @@ -433,6 +433,13 @@ export function PerspectivesDataProvider({ global_context: effectiveOptions.global_context, context: effectiveOptions.context, resize_resolution: effectiveOptions.resize_resolution + }, + metadata: { + provider: effectiveProvider.name, + model: result.metadata?.model ?? effectiveOptions.model_name ?? "MISSING_MODEL", + version: "1.0", + config_name: schemaName, + generatedAt: new Date().toISOString() } }; @@ -451,7 +458,7 @@ export function PerspectivesDataProvider({ metadata: { captioned_at: new Date().toISOString(), provider: effectiveProvider?.name || "", - model: result.metadata?.model ?? "unknown", + model: result.metadata?.model ?? effectiveOptions.model_name ?? "unknown", }, }; diff --git a/graphcap_studio/src/features/perspectives/hooks/index.ts b/graphcap_studio/src/features/perspectives/hooks/index.ts index d67eb8be..863e4f4d 100644 --- a/graphcap_studio/src/features/perspectives/hooks/index.ts +++ b/graphcap_studio/src/features/perspectives/hooks/index.ts @@ -9,10 +9,9 @@ export { usePerspectiveUI } from "./usePerspectiveUI"; // API Hooks -export { usePerspectives } from "./usePerspectives"; -export { usePerspectiveModules } from "./usePerspectiveModules"; export { useGeneratePerspectiveCaption } from "./useGeneratePerspectiveCaption"; -export { useImagePerspectives } from "./useImagePerspectives"; +export { usePerspectiveModules } from "./usePerspectiveModules"; +export { usePerspectives } from "./usePerspectives"; // Utilities diff --git a/graphcap_studio/src/features/perspectives/hooks/useImagePerspectives.ts b/graphcap_studio/src/features/perspectives/hooks/useImagePerspectives.ts deleted file mode 100644 index c9c3dcbe..00000000 --- a/graphcap_studio/src/features/perspectives/hooks/useImagePerspectives.ts +++ /dev/null @@ -1,334 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -/** - * useImagePerspectives Hook - * - * This hook manages perspective data for a specific image. - */ - -import { useServerConnectionsContext } from "@/context"; -import { SERVER_IDS } from "@/features/server-connections/constants"; -import { useProviders } from "@/features/server-connections/services/providers"; -import type { Image } from "@/services/images"; -import type { Provider } from "@/types/provider-config-types"; -import { useCallback, useEffect, useState } from "react"; - -import type { - CaptionOptions, - ImageCaptions, - ImagePerspectivesResult, - PerspectiveData, - PerspectiveType, -} from "@/types/perspective-types"; -import { useGeneratePerspectiveCaption } from "./useGeneratePerspectiveCaption"; -import { usePerspectives } from "./usePerspectives"; - -/** - * Hook for fetching and managing perspective data for an image - * - * This hook combines the functionality of the perspectives API and captions - * to provide a unified interface for working with image perspectives. - * - * @param image - The image to get perspectives for - * @returns An object with perspective data and functions to manage it - */ -export function useImagePerspectives( - image: Image | null, -): ImagePerspectivesResult { - const [captions, setCaptions] = useState(null); - const [isLoading, setIsLoading] = useState(false); - const [generatingPerspectives, setGeneratingPerspectives] = useState< - string[] - >([]); - const [error, setError] = useState(null); - - // Get server connection status - const { connections } = useServerConnectionsContext(); - const graphcapServerConnection = connections.find( - (conn) => conn.id === SERVER_IDS.INFERENCE_BRIDGE, - ); - const isServerConnected = graphcapServerConnection?.status === "connected"; - - console.debug("useImagePerspectives hook initialized", { - imagePath: image?.path, - isServerConnected, - }); - - // Derived state - const generatedPerspectives = captions - ? Object.keys(captions.perspectives) - : []; - - // Get available perspectives from the server - const { data: perspectivesData } = usePerspectives(); - - // Get available providers - const { data: providersData } = useProviders(); - - // Generate caption mutation - const generateCaption = useGeneratePerspectiveCaption(); - - // Derived state for available perspectives - const availablePerspectives = perspectivesData || []; - - // Derived state for available providers - const availableProviders = - providersData?.map((provider) => ({ - id: provider.id, - name: provider.name, - })) || []; - - // Function to generate a perspective using the perspectives API - const generatePerspective = useCallback( - async ( - perspective: PerspectiveType, - providerId?: number, - options?: CaptionOptions, - ) => { - if (!image) { - console.warn("Cannot generate perspective: No image provided"); - setError("No image provided"); - return; - } - - if (!options) { - console.warn("No options provided, using default options"); - setError("No options provided"); - return; - } - - if (!isServerConnected) { - console.warn( - "Cannot generate perspective: Server connection not established", - ); - setError("Server connection not established"); - return; - } - - // Find the provider by ID if provided - let providerObject: Provider | undefined; - if (providerId && providersData) { - providerObject = providersData.find((p) => p.id === providerId); - if (providerObject) { - console.debug( - `Using provider: ${providerObject.name} (ID: ${providerId})`, - ); - } else { - console.warn(`Provider with ID ${providerId} not found`); - setError(`Provider with ID ${providerId} not found`); - return; - } - } else { - console.warn("No provider ID specified"); - setError("No provider ID specified"); - return; - } - - console.log(`Generating perspective: ${perspective}`, { - imagePath: image.path, - provider: providerObject.name, - options, - }); - - setError(null); - // Track which perspective is being generated - setGeneratingPerspectives((prev) => [...prev, perspective]); - setIsLoading(true); - - try { - // Generate the caption - const result = await generateCaption.mutateAsync({ - imagePath: image.path, - perspective, - provider: providerObject, - options, - }); - - // Log the caption result - console.debug("Caption generation result received"); - console.debug( - `Caption content for perspective ${perspective}:`, - result.content || result.result, - ); - - // Create a perspective data object - const perspectiveData: PerspectiveData = { - config_name: perspective, - version: "1.0", - model: "api-generated", - provider: providerObject.name, - content: result.content || result.result || {}, - options: options, - }; - - // Update the captions with the new perspective - setCaptions((prevCaptions) => { - if (!prevCaptions) { - // Create a new captions object if none exists - console.debug("Creating new captions object"); - return { - image, - perspectives: { - [perspective]: perspectiveData, - }, - metadata: { - captioned_at: new Date().toISOString(), - provider: providerObject.name, - model: "api-generated", - }, - }; - } - - // Update existing captions - console.debug("Updating existing captions"); - return { - ...prevCaptions, - perspectives: { - ...prevCaptions.perspectives, - [perspective]: perspectiveData, - }, - metadata: { - ...prevCaptions.metadata, - captioned_at: new Date().toISOString(), - provider: providerObject.name, - model: "api-generated", - }, - }; - }); - } catch (err) { - console.error("Error generating perspective", err); - setError( - err instanceof Error ? err.message : "Failed to generate perspective", - ); - } finally { - // Remove the perspective from the generating list - setGeneratingPerspectives((prev) => - prev.filter((p) => p !== perspective), - ); - // Only set isLoading to false if no perspectives are being generated - setIsLoading(() => { - const updatedGenerating = generatingPerspectives.filter( - (p) => p !== perspective, - ); - return updatedGenerating.length > 0; - }); - } - }, - [ - image, - providersData, - generateCaption, - generatingPerspectives, - isServerConnected, - ], - ); - - // Function to generate all perspectives - const generateAllPerspectives = useCallback(() => { - if (!image || !perspectivesData) { - console.warn( - "Cannot generate all perspectives: No image or perspectives data", - ); - setError("No image or perspectives data available"); - return; - } - - if (!isServerConnected) { - console.warn( - "Cannot generate all perspectives: Server connection not established", - ); - setError("Server connection not established"); - return; - } - - console.log("Generating all perspectives", { - imagePath: image.path, - perspectiveCount: perspectivesData.length, - }); - - setIsLoading(true); - // Track all perspectives as generating - setGeneratingPerspectives(perspectivesData.map((p) => p.name)); - - try { - // Generate each perspective one by one - for (const perspective of perspectivesData) { - console.debug(`Generating perspective: ${perspective.name}`); - generatePerspective(perspective.name); - } - - console.log("All perspectives generated successfully"); - } catch (err) { - console.error("Error generating all perspectives", err); - setError( - err instanceof Error - ? err.message - : "Failed to generate all perspectives", - ); - } finally { - setIsLoading(false); - setGeneratingPerspectives([]); - } - }, [image, perspectivesData, generatePerspective, isServerConnected]); - - // Reset error when server connection changes - useEffect(() => { - if (isServerConnected && error === "Server connection not established") { - setError(null); - } - }, [isServerConnected, error]); - - // Log when the hook's return value changes - useEffect(() => { - console.debug("useImagePerspectives state updated", { - isLoading, - hasError: error !== null, - hasCaptions: captions !== null, - generatedPerspectiveCount: generatedPerspectives.length, - availablePerspectiveCount: availablePerspectives.length, - availableProviderCount: availableProviders.length, - generatingPerspectives, - isServerConnected, - }); - - if (captions?.perspectives) { - console.debug( - "Current perspectives:", - Object.keys(captions.perspectives), - ); - } - }, [ - isLoading, - error, - captions, - generatedPerspectives, - availablePerspectives, - availableProviders, - generatingPerspectives, - isServerConnected, - ]); - - // Create wrapper functions that don't return the promises - const generatePerspectiveWrapper = ( - perspective: PerspectiveType, - providerId?: number, - options?: CaptionOptions, - ): void => { - generatePerspective(perspective, providerId, options); - }; - - const generateAllPerspectivesWrapper = (): void => { - generateAllPerspectives(); - }; - - return { - isLoading, - error, - captions, - generatedPerspectives, - generatingPerspectives, - generatePerspective: generatePerspectiveWrapper, - generateAllPerspectives: generateAllPerspectivesWrapper, - availablePerspectives, - availableProviders, - }; -} diff --git a/graphcap_studio/src/features/perspectives/hooks/usePerspectiveUI.ts b/graphcap_studio/src/features/perspectives/hooks/usePerspectiveUI.ts index d86fa70d..c4eca286 100644 --- a/graphcap_studio/src/features/perspectives/hooks/usePerspectiveUI.ts +++ b/graphcap_studio/src/features/perspectives/hooks/usePerspectiveUI.ts @@ -5,16 +5,15 @@ * This hook provides UI-related functionality for the perspectives components. */ -import { PerspectiveType } from "@/features/perspectives/types"; import { useCallback, useState } from "react"; interface UsePerspectiveUIOptions { onGeneratePerspective?: ( - perspective: PerspectiveType, + perspective: string, provider?: string, ) => void; initialSelectedProvider?: string; - perspectiveKey?: PerspectiveType; + perspectiveKey?: string; } /** diff --git a/graphcap_studio/src/features/perspectives/services/index.ts b/graphcap_studio/src/features/perspectives/services/index.ts index 7d8b4332..d8783d3b 100644 --- a/graphcap_studio/src/features/perspectives/services/index.ts +++ b/graphcap_studio/src/features/perspectives/services/index.ts @@ -17,7 +17,6 @@ export { perspectivesApi } from "./api"; // Export hooks from the hooks directory export { - usePerspectives, - useGeneratePerspectiveCaption, - useImagePerspectives, + useGeneratePerspectiveCaption, usePerspectives } from "@/features/perspectives/hooks"; + diff --git a/graphcap_studio/src/types/perspective-types.ts b/graphcap_studio/src/types/perspective-types.ts index 8df840f8..03cafd31 100644 --- a/graphcap_studio/src/types/perspective-types.ts +++ b/graphcap_studio/src/types/perspective-types.ts @@ -202,11 +202,6 @@ export type CaptionOptions = { resize_resolution?: string; }; -/** - * String alias that allows any perspective name to be used. - */ -export type PerspectiveType = string; - /** * Describes a provider with id and name. */ @@ -223,7 +218,7 @@ export interface PerspectiveData { version: string; model: string; provider: string; - content: Record; + content: Record; options: CaptionOptions; } @@ -244,24 +239,6 @@ export interface ImageCaptions { // SECTION D - COMPOSITE TYPES // ============================================================================ -/** - * Result type for the useImagePerspectives hook. - */ -export interface ImagePerspectivesResult { - isLoading: boolean; - error: string | null; - captions: ImageCaptions | null; - generatedPerspectives: PerspectiveType[]; - generatingPerspectives: string[]; - generatePerspective: ( - perspective: PerspectiveType, - providerId?: number, - options?: CaptionOptions, - ) => void; - generateAllPerspectives: () => void; - availablePerspectives: Perspective[]; - availableProviders: Provider[]; -} /** * Context type for the perspectives feature. @@ -285,5 +262,5 @@ export interface PerspectivesContextType { */ export interface PerspectivesProviderProps { children: React.ReactNode; - initialSelectedProviderId?: number | undefined; + initialSelectedProviderId?: number; } From 2cfeed729f576074f987e3a3084149909902e865 Mon Sep 17 00:00:00 2001 From: jphillips Date: Sun, 30 Mar 2025 07:23:55 -0500 Subject: [PATCH 69/69] Lint tweaks Signed-off-by: jphillips --- graphcap_studio/src/features/perspectives/services/utils.ts | 5 ++--- .../server-connections/services/inferenceBridgeClient.ts | 4 ++-- .../features/server-connections/services/providerAdapters.ts | 2 +- .../src/features/server-connections/services/providers.ts | 2 +- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/graphcap_studio/src/features/perspectives/services/utils.ts b/graphcap_studio/src/features/perspectives/services/utils.ts index 75c2b619..3442773e 100644 --- a/graphcap_studio/src/features/perspectives/services/utils.ts +++ b/graphcap_studio/src/features/perspectives/services/utils.ts @@ -5,7 +5,6 @@ * This module provides utility functions for the perspectives service. */ -import { DEFAULTS } from "@/features/perspectives/constants/index"; import type { ServerConnection } from "@/features/perspectives/types"; import { SERVER_IDS } from "@/features/server-connections/constants"; @@ -19,8 +18,8 @@ export function getGraphCapServerUrl(connections: ServerConnection[]): string { // Use connection URL or fallback to environment variable or default const serverUrl = - serverConnection?.url || - import.meta.env.VITE_INFERENCE_BRIDGE_URL || + serverConnection?.url ?? + import.meta.env.VITE_INFERENCE_BRIDGE_URL ?? "http://localhost:32100"; console.debug(`Using Inference Bridge URL: ${serverUrl}`); diff --git a/graphcap_studio/src/features/server-connections/services/inferenceBridgeClient.ts b/graphcap_studio/src/features/server-connections/services/inferenceBridgeClient.ts index f7f114dc..a8d013e2 100644 --- a/graphcap_studio/src/features/server-connections/services/inferenceBridgeClient.ts +++ b/graphcap_studio/src/features/server-connections/services/inferenceBridgeClient.ts @@ -74,8 +74,8 @@ export function getInferenceBridgeUrl(connections: ServerConnection[]): string { ); return ( - inferenceBridgeConnection?.url || - import.meta.env.VITE_INFERENCE_BRIDGE_URL || + inferenceBridgeConnection?.url ?? + import.meta.env.VITE_INFERENCE_BRIDGE_URL ?? DEFAULT_URLS[SERVER_IDS.INFERENCE_BRIDGE] ); } diff --git a/graphcap_studio/src/features/server-connections/services/providerAdapters.ts b/graphcap_studio/src/features/server-connections/services/providerAdapters.ts index 2f0af535..f156222d 100644 --- a/graphcap_studio/src/features/server-connections/services/providerAdapters.ts +++ b/graphcap_studio/src/features/server-connections/services/providerAdapters.ts @@ -119,7 +119,7 @@ export function createProviderModel( id?: string, ): ProviderModel { return { - id: id || crypto.randomUUID(), // Generate UUID if no ID provided + id: id ?? crypto.randomUUID(), // Generate UUID if no ID provided providerId, name, isEnabled: true, diff --git a/graphcap_studio/src/features/server-connections/services/providers.ts b/graphcap_studio/src/features/server-connections/services/providers.ts index 25882225..5ebdc167 100644 --- a/graphcap_studio/src/features/server-connections/services/providers.ts +++ b/graphcap_studio/src/features/server-connections/services/providers.ts @@ -155,7 +155,7 @@ export function useUpdateProvider() { const client = createDataServiceClient(connections); // Convert application data to API format - const apiData = toApiProvider(data as Provider); + const apiData = toApiProvider(data); // Create a new object without the ID const { id: _, ...apiDataWithoutId } = apiData;