No providers available
@@ -22,9 +20,18 @@ export default function ProvidersList({
);
}
+ // Create a simple handler for form submission - just updates the global context
+ const handleSubmit = async (data: ProviderCreate | ProviderUpdate) => {
+ // This is a simple wrapper, so we don't actually need to do anything on submit
+ console.log("Provider selected in list:", data);
+ };
+
return (
-
+ {/* Wrap in provider form container to provide the necessary context */}
+
+
+
);
}
diff --git a/graphcap_studio/src/features/inference/providers/ProvidersPanel.tsx b/graphcap_studio/src/features/inference/providers/ProvidersPanel.tsx
index 5711a3ce..3154e996 100644
--- a/graphcap_studio/src/features/inference/providers/ProvidersPanel.tsx
+++ b/graphcap_studio/src/features/inference/providers/ProvidersPanel.tsx
@@ -1,66 +1,23 @@
import { useColorMode } from "@/components/ui/theme/color-mode";
-import { Box, Button, Center, Flex, Text, VStack } from "@chakra-ui/react";
+import { Box, Center, Flex, Text } from "@chakra-ui/react";
// SPDX-License-Identifier: Apache-2.0
import { useMemo } from "react";
import { useProviders } from "../services/providers";
-import ProviderForm from "./ProviderForm";
+import { ProviderConnection } from "./ProviderConnection";
import {
InferenceProviderProvider,
- useInferenceProviderContext,
} from "./context";
-import { ProviderSelect } from "./form";
/**
* Panel content that requires context
*/
function PanelContent() {
- const { setMode, providers } =
- useInferenceProviderContext();
-
- const { colorMode } = useColorMode();
- const textColor = colorMode === "light" ? "gray.600" : "gray.300";
- const borderColor = colorMode === "light" ? "gray.200" : "gray.700";
-
- // No providers state
- if (providers.length === 0) {
- return (
-
- No providers configured
-
-
- );
- }
return (
- {/* Header */}
-
- {/* Provider Selection Dropdown */}
-
-
-
-
-
-
{/* Content */}
-
+
);
@@ -69,8 +26,8 @@ function PanelContent() {
/**
* Providers Panel Component
*
- * This component displays a list of providers and allows viewing and editing
- * provider configurations.
+ * This component displays provider configurations in a panel.
+ * It acts as a container for the provider connection form.
*/
export function ProvidersPanel() {
const {
@@ -111,9 +68,6 @@ export function ProvidersPanel() {
providers={providersData}
selectedProvider={initialSelectedProvider}
isCreating={false}
- onSubmit={() => {}}
- onCancel={() => {}}
- isSubmitting={false}
>
diff --git a/graphcap_studio/src/features/inference/providers/context/InferenceProviderContext.tsx b/graphcap_studio/src/features/inference/providers/context/InferenceProviderContext.tsx
index 7401a316..bdc1fd02 100644
--- a/graphcap_studio/src/features/inference/providers/context/InferenceProviderContext.tsx
+++ b/graphcap_studio/src/features/inference/providers/context/InferenceProviderContext.tsx
@@ -1,3 +1,4 @@
+import type { Provider } from "@/types/provider-config-types";
// SPDX-License-Identifier: Apache-2.0
/**
* Inference Provider Context
@@ -8,9 +9,8 @@
* The context includes:
* - View state (mode, selected provider)
* - Providers data
- * - Form state and validation
* - Model selection state
- * - Form actions and callbacks
+ * - Basic view actions
*/
import {
type ReactNode,
@@ -21,15 +21,12 @@ import {
useMemo,
useState,
} from "react";
-import { DEFAULT_PROVIDER_FORM_DATA } from "../../constants";
-import { useModelSelection, useProviderForm } from "../../hooks";
-import type { Provider, ProviderCreate, ProviderUpdate } from "../types";
+import { useModelSelection } from "../../hooks";
// Local storage key for selected provider
const SELECTED_PROVIDER_STORAGE_KEY = "graphcap-selected-provider";
type ViewMode = "view" | "edit" | "create";
-type FormData = ProviderCreate | ProviderUpdate;
/**
* Type definition for the Inference Provider Context
@@ -47,29 +44,12 @@ type InferenceProviderContextType = {
providers: Provider[];
setProviders: (providers: Provider[]) => void;
- // Form state
- control: any;
- handleSubmit: any;
- errors: any;
- watch: any;
- providerName: string | undefined;
- reset: any;
-
// Model selection state
selectedModelId: string;
setSelectedModelId: (id: string) => void;
- providerModelsData: any;
- isLoadingModels: boolean;
- isModelsError: boolean;
- modelsError: any;
-
- // Form actions
handleModelSelect: () => void;
- isSubmitting: boolean;
- isCreating: boolean;
- // Form callbacks
- onSubmit: (data: FormData) => Promise
;
+ // Basic actions
onCancel: () => void;
};
@@ -86,29 +66,12 @@ const defaultContextValue: InferenceProviderContextType = {
providers: [],
setProviders: () => {},
- // Form state
- control: null,
- handleSubmit: () => ({}),
- errors: {},
- watch: () => undefined,
- providerName: undefined,
- reset: () => {},
-
// Model selection state
selectedModelId: "",
setSelectedModelId: () => {},
- providerModelsData: null,
- isLoadingModels: false,
- isModelsError: false,
- modelsError: null,
-
- // Form actions
handleModelSelect: () => {},
- isSubmitting: false,
- isCreating: false,
- // Form callbacks
- onSubmit: async () => Promise.resolve(),
+ // Basic actions
onCancel: () => {},
};
@@ -135,9 +98,6 @@ export function useInferenceProviderContext() {
return context;
}
-// For backward compatibility
-export const useProviderFormContext = useInferenceProviderContext;
-
/**
* Save provider to localStorage
* @param provider - The provider to save
@@ -176,11 +136,8 @@ const loadProviderFromStorage = (): Provider | null => {
*/
type InferenceProviderProviderProps = {
readonly children: ReactNode;
- readonly initialData?: Partial;
- readonly isCreating: boolean;
- readonly onSubmit: (data: FormData) => void;
- readonly onCancel: () => void;
- readonly isSubmitting: boolean;
+ readonly isCreating?: boolean;
+ readonly onCancel?: () => void;
readonly onModelSelect?: (providerName: string, modelId: string) => void;
readonly selectedProvider?: Provider | null;
readonly providers?: Provider[];
@@ -193,20 +150,16 @@ type InferenceProviderProviderProps = {
* available to all child components through the context. It handles:
*
* - Provider selection and management
- * - Form state and validation
* - Model selection
- * - Form submission and cancellation
+ * - Basic view actions
*
* @param props - The provider props
* @returns A context provider component
*/
export function InferenceProviderProvider({
children,
- initialData = {},
- isCreating,
- onSubmit: onSubmitProp,
- onCancel,
- isSubmitting,
+ isCreating = false,
+ onCancel = () => {},
onModelSelect,
selectedProvider: selectedProviderProp,
providers: providersProp = [],
@@ -223,79 +176,41 @@ export function InferenceProviderProvider({
// Update selected provider when prop changes
useEffect(() => {
- if (selectedProviderProp) {
+ if (selectedProviderProp && JSON.stringify(selectedProviderProp) !== JSON.stringify(selectedProvider)) {
setSelectedProvider(selectedProviderProp);
}
- }, [selectedProviderProp]);
+ }, [selectedProviderProp, selectedProvider]);
// Save selected provider to localStorage when it changes
useEffect(() => {
- saveProviderToStorage(selectedProvider);
+ if (selectedProvider) {
+ saveProviderToStorage(selectedProvider);
+ }
}, [selectedProvider]);
- // Update providers when prop changes
+ // Update providers when prop changes - only if we have providers and they're different
useEffect(() => {
- setProviders(providersProp);
- }, [providersProp]);
-
- // Use the form hook
- const {
- control,
- handleSubmit,
- errors,
- providerName,
- onSubmit: onSubmitForm,
- watch,
- reset,
- } = useProviderForm(initialData);
-
- // Reset form data when selected provider changes
- useEffect(() => {
- if (selectedProvider && mode !== "create") {
- reset({
- name: selectedProvider.name,
- kind: selectedProvider.kind,
- environment: selectedProvider.environment,
- baseUrl: selectedProvider.baseUrl,
- envVar: selectedProvider.envVar,
- isEnabled: selectedProvider.isEnabled,
- rateLimits: selectedProvider.rateLimits || {
- requestsPerMinute: 0,
- tokensPerMinute: 0,
- },
- });
- } else if (mode === "create") {
- reset(DEFAULT_PROVIDER_FORM_DATA);
+ const hasProviders = Array.isArray(providersProp) && providersProp.length > 0;
+ const providersChanged = JSON.stringify(providersProp) !== JSON.stringify(providers);
+
+ if (hasProviders && providersChanged) {
+ setProviders(providersProp);
}
- }, [selectedProvider, mode, reset]);
+ }, [providersProp, providers]);
- // Use the model selection hook with null check
+ // Use the model selection hook with selectedProvider
const {
selectedModelId,
setSelectedModelId,
- providerModelsData,
- isLoadingModels,
- isModelsError,
- modelsError,
handleModelSelect: handleModelSelectBase,
- } = useModelSelection(selectedProvider?.name ?? "", onModelSelect);
+ } = useModelSelection(selectedProvider, onModelSelect);
// Create a memoized version of handleModelSelect
const handleModelSelect = useCallback(() => {
- handleModelSelectBase();
- }, [handleModelSelectBase]);
-
- // Create a memoized version of onSubmit that calls both form and prop handlers
- const onSubmitHandler = useCallback(
- async (data: FormData) => {
- const result = await onSubmitForm(data, isCreating, selectedProvider?.id);
- if (result.success) {
- onSubmitProp(data);
- setMode("view");
- }
- },
- [onSubmitForm, onSubmitProp, setMode, isCreating, selectedProvider?.id],
- );
+ if (selectedProvider) {
+ handleModelSelectBase();
+ }
+ }, [handleModelSelectBase, selectedProvider]);
// Create a memoized version of onCancel that resets mode
const onCancelHandler = useCallback(() => {
@@ -317,51 +232,21 @@ export function InferenceProviderProvider({
providers,
setProviders,
- // Form state
- control,
- handleSubmit,
- errors,
- watch,
- providerName,
- reset,
-
// Model selection state
selectedModelId,
setSelectedModelId,
- providerModelsData,
- isLoadingModels,
- isModelsError,
- modelsError,
-
- // Form actions
handleModelSelect,
- isSubmitting,
- isCreating,
- // Form callbacks
- onSubmit: onSubmitHandler,
+ // Basic actions
onCancel: onCancelHandler,
}),
[
mode,
selectedProvider,
providers,
- control,
- handleSubmit,
- errors,
- watch,
- providerName,
- reset,
selectedModelId,
setSelectedModelId,
- providerModelsData,
- isLoadingModels,
- isModelsError,
- modelsError,
handleModelSelect,
- isSubmitting,
- isCreating,
- onSubmitHandler,
onCancelHandler,
],
);
@@ -372,6 +257,3 @@ export function InferenceProviderProvider({
);
}
-
-// For backward compatibility
-export const ProviderFormProvider = InferenceProviderProvider;
diff --git a/graphcap_studio/src/features/inference/providers/context/ProviderFormContext.tsx b/graphcap_studio/src/features/inference/providers/context/ProviderFormContext.tsx
new file mode 100644
index 00000000..5cce41e3
--- /dev/null
+++ b/graphcap_studio/src/features/inference/providers/context/ProviderFormContext.tsx
@@ -0,0 +1,63 @@
+// SPDX-License-Identifier: Apache-2.0
+import type { ConnectionDetails, ErrorDetails, Provider, ProviderCreate, ProviderUpdate } from "@/types/provider-config-types";
+import { type ReactNode, createContext, useContext } from "react";
+import type { Control, FieldErrors, UseFormWatch } from "react-hook-form";
+
+// Simplified dialog state type
+type DialogType = null | "error" | "success" | "formError" | "save";
+
+interface ProviderFormContextType {
+ // Core state
+ provider: Provider | null;
+ mode: "view" | "edit" | "create";
+
+ // Form state
+ control: Control;
+ errors: FieldErrors;
+ watch: UseFormWatch;
+
+ // UI state
+ isSubmitting: boolean;
+ dialog: DialogType;
+ error: ErrorDetails | null;
+ connectionDetails: ConnectionDetails | null;
+
+ // Selected model state
+ selectedModelId: string | null;
+ providerModels: Array<{ id: string; name: string; is_default?: boolean }> | null;
+
+ // Actions
+ setProvider: (provider: Provider | null) => void;
+ setMode: (mode: "view" | "edit" | "create") => void;
+ setSelectedModelId: (id: string | null) => void;
+ openDialog: (type: DialogType, error?: ErrorDetails) => void;
+ closeDialog: () => void;
+
+ // Form actions
+ handleSubmit: (e?: React.BaseSyntheticEvent) => Promise;
+ cancelEdit: () => void;
+ testConnection: () => Promise;
+}
+
+const ProviderFormContext = createContext(undefined);
+
+export function useProviderFormContext() {
+ const context = useContext(ProviderFormContext);
+ if (context === undefined) {
+ throw new Error("useProviderFormContext must be used within a ProviderFormProvider");
+ }
+ return context;
+}
+
+interface ProviderFormProviderProps {
+ readonly children: ReactNode;
+ readonly value: ProviderFormContextType;
+}
+
+export function ProviderFormProvider({ children, value }: ProviderFormProviderProps) {
+ return (
+
+ {children}
+
+ );
+}
diff --git a/graphcap_studio/src/features/inference/providers/form/ConnectionSection.tsx b/graphcap_studio/src/features/inference/providers/form/ConnectionSection.tsx
deleted file mode 100644
index d424a855..00000000
--- a/graphcap_studio/src/features/inference/providers/form/ConnectionSection.tsx
+++ /dev/null
@@ -1,94 +0,0 @@
-import { Switch } from "@/components/ui/buttons/Switch";
-import { useColorModeValue } from "@/components/ui/theme/color-mode";
-import { Box, Field, Input, Text, VStack } from "@chakra-ui/react";
-// SPDX-License-Identifier: Apache-2.0
-import { Controller } from "react-hook-form";
-import { useInferenceProviderContext } from "../context";
-
-/**
- * Component for displaying and editing provider connection settings
- */
-export function ConnectionSection() {
- const { control, errors, watch, isEditing } = useInferenceProviderContext();
- const labelColor = useColorModeValue("gray.600", "gray.300");
- const textColor = useColorModeValue("gray.700", "gray.200");
-
- // Watch form values for read-only display
- const baseUrl = watch("baseUrl");
- const envVar = watch("envVar");
- const isEnabled = watch("isEnabled");
-
- if (!isEditing) {
- return (
-
-
-
- Base URL
-
- {baseUrl}
-
-
-
-
- Environment Variable
-
- {envVar}
-
-
-
-
- Status
-
- {isEnabled ? "Enabled" : "Disabled"}
-
-
- );
- }
-
- return (
-
- (
-
- Base URL
-
- {errors.baseUrl?.message}
-
- )}
- />
-
- (
-
- Environment Variable
-
- {errors.envVar?.message}
-
- )}
- />
-
- (
-
-
-
- Enabled
-
-
-
- )}
- />
-
- );
-}
diff --git a/graphcap_studio/src/features/inference/providers/form/ModelSelector.tsx b/graphcap_studio/src/features/inference/providers/form/ModelSelector.tsx
deleted file mode 100644
index e535ee60..00000000
--- a/graphcap_studio/src/features/inference/providers/form/ModelSelector.tsx
+++ /dev/null
@@ -1,84 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-import { Field } from "@/components/ui/field";
-import {
- SelectContent,
- SelectItem,
- SelectRoot,
- SelectTrigger,
- SelectValueText,
-} from "@/components/ui/select";
-import { useColorMode } from "@/components/ui/theme/color-mode";
-import { Box, Heading, Text, createListCollection } from "@chakra-ui/react";
-
-// Define the model item type for the select component
-export interface ModelItem {
- label: string;
- value: string;
-}
-
-export interface ModelSelectorProps {
- modelItems: ModelItem[];
- selectedModelId: string | null;
- setSelectedModelId: (id: string) => void;
-}
-
-/**
- * Component for selecting a model from a list
- */
-export function ModelSelector({
- modelItems,
- selectedModelId,
- setSelectedModelId,
-}: ModelSelectorProps) {
- const { colorMode } = useColorMode();
- const isDark = colorMode === "dark";
-
- const cardBg = isDark ? "gray.800" : "white";
- const borderColor = isDark ? "gray.700" : "gray.200";
- const headingColor = isDark ? "gray.100" : "gray.700";
- const labelColor = isDark ? "gray.300" : "gray.600";
-
- const modelCollection = createListCollection({
- items: modelItems,
- });
-
- // Convert selectedModelId to string array format
- const value = selectedModelId ? [selectedModelId] : [];
-
- return (
-
-
- Model
-
-
- Select a model to use with this provider
-
-
- setSelectedModelId(details.value[0])}
- >
-
-
-
-
- {modelItems.map((item: ModelItem) => (
-
- {item.label}
-
- ))}
-
-
-
-
- );
-}
diff --git a/graphcap_studio/src/features/inference/providers/form/ProviderSelect.tsx b/graphcap_studio/src/features/inference/providers/form/ProviderSelect.tsx
deleted file mode 100644
index d3124ee1..00000000
--- a/graphcap_studio/src/features/inference/providers/form/ProviderSelect.tsx
+++ /dev/null
@@ -1,71 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-import {
- SelectContent,
- SelectItem,
- SelectRoot,
- SelectTrigger,
- SelectValueText,
-} from "@/components/ui/select";
-import { createListCollection } from "@chakra-ui/react";
-import { useInferenceProviderContext } from "../context";
-
-type ProviderSelectProps = {
- readonly className?: string;
- readonly "aria-label"?: string;
-};
-
-/**
- * Component for selecting a provider from a dropdown
- */
-export function ProviderSelect({
- className,
- "aria-label": ariaLabel = "Select Provider",
-}: ProviderSelectProps) {
- const { providers, selectedProvider, setSelectedProvider, setMode } =
- useInferenceProviderContext();
-
- const selectedProviderId = selectedProvider?.id ?? null;
-
- // Convert providers to the format expected by SelectRoot
- const providerItems = providers.map((provider) => ({
- label: provider.name,
- value: String(provider.id),
- }));
-
- const providerCollection = createListCollection({
- items: providerItems,
- });
-
- // Convert selectedProviderId to string array format
- const value = selectedProviderId ? [String(selectedProviderId)] : [];
-
- const handleProviderChange = (details: any) => {
- const id = Number(details.value[0]);
- const provider = providers.find((p) => p.id === id);
- if (provider) {
- setSelectedProvider(provider);
- setMode("view");
- }
- };
-
- return (
-
-
-
-
-
- {providerItems.map((item) => (
-
- {item.label}
-
- ))}
-
-
- );
-}
diff --git a/graphcap_studio/src/features/inference/providers/form/RateLimitsSection.tsx b/graphcap_studio/src/features/inference/providers/form/RateLimitsSection.tsx
deleted file mode 100644
index 0c6f8089..00000000
--- a/graphcap_studio/src/features/inference/providers/form/RateLimitsSection.tsx
+++ /dev/null
@@ -1,119 +0,0 @@
-import { useColorModeValue } from "@/components/ui/theme/color-mode";
-import {
- Box,
- Field,
- Grid,
- GridItem,
- Input,
- Text,
- VStack,
-} from "@chakra-ui/react";
-// SPDX-License-Identifier: Apache-2.0
-import { ChangeEvent } from "react";
-import { Controller } from "react-hook-form";
-import { useInferenceProviderContext } from "../context";
-
-/**
- * Component for displaying and editing provider rate limits
- */
-export function RateLimitsSection() {
- const { control, errors, watch, isEditing } = useInferenceProviderContext();
- const labelColor = useColorModeValue("gray.600", "gray.300");
- const textColor = useColorModeValue("gray.700", "gray.200");
-
- // Watch form values for read-only display
- const rateLimits = watch("rateLimits");
-
- if (!isEditing) {
- return (
-
-
-
- Rate Limits
-
-
-
-
- Requests per minute
-
-
- {rateLimits?.requestsPerMinute ?? 0}
-
-
-
-
- Tokens per minute
-
- {rateLimits?.tokensPerMinute ?? 0}
-
-
-
-
- );
- }
-
- return (
-
-
-
- Rate Limits
-
-
-
- (
-
-
- Requests per minute
-
- ) =>
- onChange(parseInt(e.target.value) || 0)
- }
- min={0}
- />
-
- {errors.rateLimits?.requestsPerMinute?.message}
-
-
- )}
- />
-
-
-
- (
-
-
- Tokens per minute
-
- ) =>
- onChange(parseInt(e.target.value) || 0)
- }
- min={0}
- />
-
- {errors.rateLimits?.tokensPerMinute?.message}
-
-
- )}
- />
-
-
-
-
- );
-}
diff --git a/graphcap_studio/src/features/inference/providers/form/index.ts b/graphcap_studio/src/features/inference/providers/form/index.ts
deleted file mode 100644
index 74080701..00000000
--- a/graphcap_studio/src/features/inference/providers/form/index.ts
+++ /dev/null
@@ -1,9 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-export * from "./BasicInfoSection";
-export * from "./ConnectionSection";
-export * from "./RateLimitsSection";
-export * from "./EnvironmentSelect";
-export * from "./ProviderSelect";
-export * from "../../../../components/ui/status/StatusMessage";
-export * from "./ModelSelector";
-export * from "../../../../components/ui/buttons/ActionButton";
diff --git a/graphcap_studio/src/features/inference/providers/index.ts b/graphcap_studio/src/features/inference/providers/index.ts
index 351881e1..386cead9 100644
--- a/graphcap_studio/src/features/inference/providers/index.ts
+++ b/graphcap_studio/src/features/inference/providers/index.ts
@@ -1,10 +1,7 @@
// SPDX-License-Identifier: Apache-2.0
-export { default as ProviderForm } from "./ProviderForm";
-export { ProvidersPanel } from "./ProvidersPanel";
-export { default as ProvidersList } from "./ProvidersList";
-export { ModelSelectionSection } from "./ModelSelectionSection";
-export { FormFields } from "./FormFields";
-export { FormActions } from "./FormActions";
-export * from "../hooks";
-export * from "./context";
+export * from '@/types/provider-config-types';
+export * from './context';
+export * from './ProviderConnection';
+export * from './ProvidersPanel';
+
diff --git a/graphcap_studio/src/features/inference/providers/types.ts b/graphcap_studio/src/features/inference/providers/types.ts
deleted file mode 100644
index d6e2915f..00000000
--- a/graphcap_studio/src/features/inference/providers/types.ts
+++ /dev/null
@@ -1,122 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-/**
- * Provider Types
- *
- * Type definitions for provider-related data.
- */
-
-/**
- * Provider model
- */
-export interface ProviderModel {
- id: number;
- providerId: number;
- name: string;
- isEnabled: boolean;
- createdAt: string | Date;
- updatedAt: string | Date;
-}
-
-/**
- * Rate limits configuration
- */
-export interface RateLimits {
- id: number;
- providerId: number;
- requestsPerMinute?: number;
- tokensPerMinute?: number;
- createdAt: string | Date;
- updatedAt: string | Date;
-}
-
-/**
- * Provider configuration
- */
-export interface Provider {
- id: number;
- name: string;
- kind: string;
- environment: "cloud" | "local";
- envVar: string;
- baseUrl: string;
- apiKey?: string;
- isEnabled: boolean;
- createdAt: string | Date;
- updatedAt: string | Date;
- models?: ProviderModel[];
- rateLimits?: RateLimits;
-}
-
-/**
- * Provider creation payload
- */
-export interface ProviderCreate {
- name: string;
- kind: string;
- environment: "cloud" | "local";
- envVar: string;
- baseUrl: string;
- apiKey?: string;
- isEnabled?: boolean;
- models?: Array<{
- name: string;
- isEnabled?: boolean;
- }>;
- rateLimits?: {
- requestsPerMinute?: number;
- tokensPerMinute?: number;
- };
-}
-
-/**
- * Provider update payload
- */
-export interface ProviderUpdate {
- name?: string;
- kind?: string;
- environment?: "cloud" | "local";
- envVar?: string;
- baseUrl?: string;
- isEnabled?: boolean;
- models?: Array<{
- id?: number;
- name: string;
- isEnabled?: boolean;
- }>;
- rateLimits?: {
- requestsPerMinute?: number;
- tokensPerMinute?: number;
- };
-}
-
-/**
- * Provider API key update payload
- */
-export interface ProviderApiKey {
- apiKey: string;
-}
-
-/**
- * Success response
- */
-export interface SuccessResponse {
- success: boolean;
- message: string;
-}
-
-/**
- * Provider model info from GraphCap server
- */
-export interface ProviderModelInfo {
- id: string;
- name: string;
- is_default: boolean;
-}
-
-/**
- * Provider models response from GraphCap server
- */
-export interface ProviderModelsResponse {
- provider: string;
- models: ProviderModelInfo[];
-}
diff --git a/graphcap_studio/src/features/inference/services/providers.ts b/graphcap_studio/src/features/inference/services/providers.ts
index fd95a03e..9d5ce231 100644
--- a/graphcap_studio/src/features/inference/services/providers.ts
+++ b/graphcap_studio/src/features/inference/services/providers.ts
@@ -8,69 +8,32 @@
import { useServerConnectionsContext } from "@/context/ServerConnectionsContext";
import { SERVER_IDS } from "@/features/server-connections/constants";
-import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
-import { hc } from "hono/client";
-import type { AppType } from "../../../../../data_service/src/app"; // TODO: Refactor
+import {
+ createDataServiceClient,
+ createInferenceBridgeClient,
+} from "@/features/server-connections/services/apiClients";
import type {
Provider,
- ProviderApiKey,
ProviderCreate,
- ProviderModelsResponse,
ProviderUpdate,
- SuccessResponse,
-} from "../providers/types";
+ ServerProviderConfig,
+ SuccessResponse
+} from "@/types/provider-config-types";
+import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
// Query keys for TanStack Query
export const queryKeys = {
providers: ["providers"] as const,
- provider: (id: number) => [...queryKeys.providers, id] as const,
+ provider: (id: number) => ["providers", id] as const,
providerModels: (providerName: string) =>
- [...queryKeys.providers, "models", providerName] as const,
+ ["providers", "models", providerName] as const,
};
-// Define a more specific type for the client
-interface DataServiceClient {
- providers: {
- $get: () => Promise;
- $post: (options: { json: ProviderCreate }) => Promise;
- [":id"]: {
- $get: (options: { param: { id: string } }) => Promise;
- $put: (options: {
- param: { id: string };
- json: ProviderUpdate;
- }) => Promise;
- $delete: (options: { param: { id: string } }) => Promise;
- "api-key": {
- $put: (options: {
- param: { id: string };
- json: ProviderApiKey;
- }) => Promise;
- };
- };
- };
-}
-
-/**
- * Get the Data Service URL from server connections context
- */
-function getDataServiceUrl(connections: any[]): string {
- const dataServiceConnection = connections.find(
- (conn) => conn.id === SERVER_IDS.DATA_SERVICE,
- );
-
- return (
- dataServiceConnection?.url ||
- import.meta.env.VITE_DATA_SERVICE_URL ||
- "http://localhost:32550"
- );
-}
-
/**
- * Create a Hono client for the Data Service
+ * Extended Error interface with cause property
*/
-function createDataServiceClient(connections: any[]): DataServiceClient {
- const baseUrl = getDataServiceUrl(connections);
- return hc(`${baseUrl}/api/v1`) as DataServiceClient;
+interface ErrorWithCause extends Error {
+ cause?: unknown;
}
/**
@@ -143,7 +106,33 @@ export function useCreateProvider() {
});
if (!response.ok) {
- throw new Error(`Failed to create provider: ${response.status}`);
+ // Try to get detailed error information
+ try {
+ const errorData = await response.json();
+ console.error("Provider creation error:", errorData);
+
+ // Check if we have a structured error response
+ if (errorData.status === "error" || errorData.validationErrors) {
+ throw errorData;
+ }
+
+ // Simple error with a message
+ if (errorData.message) {
+ throw new Error(errorData.message);
+ }
+
+ // Fallback error
+ throw new Error(`Failed to create provider: ${response.status}`);
+ } catch (parseError) {
+ // If we can't parse the error as JSON, throw a general error
+ if (
+ parseError instanceof Error &&
+ parseError.message !== "Failed to create provider"
+ ) {
+ throw parseError;
+ }
+ throw new Error(`Failed to create provider: ${response.status}`);
+ }
}
return response.json() as Promise;
@@ -164,21 +153,30 @@ export function useUpdateProvider() {
return useMutation({
mutationFn: async ({ id, data }: { id: number; data: ProviderUpdate }) => {
+ console.log("Updating provider with data:", data);
+
+ const apiData = { ...data };
+
const client = createDataServiceClient(connections);
const response = await client.providers[":id"].$put({
param: { id: id.toString() },
- json: data,
+ json: apiData,
});
if (!response.ok) {
- throw new Error(`Failed to update provider: ${response.status}`);
+ const errorData = await response.json();
+ console.error("Provider update error:", errorData);
+ throw errorData;
}
return response.json() as Promise;
},
onSuccess: (data) => {
+ // Convert string ID to number for query invalidation
+ const numericId = typeof data.id === 'string' ? Number.parseInt(data.id, 10) : data.id;
+
// Invalidate specific provider query
- queryClient.invalidateQueries({ queryKey: queryKeys.provider(data.id) });
+ queryClient.invalidateQueries({ queryKey: queryKeys.provider(numericId) });
// Invalidate providers list
queryClient.invalidateQueries({ queryKey: queryKeys.providers });
},
@@ -215,73 +213,59 @@ export function useDeleteProvider() {
}
/**
- * Hook to update a provider's API key
+ * Hook to test provider connection
*/
-export function useUpdateProviderApiKey() {
- const queryClient = useQueryClient();
+export function useTestProviderConnection() {
const { connections } = useServerConnectionsContext();
return useMutation({
- mutationFn: async ({ id, apiKey }: { id: number; apiKey: string }) => {
- const client = createDataServiceClient(connections);
- const response = await client.providers[":id"]["api-key"].$put({
- param: { id: id.toString() },
- json: { apiKey } as ProviderApiKey,
+ mutationFn: async ({
+ providerName,
+ config,
+ }: { providerName: string; config: ServerProviderConfig }) => {
+ const client = createInferenceBridgeClient(connections);
+
+ console.log("Testing connection with config:", JSON.stringify(config));
+
+ const response = await client.providers[":provider_name"][
+ "test-connection"
+ ].$post({
+ param: { provider_name: providerName },
+ json: config,
});
if (!response.ok) {
- throw new Error(`Failed to update API key: ${response.status}`);
+ const errorData = await response.json();
+ console.error("Error response:", errorData);
+
+ if (errorData.status === "error" && errorData.details) {
+ const error = new Error(
+ errorData.message || "Connection test failed",
+ ) as ErrorWithCause;
+ error.cause = errorData;
+ throw error;
+ }
+
+ // Handle different error formats
+ if (errorData.detail) {
+ throw new Error(errorData.detail);
+ }
+
+ if (errorData.message) {
+ throw new Error(errorData.message);
+ }
+
+ if (typeof errorData === "object") {
+ const error = new Error("Connection test failed") as ErrorWithCause;
+ error.cause = errorData;
+ throw error;
+ }
+
+ // Fallback to simple error
+ throw new Error(`Connection test failed: ${response.status}`);
}
- return response.json() as Promise;
+ return response.json();
},
- onSuccess: (_, { id }) => {
- // Invalidate specific provider query
- queryClient.invalidateQueries({ queryKey: queryKeys.provider(id) });
- },
- });
-}
-
-/**
- * Get the GraphCap Server URL from server connections context
- */
-function getGraphCapServerUrl(connections: any[]): string {
- const graphcapServerConnection = connections.find(
- (conn) => conn.id === SERVER_IDS.GRAPHCAP_SERVER,
- );
-
- return (
- graphcapServerConnection?.url ||
- import.meta.env.VITE_GRAPHCAP_SERVER_URL ||
- "http://localhost:32100"
- );
-}
-
-/**
- * Hook to get available models for a provider from the GraphCap server
- */
-export function useProviderModels(providerName: string) {
- const { connections } = useServerConnectionsContext();
- const graphcapServerConnection = connections.find(
- (conn) => conn.id === SERVER_IDS.GRAPHCAP_SERVER,
- );
- const isConnected = graphcapServerConnection?.status === "connected";
-
- return useQuery({
- queryKey: queryKeys.providerModels(providerName),
- queryFn: async () => {
- const baseUrl = getGraphCapServerUrl(connections);
- const response = await fetch(
- `${baseUrl}/providers/${providerName}/models`,
- );
-
- if (!response.ok) {
- throw new Error(`Failed to fetch provider models: ${response.status}`);
- }
-
- return response.json() as Promise;
- },
- enabled: isConnected && !!providerName,
- staleTime: 1000 * 60 * 5, // 5 minutes
});
}
diff --git a/graphcap_studio/src/features/perspectives/README.md b/graphcap_studio/src/features/perspectives/README.md
index 0c30ae8b..d5f4313b 100644
--- a/graphcap_studio/src/features/perspectives/README.md
+++ b/graphcap_studio/src/features/perspectives/README.md
@@ -165,4 +165,3 @@ Custom hooks are provided for working with perspectives:
- **usePerspectives** - Fetches available perspectives from the server
- **useGeneratePerspectiveCaption** - Generates captions for images using perspectives
-- **useImagePerspectives** - Manages perspective data for a specific image
\ No newline at end of file
diff --git a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveActions/PerspectivesFooter.tsx b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveActions/PerspectivesFooter.tsx
index 8435416d..b756cb30 100644
--- a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveActions/PerspectivesFooter.tsx
+++ b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveActions/PerspectivesFooter.tsx
@@ -6,27 +6,20 @@
*/
import { useColorModeValue } from "@/components/ui/theme/color-mode";
-import {
- GenerationOptionsButton,
- GenerationOptionsProvider,
- ProviderSelector,
-} from "@/features/inference/generation-options";
-import { DEFAULT_OPTIONS } from "@/features/inference/generation-options/schema";
import {
usePerspectiveUI,
usePerspectivesData,
} from "@/features/perspectives/context";
-import type { CaptionOptions } from "@/features/perspectives/types";
import {
Box,
Button,
Flex,
- HStack,
Icon,
- useBreakpointValue,
+ Text,
+ chakra,
} from "@chakra-ui/react";
-import { useCallback, useEffect } from "react";
-import { LuRefreshCw, LuSettings } from "react-icons/lu";
+import { useCallback, useEffect, useMemo } from "react";
+import { LuRefreshCw } from "react-icons/lu";
/**
* Helper function to determine button title text
@@ -64,10 +57,7 @@ export function PerspectivesFooter() {
generatePerspective,
isGenerating,
currentImage,
- captionOptions,
- setCaptionOptions,
- selectedProvider,
- handleProviderChange,
+ generationOptions,
} = usePerspectivesData();
// Use UI context
@@ -83,13 +73,19 @@ export function PerspectivesFooter() {
const bgColor = useColorModeValue("white", "gray.800");
const borderColor = useColorModeValue("gray.200", "gray.700");
-
- // Use responsive selector width based on screen size
- const selectorWidth = useBreakpointValue({
- base: "100%",
- sm: "12rem",
- md: "16rem",
- });
+ const infoTextColor = useColorModeValue("gray.600", "gray.400");
+
+ // Log information for debugging
+ console.log("GenerationOptions:", generationOptions);
+ console.log("Available providers:", availableProviders);
+
+ // Get provider information safely
+ const { providerName, modelName } = useMemo(() => {
+ return {
+ providerName: generationOptions.provider_name || "Select Provider",
+ modelName: generationOptions.model_name || "Select Model"
+ };
+ }, [generationOptions.provider_name, generationOptions.model_name]);
// Fetch providers on component mount
useEffect(() => {
@@ -109,7 +105,7 @@ export function PerspectivesFooter() {
return false;
}
- if (!selectedProvider) {
+ if (!generationOptions.provider_name) {
showMessage(
"No provider selected",
"Please select an inference provider",
@@ -128,21 +124,13 @@ export function PerspectivesFooter() {
}
return true;
- }, [activeSchemaName, selectedProvider, currentImage, showMessage]);
+ }, [activeSchemaName, generationOptions.provider_name, currentImage, showMessage]);
// Handle generate button click
const handleGenerate = useCallback(async () => {
console.log("Generate button clicked");
console.log("Active schema:", activeSchemaName);
- console.log("Selected provider:", selectedProvider);
-
- // Ensure we have valid options by applying defaults if needed
- const effectiveOptions =
- Object.keys(captionOptions).length === 0
- ? DEFAULT_OPTIONS
- : captionOptions;
-
- console.log("Using caption options:", effectiveOptions);
+ console.log("Generation options:", generationOptions);
if (!validateGeneration()) {
return;
@@ -150,12 +138,20 @@ export function PerspectivesFooter() {
try {
console.log("Calling generatePerspective...");
+ // Find the provider object from the available providers using provider_name
+ const providerObject = availableProviders.find(p => p.name === generationOptions.provider_name);
+
+ if (!providerObject) {
+ throw new Error(`Provider "${generationOptions.provider_name}" not found in available providers`);
+ }
+
await generatePerspective(
- activeSchemaName!,
- currentImage!.path,
- selectedProvider,
- effectiveOptions,
+ activeSchemaName as string,
+ currentImage?.path as string,
+ providerObject,
+ generationOptions
);
+
showMessage(
"Generation started",
`Generating ${activeSchemaName} perspective`,
@@ -171,9 +167,9 @@ export function PerspectivesFooter() {
}
}, [
activeSchemaName,
- selectedProvider,
+ availableProviders,
generatePerspective,
- captionOptions,
+ generationOptions,
showMessage,
currentImage,
validateGeneration,
@@ -184,37 +180,16 @@ export function PerspectivesFooter() {
// Check if button should be disabled
const isGenerateDisabled =
- isProcessing || !activeSchemaName || !selectedProvider;
+ isProcessing || !activeSchemaName || !generationOptions.provider_name;
// Get title for the generate button
const buttonTitle = getButtonTitle(
- selectedProvider,
+ generationOptions.provider_name,
activeSchemaName,
isProcessing,
isGenerated,
);
- // Handle options change
- const handleOptionsChange = useCallback(
- (newOptions: CaptionOptions) => {
- setCaptionOptions(newOptions);
- },
- [setCaptionOptions],
- );
-
- // Create a handler for the new ProviderSelector component
- const handleProviderSelection = useCallback(
- (provider: string) => {
- // Create a synthetic event to pass to the original handler
- const syntheticEvent = {
- target: { value: provider },
- } as React.ChangeEvent;
-
- handleProviderChange(syntheticEvent);
- },
- [handleProviderChange],
- );
-
return (
- {/* Provider Selection */}
- {availableProviders.length > 0 ? (
-
- ) : (
-
- )}
-
-
- {/* Options Button with Popover */}
-
-
-
- Options
-
- }
- size="sm"
- variant="ghost"
+ {/* Provider and Model Info */}
+
+ Using: {providerName} / {modelName}
+
+
+ {/* Generate/Regenerate Button */}
+
);
diff --git a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/PerspectiveCardTabbed.tsx b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/PerspectiveCardTabbed.tsx
index c6453bdc..e4d6eac5 100644
--- a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/PerspectiveCardTabbed.tsx
+++ b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/PerspectiveCardTabbed.tsx
@@ -1,5 +1,6 @@
import { ClipboardButton } from "@/components/ui/buttons";
import { useColorModeValue } from "@/components/ui/theme/color-mode";
+import type { PerspectiveSchema } from "@/types/perspective-types";
// SPDX-License-Identifier: Apache-2.0
/**
* PerspectiveCardTabbed Component
@@ -8,7 +9,6 @@ import { useColorModeValue } from "@/components/ui/theme/color-mode";
* This component uses Chakra UI tabs for the tabbed interface.
*/
import { Box, Card, Stack, Tabs, Text } from "@chakra-ui/react";
-import type { PerspectiveSchema } from "../../../types";
import { PerspectiveDebug } from "./PerspectiveDebug";
import { SchemaView } from "./SchemaView";
import { CaptionRenderer } from "./schema-fields";
@@ -175,8 +175,9 @@ export function PerspectiveCardTabbed({
{/* Metadata - e.g., timestamps or version info */}
- {data?.metadata?.timestamp &&
- new Date(data.metadata.timestamp).toLocaleString()}
+ {data?.metadata?.generatedAt || data?.metadata?.timestamp ?
+ new Date(data?.metadata?.generatedAt || data?.metadata?.timestamp || '').toLocaleString() :
+ ''}
diff --git a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/PerspectiveDebug.tsx b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/PerspectiveDebug.tsx
index 15f54073..6e7ad947 100644
--- a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/PerspectiveDebug.tsx
+++ b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/PerspectiveDebug.tsx
@@ -1,4 +1,5 @@
import { ClipboardButton } from "@/components/ui/buttons";
+import type { PerspectiveData, PerspectiveSchema } from "@/types";
import { Box, Stack } from "@chakra-ui/react";
// SPDX-License-Identifier: Apache-2.0
/**
@@ -8,7 +9,6 @@ import { Box, Stack } from "@chakra-ui/react";
* including its data, options, and metadata.
*/
import { useEffect } from "react";
-import type { PerspectiveData, PerspectiveSchema } from "../../../types";
import {
DataStatistics,
MetadataSection,
@@ -49,9 +49,7 @@ function processDebugInfo(
model: perspectiveData?.model,
version: perspectiveData?.version,
config_name: perspectiveData?.config_name ?? schema.name,
- generatedAt: data.metadata?.timestamp
- ? new Date(data.metadata.timestamp).toISOString()
- : null,
+ generatedAt: data.metadata?.generatedAt ?? null,
},
// Generation options - directly from the PerspectiveData interface
options: perspectiveData?.options || null,
diff --git a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/SchemaView.tsx b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/SchemaView.tsx
index 34b72a5b..b0cbb982 100644
--- a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/SchemaView.tsx
+++ b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/SchemaView.tsx
@@ -5,13 +5,12 @@
* This component displays the schema information for a perspective.
*/
-import React from "react";
-import type { PerspectiveSchema } from "../../../types";
+import type { PerspectiveSchema } from "@/types";
import { SchemaFieldFactory } from "./schema-fields";
interface SchemaViewProps {
- schema: PerspectiveSchema;
- className?: string;
+ readonly schema: PerspectiveSchema;
+ readonly className?: string;
}
export function SchemaView({ schema, className = "" }: SchemaViewProps) {
diff --git a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/debug-fields/MetadataSection.tsx b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/debug-fields/MetadataSection.tsx
index 143ddaee..13663e43 100644
--- a/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/debug-fields/MetadataSection.tsx
+++ b/graphcap_studio/src/features/perspectives/components/PerspectiveCaption/PerspectiveCard/debug-fields/MetadataSection.tsx
@@ -77,6 +77,11 @@ function MetadataItem({
labelColor,
valueColor,
}: MetadataItemProps) {
+ // Format date if this is the Generated timestamp field
+ const formattedValue = label === "Generated:" && value
+ ? new Date(value).toLocaleString()
+ : value;
+
return (
@@ -84,7 +89,7 @@ function MetadataItem({
{value ? (
- {value}
+ {formattedValue}
) : (
diff --git a/graphcap_studio/src/features/perspectives/components/PerspectiveManagement/PerspectiveFilterPanel.tsx b/graphcap_studio/src/features/perspectives/components/PerspectiveManagement/PerspectiveFilterPanel.tsx
deleted file mode 100644
index 177d6f2c..00000000
--- a/graphcap_studio/src/features/perspectives/components/PerspectiveManagement/PerspectiveFilterPanel.tsx
+++ /dev/null
@@ -1,103 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-/**
- * Perspective Filter Panel
- *
- * This component provides a UI for toggling the visibility of different perspectives.
- */
-
-import { Checkbox } from "@/components/ui/checkbox";
-import {
- Box,
- Button,
- Flex,
- HStack,
- Heading,
- Text,
- VStack,
-} from "@chakra-ui/react";
-import { useMemo } from "react";
-import { usePerspectivesData } from "../../context/PerspectivesDataContext";
-
-/**
- * Component for filtering which perspectives are visible in the UI
- */
-export function PerspectiveFilterPanel() {
- const {
- perspectives,
- hiddenPerspectives,
- togglePerspectiveVisibility,
- isPerspectiveVisible,
- setAllPerspectivesVisible,
- } = usePerspectivesData();
-
- // Count how many perspectives are visible/hidden
- const counts = useMemo(() => {
- const totalCount = perspectives.length;
- const hiddenCount = hiddenPerspectives.length;
- const visibleCount = totalCount - hiddenCount;
-
- return { totalCount, hiddenCount, visibleCount };
- }, [perspectives, hiddenPerspectives]);
-
- return (
-
-
-
- Perspective Visibility
-
- {counts.visibleCount} of {counts.totalCount} visible
-
-
-
-
-
-
- {perspectives.map((perspective) => (
- togglePerspectiveVisibility(perspective.name)}
- colorScheme="blue"
- size="sm"
- >
-
- {perspective.display_name || perspective.name}
-
-
- ))}
-
-
-
-
-
-
- Show All
-
- {
- for (const p of perspectives) {
- togglePerspectiveVisibility(p.name);
- }
- }}
- >
- Toggle All
-
-
-
-
- );
-}
diff --git a/graphcap_studio/src/features/perspectives/components/PerspectiveManagement/PerspectiveModuleFilter.tsx b/graphcap_studio/src/features/perspectives/components/PerspectiveManagement/PerspectiveModuleFilter.tsx
index 808298cc..762df226 100644
--- a/graphcap_studio/src/features/perspectives/components/PerspectiveManagement/PerspectiveModuleFilter.tsx
+++ b/graphcap_studio/src/features/perspectives/components/PerspectiveManagement/PerspectiveModuleFilter.tsx
@@ -170,7 +170,7 @@ export function PerspectiveModuleFilter({
alignItems="center"
justifyContent="center"
color={buttonColor}
- width="20px"
+ width="60px"
height="20px"
borderWidth="1px"
borderColor="currentColor"
@@ -178,7 +178,7 @@ export function PerspectiveModuleFilter({
ml={1}
_hover={{ bg: hoverBgColor }}
>
- →
+ View →
diff --git a/graphcap_studio/src/features/perspectives/components/PerspectiveManagement/index.ts b/graphcap_studio/src/features/perspectives/components/PerspectiveManagement/index.ts
index 2efe2459..e6565208 100644
--- a/graphcap_studio/src/features/perspectives/components/PerspectiveManagement/index.ts
+++ b/graphcap_studio/src/features/perspectives/components/PerspectiveManagement/index.ts
@@ -1,10 +1,10 @@
-export { ModuleList } from './PerspectiveModules/ModuleList';
-export { ModuleInfo } from './PerspectiveModules/ModuleInfo';
-export { NotFound } from './NotFound';
export { ErrorDisplay } from './ErrorDisplay';
export { LoadingDisplay } from './LoadingDisplay';
+export { NotFound } from './NotFound';
export { PerspectiveEditor } from './PerspectiveEditor/PerspectiveEditor';
-export { SchemaValidationError } from './SchemaValidationError';
-export { PerspectiveModuleFilter } from './PerspectiveModuleFilter';
export { PerspectiveManagementPanel } from './PerspectiveManagementPanel';
-export { PerspectiveFilterPanel } from './PerspectiveFilterPanel';
+export { PerspectiveModuleFilter } from './PerspectiveModuleFilter';
+export { ModuleInfo } from './PerspectiveModules/ModuleInfo';
+export { ModuleList } from './PerspectiveModules/ModuleList';
+export { SchemaValidationError } from './SchemaValidationError';
+
diff --git a/graphcap_studio/src/features/perspectives/components/PerspectivesErrorState.tsx b/graphcap_studio/src/features/perspectives/components/PerspectivesErrorState.tsx
index af12d179..908a091c 100644
--- a/graphcap_studio/src/features/perspectives/components/PerspectivesErrorState.tsx
+++ b/graphcap_studio/src/features/perspectives/components/PerspectivesErrorState.tsx
@@ -42,7 +42,7 @@ export function PerspectivesErrorState({
Server Connection Error
- Unable to connect to the GraphCap server. Please check your
+ Unable to connect to the Inference Bridge. Please check your
connection settings and try again.
onReconnect
? onReconnect()
- : handleConnect(SERVER_IDS.GRAPHCAP_SERVER)
+ : handleConnect(SERVER_IDS.INFERENCE_BRIDGE)
}
>
diff --git a/graphcap_studio/src/features/perspectives/components/index.ts b/graphcap_studio/src/features/perspectives/components/index.ts
index 07d89562..da4e9e35 100644
--- a/graphcap_studio/src/features/perspectives/components/index.ts
+++ b/graphcap_studio/src/features/perspectives/components/index.ts
@@ -6,11 +6,11 @@
*/
export * from "./PerspectiveCaption/EmptyPerspectives";
-export * from "./PerspectivesErrorState";
-export * from "./PerspectiveManagement/PerspectiveFilterPanel";
export * from "./PerspectiveCaption/ErrorMessage";
+export * from "./PerspectiveCaption/PerspectiveActions";
+export { MetadataDisplay } from "./PerspectiveCaption/PerspectiveCard/MetadataDisplay";
export { PerspectiveHeader } from "./PerspectiveCaption/PerspectiveNavigation/PerspectiveHeader";
export { PerspectivesPager } from "./PerspectiveCaption/PerspectiveNavigation/PerspectivesPager";
-export { MetadataDisplay } from "./PerspectiveCaption/PerspectiveCard/MetadataDisplay";
-export * from "./PerspectiveCaption/PerspectiveActions";
export * from "./PerspectiveManagement/PerspectiveManagementPanel";
+export * from "./PerspectivesErrorState";
+
diff --git a/graphcap_studio/src/features/perspectives/constants/index.ts b/graphcap_studio/src/features/perspectives/constants/index.ts
index ef436ffe..35a8f247 100644
--- a/graphcap_studio/src/features/perspectives/constants/index.ts
+++ b/graphcap_studio/src/features/perspectives/constants/index.ts
@@ -54,7 +54,6 @@ export const perspectivesQueryKeys = {
// Constants for API endpoints
export const API_ENDPOINTS = {
LIST_PERSPECTIVES: "/perspectives/list",
- GENERATE_CAPTION: "/perspectives/caption",
VIEW_IMAGE: "/images/view",
REST_LIST_PERSPECTIVES: "/perspectives/list",
REST_GENERATE_CAPTION: "/perspectives/caption-from-path",
@@ -72,6 +71,5 @@ export const CACHE_TIMES = {
// Default values
export const DEFAULTS = {
SERVER_URL: "http://localhost:32100",
- PROVIDER: "gemini",
- DEFAULT_FILENAME: "image.jpg",
+
};
diff --git a/graphcap_studio/src/features/perspectives/context/PerspectiveUIContext.tsx b/graphcap_studio/src/features/perspectives/context/PerspectiveUIContext.tsx
index 58f7e9f3..9d2d5e17 100644
--- a/graphcap_studio/src/features/perspectives/context/PerspectiveUIContext.tsx
+++ b/graphcap_studio/src/features/perspectives/context/PerspectiveUIContext.tsx
@@ -6,17 +6,18 @@
* It follows the Context API best practices and focuses exclusively on UI concerns.
*/
-import React, {
+import type { PerspectiveSchema } from "@/types";
+import type React from "react";
+import {
+ type ReactNode,
createContext,
+ useCallback,
useContext,
- ReactNode,
useEffect,
useMemo,
- useState,
useRef,
- useCallback,
+ useState,
} from "react";
-import { PerspectiveSchema } from "../types";
import {
getSelectedPerspective,
saveSelectedPerspective,
diff --git a/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx b/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx
index 505f9327..5b48c819 100644
--- a/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx
+++ b/graphcap_studio/src/features/perspectives/context/PerspectivesDataContext.tsx
@@ -10,27 +10,28 @@
*/
import { useServerConnectionsContext } from "@/context";
+import { useGenerationOptions } from "@/features/inference/generation-options/context";
import { SERVER_IDS } from "@/features/server-connections/constants";
+import { useProviders } from "@/features/server-connections/services/providers";
import type { Image } from "@/services/images";
+import type {
+ Perspective,
+ PerspectiveData,
+ PerspectiveSchema
+} from "@/types";
+import type { GenerationOptions } from "@/types/generation-option-types";
+import type { Provider } from "@/types/provider-config-types";
import React, {
createContext,
- useContext,
- type ReactNode,
- useState,
useCallback,
+ useContext,
useEffect,
useMemo,
+ useState,
+ type ReactNode,
} from "react";
-import { useProviders } from "../../inference/services/providers";
import { useGeneratePerspectiveCaption } from "../hooks/useGeneratePerspectiveCaption";
import { usePerspectives } from "../hooks/usePerspectives";
-import type {
- CaptionOptions,
- Perspective,
- PerspectiveData,
- PerspectiveSchema,
- Provider,
-} from "../types";
import {
getAllPerspectiveCaptions,
loadHiddenPerspectives,
@@ -106,14 +107,13 @@ interface PerspectivesDataContextType {
refetchPerspectives: () => Promise;
// Captions data
- captions: Record;
+ captions: Record;
generatedPerspectives: string[];
isGenerating: boolean;
isServerConnected: boolean;
- // Caption options
- captionOptions: CaptionOptions;
- setCaptionOptions: (options: CaptionOptions) => void;
+ // Generation options from the GenerationOptions context
+ generationOptions: GenerationOptions;
// Current image
currentImage: Image | null;
@@ -123,9 +123,9 @@ interface PerspectivesDataContextType {
generatePerspective: (
schemaName: string,
imagePath: string,
- provider_name?: string,
- options?: CaptionOptions,
- ) => Promise;
+ provider?: Provider,
+ options?: GenerationOptions,
+ ) => Promise;
// Status helpers
isPerspectiveGenerated: (schemaName: string) => boolean;
@@ -162,7 +162,6 @@ interface PerspectivesDataProviderProps {
readonly image: Image | null;
readonly initialProvider?: string;
readonly initialProviders?: Provider[];
- readonly initialCaptionOptions?: CaptionOptions;
}
/**
@@ -174,15 +173,17 @@ export function PerspectivesDataProvider({
image: initialImage,
initialProvider,
initialProviders = [],
- initialCaptionOptions = {},
}: PerspectivesDataProviderProps) {
// Server connection state
const { connections } = useServerConnectionsContext();
const graphcapServerConnection = connections.find(
- (conn) => conn.id === SERVER_IDS.GRAPHCAP_SERVER,
+ (conn) => conn.id === SERVER_IDS.INFERENCE_BRIDGE,
);
const isServerConnected = graphcapServerConnection?.status === "connected";
+ // Get generation options from context
+ const generationOptions = useGenerationOptions();
+
// Current image state
const [currentImage, setCurrentImage] = useState(initialImage);
@@ -194,13 +195,8 @@ export function PerspectivesDataProvider({
useState(initialProviders);
const [isGeneratingAll, setIsGeneratingAll] = useState(false);
- // Caption options state
- const [captionOptions, setCaptionOptions] = useState(
- initialCaptionOptions,
- );
-
// Captions state
- const [captions, setCaptions] = useState({});
+ const [captions, setCaptions] = useState>({});
// Generation state
const [generatingPerspectives, setGeneratingPerspectives] = useState<
@@ -289,10 +285,9 @@ export function PerspectivesDataProvider({
if (prev.includes(perspectiveName)) {
// If already hidden, make it visible (remove from hidden list)
return prev.filter((name) => name !== perspectiveName);
- } else {
- // If visible, hide it (add to hidden list)
- return [...prev, perspectiveName];
}
+ // If visible, hide it (add to hidden list)
+ return [...prev, perspectiveName];
});
}, []);
@@ -315,7 +310,7 @@ export function PerspectivesDataProvider({
}, [refetchProviders]);
// Refetch perspectives and return the data
- const refetchPerspectives = async (): Promise => {
+ const refetchPerspectives = useCallback(async (): Promise => {
if (!isServerConnected) {
throw new Error(
"Cannot refetch perspectives: Server connection not established",
@@ -329,7 +324,7 @@ export function PerspectivesDataProvider({
console.error("Error refetching perspectives:", error);
throw error;
}
- };
+ }, [isServerConnected, refetchPerspectivesQuery]);
// Load captions from localStorage when image changes
useEffect(() => {
@@ -360,8 +355,9 @@ export function PerspectivesDataProvider({
// Get generated perspectives based on captions
const generatedPerspectives = React.useMemo(() => {
- if (!captions.perspectives) return [];
- return Object.keys(captions.perspectives);
+ const perspectives = captions.perspectives as Record | undefined;
+ if (!perspectives) return [];
+ return Object.keys(perspectives);
}, [captions]);
// Generate a perspective caption and save to localStorage
@@ -369,8 +365,8 @@ export function PerspectivesDataProvider({
async (
schemaName: string,
imagePath: string,
- provider_name?: string,
- options?: CaptionOptions,
+ provider?: Provider,
+ options?: GenerationOptions,
) => {
if (!isServerConnected) {
throw new Error(
@@ -384,18 +380,25 @@ export function PerspectivesDataProvider({
// Add to generating list
setGeneratingPerspectives((prev) => [...prev, schemaName]);
- // Use provided provider or selected provider
- const effectiveProvider = provider_name ?? selectedProvider;
+ // Use provided provider or get the selected provider by name from available providers
+ let effectiveProvider = provider;
+ if (!effectiveProvider && selectedProvider) {
+ // Find the provider object by name
+ effectiveProvider = availableProviders.find(p => p.name === selectedProvider);
+ }
if (!effectiveProvider) {
throw new Error("No provider selected for caption generation");
}
+ // Get current options from GenerationOptions context if not provided
+ const effectiveOptions = options || generationOptions.options;
+
// Log the options to ensure they're being passed correctly
console.debug(`Generating perspective "${schemaName}" with options:`, {
providedOptions: options,
- contextOptions: captionOptions,
- finalOptions: options ?? captionOptions ?? {},
+ contextOptions: generationOptions.options,
+ finalOptions: effectiveOptions,
provider: effectiveProvider,
});
@@ -403,8 +406,8 @@ export function PerspectivesDataProvider({
const result = await generateCaptionMutation.mutateAsync({
perspective: schemaName,
imagePath,
- provider_name: effectiveProvider,
- options: options ?? captionOptions,
+ provider: effectiveProvider,
+ options: effectiveOptions,
});
// Validate required data
@@ -414,50 +417,48 @@ export function PerspectivesDataProvider({
);
}
- if (!effectiveProvider) {
- console.error(
- `ERROR: Missing provider information for perspective ${schemaName}`,
- );
- }
-
- if (!options && !captionOptions) {
- console.error(
- `ERROR: Missing generation options for perspective ${schemaName}`,
- );
- }
-
- // Format the data as PerspectiveData object - no defaults!
- const perspectiveData: PerspectiveData = {
+ // Format the data as PerspectiveData object
+ const perspectiveData = {
config_name: schemaName,
version: "1.0",
- model:
- result.metadata?.model ??
- (() => {
- console.error(
- `CRITICAL ERROR: Missing model information in API response for perspective ${schemaName}`,
- );
- return "MISSING_MODEL";
- })(),
- provider: effectiveProvider,
+ model: result.metadata?.model ?? effectiveOptions.model_name ?? "MISSING_MODEL",
+ provider: effectiveProvider.name,
content: result.result || {},
- options: options || captionOptions,
+ options: {
+ model: effectiveOptions.model_name,
+ max_tokens: effectiveOptions.max_tokens,
+ temperature: effectiveOptions.temperature,
+ top_p: effectiveOptions.top_p,
+ repetition_penalty: effectiveOptions.repetition_penalty,
+ global_context: effectiveOptions.global_context,
+ context: effectiveOptions.context,
+ resize_resolution: effectiveOptions.resize_resolution
+ },
+ metadata: {
+ provider: effectiveProvider.name,
+ model: result.metadata?.model ?? effectiveOptions.model_name ?? "MISSING_MODEL",
+ version: "1.0",
+ config_name: schemaName,
+ generatedAt: new Date().toISOString()
+ }
};
// Save the perspective directly to localStorage
savePerspectiveCaption(imagePath, schemaName, perspectiveData);
// Update captions state with this new perspective data
- setCaptions((prev: Record) => {
+ setCaptions((prev) => {
+ const prevPerspectives = (prev.perspectives || {}) as Record;
const newCaptions = {
...prev,
perspectives: {
- ...prev.perspectives,
+ ...prevPerspectives,
[schemaName]: perspectiveData,
},
metadata: {
captioned_at: new Date().toISOString(),
- provider: effectiveProvider,
- model: result.metadata?.model ?? "unknown",
+ provider: effectiveProvider?.name || "",
+ model: result.metadata?.model ?? effectiveOptions.model_name ?? "unknown",
},
};
@@ -480,7 +481,8 @@ export function PerspectivesDataProvider({
isServerConnected,
currentImage,
selectedProvider,
- captionOptions,
+ availableProviders,
+ generationOptions.options,
generateCaptionMutation,
],
);
@@ -488,7 +490,8 @@ export function PerspectivesDataProvider({
// Helper to check if a perspective is generated
const isPerspectiveGenerated = useCallback(
(schemaName: string) => {
- return !!captions.perspectives?.[schemaName];
+ const perspectives = captions.perspectives as Record | undefined;
+ return !!perspectives?.[schemaName];
},
[captions],
);
@@ -505,107 +508,99 @@ export function PerspectivesDataProvider({
const getPerspectiveData = useCallback(
(schemaName: string) => {
// Try to get data from our in-memory state
- const perspectiveData = captions.perspectives?.[schemaName];
+ const perspectives = captions.perspectives as Record | undefined;
+ const perspectiveData = perspectives?.[schemaName];
console.debug("getPerspectiveData for", schemaName, perspectiveData);
// Always return the complete perspective data object
// to preserve options and metadata
- return perspectiveData;
+ return perspectiveData ? { ...perspectiveData } as Record : null;
},
[captions],
);
// Create consolidated context value
- const value: PerspectivesDataContextType = useMemo(
- () => ({
- // Provider state
- selectedProvider,
- availableProviders,
- isGeneratingAll,
-
- // Provider actions
- setSelectedProvider,
- setAvailableProviders,
- setIsGeneratingAll,
- handleProviderChange,
-
- // Data fetching - providers
- fetchProviders,
- isLoadingProviders,
- providerError,
-
- // Perspectives data
- perspectives: perspectivesData || [],
- schemas,
- isLoadingPerspectives,
- perspectivesError,
- refetchPerspectives,
-
- // Captions data
- captions,
- generatedPerspectives,
- isGenerating: generatingPerspectives.length > 0,
- isServerConnected,
-
- // Caption options
- captionOptions,
- setCaptionOptions,
-
- // Current image
- currentImage,
- setCurrentImage,
-
- // Generation operations
- generatePerspective,
-
- // Status helpers
- isPerspectiveGenerated,
- isPerspectiveGenerating,
-
- // Data helpers
- getPerspectiveData,
-
- // Perspective visibility
- hiddenPerspectives,
- togglePerspectiveVisibility,
- isPerspectiveVisible,
- setAllPerspectivesVisible,
- }),
- [
- selectedProvider,
- availableProviders,
- isGeneratingAll,
- setSelectedProvider,
- setAvailableProviders,
- setIsGeneratingAll,
- handleProviderChange,
- fetchProviders,
- isLoadingProviders,
- providerError,
- perspectivesData,
- schemas,
- isLoadingPerspectives,
- perspectivesError,
- refetchPerspectives,
- captions,
- generatedPerspectives,
- generatingPerspectives,
- isServerConnected,
- captionOptions,
- setCaptionOptions,
- currentImage,
- setCurrentImage,
- generatePerspective,
- isPerspectiveGenerated,
- isPerspectiveGenerating,
- getPerspectiveData,
- hiddenPerspectives,
- togglePerspectiveVisibility,
- isPerspectiveVisible,
- setAllPerspectivesVisible,
- ],
- );
+ const value = useMemo(() => ({
+ // Provider state
+ selectedProvider,
+ availableProviders,
+ isGeneratingAll,
+
+ // Provider actions
+ setSelectedProvider,
+ setAvailableProviders,
+ setIsGeneratingAll,
+ handleProviderChange,
+
+ // Data fetching - providers
+ fetchProviders,
+ isLoadingProviders,
+ providerError,
+
+ // Perspectives data
+ perspectives: perspectivesData || [],
+ schemas,
+ isLoadingPerspectives,
+ perspectivesError,
+ refetchPerspectives,
+
+ // Captions data
+ captions,
+ generatedPerspectives,
+ isGenerating: generatingPerspectives.length > 0,
+ isServerConnected,
+
+ // Generation options from context
+ generationOptions: generationOptions.options,
+
+ // Current image
+ currentImage,
+ setCurrentImage,
+
+ // Generation operations
+ generatePerspective,
+
+ // Status helpers
+ isPerspectiveGenerated,
+ isPerspectiveGenerating,
+
+ // Data helpers
+ getPerspectiveData,
+
+ // Perspective visibility
+ hiddenPerspectives,
+ togglePerspectiveVisibility,
+ isPerspectiveVisible,
+ setAllPerspectivesVisible,
+ }), [
+ selectedProvider,
+ availableProviders,
+ isGeneratingAll,
+ handleProviderChange,
+ fetchProviders,
+ isLoadingProviders,
+ providerError,
+ perspectivesData,
+ schemas,
+ isLoadingPerspectives,
+ perspectivesError,
+ refetchPerspectives,
+ captions,
+ generatedPerspectives,
+ generatingPerspectives.length,
+ isServerConnected,
+ generationOptions.options,
+ currentImage,
+ generatePerspective,
+ isPerspectiveGenerated,
+ isPerspectiveGenerating,
+ getPerspectiveData,
+ hiddenPerspectives,
+ togglePerspectiveVisibility,
+ isPerspectiveVisible,
+ setAllPerspectivesVisible
+ ]);
return (
diff --git a/graphcap_studio/src/features/perspectives/context/PerspectivesProvider.tsx b/graphcap_studio/src/features/perspectives/context/PerspectivesProvider.tsx
index a22e1d1e..d5909b1f 100644
--- a/graphcap_studio/src/features/perspectives/context/PerspectivesProvider.tsx
+++ b/graphcap_studio/src/features/perspectives/context/PerspectivesProvider.tsx
@@ -6,9 +6,9 @@
* to simplify usage in component trees.
*/
-import { Image } from "@/services/images";
-import { ReactNode } from "react";
-import { Provider } from "../types";
+import type { Image } from "@/services/images";
+import type { Provider } from "@/types/provider-config-types";
+import type { ReactNode } from "react";
import { PerspectiveUIProvider } from "./PerspectiveUIContext";
import { PerspectivesDataProvider } from "./PerspectivesDataContext";
diff --git a/graphcap_studio/src/features/perspectives/hooks/index.ts b/graphcap_studio/src/features/perspectives/hooks/index.ts
index d67eb8be..863e4f4d 100644
--- a/graphcap_studio/src/features/perspectives/hooks/index.ts
+++ b/graphcap_studio/src/features/perspectives/hooks/index.ts
@@ -9,10 +9,9 @@
export { usePerspectiveUI } from "./usePerspectiveUI";
// API Hooks
-export { usePerspectives } from "./usePerspectives";
-export { usePerspectiveModules } from "./usePerspectiveModules";
export { useGeneratePerspectiveCaption } from "./useGeneratePerspectiveCaption";
-export { useImagePerspectives } from "./useImagePerspectives";
+export { usePerspectiveModules } from "./usePerspectiveModules";
+export { usePerspectives } from "./usePerspectives";
// Utilities
diff --git a/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts b/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts
index a0de62be..3bebd87e 100644
--- a/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts
+++ b/graphcap_studio/src/features/perspectives/hooks/useGeneratePerspectiveCaption.ts
@@ -7,15 +7,21 @@
import { useServerConnectionsContext } from "@/context";
import { SERVER_IDS } from "@/features/server-connections/constants";
-import type { ServerConnection } from "@/features/server-connections/types";
-import { useMutation, useQueryClient } from "@tanstack/react-query";
-import { API_ENDPOINTS, perspectivesQueryKeys } from "../services/constants";
+import { createInferenceBridgeClient } from "@/features/server-connections/services/apiClients";
+import type { CaptionResponse } from "@/types";
+import {
+ type GenerationOptions,
+ formatApiOptions
+} from "@/types/generation-option-types";
import {
- ensureWorkspacePath,
- getGraphCapServerUrl,
- handleApiError,
-} from "../services/utils";
-import type { CaptionOptions, CaptionResponse } from "../types";
+ type Provider,
+ toServerConfig,
+} from "@/types/provider-config-types";
+import type { ServerConnection } from "@/types/server-connection-types";
+import { toast } from "@/utils/toast";
+import { useMutation, useQueryClient } from "@tanstack/react-query";
+import { perspectivesQueryKeys } from "../services/constants";
+import { ensureWorkspacePath, handleApiError } from "../services/utils";
/**
* Hook to generate a perspective caption for an image
@@ -32,13 +38,13 @@ export function useGeneratePerspectiveCaption() {
{
perspective: string;
imagePath: string;
- provider_name: string;
- options?: CaptionOptions;
+ provider: Provider;
+ options: GenerationOptions;
}
>({
- mutationFn: async ({ perspective, imagePath, provider_name, options }) => {
+ mutationFn: async ({ perspective, imagePath, provider, options }) => {
const graphcapServerConnection = connections.find(
- (conn: ServerConnection) => conn.id === SERVER_IDS.GRAPHCAP_SERVER,
+ (conn: ServerConnection) => conn.id === SERVER_IDS.INFERENCE_BRIDGE,
);
const isConnected = graphcapServerConnection?.status === "connected";
@@ -50,58 +56,47 @@ export function useGeneratePerspectiveCaption() {
throw new Error("Caption generation options are required");
}
- const baseUrl = getGraphCapServerUrl(connections);
- if (!baseUrl) {
- throw new Error("No GraphCap server URL available");
+ // Check if a model is specified in the options
+ if (!options.model_name) {
+ throw new Error("A model must be specified in the options");
}
+ // Use the inference bridge client instead of direct fetch
+ const client = createInferenceBridgeClient(connections);
+
// Normalize the image path to ensure it starts with /workspace
const normalizedImagePath = ensureWorkspacePath(imagePath);
+ // Convert provider to server config
+ const providerConfig = toServerConfig(provider);
+
console.log(
`Generating caption for image: ${normalizedImagePath} using perspective: ${perspective}`,
);
- const endpoint = API_ENDPOINTS.REST_GENERATE_CAPTION;
- const url = `${baseUrl}${endpoint}`;
-
+ // Format options for API request
+ const apiOptions = formatApiOptions(options);
+
// Prepare the request body according to the server's expected format
const requestBody = {
perspective,
image_path: normalizedImagePath,
- provider: provider_name,
- max_tokens: options.max_tokens,
- temperature: options.temperature,
- top_p: options.top_p,
- repetition_penalty: options.repetition_penalty,
- context: options.context || [],
- global_context: options.global_context ?? "",
- resize: options.resize ?? false,
- resize_resolution: options.resize_resolution ?? "HD_720P",
+ provider: provider.name,
+ model: options.model_name, // Use model_name from GenerationOptions
+ provider_config: providerConfig, // Include the full provider configuration
+ ...apiOptions, // Spread the formatted API options
};
- console.log(`Sending caption generation request to: ${url}`, {
+ console.log("Sending caption generation request using API client", {
perspective,
image_path: normalizedImagePath,
- provider: provider_name,
- options: {
- max_tokens: requestBody.max_tokens,
- temperature: requestBody.temperature,
- top_p: requestBody.top_p,
- repetition_penalty: requestBody.repetition_penalty,
- context: requestBody.context,
- global_context: requestBody.global_context,
- resize: requestBody.resize,
- resize_resolution: requestBody.resize_resolution,
- },
+ provider: provider.name,
+ model: options.model_name, // Log the model_name from options
+ options: apiOptions,
});
- const response = await fetch(url, {
- method: "POST",
- headers: {
- "Content-Type": "application/json",
- },
- body: JSON.stringify(requestBody),
+ const response = await client.perspectives["caption-from-path"].$post({
+ json: requestBody,
});
if (!response.ok) {
@@ -140,6 +135,10 @@ export function useGeneratePerspectiveCaption() {
},
onError: (error) => {
console.error("Caption generation failed", error);
+ toast.error({
+ title: "Caption generation failed",
+ description: error.message,
+ });
},
});
}
diff --git a/graphcap_studio/src/features/perspectives/hooks/useImagePerspectives.ts b/graphcap_studio/src/features/perspectives/hooks/useImagePerspectives.ts
deleted file mode 100644
index c9bcf2ed..00000000
--- a/graphcap_studio/src/features/perspectives/hooks/useImagePerspectives.ts
+++ /dev/null
@@ -1,332 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-/**
- * useImagePerspectives Hook
- *
- * This hook manages perspective data for a specific image.
- */
-
-import { useServerConnectionsContext } from "@/context";
-import { useProviders } from "@/features/inference/services/providers";
-import { SERVER_IDS } from "@/features/server-connections/constants";
-import type { Image } from "@/services/images";
-import { useCallback, useEffect, useState } from "react";
-
-import type {
- CaptionOptions,
- ImageCaptions,
- ImagePerspectivesResult,
- PerspectiveData,
- PerspectiveType,
-} from "../types";
-import { useGeneratePerspectiveCaption } from "./useGeneratePerspectiveCaption";
-import { usePerspectives } from "./usePerspectives";
-
-/**
- * Hook for fetching and managing perspective data for an image
- *
- * This hook combines the functionality of the perspectives API and captions
- * to provide a unified interface for working with image perspectives.
- *
- * @param image - The image to get perspectives for
- * @returns An object with perspective data and functions to manage it
- */
-export function useImagePerspectives(
- image: Image | null,
-): ImagePerspectivesResult {
- const [captions, setCaptions] = useState(null);
- const [isLoading, setIsLoading] = useState(false);
- const [generatingPerspectives, setGeneratingPerspectives] = useState<
- string[]
- >([]);
- const [error, setError] = useState(null);
-
- // Get server connection status
- const { connections } = useServerConnectionsContext();
- const graphcapServerConnection = connections.find(
- (conn) => conn.id === SERVER_IDS.GRAPHCAP_SERVER,
- );
- const isServerConnected = graphcapServerConnection?.status === "connected";
-
- console.debug("useImagePerspectives hook initialized", {
- imagePath: image?.path,
- isServerConnected,
- });
-
- // Derived state
- const generatedPerspectives = captions
- ? Object.keys(captions.perspectives)
- : [];
-
- // Get available perspectives from the server
- const { data: perspectivesData } = usePerspectives();
-
- // Get available providers
- const { data: providersData } = useProviders();
-
- // Generate caption mutation
- const generateCaption = useGeneratePerspectiveCaption();
-
- // Derived state for available perspectives
- const availablePerspectives = perspectivesData || [];
-
- // Derived state for available providers
- const availableProviders =
- providersData?.map((provider) => ({
- id: provider.id,
- name: provider.name,
- })) || [];
-
- // Function to generate a perspective using the perspectives API
- const generatePerspective = useCallback(
- async (
- perspective: PerspectiveType,
- providerId?: number,
- options?: CaptionOptions,
- ) => {
- if (!image) {
- console.warn("Cannot generate perspective: No image provided");
- setError("No image provided");
- return;
- }
-
- if (!options) {
- console.warn("No options provided, using default options");
- setError("No options provided");
- return;
- }
-
- if (!isServerConnected) {
- console.warn(
- "Cannot generate perspective: Server connection not established",
- );
- setError("Server connection not established");
- return;
- }
-
- // Find the provider by ID if provided
- let provider_name: string | undefined;
- if (providerId && providersData) {
- const provider = providersData.find((p) => p.id === providerId);
- if (provider) {
- provider_name = provider.name;
- console.debug(`Using provider: ${provider_name} (ID: ${providerId})`);
- } else {
- console.warn(`Provider with ID ${providerId} not found`);
- setError(`Provider with ID ${providerId} not found`);
- return;
- }
- } else {
- console.warn("No provider ID specified");
- setError("No provider ID specified");
- return;
- }
-
- console.log(`Generating perspective: ${perspective}`, {
- imagePath: image.path,
- provider_name,
- options,
- });
-
- setError(null);
- // Track which perspective is being generated
- setGeneratingPerspectives((prev) => [...prev, perspective]);
- setIsLoading(true);
-
- try {
- // Generate the caption
- const result = await generateCaption.mutateAsync({
- imagePath: image.path,
- perspective,
- provider_name,
- options,
- });
-
- // Log the caption result
- console.debug("Caption generation result received");
- console.debug(
- `Caption content for perspective ${perspective}:`,
- result.content || result.result,
- );
-
- // Create a perspective data object
- const perspectiveData: PerspectiveData = {
- config_name: perspective,
- version: "1.0",
- model: "api-generated",
- provider: provider_name,
- content: result.content || result.result || {},
- options: options,
- };
-
- // Update the captions with the new perspective
- setCaptions((prevCaptions) => {
- if (!prevCaptions) {
- // Create a new captions object if none exists
- console.debug("Creating new captions object");
- return {
- image,
- perspectives: {
- [perspective]: perspectiveData,
- },
- metadata: {
- captioned_at: new Date().toISOString(),
- provider: provider_name,
- model: "api-generated",
- },
- };
- }
-
- // Update existing captions
- console.debug("Updating existing captions");
- return {
- ...prevCaptions,
- perspectives: {
- ...prevCaptions.perspectives,
- [perspective]: perspectiveData,
- },
- metadata: {
- ...prevCaptions.metadata,
- captioned_at: new Date().toISOString(),
- provider: provider_name,
- model: "api-generated",
- },
- };
- });
- } catch (err) {
- console.error("Error generating perspective", err);
- setError(
- err instanceof Error ? err.message : "Failed to generate perspective",
- );
- } finally {
- // Remove the perspective from the generating list
- setGeneratingPerspectives((prev) =>
- prev.filter((p) => p !== perspective),
- );
- // Only set isLoading to false if no perspectives are being generated
- setIsLoading(() => {
- const updatedGenerating = generatingPerspectives.filter(
- (p) => p !== perspective,
- );
- return updatedGenerating.length > 0;
- });
- }
- },
- [
- image,
- providersData,
- generateCaption,
- generatingPerspectives,
- isServerConnected,
- ],
- );
-
- // Function to generate all perspectives
- const generateAllPerspectives = useCallback(() => {
- if (!image || !perspectivesData) {
- console.warn(
- "Cannot generate all perspectives: No image or perspectives data",
- );
- setError("No image or perspectives data available");
- return;
- }
-
- if (!isServerConnected) {
- console.warn(
- "Cannot generate all perspectives: Server connection not established",
- );
- setError("Server connection not established");
- return;
- }
-
- console.log("Generating all perspectives", {
- imagePath: image.path,
- perspectiveCount: perspectivesData.length,
- });
-
- setIsLoading(true);
- // Track all perspectives as generating
- setGeneratingPerspectives(perspectivesData.map((p) => p.name));
-
- try {
- // Generate each perspective one by one
- for (const perspective of perspectivesData) {
- console.debug(`Generating perspective: ${perspective.name}`);
- generatePerspective(perspective.name);
- }
-
- console.log("All perspectives generated successfully");
- } catch (err) {
- console.error("Error generating all perspectives", err);
- setError(
- err instanceof Error
- ? err.message
- : "Failed to generate all perspectives",
- );
- } finally {
- setIsLoading(false);
- setGeneratingPerspectives([]);
- }
- }, [image, perspectivesData, generatePerspective, isServerConnected]);
-
- // Reset error when server connection changes
- useEffect(() => {
- if (isServerConnected && error === "Server connection not established") {
- setError(null);
- }
- }, [isServerConnected, error]);
-
- // Log when the hook's return value changes
- useEffect(() => {
- console.debug("useImagePerspectives state updated", {
- isLoading,
- hasError: error !== null,
- hasCaptions: captions !== null,
- generatedPerspectiveCount: generatedPerspectives.length,
- availablePerspectiveCount: availablePerspectives.length,
- availableProviderCount: availableProviders.length,
- generatingPerspectives,
- isServerConnected,
- });
-
- if (captions?.perspectives) {
- console.debug(
- "Current perspectives:",
- Object.keys(captions.perspectives),
- );
- }
- }, [
- isLoading,
- error,
- captions,
- generatedPerspectives,
- availablePerspectives,
- availableProviders,
- generatingPerspectives,
- isServerConnected,
- ]);
-
- // Create wrapper functions that don't return the promises
- const generatePerspectiveWrapper = (
- perspective: PerspectiveType,
- providerId?: number,
- options?: CaptionOptions,
- ): void => {
- generatePerspective(perspective, providerId, options);
- };
-
- const generateAllPerspectivesWrapper = (): void => {
- generateAllPerspectives();
- };
-
- return {
- isLoading,
- error,
- captions,
- generatedPerspectives,
- generatingPerspectives,
- generatePerspective: generatePerspectiveWrapper,
- generateAllPerspectives: generateAllPerspectivesWrapper,
- availablePerspectives,
- availableProviders,
- };
-}
diff --git a/graphcap_studio/src/features/perspectives/hooks/usePerspectiveModules.ts b/graphcap_studio/src/features/perspectives/hooks/usePerspectiveModules.ts
index 7f43b7d0..01355eeb 100644
--- a/graphcap_studio/src/features/perspectives/hooks/usePerspectiveModules.ts
+++ b/graphcap_studio/src/features/perspectives/hooks/usePerspectiveModules.ts
@@ -8,15 +8,12 @@
import { useServerConnectionsContext } from "@/context";
import { SERVER_IDS } from "@/features/server-connections/constants";
+import { createInferenceBridgeClient } from "@/features/server-connections/services/apiClients";
+import type { ModuleInfo, ModuleListResponse, Perspective, PerspectiveModule } from "@/types";
import { useQuery } from "@tanstack/react-query";
import { useEffect, useMemo } from "react";
-import {
- API_ENDPOINTS,
- CACHE_TIMES,
- perspectivesQueryKeys,
-} from "../services/constants";
-import { getGraphCapServerUrl, handleApiError } from "../services/utils";
-import type { ModuleInfo, ModuleListResponse, Perspective, PerspectiveModule } from "../types";
+import { CACHE_TIMES, perspectivesQueryKeys } from "../services/constants";
+import { handleApiError } from "../services/utils";
import { PerspectiveError } from "./usePerspectives";
type ModuleQueryResult = {
@@ -85,7 +82,7 @@ export function useModuleInfo() {
export function usePerspectiveModules(): ModuleQueryResult {
const { connections } = useServerConnectionsContext();
const graphcapServerConnection = connections.find(
- (conn) => conn.id === SERVER_IDS.GRAPHCAP_SERVER
+ (conn) => conn.id === SERVER_IDS.INFERENCE_BRIDGE
);
const isConnected = graphcapServerConnection?.status === "connected";
@@ -103,33 +100,29 @@ export function usePerspectiveModules(): ModuleQueryResult {
});
}
- const baseUrl = getGraphCapServerUrl(connections);
- if (!baseUrl) {
- console.warn("No GraphCap server URL available");
- throw new PerspectiveError("No GraphCap server URL available", {
- code: "MISSING_SERVER_URL",
- context: { connections },
- });
- }
+ // Use the inference bridge client instead of direct fetch
+ const client = createInferenceBridgeClient(connections);
- const endpoint = API_ENDPOINTS.LIST_MODULES;
- const url = `${baseUrl}${endpoint}`;
+ console.debug("Fetching modules from server using API client");
- console.debug(`Fetching modules from server: ${url}`);
+ try {
+ const response = await client.perspectives.modules.$get();
- const response = await fetch(url);
-
- if (!response.ok) {
- return handleApiError(response, "Failed to fetch modules");
- }
+ if (!response.ok) {
+ return handleApiError(response, "Failed to fetch modules");
+ }
- const data = (await response.json()) as ModuleListResponse;
+ const data = (await response.json()) as ModuleListResponse;
- console.debug(
- `Successfully fetched ${data.modules.length} modules`,
- );
+ console.debug(
+ `Successfully fetched ${data.modules.length} modules`,
+ );
- return data.modules;
+ return data.modules;
+ } catch (error) {
+ console.error("API client error:", error);
+ throw error;
+ }
} catch (error) {
// Improve error handling - log the error and rethrow
console.error("Error fetching modules:", error);
diff --git a/graphcap_studio/src/features/perspectives/hooks/usePerspectiveUI.ts b/graphcap_studio/src/features/perspectives/hooks/usePerspectiveUI.ts
index d86fa70d..c4eca286 100644
--- a/graphcap_studio/src/features/perspectives/hooks/usePerspectiveUI.ts
+++ b/graphcap_studio/src/features/perspectives/hooks/usePerspectiveUI.ts
@@ -5,16 +5,15 @@
* This hook provides UI-related functionality for the perspectives components.
*/
-import { PerspectiveType } from "@/features/perspectives/types";
import { useCallback, useState } from "react";
interface UsePerspectiveUIOptions {
onGeneratePerspective?: (
- perspective: PerspectiveType,
+ perspective: string,
provider?: string,
) => void;
initialSelectedProvider?: string;
- perspectiveKey?: PerspectiveType;
+ perspectiveKey?: string;
}
/**
diff --git a/graphcap_studio/src/features/perspectives/hooks/usePerspectives.ts b/graphcap_studio/src/features/perspectives/hooks/usePerspectives.ts
index 147ee79f..a149afcc 100644
--- a/graphcap_studio/src/features/perspectives/hooks/usePerspectives.ts
+++ b/graphcap_studio/src/features/perspectives/hooks/usePerspectives.ts
@@ -8,16 +8,16 @@
import { useServerConnectionsContext } from "@/context";
import { SERVER_IDS } from "@/features/server-connections/constants";
-import type { ServerConnection } from "@/features/server-connections/types";
+import { createInferenceBridgeClient } from "@/features/server-connections/services/apiClients";
+import type { Perspective, PerspectiveListResponse } from "@/types";
+import type { ServerConnection } from "@/types/server-connection-types";
import { useQuery } from "@tanstack/react-query";
import { useEffect } from "react";
import {
- API_ENDPOINTS,
CACHE_TIMES,
- perspectivesQueryKeys,
+ perspectivesQueryKeys
} from "../services/constants";
-import { getGraphCapServerUrl, handleApiError } from "../services/utils";
-import type { Perspective, PerspectiveListResponse } from "../types";
+import { handleApiError } from "../services/utils";
/**
* Custom error class for perspective fetching errors
@@ -51,7 +51,7 @@ export class PerspectiveError extends Error {
export function usePerspectives() {
const { connections } = useServerConnectionsContext();
const graphcapServerConnection = connections.find(
- (conn: ServerConnection) => conn.id === SERVER_IDS.GRAPHCAP_SERVER,
+ (conn: ServerConnection) => conn.id === SERVER_IDS.INFERENCE_BRIDGE,
);
const isConnected = graphcapServerConnection?.status === "connected";
@@ -68,33 +68,29 @@ export function usePerspectives() {
});
}
- const baseUrl = getGraphCapServerUrl(connections);
- if (!baseUrl) {
- console.warn("No GraphCap server URL available");
- throw new PerspectiveError("No GraphCap server URL available", {
- code: "MISSING_SERVER_URL",
- context: { connections },
- });
- }
-
- const endpoint = API_ENDPOINTS.LIST_PERSPECTIVES;
- const url = `${baseUrl}${endpoint}`;
+ // Use the inference bridge client instead of direct fetch
+ const client = createInferenceBridgeClient(connections);
- console.debug(`Fetching perspectives from server: ${url}`);
+ console.debug("Fetching perspectives from server using API client");
- const response = await fetch(url);
+ try {
+ const response = await client.perspectives.list.$get();
- if (!response.ok) {
- return handleApiError(response, "Failed to fetch perspectives");
- }
+ if (!response.ok) {
+ return handleApiError(response, "Failed to fetch perspectives");
+ }
- const data = (await response.json()) as PerspectiveListResponse;
+ const data = (await response.json()) as PerspectiveListResponse;
- console.debug(
- `Successfully fetched ${data.perspectives.length} perspectives`,
- );
+ console.debug(
+ `Successfully fetched ${data.perspectives.length} perspectives`,
+ );
- return data.perspectives;
+ return data.perspectives;
+ } catch (error) {
+ console.error("API client error:", error);
+ throw error;
+ }
} catch (error) {
// Improve error handling - log the error and rethrow
console.error("Error fetching perspectives:", error);
diff --git a/graphcap_studio/src/features/perspectives/services/api.ts b/graphcap_studio/src/features/perspectives/services/api.ts
index 4f86ae88..0877284d 100644
--- a/graphcap_studio/src/features/perspectives/services/api.ts
+++ b/graphcap_studio/src/features/perspectives/services/api.ts
@@ -5,29 +5,29 @@
* This module provides direct API methods for interacting with the perspectives service.
*/
-import { API_ENDPOINTS } from "../constants/index";
-import {
- CaptionRequestSchema,
- CaptionResponseSchema,
- ModuleListResponseSchema,
- ModulePerspectivesResponseSchema,
- PerspectiveListResponseSchema,
-} from "../types";
+import { createInferenceBridgeClient } from "@/features/server-connections/services/apiClients";
import type {
CaptionRequest,
CaptionResponse,
ModuleListResponse,
ModulePerspectivesResponse,
Perspective,
-} from "../types";
-import { ensureWorkspacePath, getGraphCapServerUrl, handleApiError } from "./utils";
+} from "@/types";
+import {
+ CaptionRequestSchema,
+ CaptionResponseSchema,
+ ModuleListResponseSchema,
+ ModulePerspectivesResponseSchema,
+ PerspectiveListResponseSchema,
+} from "@/types";
+import { ensureWorkspacePath, handleApiError } from "./utils";
/**
* Get server connections from local storage
*/
const getConnections = () => {
// Get the current connections from local storage
- const connectionsStr = localStorage.getItem("graphcap-server-connections");
+ const connectionsStr = localStorage.getItem("inference-bridge-connections");
let connections = [];
if (connectionsStr) {
@@ -55,16 +55,13 @@ export const perspectivesApi = {
*/
async listPerspectives(): Promise {
try {
- // Get the base URL using the utility function
+ // Get connections and create the client
const connections = getConnections();
- const baseUrl = getGraphCapServerUrl(connections);
+ const client = createInferenceBridgeClient(connections);
- // Create the full URL by combining base URL and endpoint path
- const url = `${baseUrl}${API_ENDPOINTS.LIST_PERSPECTIVES}`;
+ console.debug("Fetching perspectives using API client");
- console.debug(`Fetching perspectives from: ${url}`);
-
- const response = await fetch(url);
+ const response = await client.perspectives.list.$get();
if (!response.ok) {
await handleApiError(response, "Failed to fetch perspectives");
@@ -87,16 +84,13 @@ export const perspectivesApi = {
*/
async listModules(): Promise {
try {
- // Get the base URL using the utility function
+ // Get connections and create the client
const connections = getConnections();
- const baseUrl = getGraphCapServerUrl(connections);
-
- // Create the full URL by combining base URL and endpoint path
- const url = `${baseUrl}${API_ENDPOINTS.LIST_MODULES}`;
+ const client = createInferenceBridgeClient(connections);
- console.debug(`Fetching modules from: ${url}`);
+ console.debug("Fetching modules using API client");
- const response = await fetch(url);
+ const response = await client.perspectives.modules.$get();
if (!response.ok) {
await handleApiError(response, "Failed to fetch perspective modules");
@@ -120,22 +114,15 @@ export const perspectivesApi = {
*/
async getModulePerspectives(moduleName: string): Promise {
try {
- // Get the base URL using the utility function
+ // Get connections and create the client
const connections = getConnections();
- const baseUrl = getGraphCapServerUrl(connections);
+ const client = createInferenceBridgeClient(connections);
- // Create the endpoint path
- const endpointPath = API_ENDPOINTS.MODULE_PERSPECTIVES.replace(
- "{module_name}",
- encodeURIComponent(moduleName)
- );
+ console.debug(`Fetching perspectives for module '${moduleName}' using API client`);
- // Create the full URL by combining base URL and endpoint path
- const url = `${baseUrl}${endpointPath}`;
-
- console.debug(`Fetching perspectives for module '${moduleName}' from: ${url}`);
-
- const response = await fetch(url);
+ const response = await client.perspectives.modules[":moduleName"].$get({
+ param: { moduleName }
+ });
if (!response.ok) {
// Check if we got HTML instead of JSON
@@ -190,14 +177,13 @@ export const perspectivesApi = {
requestParams: CaptionRequest,
): Promise {
try {
- // Get the base URL using the utility function
+ // Get connections and create the client
const connections = getConnections();
- const baseUrl = getGraphCapServerUrl(connections);
+ const client = createInferenceBridgeClient(connections);
// Ensure the image path has the correct workspace prefix
const normalizedImagePath = ensureWorkspacePath(requestParams.image_path);
console.log("Generating caption for image path:", normalizedImagePath);
- console.log("Request params:", requestParams);
// Create the request body and validate with Zod
const request: CaptionRequest = {
@@ -208,16 +194,11 @@ export const perspectivesApi = {
// Validate the request with Zod
const validatedRequest = CaptionRequestSchema.parse(request);
- // Create the full URL by combining base URL and endpoint path
- const url = `${baseUrl}${API_ENDPOINTS.REST_GENERATE_CAPTION}`;
-
- // Make the API request
- const response = await fetch(url, {
- method: "POST",
- headers: {
- "Content-Type": "application/json",
- },
- body: JSON.stringify(validatedRequest),
+ console.debug(`Generating caption for perspective '${requestParams.perspective}' using API client`);
+
+ // Use the API client to post the request
+ const response = await client.perspectives["caption-from-path"].$post({
+ json: validatedRequest
});
if (!response.ok) {
diff --git a/graphcap_studio/src/features/perspectives/services/constants.ts b/graphcap_studio/src/features/perspectives/services/constants.ts
index ff1a850a..eaca9489 100644
--- a/graphcap_studio/src/features/perspectives/services/constants.ts
+++ b/graphcap_studio/src/features/perspectives/services/constants.ts
@@ -58,7 +58,7 @@ export const DEFAULTS = {
/**
* Default server URL
*/
- SERVER_URL: "http://localhost:32100/api",
+ SERVER_URL: "http://localhost:32100/",
};
/**
diff --git a/graphcap_studio/src/features/perspectives/services/index.ts b/graphcap_studio/src/features/perspectives/services/index.ts
index 7d8b4332..d8783d3b 100644
--- a/graphcap_studio/src/features/perspectives/services/index.ts
+++ b/graphcap_studio/src/features/perspectives/services/index.ts
@@ -17,7 +17,6 @@ export { perspectivesApi } from "./api";
// Export hooks from the hooks directory
export {
- usePerspectives,
- useGeneratePerspectiveCaption,
- useImagePerspectives,
+ useGeneratePerspectiveCaption, usePerspectives
} from "@/features/perspectives/hooks";
+
diff --git a/graphcap_studio/src/features/perspectives/services/utils.ts b/graphcap_studio/src/features/perspectives/services/utils.ts
index ca78209b..3442773e 100644
--- a/graphcap_studio/src/features/perspectives/services/utils.ts
+++ b/graphcap_studio/src/features/perspectives/services/utils.ts
@@ -5,31 +5,36 @@
* This module provides utility functions for the perspectives service.
*/
-import { DEFAULTS } from "@/features/perspectives/constants/index";
import type { ServerConnection } from "@/features/perspectives/types";
import { SERVER_IDS } from "@/features/server-connections/constants";
/**
- * Get the GraphCap Server URL from server connections context
+ * Get the Inference Bridge URL from server connections context
*/
export function getGraphCapServerUrl(connections: ServerConnection[]): string {
- const graphcapServerConnection = connections.find(
- (conn) => conn.id === SERVER_IDS.GRAPHCAP_SERVER,
+ const serverConnection = connections.find(
+ (conn) => conn.id === SERVER_IDS.INFERENCE_BRIDGE,
);
- // Get URL from connection, environment variable, or default
+ // Use connection URL or fallback to environment variable or default
const serverUrl =
- graphcapServerConnection?.url ??
- import.meta.env.VITE_GRAPHCAP_SERVER_URL ??
- import.meta.env.VITE_API_URL ??
- DEFAULTS.SERVER_URL;
-
- // Log the server URL being used for debugging
- console.debug(`Using GraphCap server URL: ${serverUrl}`);
+ serverConnection?.url ??
+ import.meta.env.VITE_INFERENCE_BRIDGE_URL ??
+ "http://localhost:32100";
+ console.debug(`Using Inference Bridge URL: ${serverUrl}`);
return serverUrl;
}
+/**
+ * Get the full Inference Bridge API URL (including /api/v1)
+ */
+export function getInferenceBridgeApiUrl(connections: ServerConnection[]): string {
+ const baseUrl = getGraphCapServerUrl(connections);
+ // Ensure the URL doesn't already have /api/v1
+ return baseUrl.endsWith('/api/v1') ? baseUrl : `${baseUrl}/api/v1`;
+}
+
/**
* Ensure path starts with /workspace if it doesn't already
*/
diff --git a/graphcap_studio/src/features/perspectives/types/index.ts b/graphcap_studio/src/features/perspectives/types/index.ts
index 6b6024f2..5a5d7e94 100644
--- a/graphcap_studio/src/features/perspectives/types/index.ts
+++ b/graphcap_studio/src/features/perspectives/types/index.ts
@@ -6,5 +6,6 @@
* Type definitions are consolidated in their respective files.
*/
-export * from "./perspectivesTypes";
-export * from "./perspectiveModuleTypes";
+export * from "@/types/perspective-module-types";
+export * from "@/types/perspective-types";
+
diff --git a/graphcap_studio/src/features/perspectives/utils/__tests__/persist-perspective-caption.test.ts b/graphcap_studio/src/features/perspectives/utils/__tests__/persist-perspective-caption.test.ts
index 34bf9d88..638bf6f3 100644
--- a/graphcap_studio/src/features/perspectives/utils/__tests__/persist-perspective-caption.test.ts
+++ b/graphcap_studio/src/features/perspectives/utils/__tests__/persist-perspective-caption.test.ts
@@ -3,8 +3,8 @@
* Unit tests for perspective caption persistence utilities
*/
+import type { PerspectiveData } from "@/types/perspective-types";
import { afterAll, beforeEach, describe, expect, it } from "vitest";
-import { PerspectiveData } from "../../types";
import {
clearAllPerspectiveCaptions,
deletePerspectiveCaption,
diff --git a/graphcap_studio/src/features/perspectives/utils/api-adapters.ts b/graphcap_studio/src/features/perspectives/utils/api-adapters.ts
new file mode 100644
index 00000000..152b9cf2
--- /dev/null
+++ b/graphcap_studio/src/features/perspectives/utils/api-adapters.ts
@@ -0,0 +1,106 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Perspectives API Adapters
+ *
+ * This module provides adapter functions for converting between API and application
+ * types for the perspectives feature, including caption request formatting.
+ */
+
+import type { GenerationOptions } from "@/types/generation-option-types";
+import { formatApiOptions } from "@/types/generation-option-types";
+import type { Provider } from "@/types/provider-config-types";
+import { denormalizeProviderId } from "@/types/provider-config-types";
+
+// Legacy caption options interface for migration purposes
+interface LegacyCaptionOptions {
+ model: string;
+ max_tokens?: number;
+ temperature?: number;
+ top_p?: number;
+ repetition_penalty?: number;
+ global_context?: string;
+ context?: string[];
+ resize?: boolean;
+ resize_resolution?: string;
+}
+
+/**
+ * Format a caption generation request
+ * Converts from application types to the API request format
+ */
+export function formatCaptionRequest(
+ imagePath: string,
+ perspective: string,
+ provider: Provider,
+ options: GenerationOptions
+): {
+ image_path: string;
+ perspective: string;
+ provider_id: number;
+ options: Record;
+} {
+ return {
+ image_path: imagePath,
+ perspective,
+ provider_id: denormalizeProviderId(provider.id),
+ options: formatApiOptions(options)
+ };
+}
+
+/**
+ * Convert from CaptionOptions format to GenerationOptions format
+ * Used during the migration from CaptionOptions to GenerationOptions
+ */
+export function legacyCaptionToGenerationOptions(
+ captionOptions: LegacyCaptionOptions,
+ providerName: string
+): GenerationOptions {
+ return {
+ model_name: captionOptions.model,
+ max_tokens: captionOptions.max_tokens ?? 4096,
+ temperature: captionOptions.temperature ?? 0.7,
+ top_p: captionOptions.top_p ?? 0.95,
+ repetition_penalty: captionOptions.repetition_penalty ?? 1.1,
+ global_context: captionOptions.global_context ?? "You are a visual captioning perspective.",
+ context: captionOptions.context ?? [],
+ resize_resolution: captionOptions.resize_resolution ?? "NONE",
+ provider_name: providerName
+ };
+}
+
+// Interface for perspective data structure
+interface PerspectiveDataWithOptions {
+ model: string;
+ provider: string;
+ options?: LegacyCaptionOptions;
+}
+
+/**
+ * Convert from PerspectiveData to GenerationOptions format
+ * Used for loading saved perspective settings
+ */
+export function perspectiveDataToGenerationOptions(
+ perspectiveData: PerspectiveDataWithOptions,
+ providerNameMap: Record
+): GenerationOptions {
+ // If we have structured options, use those
+ if (perspectiveData.options) {
+ return legacyCaptionToGenerationOptions(
+ perspectiveData.options,
+ providerNameMap[perspectiveData.provider] || perspectiveData.provider
+ );
+ }
+
+ // Otherwise create minimal options
+ return {
+ model_name: perspectiveData.model,
+ provider_name: providerNameMap[perspectiveData.provider] || perspectiveData.provider,
+ max_tokens: 4096,
+ temperature: 0.7,
+ top_p: 0.95,
+ repetition_penalty: 1.1,
+ global_context: "You are a visual captioning perspective.",
+ context: [],
+ resize_resolution: "NONE"
+ };
+}
\ No newline at end of file
diff --git a/graphcap_studio/src/features/perspectives/utils/persist-perspective-caption.ts b/graphcap_studio/src/features/perspectives/utils/persist-perspective-caption.ts
index 9adfe5b4..d0d47863 100644
--- a/graphcap_studio/src/features/perspectives/utils/persist-perspective-caption.ts
+++ b/graphcap_studio/src/features/perspectives/utils/persist-perspective-caption.ts
@@ -7,7 +7,7 @@
* and perspective name.
*/
-import { PerspectiveData } from "../types";
+import type { PerspectiveData } from "@/types/perspective-types";
/**
* Storage key prefix for saving perspective captions in localStorage
diff --git a/graphcap_studio/src/features/server-connections/components/ConnectionActionButton.tsx b/graphcap_studio/src/features/server-connections/components/ConnectionActionButton.tsx
index 9db05a2f..48f908ad 100644
--- a/graphcap_studio/src/features/server-connections/components/ConnectionActionButton.tsx
+++ b/graphcap_studio/src/features/server-connections/components/ConnectionActionButton.tsx
@@ -1,5 +1,5 @@
import { CONNECTION_STATUS } from "@/features/server-connections/constants";
-import { ConnectionActionButtonProps } from "@/features/server-connections/types";
+import type { ConnectionActionButtonProps } from "@/types/server-connection-types";
import { Button } from "@chakra-ui/react";
// SPDX-License-Identifier: Apache-2.0
import { memo } from "react";
diff --git a/graphcap_studio/src/features/server-connections/components/ConnectionCard.tsx b/graphcap_studio/src/features/server-connections/components/ConnectionCard.tsx
index 2603dd55..1e46f435 100644
--- a/graphcap_studio/src/features/server-connections/components/ConnectionCard.tsx
+++ b/graphcap_studio/src/features/server-connections/components/ConnectionCard.tsx
@@ -1,5 +1,5 @@
import { useColorModeValue } from "@/components/ui/theme/color-mode";
-import { ConnectionCardProps } from "@/features/server-connections/types";
+import type { ConnectionCardProps } from "@/types/server-connection-types";
import { Box, Flex, Heading, Stack } from "@chakra-ui/react";
// SPDX-License-Identifier: Apache-2.0
import { memo } from "react";
diff --git a/graphcap_studio/src/features/server-connections/components/ConnectionStatusIndicator.tsx b/graphcap_studio/src/features/server-connections/components/ConnectionStatusIndicator.tsx
index b746d85b..df872334 100644
--- a/graphcap_studio/src/features/server-connections/components/ConnectionStatusIndicator.tsx
+++ b/graphcap_studio/src/features/server-connections/components/ConnectionStatusIndicator.tsx
@@ -1,6 +1,6 @@
import { Status } from "@/components/ui/status";
import { CONNECTION_STATUS } from "@/features/server-connections/constants";
-import { ConnectionStatusIndicatorProps } from "@/features/server-connections/types";
+import type { ConnectionStatusIndicatorProps } from "@/types/server-connection-types";
// SPDX-License-Identifier: Apache-2.0
import { memo } from "react";
diff --git a/graphcap_studio/src/features/server-connections/components/ConnectionUrlInput.tsx b/graphcap_studio/src/features/server-connections/components/ConnectionUrlInput.tsx
index 855c773d..996ba3ad 100644
--- a/graphcap_studio/src/features/server-connections/components/ConnectionUrlInput.tsx
+++ b/graphcap_studio/src/features/server-connections/components/ConnectionUrlInput.tsx
@@ -1,8 +1,8 @@
import { useColorModeValue } from "@/components/ui/theme/color-mode";
-import { ConnectionUrlInputProps } from "@/features/server-connections/types";
+import type { ConnectionUrlInputProps } from "@/types/server-connection-types";
import { Input } from "@chakra-ui/react";
// SPDX-License-Identifier: Apache-2.0
-import { ChangeEvent, memo } from "react";
+import { type ChangeEvent, memo } from "react";
/**
* ConnectionUrlInput component
diff --git a/graphcap_studio/src/features/server-connections/components/ServerConnectionsPanel.tsx b/graphcap_studio/src/features/server-connections/components/ServerConnectionsPanel.tsx
index 87d14ddb..7c118b94 100644
--- a/graphcap_studio/src/features/server-connections/components/ServerConnectionsPanel.tsx
+++ b/graphcap_studio/src/features/server-connections/components/ServerConnectionsPanel.tsx
@@ -1,7 +1,7 @@
import { useColorModeValue } from "@/components/ui/theme/color-mode";
import { useServerConnectionsContext } from "@/context/ServerConnectionsContext";
import { CONNECTION_STATUS } from "@/features/server-connections/constants";
-import { ServerConnectionsPanelProps } from "@/features/server-connections/types";
+import type { ServerConnectionsPanelProps } from "@/types/server-connection-types";
import { Box, Button, Flex, Heading, Spinner, Stack } from "@chakra-ui/react";
// SPDX-License-Identifier: Apache-2.0
import { memo, useMemo } from "react";
diff --git a/graphcap_studio/src/features/server-connections/constants.ts b/graphcap_studio/src/features/server-connections/constants.ts
index 2c6b70c4..35e1a78c 100644
--- a/graphcap_studio/src/features/server-connections/constants.ts
+++ b/graphcap_studio/src/features/server-connections/constants.ts
@@ -5,13 +5,13 @@
*/
export const SERVER_IDS = {
MEDIA_SERVER: "media-server",
- GRAPHCAP_SERVER: "graphcap-server",
+ INFERENCE_BRIDGE: "inference-bridge",
DATA_SERVICE: "data-service",
} as const;
export const SERVER_NAMES = {
[SERVER_IDS.MEDIA_SERVER]: "Media Server",
- [SERVER_IDS.GRAPHCAP_SERVER]: "GraphCap Server",
+ [SERVER_IDS.INFERENCE_BRIDGE]: "Inference Bridge",
[SERVER_IDS.DATA_SERVICE]: "Data Service",
} as const;
@@ -24,6 +24,6 @@ export const CONNECTION_STATUS = {
export const DEFAULT_URLS = {
[SERVER_IDS.MEDIA_SERVER]: "http://localhost:32400",
- [SERVER_IDS.GRAPHCAP_SERVER]: "http://localhost:32100",
+ [SERVER_IDS.INFERENCE_BRIDGE]: "http://localhost:32100",
[SERVER_IDS.DATA_SERVICE]: "http://localhost:32550",
} as const;
diff --git a/graphcap_studio/src/features/server-connections/index.ts b/graphcap_studio/src/features/server-connections/index.ts
index 51a3eb3a..afe983f5 100644
--- a/graphcap_studio/src/features/server-connections/index.ts
+++ b/graphcap_studio/src/features/server-connections/index.ts
@@ -1,4 +1,5 @@
// SPDX-License-Identifier: Apache-2.0
-export * from "./types";
+export * from "@/types/server-connection-types";
export * from "./constants";
export * from "./useServerConnections";
+
diff --git a/graphcap_studio/src/features/server-connections/services/apiClients.ts b/graphcap_studio/src/features/server-connections/services/apiClients.ts
new file mode 100644
index 00000000..df344072
--- /dev/null
+++ b/graphcap_studio/src/features/server-connections/services/apiClients.ts
@@ -0,0 +1,24 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * API Clients Service
+ *
+ * This module re-exports client functions for interacting with various server APIs.
+ */
+
+// Re-export everything from dataServiceClient.ts
+export {
+ type DataServiceClient,
+ getDataServiceUrl,
+ createDataServiceClient,
+} from "./dataServiceClient";
+
+// Re-export everything from inferenceBridgeClient.ts
+export {
+ type InferenceBridgeClient,
+ type ProviderClient,
+ type PerspectivesClient,
+ getInferenceBridgeUrl,
+ createInferenceBridgeClient,
+ createProviderClient,
+ createPerspectivesClient,
+} from "./inferenceBridgeClient";
\ No newline at end of file
diff --git a/graphcap_studio/src/features/server-connections/services/dataServiceClient.ts b/graphcap_studio/src/features/server-connections/services/dataServiceClient.ts
new file mode 100644
index 00000000..4ef3dc68
--- /dev/null
+++ b/graphcap_studio/src/features/server-connections/services/dataServiceClient.ts
@@ -0,0 +1,54 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Data Service API Client
+ *
+ * This module provides client functions for interacting with the Data Service API.
+ */
+
+import type { ServerConnection } from "@/types/server-connection-types";
+import { hc } from "hono/client";
+import { DEFAULT_URLS, SERVER_IDS } from "../constants";
+
+/**
+ * Interface for the Data Service client
+ */
+export interface DataServiceClient {
+ providers: {
+ $get: () => Promise;
+ $post: (options: { json: unknown }) => Promise;
+ ":id": {
+ $get: (options: { param: { id: string } }) => Promise;
+ $put: (options: {
+ param: { id: string };
+ json: unknown;
+ }) => Promise;
+ $delete: (options: { param: { id: string } }) => Promise;
+ };
+ };
+ health: {
+ $get: () => Promise;
+ };
+}
+
+/**
+ * Get the Data Service URL from server connections
+ */
+export function getDataServiceUrl(connections: ServerConnection[]): string {
+ const dataServiceConnection = connections.find(
+ (conn) => conn.id === SERVER_IDS.DATA_SERVICE,
+ );
+
+ return (
+ dataServiceConnection?.url ??
+ import.meta.env.VITE_DATA_SERVICE_URL ??
+ DEFAULT_URLS[SERVER_IDS.DATA_SERVICE]
+ );
+}
+
+/**
+ * Create a Hono client for the Data Service
+ */
+export function createDataServiceClient(connections: ServerConnection[]): DataServiceClient {
+ const baseUrl = getDataServiceUrl(connections);
+ return hc(`${baseUrl}/api/v1`) as unknown as DataServiceClient;
+}
\ No newline at end of file
diff --git a/graphcap_studio/src/features/server-connections/services/index.ts b/graphcap_studio/src/features/server-connections/services/index.ts
index 225c3d2c..2321ada0 100644
--- a/graphcap_studio/src/features/server-connections/services/index.ts
+++ b/graphcap_studio/src/features/server-connections/services/index.ts
@@ -1,6 +1,20 @@
+// Server health checks
export {
- checkServerHealth,
- checkMediaServerHealth,
- checkGraphCapServerHealth,
- checkServerHealthById,
+ checkInferenceBridgeHealth, checkMediaServerHealth, checkServerHealth, checkServerHealthById
} from "./serverConnections";
+
+// API clients
+export type {
+ DataServiceClient,
+ InferenceBridgeClient, PerspectivesClient, ProviderClient
+} from "./apiClients";
+
+export {
+ createDataServiceClient, createInferenceBridgeClient, createPerspectivesClient, createProviderClient, getDataServiceUrl, getInferenceBridgeUrl
+} from "./apiClients";
+
+// Provider services
+export {
+ queryKeys as providerQueryKeys, useCreateProvider, useDeleteProvider, useProvider, useProviders, useUpdateProvider
+} from "./providers";
+
diff --git a/graphcap_studio/src/features/server-connections/services/inferenceBridgeClient.ts b/graphcap_studio/src/features/server-connections/services/inferenceBridgeClient.ts
new file mode 100644
index 00000000..a8d013e2
--- /dev/null
+++ b/graphcap_studio/src/features/server-connections/services/inferenceBridgeClient.ts
@@ -0,0 +1,110 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Inference Bridge API Client
+ *
+ * This module provides client functions for interacting with the Inference Bridge API.
+ */
+
+import type { ServerConnection } from "@/types/server-connection-types";
+import { hc } from "hono/client";
+import { DEFAULT_URLS, SERVER_IDS } from "../constants";
+
+/**
+ * Interface for the Inference Bridge Provider operations
+ */
+export interface ProviderClient {
+ ":provider_name": {
+ "test-connection": {
+ $post: (options: {
+ param: { provider_name: string };
+ json: unknown;
+ }) => Promise;
+ };
+ "models": {
+ $post: (options: {
+ param: { provider_name: string };
+ json: unknown;
+ }) => Promise;
+ };
+ };
+}
+
+/**
+ * Interface for the Inference Bridge Perspectives operations
+ */
+export interface PerspectivesClient {
+ list: {
+ $get: () => Promise;
+ };
+ modules: {
+ $get: () => Promise;
+ ":moduleName": {
+ $get: (options: { param: { moduleName: string } }) => Promise;
+ };
+ };
+ "caption-from-path": {
+ $post: (options: { json: unknown }) => Promise;
+ };
+ ":name": {
+ $post: (options: {
+ param: { name: string };
+ json: unknown;
+ formData?: FormData;
+ }) => Promise;
+ };
+}
+
+/**
+ * Interface for the Inference Bridge client - combines provider and perspectives APIs
+ */
+export interface InferenceBridgeClient {
+ providers: ProviderClient;
+ perspectives: PerspectivesClient;
+ health: {
+ $get: () => Promise;
+ };
+}
+
+/**
+ * Get the Inference Bridge URL from server connections
+ */
+export function getInferenceBridgeUrl(connections: ServerConnection[]): string {
+ const inferenceBridgeConnection = connections.find(
+ (conn) => conn.id === SERVER_IDS.INFERENCE_BRIDGE,
+ );
+
+ return (
+ inferenceBridgeConnection?.url ??
+ import.meta.env.VITE_INFERENCE_BRIDGE_URL ??
+ DEFAULT_URLS[SERVER_IDS.INFERENCE_BRIDGE]
+ );
+}
+
+/**
+ * Create a Hono client for the Inference Bridge
+ * Automatically appends /api/v1 to the base URL
+ */
+export function createInferenceBridgeClient(connections: ServerConnection[]): InferenceBridgeClient {
+ const baseUrl = getInferenceBridgeUrl(connections);
+ // Ensure the URL doesn't already have /api/v1
+ const apiUrl = baseUrl.endsWith('/api/v1') ? baseUrl : `${baseUrl}/api/v1`;
+ return hc(apiUrl) as unknown as InferenceBridgeClient;
+}
+
+/**
+ * Create a client for provider operations only
+ */
+export function createProviderClient(connections: ServerConnection[]): ProviderClient {
+ const baseUrl = getInferenceBridgeUrl(connections);
+ const apiUrl = baseUrl.endsWith('/api/v1') ? baseUrl : `${baseUrl}/api/v1`;
+ return (hc(apiUrl) as unknown as InferenceBridgeClient).providers;
+}
+
+/**
+ * Create a client for perspectives operations only
+ */
+export function createPerspectivesClient(connections: ServerConnection[]): PerspectivesClient {
+ const baseUrl = getInferenceBridgeUrl(connections);
+ const apiUrl = baseUrl.endsWith('/api/v1') ? baseUrl : `${baseUrl}/api/v1`;
+ return (hc(apiUrl) as unknown as InferenceBridgeClient).perspectives;
+}
\ No newline at end of file
diff --git a/graphcap_studio/src/features/server-connections/services/providerAdapters.ts b/graphcap_studio/src/features/server-connections/services/providerAdapters.ts
new file mode 100644
index 00000000..f156222d
--- /dev/null
+++ b/graphcap_studio/src/features/server-connections/services/providerAdapters.ts
@@ -0,0 +1,129 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Provider API Adapters
+ *
+ * This module provides adapter functions for converting between API and application
+ * provider types, handling the conversion between numeric and string IDs.
+ */
+
+import type { Provider, ProviderModel, ProviderModelInfo } from "@/types/provider-config-types";
+import { normalizeProviderId } from "@/types/provider-config-types";
+
+// Type for raw API provider data
+interface ApiProvider {
+ id: number;
+ name: string;
+ kind: string;
+ environment: "cloud" | "local";
+ baseUrl: string;
+ apiKey?: string;
+ isEnabled: boolean;
+ defaultModel?: string;
+ createdAt: string | Date;
+ updatedAt: string | Date;
+ models?: ApiProviderModel[];
+}
+
+// Type for raw API provider model data
+interface ApiProviderModel {
+ id: number;
+ providerId: number;
+ name: string;
+ isEnabled: boolean;
+ createdAt: string | Date;
+ updatedAt: string | Date;
+}
+
+// Type for raw API model info
+interface ApiModelInfo {
+ id: string;
+ name: string;
+ is_default?: boolean;
+}
+
+/**
+ * Convert API provider to application Provider type
+ * This handles ID conversion from number to string
+ */
+export function fromApiProvider(apiProvider: ApiProvider): Provider {
+ return {
+ id: normalizeProviderId(apiProvider.id),
+ name: apiProvider.name,
+ kind: apiProvider.kind,
+ environment: apiProvider.environment,
+ baseUrl: apiProvider.baseUrl,
+ apiKey: apiProvider.apiKey,
+ isEnabled: apiProvider.isEnabled,
+ defaultModel: apiProvider.defaultModel,
+ createdAt: apiProvider.createdAt,
+ updatedAt: apiProvider.updatedAt,
+
+ // Convert nested models
+ models: apiProvider.models?.map((model: ApiProviderModel) => ({
+ id: normalizeProviderId(model.id),
+ providerId: normalizeProviderId(model.providerId),
+ name: model.name,
+ isEnabled: model.isEnabled,
+ createdAt: model.createdAt,
+ updatedAt: model.updatedAt,
+ })),
+ };
+}
+
+/**
+ * Convert application Provider to API provider
+ * This handles ID conversion from string to number
+ */
+export function toApiProvider(provider: Provider): ApiProvider {
+ return {
+ id: Number.parseInt(provider.id, 10),
+ name: provider.name,
+ kind: provider.kind,
+ environment: provider.environment,
+ baseUrl: provider.baseUrl,
+ apiKey: provider.apiKey,
+ isEnabled: provider.isEnabled,
+ defaultModel: provider.defaultModel,
+ createdAt: provider.createdAt,
+ updatedAt: provider.updatedAt,
+
+ // Convert models back to numeric IDs
+ models: provider.models?.map((model) => ({
+ id: Number.parseInt(model.id, 10),
+ providerId: Number.parseInt(model.providerId, 10),
+ name: model.name,
+ isEnabled: model.isEnabled,
+ createdAt: model.createdAt,
+ updatedAt: model.updatedAt,
+ })),
+ };
+}
+
+/**
+ * Convert API model info to application ProviderModelInfo
+ */
+export function fromApiModelInfo(apiModel: ApiModelInfo): ProviderModelInfo {
+ return {
+ id: apiModel.id,
+ name: apiModel.name,
+ is_default: apiModel.is_default,
+ };
+}
+
+/**
+ * Create a provider model with defaults
+ */
+export function createProviderModel(
+ providerId: string,
+ name: string,
+ id?: string,
+): ProviderModel {
+ return {
+ id: id ?? crypto.randomUUID(), // Generate UUID if no ID provided
+ providerId,
+ name,
+ isEnabled: true,
+ createdAt: new Date().toISOString(),
+ updatedAt: new Date().toISOString(),
+ };
+}
diff --git a/graphcap_studio/src/features/server-connections/services/providers.ts b/graphcap_studio/src/features/server-connections/services/providers.ts
new file mode 100644
index 00000000..5ebdc167
--- /dev/null
+++ b/graphcap_studio/src/features/server-connections/services/providers.ts
@@ -0,0 +1,267 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Providers Service
+ *
+ * This module provides functions for interacting with the Data Service's
+ * provider management API using TanStack Query and Hono's RPC client.
+ */
+
+import { useServerConnectionsContext } from "@/context/ServerConnectionsContext";
+import type {
+ Provider,
+ ProviderCreate,
+ ProviderUpdate,
+ ServerProviderConfig,
+ SuccessResponse
+} from "@/types/provider-config-types";
+import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
+import { SERVER_IDS } from "../constants";
+import { createDataServiceClient, createInferenceBridgeClient } from "./apiClients";
+import { fromApiProvider, toApiProvider } from "./providerAdapters";
+
+// Query keys for TanStack Query
+export const queryKeys = {
+ providers: ["providers"] as const,
+ provider: (id: string) => ["providers", id] as const,
+ providerModels: (providerName: string) => ["providers", "models", providerName] as const,
+};
+
+/**
+ * Hook to get all providers
+ */
+export function useProviders() {
+ const { connections } = useServerConnectionsContext();
+ const dataServiceConnection = connections.find(
+ (conn) => conn.id === SERVER_IDS.DATA_SERVICE,
+ );
+ const isConnected = dataServiceConnection?.status === "connected";
+
+ return useQuery({
+ queryKey: queryKeys.providers,
+ queryFn: async () => {
+ console.log("📡 Fetching all providers");
+
+ const client = createDataServiceClient(connections);
+ const response = await client.providers.$get();
+
+ if (!response.ok) {
+ const errorMsg = `Failed to fetch providers: ${response.status}`;
+ console.error(`❌ ${errorMsg}`);
+ throw new Error(errorMsg);
+ }
+
+ // Convert API response to application types
+ const apiProviders = await response.json();
+ const providers = apiProviders.map(fromApiProvider);
+
+ console.log(`✅ Fetched ${providers.length} providers:`, providers);
+ return providers;
+ },
+ enabled: isConnected,
+ staleTime: 1000 * 60 * 5, // 5 minutes
+ });
+}
+
+/**
+ * Hook to get a provider by ID
+ */
+export function useProvider(id: string) {
+ const { connections } = useServerConnectionsContext();
+ const dataServiceConnection = connections.find(
+ (conn) => conn.id === SERVER_IDS.DATA_SERVICE,
+ );
+ const isConnected = dataServiceConnection?.status === "connected";
+
+ return useQuery({
+ queryKey: queryKeys.provider(id),
+ queryFn: async () => {
+ console.log(`📡 Fetching provider with ID: ${id}`);
+
+ const client = createDataServiceClient(connections);
+ const response = await client.providers[":id"].$get({
+ param: { id },
+ });
+
+ if (!response.ok) {
+ const errorMsg = `Failed to fetch provider: ${response.status}`;
+ console.error(`❌ ${errorMsg}`);
+ throw new Error(errorMsg);
+ }
+
+ // Convert API response to application types
+ const apiProvider = await response.json();
+ const provider = fromApiProvider(apiProvider);
+
+ console.log("✅ Fetched provider:", provider);
+ return provider;
+ },
+ enabled: isConnected && !!id,
+ });
+}
+
+/**
+ * Hook to create a provider
+ */
+export function useCreateProvider() {
+ const { connections } = useServerConnectionsContext();
+ const queryClient = useQueryClient();
+
+ return useMutation({
+ mutationFn: async (data: ProviderCreate) => {
+ console.log("📡 Creating provider:", data);
+
+ const client = createDataServiceClient(connections);
+ // Convert application data to API format
+ const apiData = toApiProvider(data as Provider);
+ console.log("📤 API request data:", apiData);
+
+ const response = await client.providers.$post({
+ json: apiData,
+ });
+
+ if (!response.ok) {
+ const errorMsg = `Failed to create provider: ${response.status}`;
+ console.error(`❌ ${errorMsg}`);
+ throw new Error(errorMsg);
+ }
+
+ // Convert API response to application types
+ const apiProvider = await response.json();
+ const provider = fromApiProvider(apiProvider);
+
+ console.log("✅ Provider created:", provider);
+ return provider;
+ },
+ onSuccess: () => {
+ console.log("🔄 Invalidating providers cache after create");
+ queryClient.invalidateQueries({ queryKey: queryKeys.providers });
+ },
+ onError: (error: Error) => {
+ console.error("❌ Error in useCreateProvider:", error);
+ },
+ });
+}
+
+/**
+ * Hook to update a provider
+ */
+export function useUpdateProvider() {
+ const { connections } = useServerConnectionsContext();
+ const queryClient = useQueryClient();
+
+ return useMutation({
+ mutationFn: async ({ id, data }: { id: string; data: ProviderUpdate }) => {
+ console.log(`📡 Updating provider with ID ${id}:`, data);
+
+ const client = createDataServiceClient(connections);
+ // Convert application data to API format
+ const apiData = toApiProvider(data);
+ // Create a new object without the ID
+ const { id: _, ...apiDataWithoutId } = apiData;
+
+ console.log("📤 API request data:", apiDataWithoutId);
+
+ const response = await client.providers[":id"].$put({
+ param: { id },
+ json: apiDataWithoutId,
+ });
+
+ if (!response.ok) {
+ const errorMsg = `Failed to update provider: ${response.status}`;
+ console.error(`❌ ${errorMsg}`);
+ throw new Error(errorMsg);
+ }
+
+ // Convert API response to application types
+ const apiProvider = await response.json();
+ const provider = fromApiProvider(apiProvider);
+
+ console.log("✅ Provider updated:", provider);
+ return provider;
+ },
+ onSuccess: (_data, variables) => {
+ console.log(`🔄 Invalidating providers cache after update for ID ${variables.id}`);
+ queryClient.invalidateQueries({ queryKey: queryKeys.providers });
+ queryClient.invalidateQueries({ queryKey: queryKeys.provider(variables.id) });
+ },
+ onError: (error: Error, variables) => {
+ console.error(`❌ Error in useUpdateProvider for ID ${variables.id}:`, error);
+ },
+ });
+}
+
+/**
+ * Hook to delete a provider
+ */
+export function useDeleteProvider() {
+ const { connections } = useServerConnectionsContext();
+ const queryClient = useQueryClient();
+
+ return useMutation({
+ mutationFn: async (id: string) => {
+ console.log(`📡 Deleting provider with ID: ${id}`);
+
+ const client = createDataServiceClient(connections);
+
+ const response = await client.providers[":id"].$delete({
+ param: { id },
+ });
+
+ if (!response.ok) {
+ const errorMsg = `Failed to delete provider: ${response.status}`;
+ console.error(`❌ ${errorMsg}`);
+ throw new Error(errorMsg);
+ }
+
+ const result = await response.json() as SuccessResponse;
+ console.log("✅ Provider deleted:", result);
+ return result;
+ },
+ onSuccess: (_data, id) => {
+ console.log(`🔄 Invalidating providers cache after delete for ID ${id}`);
+ queryClient.invalidateQueries({ queryKey: queryKeys.providers });
+ queryClient.invalidateQueries({ queryKey: queryKeys.provider(id) });
+ },
+ onError: (error: Error, id) => {
+ console.error(`❌ Error in useDeleteProvider for ID ${id}:`, error);
+ },
+ });
+}
+
+
+
+/**
+ * Hook to test a provider connection
+ */
+export function useTestProviderConnection() {
+ const { connections } = useServerConnectionsContext();
+
+ return useMutation({
+ mutationFn: async ({ providerName, config }: { providerName: string, config: ServerProviderConfig }) => {
+ console.log(`📡 Testing connection for provider: ${providerName}`, config);
+
+ const client = createInferenceBridgeClient(connections);
+
+ console.log("📤 API request data:", config);
+
+ const response = await client.providers[":provider_name"].models.$post({
+ param: { provider_name: providerName },
+ json: config,
+ });
+
+ if (!response.ok) {
+ const errorData = await response.json();
+ const errorMsg = errorData.message || `Failed to test provider connection: ${response.status}`;
+ console.error(`❌ ${errorMsg}`, errorData);
+ throw new Error(errorMsg);
+ }
+
+ const result = await response.json();
+ console.log("✅ Provider connection test successful:", result);
+ return result;
+ },
+ onError: (error: Error, variables) => {
+ console.error(`❌ Error in useTestProviderConnection for provider ${variables.providerName}:`, error);
+ },
+ });
+}
\ No newline at end of file
diff --git a/graphcap_studio/src/features/server-connections/services/serverConnections.ts b/graphcap_studio/src/features/server-connections/services/serverConnections.ts
index 55362831..d1b8503b 100644
--- a/graphcap_studio/src/features/server-connections/services/serverConnections.ts
+++ b/graphcap_studio/src/features/server-connections/services/serverConnections.ts
@@ -3,10 +3,12 @@
* Server Connections Service
*
* This module provides functions for checking the health of server connections
- * such as the Media Server and GraphCap Server.
+ * such as the Media Server and Inference Bridge.
*/
-import { SERVER_IDS } from "../constants";
+import type { ServerConnection } from "@/types/server-connection-types";
+import { CONNECTION_STATUS, SERVER_IDS } from "../constants";
+import { createDataServiceClient, createInferenceBridgeClient } from "./apiClients";
/**
* Interface for health check response
@@ -62,13 +64,85 @@ export async function checkMediaServerHealth(url: string): Promise {
}
/**
- * Check the health of the GraphCap Server
+ * Check the health of the Inference Bridge
*
- * @param url - The base URL of the GraphCap Server
+ * @param url - The base URL of the Inference Bridge
* @returns A promise that resolves to a boolean indicating if the server is healthy
*/
-export async function checkGraphCapServerHealth(url: string): Promise {
- return checkServerHealth(url);
+export async function checkInferenceBridgeHealth(url: string): Promise {
+ try {
+ // Create mock connection array with the URL
+ const mockConnection: ServerConnection[] = [
+ {
+ id: SERVER_IDS.INFERENCE_BRIDGE,
+ name: "Inference Bridge",
+ status: CONNECTION_STATUS.DISCONNECTED,
+ url,
+ },
+ ];
+
+ // Create client with the URL
+ const client = createInferenceBridgeClient(mockConnection);
+
+ // First try the /api/v1/health endpoint using the client
+ try {
+ const response = await client.health.$get();
+
+ if (response.ok) {
+ const data = (await response.json()) as HealthCheckResponse;
+ // Check if the response contains a valid status
+ return data.status === "ok" || data.status === "healthy";
+ }
+ } catch (apiError) {
+ console.warn("Error checking Inference Bridge at /api/v1/health, trying direct health endpoint next:", apiError);
+ }
+
+ // Try direct /api/v1/health endpoint
+ try {
+ // Normalize URL by removing trailing slash if present
+ const normalizedUrl = url.endsWith("/") ? url.slice(0, -1) : url;
+ const apiResponse = await fetch(`${normalizedUrl}/api/v1/health`, {
+ method: "GET",
+ headers: {
+ Accept: "application/json",
+ },
+ // Set a timeout to prevent long-hanging requests
+ signal: AbortSignal.timeout(3000),
+ });
+
+ if (apiResponse.ok) {
+ const data = (await apiResponse.json()) as HealthCheckResponse;
+ // Check if the response contains a valid status
+ return data.status === "ok" || data.status === "healthy";
+ }
+ } catch (directApiError) {
+ console.warn("Error checking Inference Bridge at direct /api/v1/health, trying /health next:", directApiError);
+ }
+
+ // Fallback to the legacy /health endpoint with direct fetch as last resort
+ // Normalize URL by removing trailing slash if present
+ const normalizedUrl = url.endsWith("/") ? url.slice(0, -1) : url;
+ const fallbackResponse = await fetch(`${normalizedUrl}/health`, {
+ method: "GET",
+ headers: {
+ Accept: "application/json",
+ },
+ // Set a timeout to prevent long-hanging requests
+ signal: AbortSignal.timeout(3000),
+ });
+
+ if (!fallbackResponse.ok) {
+ console.error(`All health check endpoints failed. Last status: ${fallbackResponse.status}`);
+ return false;
+ }
+
+ const fallbackData = (await fallbackResponse.json()) as HealthCheckResponse;
+ // Check if the response contains a valid status
+ return fallbackData.status === "ok" || fallbackData.status === "healthy";
+ } catch (error) {
+ console.error("Error checking Inference Bridge health:", error);
+ return false;
+ }
}
/**
@@ -78,7 +152,61 @@ export async function checkGraphCapServerHealth(url: string): Promise {
* @returns A promise that resolves to a boolean indicating if the server is healthy
*/
export async function checkDataServiceHealth(url: string): Promise {
- return checkServerHealth(url);
+ try {
+ // Create mock connection array with the URL
+ const mockConnection: ServerConnection[] = [
+ {
+ id: SERVER_IDS.DATA_SERVICE,
+ name: "Data Service",
+ status: CONNECTION_STATUS.DISCONNECTED,
+ url,
+ },
+ ];
+
+ // Create client with the URL
+ const client = createDataServiceClient(mockConnection);
+
+ // Try the /api/v1/health endpoint using the client
+ try {
+ const response = await client.health.$get();
+
+ if (response.ok) {
+ const data = (await response.json()) as HealthCheckResponse;
+ // Check if the response contains a valid status
+ return data.status === "ok" || data.status === "healthy";
+ }
+ } catch (apiError) {
+ console.warn("Error checking Data Service at /api/v1/health, trying direct endpoint next:", apiError);
+ }
+
+ // Try direct /api/v1/health endpoint
+ try {
+ // Normalize URL by removing trailing slash if present
+ const normalizedUrl = url.endsWith("/") ? url.slice(0, -1) : url;
+ const apiResponse = await fetch(`${normalizedUrl}/api/v1/health`, {
+ method: "GET",
+ headers: {
+ Accept: "application/json",
+ },
+ // Set a timeout to prevent long-hanging requests
+ signal: AbortSignal.timeout(3000),
+ });
+
+ if (apiResponse.ok) {
+ const data = (await apiResponse.json()) as HealthCheckResponse;
+ // Check if the response contains a valid status
+ return data.status === "ok" || data.status === "healthy";
+ }
+ } catch (directApiError) {
+ console.warn("Error checking Data Service at direct /api/v1/health, trying /health next:", directApiError);
+ }
+
+ // Fallback to the direct /health endpoint check as last resort
+ return checkServerHealth(url);
+ } catch (error) {
+ console.error("Error checking Data Service health:", error);
+ return false;
+ }
}
/**
@@ -95,8 +223,8 @@ export async function checkServerHealthById(
switch (id) {
case SERVER_IDS.MEDIA_SERVER:
return checkMediaServerHealth(url);
- case SERVER_IDS.GRAPHCAP_SERVER:
- return checkGraphCapServerHealth(url);
+ case SERVER_IDS.INFERENCE_BRIDGE:
+ return checkInferenceBridgeHealth(url);
case SERVER_IDS.DATA_SERVICE:
return checkDataServiceHealth(url);
default:
diff --git a/graphcap_studio/src/features/server-connections/useServerConnections.ts b/graphcap_studio/src/features/server-connections/useServerConnections.ts
index 4e003bd9..ecd98566 100644
--- a/graphcap_studio/src/features/server-connections/useServerConnections.ts
+++ b/graphcap_studio/src/features/server-connections/useServerConnections.ts
@@ -1,3 +1,4 @@
+import type { ServerConnection } from "@/types/server-connection-types";
// SPDX-License-Identifier: Apache-2.0
import { useCallback, useEffect, useRef, useState } from "react";
import {
@@ -7,11 +8,10 @@ import {
SERVER_NAMES,
} from "./constants";
import { checkServerHealthById } from "./services/serverConnections";
-import { ServerConnection } from "./types";
// Local storage keys
-const STORAGE_KEY = "graphcap-server-connections";
-const VERSION_KEY = "graphcap-server-connections-version";
+const STORAGE_KEY = "inference-bridge-connections";
+const VERSION_KEY = "inference-bridge-connections-version";
// Current version of the connections schema
const CURRENT_VERSION = 1;
@@ -30,12 +30,12 @@ const getDefaultConnections = (): ServerConnection[] => {
DEFAULT_URLS[SERVER_IDS.MEDIA_SERVER],
},
{
- id: SERVER_IDS.GRAPHCAP_SERVER,
- name: SERVER_NAMES[SERVER_IDS.GRAPHCAP_SERVER],
+ id: SERVER_IDS.INFERENCE_BRIDGE,
+ name: SERVER_NAMES[SERVER_IDS.INFERENCE_BRIDGE],
status: "disconnected",
url:
import.meta.env.VITE_API_URL ||
- DEFAULT_URLS[SERVER_IDS.GRAPHCAP_SERVER],
+ DEFAULT_URLS[SERVER_IDS.INFERENCE_BRIDGE],
},
{
id: SERVER_IDS.DATA_SERVICE,
@@ -146,7 +146,7 @@ const saveConnectionsToStorage = (connections: ServerConnection[]): void => {
* Custom hook for managing server connections
*
* This hook provides state and handlers for managing server connections
- * such as Media Server and GraphCap Server.
+ * such as Media Server and Inference Bridge.
*/
export function useServerConnections() {
// Initialize connections with values from local storage or defaults
diff --git a/graphcap_studio/src/features/inference/generation-options/schema.ts b/graphcap_studio/src/types/generation-option-types.ts
similarity index 53%
rename from graphcap_studio/src/features/inference/generation-options/schema.ts
rename to graphcap_studio/src/types/generation-option-types.ts
index 593045c7..a348a593 100644
--- a/graphcap_studio/src/features/inference/generation-options/schema.ts
+++ b/graphcap_studio/src/types/generation-option-types.ts
@@ -2,7 +2,8 @@
/**
* Generation Options Schema
*
- * This module defines the validation schema for caption generation options.
+ * This module defines the validation schema for generation options,
+ * replacing the legacy CaptionOptions with a consolidated schema.
*/
import { z } from "zod";
@@ -27,7 +28,7 @@ export const RESOLUTION_PRESETS = {
UHD_8K: { label: "8K UHD", value: "UHD_8K" },
} as const;
-// Default options for caption generation
+// Default options for generation
export const DEFAULT_OPTIONS = {
temperature: 0.7,
max_tokens: 4096,
@@ -35,6 +36,9 @@ export const DEFAULT_OPTIONS = {
repetition_penalty: 1.1,
resize_resolution: "NONE", // Default to no resize
global_context: "You are a visual captioning perspective.",
+ context: [] as string[], // Default to empty context array
+ provider_name: "", // Default to empty (will be populated later)
+ model_name: "", // Default to empty (will be populated later)
} as const;
// Schema for generation options
@@ -67,7 +71,60 @@ export const GenerationOptionsSchema = z.object({
resize_resolution: z.string().default(DEFAULT_OPTIONS.resize_resolution),
global_context: z.string().default(DEFAULT_OPTIONS.global_context),
+
+ // Added context array (was in CaptionOptions)
+ context: z.array(z.string()).default([]),
+
+ // Provider and model selection (using names instead of IDs)
+ provider_name: z.string().default(DEFAULT_OPTIONS.provider_name),
+
+ model_name: z.string().default(DEFAULT_OPTIONS.model_name),
});
// Type for generation options
export type GenerationOptions = z.infer;
+
+/**
+ * Format generation options for API requests
+ * This transforms the frontend GenerationOptions to the format expected by the API
+ */
+export function formatApiOptions(options: GenerationOptions): Record {
+ return {
+ model: options.model_name, // API expects 'model' instead of model_name
+ temperature: options.temperature,
+ max_tokens: options.max_tokens,
+ top_p: options.top_p,
+ repetition_penalty: options.repetition_penalty,
+ global_context: options.global_context,
+ context: options.context,
+ resize_resolution: options.resize_resolution,
+ };
+}
+
+/**
+ * Format a complete caption request
+ *
+ * @param imagePath Path to the image
+ * @param perspective Perspective name
+ * @param options Generation options
+ * @param providerId Provider ID to use with the API
+ * @returns Formatted request object
+ */
+export function formatCaptionRequest(
+ imagePath: string,
+ perspective: string,
+ options: GenerationOptions,
+ providerId: string
+): {
+ image_path: string;
+ perspective: string;
+ provider_id: string;
+ options: Record;
+} {
+ return {
+ image_path: imagePath,
+ perspective: perspective,
+ provider_id: providerId,
+ options: formatApiOptions(options),
+ };
+}
diff --git a/graphcap_studio/src/types/index.ts b/graphcap_studio/src/types/index.ts
new file mode 100644
index 00000000..653e5f72
--- /dev/null
+++ b/graphcap_studio/src/types/index.ts
@@ -0,0 +1,6 @@
+export * from "./generation-option-types";
+export * from "./perspective-module-types";
+export * from "./perspective-types";
+export * from "./provider-config-types";
+export * from "./server-connection-types";
+
diff --git a/graphcap_studio/src/features/perspectives/types/perspectiveModuleTypes.ts b/graphcap_studio/src/types/perspective-module-types.ts
similarity index 92%
rename from graphcap_studio/src/features/perspectives/types/perspectiveModuleTypes.ts
rename to graphcap_studio/src/types/perspective-module-types.ts
index 3084b7ab..0d265d1c 100644
--- a/graphcap_studio/src/features/perspectives/types/perspectiveModuleTypes.ts
+++ b/graphcap_studio/src/types/perspective-module-types.ts
@@ -5,9 +5,9 @@
* This module defines types related to perspective modules and management.
*/
+import type { Perspective } from "@/types/perspective-types";
+import { PerspectiveSchema } from "@/types/perspective-types";
import { z } from "zod";
-import { PerspectiveSchema } from "./perspectivesTypes";
-import type { Perspective } from "./perspectivesTypes";
/**
* Schema for module information
diff --git a/graphcap_studio/src/features/perspectives/types/perspectivesTypes.ts b/graphcap_studio/src/types/perspective-types.ts
similarity index 90%
rename from graphcap_studio/src/features/perspectives/types/perspectivesTypes.ts
rename to graphcap_studio/src/types/perspective-types.ts
index 2a1804a2..03cafd31 100644
--- a/graphcap_studio/src/features/perspectives/types/perspectivesTypes.ts
+++ b/graphcap_studio/src/types/perspective-types.ts
@@ -101,6 +101,7 @@ export const CaptionRequestSchema = z.object({
provider: z.string().optional(), // For backward compatibility
options: z
.object({
+ model: z.string(), // Required model name to use for processing
max_tokens: z.number().optional(),
temperature: z.number().optional(),
top_p: z.number().optional(),
@@ -190,6 +191,7 @@ export type ServerConnection = {
* Specifies options for generating captions.
*/
export type CaptionOptions = {
+ model: string; // Required model name to use for processing
max_tokens?: number;
temperature?: number;
top_p?: number;
@@ -200,11 +202,6 @@ export type CaptionOptions = {
resize_resolution?: string;
};
-/**
- * String alias that allows any perspective name to be used.
- */
-export type PerspectiveType = string;
-
/**
* Describes a provider with id and name.
*/
@@ -221,7 +218,7 @@ export interface PerspectiveData {
version: string;
model: string;
provider: string;
- content: Record;
+ content: Record;
options: CaptionOptions;
}
@@ -242,24 +239,6 @@ export interface ImageCaptions {
// SECTION D - COMPOSITE TYPES
// ============================================================================
-/**
- * Result type for the useImagePerspectives hook.
- */
-export interface ImagePerspectivesResult {
- isLoading: boolean;
- error: string | null;
- captions: ImageCaptions | null;
- generatedPerspectives: PerspectiveType[];
- generatingPerspectives: string[];
- generatePerspective: (
- perspective: PerspectiveType,
- providerId?: number,
- options?: CaptionOptions,
- ) => void;
- generateAllPerspectives: () => void;
- availablePerspectives: Perspective[];
- availableProviders: Provider[];
-}
/**
* Context type for the perspectives feature.
@@ -283,5 +262,5 @@ export interface PerspectivesContextType {
*/
export interface PerspectivesProviderProps {
children: React.ReactNode;
- initialSelectedProviderId?: number | undefined;
+ initialSelectedProviderId?: number;
}
diff --git a/graphcap_studio/src/types/provider-config-types.ts b/graphcap_studio/src/types/provider-config-types.ts
new file mode 100644
index 00000000..ce0020b1
--- /dev/null
+++ b/graphcap_studio/src/types/provider-config-types.ts
@@ -0,0 +1,225 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Provider Types
+ *
+ * Type definitions for provider-related data with Zod validation.
+ */
+
+import { z } from "zod";
+
+// ============================================================================
+// SECTION A - ZOD SCHEMAS
+// ============================================================================
+
+/**
+ * Base provider schema
+ */
+export const BaseProviderSchema = z.object({
+ id: z.string(),
+ name: z.string().min(1, "Name is required"),
+ isEnabled: z.boolean().default(true),
+});
+
+/**
+ * Provider model schema
+ */
+export const ProviderModelSchema = z.object({
+ id: z.string(),
+ providerId: z.string(),
+ name: z.string().min(1, "Model name is required"),
+ isEnabled: z.boolean().default(true),
+ createdAt: z.string().or(z.date()),
+ updatedAt: z.string().or(z.date()),
+});
+
+/**
+ * Complete provider schema
+ */
+export const ProviderSchema = BaseProviderSchema.extend({
+ kind: z.string().min(1, "Kind is required"),
+ environment: z.enum(["cloud", "local"]),
+ baseUrl: z.string().url("Must be a valid URL"),
+ apiKey: z.string().optional(),
+ defaultModel: z.string().optional(),
+ createdAt: z.string().or(z.date()),
+ updatedAt: z.string().or(z.date()),
+ models: z.array(ProviderModelSchema).optional(),
+});
+
+// Provider creation schema
+export const ProviderCreateSchema = z.object({
+ name: z.string().min(1, "Name is required"),
+ kind: z.string().min(1, "Kind is required"),
+ environment: z.enum(["cloud", "local"]),
+ baseUrl: z.string().url("Must be a valid URL"),
+ apiKey: z.string().optional(),
+ isEnabled: z.boolean().default(true),
+ defaultModel: z.string().optional(),
+ models: z
+ .array(
+ z.object({
+ name: z.string().min(1, "Model name is required"),
+ isEnabled: z.boolean().default(true),
+ }),
+ )
+ .optional(),
+});
+
+// Provider update schema
+export const ProviderUpdateSchema = z.object({
+ name: z.string().min(1, "Name is required").optional(),
+ kind: z.string().min(1, "Kind is required").optional(),
+ environment: z.enum(["cloud", "local"]).optional(),
+ baseUrl: z.string().url("Must be a valid URL").optional(),
+ apiKey: z.string().optional(),
+ isEnabled: z.boolean().optional(),
+ defaultModel: z.string().optional(),
+ models: z
+ .array(
+ z.object({
+ id: z.string().optional(),
+ name: z.string().min(1, "Model name is required"),
+ isEnabled: z.boolean().default(true),
+ }),
+ )
+ .optional(),
+});
+
+// Provider model info schema
+export const ProviderModelInfoSchema = z.object({
+ id: z.string(),
+ name: z.string(),
+ is_default: z.boolean().optional(),
+});
+
+// Provider models response schema
+export const ProviderModelsResponseSchema = z.object({
+ provider: z.string(),
+ models: z.array(ProviderModelInfoSchema),
+});
+
+// Success response schema
+export const SuccessResponseSchema = z.object({
+ success: z.boolean(),
+ message: z.string(),
+});
+
+// Error details schema
+export const ErrorDetailsSchema = z.object({
+ message: z.string(),
+ code: z.string().optional(),
+ details: z.record(z.unknown()).optional(),
+});
+
+// Connection details schema
+export const ConnectionDetailsSchema = z.object({
+ result: z.boolean(),
+ details: z.record(z.unknown()).optional(),
+ message: z.string().optional(),
+});
+
+// Server provider config schema
+export const ServerProviderConfigSchema = z.object({
+ name: z.string(),
+ kind: z.string(),
+ environment: z.enum(["cloud", "local"]),
+ base_url: z.string(),
+ api_key: z.string(),
+ default_model: z.string().optional(),
+ models: z.array(z.string()),
+});
+
+// ============================================================================
+// SECTION B - INFERRED TYPES
+// ============================================================================
+
+/**
+ * Base provider interface for selection
+ */
+export type BaseProvider = z.infer;
+
+/**
+ * Provider model
+ */
+export type ProviderModel = z.infer;
+
+/**
+ * Provider configuration stored in data service
+ */
+export type Provider = z.infer;
+
+/**
+ * Provider creation payload
+ */
+export type ProviderCreate = z.infer;
+
+/**
+ * Provider update payload
+ */
+export type ProviderUpdate = z.infer;
+
+/**
+ * Provider model info from GraphCap server
+ */
+export type ProviderModelInfo = z.infer;
+
+/**
+ * Provider models response from GraphCap server
+ */
+export type ProviderModelsResponse = z.infer<
+ typeof ProviderModelsResponseSchema
+>;
+
+/**
+ * Success response
+ */
+export type SuccessResponse = z.infer;
+
+/**
+ * Error details
+ */
+export type ErrorDetails = z.infer;
+
+/**
+ * Connection details
+ */
+export type ConnectionDetails = z.infer;
+
+/**
+ * Server-side provider configuration
+ * This is the configuration that gets sent to the inference server
+ */
+export type ServerProviderConfig = z.infer;
+
+// ============================================================================
+// SECTION C - UTILITY FUNCTIONS
+// ============================================================================
+
+/**
+ * Convert string ID to number for API calls
+ */
+export function denormalizeProviderId(id: string): number {
+ return Number.parseInt(id, 10);
+}
+
+/**
+ * Convert number ID to string for frontend use
+ */
+export function normalizeProviderId(id: number | string): string {
+ return id.toString();
+}
+
+/**
+ * Helper function to convert Provider to ServerProviderConfig
+ */
+export function toServerConfig(provider: Provider): ServerProviderConfig {
+ return {
+ name: provider.name,
+ kind: provider.kind,
+ environment: provider.environment,
+ base_url: provider.baseUrl,
+ api_key: provider.apiKey ?? "",
+ default_model: provider.defaultModel,
+ models: provider.models?.map((m) => m.name) || [],
+ };
+}
diff --git a/graphcap_studio/src/features/server-connections/types.ts b/graphcap_studio/src/types/server-connection-types.ts
similarity index 100%
rename from graphcap_studio/src/features/server-connections/types.ts
rename to graphcap_studio/src/types/server-connection-types.ts
diff --git a/graphcap_studio/src/utils/error-handler.ts b/graphcap_studio/src/utils/error-handler.ts
new file mode 100644
index 00000000..9f475a78
--- /dev/null
+++ b/graphcap_studio/src/utils/error-handler.ts
@@ -0,0 +1,128 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Error Handler
+ *
+ * Utilities for handling and formatting errors in the client application.
+ */
+
+// Import from our custom toast utility
+import { toast } from './toast';
+
+interface ServerErrorResponse {
+ status?: string;
+ statusCode?: number;
+ message?: string;
+ timestamp?: string;
+ path?: string;
+ details?: unknown;
+ validationErrors?: Record;
+}
+
+/**
+ * Extracts a message from a validation error object
+ */
+function extractValidationErrorMessage(validationErrors: Record): string | null {
+ const validationMessages: string[] = [];
+
+ for (const [field, errors] of Object.entries(validationErrors)) {
+ for (const errorMsg of errors) {
+ validationMessages.push(`${field}: ${errorMsg}`);
+ }
+ }
+
+ if (validationMessages.length > 0) {
+ return `Validation errors:\n${validationMessages.join('\n')}`;
+ }
+
+ return null;
+}
+
+/**
+ * Extracts message from a server error object
+ */
+function extractServerErrorMessage(serverError: ServerErrorResponse): string | null {
+ // If there's a message, use it
+ if (serverError.message) {
+ return serverError.message;
+ }
+
+ // If there are validation errors, format them
+ if (serverError.validationErrors) {
+ return extractValidationErrorMessage(serverError.validationErrors);
+ }
+
+ // If there's an error property with a message (common in Axios errors)
+ if ('error' in serverError && typeof serverError.error === 'string') {
+ return serverError.error;
+ }
+
+ return null;
+}
+
+/**
+ * Formats a server error response into a human-readable message
+ */
+export function formatServerError(error: unknown): string {
+ // If it's already a string, just return it
+ if (typeof error === 'string') {
+ return error;
+ }
+
+ // Try to handle server error response
+ if (error && typeof error === 'object') {
+ const serverError = error as ServerErrorResponse;
+ const message = extractServerErrorMessage(serverError);
+ if (message) {
+ return message;
+ }
+ }
+
+ // Fallback for Error instances
+ if (error instanceof Error) {
+ return error.message;
+ }
+
+ // Last resort
+ return 'An unknown error occurred';
+}
+
+/**
+ * Shows a toast notification for server errors
+ */
+export function showServerError(error: unknown, title = 'Error'): void {
+ const message = formatServerError(error);
+ toast.error({ title, description: message });
+}
+
+/**
+ * Helper to extract validation errors from server responses
+ */
+export function getValidationErrors(error: unknown): Record | null {
+ if (!error || typeof error !== 'object') {
+ return null;
+ }
+
+ const serverError = error as ServerErrorResponse;
+
+ if (!serverError.validationErrors) {
+ return null;
+ }
+
+ const formattedErrors: Record = {};
+
+ for (const [field, errors] of Object.entries(serverError.validationErrors)) {
+ if (errors && errors.length > 0) {
+ formattedErrors[field] = errors[0];
+ }
+ }
+
+ return Object.keys(formattedErrors).length > 0 ? formattedErrors : null;
+}
+
+/**
+ * Handles common query/mutation errors
+ */
+export function handleApiError(error: unknown): void {
+ showServerError(error);
+ console.error('API error:', error);
+}
\ No newline at end of file
diff --git a/graphcap_studio/src/utils/toast.ts b/graphcap_studio/src/utils/toast.ts
index 81197da5..ed918de2 100644
--- a/graphcap_studio/src/utils/toast.ts
+++ b/graphcap_studio/src/utils/toast.ts
@@ -1,49 +1,174 @@
-import { ToastT, toast } from "sonner";
+// SPDX-License-Identifier: Apache-2.0
+import { toaster } from "@/components/ui/toaster";
type ToastType = "error" | "success";
export const showToast = (
text: string,
type: ToastType,
- options?: Parameters[1],
+ options?: Omit[0], "title" | "description">,
) => {
- const toastFn = type === "error" ? toast.error : toast.success;
- toastFn(text, options);
+ const title = text;
+ if (type === "error") {
+ toaster.create({
+ title,
+ type: "error",
+ ...options,
+ });
+ } else {
+ toaster.create({
+ title,
+ type: "success",
+ ...options,
+ });
+ }
};
export const errorToast = (
text: string,
- options?: Parameters[1],
+ options?: Omit[0], "title" | "description" | "type">,
) => {
console.error(text);
if (text != null && text !== "") {
- toast.error(text, options);
+ toaster.create({
+ title: text,
+ type: "error",
+ ...options,
+ });
}
};
export const successToast = (
text: string,
- options?: Parameters[1],
+ options?: Omit[0], "title" | "description" | "type">,
) => {
if (text != null && text !== "") {
- toast.success(text, options);
+ toaster.create({
+ title: text,
+ type: "success",
+ ...options,
+ });
}
};
type MessageType = {
- success: string | ((data: any) => string);
- error?: string | ((error: any) => string);
+ success: string | ((data: unknown) => string);
+ error?: string | ((error: unknown) => string);
};
export async function promiseToast(
promise: Promise,
message: MessageType,
- options?: Parameters[1],
+ options?: Omit[0], "title" | "description" | "type">,
) {
- return toast.promise(promise, {
- loading: "Loading",
- success: message.success,
- error: message.error || "Error. Please try again",
+ // Show loading toast
+ const loadingToastId = toaster.create({
+ title: "Loading",
+ type: "loading",
...options,
});
+
+ try {
+ const result = await promise;
+ // Close loading toast
+ toaster.dismiss(loadingToastId);
+ // Show success toast
+ const successMessage = typeof message.success === 'function'
+ ? message.success(result)
+ : message.success;
+ toaster.create({
+ title: successMessage,
+ type: "success",
+ ...options,
+ });
+ return result;
+ } catch (error) {
+ // Close loading toast
+ toaster.dismiss(loadingToastId);
+ // Show error toast
+ const errorMessage = message.error && typeof message.error === 'function'
+ ? message.error(error)
+ : message.error || "Error. Please try again";
+ toaster.create({
+ title: errorMessage,
+ type: "error",
+ ...options,
+ });
+ throw error;
+ }
}
+
+/**
+ * Toast notification utility
+ */
+export const toast = {
+ /**
+ * Show a success toast
+ */
+ success: ({ title, description, duration = 1000 }: { title: string; description?: string; duration?: number }) => {
+ return toaster.create({
+ title,
+ description,
+ duration,
+ type: "success",
+ });
+ },
+
+ /**
+ * Show an error toast
+ */
+ error: ({ title, description, duration = 2000 }: { title: string; description?: string; duration?: number }) => {
+ return toaster.create({
+ title,
+ description,
+ duration,
+ type: "error",
+ });
+ },
+
+ /**
+ * Show an info toast
+ */
+ info: ({ title, description, duration = 2000 }: { title: string; description?: string; duration?: number }) => {
+ return toaster.create({
+ title,
+ description,
+ duration,
+ type: "info",
+ });
+ },
+
+ /**
+ * Show a warning toast
+ */
+ warning: ({ title, description, duration = 2000 }: { title: string; description?: string; duration?: number }) => {
+ return toaster.create({
+ title,
+ description,
+ duration,
+ type: "warning",
+ });
+ },
+
+ /**
+ * Dismiss a toast by its ID
+ * If no ID is provided, all toasts will be dismissed
+ */
+ dismiss: (id?: string) => {
+ toaster.dismiss(id);
+ },
+
+ /**
+ * Pause a toast by its ID to prevent it from timing out
+ */
+ pause: (id: string) => {
+ toaster.pause(id);
+ },
+
+ /**
+ * Resume a paused toast, re-enabling the timeout with the remaining duration
+ */
+ resume: (id: string) => {
+ toaster.resume(id);
+ }
+};
diff --git a/pyproject.toml b/pyproject.toml
index 99491de5..2057d0dc 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -49,7 +49,7 @@ exclude = [
]
# Source directories
-src = ["./servers/inference_server"]
+src = ["./servers/inference_bridge"]
# Same as Black
line-length = 120
diff --git a/servers/data_service/src/api/controllers/providers.ts b/servers/data_service/src/api/controllers/providers.ts
deleted file mode 100644
index 06ad3e07..00000000
--- a/servers/data_service/src/api/controllers/providers.ts
+++ /dev/null
@@ -1,344 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-/**
- * Provider Controllers
- *
- * This module defines the controller functions for provider management.
- */
-
-import { eq } from "drizzle-orm";
-import type { Context } from "hono";
-import { db } from "../../db";
-import { providerModels, providerRateLimits, providers } from "../../db/schema";
-import { encryptApiKey } from "../../utils/encryption";
-import { logger } from "../../utils/logger";
-import type {
- ProviderApiKey,
- ProviderCreate,
- ProviderUpdate,
-} from "../schemas/providers";
-
-// Type for the validated parameters
-type ValidatedParams = {
- id: string;
-};
-
-/**
- * Get all providers
- */
-export const getProviders = async (c: Context) => {
- try {
- logger.debug("Fetching all providers");
-
- const allProviders = await db.query.providers.findMany({
- with: {
- models: true,
- rateLimits: true,
- },
- });
-
- logger.debug(
- { count: allProviders.length },
- "Providers fetched successfully",
- );
- return c.json(allProviders);
- } catch (error) {
- // Log the full error details for debugging
- logger.error(
- {
- error,
- message: error instanceof Error ? error.message : "Unknown error",
- stack: error instanceof Error ? error.stack : undefined,
- },
- "Error fetching providers",
- );
-
- // Return a more informative error response
- return c.json(
- {
- error: "Failed to fetch providers",
- message:
- error instanceof Error ? error.message : "Unknown database error",
- },
- 500,
- );
- }
-};
-
-/**
- * Get a specific provider by ID
- */
-export const getProvider = async (c: Context) => {
- try {
- // @ts-ignore - Hono OpenAPI validation types are not properly recognized
- const { id } = c.req.valid("param") as ValidatedParams;
- logger.debug({ id }, "Fetching provider by ID");
-
- const provider = await db.query.providers.findFirst({
- where: eq(providers.id, Number.parseInt(id)),
- with: {
- models: true,
- rateLimits: true,
- },
- });
-
- if (!provider) {
- logger.debug({ id }, "Provider not found");
- return c.json({ error: "Provider not found" }, 404);
- }
-
- logger.debug({ id }, "Provider fetched successfully");
- return c.json(provider);
- } catch (error) {
- logger.error({ error }, "Error fetching provider");
- return c.json({ error: "Failed to fetch provider" }, 500);
- }
-};
-
-/**
- * Create a new provider
- */
-export const createProvider = async (c: Context) => {
- try {
- // @ts-ignore - Hono OpenAPI validation types are not properly recognized
- const data = c.req.valid("json") as ProviderCreate;
- logger.debug({ data }, "Creating new provider");
-
- // Extract models and rate limits if provided
- const { models, rateLimits, ...providerData } = data;
-
- // Encrypt API key if provided
- if (providerData.apiKey) {
- providerData.apiKey = await encryptApiKey(providerData.apiKey);
- }
-
- // Start a transaction
- const result = await db.transaction(async (tx) => {
- // Insert provider
- const [provider] = await tx
- .insert(providers)
- .values({
- ...providerData,
- createdAt: new Date(),
- updatedAt: new Date(),
- })
- .returning();
-
- // Insert models if provided
- if (models && models.length > 0) {
- await tx.insert(providerModels).values(
- models.map((model) => ({
- providerId: provider.id,
- name: model.name,
- isEnabled: model.isEnabled,
- createdAt: new Date(),
- updatedAt: new Date(),
- })),
- );
- }
-
- // Insert rate limits if provided
- if (rateLimits) {
- await tx.insert(providerRateLimits).values({
- providerId: provider.id,
- requestsPerMinute: rateLimits.requestsPerMinute,
- tokensPerMinute: rateLimits.tokensPerMinute,
- createdAt: new Date(),
- updatedAt: new Date(),
- });
- }
-
- // Return the created provider with relations
- return await tx.query.providers.findFirst({
- where: eq(providers.id, provider.id),
- with: {
- models: true,
- rateLimits: true,
- },
- });
- });
-
- logger.debug({ id: result?.id }, "Provider created successfully");
- return c.json(result, 201);
- } catch (error) {
- logger.error({ error }, "Error creating provider");
- return c.json({ error: "Failed to create provider" }, 500);
- }
-};
-
-/**
- * Update an existing provider
- */
-export const updateProvider = async (c: Context) => {
- try {
- // @ts-ignore - Hono OpenAPI validation types are not properly recognized
- const { id } = c.req.valid("param") as ValidatedParams;
- // @ts-ignore - Hono OpenAPI validation types are not properly recognized
- const data = c.req.valid("json") as ProviderUpdate;
- logger.debug({ id, data }, "Updating provider");
-
- // Check if provider exists
- const existingProvider = await db.query.providers.findFirst({
- where: eq(providers.id, Number.parseInt(id)),
- });
-
- if (!existingProvider) {
- logger.debug({ id }, "Provider not found for update");
- return c.json({ error: "Provider not found" }, 404);
- }
-
- // Extract models and rate limits if provided
- const { models, rateLimits, ...providerData } = data;
-
- // Start a transaction
- const result = await db.transaction(async (tx) => {
- // Update provider
- await tx
- .update(providers)
- .set({
- ...providerData,
- updatedAt: new Date(),
- })
- .where(eq(providers.id, Number.parseInt(id)));
-
- // Update models if provided
- if (models && models.length > 0) {
- // First, delete existing models
- await tx
- .delete(providerModels)
- .where(eq(providerModels.providerId, Number.parseInt(id)));
-
- // Then insert new models
- await tx.insert(providerModels).values(
- models.map((model) => ({
- providerId: Number.parseInt(id),
- name: model.name,
- isEnabled: model.isEnabled,
- createdAt: new Date(),
- updatedAt: new Date(),
- })),
- );
- }
-
- // Update rate limits if provided
- if (rateLimits) {
- // Check if rate limits exist
- const existingRateLimits = await tx.query.providerRateLimits.findFirst({
- where: eq(providerRateLimits.providerId, Number.parseInt(id)),
- });
-
- if (existingRateLimits) {
- // Update existing rate limits
- await tx
- .update(providerRateLimits)
- .set({
- requestsPerMinute: rateLimits.requestsPerMinute,
- tokensPerMinute: rateLimits.tokensPerMinute,
- updatedAt: new Date(),
- })
- .where(eq(providerRateLimits.providerId, Number.parseInt(id)));
- } else {
- // Insert new rate limits
- await tx.insert(providerRateLimits).values({
- providerId: Number.parseInt(id),
- requestsPerMinute: rateLimits.requestsPerMinute,
- tokensPerMinute: rateLimits.tokensPerMinute,
- createdAt: new Date(),
- updatedAt: new Date(),
- });
- }
- }
-
- // Return the updated provider with relations
- return await tx.query.providers.findFirst({
- where: eq(providers.id, Number.parseInt(id)),
- with: {
- models: true,
- rateLimits: true,
- },
- });
- });
-
- logger.debug({ id }, "Provider updated successfully");
- return c.json(result);
- } catch (error) {
- logger.error({ error }, "Error updating provider");
- return c.json({ error: "Failed to update provider" }, 500);
- }
-};
-
-/**
- * Delete a provider
- */
-export const deleteProvider = async (c: Context) => {
- try {
- // @ts-ignore - Hono OpenAPI validation types are not properly recognized
- const { id } = c.req.valid("param") as ValidatedParams;
- logger.debug({ id }, "Deleting provider");
-
- // Check if provider exists
- const existingProvider = await db.query.providers.findFirst({
- where: eq(providers.id, Number.parseInt(id)),
- });
-
- if (!existingProvider) {
- logger.debug({ id }, "Provider not found for deletion");
- return c.json({ error: "Provider not found" }, 404);
- }
-
- // Delete provider (cascade will handle related records)
- await db.delete(providers).where(eq(providers.id, Number.parseInt(id)));
-
- logger.debug({ id }, "Provider deleted successfully");
- return c.json({
- success: true,
- message: "Provider deleted successfully",
- });
- } catch (error) {
- logger.error({ error }, "Error deleting provider");
- return c.json({ error: "Failed to delete provider" }, 500);
- }
-};
-
-/**
- * Update a provider's API key
- */
-export const updateProviderApiKey = async (c: Context) => {
- try {
- // @ts-ignore - Hono OpenAPI validation types are not properly recognized
- const { id } = c.req.valid("param") as ValidatedParams;
- // @ts-ignore - Hono OpenAPI validation types are not properly recognized
- const { apiKey } = c.req.valid("json") as ProviderApiKey;
- logger.debug({ id }, "Updating provider API key");
-
- // Check if provider exists
- const existingProvider = await db.query.providers.findFirst({
- where: eq(providers.id, Number.parseInt(id)),
- });
-
- if (!existingProvider) {
- logger.debug({ id }, "Provider not found for API key update");
- return c.json({ error: "Provider not found" }, 404);
- }
-
- // Encrypt the API key
- const encryptedApiKey = await encryptApiKey(apiKey);
-
- // Update the provider's API key
- await db
- .update(providers)
- .set({
- apiKey: encryptedApiKey,
- updatedAt: new Date(),
- })
- .where(eq(providers.id, Number.parseInt(id)));
-
- logger.debug({ id }, "Provider API key updated successfully");
- return c.json({
- success: true,
- message: "API key updated successfully",
- });
- } catch (error) {
- logger.error({ error }, "Error updating provider API key");
- return c.json({ error: "Failed to update API key" }, 500);
- }
-};
diff --git a/servers/data_service/src/api/routes/index.ts b/servers/data_service/src/api/routes/index.ts
index 22560994..fac8b6d3 100644
--- a/servers/data_service/src/api/routes/index.ts
+++ b/servers/data_service/src/api/routes/index.ts
@@ -5,7 +5,7 @@
* This file exports route definitions for client consumption.
*/
+import { providerRoutes } from '../../features/provider_config/routes';
import { batchQueueRoutes } from './batch_queue';
-import { providerRoutes } from './providers';
export { providerRoutes, batchQueueRoutes };
\ No newline at end of file
diff --git a/servers/data_service/src/api/routes/log_test.ts b/servers/data_service/src/api/routes/log_test.ts
new file mode 100644
index 00000000..ff1fca4d
--- /dev/null
+++ b/servers/data_service/src/api/routes/log_test.ts
@@ -0,0 +1,63 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Log Test Routes
+ *
+ * Test routes to demonstrate the usage of the pino logger.
+ */
+
+import { Hono } from 'hono';
+
+// Create a new router instance
+const router = new Hono();
+
+// Test route that uses the logger
+router.get('/', (c) => {
+ // Get the pino logger instance from context
+ const { logger } = c.var;
+
+ // Log at different levels
+ logger.trace('This is a trace message');
+ logger.debug('This is a debug message');
+ logger.info('This is an info message');
+ logger.warn('This is a warning message');
+
+ // Log objects
+ logger.info({ user: { id: 1, name: 'test user' } }, 'User info');
+
+ // Log with additional context
+ logger.assign({ requestId: 'test-123' });
+ logger.info('Message with request ID');
+
+ // Set a response message
+ logger.setResMessage('Log test successful');
+
+ // Return a basic response
+ return c.json({
+ message: 'Logger test route',
+ levels: ['trace', 'debug', 'info', 'warn', 'error', 'fatal'],
+ time: new Date().toISOString()
+ });
+});
+
+// Add a route that triggers an error
+router.get('/error', (c) => {
+ const { logger } = c.var;
+
+ try {
+ // Simulate an error
+ throw new Error('Test error for logging');
+ } catch (error) {
+ // Log the error
+ logger.error({ error }, 'An error occurred');
+
+ // Return an error response
+ return c.json({
+ status: 'error',
+ message: 'Test error triggered',
+ time: new Date().toISOString()
+ }, 500);
+ }
+});
+
+// Export the router
+export const logTestRoutes = router;
\ No newline at end of file
diff --git a/servers/data_service/src/app.ts b/servers/data_service/src/app.ts
index 5faa524a..de3bdac5 100644
--- a/servers/data_service/src/app.ts
+++ b/servers/data_service/src/app.ts
@@ -14,32 +14,32 @@ import { timing } from 'hono/timing';
import { z } from 'zod';
import { batchQueueRoutes } from './api/routes/batch_queue';
-import { providerRoutes } from './api/routes/providers';
+import { logTestRoutes } from './api/routes/log_test';
import { checkDatabaseConnection } from './db/init';
import { env } from './env';
+import { providerRoutes } from './features/provider_config/routes';
+import { errorHandlerMiddleware, notFoundHandler } from './utils/error-handler';
import { logger } from './utils/logger';
+import { createDetailedLoggingMiddleware, createPinoLoggerMiddleware } from './utils/pino-middleware';
// Create OpenAPI Hono app
const app = new OpenAPIHono();
+// Add error handling middleware first so it can catch errors from other middleware
+app.use('*', errorHandlerMiddleware({ logErrors: true }));
+
// Add middleware
app.use('*', cors());
app.use('*', prettyJSON());
app.use('*', timing());
app.use('*', secureHeaders());
-// Add custom logger middleware
-app.use('*', async (c, next) => {
- const { method, url } = c.req;
- logger.info({ method, url }, 'Request received');
-
- const start = Date.now();
- await next();
- const end = Date.now();
-
- const status = c.res.status;
- logger.info({ method, url, status, responseTime: end - start }, 'Request completed');
-});
+// Add pino logger middleware
+app.use('*', createPinoLoggerMiddleware());
+
+// Add detailed logging middleware for API routes
+app.use('/api/*', createDetailedLoggingMiddleware());
+app.use(`${env.API_PREFIX}/*`, createDetailedLoggingMiddleware());
// Health check endpoint
const healthCheckRoute = createRoute({
@@ -138,6 +138,7 @@ app.openapi(dbHealthCheckRoute, async (c) => {
// API routes with v1 prefix
app.route(`${env.API_PREFIX}/v1/providers`, providerRoutes);
app.route(`${env.API_PREFIX}/v1/perspectives/batch`, batchQueueRoutes);
+app.route(`${env.API_PREFIX}/v1/logs`, logTestRoutes);
// OpenAPI documentation
app.doc('openapi', {
@@ -164,12 +165,24 @@ app.get('/docs', apiReference({
layout: 'modern',
}));
-// Error handling
+// Error handling - replace existing onError handler
app.onError((err, c) => {
- logger.error({ err, path: c.req.path }, 'Unhandled error');
- return c.json({ error: 'Internal server error' }, 500);
+ // The middleware should handle most errors,
+ // but this is a fallback for errors that somehow bypass the middleware
+ logger.error({ err, path: c.req.path }, 'Unhandled error in onError handler');
+
+ return c.json({
+ status: 'error',
+ statusCode: 500,
+ message: 'Internal server error',
+ timestamp: new Date().toISOString(),
+ path: c.req.path
+ }, 500);
});
+// Add not found handler
+app.notFound(notFoundHandler);
+
// Export the app
export default app;
diff --git a/servers/data_service/src/db/migrations/20250326170753_great_wallop.sql b/servers/data_service/src/db/migrations/20250326170753_great_wallop.sql
new file mode 100644
index 00000000..4fc36a53
--- /dev/null
+++ b/servers/data_service/src/db/migrations/20250326170753_great_wallop.sql
@@ -0,0 +1 @@
+ALTER TABLE "providers" DROP COLUMN IF EXISTS "env_var";
\ No newline at end of file
diff --git a/servers/data_service/src/db/migrations/meta/20250326170753_snapshot.json b/servers/data_service/src/db/migrations/meta/20250326170753_snapshot.json
new file mode 100644
index 00000000..8190fd11
--- /dev/null
+++ b/servers/data_service/src/db/migrations/meta/20250326170753_snapshot.json
@@ -0,0 +1,446 @@
+{
+ "id": "5cdc19fe-6d18-47e2-ab22-cd834b775abf",
+ "prevId": "d01eab73-fa91-4e9c-ab62-d5b64395049a",
+ "version": "7",
+ "dialect": "postgresql",
+ "tables": {
+ "public.provider_models": {
+ "name": "provider_models",
+ "schema": "",
+ "columns": {
+ "id": {
+ "name": "id",
+ "type": "serial",
+ "primaryKey": true,
+ "notNull": true
+ },
+ "provider_id": {
+ "name": "provider_id",
+ "type": "integer",
+ "primaryKey": false,
+ "notNull": false
+ },
+ "name": {
+ "name": "name",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true
+ },
+ "is_enabled": {
+ "name": "is_enabled",
+ "type": "boolean",
+ "primaryKey": false,
+ "notNull": false,
+ "default": true
+ },
+ "created_at": {
+ "name": "created_at",
+ "type": "timestamp",
+ "primaryKey": false,
+ "notNull": false,
+ "default": "now()"
+ },
+ "updated_at": {
+ "name": "updated_at",
+ "type": "timestamp",
+ "primaryKey": false,
+ "notNull": false,
+ "default": "now()"
+ }
+ },
+ "indexes": {},
+ "foreignKeys": {
+ "provider_models_provider_id_providers_id_fk": {
+ "name": "provider_models_provider_id_providers_id_fk",
+ "tableFrom": "provider_models",
+ "tableTo": "providers",
+ "columnsFrom": [
+ "provider_id"
+ ],
+ "columnsTo": [
+ "id"
+ ],
+ "onDelete": "cascade",
+ "onUpdate": "no action"
+ }
+ },
+ "compositePrimaryKeys": {},
+ "uniqueConstraints": {}
+ },
+ "public.provider_rate_limits": {
+ "name": "provider_rate_limits",
+ "schema": "",
+ "columns": {
+ "id": {
+ "name": "id",
+ "type": "serial",
+ "primaryKey": true,
+ "notNull": true
+ },
+ "provider_id": {
+ "name": "provider_id",
+ "type": "integer",
+ "primaryKey": false,
+ "notNull": false
+ },
+ "requests_per_minute": {
+ "name": "requests_per_minute",
+ "type": "integer",
+ "primaryKey": false,
+ "notNull": false
+ },
+ "tokens_per_minute": {
+ "name": "tokens_per_minute",
+ "type": "integer",
+ "primaryKey": false,
+ "notNull": false
+ },
+ "created_at": {
+ "name": "created_at",
+ "type": "timestamp",
+ "primaryKey": false,
+ "notNull": false,
+ "default": "now()"
+ },
+ "updated_at": {
+ "name": "updated_at",
+ "type": "timestamp",
+ "primaryKey": false,
+ "notNull": false,
+ "default": "now()"
+ }
+ },
+ "indexes": {},
+ "foreignKeys": {
+ "provider_rate_limits_provider_id_providers_id_fk": {
+ "name": "provider_rate_limits_provider_id_providers_id_fk",
+ "tableFrom": "provider_rate_limits",
+ "tableTo": "providers",
+ "columnsFrom": [
+ "provider_id"
+ ],
+ "columnsTo": [
+ "id"
+ ],
+ "onDelete": "cascade",
+ "onUpdate": "no action"
+ }
+ },
+ "compositePrimaryKeys": {},
+ "uniqueConstraints": {}
+ },
+ "public.providers": {
+ "name": "providers",
+ "schema": "",
+ "columns": {
+ "id": {
+ "name": "id",
+ "type": "serial",
+ "primaryKey": true,
+ "notNull": true
+ },
+ "name": {
+ "name": "name",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true
+ },
+ "kind": {
+ "name": "kind",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true
+ },
+ "environment": {
+ "name": "environment",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true
+ },
+ "base_url": {
+ "name": "base_url",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true
+ },
+ "api_key": {
+ "name": "api_key",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": false
+ },
+ "is_enabled": {
+ "name": "is_enabled",
+ "type": "boolean",
+ "primaryKey": false,
+ "notNull": false,
+ "default": true
+ },
+ "created_at": {
+ "name": "created_at",
+ "type": "timestamp",
+ "primaryKey": false,
+ "notNull": false,
+ "default": "now()"
+ },
+ "updated_at": {
+ "name": "updated_at",
+ "type": "timestamp",
+ "primaryKey": false,
+ "notNull": false,
+ "default": "now()"
+ }
+ },
+ "indexes": {},
+ "foreignKeys": {},
+ "compositePrimaryKeys": {},
+ "uniqueConstraints": {
+ "providers_name_unique": {
+ "name": "providers_name_unique",
+ "nullsNotDistinct": false,
+ "columns": [
+ "name"
+ ]
+ }
+ }
+ },
+ "public.batch_job_dependencies": {
+ "name": "batch_job_dependencies",
+ "schema": "",
+ "columns": {
+ "id": {
+ "name": "id",
+ "type": "serial",
+ "primaryKey": true,
+ "notNull": true
+ },
+ "job_id": {
+ "name": "job_id",
+ "type": "uuid",
+ "primaryKey": false,
+ "notNull": true
+ },
+ "depends_on_job_id": {
+ "name": "depends_on_job_id",
+ "type": "uuid",
+ "primaryKey": false,
+ "notNull": true
+ }
+ },
+ "indexes": {},
+ "foreignKeys": {
+ "batch_job_dependencies_job_id_batch_jobs_job_id_fk": {
+ "name": "batch_job_dependencies_job_id_batch_jobs_job_id_fk",
+ "tableFrom": "batch_job_dependencies",
+ "tableTo": "batch_jobs",
+ "columnsFrom": [
+ "job_id"
+ ],
+ "columnsTo": [
+ "job_id"
+ ],
+ "onDelete": "cascade",
+ "onUpdate": "no action"
+ },
+ "batch_job_dependencies_depends_on_job_id_batch_jobs_job_id_fk": {
+ "name": "batch_job_dependencies_depends_on_job_id_batch_jobs_job_id_fk",
+ "tableFrom": "batch_job_dependencies",
+ "tableTo": "batch_jobs",
+ "columnsFrom": [
+ "depends_on_job_id"
+ ],
+ "columnsTo": [
+ "job_id"
+ ],
+ "onDelete": "no action",
+ "onUpdate": "no action"
+ }
+ },
+ "compositePrimaryKeys": {},
+ "uniqueConstraints": {}
+ },
+ "public.batch_job_items": {
+ "name": "batch_job_items",
+ "schema": "",
+ "columns": {
+ "id": {
+ "name": "id",
+ "type": "serial",
+ "primaryKey": true,
+ "notNull": true
+ },
+ "job_id": {
+ "name": "job_id",
+ "type": "uuid",
+ "primaryKey": false,
+ "notNull": true
+ },
+ "image_path": {
+ "name": "image_path",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true
+ },
+ "perspective": {
+ "name": "perspective",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true
+ },
+ "status": {
+ "name": "status",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true,
+ "default": "'pending'"
+ },
+ "error": {
+ "name": "error",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": false
+ },
+ "started_at": {
+ "name": "started_at",
+ "type": "timestamp",
+ "primaryKey": false,
+ "notNull": false
+ },
+ "completed_at": {
+ "name": "completed_at",
+ "type": "timestamp",
+ "primaryKey": false,
+ "notNull": false
+ },
+ "processing_time": {
+ "name": "processing_time",
+ "type": "integer",
+ "primaryKey": false,
+ "notNull": false
+ }
+ },
+ "indexes": {},
+ "foreignKeys": {
+ "batch_job_items_job_id_batch_jobs_job_id_fk": {
+ "name": "batch_job_items_job_id_batch_jobs_job_id_fk",
+ "tableFrom": "batch_job_items",
+ "tableTo": "batch_jobs",
+ "columnsFrom": [
+ "job_id"
+ ],
+ "columnsTo": [
+ "job_id"
+ ],
+ "onDelete": "cascade",
+ "onUpdate": "no action"
+ }
+ },
+ "compositePrimaryKeys": {},
+ "uniqueConstraints": {}
+ },
+ "public.batch_jobs": {
+ "name": "batch_jobs",
+ "schema": "",
+ "columns": {
+ "job_id": {
+ "name": "job_id",
+ "type": "uuid",
+ "primaryKey": true,
+ "notNull": true,
+ "default": "gen_random_uuid()"
+ },
+ "type": {
+ "name": "type",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true
+ },
+ "status": {
+ "name": "status",
+ "type": "text",
+ "primaryKey": false,
+ "notNull": true,
+ "default": "'pending'"
+ },
+ "priority": {
+ "name": "priority",
+ "type": "integer",
+ "primaryKey": false,
+ "notNull": true,
+ "default": 100
+ },
+ "config": {
+ "name": "config",
+ "type": "jsonb",
+ "primaryKey": false,
+ "notNull": true
+ },
+ "total_images": {
+ "name": "total_images",
+ "type": "integer",
+ "primaryKey": false,
+ "notNull": true
+ },
+ "processed_images": {
+ "name": "processed_images",
+ "type": "integer",
+ "primaryKey": false,
+ "notNull": true,
+ "default": 0
+ },
+ "failed_images": {
+ "name": "failed_images",
+ "type": "integer",
+ "primaryKey": false,
+ "notNull": true,
+ "default": 0
+ },
+ "progress": {
+ "name": "progress",
+ "type": "integer",
+ "primaryKey": false,
+ "notNull": true,
+ "default": 0
+ },
+ "created_at": {
+ "name": "created_at",
+ "type": "timestamp",
+ "primaryKey": false,
+ "notNull": true,
+ "default": "now()"
+ },
+ "started_at": {
+ "name": "started_at",
+ "type": "timestamp",
+ "primaryKey": false,
+ "notNull": false
+ },
+ "completed_at": {
+ "name": "completed_at",
+ "type": "timestamp",
+ "primaryKey": false,
+ "notNull": false
+ },
+ "archived": {
+ "name": "archived",
+ "type": "boolean",
+ "primaryKey": false,
+ "notNull": true,
+ "default": false
+ }
+ },
+ "indexes": {},
+ "foreignKeys": {},
+ "compositePrimaryKeys": {},
+ "uniqueConstraints": {}
+ }
+ },
+ "enums": {},
+ "schemas": {},
+ "sequences": {},
+ "_meta": {
+ "columns": {},
+ "schemas": {},
+ "tables": {}
+ }
+}
\ No newline at end of file
diff --git a/servers/data_service/src/db/migrations/meta/_journal.json b/servers/data_service/src/db/migrations/meta/_journal.json
index cdde405a..0354da4a 100644
--- a/servers/data_service/src/db/migrations/meta/_journal.json
+++ b/servers/data_service/src/db/migrations/meta/_journal.json
@@ -15,6 +15,13 @@
"when": 1742476062417,
"tag": "20250320130742_pink_vertigo",
"breakpoints": true
+ },
+ {
+ "idx": 2,
+ "version": "7",
+ "when": 1743008873474,
+ "tag": "20250326170753_great_wallop",
+ "breakpoints": true
}
]
}
\ No newline at end of file
diff --git a/servers/data_service/src/db/schema/index.ts b/servers/data_service/src/db/schema/index.ts
index 2003e966..7ca32784 100644
--- a/servers/data_service/src/db/schema/index.ts
+++ b/servers/data_service/src/db/schema/index.ts
@@ -5,5 +5,5 @@
* This file exports all database schema definitions for use with Drizzle ORM.
*/
-export * from './providers';
+export * from '../../features/provider_config/db_providers';
export * from './batch_queue';
\ No newline at end of file
diff --git a/servers/data_service/src/db/seed/index.ts b/servers/data_service/src/db/seed/index.ts
index d6cf6a9c..7fbe0124 100644
--- a/servers/data_service/src/db/seed/index.ts
+++ b/servers/data_service/src/db/seed/index.ts
@@ -6,8 +6,8 @@
* Add new seed operations here in the desired order.
*/
+import { seedProviders } from '../../features/provider_config/seed_providers';
import { logger } from '../../utils/logger';
-import { seedProviders } from './providers';
/**
* Main seed function that orchestrates all seeding operations
diff --git a/servers/data_service/src/features/provider_config/api-key-manager.ts b/servers/data_service/src/features/provider_config/api-key-manager.ts
new file mode 100644
index 00000000..af233074
--- /dev/null
+++ b/servers/data_service/src/features/provider_config/api-key-manager.ts
@@ -0,0 +1,71 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * API Key Manager
+ *
+ * This module handles API key encryption, decryption, and management operations.
+ */
+
+import { decryptApiKey, encryptApiKey } from "../../utils/encryption";
+import { logger } from "../../utils/logger";
+import type { Provider } from "./schemas";
+
+// Simple type for objects with an optional API key property
+export type WithApiKey = {
+ apiKey: string | null | undefined;
+};
+
+/**
+ * Processes an API key for update operations
+ * Determines if we should use existing key, encrypt a new key, or clear the key
+ */
+export const processApiKeyForUpdate = async (
+ currentProvider: WithApiKey,
+ newApiKeyValue: string | undefined | null
+): Promise => {
+ // CASE 1: Keep existing key if no new value provided
+ if (newApiKeyValue === undefined) {
+ logger.debug("Keeping existing API key - no change requested");
+ return currentProvider.apiKey ?? null;
+ }
+
+ // CASE 2: Explicitly clear the key
+ if (newApiKeyValue === null || newApiKeyValue === "") {
+ logger.debug("API key explicitly cleared in update");
+ return null;
+ }
+
+ // CASE 3: Encrypt new key value
+ logger.debug("Encrypting new API key for provider update");
+ const encryptedKey = await encryptApiKey(newApiKeyValue);
+ logger.debug("API key encrypted for update");
+ return encryptedKey;
+};
+
+/**
+ * Safely decrypts a provider's API key for client response
+ */
+export const decryptProviderApiKey = async (
+ provider: Provider
+): Promise => {
+ const providerCopy = { ...provider };
+
+ if (providerCopy.apiKey) {
+ logger.debug({
+ providerId: provider.id,
+ encryptedKeyLength: providerCopy.apiKey.length
+ }, "Decrypting API key for provider");
+
+ providerCopy.apiKey = await decryptApiKey(providerCopy.apiKey);
+
+ // Log the result of decryption (without showing the actual key)
+ logger.debug({
+ providerId: provider.id,
+ apiKeyPresent: Boolean(providerCopy.apiKey),
+ apiKeyLength: providerCopy.apiKey ? providerCopy.apiKey.length : 0
+ }, "Provider API key decryption result");
+ } else {
+ logger.debug({ providerId: provider.id }, "No API key to decrypt for provider");
+ }
+
+ return providerCopy;
+};
\ No newline at end of file
diff --git a/servers/data_service/src/features/provider_config/controller.ts b/servers/data_service/src/features/provider_config/controller.ts
new file mode 100644
index 00000000..ac07f1a5
--- /dev/null
+++ b/servers/data_service/src/features/provider_config/controller.ts
@@ -0,0 +1,897 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Provider Controllers
+ *
+ * This module defines the controller functions for provider management.
+ */
+
+import { eq } from "drizzle-orm";
+import type { Context } from "hono";
+import type { Logger } from "pino";
+import { db } from "../../db";
+import { providerModels, providerRateLimits, providers } from "../../db/schema";
+import { decryptApiKey, encryptApiKey } from "../../utils/encryption";
+import { processApiKeyForUpdate } from "./api-key-manager";
+import type { Provider, ProviderCreate, ProviderUpdate } from "./schemas";
+
+
+
+// Type for the validated parameters
+type ValidatedParams = {
+ id: string;
+};
+
+/**
+ * Get all providers
+ */
+export const getProviders = async (c: Context) => {
+ const { logger } = c.var;
+ logger.debug("Fetching all providers");
+
+ try {
+ const allProviders = await db.query.providers.findMany({
+ with: {
+ models: true,
+ rateLimits: true,
+ },
+ });
+
+ // Decrypt API keys before returning to client
+ for (const provider of allProviders) {
+ if (provider.apiKey) {
+ logger.debug(
+ { providerId: provider.id },
+ "Decrypting API key for provider",
+ );
+ provider.apiKey = await decryptApiKey(provider.apiKey);
+
+ // Log whether API key is present after decryption (without showing the actual key)
+ logger.debug(
+ {
+ providerId: provider.id,
+ apiKeyPresent: Boolean(provider.apiKey),
+ apiKeyLength: provider.apiKey ? provider.apiKey.length : 0,
+ },
+ "Provider API key decryption result",
+ );
+ } else {
+ logger.debug(
+ { providerId: provider.id },
+ "No API key to decrypt for provider",
+ );
+ }
+ }
+
+ logger.info(
+ { count: allProviders.length },
+ "Providers fetched successfully",
+ );
+ return c.json(allProviders);
+ } catch (error) {
+ // Log the full error details for debugging
+ logger.error(
+ {
+ error,
+ message: error instanceof Error ? error.message : "Unknown error",
+ stack: error instanceof Error ? error.stack : undefined,
+ },
+ "Error fetching providers",
+ );
+
+ // Return a more informative error response
+ return c.json(
+ {
+ error: "Failed to fetch providers",
+ message:
+ error instanceof Error ? error.message : "Unknown database error",
+ },
+ 500,
+ );
+ }
+};
+
+/**
+ * Get a specific provider by ID
+ */
+export const getProvider = async (c: Context) => {
+ const { logger } = c.var;
+ const id = c.req.param("id");
+ logger.debug({ id }, "Fetching provider by ID");
+
+ try {
+ // @ts-ignore - Hono OpenAPI validation types are not properly recognized
+ const { id: paramId } = c.req.valid("param") as ValidatedParams;
+ if (id === paramId) {
+ const provider = await db.query.providers.findFirst({
+ where: eq(providers.id, Number.parseInt(id)),
+ with: {
+ models: true,
+ rateLimits: true,
+ },
+ });
+
+ if (!provider) {
+ logger.debug({ id }, "Provider not found");
+ return c.json({ error: "Provider not found" }, 404);
+ }
+
+ // Decrypt API key before returning to client
+ if (provider.apiKey) {
+ logger.debug(
+ {
+ providerId: id,
+ encryptedKeyLength: provider.apiKey.length,
+ },
+ "Decrypting API key for provider",
+ );
+
+ provider.apiKey = await decryptApiKey(provider.apiKey);
+
+ // Log the result of decryption (without showing the actual key)
+ logger.debug(
+ {
+ providerId: id,
+ apiKeyPresent: Boolean(provider.apiKey),
+ apiKeyLength: provider.apiKey ? provider.apiKey.length : 0,
+ },
+ "Provider API key decryption result",
+ );
+ } else {
+ logger.debug({ providerId: id }, "No API key to decrypt for provider");
+ }
+
+ logger.info({ providerId: id }, "Provider fetched successfully");
+ return c.json(provider);
+ }
+
+ // If ID mismatch, return not found (removed else clause)
+ logger.warn({ providerId: id }, "Provider not found");
+ return c.json({ error: "Provider not found" }, 404);
+ } catch (error) {
+ logger.error({ error, providerId: id }, "Error fetching provider");
+ return c.json({ error: "Failed to fetch provider" }, 500);
+ }
+};
+
+/**
+ * Validates provider data during creation
+ */
+const validateProviderCreate = (
+ data: ProviderCreate,
+): Record => {
+ const validationErrors: Record = {};
+
+ // Name validation
+ if (!data.name) {
+ validationErrors.name = ["Provider name is required"];
+ } else if (data.name.trim() === "") {
+ validationErrors.name = ["Provider name cannot be just whitespace"];
+ } else if (data.name.length < 3) {
+ validationErrors.name = [
+ "Provider name must be at least 3 characters long",
+ ];
+ }
+
+ // Kind validation
+ if (!data.kind) {
+ validationErrors.kind = ["Provider kind is required"];
+ } else if (data.kind.trim() === "") {
+ validationErrors.kind = ["Provider kind cannot be just whitespace"];
+ }
+
+ // Base URL validation
+ if (!data.baseUrl) {
+ validationErrors.baseUrl = ["Base URL is required"];
+ } else {
+ try {
+ new URL(data.baseUrl);
+ } catch (e) {
+ validationErrors.baseUrl = ["Base URL must be a valid URL"];
+ }
+ }
+
+ // Environment validation
+ if (!data.environment) {
+ validationErrors.environment = ["Environment is required"];
+ } else if (!["cloud", "local"].includes(data.environment)) {
+ validationErrors.environment = [
+ 'Environment must be either "cloud" or "local"',
+ ];
+ }
+
+ return validationErrors;
+};
+
+/**
+ * Handles specific error cases for provider creation
+ */
+const handleProviderCreateError = (c: Context, error: unknown) => {
+ const logger = c.var.logger;
+
+ logger.error(
+ {
+ error,
+ message: error instanceof Error ? error.message : "Unknown error",
+ stack: error instanceof Error ? error.stack : undefined,
+ },
+ "Error creating provider",
+ );
+
+ // Check for specific error types to provide better error messages
+ if (error instanceof Error) {
+ // Handle database unique constraint violation
+ if (
+ error.message.includes("duplicate key value violates unique constraint")
+ ) {
+ return c.json(
+ {
+ status: "error",
+ statusCode: 400,
+ message: "A provider with that name already exists",
+ details: {
+ type: "UniqueConstraintViolation",
+ },
+ },
+ 400,
+ );
+ }
+
+ // Handle other database errors
+ if (error.message.includes("database") || error.message.includes("query")) {
+ return c.json(
+ {
+ status: "error",
+ statusCode: 500,
+ message: "Database error occurred while creating provider",
+ details: {
+ type: "DatabaseError",
+ },
+ },
+ 500,
+ );
+ }
+
+ // Handle validation errors from Zod or other validators
+ if (error.message.includes("validation")) {
+ return c.json(
+ {
+ status: "error",
+ statusCode: 400,
+ message: "Validation error",
+ details: {
+ message: error.message,
+ },
+ },
+ 400,
+ );
+ }
+ }
+
+ // Generic error fallback
+ return c.json(
+ {
+ status: "error",
+ statusCode: 500,
+ message: "Failed to create provider",
+ details: error instanceof Error ? { message: error.message } : undefined,
+ },
+ 500,
+ );
+};
+
+/**
+ * Creates provider data in the database
+ */
+const saveProviderToDatabase = async (
+ tx: typeof db,
+ providerData: Omit,
+ models?: ProviderCreate["models"],
+ rateLimits?: ProviderCreate["rateLimits"],
+) => {
+ // Insert provider
+ const [provider] = await tx
+ .insert(providers)
+ .values({
+ ...providerData,
+ createdAt: new Date(),
+ updatedAt: new Date(),
+ })
+ .returning();
+
+ // Insert models if provided
+ if (models && models.length > 0) {
+ await tx.insert(providerModels).values(
+ models.map((model) => ({
+ providerId: provider.id,
+ name: model.name,
+ isEnabled: model.isEnabled ?? true,
+ createdAt: new Date(),
+ updatedAt: new Date(),
+ })),
+ );
+ }
+
+ // Insert rate limits if provided
+ if (rateLimits) {
+ await tx.insert(providerRateLimits).values({
+ providerId: provider.id,
+ requestsPerMinute: rateLimits.requestsPerMinute,
+ tokensPerMinute: rateLimits.tokensPerMinute,
+ createdAt: new Date(),
+ updatedAt: new Date(),
+ });
+ }
+
+ // Return the created provider with relations
+ return await tx.query.providers.findFirst({
+ where: eq(providers.id, provider.id),
+ with: {
+ models: true,
+ rateLimits: true,
+ },
+ });
+};
+
+/**
+ * Create a new provider
+ */
+export const createProvider = async (c: Context) => {
+ const { logger } = c.var;
+
+ try {
+ // @ts-ignore - Hono OpenAPI validation types are not properly recognized
+ const data = c.req.valid("json") as ProviderCreate;
+ logger.debug({ data }, "Creating new provider");
+
+ // Validate the provider data
+ const validationErrors = validateProviderCreate(data);
+
+ // If there are validation errors, return them
+ if (Object.keys(validationErrors).length > 0) {
+ logger.debug(
+ { validationErrors },
+ "Validation errors in provider creation",
+ );
+ return c.json(
+ {
+ status: "error",
+ statusCode: 400,
+ message: "Validation failed",
+ validationErrors,
+ },
+ 400,
+ );
+ }
+
+ // Extract models and rate limits if provided
+ const { models, rateLimits, ...providerData } = data;
+
+ // Encrypt API key if provided
+ if (providerData.apiKey) {
+ providerData.apiKey = await encryptApiKey(providerData.apiKey);
+ }
+
+ // Start a transaction
+ const result = await db.transaction(async (tx) => {
+ return saveProviderToDatabase(tx, providerData, models, rateLimits);
+ });
+
+ logger.info(
+ {
+ provider: {
+ id: result?.id,
+ name: result?.name,
+ kind: result?.kind,
+ },
+ },
+ "Provider created successfully",
+ );
+ return c.json(result, 201);
+ } catch (error) {
+ return handleProviderCreateError(c, error);
+ }
+};
+
+/**
+ * Validates provider data during update
+ */
+const validateProviderUpdate = (
+): Record => {
+ const validationErrors: Record = {};
+
+ // Add validation logic here if needed
+
+ return validationErrors;
+};
+
+/**
+ * Checks if a value has changed
+ */
+const hasValueChanged = (existingValue: unknown, newValue: unknown): boolean => {
+ return existingValue !== newValue && newValue !== undefined;
+};
+
+/**
+ * Creates a log entry for API key changes
+ */
+const createApiKeyLogEntry = (
+ existingValue: unknown,
+ value: unknown
+): { from: unknown; to: unknown } => {
+ return {
+ from: existingValue ? "[ENCRYPTED]" : "[EMPTY]",
+ to: value ? "[NEW_VALUE]" : "[EMPTY]",
+ };
+};
+
+/**
+ * Logs an API key change
+ */
+const logApiKeyChange = (
+ logger: Logger,
+ id: string,
+ existingValue: unknown,
+ value: unknown
+): void => {
+ logger.info(
+ { providerId: id },
+ `Updating API key from ${existingValue ? "existing value" : "empty"} to ${value ? "new value" : "empty"}`,
+ );
+};
+
+/**
+ * Logs all field changes
+ */
+const logAllFieldChanges = (
+ logger: Logger,
+ id: string,
+ existingProvider: Record,
+ updatedFields: Record
+): void => {
+ if (Object.keys(updatedFields).length > 0) {
+ logger.info(
+ {
+ providerId: id,
+ provider: existingProvider.name,
+ updatedFields,
+ },
+ "Provider fields being updated",
+ );
+ }
+};
+
+/**
+ * Logs field changes between existing provider and update data
+ */
+const logFieldChanges = (
+ logger: Logger,
+ id: string,
+ existingProvider: Record,
+ providerData: Partial
+): Record => {
+ const updatedFields: Record = {};
+
+ // Compare each field being updated with existing values
+ for (const [key, value] of Object.entries(providerData)) {
+ const existingValue = (existingProvider as Record)[key];
+
+ // Only process if the value is actually changing
+ if (hasValueChanged(existingValue, value)) {
+ // Special handling for API key to avoid logging actual values
+ if (key === "apiKey") {
+ updatedFields[key] = createApiKeyLogEntry(existingValue, value);
+ logApiKeyChange(logger, id, existingValue, value);
+ } else {
+ updatedFields[key] = { from: existingValue, to: value };
+ }
+ }
+ }
+
+ // Log all field changes
+ logAllFieldChanges(logger, id, existingProvider, updatedFields);
+
+ return updatedFields;
+};
+
+/**
+ * Processes model updates
+ */
+const processModelUpdates = async (
+ tx: typeof db,
+ id: string,
+ models: ProviderUpdate["models"],
+) => {
+ if (!models || models.length === 0) return;
+
+ // First, delete existing models
+ await tx
+ .delete(providerModels)
+ .where(eq(providerModels.providerId, Number.parseInt(id)));
+
+ // Then insert new models
+ await tx.insert(providerModels).values(
+ models.map((model) => {
+ // Create base model data object
+ const modelData = {
+ providerId: Number.parseInt(id),
+ name: model.name,
+ isEnabled: model.isEnabled ?? true,
+ createdAt: new Date(),
+ updatedAt: new Date(),
+ };
+
+ // Only include ID if it exists and is a number
+ if (model.id !== undefined && typeof model.id === "number") {
+ return {
+ ...modelData,
+ id: model.id,
+ };
+ }
+
+ // Let database auto-generate ID
+ return modelData;
+ }),
+ );
+};
+
+/**
+ * Processes rate limit updates
+ */
+const processRateLimitUpdates = async (
+ tx: typeof db,
+ id: string,
+ rateLimits: ProviderUpdate["rateLimits"],
+) => {
+ if (!rateLimits) return;
+
+ // Check if rate limits exist
+ const existingRateLimits = await tx.query.providerRateLimits.findFirst({
+ where: eq(providerRateLimits.providerId, Number.parseInt(id)),
+ });
+
+ if (existingRateLimits) {
+ // Update existing rate limits
+ await tx
+ .update(providerRateLimits)
+ .set({
+ requestsPerMinute: rateLimits.requestsPerMinute,
+ tokensPerMinute: rateLimits.tokensPerMinute,
+ updatedAt: new Date(),
+ })
+ .where(eq(providerRateLimits.providerId, Number.parseInt(id)));
+ } else {
+ // Insert new rate limits
+ await tx.insert(providerRateLimits).values({
+ providerId: Number.parseInt(id),
+ requestsPerMinute: rateLimits.requestsPerMinute,
+ tokensPerMinute: rateLimits.tokensPerMinute,
+ createdAt: new Date(),
+ updatedAt: new Date(),
+ });
+ }
+};
+
+/**
+ * Handles database updates for a provider
+ */
+const updateProviderInDatabase = async (
+ tx: typeof db,
+ id: string,
+ providerData: Partial,
+ models?: ProviderUpdate["models"],
+ rateLimits?: ProviderUpdate["rateLimits"],
+): Promise => {
+ // Get the current provider from the database to ensure we have the latest data
+ const currentProvider = await tx.query.providers.findFirst({
+ where: eq(providers.id, Number.parseInt(id)),
+ });
+
+ if (!currentProvider) {
+ throw new Error(`Provider not found with id ${id}`);
+ }
+
+ // Use the API key manager to handle API key updates
+ const apiKeyToUse = await processApiKeyForUpdate(
+ currentProvider,
+ providerData.apiKey,
+ );
+
+ // Update provider with the appropriate API key
+ await tx
+ .update(providers)
+ .set({
+ ...providerData,
+ apiKey: apiKeyToUse, // Use the properly determined API key
+ updatedAt: new Date(),
+ })
+ .where(eq(providers.id, Number.parseInt(id)));
+
+ // Update models if provided
+ if (models && models.length > 0) {
+ await processModelUpdates(tx, id, models);
+ }
+
+ // Update rate limits if provided
+ if (rateLimits) {
+ await processRateLimitUpdates(tx, id, rateLimits);
+ }
+
+ // Return the updated provider with relations
+ const result = await tx.query.providers.findFirst({
+ where: eq(providers.id, Number.parseInt(id)),
+ with: {
+ models: true,
+ rateLimits: true,
+ },
+ });
+
+ // Cast to ensure type safety
+ return result as Provider | null;
+};
+
+/**
+ * Handles specific error cases for provider updates
+ */
+const handleProviderUpdateError = (c: Context, error: unknown) => {
+ const logger = c.var.logger;
+ const id = c.req.param("id");
+
+ logger.error(
+ {
+ error,
+ message: error instanceof Error ? error.message : "Unknown error",
+ stack: error instanceof Error ? error.stack : undefined,
+ providerId: id,
+ },
+ "Error updating provider",
+ );
+
+ // Return error response
+ return c.json(
+ {
+ status: "error",
+ statusCode: 500,
+ message:
+ error instanceof Error ? error.message : "Failed to update provider",
+ errorType: error instanceof Error ? error.name : "Unknown",
+ },
+ 500,
+ );
+};
+
+/**
+ * Checks if a provider exists
+ */
+const checkProviderExists = async (id: string): Promise => {
+ const provider = await db.query.providers.findFirst({
+ where: eq(providers.id, Number.parseInt(id)),
+ });
+ return !!provider;
+};
+
+/**
+ * Fetches existing provider with models and rate limits
+ */
+const fetchExistingProvider = async (id: string): Promise => {
+ const provider = await db.query.providers.findFirst({
+ where: eq(providers.id, Number.parseInt(id)),
+ with: {
+ models: true,
+ rateLimits: true,
+ },
+ });
+
+ // Cast to ensure type safety
+ return provider as Provider | null;
+};
+
+/**
+ * Logs API key status for debugging
+ */
+const logApiKeyStatus = (
+ logger: Logger,
+ id: string,
+ existingProvider: Provider,
+ providerData: Partial
+): void => {
+ logger.debug(
+ {
+ providerId: id,
+ original_apiKey_present: existingProvider.apiKey !== null,
+ update_apiKey_present: "apiKey" in providerData,
+ update_apiKey_value_present:
+ providerData.apiKey !== undefined && providerData.apiKey !== null,
+ },
+ "API key update status",
+ );
+};
+
+/**
+ * Logs model updates
+ */
+const logModelUpdates = (
+ logger: Logger,
+ id: string,
+ models: ProviderUpdate["models"]
+): void => {
+ if (models && models.length > 0) {
+ logger.info(
+ { providerId: id, modelCount: models.length },
+ "Updating provider models",
+ );
+ }
+};
+
+/**
+ * Fetches and logs rate limit information
+ */
+const fetchAndLogRateLimits = async (
+ logger: Logger,
+ id: string,
+ existingProvider: Provider,
+ rateLimits: ProviderUpdate["rateLimits"]
+): Promise => {
+ if (!rateLimits) return;
+
+ // Query for existing rate limits
+ const existingRateLimitsQuery =
+ await db.query.providerRateLimits.findFirst({
+ where: eq(providerRateLimits.providerId, Number.parseInt(id)),
+ });
+
+ logger.info(
+ {
+ providerId: id,
+ provider: existingProvider.name,
+ existingRateLimits: existingRateLimitsQuery
+ ? {
+ requestsPerMinute: existingRateLimitsQuery.requestsPerMinute,
+ tokensPerMinute: existingRateLimitsQuery.tokensPerMinute,
+ }
+ : { requestsPerMinute: null, tokensPerMinute: null },
+ newRateLimits: rateLimits,
+ },
+ "Updating provider rate limits",
+ );
+};
+
+/**
+ * Log the result of a successful provider update
+ */
+const logSuccessfulUpdate = (
+ logger: Logger,
+ id: string,
+ result: Provider | null
+): void => {
+ logger.info(
+ {
+ providerId: id,
+ provider: {
+ name: result?.name,
+ kind: result?.kind,
+ },
+ },
+ "Provider updated successfully",
+ );
+};
+
+/**
+ * Update an existing provider
+ */
+export const updateProvider = async (c: Context) => {
+ const { logger } = c.var;
+ const id = c.req.param("id");
+
+ try {
+ // @ts-ignore - Hono OpenAPI validation types are not properly recognized
+ const data = c.req.valid("json") as ProviderUpdate;
+ logger.debug(
+ {
+ id,
+ data: {
+ ...data,
+ apiKey: data.apiKey !== undefined ? "[PRESENT]" : "[MISSING]",
+ },
+ },
+ "Updating provider",
+ );
+
+ // Check if provider exists
+ const existingProvider = await fetchExistingProvider(id);
+
+ if (!existingProvider) {
+ logger.debug({ id }, "Provider not found for update");
+ return c.json(
+ {
+ status: "error",
+ statusCode: 404,
+ message: "Provider not found",
+ providerId: id,
+ },
+ 404,
+ );
+ }
+
+ // Validate update data
+ const validationErrors = validateProviderUpdate(data);
+
+ // If there are validation errors, return them
+ if (Object.keys(validationErrors).length > 0) {
+ logger.debug(
+ { validationErrors },
+ "Validation errors in provider update",
+ );
+ return c.json(
+ {
+ status: "error",
+ statusCode: 400,
+ message: "Validation failed",
+ providerId: id,
+ validationErrors,
+ },
+ 400,
+ );
+ }
+
+ // Extract models and rate limits if provided
+ const { models, rateLimits, ...providerData } = data;
+
+ // Log API key status for debugging
+ logApiKeyStatus(logger, id, existingProvider, providerData);
+
+ // Log field changes
+ logFieldChanges(logger, id, existingProvider, providerData);
+
+ // Log model and rate limit changes
+ logModelUpdates(logger, id, models);
+
+ // Log rate limit changes if provided
+ await fetchAndLogRateLimits(logger, id, existingProvider, rateLimits);
+
+ // Update the provider in the database
+ const result = await db.transaction(async (tx) => {
+ return updateProviderInDatabase(tx, id, providerData, models, rateLimits);
+ });
+
+ // Log successful update (only if result is not null)
+ if (result) {
+ logSuccessfulUpdate(logger, id, result);
+ }
+
+ return c.json(result);
+ } catch (error) {
+ return handleProviderUpdateError(c, error);
+ }
+};
+
+/**
+ * Delete a provider
+ */
+export const deleteProvider = async (c: Context) => {
+ const { logger } = c.var;
+ const id = c.req.param("id");
+
+ logger.debug({ id }, "Deleting provider");
+
+ try {
+ // Check if provider exists
+ const providerExists = await checkProviderExists(id);
+
+ if (!providerExists) {
+ logger.debug({ id }, "Provider not found for deletion");
+ return c.json({ error: "Provider not found" }, 404);
+ }
+
+ // Delete provider (cascade will handle related records)
+ await db.delete(providers).where(eq(providers.id, Number.parseInt(id)));
+
+ logger.info({ providerId: id }, "Provider deleted successfully");
+ return c.json({
+ success: true,
+ message: "Provider deleted successfully",
+ });
+ } catch (error) {
+ logger.error({ error }, "Error deleting provider");
+ return c.json({ error: "Failed to delete provider" }, 500);
+ }
+};
diff --git a/servers/data_service/src/db/schema/providers.ts b/servers/data_service/src/features/provider_config/db_providers.ts
similarity index 95%
rename from servers/data_service/src/db/schema/providers.ts
rename to servers/data_service/src/features/provider_config/db_providers.ts
index 80552a18..780f998c 100644
--- a/servers/data_service/src/db/schema/providers.ts
+++ b/servers/data_service/src/features/provider_config/db_providers.ts
@@ -1,6 +1,6 @@
-// SPDX-License-Identifier: Apache-2.0
-import { pgTable, serial, text, boolean, integer, timestamp } from 'drizzle-orm/pg-core';
import { relations } from 'drizzle-orm';
+// SPDX-License-Identifier: Apache-2.0
+import { boolean, integer, pgTable, serial, text, timestamp } from 'drizzle-orm/pg-core';
/**
* Providers table schema
@@ -11,7 +11,6 @@ export const providers = pgTable('providers', {
name: text('name').notNull().unique(),
kind: text('kind').notNull(), // openai, gemini, etc.
environment: text('environment').notNull(), // cloud, local
- envVar: text('env_var').notNull(),
baseUrl: text('base_url').notNull(),
apiKey: text('api_key'), // Will store encrypted API key
isEnabled: boolean('is_enabled').default(true),
diff --git a/servers/data_service/src/api/routes/providers.ts b/servers/data_service/src/features/provider_config/routes.ts
similarity index 76%
rename from servers/data_service/src/api/routes/providers.ts
rename to servers/data_service/src/features/provider_config/routes.ts
index d2724668..85906778 100644
--- a/servers/data_service/src/api/routes/providers.ts
+++ b/servers/data_service/src/features/provider_config/routes.ts
@@ -7,9 +7,9 @@
import { OpenAPIHono, createRoute } from '@hono/zod-openapi';
import { z } from 'zod';
-import * as handlers from '../controllers/providers';
-import { providerSchema, providerCreateSchema, providerUpdateSchema, providerApiKeySchema } from '../schemas/providers';
-import { commonResponses, notFoundResponse, invalidRequestResponse, successResponse } from '../schemas/common';
+import { commonResponses, invalidRequestResponse, notFoundResponse, successResponse } from '../../api/schemas/common';
+import * as handlers from './controller';
+import { providerCreateSchema, providerSchema, providerUpdateSchema } from './schemas';
// Create a new OpenAPI router
const router = new OpenAPIHono();
@@ -149,46 +149,12 @@ const deleteProviderRoute = createRoute({
},
});
-const updateProviderApiKeyRoute = createRoute({
- method: 'put',
- path: '/{id}/api-key',
- tags: ['Providers'],
- summary: 'Update provider API key',
- description: 'Updates the API key for an existing provider',
- request: {
- params: z.object({
- id: z.string().min(1),
- }),
- body: {
- content: {
- 'application/json': {
- schema: providerApiKeySchema,
- },
- },
- },
- },
- responses: {
- 200: {
- description: 'API key updated successfully',
- content: {
- 'application/json': {
- schema: successResponse,
- },
- },
- },
- ...notFoundResponse,
- ...invalidRequestResponse,
- ...commonResponses,
- },
-});
-
// Register routes with handlers
router.openapi(getAllProvidersRoute, handlers.getProviders);
router.openapi(getProviderRoute, handlers.getProvider);
router.openapi(createProviderRoute, handlers.createProvider);
router.openapi(updateProviderRoute, handlers.updateProvider);
router.openapi(deleteProviderRoute, handlers.deleteProvider);
-router.openapi(updateProviderApiKeyRoute, handlers.updateProviderApiKey);
// Export the router
export const providerRoutes = router;
\ No newline at end of file
diff --git a/servers/data_service/src/api/schemas/providers.ts b/servers/data_service/src/features/provider_config/schemas.ts
similarity index 84%
rename from servers/data_service/src/api/schemas/providers.ts
rename to servers/data_service/src/features/provider_config/schemas.ts
index 89f2856f..0435e744 100644
--- a/servers/data_service/src/api/schemas/providers.ts
+++ b/servers/data_service/src/features/provider_config/schemas.ts
@@ -13,7 +13,6 @@ export const providerSchema = z.object({
name: z.string().min(1, 'Name is required'),
kind: z.string().min(1, 'Kind is required'),
environment: z.enum(['cloud', 'local']),
- envVar: z.string().min(1, 'Environment variable name is required'),
baseUrl: z.string().url('Must be a valid URL'),
apiKey: z.string().optional(),
isEnabled: z.boolean().default(true),
@@ -44,7 +43,6 @@ export const providerCreateSchema = z.object({
name: z.string().min(1, 'Name is required'),
kind: z.string().min(1, 'Kind is required'),
environment: z.enum(['cloud', 'local']),
- envVar: z.string().min(1, 'Environment variable name is required'),
baseUrl: z.string().url('Must be a valid URL'),
apiKey: z.string().optional(),
isEnabled: z.boolean().default(true),
@@ -65,12 +63,12 @@ export const providerUpdateSchema = z.object({
name: z.string().min(1, 'Name is required').optional(),
kind: z.string().min(1, 'Kind is required').optional(),
environment: z.enum(['cloud', 'local']).optional(),
- envVar: z.string().min(1, 'Environment variable name is required').optional(),
baseUrl: z.string().url('Must be a valid URL').optional(),
+ apiKey: z.string().optional(),
isEnabled: z.boolean().optional(),
models: z.array(
z.object({
- id: z.number().optional(),
+ id: z.number().or(z.string()).optional(),
name: z.string().min(1, 'Model name is required'),
isEnabled: z.boolean().default(true),
})
@@ -81,13 +79,7 @@ export const providerUpdateSchema = z.object({
}).optional(),
});
-// Schema for updating a provider's API key
-export const providerApiKeySchema = z.object({
- apiKey: z.string().min(1, 'API key is required'),
-});
-
// Export types
export type Provider = z.infer;
export type ProviderCreate = z.infer;
-export type ProviderUpdate = z.infer;
-export type ProviderApiKey = z.infer;
\ No newline at end of file
+export type ProviderUpdate = z.infer;
\ No newline at end of file
diff --git a/servers/data_service/src/db/seed/providers.ts b/servers/data_service/src/features/provider_config/seed_providers.ts
similarity index 96%
rename from servers/data_service/src/db/seed/providers.ts
rename to servers/data_service/src/features/provider_config/seed_providers.ts
index 2361d817..5d515bf4 100644
--- a/servers/data_service/src/db/seed/providers.ts
+++ b/servers/data_service/src/features/provider_config/seed_providers.ts
@@ -5,9 +5,9 @@
* This script seeds the database with predefined provider configurations.
*/
-import { db } from '../index';
-import { providers, providerModels} from '../schema';
import { eq } from 'drizzle-orm';
+import { db } from '../../db/index';
+import { providerModels, providers } from '../../db/schema';
import { logger } from '../../utils/logger';
// Define interfaces for provider configurations
@@ -105,7 +105,6 @@ export async function seedProviders() {
name,
kind: providerConfig.kind,
environment: providerConfig.environment,
- envVar: providerConfig.env_var,
baseUrl: providerConfig.base_url,
isEnabled: true
}).returning();
diff --git a/servers/data_service/src/utils/error-handler.ts b/servers/data_service/src/utils/error-handler.ts
new file mode 100644
index 00000000..42f60692
--- /dev/null
+++ b/servers/data_service/src/utils/error-handler.ts
@@ -0,0 +1,142 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Error Handler
+ *
+ * Utility for handling and formatting errors in a consistent way.
+ */
+
+import type { Context } from "hono";
+import { ZodError } from "zod";
+import { logger } from "./logger";
+
+interface ErrorResponse {
+ status: "error";
+ statusCode: number;
+ message: string;
+ timestamp: string;
+ path?: string;
+ details?: unknown;
+ validationErrors?: Record;
+}
+
+/**
+ * Creates a standardized error response object
+ */
+export function createErrorResponse(
+ message: string,
+ statusCode = 400,
+ details?: unknown,
+ path?: string,
+): ErrorResponse {
+ return {
+ status: "error",
+ statusCode,
+ message,
+ timestamp: new Date().toISOString(),
+ path,
+ details,
+ };
+}
+
+/**
+ * Handles validation errors from Zod
+ */
+export function handleValidationError(error: ZodError, c: Context): Response {
+ const validationErrors: Record = {};
+
+ for (const err of error.errors) {
+ const path = err.path.join(".");
+ if (!validationErrors[path]) {
+ validationErrors[path] = [];
+ }
+ validationErrors[path].push(err.message);
+ }
+
+ const response = createErrorResponse(
+ "Validation error",
+ 400,
+ undefined,
+ c.req.path,
+ );
+
+ response.validationErrors = validationErrors;
+
+ logger.debug({ validationErrors }, "Validation errors");
+
+ return c.json(response, 400);
+}
+
+/**
+ * Handles general application errors
+ */
+export function handleApplicationError(error: unknown, c: Context): Response {
+ if (error instanceof ZodError) {
+ return handleValidationError(error, c);
+ }
+
+ const statusCode = 500;
+ let message = "Internal server error";
+ let details = undefined;
+
+ if (error instanceof Error) {
+ message = error.message;
+ details = {
+ name: error.name,
+ stack: process.env.NODE_ENV === "development" ? error.stack : undefined,
+ };
+ } else if (typeof error === "string") {
+ message = error;
+ } else if (typeof error === "object" && error !== null) {
+ message = "Application error";
+ details = error;
+ }
+
+ logger.error({ error, path: c.req.path }, message);
+
+ const response = createErrorResponse(
+ message,
+ statusCode,
+ details,
+ c.req.path,
+ );
+ return c.json(response, statusCode);
+}
+
+/**
+ * Error handling middleware for Hono
+ */
+export function errorHandlerMiddleware(options: { logErrors?: boolean } = {}) {
+ return async (c: Context, next: () => Promise) => {
+ try {
+ await next();
+ } catch (error) {
+ if (options.logErrors !== false) {
+ logger.error(
+ {
+ error,
+ path: c.req.path,
+ method: c.req.method,
+ headers: Object.fromEntries(c.req.headers.entries()),
+ },
+ "Error caught in middleware",
+ );
+ }
+
+ return handleApplicationError(error, c);
+ }
+ };
+}
+
+/**
+ * Not found error handler for Hono
+ */
+export function notFoundHandler(c: Context) {
+ const response = createErrorResponse(
+ `Route not found: ${c.req.method} ${c.req.path}`,
+ 404,
+ undefined,
+ c.req.path,
+ );
+
+ return c.json(response, 404);
+}
diff --git a/servers/data_service/src/utils/logger.ts b/servers/data_service/src/utils/logger.ts
index 8371dc8b..869e87cf 100644
--- a/servers/data_service/src/utils/logger.ts
+++ b/servers/data_service/src/utils/logger.ts
@@ -8,6 +8,19 @@
import pino from 'pino';
import { env } from '../env';
+// Define request and response types for serializers
+interface PinoRequest {
+ method: string;
+ url: string;
+ headers: Record;
+ body?: unknown;
+}
+
+interface PinoResponse {
+ statusCode: number;
+ getHeaders?: () => Record;
+}
+
// Configure logger based on environment
const loggerConfig = {
level: env.NODE_ENV === 'production' ? 'info' : 'debug',
@@ -17,9 +30,41 @@ const loggerConfig = {
options: {
colorize: true,
translateTime: 'SYS:standard',
+ ignore: 'pid,hostname',
+ messageFormat: '{msg} {reqId} {req.method} {req.url}',
},
}
: undefined,
+ // Add request ID generation - helps to correlate logs from the same request
+ genReqId: (req: PinoRequest) => {
+ // If we have an existing ID in headers, use it
+ if (req.headers?.['x-request-id']) {
+ return req.headers['x-request-id'];
+ }
+ return crypto.randomUUID();
+ },
+ // Return request/response serializers - these functions control what gets logged
+ serializers: {
+ req: (req: PinoRequest) => ({
+ method: req.method,
+ url: req.url,
+ headers: req.headers,
+ // Skip body for GET/HEAD requests where there shouldn't be any
+ ...(req.method !== 'GET' && req.method !== 'HEAD' && { body: req.body }),
+ }),
+ res: (res: PinoResponse) => ({
+ status: res.statusCode,
+ headers: res.getHeaders?.() || {},
+ }),
+ err: pino.stdSerializers.err,
+ },
+ // Ensure timestamps are present
+ timestamp: pino.stdTimeFunctions.isoTime,
+ // Base object included in every log
+ base: {
+ service: 'graphcap-data-service',
+ env: env.NODE_ENV,
+ },
};
// Create and export the logger instance
diff --git a/servers/data_service/src/utils/pino-middleware.ts b/servers/data_service/src/utils/pino-middleware.ts
new file mode 100644
index 00000000..d47b6965
--- /dev/null
+++ b/servers/data_service/src/utils/pino-middleware.ts
@@ -0,0 +1,189 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Pino Logger Middleware for Hono
+ *
+ * Sets up structured logging with hono-pino.
+ */
+
+import type { Context } from "hono";
+import { pinoLogger } from "hono-pino";
+import type { Logger } from "pino";
+import { logger } from "./logger";
+
+// Define our logger property in the Hono context variables
+declare module "hono" {
+ interface ContextVariableMap {
+ logger: Logger;
+ }
+}
+
+/**
+ * Create a customized Pino logger middleware for Hono
+ *
+ * This middleware will log all requests and responses with structured data
+ * and provide a child logger instance accessible via c.var.logger.
+ */
+export const createPinoLoggerMiddleware = () => {
+ // Return the pino logger middleware
+ return pinoLogger({
+ pino: logger,
+ });
+};
+
+/**
+ * Get safe query parameters with error handling
+ */
+const getSafeQueryParams = (c: Context): Record => {
+ try {
+ return c.req.query();
+ } catch (e) {
+ logger.debug({ error: e }, "Failed to get query parameters");
+ return {}; // Fallback to empty object if query() throws an error
+ }
+};
+
+/**
+ * Extract and parse request body based on content type
+ */
+const parseRequestBody = async (clonedReq: Request, contentType: string): Promise => {
+ if (contentType.includes("application/json")) {
+ try {
+ return await clonedReq.json();
+ } catch (jsonError) {
+ logger.debug({ error: jsonError }, "Failed to parse JSON body");
+ return "[Unparseable JSON]";
+ }
+ }
+
+ if (contentType.includes("multipart/form-data")) {
+ return "[Multipart form data]";
+ }
+
+ if (contentType.includes("application/x-www-form-urlencoded")) {
+ try {
+ return Object.fromEntries(await clonedReq.formData());
+ } catch (formError) {
+ logger.debug({ error: formError }, "Failed to parse form data");
+ return "[Unparseable form data]";
+ }
+ }
+
+ // Default to text handling
+ try {
+ const textBody = await clonedReq.text();
+ return textBody.length > 1000 ? `${textBody.substring(0, 1000)}... [truncated]` : textBody;
+ } catch (textError) {
+ logger.debug({ error: textError }, "Failed to get text body");
+ return "[Unreadable text body]";
+ }
+};
+
+/**
+ * Extract request body with proper error handling
+ */
+const getRequestBody = async (c: Context, method: string): Promise<[unknown, boolean]> => {
+ // Skip for GET and HEAD requests
+ if (method === "GET" || method === "HEAD") {
+ return [null, true];
+ }
+
+ try {
+ // Check if the request can be cloned
+ if (c.req.raw.clone && typeof c.req.raw.clone === 'function') {
+ const clonedReq = c.req.raw.clone();
+ const contentType = c.req.header("content-type") ?? "";
+ const body = await parseRequestBody(clonedReq, contentType);
+ return [body, true];
+ }
+
+ // If request cloning is not supported
+ logger.debug("Request body logging skipped - Request.clone() not supported");
+ return ["[Body logging disabled - clone not supported]", false];
+ } catch (e) {
+ logger.debug({ error: e }, "Error while attempting to read request body");
+ return ["[Error reading request body]", true];
+ }
+};
+
+/**
+ * Log request information
+ */
+const logRequest = (
+ method: string,
+ url: string,
+ path: string,
+ queryParams: Record,
+ headers: Record,
+ body: unknown,
+ bodyReadable: boolean
+) => {
+ logger.info(
+ {
+ type: "request",
+ method,
+ url,
+ path,
+ query: queryParams,
+ headers,
+ body,
+ bodyReadable,
+ },
+ "API Request",
+ );
+};
+
+/**
+ * Log response information
+ */
+const logResponse = (
+ method: string,
+ url: string,
+ path: string,
+ status: number | undefined,
+ headers: Headers | undefined,
+ responseTime: number
+) => {
+ logger.info(
+ {
+ type: "response",
+ method,
+ url,
+ path,
+ status,
+ headers,
+ responseTime,
+ body: "[Response body not captured]",
+ },
+ "API Response",
+ );
+};
+
+/**
+ * Detailed logging middleware
+ *
+ * This middleware captures and logs detailed request/response information
+ * separately from the pino logger middleware.
+ */
+export const createDetailedLoggingMiddleware = () => {
+ return async (c: Context, next: () => Promise) => {
+ // Extract basic request information
+ const { method } = c.req;
+ const url = c.req.url;
+ const path = c.req.path;
+
+ // Get request components
+ const queryParams = getSafeQueryParams(c);
+ const [reqBody, bodyReadable] = await getRequestBody(c, method);
+
+ // Log the request
+ logRequest(method, url, path, queryParams, c.req.header(), reqBody, bodyReadable);
+
+ // Process the request and measure response time
+ const startTime = Date.now();
+ await next();
+ const responseTime = Date.now() - startTime;
+
+ // Log response details
+ logResponse(method, url, path, c.res?.status, c.res?.headers, responseTime);
+ };
+};
diff --git a/servers/inference_server/README.md b/servers/inference_bridge/README.md
similarity index 100%
rename from servers/inference_server/README.md
rename to servers/inference_bridge/README.md
diff --git a/servers/inference_server/__init__.py b/servers/inference_bridge/__init__.py
similarity index 100%
rename from servers/inference_server/__init__.py
rename to servers/inference_bridge/__init__.py
diff --git a/servers/inference_server/graphcap/__init__.py b/servers/inference_bridge/graphcap/__init__.py
similarity index 100%
rename from servers/inference_server/graphcap/__init__.py
rename to servers/inference_bridge/graphcap/__init__.py
diff --git a/servers/inference_server/graphcap/perspectives/__init__.py b/servers/inference_bridge/graphcap/perspectives/__init__.py
similarity index 97%
rename from servers/inference_server/graphcap/perspectives/__init__.py
rename to servers/inference_bridge/graphcap/perspectives/__init__.py
index fffa256b..3b6a3068 100644
--- a/servers/inference_server/graphcap/perspectives/__init__.py
+++ b/servers/inference_bridge/graphcap/perspectives/__init__.py
@@ -5,23 +5,18 @@
Provides utilities for working with different perspectives/views of data.
"""
-from pathlib import Path
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List
from loguru import logger
-from .constants import WORKSPACE_PERSPECTIVES_DIR
from .perspective_loader import (
- # Classes
JsonPerspectiveProcessor,
ModuleConfig,
PerspectiveConfig,
PerspectiveModule,
PerspectiveSettings,
- # Models
SchemaField,
get_all_modules,
- # Functions
get_perspective_directories,
load_all_perspectives,
load_module_settings,
@@ -55,7 +50,7 @@
logger.exception(e)
-def get_perspective(perspective_name: str, **kwargs):
+def get_perspective(perspective_name: str):
"""
Get a perspective processor by name.
diff --git a/servers/inference_server/graphcap/perspectives/base.py b/servers/inference_bridge/graphcap/perspectives/base.py
similarity index 100%
rename from servers/inference_server/graphcap/perspectives/base.py
rename to servers/inference_bridge/graphcap/perspectives/base.py
diff --git a/servers/inference_bridge/graphcap/perspectives/base_caption.py b/servers/inference_bridge/graphcap/perspectives/base_caption.py
new file mode 100644
index 00000000..d6644611
--- /dev/null
+++ b/servers/inference_bridge/graphcap/perspectives/base_caption.py
@@ -0,0 +1,271 @@
+"""
+# SPDX-License-Identifier: Apache-2.0
+Base Caption Module
+
+Provides base classes and shared functionality for different caption types.
+"""
+
+import json
+from abc import ABC, abstractmethod
+from pathlib import Path
+from typing import Any, Dict, Optional, cast
+
+from loguru import logger
+from pydantic import BaseModel
+from rich.console import Console
+from rich.table import Table
+
+from ..providers.clients.base_client import BaseClient
+from .types import StructuredVisionConfig
+
+# Initialize Rich console
+console = Console()
+
+
+def pretty_print_caption(caption_data: Dict[str, Any]) -> str:
+ """Format caption data for pretty console output."""
+ return json.dumps(caption_data["parsed"], indent=2, ensure_ascii=False)
+
+
+class CaptionError(Exception):
+ """Base exception for caption processing errors."""
+ pass
+
+
+class CaptionParsingError(CaptionError):
+ """Exception raised when parsing JSON response fails."""
+ pass
+
+
+class CaptionProcessingError(CaptionError):
+ """Exception raised when processing an image fails."""
+ pass
+
+
+class BaseCaptionProcessor(ABC):
+ """
+ Base class for caption processors.
+
+ Provides shared functionality for processing images with vision models
+ and handling responses. Subclasses implement specific caption formats.
+
+ Attributes:
+ config_name (str): Name of this caption processor
+ version (str): Version of the processor
+ prompt (str): Instruction prompt for the vision model
+ schema (BaseModel): Pydantic model for response validation
+ """
+
+ def __init__(
+ self,
+ config_name: str,
+ version: str,
+ prompt: str,
+ schema: type[BaseModel],
+ ):
+ self.vision_config = StructuredVisionConfig(
+ config_name=config_name,
+ version=version,
+ prompt=prompt,
+ schema=schema,
+ )
+
+ def _sanitize_json_string(self, text: str) -> str:
+ """
+ Sanitize JSON string by properly escaping control characters.
+
+ Args:
+ text: Raw JSON string that may contain control characters
+
+ Returns:
+ Sanitized JSON string with properly escaped control characters
+ """
+ # Define escape sequences for common control characters
+ control_char_map = {
+ "\n": "\\n", # Line feed
+ "\r": "\\r", # Carriage return
+ "\t": "\\t", # Tab
+ "\b": "\\b", # Backspace
+ "\f": "\\f", # Form feed
+ "\v": "\\u000b", # Vertical tab
+ "\0": "", # Null character - remove it
+ }
+
+ # First pass: handle known control characters
+ for char, escape_seq in control_char_map.items():
+ text = text.replace(char, escape_seq)
+
+ # Second pass: handle any remaining control characters
+ result = ""
+ for char in text:
+ if ord(char) < 32: # Control characters are below ASCII 32
+ result += f"\\u{ord(char):04x}"
+ else:
+ result += char
+
+ return result
+
+ def _build_prompt_with_context(
+ self, context: list[str] | None = None, global_context: str | None = None
+ ) -> str:
+ """
+ Build the prompt with optional context.
+
+ Args:
+ context: List of context strings
+ global_context: Global context string
+
+ Returns:
+ The complete prompt with context if provided
+ """
+ if not context and not global_context:
+ return self.vision_config.prompt
+
+ context_block = " Consider the following context when generating the caption:\n"
+
+ if global_context:
+ context_block += f"\n{global_context}\n\n"
+
+ if context:
+ for entry in context:
+ context_block += f"\n{entry}\n\n"
+
+ context_block += "\n"
+ return f"{context_block}{self.vision_config.prompt}"
+
+ def _parse_completion_result(self, completion: Any) -> Dict[str, Any]:
+ """
+ Parse the completion result into a standardized format.
+
+ Args:
+ completion: The completion response from the vision model
+
+ Returns:
+ Parsed result as a dictionary
+
+ Raises:
+ json.JSONDecodeError: If JSON parsing fails
+ """
+ # Handle BaseModel responses through duck typing
+ if hasattr(completion, 'choices') and hasattr(completion.choices[0], 'message'):
+ result = completion.choices[0].message.parsed
+ if hasattr(result, 'model_dump'):
+ return result.model_dump()
+ return cast(Dict[str, Any], result)
+
+ result = completion.choices[0].message.parsed
+
+ # Handle string responses
+ if isinstance(result, str):
+ sanitized = self._sanitize_json_string(result)
+ return json.loads(sanitized)
+
+ # Handle nested structure responses
+ if isinstance(result, dict):
+ if "choices" in result:
+ return cast(Dict[str, Any], result["choices"][0]["message"]["parsed"]["parsed"])
+
+ if "message" in result:
+ return cast(Dict[str, Any], result["message"]["parsed"])
+
+ return cast(Dict[str, Any], result)
+
+ @abstractmethod
+ def create_rich_table(self, caption_data: Dict[str, Any]) -> Table:
+ """
+ Create a Rich table for displaying caption data.
+
+ Args:
+ caption_data: The caption data to format
+
+ Returns:
+ Rich Table object for display
+ """
+ pass
+
+ async def process_single(
+ self,
+ provider: BaseClient,
+ image_path: Path,
+ model: str,
+ max_tokens: Optional[int] = 4096,
+ temperature: Optional[float] = 0.8,
+ top_p: Optional[float] = 0.9,
+ repetition_penalty: Optional[float] = 1.15,
+ context: list[str] | None = None,
+ global_context: str | None = None,
+ ) -> Dict[str, Any]:
+ """
+ Process a single image and return caption data.
+
+ Args:
+ provider: Vision AI provider client instance
+ image_path: Path to the image file
+ model: Model name to use for processing
+ max_tokens: Maximum tokens for model response
+ temperature: Sampling temperature
+ top_p: Nucleus sampling parameter
+ repetition_penalty: Repetition penalty parameter
+ context: List of context strings
+ global_context: Global context string
+
+ Returns:
+ dict: Structured caption data according to schema
+
+ Raises:
+ CaptionParsingError: If parsing the JSON response fails
+ CaptionProcessingError: If processing the image fails
+ """
+ try:
+ # Build prompt with context if provided
+ prompt = self._build_prompt_with_context(context, global_context)
+
+ # Handle optional parameters with defaults
+ tokens = 4096 if max_tokens is None else max_tokens
+ temp = 0.8 if temperature is None else temperature
+ nucleus = 0.9 if top_p is None else top_p
+ rep_penalty = 1.15 if repetition_penalty is None else repetition_penalty
+
+ # Process image with vision model
+ completion = await provider.vision(
+ prompt=prompt,
+ image=image_path,
+ schema=self.vision_config.schema,
+ model=model,
+ max_tokens=tokens,
+ temperature=temp,
+ top_p=nucleus,
+ repetition_penalty=rep_penalty,
+ )
+
+ # Parse the completion result
+ return self._parse_completion_result(completion)
+
+ except json.JSONDecodeError as e:
+ logger.error(f"Failed to parse JSON response: {e}")
+ raise CaptionParsingError(f"Error parsing response for {image_path}: {str(e)}")
+ except Exception as e:
+ raise CaptionProcessingError(f"Error processing {image_path}: {str(e)}")
+
+ # Note: process_batch has been removed as batch processing is being migrated to Kafka.
+ # Batch processing functionality should now be implemented in Kafka-based pipeline components.
+
+ @abstractmethod
+ def to_table(self, caption_data: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Convert caption data to a flat dictionary suitable for tabular representation.
+
+ Args:
+ caption_data: The caption data to format
+
+ Returns:
+ Dict[str, Any]: Flattened dictionary for tabular representation
+ """
+ pass
+
+ @abstractmethod
+ def to_context(self, caption_data: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Convert caption data to a context string suitable for downstream perspectives.
+ """
+ pass
diff --git a/servers/inference_server/graphcap/perspectives/constants.py b/servers/inference_bridge/graphcap/perspectives/constants.py
similarity index 100%
rename from servers/inference_server/graphcap/perspectives/constants.py
rename to servers/inference_bridge/graphcap/perspectives/constants.py
diff --git a/servers/inference_server/graphcap/perspectives/loaders/__init__.py b/servers/inference_bridge/graphcap/perspectives/loaders/__init__.py
similarity index 100%
rename from servers/inference_server/graphcap/perspectives/loaders/__init__.py
rename to servers/inference_bridge/graphcap/perspectives/loaders/__init__.py
diff --git a/servers/inference_server/graphcap/perspectives/loaders/directory.py b/servers/inference_bridge/graphcap/perspectives/loaders/directory.py
similarity index 100%
rename from servers/inference_server/graphcap/perspectives/loaders/directory.py
rename to servers/inference_bridge/graphcap/perspectives/loaders/directory.py
diff --git a/servers/inference_server/graphcap/perspectives/loaders/json_file.py b/servers/inference_bridge/graphcap/perspectives/loaders/json_file.py
similarity index 100%
rename from servers/inference_server/graphcap/perspectives/loaders/json_file.py
rename to servers/inference_bridge/graphcap/perspectives/loaders/json_file.py
diff --git a/servers/inference_server/graphcap/perspectives/loaders/modules.py b/servers/inference_bridge/graphcap/perspectives/loaders/modules.py
similarity index 100%
rename from servers/inference_server/graphcap/perspectives/loaders/modules.py
rename to servers/inference_bridge/graphcap/perspectives/loaders/modules.py
diff --git a/servers/inference_server/graphcap/perspectives/loaders/settings.py b/servers/inference_bridge/graphcap/perspectives/loaders/settings.py
similarity index 100%
rename from servers/inference_server/graphcap/perspectives/loaders/settings.py
rename to servers/inference_bridge/graphcap/perspectives/loaders/settings.py
diff --git a/servers/inference_server/graphcap/perspectives/models.py b/servers/inference_bridge/graphcap/perspectives/models.py
similarity index 95%
rename from servers/inference_server/graphcap/perspectives/models.py
rename to servers/inference_bridge/graphcap/perspectives/models.py
index 09631863..a7230e62 100644
--- a/servers/inference_server/graphcap/perspectives/models.py
+++ b/servers/inference_bridge/graphcap/perspectives/models.py
@@ -8,9 +8,6 @@
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
-from typing_extensions import override
-
-from .base import PerspectiveData
class SchemaField(BaseModel):
diff --git a/servers/inference_server/graphcap/perspectives/module.py b/servers/inference_bridge/graphcap/perspectives/module.py
similarity index 100%
rename from servers/inference_server/graphcap/perspectives/module.py
rename to servers/inference_bridge/graphcap/perspectives/module.py
diff --git a/servers/inference_server/graphcap/perspectives/perspective_loader.py b/servers/inference_bridge/graphcap/perspectives/perspective_loader.py
similarity index 100%
rename from servers/inference_server/graphcap/perspectives/perspective_loader.py
rename to servers/inference_bridge/graphcap/perspectives/perspective_loader.py
diff --git a/servers/inference_server/graphcap/perspectives/processor.py b/servers/inference_bridge/graphcap/perspectives/processor.py
similarity index 100%
rename from servers/inference_server/graphcap/perspectives/processor.py
rename to servers/inference_bridge/graphcap/perspectives/processor.py
diff --git a/servers/inference_server/graphcap/perspectives/types.py b/servers/inference_bridge/graphcap/perspectives/types.py
similarity index 95%
rename from servers/inference_server/graphcap/perspectives/types.py
rename to servers/inference_bridge/graphcap/perspectives/types.py
index 9bc18e70..d41af3f4 100644
--- a/servers/inference_server/graphcap/perspectives/types.py
+++ b/servers/inference_bridge/graphcap/perspectives/types.py
@@ -26,4 +26,4 @@ class StructuredVisionConfig:
config_name: str
version: str
prompt: str
- schema: BaseModel
+ schema: type[BaseModel]
diff --git a/servers/inference_server/graphcap/providers/__init__.py b/servers/inference_bridge/graphcap/providers/__init__.py
similarity index 52%
rename from servers/inference_server/graphcap/providers/__init__.py
rename to servers/inference_bridge/graphcap/providers/__init__.py
index 950cb1d5..57392b17 100644
--- a/servers/inference_server/graphcap/providers/__init__.py
+++ b/servers/inference_bridge/graphcap/providers/__init__.py
@@ -14,6 +14,23 @@
Components:
clients: Provider-specific client implementations
- provider_config: Configuration management
- provider_manager: Provider lifecycle management
+ factory: Provider client factory
+ types: Common type definitions
"""
+
+from .factory import (
+ ProviderFactory,
+ clear_provider_cache,
+ create_provider_client,
+ get_provider_factory,
+)
+from .types import ProviderConfig, RateLimits
+
+__all__ = [
+ "ProviderFactory",
+ "create_provider_client",
+ "get_provider_factory",
+ "clear_provider_cache",
+ "ProviderConfig",
+ "RateLimits",
+]
diff --git a/servers/inference_server/graphcap/providers/clients/__init__.py b/servers/inference_bridge/graphcap/providers/clients/__init__.py
similarity index 100%
rename from servers/inference_server/graphcap/providers/clients/__init__.py
rename to servers/inference_bridge/graphcap/providers/clients/__init__.py
diff --git a/servers/inference_server/graphcap/providers/clients/base_client.py b/servers/inference_bridge/graphcap/providers/clients/base_client.py
similarity index 93%
rename from servers/inference_server/graphcap/providers/clients/base_client.py
rename to servers/inference_bridge/graphcap/providers/clients/base_client.py
index f2401492..9f77e88a 100644
--- a/servers/inference_server/graphcap/providers/clients/base_client.py
+++ b/servers/inference_bridge/graphcap/providers/clients/base_client.py
@@ -19,12 +19,10 @@
environment (str): Deployment environment
env_var (str): Environment variable for API key
base_url (str): Base API URL
- default_model (str): Default model identifier
"""
import asyncio
import base64
-import os
import time
from abc import ABC, abstractmethod
from pathlib import Path
@@ -38,15 +36,7 @@
class BaseClient(AsyncOpenAI, ABC):
"""Abstract base class for all provider clients"""
- def __init__(self, name: str, kind: str, environment: str, env_var: str, base_url: str, default_model: str):
- # Check for required environment variable
- if env_var and env_var != "NONE":
- api_key = os.getenv(env_var)
- if api_key is None:
- raise ValueError(f"Environment variable {env_var} is not set")
- else:
- api_key = "stub_key"
-
+ def __init__(self, name: str, kind: str, environment: str, base_url: str, api_key: str):
# Initialize OpenAI client
super().__init__(api_key=api_key, base_url=base_url)
@@ -54,9 +44,7 @@ def __init__(self, name: str, kind: str, environment: str, env_var: str, base_ur
self.name = name
self.kind = kind
self.environment = environment
- self.env_var = env_var
self.base_url = base_url
- self.default_model = default_model
# Rate limiting state
self._request_times: list[float] = []
diff --git a/servers/inference_server/graphcap/providers/clients/gemini_client.py b/servers/inference_bridge/graphcap/providers/clients/gemini_client.py
similarity index 88%
rename from servers/inference_server/graphcap/providers/clients/gemini_client.py
rename to servers/inference_bridge/graphcap/providers/clients/gemini_client.py
index 4e232c56..16aa6dc3 100644
--- a/servers/inference_server/graphcap/providers/clients/gemini_client.py
+++ b/servers/inference_bridge/graphcap/providers/clients/gemini_client.py
@@ -15,7 +15,6 @@
GeminiClient: Gemini API client implementation
"""
-import time
from typing import Any
from loguru import logger
@@ -27,22 +26,19 @@
class GeminiClient(BaseClient):
"""Client for Google's Gemini API with OpenAI compatibility layer"""
- def __init__(self, name: str, kind: str, environment: str, env_var: str, base_url: str, default_model: str):
+ def __init__(self, name: str, kind: str, environment: str, base_url: str, api_key: str):
logger.info(f"GeminiClient initialized with base_url: {base_url}")
super().__init__(
name=name,
kind=kind,
environment=environment,
- env_var=env_var,
base_url=base_url.rstrip("/"),
- default_model=default_model,
+ api_key=api_key,
)
def _format_vision_content(self, text: str, image_data: str) -> list[dict[str, Any]]:
"""Format vision content for Gemini API"""
# TODO: Add feature flag to handle gemini free tier rate limits instead of this hack
- logger.info("Sleeping for 3 seconds to avoid rate limits")
- time.sleep(3)
return [
{"type": "text", "text": text},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}},
diff --git a/servers/inference_server/graphcap/providers/clients/ollama_client.py b/servers/inference_bridge/graphcap/providers/clients/ollama_client.py
similarity index 91%
rename from servers/inference_server/graphcap/providers/clients/ollama_client.py
rename to servers/inference_bridge/graphcap/providers/clients/ollama_client.py
index fe42c180..3a76c3ff 100644
--- a/servers/inference_server/graphcap/providers/clients/ollama_client.py
+++ b/servers/inference_bridge/graphcap/providers/clients/ollama_client.py
@@ -26,13 +26,12 @@
class OllamaClient(BaseClient):
"""Client for Ollama API with OpenAI compatibility layer"""
- def __init__(self, name: str, kind: str, environment: str, env_var: str, base_url: str, default_model: str):
+ def __init__(self, name: str, kind: str, environment: str, base_url: str, api_key: str = "stub_key"):
logger.info("Initializing OllamaClient:")
logger.info(f" - name: {name}")
logger.info(f" - kind: {kind}")
logger.info(f" - environment: {environment}")
logger.info(f" - base_url: {base_url}")
- logger.info(f" - default_model: {default_model}")
# Store the raw base URL for Ollama-specific endpoints
base_url = base_url.rstrip("/")
@@ -56,9 +55,8 @@ def __init__(self, name: str, kind: str, environment: str, env_var: str, base_ur
name=name,
kind=kind,
environment=environment,
- env_var=env_var,
base_url=openai_base_url,
- default_model=default_model,
+ api_key=api_key,
)
logger.debug(f"OllamaClient initialized with environment: {environment}, kind: {kind}")
logger.debug(f"Using base URL {self._raw_base_url} for Ollama endpoints")
@@ -77,7 +75,6 @@ async def get_models(self):
try:
logger.info("Fetching models from Ollama:")
logger.info(f" - URL: {self._raw_base_url}/models")
- logger.info(f" - Default model: {self.default_model}")
async with httpx.AsyncClient() as client:
response = await client.get(f"{self._raw_base_url}/models")
@@ -90,7 +87,6 @@ async def get_models(self):
logger.error("Connection error while fetching models from Ollama:")
logger.error(f" - Error: {str(e)}")
logger.error(f" - URL: {self._raw_base_url}/models")
- logger.error(f" - Default model: {self.default_model}")
raise
except Exception as e:
logger.error(f"Failed to get models from Ollama: {str(e)}")
diff --git a/servers/inference_server/graphcap/providers/clients/openai_client.py b/servers/inference_bridge/graphcap/providers/clients/openai_client.py
similarity index 95%
rename from servers/inference_server/graphcap/providers/clients/openai_client.py
rename to servers/inference_bridge/graphcap/providers/clients/openai_client.py
index 46eebda0..5459b02f 100644
--- a/servers/inference_server/graphcap/providers/clients/openai_client.py
+++ b/servers/inference_bridge/graphcap/providers/clients/openai_client.py
@@ -29,15 +29,14 @@
class OpenAIClient(BaseClient):
"""Client for OpenAI API"""
- def __init__(self, name: str, kind: str, environment: str, env_var: str, base_url: str, default_model: str):
+ def __init__(self, name: str, kind: str, environment: str, base_url: str, api_key: str):
logger.info(f"OpenAIClient initialized with base_url: {base_url}")
super().__init__(
name=name,
kind=kind,
environment=environment,
- env_var=env_var,
base_url=base_url.rstrip("/"),
- default_model=default_model,
+ api_key=api_key,
)
def _format_vision_content(self, text: str, image_data: str) -> list[dict[str, Any]]:
diff --git a/servers/inference_server/graphcap/providers/clients/openrouter_client.py b/servers/inference_bridge/graphcap/providers/clients/openrouter_client.py
similarity index 92%
rename from servers/inference_server/graphcap/providers/clients/openrouter_client.py
rename to servers/inference_bridge/graphcap/providers/clients/openrouter_client.py
index 031a9e51..b6c220f3 100644
--- a/servers/inference_server/graphcap/providers/clients/openrouter_client.py
+++ b/servers/inference_bridge/graphcap/providers/clients/openrouter_client.py
@@ -25,17 +25,22 @@
class OpenRouterClient(BaseClient):
- """Client for OpenRouter API with OpenAI compatibility layer"""
+ """Client for OpenRouter API"""
- def __init__(self, name: str, kind: str, environment: str, env_var: str, base_url: str, default_model: str):
+ def __init__(self, name: str, kind: str, environment: str, base_url: str, api_key: str):
logger.info(f"OpenRouterClient initialized with base_url: {base_url}")
+
+ # Base URL handling for OpenRouter
+ if not base_url.endswith("/v1"):
+ base_url = f"{base_url}/v1"
+ logger.info(f"Added /v1 to base URL: {base_url}")
+
super().__init__(
name=name,
kind=kind,
environment=environment,
- env_var=env_var,
base_url=base_url.rstrip("/"),
- default_model=default_model,
+ api_key=api_key,
)
def _format_vision_content(self, text: str, image_data: str) -> list[dict[str, Any]]:
diff --git a/servers/inference_server/graphcap/providers/clients/vllm_client.py b/servers/inference_bridge/graphcap/providers/clients/vllm_client.py
similarity index 93%
rename from servers/inference_server/graphcap/providers/clients/vllm_client.py
rename to servers/inference_bridge/graphcap/providers/clients/vllm_client.py
index 26de0d9f..6789783d 100644
--- a/servers/inference_server/graphcap/providers/clients/vllm_client.py
+++ b/servers/inference_bridge/graphcap/providers/clients/vllm_client.py
@@ -26,9 +26,9 @@
class VLLMClient(BaseClient):
- """Client for VLLM API with OpenAI compatibility layer"""
+ """Client for vLLM API"""
- def __init__(self, name: str, kind: str, environment: str, env_var: str, base_url: str, default_model: str):
+ def __init__(self, name: str, kind: str, environment: str, base_url: str, api_key: str = "stub_key"):
# If base_url doesn't include /v1, append it
if not base_url.endswith("/v1"):
base_url = f"{base_url}/v1"
@@ -38,9 +38,8 @@ def __init__(self, name: str, kind: str, environment: str, env_var: str, base_ur
name=name,
kind=kind,
environment=environment,
- env_var=env_var,
base_url=base_url.rstrip("/"),
- default_model=default_model,
+ api_key=api_key,
)
def _format_vision_content(self, text: str, image_data: str) -> list[dict[str, Any]]:
diff --git a/servers/inference_bridge/graphcap/providers/factory.py b/servers/inference_bridge/graphcap/providers/factory.py
new file mode 100644
index 00000000..f66343e2
--- /dev/null
+++ b/servers/inference_bridge/graphcap/providers/factory.py
@@ -0,0 +1,165 @@
+"""
+# SPDX-License-Identifier: Apache-2.0
+Provider Factory Module
+
+This module provides factory functionality for creating provider clients.
+
+Key features:
+- Client instantiation
+- Environment validation
+- Rate limit configuration
+- Client caching
+"""
+
+from typing import Dict, Optional
+
+from loguru import logger
+
+from .clients import BaseClient, get_client
+
+
+class ProviderFactory:
+ """Factory class for creating provider clients with specific configurations"""
+
+ def __init__(self):
+ """Initialize provider factory"""
+ logger.info("Initializing ProviderFactory")
+ self._client_cache: Dict[str, BaseClient] = {}
+
+ def create_client(
+ self,
+ name: str,
+ kind: str,
+ environment: str,
+ base_url: str,
+ api_key: str,
+ rate_limits: Optional[dict] = None,
+ use_cache: bool = True,
+ ) -> BaseClient:
+ """Create a client with the given configuration.
+
+ Args:
+ name: Unique identifier for the provider
+ kind: Type of provider (e.g., 'openai', 'anthropic', 'gemini')
+ environment: Provider environment (cloud, local)
+ base_url: Base URL for the provider API
+ api_key: API key for the provider
+ rate_limits: Rate limiting configuration
+ use_cache: Whether to cache and reuse client instances (default: True)
+
+ Returns:
+ BaseClient: The provider client instance
+
+ Raises:
+ ValueError: If client creation fails
+ """
+ # Check cache first if enabled
+ cache_key = f"{name}:{kind}:{environment}:{base_url}:{api_key}"
+ if use_cache and cache_key in self._client_cache:
+ logger.debug(f"Using cached client for provider: {name}")
+ return self._client_cache[cache_key]
+
+ logger.info(f"Creating new client for provider: {name}")
+ logger.info("Provider config details:")
+ logger.info(f" - kind: {kind}")
+ logger.info(f" - environment: {environment}")
+ logger.info(f" - base_url: {base_url}")
+
+ try:
+ client = get_client(
+ name=name,
+ kind=kind,
+ environment=environment,
+ api_key=api_key,
+ base_url=base_url,
+ )
+
+ # Set rate limits if configured
+ if rate_limits:
+ logger.debug(
+ f"Setting rate limits for {name} - requests: {rate_limits.get('requests_per_minute')}/min, ",
+ f"tokens: {rate_limits.get('tokens_per_minute')}/min"
+ )
+ client.requests_per_minute = rate_limits.get("requests_per_minute")
+ client.tokens_per_minute = rate_limits.get("tokens_per_minute")
+
+ # Cache the client if enabled
+ if use_cache:
+ self._client_cache[cache_key] = client
+
+ return client
+
+ except Exception as e:
+ logger.error(f"Failed to create client for {name}: {str(e)}")
+ logger.error("Provider config details:")
+ logger.error(f" - kind: {kind}")
+ logger.error(f" - environment: {environment}")
+ logger.error(f" - base_url: {base_url}")
+ raise ValueError(f"Failed to create client for {name}: {str(e)}")
+
+ def clear_cache(self) -> None:
+ """Clear the client cache"""
+ self._client_cache.clear()
+
+
+# Global provider factory instance
+_provider_factory: Optional[ProviderFactory] = None
+
+
+def get_provider_factory() -> ProviderFactory:
+ """Get or create the global provider factory instance.
+
+ Returns:
+ ProviderFactory: The global provider factory instance
+ """
+ global _provider_factory
+
+ if _provider_factory is None:
+ _provider_factory = ProviderFactory()
+ logger.info("Created new provider factory instance")
+
+ return _provider_factory
+
+
+def create_provider_client(
+ name: str,
+ kind: str,
+ environment: str,
+ base_url: str,
+ api_key: str,
+ rate_limits: Optional[dict] = None,
+ use_cache: bool = True,
+) -> BaseClient:
+ """Create a provider client with the given configuration.
+
+ Args:
+ name: Unique identifier for the provider
+ kind: Type of provider (e.g., 'openai', 'anthropic', 'gemini')
+ environment: Provider environment (cloud, local)
+ base_url: Base URL for the provider API
+ api_key: API key for the provider
+ rate_limits: Rate limiting configuration
+ use_cache: Whether to cache and reuse client instances (default: True)
+
+ Returns:
+ BaseClient: The provider client instance
+
+ Raises:
+ ValueError: If client creation fails
+ """
+ factory = get_provider_factory()
+ return factory.create_client(
+ name=name,
+ kind=kind,
+ environment=environment,
+ base_url=base_url,
+ api_key=api_key,
+ rate_limits=rate_limits,
+ use_cache=use_cache,
+ )
+
+
+def clear_provider_cache() -> None:
+ """Clear the provider client cache"""
+ if _provider_factory is not None:
+ _provider_factory.clear_cache()
diff --git a/servers/inference_server/graphcap/providers/types.py b/servers/inference_bridge/graphcap/providers/types.py
similarity index 96%
rename from servers/inference_server/graphcap/providers/types.py
rename to servers/inference_bridge/graphcap/providers/types.py
index 4ef7922d..d7dce188 100644
--- a/servers/inference_server/graphcap/providers/types.py
+++ b/servers/inference_bridge/graphcap/providers/types.py
@@ -22,6 +22,5 @@ class ProviderConfig:
env_var: str
base_url: str
models: list[str]
- default_model: str
fetch_models: bool = False
rate_limits: Optional[RateLimits] = None
diff --git a/servers/inference_server/pipelines/.dep_hash b/servers/inference_bridge/pipelines/.dep_hash
similarity index 100%
rename from servers/inference_server/pipelines/.dep_hash
rename to servers/inference_bridge/pipelines/.dep_hash
diff --git a/servers/inference_server/pipelines/.dockerignore b/servers/inference_bridge/pipelines/.dockerignore
similarity index 100%
rename from servers/inference_server/pipelines/.dockerignore
rename to servers/inference_bridge/pipelines/.dockerignore
diff --git a/servers/inference_server/pipelines/Dockerfile.pipelines.dev b/servers/inference_bridge/pipelines/Dockerfile.pipelines.dev
similarity index 100%
rename from servers/inference_server/pipelines/Dockerfile.pipelines.dev
rename to servers/inference_bridge/pipelines/Dockerfile.pipelines.dev
diff --git a/servers/inference_server/pipelines/README.md b/servers/inference_bridge/pipelines/README.md
similarity index 100%
rename from servers/inference_server/pipelines/README.md
rename to servers/inference_bridge/pipelines/README.md
diff --git a/servers/inference_server/pipelines/Taskfile.pipelines.yml b/servers/inference_bridge/pipelines/Taskfile.pipelines.yml
similarity index 100%
rename from servers/inference_server/pipelines/Taskfile.pipelines.yml
rename to servers/inference_bridge/pipelines/Taskfile.pipelines.yml
diff --git a/servers/inference_server/pipelines/_scripts/pipeline_entrypoint.sh b/servers/inference_bridge/pipelines/_scripts/pipeline_entrypoint.sh
similarity index 100%
rename from servers/inference_server/pipelines/_scripts/pipeline_entrypoint.sh
rename to servers/inference_bridge/pipelines/_scripts/pipeline_entrypoint.sh
diff --git a/servers/inference_server/pipelines/dagster.example.yml b/servers/inference_bridge/pipelines/dagster.example.yml
similarity index 100%
rename from servers/inference_server/pipelines/dagster.example.yml
rename to servers/inference_bridge/pipelines/dagster.example.yml
diff --git a/servers/inference_server/pipelines/pipelines/__init__.py b/servers/inference_bridge/pipelines/pipelines/__init__.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/__init__.py
rename to servers/inference_bridge/pipelines/pipelines/__init__.py
diff --git a/servers/inference_server/pipelines/pipelines/assets.py b/servers/inference_bridge/pipelines/pipelines/assets.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/assets.py
rename to servers/inference_bridge/pipelines/pipelines/assets.py
diff --git a/servers/inference_server/pipelines/pipelines/common/__init__.py b/servers/inference_bridge/pipelines/pipelines/common/__init__.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/common/__init__.py
rename to servers/inference_bridge/pipelines/pipelines/common/__init__.py
diff --git a/servers/inference_server/pipelines/pipelines/common/constants.py b/servers/inference_bridge/pipelines/pipelines/common/constants.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/common/constants.py
rename to servers/inference_bridge/pipelines/pipelines/common/constants.py
diff --git a/servers/inference_server/pipelines/pipelines/common/io.py b/servers/inference_bridge/pipelines/pipelines/common/io.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/common/io.py
rename to servers/inference_bridge/pipelines/pipelines/common/io.py
diff --git a/servers/inference_server/pipelines/pipelines/common/logging.py b/servers/inference_bridge/pipelines/pipelines/common/logging.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/common/logging.py
rename to servers/inference_bridge/pipelines/pipelines/common/logging.py
diff --git a/servers/inference_server/pipelines/pipelines/common/resources.py b/servers/inference_bridge/pipelines/pipelines/common/resources.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/common/resources.py
rename to servers/inference_bridge/pipelines/pipelines/common/resources.py
diff --git a/servers/inference_server/pipelines/pipelines/common/utils.py b/servers/inference_bridge/pipelines/pipelines/common/utils.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/common/utils.py
rename to servers/inference_bridge/pipelines/pipelines/common/utils.py
diff --git a/servers/inference_server/pipelines/pipelines/common/workspace.py b/servers/inference_bridge/pipelines/pipelines/common/workspace.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/common/workspace.py
rename to servers/inference_bridge/pipelines/pipelines/common/workspace.py
diff --git a/servers/inference_server/pipelines/pipelines/definitions.py b/servers/inference_bridge/pipelines/pipelines/definitions.py
similarity index 85%
rename from servers/inference_server/pipelines/pipelines/definitions.py
rename to servers/inference_bridge/pipelines/pipelines/definitions.py
index c2b86142..5e01f0b4 100644
--- a/servers/inference_server/pipelines/pipelines/definitions.py
+++ b/servers/inference_bridge/pipelines/pipelines/definitions.py
@@ -13,17 +13,10 @@
from .common.resources import FileSystemConfig, PerspectiveConfig, PostgresConfig, ProviderConfigFile
from .huggingface import huggingface_client
from .huggingface.types import HfUploadManifestConfig
-from .perspectives.jobs import PerspectivePipelineRunConfig
# Import jobs
from .jobs import JOBS
-
-# # Import sensors
-# from .sensors.image_sensors import (
-# new_image_sensor,
-# art_analysis_asset_sensor
-# )
-
+from .perspectives.jobs import PerspectivePipelineRunConfig
# Configure custom loggers
loggers = configure_loggers()
@@ -43,6 +36,4 @@
**loggers, # Integrate custom loggers into resources
},
jobs=[*JOBS],
- # schedules=[daily_caption_schedule],
- # sensors=[new_image_sensor, art_analysis_asset_sensor]
)
diff --git a/servers/inference_server/pipelines/pipelines/huggingface/__init__.py b/servers/inference_bridge/pipelines/pipelines/huggingface/__init__.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/huggingface/__init__.py
rename to servers/inference_bridge/pipelines/pipelines/huggingface/__init__.py
diff --git a/servers/inference_server/pipelines/pipelines/huggingface/client.py b/servers/inference_bridge/pipelines/pipelines/huggingface/client.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/huggingface/client.py
rename to servers/inference_bridge/pipelines/pipelines/huggingface/client.py
diff --git a/servers/inference_server/pipelines/pipelines/huggingface/dataset_export.py b/servers/inference_bridge/pipelines/pipelines/huggingface/dataset_export.py
similarity index 97%
rename from servers/inference_server/pipelines/pipelines/huggingface/dataset_export.py
rename to servers/inference_bridge/pipelines/pipelines/huggingface/dataset_export.py
index e38663f2..07d8f4c0 100644
--- a/servers/inference_server/pipelines/pipelines/huggingface/dataset_export.py
+++ b/servers/inference_bridge/pipelines/pipelines/huggingface/dataset_export.py
@@ -7,10 +7,7 @@
from graphcap.perspectives.types import PerspectiveCaptionOutput
from huggingface_hub import HfApi, upload_file
-from .dataset_manifest import (
- create_dataset_manifest,
- load_perspective_results_from_manifest,
-)
+from .dataset_manifest import create_dataset_manifest, load_perspective_results_from_manifest
from .dataset_prep import create_huggingface_dataset
from .dataset_readme import generate_readme_content
from .perspective_export import upload_perspective_dataset_to_huggingface
@@ -92,7 +89,6 @@ def huggingface_upload_manifest(
if not manifest_path_str:
raise ValueError("Manifest path not found in dataset_export_manifest metadata.")
- # export_dir = Path(str(export_dir_str))
manifest_path = Path(str(manifest_path_str))
# Load perspective results from manifest
diff --git a/servers/inference_server/pipelines/pipelines/huggingface/dataset_import.py b/servers/inference_bridge/pipelines/pipelines/huggingface/dataset_import.py
similarity index 70%
rename from servers/inference_server/pipelines/pipelines/huggingface/dataset_import.py
rename to servers/inference_bridge/pipelines/pipelines/huggingface/dataset_import.py
index d3b979ea..cba5f393 100644
--- a/servers/inference_server/pipelines/pipelines/huggingface/dataset_import.py
+++ b/servers/inference_bridge/pipelines/pipelines/huggingface/dataset_import.py
@@ -27,7 +27,8 @@
from huggingface_hub import hf_hub_download
from tqdm import tqdm
-from .types import DatasetImportConfig, DatasetParquetUrlDownloadConfig, DatasetParseConfig
+from .types import (DatasetImportConfig, DatasetParquetUrlDownloadConfig,
+ DatasetParseConfig)
def _clone_with_git_lfs(
@@ -242,93 +243,152 @@ def dataset_download_urls(
dataset_download: Path to the downloaded dataset
config: Configuration for URL downloading
"""
+ input_dir, output_dir = _setup_directories(context, dataset_download, config)
+ parquet_files = _find_parquet_files(context, input_dir)
+
+ successful_downloads = 0
+ failed_downloads = 0
+ total_urls = 0
+
+ for parquet_file in parquet_files:
+ df = _load_parquet_file(context, parquet_file, config)
+ download_results = _process_dataframe(context, df, output_dir, config)
+
+ successful_downloads += download_results["successful"]
+ failed_downloads += download_results["failed"]
+ total_urls += download_results["total"]
+
+ # Log summary
+ context.log.info(
+ f"Download complete. Successful: {successful_downloads}, Failed: {failed_downloads}, Total: {total_urls}"
+ )
+
+
+def _setup_directories(
+ context: dg.AssetExecutionContext, dataset_download: str, config: DatasetParquetUrlDownloadConfig
+) -> tuple[Path, Path]:
+ """Setup input and output directories for URL downloads."""
input_dir = Path(dataset_download) / config.parquet_dir
if not input_dir.exists():
raise ValueError(f"Parquet directory not found at {input_dir}")
+
+ output_dir = Path(config.output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ return input_dir, output_dir
- # Find all parquet files
+
+def _find_parquet_files(context: dg.AssetExecutionContext, input_dir: Path) -> list[Path]:
+ """Find parquet files in the input directory."""
parquet_files = list(input_dir.glob("*.parquet"))
if not parquet_files:
raise ValueError(f"No parquet files found in {input_dir}")
-
+
context.log.info(f"Found {len(parquet_files)} parquet files")
+ return parquet_files
+
+
+def _load_parquet_file(
+ context: dg.AssetExecutionContext, parquet_file: Path, config: DatasetParquetUrlDownloadConfig
+) -> pd.DataFrame:
+ """Load and validate a parquet file."""
+ context.log.info(f"Processing {parquet_file}")
+ df = pd.read_parquet(parquet_file)
+
+ context.log.info(f"Loaded parquet file with {len(df)} rows")
+ context.log.info(f"Columns: {df.columns.tolist()}")
+
+ if config.url_column not in df.columns:
+ raise ValueError(
+ f"URL column '{config.url_column}' not found in {parquet_file}. "
+ f"Available columns: {df.columns.tolist()}"
+ )
+
+ return df
- # Create output directory
- output_dir = Path(config.output_dir)
- output_dir.mkdir(parents=True, exist_ok=True)
- # Track progress
+def _process_dataframe(
+ context: dg.AssetExecutionContext,
+ df: pd.DataFrame,
+ output_dir: Path,
+ config: DatasetParquetUrlDownloadConfig,
+) -> dict:
+ """Process a dataframe and download URLs."""
successful_downloads = 0
failed_downloads = 0
total_urls = 0
-
- # Process each parquet file
- for parquet_file in parquet_files:
- context.log.info(f"Processing {parquet_file}")
- df = pd.read_parquet(parquet_file)
-
- context.log.info(f"Loaded parquet file with {len(df)} rows")
- context.log.info(f"Columns: {df.columns.tolist()}")
-
- if config.url_column not in df.columns:
- raise ValueError(
- f"URL column '{config.url_column}' not found in {parquet_file}. "
- f"Available columns: {df.columns.tolist()}"
+
+ with ThreadPoolExecutor(max_workers=config.max_workers) as executor:
+ future_to_url = {}
+
+ for idx, row in df.iterrows():
+ if idx % 1000 == 0:
+ context.log.info(f"Processing row {idx}")
+
+ download_batch = _prepare_download_batch(context, row, output_dir, config)
+ if not download_batch:
+ continue
+
+ for url, output_path in download_batch:
+ if len(future_to_url) >= config.max_workers * 2:
+ # Process current batch before adding more
+ results = _process_completed_downloads(
+ future_to_url, context, successful_downloads, failed_downloads
+ )
+ successful_downloads, failed_downloads = results
+ future_to_url = {}
+
+ # Submit download task
+ future = executor.submit(_download_url, url, output_path, context)
+ future_to_url[future] = (url, output_path)
+ total_urls += 1
+
+ # Process remaining downloads
+ if future_to_url:
+ results = _process_completed_downloads(
+ future_to_url, context, successful_downloads, failed_downloads
)
+ successful_downloads, failed_downloads = results
+
+ return {
+ "successful": successful_downloads,
+ "failed": failed_downloads,
+ "total": total_urls
+ }
- # Process each row
- with ThreadPoolExecutor(max_workers=config.max_workers) as executor:
- future_to_url = {}
-
- for idx, row in df.iterrows():
- if idx % 1000 == 0:
- context.log.info(f"Processing row {idx}")
-
- # Extract URLs using helper function
- urls = _extract_urls(row[config.url_column])
- if not urls:
- context.log.debug(f"Skipping row {idx}: no valid URLs")
- continue
-
- # Use row ID as base filename
- base_filename = row["id"]
- if not base_filename:
- context.log.warning(f"Skipping row {idx}: no ID")
- continue
-
- for i, url in enumerate(urls):
- # Generate unique filename for each URL
- filename = f"{base_filename}_{i}.{config.default_extension}"
- output_path = output_dir / filename
-
- # Skip if file exists and no overwrite
- if output_path.exists() and not config.overwrite_existing:
- context.log.debug(f"Skipping existing file: {output_path}")
- continue
-
- # Limit batch size for rate limiting
- if len(future_to_url) >= config.max_workers * 2:
- # Wait for some downloads to complete before adding more
- successful_downloads, failed_downloads = _process_completed_downloads(
- future_to_url, context, successful_downloads, failed_downloads
- )
- future_to_url = {}
-
- # Submit download task
- future = executor.submit(_download_url, url, output_path, context)
- future_to_url[future] = (url, output_path)
- total_urls += 1
-
- # Process remaining downloads
- if future_to_url:
- successful_downloads, failed_downloads = _process_completed_downloads(
- future_to_url, context, successful_downloads, failed_downloads
- )
- # Log summary
- context.log.info(
- f"Download complete. Successful: {successful_downloads}, Failed: {failed_downloads}, Total: {total_urls}"
- )
+def _prepare_download_batch(
+ context: dg.AssetExecutionContext,
+ row: pd.Series,
+ output_dir: Path,
+ config: DatasetParquetUrlDownloadConfig,
+) -> list[tuple[str, Path]]:
+ """Prepare a batch of URLs to download for a single row."""
+ # Extract URLs
+ urls = _extract_urls(row[config.url_column])
+ if not urls:
+ return []
+
+ # Use row ID as base filename
+ base_filename = row.get("id")
+ if not base_filename:
+ context.log.warning("Skipping row: no ID")
+ return []
+
+ download_batch = []
+ for i, url in enumerate(urls):
+ # Generate unique filename for each URL
+ filename = f"{base_filename}_{i}.{config.default_extension}"
+ output_path = output_dir / filename
+
+ # Skip if file exists and no overwrite
+ if output_path.exists() and not config.overwrite_existing:
+ context.log.debug(f"Skipping existing file: {output_path}")
+ continue
+
+ download_batch.append((url, output_path))
+
+ return download_batch
def _process_completed_downloads(
diff --git a/servers/inference_server/pipelines/pipelines/huggingface/dataset_manifest.py b/servers/inference_bridge/pipelines/pipelines/huggingface/dataset_manifest.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/huggingface/dataset_manifest.py
rename to servers/inference_bridge/pipelines/pipelines/huggingface/dataset_manifest.py
diff --git a/servers/inference_server/pipelines/pipelines/huggingface/dataset_prep.py b/servers/inference_bridge/pipelines/pipelines/huggingface/dataset_prep.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/huggingface/dataset_prep.py
rename to servers/inference_bridge/pipelines/pipelines/huggingface/dataset_prep.py
diff --git a/servers/inference_server/pipelines/pipelines/huggingface/dataset_readme.py b/servers/inference_bridge/pipelines/pipelines/huggingface/dataset_readme.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/huggingface/dataset_readme.py
rename to servers/inference_bridge/pipelines/pipelines/huggingface/dataset_readme.py
diff --git a/servers/inference_server/pipelines/pipelines/huggingface/perspective_export.py b/servers/inference_bridge/pipelines/pipelines/huggingface/perspective_export.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/huggingface/perspective_export.py
rename to servers/inference_bridge/pipelines/pipelines/huggingface/perspective_export.py
diff --git a/servers/inference_server/pipelines/pipelines/huggingface/types.py b/servers/inference_bridge/pipelines/pipelines/huggingface/types.py
similarity index 96%
rename from servers/inference_server/pipelines/pipelines/huggingface/types.py
rename to servers/inference_bridge/pipelines/pipelines/huggingface/types.py
index 40bd5aad..afaf1eec 100644
--- a/servers/inference_server/pipelines/pipelines/huggingface/types.py
+++ b/servers/inference_bridge/pipelines/pipelines/huggingface/types.py
@@ -151,13 +151,13 @@ class DatasetAnnotation:
id: str
content: str
annotation: dict
- manuallyAdjusted: bool
+ manually_adjusted: bool
embedding: Optional[Any]
- fromUser: str
- fromTeam: str
- createdAt: str
- updatedAt: str
- overallRating: Optional[Any]
+ from_user: str
+ from_team: str
+ created_at: str
+ updated_at: str
+ overall_rating: Optional[Any]
@dataclass
@@ -172,18 +172,18 @@ class DatasetRow:
status: str
flags: int
meta: dict
- fromUser: str
- fromTeam: str
+ from_user: str
+ from_team: str
embeddings: List[Any]
- createdAt: str
- updatedAt: str
+ created_at: str
+ updated_at: str
name: Optional[str]
width: Optional[int]
height: Optional[int]
format: Optional[str]
license: Optional[str]
- licenseUrl: Optional[str]
- contentAuthor: Optional[str]
+ license_url: Optional[str]
+ content_author: Optional[str]
annotations: List[DatasetAnnotation]
image_column: Optional[str]
diff --git a/servers/inference_server/pipelines/pipelines/io/__init__.py b/servers/inference_bridge/pipelines/pipelines/io/__init__.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/io/__init__.py
rename to servers/inference_bridge/pipelines/pipelines/io/__init__.py
diff --git a/servers/inference_server/pipelines/pipelines/io/image/__init__.py b/servers/inference_bridge/pipelines/pipelines/io/image/__init__.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/io/image/__init__.py
rename to servers/inference_bridge/pipelines/pipelines/io/image/__init__.py
diff --git a/servers/inference_server/pipelines/pipelines/io/image/image_metadata/__init__.py b/servers/inference_bridge/pipelines/pipelines/io/image/image_metadata/__init__.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/io/image/image_metadata/__init__.py
rename to servers/inference_bridge/pipelines/pipelines/io/image/image_metadata/__init__.py
diff --git a/servers/inference_server/pipelines/pipelines/io/image/image_metadata/common_formats/__init__.py b/servers/inference_bridge/pipelines/pipelines/io/image/image_metadata/common_formats/__init__.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/io/image/image_metadata/common_formats/__init__.py
rename to servers/inference_bridge/pipelines/pipelines/io/image/image_metadata/common_formats/__init__.py
diff --git a/servers/inference_server/pipelines/pipelines/io/image/image_metadata/common_formats/iptc_metadata.py b/servers/inference_bridge/pipelines/pipelines/io/image/image_metadata/common_formats/iptc_metadata.py
similarity index 90%
rename from servers/inference_server/pipelines/pipelines/io/image/image_metadata/common_formats/iptc_metadata.py
rename to servers/inference_bridge/pipelines/pipelines/io/image/image_metadata/common_formats/iptc_metadata.py
index a93250a3..ced14179 100644
--- a/servers/inference_server/pipelines/pipelines/io/image/image_metadata/common_formats/iptc_metadata.py
+++ b/servers/inference_bridge/pipelines/pipelines/io/image/image_metadata/common_formats/iptc_metadata.py
@@ -29,16 +29,16 @@ def extract_iptc_metadata(description: str) -> Dict[str, Any]:
Dict[str, Any]: Extracted IPTC metadata.
"""
result: Dict[str, Any] = {}
- caption_match = re.search(r"IPTCCaption:\s*(.+?)(?:\s|$)", description)
+ caption_match = re.search(r"IPTCCaption:\s*(.+)(?:\s|$)", description)
if caption_match:
result["caption"] = caption_match.group(1).strip()
- keywords_match = re.search(r"IPTCKeywords:\s*(.+?)(?:\s|$)", description)
+ keywords_match = re.search(r"IPTCKeywords:\s*(.+)(?:\s|$)", description)
if keywords_match:
result["keywords"] = keywords_match.group(1).strip()
- location_match = re.search(r"IPTCLocation:\s*(.+?)(?:\s|$)", description)
+ location_match = re.search(r"IPTCLocation:\s*(.+)(?:\s|$)", description)
if location_match:
result["location"] = location_match.group(1).strip()
- credits_match = re.search(r"IPTCCredits:\s*(.+?)(?:\s|$)", description)
+ credits_match = re.search(r"IPTCCredits:\s*(.+)(?:\s|$)", description)
if credits_match:
result["credits"] = credits_match.group(1).strip()
# Extend with additional IPTC fields as needed.
diff --git a/servers/inference_server/pipelines/pipelines/io/image/image_metadata/common_formats/xmp_metadata.py b/servers/inference_bridge/pipelines/pipelines/io/image/image_metadata/common_formats/xmp_metadata.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/io/image/image_metadata/common_formats/xmp_metadata.py
rename to servers/inference_bridge/pipelines/pipelines/io/image/image_metadata/common_formats/xmp_metadata.py
diff --git a/servers/inference_server/pipelines/pipelines/io/image/image_metadata/extract_exif.py b/servers/inference_bridge/pipelines/pipelines/io/image/image_metadata/extract_exif.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/io/image/image_metadata/extract_exif.py
rename to servers/inference_bridge/pipelines/pipelines/io/image/image_metadata/extract_exif.py
diff --git a/servers/inference_server/pipelines/pipelines/io/image/load_images.py b/servers/inference_bridge/pipelines/pipelines/io/image/load_images.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/io/image/load_images.py
rename to servers/inference_bridge/pipelines/pipelines/io/image/load_images.py
diff --git a/servers/inference_server/pipelines/pipelines/io/image/types.py b/servers/inference_bridge/pipelines/pipelines/io/image/types.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/io/image/types.py
rename to servers/inference_bridge/pipelines/pipelines/io/image/types.py
diff --git a/servers/inference_server/pipelines/pipelines/jobs/__init__.py b/servers/inference_bridge/pipelines/pipelines/jobs/__init__.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/jobs/__init__.py
rename to servers/inference_bridge/pipelines/pipelines/jobs/__init__.py
diff --git a/servers/inference_server/pipelines/pipelines/jobs/dataset_import_job.py b/servers/inference_bridge/pipelines/pipelines/jobs/dataset_import_job.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/jobs/dataset_import_job.py
rename to servers/inference_bridge/pipelines/pipelines/jobs/dataset_import_job.py
diff --git a/servers/inference_server/pipelines/pipelines/jobs/image_metadata.py b/servers/inference_bridge/pipelines/pipelines/jobs/image_metadata.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/jobs/image_metadata.py
rename to servers/inference_bridge/pipelines/pipelines/jobs/image_metadata.py
diff --git a/servers/inference_server/pipelines/pipelines/jobs/omi.py b/servers/inference_bridge/pipelines/pipelines/jobs/omi.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/jobs/omi.py
rename to servers/inference_bridge/pipelines/pipelines/jobs/omi.py
diff --git a/servers/inference_server/pipelines/pipelines/perspectives/__init__.py b/servers/inference_bridge/pipelines/pipelines/perspectives/__init__.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/perspectives/__init__.py
rename to servers/inference_bridge/pipelines/pipelines/perspectives/__init__.py
diff --git a/servers/inference_server/pipelines/pipelines/perspectives/assets.py b/servers/inference_bridge/pipelines/pipelines/perspectives/assets.py
similarity index 62%
rename from servers/inference_server/pipelines/pipelines/perspectives/assets.py
rename to servers/inference_bridge/pipelines/pipelines/perspectives/assets.py
index ef1f4a2e..1b4fef19 100644
--- a/servers/inference_server/pipelines/pipelines/perspectives/assets.py
+++ b/servers/inference_bridge/pipelines/pipelines/perspectives/assets.py
@@ -1,18 +1,143 @@
# SPDX-License-Identifier: Apache-2.0
"""Assets and ops for basic text captioning."""
+import asyncio
+import json
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List
import dagster as dg
import pandas as pd
+from loguru import logger
+from tqdm.asyncio import tqdm_asyncio
+
from graphcap.perspectives import get_perspective, get_synthesizer
from ..common.logging import write_caption_results
from ..perspectives.jobs.config import PerspectivePipelineConfig
from ..providers.util import get_provider
+# File constants
+JOB_INFO_FILENAME = "job_info.json"
+CAPTIONS_FILENAME = "captions.jsonl"
+
+# Temporary batch processing function to replace BaseCaptionProcessor.process_batch
+# This will be replaced with Kafka-based processing in the future
+async def process_images_in_batch(
+ processor,
+ provider,
+ image_paths,
+ model="gemini-2.0-flash-exp",
+ max_tokens=4096,
+ temperature=0.8,
+ top_p=0.9,
+ repetition_penalty=1.15,
+ max_concurrent=3,
+ output_dir=None,
+ global_context=None,
+ contexts=None,
+ name=None,
+):
+ """
+ Temporary batch processing function to replace BaseCaptionProcessor.process_batch.
+ Will be replaced with Kafka-based processing in the future.
+
+ Processes multiple images by calling process_single for each in parallel.
+ """
+ logger.info(f"[DEPRECATED] Processing {len(image_paths)} images with {provider.name}")
+ logger.info(f"Using max concurrency of {max_concurrent} requests")
+
+ # Create job directory for output if requested
+ job_dir = None
+ if output_dir:
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ job_dir = output_dir / f"batch_{name or timestamp}"
+ job_dir.mkdir(parents=True, exist_ok=True)
+ logger.info(f"Writing results to {job_dir}")
+
+ # Create captions.jsonl file and job_info.json with basic info
+ with open(job_dir / JOB_INFO_FILENAME, "w") as f:
+ job_info = {
+ "started_at": timestamp,
+ "provider": provider.name,
+ "model": model,
+ "config_name": getattr(processor, "config_name", name),
+ "version": getattr(processor, "version", "1.0"),
+ "total_images": len(image_paths),
+ "global_context": global_context,
+ "note": "This is a temporary implementation of batch processing until Kafka-based processing is implemented."
+ }
+ json.dump(job_info, f, indent=2)
+
+ # Process images in parallel with limited concurrency
+ semaphore = asyncio.Semaphore(max_concurrent)
+
+ async def process_image(path):
+ async with semaphore:
+ try:
+ result = await processor.process_single(
+ provider=provider,
+ image_path=path,
+ model=model,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ context=contexts.get(path.name) if contexts else None,
+ global_context=global_context,
+ )
+
+ caption_data = {
+ "filename": f"./{path.name}",
+ "config_name": getattr(processor, "config_name", name),
+ "version": getattr(processor, "version", "1.0"),
+ "model": model,
+ "provider": provider.name,
+ "parsed": result,
+ }
+
+ # Write to captions.jsonl if job_dir exists
+ if job_dir:
+ with open(job_dir / CAPTIONS_FILENAME, "a") as f:
+ f.write(json.dumps(caption_data) + "\n")
+
+ return caption_data
+ except Exception as e:
+ logger.error(f"Error processing {path}: {e}")
+ error_data = {
+ "filename": f"./{path.name}",
+ "config_name": getattr(processor, "config_name", name),
+ "version": getattr(processor, "version", "1.0"),
+ "model": model,
+ "provider": provider.name,
+ "parsed": {"error": str(e)},
+ }
+
+ # Write error to captions.jsonl if job_dir exists
+ if job_dir:
+ with open(job_dir / CAPTIONS_FILENAME, "a") as f:
+ f.write(json.dumps(error_data) + "\n")
+
+ return error_data
+
+ tasks = [process_image(path) for path in image_paths]
+ results = await tqdm_asyncio.gather(*tasks, desc=f"Processing images with {provider.name}")
+
+ # Update job_info.json with completion info
+ if job_dir:
+ with open(job_dir / JOB_INFO_FILENAME, "r") as f:
+ job_info = json.load(f)
+
+ job_info["completed_at"] = datetime.now().strftime("%Y%m%d_%H%M%S")
+ job_info["success_count"] = sum(1 for r in results if "error" not in r["parsed"])
+ job_info["failed_count"] = sum(1 for r in results if "error" in r["parsed"])
+
+ with open(job_dir / JOB_INFO_FILENAME, "w") as f:
+ json.dump(job_info, f, indent=2)
+
+ return results
+
@dg.asset(
group_name="perspectives",
@@ -45,11 +170,13 @@ async def perspective_caption(
processor = get_perspective(perspective)
try:
- # Process images in batch
+ # Process images using the temporary batch processing function
image_paths = [Path(image) for image in perspective_image_list]
- caption_data_list = await processor.process_batch(
+ caption_data_list = await process_images_in_batch(
+ processor,
client,
image_paths,
+ model=getattr(provider_config, "model", "gemini-2.0-flash-exp"),
output_dir=Path(io_config.run_dir),
global_context=perspective_config.global_context,
name=perspective,
@@ -128,8 +255,15 @@ async def synthesizer_caption(
image_dir = Path(io_config.output_dir) / "images"
paths = [image_dir / path for path in caption_contexts.keys()]
- results = await synthesizer.process_batch(
- client, paths, output_dir=Path(io_config.run_dir), contexts=caption_contexts, name="synthesized_caption"
+
+ # Use the temporary batch processing function
+ results = await process_images_in_batch(
+ synthesizer,
+ client,
+ paths,
+ output_dir=Path(io_config.run_dir),
+ contexts=caption_contexts,
+ name="synthesized_caption"
)
# Format the results to match the perspective_caption output
diff --git a/servers/inference_server/pipelines/pipelines/perspectives/jobs/__init__.py b/servers/inference_bridge/pipelines/pipelines/perspectives/jobs/__init__.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/perspectives/jobs/__init__.py
rename to servers/inference_bridge/pipelines/pipelines/perspectives/jobs/__init__.py
diff --git a/servers/inference_server/pipelines/pipelines/perspectives/jobs/basic_perspective_pipeline.py b/servers/inference_bridge/pipelines/pipelines/perspectives/jobs/basic_perspective_pipeline.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/perspectives/jobs/basic_perspective_pipeline.py
rename to servers/inference_bridge/pipelines/pipelines/perspectives/jobs/basic_perspective_pipeline.py
diff --git a/servers/inference_server/pipelines/pipelines/perspectives/jobs/config.py b/servers/inference_bridge/pipelines/pipelines/perspectives/jobs/config.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/perspectives/jobs/config.py
rename to servers/inference_bridge/pipelines/pipelines/perspectives/jobs/config.py
diff --git a/servers/inference_server/pipelines/pipelines/perspectives/types.py b/servers/inference_bridge/pipelines/pipelines/perspectives/types.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/perspectives/types.py
rename to servers/inference_bridge/pipelines/pipelines/perspectives/types.py
diff --git a/servers/inference_server/pipelines/pipelines/providers/__init__.py b/servers/inference_bridge/pipelines/pipelines/providers/__init__.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/providers/__init__.py
rename to servers/inference_bridge/pipelines/pipelines/providers/__init__.py
diff --git a/servers/inference_bridge/pipelines/pipelines/providers/assets.py b/servers/inference_bridge/pipelines/pipelines/providers/assets.py
new file mode 100644
index 00000000..f3a606c0
--- /dev/null
+++ b/servers/inference_bridge/pipelines/pipelines/providers/assets.py
@@ -0,0 +1,55 @@
+# SPDX-License-Identifier: Apache-2.0
+"""Assets for loading provider configurations."""
+
+import dagster as dg
+
+from graphcap.providers.types import ProviderConfig
+
+from ..common.resources import ProviderConfigFile
+
+deprecation_msg = "Provider configuration is now managed by the data service"
+
+@dg.asset(compute_kind="python", group_name="providers")
+def provider_list(
+ context: dg.AssetExecutionContext, provider_config_file: ProviderConfigFile
+) -> dict[str, ProviderConfig]:
+ """Loads the list of providers (now from data service API)."""
+ # TODO: Call data service API to get providers instead of loading from file
+ # For now, return an empty dictionary to avoid errors
+ context.log.info(deprecation_msg)
+
+ # Sample provider for testing
+ gemini_config = ProviderConfig(
+ kind="gemini",
+ environment="cloud",
+ env_var="GOOGLE_API_KEY",
+ base_url="https://generativelanguage.googleapis.com/v1beta",
+ models=["gemini-2.0-flash-exp"],
+ fetch_models=False,
+ )
+
+ providers = {"gemini": gemini_config}
+
+ context.add_output_metadata(
+ {
+ "num_providers": len(providers),
+ "providers": "gemini: gemini-2.0-flash-exp",
+ "note": deprecation_msg
+ }
+ )
+ return providers
+
+
+@dg.asset(compute_kind="python", group_name="providers")
+def default_provider(context: dg.AssetExecutionContext, provider_config_file: ProviderConfigFile) -> str | None:
+ """Returns the default provider."""
+ selected_provider_name = provider_config_file.default_provider
+ context.log.info(f"Using default provider: {selected_provider_name}")
+
+ context.add_output_metadata(
+ {
+ "selected_provider": selected_provider_name,
+ "note": deprecation_msg
+ }
+ )
+ return selected_provider_name
diff --git a/servers/inference_bridge/pipelines/pipelines/providers/util.py b/servers/inference_bridge/pipelines/pipelines/providers/util.py
new file mode 100644
index 00000000..58b154ba
--- /dev/null
+++ b/servers/inference_bridge/pipelines/pipelines/providers/util.py
@@ -0,0 +1,16 @@
+
+
+def get_provider(config_path: str, default_provider: str):
+ """Instantiates the client based on the provider configuration.
+
+ Args:
+ config_path (str): Path to the provider configuration file (deprecated).
+ default_provider (str): The name of the default provider.
+
+ Returns:
+ The instantiated client.
+ """
+ # TODO: Get provider configuration from the data service API
+ # For now, hardcode a default configuration for Gemini
+ raise NotImplementedError("v2 provider configuration not implemented")
+
diff --git a/servers/inference_server/pipelines/pipelines/start.py b/servers/inference_bridge/pipelines/pipelines/start.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines/start.py
rename to servers/inference_bridge/pipelines/pipelines/start.py
diff --git a/servers/inference_server/pipelines/pipelines_tests/__init__.py b/servers/inference_bridge/pipelines/pipelines_tests/__init__.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines_tests/__init__.py
rename to servers/inference_bridge/pipelines/pipelines_tests/__init__.py
diff --git a/servers/inference_server/pipelines/pipelines_tests/test_assets.py b/servers/inference_bridge/pipelines/pipelines_tests/test_assets.py
similarity index 100%
rename from servers/inference_server/pipelines/pipelines_tests/test_assets.py
rename to servers/inference_bridge/pipelines/pipelines_tests/test_assets.py
diff --git a/servers/inference_server/pipelines/pyproject.toml b/servers/inference_bridge/pipelines/pyproject.toml
similarity index 100%
rename from servers/inference_server/pipelines/pyproject.toml
rename to servers/inference_bridge/pipelines/pyproject.toml
diff --git a/servers/inference_server/pipelines/setup.cfg b/servers/inference_bridge/pipelines/setup.cfg
similarity index 100%
rename from servers/inference_server/pipelines/setup.cfg
rename to servers/inference_bridge/pipelines/setup.cfg
diff --git a/servers/inference_server/pipelines/setup.py b/servers/inference_bridge/pipelines/setup.py
similarity index 100%
rename from servers/inference_server/pipelines/setup.py
rename to servers/inference_bridge/pipelines/setup.py
diff --git a/servers/inference_server/pipelines/uv.lock b/servers/inference_bridge/pipelines/uv.lock
similarity index 100%
rename from servers/inference_server/pipelines/uv.lock
rename to servers/inference_bridge/pipelines/uv.lock
diff --git a/servers/inference_server/pyproject.toml b/servers/inference_bridge/pyproject.toml
similarity index 100%
rename from servers/inference_server/pyproject.toml
rename to servers/inference_bridge/pyproject.toml
diff --git a/servers/inference_server/pytest.ini b/servers/inference_bridge/pytest.ini
similarity index 100%
rename from servers/inference_server/pytest.ini
rename to servers/inference_bridge/pytest.ini
diff --git a/servers/inference_server/scripts/__init__.py b/servers/inference_bridge/scripts/__init__.py
similarity index 100%
rename from servers/inference_server/scripts/__init__.py
rename to servers/inference_bridge/scripts/__init__.py
diff --git a/servers/inference_server/scripts/__main__.py b/servers/inference_bridge/scripts/__main__.py
similarity index 100%
rename from servers/inference_server/scripts/__main__.py
rename to servers/inference_bridge/scripts/__main__.py
diff --git a/servers/inference_server/scripts/config_writer.py b/servers/inference_bridge/scripts/config_writer.py
similarity index 100%
rename from servers/inference_server/scripts/config_writer.py
rename to servers/inference_bridge/scripts/config_writer.py
diff --git a/servers/inference_server/scripts/setup.py b/servers/inference_bridge/scripts/setup.py
similarity index 100%
rename from servers/inference_server/scripts/setup.py
rename to servers/inference_bridge/scripts/setup.py
diff --git a/servers/inference_server/server/.dep_hash b/servers/inference_bridge/server/.dep_hash
similarity index 100%
rename from servers/inference_server/server/.dep_hash
rename to servers/inference_bridge/server/.dep_hash
diff --git a/servers/inference_server/server/.dockerignore b/servers/inference_bridge/server/.dockerignore
similarity index 100%
rename from servers/inference_server/server/.dockerignore
rename to servers/inference_bridge/server/.dockerignore
diff --git a/servers/inference_server/server/.env.local.template b/servers/inference_bridge/server/.env.local.template
similarity index 100%
rename from servers/inference_server/server/.env.local.template
rename to servers/inference_bridge/server/.env.local.template
diff --git a/servers/inference_server/server/Dockerfile.server.dev b/servers/inference_bridge/server/Dockerfile.server.dev
similarity index 100%
rename from servers/inference_server/server/Dockerfile.server.dev
rename to servers/inference_bridge/server/Dockerfile.server.dev
diff --git a/servers/inference_server/server/README.md b/servers/inference_bridge/server/README.md
similarity index 100%
rename from servers/inference_server/server/README.md
rename to servers/inference_bridge/server/README.md
diff --git a/servers/inference_server/server/Taskfile.inference.yml b/servers/inference_bridge/server/Taskfile.inference.yml
similarity index 69%
rename from servers/inference_server/server/Taskfile.inference.yml
rename to servers/inference_bridge/server/Taskfile.inference.yml
index a2c64d9b..66ca0b6b 100644
--- a/servers/inference_server/server/Taskfile.inference.yml
+++ b/servers/inference_bridge/server/Taskfile.inference.yml
@@ -4,26 +4,26 @@ tasks:
dev:
desc: Start the inference server container with watch mode
cmds:
- - docker compose -f ./docker-compose.yml --env-file ./workspace/config/.env up graphcap_server --watch --build
+ - docker compose -f ./docker-compose.yml --env-file ./workspace/config/.env up inference_bridge --watch --build
start:
desc: Start the inference server container
cmds:
- - docker compose -f ./docker-compose.yml --env-file ./workspace/config/.env up -d graphcap_server
+ - docker compose -f ./docker-compose.yml --env-file ./workspace/config/.env up -d inference_bridge
stop:
desc: Stop the inference server container
cmds:
- - docker compose -f ./docker-compose.yml --env-file ./workspace/config/.env stop graphcap_server
+ - docker compose -f ./docker-compose.yml --env-file ./workspace/config/.env stop inference_bridge
logs:
desc: View logs for the inference server container
cmds:
- - docker compose -f ./docker-compose.yml logs -f graphcap_server
+ - docker compose -f ./docker-compose.yml logs -f inference_bridge
rebuild:
desc: Rebuild and restart the inference server container
cmds:
- - docker compose -f ./docker-compose.yml --env-file ./workspace/config/.env build graphcap_server
- - docker compose -f ./docker-compose.yml --env-file ./workspace/config/.env up -d --force-recreate graphcap_server
+ - docker compose -f ./docker-compose.yml --env-file ./workspace/config/.env build inference_bridge
+ - docker compose -f ./docker-compose.yml --env-file ./workspace/config/.env up -d --force-recreate inference_bridge
diff --git a/servers/inference_server/server/__init__.py b/servers/inference_bridge/server/__init__.py
similarity index 100%
rename from servers/inference_server/server/__init__.py
rename to servers/inference_bridge/server/__init__.py
diff --git a/servers/inference_server/server/_scripts/endpoints-entrypoint.sh b/servers/inference_bridge/server/_scripts/endpoints-entrypoint.sh
similarity index 100%
rename from servers/inference_server/server/_scripts/endpoints-entrypoint.sh
rename to servers/inference_bridge/server/_scripts/endpoints-entrypoint.sh
diff --git a/servers/inference_server/server/_scripts/gunicorn.conf.py b/servers/inference_bridge/server/_scripts/gunicorn.conf.py
similarity index 100%
rename from servers/inference_server/server/_scripts/gunicorn.conf.py
rename to servers/inference_bridge/server/_scripts/gunicorn.conf.py
diff --git a/servers/inference_server/server/pyproject.toml b/servers/inference_bridge/server/pyproject.toml
similarity index 100%
rename from servers/inference_server/server/pyproject.toml
rename to servers/inference_bridge/server/pyproject.toml
diff --git a/servers/inference_server/server/server/__init__.py b/servers/inference_bridge/server/server/__init__.py
similarity index 100%
rename from servers/inference_server/server/server/__init__.py
rename to servers/inference_bridge/server/server/__init__.py
diff --git a/servers/inference_server/server/server/config.py b/servers/inference_bridge/server/server/config.py
similarity index 100%
rename from servers/inference_server/server/server/config.py
rename to servers/inference_bridge/server/server/config.py
diff --git a/servers/inference_server/server/server/config/router.py b/servers/inference_bridge/server/server/config/router.py
similarity index 100%
rename from servers/inference_server/server/server/config/router.py
rename to servers/inference_bridge/server/server/config/router.py
diff --git a/servers/inference_server/server/server/db.py b/servers/inference_bridge/server/server/db.py
similarity index 100%
rename from servers/inference_server/server/server/db.py
rename to servers/inference_bridge/server/server/db.py
diff --git a/servers/inference_server/server/server/dependencies.py b/servers/inference_bridge/server/server/dependencies.py
similarity index 100%
rename from servers/inference_server/server/server/dependencies.py
rename to servers/inference_bridge/server/server/dependencies.py
diff --git a/servers/inference_server/server/server/features/perspectives/__init__.py b/servers/inference_bridge/server/server/features/perspectives/__init__.py
similarity index 100%
rename from servers/inference_server/server/server/features/perspectives/__init__.py
rename to servers/inference_bridge/server/server/features/perspectives/__init__.py
diff --git a/servers/inference_server/server/server/features/perspectives/models.py b/servers/inference_bridge/server/server/features/perspectives/models.py
similarity index 78%
rename from servers/inference_server/server/server/features/perspectives/models.py
rename to servers/inference_bridge/server/server/features/perspectives/models.py
index f975e49c..777a4350 100644
--- a/servers/inference_server/server/server/features/perspectives/models.py
+++ b/servers/inference_bridge/server/server/features/perspectives/models.py
@@ -7,7 +7,6 @@
from typing import Any, Dict, List, Optional, Union
-from fastapi import File, Form, UploadFile
from pydantic import BaseModel, Field
# Field description constants
@@ -138,50 +137,11 @@ class CaptionResponse(BaseModel):
"""Response model for a generated caption."""
perspective: str = Field(..., description="Name of the perspective used")
- provider: str = Field("gemini", description="Name of the provider used")
+ provider: str = Field(..., description="Name of the provider used")
result: dict = Field(..., description="Structured caption result")
raw_text: Optional[str] = Field(None, description="Raw text response from the model")
-# Form data model for multipart/form-data requests with file uploads
-class CaptionFormRequest:
- """Form request model for generating a caption with a perspective using file upload."""
-
- def __init__(
- self,
- perspective: str = Form(..., description=DESC_PERSPECTIVE_NAME),
- file: Optional[UploadFile] = File(None, description="Image file to caption"),
- url: Optional[str] = Form(None, description="URL of the image to caption"),
- base64: Optional[str] = Form(None, description="Base64-encoded image data"),
- max_tokens: Optional[int] = Form(4096, description=DESC_MAX_TOKENS),
- temperature: Optional[float] = Form(0.8, description=DESC_TEMPERATURE),
- top_p: Optional[float] = Form(0.9, description=DESC_TOP_P),
- repetition_penalty: Optional[float] = Form(1.15, description=DESC_REPETITION_PENALTY),
- global_context: Optional[str] = Form(None, description=DESC_GLOBAL_CONTEXT),
- context: Optional[str] = Form(None, description="Additional context for the caption (JSON array string)"),
- resize_resolution: Optional[str] = Form(None, description=DESC_RESIZE_RESOLUTION),
- ):
- self.perspective = perspective
- self.file = file
- self.url = url
- self.base64 = base64
- self.max_tokens = max_tokens
- self.temperature = temperature
- self.top_p = top_p
- self.repetition_penalty = repetition_penalty
- self.global_context = global_context
- self.resize_resolution = resize_resolution
-
- # Parse context from JSON string if provided
- self.context = None
- if context:
- import json
-
- try:
- self.context = json.loads(context)
- except json.JSONDecodeError:
- # If not valid JSON array, treat as a single context item
- self.context = [context]
class CaptionPathRequest(BaseModel):
@@ -189,7 +149,9 @@ class CaptionPathRequest(BaseModel):
perspective: str = Field(..., description=DESC_PERSPECTIVE_NAME)
image_path: str = Field(..., description="Path to the image file in the workspace")
- provider: str = Field("gemini", description="Name of the provider to use")
+ provider: str = Field(..., description="Name of the provider to use")
+ provider_config: dict = Field(..., description="Provider configuration")
+ model: str = Field(..., description="Model name to use for processing")
max_tokens: Optional[int] = Field(4096, description=DESC_MAX_TOKENS)
temperature: Optional[float] = Field(0.8, description=DESC_TEMPERATURE)
top_p: Optional[float] = Field(0.9, description=DESC_TOP_P)
@@ -204,6 +166,15 @@ class Config:
"perspective": "custom_caption",
"image_path": "/workspace/datasets/example.jpg",
"provider": "gemini",
+ "model": "gemini-pro-vision",
+ "provider_config": {
+ "name": "gemini",
+ "kind": "gemini",
+ "environment": "cloud",
+ "api_key": "your_api_key_here",
+ "base_url": "https://generativelanguage.googleapis.com/v1beta",
+ "models": ["gemini-pro-vision"]
+ },
"max_tokens": 4096,
"temperature": 0.8,
"resize_resolution": "HD_720P",
diff --git a/servers/inference_server/server/server/features/perspectives/router.py b/servers/inference_bridge/server/server/features/perspectives/router.py
similarity index 56%
rename from servers/inference_server/server/server/features/perspectives/router.py
rename to servers/inference_bridge/server/server/features/perspectives/router.py
index 3671273e..58d7df3d 100644
--- a/servers/inference_server/server/server/features/perspectives/router.py
+++ b/servers/inference_bridge/server/server/features/perspectives/router.py
@@ -19,24 +19,15 @@
from pathlib import Path
from typing import List, Optional
-from fastapi import APIRouter, BackgroundTasks, File, Form, HTTPException, UploadFile, status
+from fastapi import APIRouter, HTTPException, status
from loguru import logger
-from ...utils.resizing import ResolutionPreset, log_resize_options, resize_image
-from .models import (
- CaptionPathRequest,
- CaptionResponse,
- ModuleListResponse,
- ModulePerspectivesResponse,
- PerspectiveListResponse,
-)
-from .service import (
- generate_caption,
- get_available_modules,
- get_available_perspectives,
- get_perspectives_by_module,
- save_uploaded_file,
-)
+from ...utils.resizing import (ResolutionPreset, log_resize_options,
+ resize_image)
+from .models import (CaptionPathRequest, CaptionResponse, ModuleListResponse,
+ ModulePerspectivesResponse, PerspectiveListResponse)
+from .service import (generate_caption, get_available_modules,
+ get_available_perspectives, get_perspectives_by_module)
router = APIRouter(prefix="/perspectives", tags=["perspectives"])
@@ -53,127 +44,6 @@ async def list_perspectives() -> PerspectiveListResponse:
return PerspectiveListResponse(perspectives=perspectives)
-@router.post("/caption", response_model=CaptionResponse, status_code=status.HTTP_200_OK)
-async def create_caption(
- background_tasks: BackgroundTasks,
- file: UploadFile = File(..., description="Image file to upload"),
- perspective: str = Form(..., description="Name of the perspective to use"),
- provider: str = Form("gemini", description="Name of the provider to use"),
- max_tokens: Optional[int] = Form(4096, description="Maximum number of tokens"),
- temperature: Optional[float] = Form(0.8, description="Temperature for generation"),
- top_p: Optional[float] = Form(0.9, description="Top-p sampling parameter"),
- repetition_penalty: Optional[float] = Form(1.15, description="Repetition penalty"),
- global_context: Optional[str] = Form(None, description="Global context for the caption"),
- context: Optional[str] = Form(None, description="Additional context for the caption as JSON array string"),
- resize_resolution: Optional[str] = Form(None, description="Resolution to resize to (None to disable resizing)"),
-) -> CaptionResponse:
- """
- Generate a caption for an image using a perspective.
-
- This endpoint supports file uploads only.
-
- Args:
- background_tasks: Background tasks for cleanup
- file: Image file to upload (required)
- perspective: Name of the perspective to use (required)
- provider: Name of the provider to use (optional, default: "default")
- max_tokens: Maximum number of tokens (optional, default: 4096)
- temperature: Temperature for generation (optional, default: 0.8)
- top_p: Top-p sampling parameter (optional, default: 0.9)
- repetition_penalty: Repetition penalty (optional, default: 1.15)
- context: JSON array string of context items (optional)
- global_context: Global context string (optional)
- resize_resolution: Resolution to resize to (optional, default: None - no resizing)
-
- Returns:
- Generated caption with structured result and optional raw text
-
- Raises:
- HTTPException: If the request is invalid or processing fails
- """
- try:
- # Parse context from JSON string if provided
- parsed_context = _parse_context(context)
-
- # Process the uploaded file
- image_path = await save_uploaded_file(file)
-
- # Log resize options
- options = {"resize_resolution": resize_resolution}
- log_resize_options(options)
-
- # Resize the image if resize_resolution is provided
- if resize_resolution:
- try:
- # Get the resolution enum value
- try:
- resolution = ResolutionPreset[resize_resolution]
- except (KeyError, ValueError):
- logger.warning(f"Invalid resolution: {resize_resolution}. Using HD_720P.")
- resolution = ResolutionPreset.HD_720P
-
- # Create temporary file for resized image
- suffix = os.path.splitext(image_path)[1]
- fd, resized_path = tempfile.mkstemp(suffix=suffix)
- os.close(fd)
-
- # Resize the image
- logger.info(f"Resizing image to {resolution.name} ({resolution.value})")
- resized_img = resize_image(image_path, resolution)
- resized_img.save(resized_path)
-
- # Add cleanup task for original image
- background_tasks.add_task(lambda: os.unlink(image_path) if os.path.exists(image_path) else None)
-
- # Use the resized image
- image_path = Path(resized_path)
- logger.info(f"Image resized successfully to {resolution.name}")
- except Exception as e:
- logger.error(f"Error resizing image: {str(e)}")
- logger.warning("Using original image instead")
- # Continue with original image if resizing fails
-
- # Add cleanup task
- background_tasks.add_task(lambda: os.unlink(image_path) if os.path.exists(image_path) else None)
-
- # Generate the caption
- caption_data = await generate_caption(
- perspective_name=perspective,
- image_path=image_path,
- max_tokens=max_tokens,
- temperature=temperature,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- context=parsed_context,
- global_context=global_context,
- provider_name=provider,
- )
-
- # Log the caption data for debugging
- logger.debug(f"Caption data: {caption_data}")
-
- # Extract the parsed result and raw text
- parsed_result = caption_data.get("parsed", {})
- raw_text = caption_data.get("raw_text")
-
- # If parsed result is empty but raw_text exists, try to create a basic result
- if not parsed_result and raw_text:
- logger.warning("Parsed result is empty but raw_text exists. Creating basic result.")
- parsed_result = {"text": raw_text}
-
- # Return the response
- return CaptionResponse(
- perspective=perspective,
- provider=provider,
- result=parsed_result,
- raw_text=raw_text,
- )
- except Exception as e:
- logger.error(f"Error creating caption: {str(e)}")
- if isinstance(e, HTTPException):
- raise
- raise HTTPException(status_code=500, detail=f"Error creating caption: {str(e)}")
-
@router.post("/caption-from-path", response_model=CaptionResponse, status_code=status.HTTP_200_OK)
async def create_caption_from_path(
@@ -204,10 +74,27 @@ async def create_caption_from_path(
# Process context
context = _process_context(request.context)
+ # Validate that provider_config is present
+ if not hasattr(request, 'provider_config') or not request.provider_config:
+ logger.error(f"No provider configuration provided for {request.provider}")
+ raise HTTPException(
+ status_code=400,
+ detail=f"Provider configuration not provided for '{request.provider}'. Please include provider_config in the request."
+ )
+
+ # Validate that model is provided
+ if not hasattr(request, 'model') or not request.model:
+ logger.error(f"No model specified for {request.provider}")
+ raise HTTPException(
+ status_code=400,
+ detail=f"Model name not provided for '{request.provider}'. Please include model in the request."
+ )
+
# Generate the caption
caption_data = await generate_caption(
perspective_name=request.perspective,
image_path=image_path,
+ model=request.model,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
@@ -215,6 +102,7 @@ async def create_caption_from_path(
context=context,
global_context=request.global_context,
provider_name=request.provider,
+ provider_config=request.provider_config,
)
# Clean up temporary file if we created one
@@ -253,9 +141,9 @@ async def _resize_image_if_needed(image_path: Path, resize_resolution: Optional[
try:
resolution = ResolutionPreset[resize_resolution]
except (KeyError, ValueError):
- logger.warning(f"Invalid resolution: {resize_resolution}. Using HD_720P.")
- resolution = ResolutionPreset.HD_720P
-
+ logger.warning(f"Invalid resolution: {resize_resolution}. Skipping resize.")
+ return image_path, temp_path
+
# Create temporary file for resized image
suffix = os.path.splitext(str(image_path))[1]
fd, resized_path = tempfile.mkstemp(suffix=suffix)
@@ -330,20 +218,6 @@ def _prepare_caption_response(caption_data: dict, perspective: str, provider: st
)
-def _parse_context(context_str) -> Optional[List[str]]:
- """Parse context from a JSON string."""
- if not context_str or not isinstance(context_str, str):
- return None
-
- try:
- context = json.loads(context_str)
- if isinstance(context, list):
- return context
- return [context_str]
- except json.JSONDecodeError:
- return [context_str]
-
-
@router.get("/modules", response_model=ModuleListResponse)
diff --git a/servers/inference_server/server/server/features/perspectives/service.py b/servers/inference_bridge/server/server/features/perspectives/service.py
similarity index 69%
rename from servers/inference_server/server/server/features/perspectives/service.py
rename to servers/inference_bridge/server/server/features/perspectives/service.py
index d52d969b..217486df 100644
--- a/servers/inference_server/server/server/features/perspectives/service.py
+++ b/servers/inference_bridge/server/server/features/perspectives/service.py
@@ -7,23 +7,19 @@
import base64
import os
-import socket
import tempfile
+from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional
-from collections import defaultdict
import aiohttp
-from fastapi import HTTPException, UploadFile
-from graphcap.perspectives import (
- get_perspective,
- get_perspective_list,
-)
-from graphcap.providers.clients.base_client import BaseClient
+from fastapi import HTTPException
from loguru import logger
-from ..providers.service import get_provider_manager
-from .models import ModuleInfo, PerspectiveInfo, PerspectiveSchema, SchemaField, TableColumn
+from graphcap.perspectives import get_perspective, get_perspective_list
+
+from .models import (ModuleInfo, PerspectiveInfo, PerspectiveSchema,
+ SchemaField, TableColumn)
async def download_image(url: str) -> Path:
@@ -108,7 +104,7 @@ def load_perspective_schema(perspective_name: str) -> Optional[PerspectiveSchema
try:
# Import perspective function
from graphcap.perspectives import get_perspective
-
+
# Get the perspective processor
perspective = get_perspective(perspective_name)
if perspective and hasattr(perspective, 'config'):
@@ -263,6 +259,7 @@ def get_perspectives_by_module(module_name: str) -> List[PerspectiveInfo]:
async def generate_caption(
perspective_name: str,
image_path: Path,
+ model: str,
max_tokens: Optional[int] = 4096,
temperature: Optional[float] = 0.8,
top_p: Optional[float] = 0.9,
@@ -270,6 +267,7 @@ async def generate_caption(
context: Optional[List[str]] = None,
global_context: Optional[str] = None,
provider_name: str = "gemini",
+ provider_config: Optional[dict] = None,
) -> Dict:
"""
Generate a caption for an image using a perspective.
@@ -277,6 +275,7 @@ async def generate_caption(
Args:
perspective_name: Name of the perspective to use
image_path: Path to the image file
+ model: Model name to use for processing
max_tokens: Maximum number of tokens in the response
temperature: Temperature for generation
top_p: Top-p sampling parameter
@@ -284,6 +283,7 @@ async def generate_caption(
context: Additional context for the caption
global_context: Global context for the caption
provider_name: Name of the provider to use (default: "gemini")
+ provider_config: Full provider configuration if available
Returns:
Caption data
@@ -295,35 +295,23 @@ async def generate_caption(
# Get the perspective
perspective = get_perspective(perspective_name)
- # Get the provider client from the provider manager
- provider_manager = get_provider_manager()
-
- # Debug: Log available providers
- available_providers = provider_manager.available_providers()
- logger.debug(f"Available providers: {available_providers}")
-
- # Debug: Try to resolve host.docker.internal
- try:
- host_ip = socket.gethostbyname("host.docker.internal")
- logger.debug(f"host.docker.internal resolves to: {host_ip}")
- except socket.gaierror as e:
- logger.warning(f"Could not resolve host.docker.internal: {e}")
-
- try:
- provider: BaseClient = provider_manager.get_client(provider_name)
- # Debug: Log provider details
- logger.debug("Provider details:")
- logger.debug(f" - Name: {provider_name}")
- logger.debug(f" - Kind: {provider.kind}")
- logger.debug(f" - Environment: {provider.environment}")
- logger.debug(f" - Base URL: {provider.base_url}")
- logger.debug(f" - Default Model: {provider.default_model}")
- except ValueError as e:
- logger.error(f"Provider '{provider_name}' not found: {str(e)}")
+ # Create a provider client using the config if provided
+ if provider_config:
+ from ..providers.models import ProviderConfig
+ from ..providers.service import create_provider_client_from_config
+
+ # Convert dict to ProviderConfig
+ config = ProviderConfig(**provider_config)
+ provider = create_provider_client_from_config(config)
+ logger.info(f"Created provider client from provided config for {provider_name}")
+ else:
+ # Legacy path - will likely fail as no provider manager exists
+ logger.error(f"No provider configuration provided for {provider_name}. Caption generation will likely fail.")
+ logger.error("Provider configuration must be provided in the request.")
raise HTTPException(
- status_code=404,
- detail=f"""Provider '{provider_name}' not found.
- Available providers: {', '.join(provider_manager.available_providers())}""",
+ status_code=400,
+ detail=f"""Provider configuration not provided for '{provider_name}'.
+ Provider configuration must be included in the request.""",
)
# Create a temporary output directory
@@ -335,41 +323,28 @@ async def generate_caption(
f"Generating caption for {image_path} using {perspective_name} perspective and {provider_name} provider"
)
- # Check if the perspective has process_batch method
- if hasattr(perspective, "process_batch"):
- logger.info(f"Using process_batch method for {perspective_name}")
- # Use process_batch with a single image to match the pipeline implementation
- caption_data_list = await perspective.process_batch(
- provider=provider,
- image_paths=[image_path],
- output_dir=output_dir,
- max_tokens=max_tokens,
- temperature=temperature,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- global_context=global_context,
- name=perspective_name,
- )
+ # Use process_single directly as process_batch has been deprecated
+ logger.info(f"Using process_single for {perspective_name}")
+ caption_data = await perspective.process_single(
+ provider=provider,
+ image_path=image_path,
+ model=model,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ context=context,
+ global_context=global_context,
+ )
- # Get the first (and only) result
- if not caption_data_list or len(caption_data_list) == 0:
- logger.error(f"No caption data returned for {image_path}")
- raise HTTPException(status_code=500, detail="No caption data returned")
-
- caption_data = caption_data_list[0]
- else:
- # Fallback to process_single if process_batch is not available
- logger.info(f"Falling back to process_single method for {perspective_name}")
- caption_data = await perspective.process_single(
- provider=provider,
- image_path=image_path,
- max_tokens=max_tokens,
- temperature=temperature,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- context=context,
- global_context=global_context,
- )
+ caption_data = {
+ "filename": f"./{image_path.name}",
+ "config_name": getattr(perspective, 'config_name', perspective_name),
+ "version": getattr(perspective, 'version', '1.0'),
+ "model": model,
+ "provider": provider.name,
+ "parsed": caption_data,
+ }
# Log the result
logger.info(f"Caption generated successfully: {caption_data.keys() if caption_data else 'None'}")
@@ -383,35 +358,3 @@ async def generate_caption(
raise HTTPException(status_code=500, detail=f"Error generating caption: {str(e)}")
-async def save_uploaded_file(file: UploadFile) -> Path:
- """
- Save an uploaded file to a temporary location.
-
- Args:
- file: Uploaded file object
-
- Returns:
- Path to the saved file
-
- Raises:
- HTTPException: If the file cannot be saved
- """
- try:
- # Create a temporary file with appropriate extension
- suffix = os.path.splitext(file.filename)[1] if file.filename else ".jpg"
- fd, temp_path = tempfile.mkstemp(suffix=suffix)
- os.close(fd)
- temp_file = Path(temp_path)
-
- # Save the uploaded file
- content = await file.read()
- with open(temp_file, "wb") as f:
- f.write(content)
-
- # Reset file pointer for potential future reads
- await file.seek(0)
-
- return temp_file
- except Exception as e:
- logger.error(f"Error saving uploaded file: {str(e)}")
- raise HTTPException(status_code=400, detail=f"Error saving uploaded file: {str(e)}")
diff --git a/servers/inference_server/server/server/features/providers/__init__.py b/servers/inference_bridge/server/server/features/providers/__init__.py
similarity index 100%
rename from servers/inference_server/server/server/features/providers/__init__.py
rename to servers/inference_bridge/server/server/features/providers/__init__.py
diff --git a/servers/inference_bridge/server/server/features/providers/error_handler.py b/servers/inference_bridge/server/server/features/providers/error_handler.py
new file mode 100644
index 00000000..eaf2282b
--- /dev/null
+++ b/servers/inference_bridge/server/server/features/providers/error_handler.py
@@ -0,0 +1,260 @@
+"""
+# SPDX-License-Identifier: Apache-2.0
+Provider Error Handler
+
+Handles provider-specific error formatting and responses.
+"""
+
+import datetime
+from typing import Any, Dict, Set
+
+from fastapi.responses import JSONResponse
+from pydantic import ValidationError
+
+from .models import ProviderConfig
+
+
+def _extract_invalid_fields(errors) -> Set[str]:
+ """Extract the set of invalid field names from validation errors."""
+ invalid_fields: Set[str] = set()
+
+ for error in errors:
+ loc = error.get("loc", [])
+ if len(loc) > 1:
+ field_name = loc[1] if isinstance(loc[1], str) else str(loc[1])
+ invalid_fields.add(field_name)
+
+ return invalid_fields
+
+
+def _build_invalid_params(errors) -> Dict[str, Dict]:
+ """Build dictionary mapping fields to their error details."""
+ invalid_params = {}
+
+ for error in errors:
+ # Get field location
+ loc = error.get("loc", [])
+ field = ".".join(str(loc) for loc in error.get("loc", [])) if error.get("loc") else ""
+ message = error.get("msg", "Validation error")
+ error_type = error.get("type", "unknown_error")
+
+ # Add context if available
+ context = {}
+ if error.get("ctx"):
+ for key, value in error.get("ctx", {}).items():
+ if key != "expected" or not isinstance(value, list) or len(value) < 5:
+ context[key] = value
+
+ invalid_params[field] = {
+ "message": message,
+ "error_type": error_type
+ }
+
+ if context:
+ invalid_params[field]["context"] = context
+
+ return invalid_params
+
+
+def _generate_error_message(invalid_fields: Set[str]) -> str:
+ """Generate an appropriate error message based on invalid fields."""
+ if len(invalid_fields) == 1:
+ field = next(iter(invalid_fields))
+ return f"Invalid provider configuration: '{field}' parameter is invalid"
+ elif len(invalid_fields) > 1:
+ field_list = "', '".join(sorted(invalid_fields))
+ return f"Invalid provider configuration: Parameters '{field_list}' are invalid"
+ else:
+ return "Invalid provider configuration"
+
+
+def _get_field_from_error(error: dict) -> str:
+ """Extract the field name from the error location."""
+ return ".".join(str(loc) for loc in error.get("loc", [])[1:]) if error.get("loc") else ""
+
+
+def _add_error_type_suggestion(error: dict, field: str, suggestions: list) -> None:
+ """Add suggestion based on error type."""
+ error_type = error.get("type", "")
+
+ if error_type == "missing":
+ suggestions.append(f"Add the missing required parameter: '{field}'")
+ elif error_type == "string_type":
+ suggestions.append(f"Ensure '{field}' is a valid string")
+ elif error_type == "url_parsing":
+ suggestions.append(f"Use a valid URL format for '{field}'")
+ elif error_type and "enum" in error_type:
+ _add_enum_suggestion(error, field, suggestions)
+
+
+def _add_enum_suggestion(error: dict, field: str, suggestions: list) -> None:
+ """Add suggestion for enum validation errors."""
+ valid_values = error.get("ctx", {}).get("expected", [])
+ if valid_values:
+ values_str = ", ".join([f"'{v}'" for v in valid_values])
+ suggestions.append(f"Choose a valid option for '{field}': {values_str}")
+ else:
+ suggestions.append(f"Choose a valid option for '{field}'")
+
+
+def _add_field_specific_suggestion(field: str, suggestions: list) -> None:
+ """Add suggestion based on specific field name."""
+ if field == "api_key":
+ suggestions.append("Check the API key is correct for this provider")
+ elif field == "base_url":
+ suggestions.append("Verify the base URL format matches the provider's API documentation")
+ elif field == "environment":
+ suggestions.append("Valid environment values are typically 'cloud' or 'local'")
+
+
+def _generate_suggestions(errors) -> list:
+ """Generate helpful suggestions based on validation errors."""
+ suggestions = ["Check API key and endpoint URL", "Verify the provider is correctly configured"]
+
+ for error in errors:
+ field = _get_field_from_error(error)
+ _add_error_type_suggestion(error, field, suggestions)
+ _add_field_specific_suggestion(field, suggestions)
+
+ suggestions.append("Check server logs for more details")
+ return list(dict.fromkeys(suggestions)) # Remove duplicates while preserving order
+
+
+def format_provider_validation_error(e: ValidationError) -> JSONResponse:
+ """
+ Format a provider validation error into a standardized response.
+
+ Args:
+ e: The validation error
+
+ Returns:
+ A JSONResponse with detailed error information
+ """
+ errors = e.errors()
+
+ # Extract field names and build error details
+ invalid_fields = _extract_invalid_fields(errors)
+ invalid_params = _build_invalid_params(errors)
+
+ # Generate appropriate message and suggestions
+ message = _generate_error_message(invalid_fields)
+ suggestions = _generate_suggestions(errors)
+
+ # Build the response
+ error_response = {
+ "title": "Connection failed",
+ "timestamp": datetime.datetime.now().isoformat(),
+ "message": message,
+ "name": "Error",
+ "details": "The server rejected the request due to invalid provider parameters.",
+ "invalid_parameters": invalid_params,
+ "suggestions": suggestions
+ }
+
+ return JSONResponse(
+ status_code=400,
+ content=error_response
+ )
+
+
+def _create_safe_config(config: ProviderConfig) -> Dict[str, Any]:
+ """Create a copy of the config without sensitive information."""
+ return {
+ "kind": config.kind,
+ "environment": config.environment,
+ "base_url": config.base_url,
+ "models": config.models,
+ "fetch_models": config.fetch_models,
+ }
+
+
+def _determine_error_code(error_message: str) -> str:
+ """Determine the error code based on the error message."""
+ error_message = error_message.lower()
+
+ if "authentication failed" in error_message or "unauthorized" in error_message:
+ return "AUTH_ERROR"
+ elif "not found" in error_message or "404" in error_message:
+ return "ENDPOINT_NOT_FOUND"
+ elif "timeout" in error_message:
+ return "TIMEOUT"
+ elif "connection" in error_message:
+ return "CONNECTION_ERROR"
+ elif "rate limit" in error_message or "too many requests" in error_message:
+ return "RATE_LIMIT"
+ elif "quota" in error_message or "exceeded" in error_message:
+ return "QUOTA_EXCEEDED"
+ else:
+ return "UNKNOWN_ERROR"
+
+
+def _generate_connection_suggestions(error_code: str) -> list:
+ """Generate suggestions based on the error code."""
+ suggestions = ["Check API key and endpoint URL", "Verify the provider is correctly configured"]
+
+ if error_code == "AUTH_ERROR":
+ suggestions.append("Check if the API key is valid and has the necessary permissions")
+ elif error_code == "ENDPOINT_NOT_FOUND":
+ suggestions.append("Verify the base URL is correct for this provider")
+ elif error_code == "TIMEOUT":
+ suggestions.append("The server took too long to respond. Check network connectivity or try again later")
+ elif error_code == "CONNECTION_ERROR":
+ suggestions.append("Failed to establish connection to the provider. Check network connectivity")
+ elif error_code == "RATE_LIMIT":
+ suggestions.append("You've exceeded the provider's rate limits. Try again later")
+ elif error_code == "QUOTA_EXCEEDED":
+ suggestions.append("You've exceeded your provider quota. Check your usage dashboard")
+
+ suggestions.append("Check server logs for more details")
+ return suggestions
+
+
+def format_provider_connection_error(e: Exception, provider_name: str, config: ProviderConfig) -> JSONResponse:
+ """
+ Format a provider connection error into a standardized response.
+
+ Args:
+ e: The exception that occurred
+ provider_name: Name of the provider
+ config: The provider configuration
+
+ Returns:
+ A JSONResponse with detailed error information
+ """
+ # Get the error message as string
+ error_message = str(e)
+
+ # Determine error code
+ error_code = _determine_error_code(error_message)
+
+ # Create safe configuration without sensitive data
+ safe_config = _create_safe_config(config)
+
+ # Generate provider details
+ provider_details = {
+ "provider": provider_name,
+ "error_type": type(e).__name__,
+ "error_code": error_code,
+ "config": safe_config
+ }
+
+ # Generate suggestions
+ suggestions = _generate_connection_suggestions(error_code)
+
+ # Create the error response
+ error_response = {
+ "title": "Connection failed",
+ "timestamp": datetime.datetime.now().isoformat(),
+ "status": "error",
+ "message": error_message,
+ "name": "Error",
+ "details": "Failed to connect to the provider service.",
+ "provider_details": provider_details,
+ "suggestions": suggestions
+ }
+
+ # Return a structured error response with HTTP 400 status
+ return JSONResponse(
+ status_code=400,
+ content=error_response
+ )
diff --git a/servers/inference_server/server/server/features/providers/models.py b/servers/inference_bridge/server/server/features/providers/models.py
similarity index 54%
rename from servers/inference_server/server/server/features/providers/models.py
rename to servers/inference_bridge/server/server/features/providers/models.py
index 506f7298..928bcd82 100644
--- a/servers/inference_server/server/server/features/providers/models.py
+++ b/servers/inference_bridge/server/server/features/providers/models.py
@@ -5,7 +5,7 @@
Defines data models for the providers API endpoints.
"""
-from typing import List
+from typing import List, Optional
from pydantic import BaseModel, Field
@@ -15,7 +15,6 @@ class ProviderInfo(BaseModel):
name: str = Field(..., description="Unique identifier for the provider")
kind: str = Field(..., description="Type of provider (e.g., 'openai', 'anthropic', 'gemini')")
- default_model: str = Field("", description="Default model used by the provider")
class ProviderListResponse(BaseModel):
@@ -37,3 +36,22 @@ class ProviderModelsResponse(BaseModel):
provider: str = Field(..., description="Name of the provider")
models: List[ModelInfo] = Field(..., description="List of available models")
+
+
+class ProviderConfig(BaseModel):
+ """Provider configuration model."""
+
+ name: str = Field(..., description="Unique identifier for the provider")
+ kind: str = Field(..., description="Type of provider (e.g., 'openai', 'anthropic', 'gemini')")
+ environment: str = Field(..., description="Provider environment (cloud, local)")
+ base_url: str = Field(..., description="Base URL for the provider API")
+ api_key: str = Field(..., description="API key for the provider")
+ models: List[str] = Field(default_factory=list, description="List of available model IDs")
+ fetch_models: bool = Field(default=True, description="Whether to fetch models from the provider API")
+ rate_limits: Optional[dict] = Field(None, description="Rate limiting configuration")
+
+
+class ProviderConfigureRequest(BaseModel):
+ """Request model for configuring a provider."""
+
+ config: ProviderConfig = Field(..., description="Provider configuration")
diff --git a/servers/inference_bridge/server/server/features/providers/router.py b/servers/inference_bridge/server/server/features/providers/router.py
new file mode 100644
index 00000000..05aac4e2
--- /dev/null
+++ b/servers/inference_bridge/server/server/features/providers/router.py
@@ -0,0 +1,47 @@
+"""
+# SPDX-License-Identifier: Apache-2.0
+Providers Router
+
+Defines API routes for working with AI providers.
+
+This module provides the following endpoints:
+- POST /providers/{provider_name}/test-connection - Test connection to a provider using provided configuration
+"""
+
+import traceback
+
+from fastapi import APIRouter
+from pydantic import ValidationError
+
+from ...utils.logger import logger
+from .error_handler import (format_provider_connection_error,
+ format_provider_validation_error)
+from .models import ProviderConfig
+from .service import test_provider_connection
+
+router = APIRouter(prefix="/providers", tags=["providers"])
+
+@router.post("/{provider_name}/test-connection")
+async def test_connection(provider_name: str, config: ProviderConfig):
+ """
+ Test connection to a provider using provided configuration.
+
+ Args:
+ provider_name: Name of the provider to test connection for
+ config: Provider configuration for this request
+
+ Returns:
+ A success message if connection is successful
+
+ Raises:
+ HTTPException: If there is an error connecting to the provider
+ """
+ try:
+ result = await test_provider_connection(provider_name, config)
+ return {"status": "success", "message": "Connection successful", "result": result}
+ except ValidationError as e:
+ return format_provider_validation_error(e)
+ except Exception as e:
+ logger.error(f"Error testing connection to {provider_name}: {str(e)}")
+ logger.error(traceback.format_exc())
+ return format_provider_connection_error(e, provider_name, config)
diff --git a/servers/inference_bridge/server/server/features/providers/service.py b/servers/inference_bridge/server/server/features/providers/service.py
new file mode 100644
index 00000000..d75c8394
--- /dev/null
+++ b/servers/inference_bridge/server/server/features/providers/service.py
@@ -0,0 +1,215 @@
+"""
+# SPDX-License-Identifier: Apache-2.0
+Providers Service
+
+Provides services for working with AI providers.
+"""
+
+import datetime
+from typing import Any, Dict, Protocol, cast, runtime_checkable
+
+from loguru import logger
+
+from graphcap.providers.clients.base_client import BaseClient
+from graphcap.providers.factory import create_provider_client
+
+from .models import ModelInfo, ProviderConfig
+
+
+@runtime_checkable
+class ModelProvider(Protocol):
+ """Protocol for model providers"""
+ async def get_available_models(self) -> Any: ...
+ async def get_models(self) -> Any: ...
+
+
+def _extract_model_id(model: Any) -> str:
+ """Extract model ID from provider response"""
+ if hasattr(model, "id"):
+ return model.id
+ if hasattr(model, "name"):
+ return model.name
+ return str(model)
+
+
+def _create_model_info(model_id: str) -> ModelInfo:
+ """Create a ModelInfo instance"""
+ return ModelInfo(
+ id=model_id,
+ name=model_id,
+ is_default=False
+ )
+
+
+def create_provider_client_from_config(config: ProviderConfig) -> BaseClient:
+ """
+ Create a provider client from a configuration.
+
+ Args:
+ config: Provider configuration
+
+ Returns:
+ Provider client
+
+ Raises:
+ ValueError: If client creation fails
+ """
+ logger.info(f"Creating provider client from config for {config.name}")
+ return create_provider_client(
+ name=config.name,
+ kind=config.kind,
+ environment=config.environment,
+ base_url=config.base_url,
+ api_key=config.api_key,
+ rate_limits=config.rate_limits,
+ use_cache=True,
+ )
+
+
+def _create_initial_result(provider_name: str, config: ProviderConfig) -> Dict[str, Any]:
+ """Create initial result structure for connection test"""
+ return {
+ "provider": provider_name,
+ "details": {},
+ "diagnostics": {
+ "config_summary": {
+ "kind": config.kind,
+ "environment": config.environment,
+ "base_url_valid": bool(config.base_url),
+ "api_key_provided": bool(config.api_key),
+ "models_count": len(config.models),
+ },
+ "connection_steps": [],
+ "warnings": []
+ }
+ }
+
+
+def _check_configuration_warnings(result: Dict[str, Any], config: ProviderConfig) -> None:
+ """Check for configuration warnings and add them to the result"""
+ # Check if an empty API key was provided
+ if not config.api_key:
+ result["diagnostics"]["warnings"].append({
+ "warning_type": "empty_api_key",
+ "message": "An empty API key was provided. This might not work with most providers."
+ })
+
+ # Check if the base URL seems valid
+ if not config.base_url.startswith(("http://", "https://")):
+ result["diagnostics"]["warnings"].append({
+ "warning_type": "invalid_base_url",
+ "message": "The base URL doesn't start with http:// or https://"
+ })
+
+
+async def _try_list_models(client: ModelProvider, result: Dict[str, Any]) -> None:
+ """Attempt to list models from the provider"""
+ # Add diagnostic step for model list
+ result["diagnostics"]["connection_steps"].append({
+ "step": "list_models",
+ "status": "pending",
+ "timestamp": str(datetime.datetime.now())
+ })
+
+ try:
+ if hasattr(client, "get_available_models"):
+ provider_models = await client.get_available_models()
+ result["details"]["method"] = "get_available_models"
+
+ if hasattr(provider_models, "data"):
+ _extract_models_data(result, provider_models.data)
+
+ elif hasattr(client, "get_models"):
+ provider_models = await client.get_models()
+ result["details"]["method"] = "get_models"
+
+ if hasattr(provider_models, "models"):
+ _extract_models_data(result, provider_models.models)
+
+ # Update diagnostic step
+ result["diagnostics"]["connection_steps"][-1]["status"] = "success"
+
+ except Exception as e:
+ logger.warning(f"Could not list models: {str(e)}")
+ result["diagnostics"]["connection_steps"][-1]["status"] = "skipped"
+ result["diagnostics"]["connection_steps"][-1]["message"] = "Model listing not supported or failed"
+
+
+def _extract_models_data(result: Dict[str, Any], models_list: Any) -> None:
+ """Extract model data from provider response"""
+ models_data = []
+ for model in models_list:
+ model_id = _extract_model_id(model)
+ models_data.append({"id": model_id})
+
+ result["details"]["available_models"] = models_data
+ result["details"]["models_count"] = len(models_data)
+
+
+async def test_provider_connection(provider_name: str, config: ProviderConfig) -> Dict[str, Any]:
+ """
+ Test connection to a provider by initializing the client and performing a simple operation.
+
+ Args:
+ provider_name: Name of the provider to test
+ config: Provider configuration for this request
+
+ Returns:
+ Dictionary containing test results and additional information
+
+ Raises:
+ Exception: If the connection test fails
+ """
+ result = _create_initial_result(provider_name, config)
+
+ try:
+ # Add diagnostic step
+ result["diagnostics"]["connection_steps"].append({
+ "step": "initialize_client",
+ "status": "pending",
+ "timestamp": str(datetime.datetime.now())
+ })
+
+ # Initialize client with provided configuration
+ client = create_provider_client(
+ name=provider_name,
+ kind=config.kind,
+ environment=config.environment,
+ base_url=config.base_url,
+ api_key=config.api_key,
+ rate_limits=config.rate_limits,
+ use_cache=False, # Don't cache test clients
+ )
+
+ # Update diagnostic step
+ result["diagnostics"]["connection_steps"][-1]["status"] = "success"
+ result["client_initialized"] = True
+
+ # Check for configuration warnings
+ _check_configuration_warnings(result, config)
+
+ # Try to list models
+ await _try_list_models(cast(ModelProvider, client), result)
+
+ # Connection test successful
+ result["connected"] = True
+ result["success"] = True
+ result["message"] = f"Successfully connected to {provider_name}"
+
+ return result
+
+ except Exception as e:
+ logger.error(f"Error testing connection to {provider_name}: {str(e)}")
+
+ # Update the last diagnostic step if it's pending
+ if result["diagnostics"]["connection_steps"] and result["diagnostics"]["connection_steps"][-1]["status"] == "pending":
+ result["diagnostics"]["connection_steps"][-1]["status"] = "failed"
+ result["diagnostics"]["connection_steps"][-1]["error"] = str(e)
+
+ # Add overall failure information
+ result["connected"] = False
+ result["success"] = False
+ result["message"] = f"Failed to connect to {provider_name}: {str(e)}"
+ result["error"] = str(e)
+
+ return result
diff --git a/servers/inference_server/server/server/features/repositories/types.py b/servers/inference_bridge/server/server/features/repositories/types.py
similarity index 100%
rename from servers/inference_server/server/server/features/repositories/types.py
rename to servers/inference_bridge/server/server/features/repositories/types.py
diff --git a/servers/inference_server/server/server/main.py b/servers/inference_bridge/server/server/main.py
similarity index 95%
rename from servers/inference_server/server/server/main.py
rename to servers/inference_bridge/server/server/main.py
index 6eb37a96..a7b41725 100644
--- a/servers/inference_server/server/server/main.py
+++ b/servers/inference_bridge/server/server/main.py
@@ -16,9 +16,10 @@
from .db import init_app_db
from .routers import main_router
from .utils.logger import logger
+from .utils.middleware import setup_middlewares
-class GracefulExit(SystemExit):
+class GracefulExit(Exception):
"""Custom exception for graceful shutdown."""
pass
@@ -84,6 +85,9 @@ async def lifespan(app: FastAPI):
lifespan=lifespan,
)
+# Set up middleware
+setup_middlewares(app)
+
# Configure CORS
app.add_middleware(
CORSMiddleware,
diff --git a/servers/inference_server/server/server/pipelines/__init__py b/servers/inference_bridge/server/server/pipelines/__init__py
similarity index 100%
rename from servers/inference_server/server/server/pipelines/__init__py
rename to servers/inference_bridge/server/server/pipelines/__init__py
diff --git a/servers/inference_server/server/server/pipelines/dagster_client.py b/servers/inference_bridge/server/server/pipelines/dagster_client.py
similarity index 100%
rename from servers/inference_server/server/server/pipelines/dagster_client.py
rename to servers/inference_bridge/server/server/pipelines/dagster_client.py
diff --git a/servers/inference_server/server/server/routers.py b/servers/inference_bridge/server/server/routers.py
similarity index 100%
rename from servers/inference_server/server/server/routers.py
rename to servers/inference_bridge/server/server/routers.py
diff --git a/servers/inference_bridge/server/server/utils/__init__.py b/servers/inference_bridge/server/server/utils/__init__.py
new file mode 100644
index 00000000..46c32ec9
--- /dev/null
+++ b/servers/inference_bridge/server/server/utils/__init__.py
@@ -0,0 +1,17 @@
+"""
+# SPDX-License-Identifier: Apache-2.0
+Utils Module
+
+Provides utility functions and classes for the FastAPI application.
+
+Key components:
+- logger: Configured loguru logger
+- resizing: Image resizing utilities
+- middleware: FastAPI middleware components
+"""
+
+from . import logger
+from . import resizing
+from . import middleware
+
+__all__ = ["logger", "resizing", "middleware"]
diff --git a/servers/inference_server/server/server/utils/logger.py b/servers/inference_bridge/server/server/utils/logger.py
similarity index 100%
rename from servers/inference_server/server/server/utils/logger.py
rename to servers/inference_bridge/server/server/utils/logger.py
diff --git a/servers/inference_bridge/server/server/utils/middleware.py b/servers/inference_bridge/server/server/utils/middleware.py
new file mode 100644
index 00000000..0633b899
--- /dev/null
+++ b/servers/inference_bridge/server/server/utils/middleware.py
@@ -0,0 +1,208 @@
+"""
+# SPDX-License-Identifier: Apache-2.0
+Middleware for FastAPI
+
+Contains middleware components for the FastAPI application.
+"""
+
+import datetime
+import json
+from typing import Any, Callable, Dict, List, Optional, Sequence, Union
+from fastapi import FastAPI, Request, status
+from fastapi.exceptions import RequestValidationError
+from fastapi.responses import JSONResponse
+from pydantic import ValidationError
+
+from ..utils.logger import logger
+
+
+class ValidationErrorMiddleware:
+ """
+ Middleware for handling validation errors and providing detailed error messages.
+
+ This middleware intercepts RequestValidationError exceptions and transforms them
+ into user-friendly error responses with specific details about what parameters
+ were invalid.
+ """
+
+ def __init__(self, app: FastAPI):
+ """Initialize the middleware with the FastAPI app."""
+ self.app = app
+
+ # Register the exception handler
+ @app.exception_handler(RequestValidationError)
+ async def validation_exception_handler(request: Request, exc: RequestValidationError):
+ return self.handle_validation_error(request, exc)
+
+ # Register the pydantic ValidationError handler
+ @app.exception_handler(ValidationError)
+ async def pydantic_validation_exception_handler(request: Request, exc: ValidationError):
+ return self.handle_validation_error(request, exc)
+
+ def handle_validation_error(self, request: Request, exc: Union[RequestValidationError, ValidationError]):
+ """
+ Handle validation errors and transform them into detailed error responses.
+
+ Args:
+ request: The FastAPI request
+ exc: The validation exception
+
+ Returns:
+ JSONResponse: A detailed error response
+ """
+ # Extract error details from the exception
+ errors = exc.errors()
+
+ # Log the error
+ logger.error(f"Validation error: {errors}")
+
+ # Generate an overall message about the invalid parameters
+ message = self._generate_overall_message(errors)
+
+ # Generate suggestions based on error types
+ suggestions = self._generate_suggestions(errors)
+
+ # Generate field-specific error details
+ invalid_params = self._format_error_details(errors)
+
+ # Build the response
+ error_response = {
+ "title": "Validation Error",
+ "timestamp": datetime.datetime.now().isoformat(),
+ "message": message,
+ "name": "Error",
+ "details": "The request was rejected due to invalid parameters.",
+ "invalid_parameters": invalid_params,
+ "suggestions": suggestions
+ }
+
+ return JSONResponse(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ content=error_response
+ )
+
+ def _generate_overall_message(self, errors: Sequence[Dict[str, Any]]) -> str:
+ """
+ Generate a clear overall error message summarizing what's invalid.
+
+ Args:
+ errors: List of error dictionaries
+
+ Returns:
+ A summary message string
+ """
+ # Start with a default message
+ if not errors:
+ return "Invalid request parameters"
+
+ # Count how many fields have errors
+ invalid_fields = set()
+ for error in errors:
+ loc = error.get("loc", [])
+ if len(loc) > 1: # Skip the body/query prefix
+ field_name = loc[1] if isinstance(loc[1], str) else str(loc[1])
+ invalid_fields.add(field_name)
+
+ if len(invalid_fields) == 1:
+ field = next(iter(invalid_fields))
+ return f"Invalid request: '{field}' parameter is invalid"
+ elif len(invalid_fields) > 1:
+ field_list = "', '".join(sorted(invalid_fields))
+ return f"Invalid request: Parameters '{field_list}' are invalid"
+ else:
+ return "Invalid request parameters"
+
+ def _generate_suggestions(self, errors: Sequence[Dict[str, Any]]) -> List[str]:
+ """
+ Generate helpful suggestions based on error types.
+
+ Args:
+ errors: List of error dictionaries
+
+ Returns:
+ List of suggestion strings
+ """
+ suggestions = []
+
+ # Add specific suggestions based on error types
+ for error in errors:
+ error_type = error.get("type", "")
+ field = ".".join(str(loc) for loc in error.get("loc", [])[1:]) if error.get("loc") else ""
+
+ if error_type == "missing":
+ suggestions.append(f"Add the missing required parameter: '{field}'")
+ elif error_type == "string_type":
+ suggestions.append(f"Ensure '{field}' is a valid string")
+ elif error_type == "url_parsing":
+ suggestions.append(f"Use a valid URL format for '{field}'")
+ elif error_type and "enum" in error_type:
+ valid_values = error.get("ctx", {}).get("expected", [])
+ if valid_values:
+ values_str = ", ".join([f"'{v}'" for v in valid_values])
+ suggestions.append(f"Choose a valid option for '{field}': {values_str}")
+ else:
+ suggestions.append(f"Choose a valid option for '{field}'")
+ elif error_type == "value_error":
+ suggestions.append(f"Provide a valid value for '{field}'")
+ elif error_type == "type_error":
+ suggestions.append(f"Check the data type for '{field}'")
+
+ # Add generic suggestion at the end
+ suggestions.append("Check the documentation for correct parameter formats")
+
+ # Return unique suggestions
+ return list(dict.fromkeys(suggestions))
+
+ def _format_error_details(self, errors: Sequence[Dict[str, Any]]) -> Dict[str, Dict[str, str]]:
+ """
+ Format validation errors into a structured dictionary.
+
+ Args:
+ errors: List of error dictionaries
+
+ Returns:
+ Dictionary of field names to error details
+ """
+ invalid_params = {}
+
+ for error in errors:
+ # Extract location (field name)
+ location = error.get("loc", [])
+ if len(location) < 2:
+ continue
+
+ # Skip the first element (usually 'body' or 'query')
+ field_path = ".".join(str(loc) for loc in location[1:])
+
+ # Extract error message and type
+ message = error.get("msg", "Validation error")
+ error_type = error.get("type", "unknown_error")
+
+ # Add any context information from the error
+ context = {}
+ if error.get("ctx"):
+ for key, value in error.get("ctx", {}).items():
+ if key != "expected" or not isinstance(value, list) or len(value) < 5:
+ context[key] = value
+
+ invalid_params[field_path] = {
+ "message": message,
+ "error_type": error_type
+ }
+
+ # Add context if available
+ if context:
+ invalid_params[field_path]["context"] = context
+
+ return invalid_params
+
+
+def setup_middlewares(app: FastAPI) -> None:
+ """
+ Set up all middleware for the FastAPI application.
+
+ Args:
+ app: The FastAPI application instance
+ """
+ # Initialize the validation error middleware
+ ValidationErrorMiddleware(app)
\ No newline at end of file
diff --git a/servers/inference_server/server/server/utils/resizing.py b/servers/inference_bridge/server/server/utils/resizing.py
similarity index 100%
rename from servers/inference_server/server/server/utils/resizing.py
rename to servers/inference_bridge/server/server/utils/resizing.py
diff --git a/servers/inference_server/server/uv.lock b/servers/inference_bridge/server/uv.lock
similarity index 100%
rename from servers/inference_server/server/uv.lock
rename to servers/inference_bridge/server/uv.lock
diff --git a/servers/inference_server/tests/test_perspective_modules.py b/servers/inference_bridge/tests/test_perspective_modules.py
similarity index 100%
rename from servers/inference_server/tests/test_perspective_modules.py
rename to servers/inference_bridge/tests/test_perspective_modules.py
diff --git a/servers/inference_server/uv.lock b/servers/inference_bridge/uv.lock
similarity index 100%
rename from servers/inference_server/uv.lock
rename to servers/inference_bridge/uv.lock
diff --git a/servers/inference_server/graphcap/perspectives/base_caption.py b/servers/inference_server/graphcap/perspectives/base_caption.py
deleted file mode 100644
index 737e09a3..00000000
--- a/servers/inference_server/graphcap/perspectives/base_caption.py
+++ /dev/null
@@ -1,413 +0,0 @@
-"""
-# SPDX-License-Identifier: Apache-2.0
-Base Caption Module
-
-Provides base classes and shared functionality for different caption types.
-"""
-
-import asyncio
-import json
-import shutil
-from abc import ABC, abstractmethod
-from datetime import datetime
-from pathlib import Path
-from typing import Any, Dict, List, Optional
-
-from loguru import logger
-from pydantic import BaseModel
-from rich.console import Console
-from rich.table import Table
-from tenacity import retry, stop_after_attempt, wait_exponential
-from tqdm.asyncio import tqdm_asyncio
-
-from ..providers.clients.base_client import BaseClient
-from .types import StructuredVisionConfig
-
-# Initialize Rich console
-console = Console()
-
-
-def pretty_print_caption(caption_data: Dict[str, Any]) -> str:
- """Format caption data for pretty console output."""
- return json.dumps(caption_data["parsed"], indent=2, ensure_ascii=False)
-
-
-class BaseCaptionProcessor(ABC):
- """
- Base class for caption processors.
-
- Provides shared functionality for processing images with vision models
- and handling responses. Subclasses implement specific caption formats.
-
- Attributes:
- config_name (str): Name of this caption processor
- version (str): Version of the processor
- prompt (str): Instruction prompt for the vision model
- schema (BaseModel): Pydantic model for response validation
- """
-
- def __init__(
- self,
- config_name: str,
- version: str,
- prompt: str,
- schema: type[BaseModel],
- ):
- self.vision_config = StructuredVisionConfig(
- config_name=config_name,
- version=version,
- prompt=prompt,
- schema=schema,
- )
-
- def _sanitize_json_string(self, text: str) -> str:
- """
- Sanitize JSON string by properly escaping control characters.
-
- Args:
- text: Raw JSON string that may contain control characters
-
- Returns:
- Sanitized JSON string with properly escaped control characters
- """
- # Define escape sequences for common control characters
- control_char_map = {
- "\n": "\\n", # Line feed
- "\r": "\\r", # Carriage return
- "\t": "\\t", # Tab
- "\b": "\\b", # Backspace
- "\f": "\\f", # Form feed
- "\v": "\\u000b", # Vertical tab
- "\0": "", # Null character - remove it
- }
-
- # First pass: handle known control characters
- for char, escape_seq in control_char_map.items():
- text = text.replace(char, escape_seq)
-
- # Second pass: handle any remaining control characters
- result = ""
- for char in text:
- if ord(char) < 32: # Control characters are below ASCII 32
- result += f"\\u{ord(char):04x}"
- else:
- result += char
-
- return result
-
- @abstractmethod
- def create_rich_table(self, caption_data: Dict[str, Any]) -> Table:
- """
- Create a Rich table for displaying caption data.
-
- Args:
- caption_data: The caption data to format
-
- Returns:
- Rich Table object for display
- """
- pass
-
- async def process_single(
- self,
- provider: BaseClient,
- image_path: Path,
- max_tokens: Optional[int] = 4096,
- temperature: Optional[float] = 0.8,
- top_p: Optional[float] = 0.9,
- repetition_penalty: Optional[float] = 1.15,
- context: list[str] | None = None,
- global_context: str | None = None,
- ) -> dict:
- """
- Process a single image and return caption data.
-
- Args:
- provider: Vision AI provider client instance
- image_path: Path to the image file
- max_tokens: Maximum tokens for model response
- temperature: Sampling temperature
- top_p: Nucleus sampling parameter
-
- Returns:
- dict: Structured caption data according to schema
-
- Raises:
- Exception: If image processing fails
- """
- if context or global_context:
- context_block = " Consider the following context when generating the caption:\n"
- if global_context:
- context_block += f"\n{global_context}\n\n"
- if context:
- for entry in context:
- context_block += f"\n{entry}\n\n"
- context_block += "\n"
- prompt = f"{context_block}{self.vision_config.prompt}"
- else:
- prompt = self.vision_config.prompt
- try:
- completion = await provider.vision(
- prompt=prompt,
- image=image_path,
- schema=self.vision_config.schema,
- model=provider.default_model,
- max_tokens=max_tokens,
- temperature=temperature,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- )
-
- # Handle response parsing with sanitization
- if isinstance(completion, BaseModel):
- result = completion.choices[0].message.parsed
- if isinstance(result, BaseModel):
- result = result.model_dump()
- else:
- result = completion.choices[0].message.parsed
- # Handle string responses that need parsing
- if isinstance(result, str):
- sanitized = self._sanitize_json_string(result)
- try:
- result = json.loads(sanitized)
- except json.JSONDecodeError as e:
- logger.error(f"Failed to parse sanitized JSON: {e}")
- raise
- elif "choices" in result:
- result = result["choices"][0]["message"]["parsed"]["parsed"]
- elif "message" in result:
- result = result["message"]["parsed"]
-
- return result
- except Exception as e:
- raise Exception(f"Error processing {image_path}: {str(e)}")
-
- async def process_batch(
- self,
- provider: BaseClient,
- image_paths: List[Path],
- max_tokens: Optional[int] = 4096,
- temperature: Optional[float] = 0.8,
- top_p: Optional[float] = 0.9,
- max_concurrent: Optional[int] = 1,
- repetition_penalty: Optional[float] = 1.15,
- output_dir: Optional[Path] = None,
- store_logs: bool = False,
- formats: Optional[List[str]] = None,
- copy_images: bool = False,
- global_context: str | None = None,
- contexts: dict[str, list[str]] | None = None,
- name: str | None = None,
- ) -> List[Dict[str, Any]]:
- """
- Process multiple images and return their captions.
-
- Args:
- provider: Vision AI provider client instance
- image_paths: List of paths to image files
- max_tokens: Maximum tokens for model response
- temperature: Sampling temperature
- top_p: Nucleus sampling parameter
- max_concurrent: Maximum number of concurrent API requests
- output_dir: Directory to store incremental results and job info
- store_logs: Whether to store logs in the output directory
- formats: List of additional formats to write caption data
- copy_images: Whether to copy images to the output directory
- contexts: Additional context for the vision model based on image paths
- Returns:
- List[Dict[str, Any]]: List of caption results with metadata
- """
- # Create job directory with timestamp if output_dir provided
- job_dir = None
- job_output = None
- job_info = None
- log_file = None
-
- if output_dir:
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- job_dir = output_dir / f"batch_{name or timestamp}"
- job_dir.mkdir(parents=True, exist_ok=True)
-
- # Create output file and job info
- job_output = job_dir / "captions.jsonl"
- job_info = job_dir / "job_info.json"
-
- # Configure logging if requested
- if store_logs:
- log_file = job_dir / "process.log"
- # Add file logger while keeping console output
- logger.add(
- log_file,
- format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
- level="INFO",
- rotation="100 MB",
- )
-
- # Write initial job info
- job_info_data = {
- "started_at": timestamp,
- "provider": provider.name,
- "model": provider.default_model,
- "config_name": self.vision_config.config_name,
- "version": self.vision_config.version,
- "total_images": len(image_paths),
- "sampling": {
- "original_count": getattr(image_paths, "original_count", len(image_paths)),
- "sample_size": getattr(image_paths, "sample_size", len(image_paths)),
- "sample_method": getattr(image_paths, "sample_method", "all"),
- },
- "params": {
- "max_tokens": max_tokens,
- "temperature": temperature,
- "top_p": top_p,
- "max_concurrent": max_concurrent,
- "repetition_penalty": repetition_penalty,
- },
- "log_file": str(log_file.relative_to(job_dir)) if log_file else None,
- "formats": formats or [],
- "copy_images": copy_images,
- "global_context": global_context,
- }
- job_info.write_text(json.dumps(job_info_data, indent=2))
-
- # Copy images if requested
- if copy_images:
- images_dir = job_dir / "images"
- images_dir.mkdir(exist_ok=True)
- for path in image_paths:
- try:
- shutil.copy2(path, images_dir / path.name)
- except Exception as e:
- logger.error(f"Failed to copy image {path}: {e}")
-
- logger.info(f"Processing {len(image_paths)} images with {provider.name} provider")
- logger.info(f"Using max concurrency of {max_concurrent} requests")
- if job_dir:
- logger.info(f"Writing results to {job_dir}")
- if log_file:
- logger.info(f"Logging to {log_file}")
-
- semaphore = asyncio.Semaphore(max_concurrent)
- active_requests = 0
- processed_count = 0
- failed_count = 0
-
- @retry(
- stop=stop_after_attempt(3),
- wait=wait_exponential(multiplier=1, min=4, max=10),
- reraise=True,
- )
- async def process_with_semaphore(path: Path) -> Dict[str, Any]:
- nonlocal active_requests, processed_count, failed_count
-
- async with semaphore:
- try:
- active_requests += 1
- logger.info(f"Starting request for {path.name} (Active requests: {active_requests})")
-
- result = await self.process_single(
- provider=provider,
- image_path=path,
- max_tokens=max_tokens,
- temperature=temperature,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- context=contexts.get(path.name) if contexts else None,
- global_context=global_context,
- )
-
- active_requests -= 1
- processed_count += 1
- logger.info(f"Completed request for {path.name} (Active requests: {active_requests})")
-
- caption_data = {
- "filename": f"./{path.name}",
- "config_name": self.vision_config.config_name,
- "version": self.vision_config.version,
- "model": provider.default_model,
- "provider": provider.name,
- "parsed": result,
- }
-
- # Write result incrementally if output file exists
- if job_output:
- with job_output.open("a") as f:
- f.write(json.dumps(caption_data) + "\n")
-
- # Update job info
- job_info_data["processed_count"] = processed_count
- job_info_data["failed_count"] = failed_count
- job_info_data["completed_at"] = datetime.now().strftime("%Y%m%d_%H%M%S")
- job_info.write_text(json.dumps(job_info_data, indent=2))
-
- # Create and display Rich table
- console.print(f"\n[bold cyan]Processed {path.name}:[/bold cyan]")
- table = self.create_rich_table(caption_data)
- console.print(table)
-
- return caption_data
- except Exception as e:
- active_requests -= 1
- failed_count += 1
- logger.error(f"Failed request for {path.name} (Active requests: {active_requests})")
- error_data = {
- "filename": f"./{path.name}",
- "config_name": self.vision_config.config_name,
- "version": self.vision_config.version,
- "model": provider.default_model,
- "provider": provider.name,
- "parsed": {"error": str(e)},
- }
-
- # Write error result if output file exists
- if job_output:
- with job_output.open("a") as f:
- f.write(json.dumps(error_data) + "\n")
-
- # Update job info
- job_info_data["processed_count"] = processed_count
- job_info_data["failed_count"] = failed_count
- job_info_data["completed_at"] = datetime.now().strftime("%Y%m%d_%H%M%S")
- job_info.write_text(json.dumps(job_info_data, indent=2))
-
- console.print(f"\n[bold red]Failed to process {path.name}:[/bold red] {str(e)}")
- return error_data
-
- results = await tqdm_asyncio.gather(
- *[process_with_semaphore(path) for path in image_paths],
- desc=f"Processing images with {provider.name}",
- )
-
- # Log summary with Rich
- success_count = sum(1 for r in results if "error" not in r["parsed"])
- summary_table = Table(title="Processing Summary", show_header=False)
- summary_table.add_column("Metric", style="cyan")
- summary_table.add_column("Value", style="green")
- summary_table.add_row("Total Images", str(len(results)))
- summary_table.add_row("Successful", str(success_count))
- summary_table.add_row("Failed", str(len(results) - success_count))
-
- console.print("\n")
- console.print(summary_table)
-
- return results
-
- @abstractmethod
- def to_table(self, caption_data: Dict[str, Any]) -> Dict[str, Any]:
- """
- Convert caption data to a flat dictionary suitable for tabular representation.
-
- Args:
- caption_data: The caption data to format
-
- Returns:
- Dict[str, Any]: Flattened dictionary for tabular representation
- """
- pass
-
- @abstractmethod
- def to_context(self, caption_data: Dict[str, Any]) -> Dict[str, Any]:
- """
- Convert caption data to a context string suitable for downstream perspectives.
- """
- pass
diff --git a/servers/inference_server/graphcap/providers/README.md b/servers/inference_server/graphcap/providers/README.md
deleted file mode 100644
index 545981bc..00000000
--- a/servers/inference_server/graphcap/providers/README.md
+++ /dev/null
@@ -1,131 +0,0 @@
-# Provider Management System
-
-A flexible provider management system for handling multiple AI service providers with an OpenAI-compatible interface.
-
-## Overview
-
-This system provides a unified way to manage and interact with different AI providers, including cloud-based and custom implementations. It supports standard chat completions, vision capabilities, and structured outputs across providers.
-
-## Configuration
-
-### Provider Config File (provider.config.toml)
-
-The system is configured using a TOML file that defines both cloud and custom providers:
-
-```toml
-[provider.cloud.openai]
-api_key = "OPENAI_API_KEY"
-base_url = "https://api.openai.com/v1"
-models = ["gpt-4-vision", "gpt-4"]
-
-[provider.cloud.gemini]
-api_key = "GOOGLE_API_KEY"
-base_url = "https://generativelanguage.googleapis.com/v1beta"
-models = ["gemini-2.0-flash-exp"]
-
-[provider.cloud.openrouter]
-api_key = "OPENROUTER_API_KEY"
-base_url = "https://openrouter.ai/api/v1"
-models = ["openai/gpt-4", "google/gemini-2.0-flash-exp:free"]
-
-[providers.custom.ollama]
-api_key = ""
-base_url = "http://localhost:11434"
-fetch_models = true
-
-[providers.custom.vllm-pixtral]
-api_key = ""
-base_url = "http://localhost:11435"
-models = ["vision-worker"]
-```
-
-### Provider Types
-
-1. Cloud Providers (`provider.cloud.*`)
- - OpenAI
- - Gemini
- - OpenRouter
-
-2. Custom Providers (`providers.custom.*`)
- - Ollama
- - VLLM
- - Other custom implementations
-
-## Usage
-
-### Basic Usage
-
-```python
-from graphcap.providers.provider_manager import ProviderManager
-
-# Initialize the manager
-manager = ProviderManager("provider.config.toml")
-
-# Get all initialized clients
-clients = manager.clients()
-
-# Get a specific client
-openai_client = manager.get_client("cloud.openai")
-gemini_client = manager.get_client("cloud.gemini")
-```
-
-### Provider Clients
-
-All provider clients inherit from BaseClient and implement an OpenAI-compatible interface:
-
-- `OpenAIClient`: Standard OpenAI implementation
-- `GeminiClient`: Google's Gemini API implementation
-- `OpenRouterClient`: OpenRouter API implementation
-- `OllamaClient`: Ollama-specific implementation
-- `VLLMClient`: VLLM-specific implementation
-
-### Vision Capabilities
-
-All providers support a unified vision interface:
-
-```python
-completion = client.vision(
- prompt="What's in this image?",
- image=image_path,
- model=client.default_model
-)
-```
-
-### Structured Output
-
-Providers support structured completions using JSON schemas or Pydantic models:
-
-```python
-completion = client.create_structured_completion(
- messages=messages,
- schema=MyPydanticModel,
- model="model-name"
-)
-```
-
-## Features
-
-- **Unified Interface**: All providers use an OpenAI-compatible interface
-- **Vision Support**: Standardized vision capabilities across providers
-- **Structured Output**: JSON schema and Pydantic model support
-- **Configuration Management**: TOML-based configuration
-- **Automatic Initialization**: Providers are initialized at startup
-- **Error Handling**: Robust error handling with detailed logging
-- **Caching**: Clients are cached after initialization
-
-## REST API
-
-The system includes a FastAPI router with endpoints:
-
-- `GET /providers/`: List all available providers
-- `GET /providers/{provider_name}`: Get provider details
-- `POST /providers/{provider_name}/vision`: Analyze image with provider
-
-## Error Handling
-
-The system includes comprehensive error handling:
-- Configuration validation
-- Client initialization errors
-- Runtime errors with detailed logging
-- API error responses
-
diff --git a/servers/inference_server/graphcap/providers/factory.py b/servers/inference_server/graphcap/providers/factory.py
deleted file mode 100644
index efbac5ab..00000000
--- a/servers/inference_server/graphcap/providers/factory.py
+++ /dev/null
@@ -1,85 +0,0 @@
-"""
-# SPDX-License-Identifier: Apache-2.0
-Provider Factory Module
-
-This module provides factory functions for creating provider clients.
-"""
-
-import os
-import tempfile
-from pathlib import Path
-from typing import Optional
-
-from loguru import logger
-
-from .clients import BaseClient, get_client
-from .provider_config import get_providers_config
-from .provider_manager import ProviderManager
-
-# Global provider manager instance
-_provider_manager: Optional[ProviderManager] = None
-
-
-def initialize_provider_manager(config_path: Optional[str | Path] = None) -> ProviderManager:
- """Initialize the global provider manager with the given config path.
-
- Args:
- config_path: Path to the provider configuration file. If None, uses default locations.
-
- Returns:
- ProviderManager: The initialized provider manager
- """
- global _provider_manager
-
- if config_path is None:
- # Try to find config in standard locations
- possible_paths = [
- os.environ.get("PROVIDER_CONFIG_PATH"),
- "./provider.config.toml",
- "./config/provider.config.toml",
- "/app/provider.config.toml",
- "/app/config/provider.config.toml",
- ]
-
- for path in possible_paths:
- if path and Path(path).exists():
- config_path = path
- break
-
- if not config_path or not Path(str(config_path)).exists():
- logger.warning(f"No provider config found at {config_path}. Using empty configuration.")
- # Create a temporary empty config file
- with tempfile.NamedTemporaryFile(delete=False, suffix=".toml") as temp:
- temp.write(b"# Empty provider config\n")
- config_path = temp.name
-
- # At this point, config_path should not be None
- _provider_manager = ProviderManager(str(config_path))
- return _provider_manager
-
-
-def get_provider_client(provider_name: str = "default") -> BaseClient:
- """Get a provider client by name.
-
- Args:
- provider_name: Name of the provider to get. Defaults to "default".
-
- Returns:
- BaseClient: The provider client
-
- Raises:
- ValueError: If the provider is not found
- """
- global _provider_manager
-
- if _provider_manager is None:
- initialize_provider_manager()
-
- if _provider_manager is None:
- raise ValueError("Failed to initialize provider manager")
-
- try:
- return _provider_manager.get_client(provider_name)
- except ValueError as e:
- logger.error(f"Failed to get provider client: {e}")
- raise
diff --git a/servers/inference_server/graphcap/providers/provider_config.py b/servers/inference_server/graphcap/providers/provider_config.py
deleted file mode 100644
index a5e4c4da..00000000
--- a/servers/inference_server/graphcap/providers/provider_config.py
+++ /dev/null
@@ -1,163 +0,0 @@
-"""
-# SPDX-License-Identifier: Apache-2.0
-Provider Configuration Module
-
-This module handles loading and validating provider configurations from TOML files.
-
-Key features:
-- TOML configuration loading
-- Provider config validation
-- Default model handling
-- Environment variable management
-
-Classes:
- ProviderConfig: Configuration dataclass for providers
-
-Functions:
- load_provider_config: Load config from TOML file
- parse_provider_config: Parse config into ProviderConfig object
- get_providers_config: Load and parse all provider configs
- validate_config: Validate provider configurations
-"""
-
-import tomllib
-from pathlib import Path
-from typing import Any
-
-from loguru import logger
-
-from .types import ProviderConfig, RateLimits
-
-
-def _load_provider_config(config_path: str | Path) -> dict[str, ProviderConfig]:
- """Load provider configuration from a TOML file."""
-
- config_path = Path(config_path)
-
- if not config_path.exists():
- raise FileNotFoundError(f"Configuration file not found: {config_path}")
-
- config_data = {}
- with config_path.open("rb") as f:
- config_data = tomllib.load(f)
- return config_data
-
-
-def _parse_provider_config(config_data: dict[str, Any]) -> ProviderConfig:
- """Parse a provider's configuration data into a ProviderConfig object"""
- # Get models list and default model
- models: list[str] = config_data.get("models", [])
- default_model: str = config_data.get("default_model", "")
- fetch_models: bool = config_data.get("fetch_models", False)
-
- kind: str = config_data["kind"]
- environment: str = config_data["environment"]
- env_var: str = config_data.get("env_var", "")
- base_url: str = config_data["base_url"]
-
- # If no default model specified, require one to be set
- if not default_model:
- if models:
- default_model = models[0]
- logger.debug(f"Using first model as default: {default_model}")
- else:
- raise ValueError("Must specify default_model when no models list is provided")
-
- # Parse rate limits if present
- rate_limits = None
- if "rate_limits" in config_data:
- rate_limits_data: dict[str, int | None] = config_data["rate_limits"]
- rate_limits = RateLimits(
- requests_per_minute=rate_limits_data.get("requests_per_minute"),
- tokens_per_minute=rate_limits_data.get("tokens_per_minute"),
- )
-
- return ProviderConfig(
- kind=kind,
- environment=environment,
- env_var=env_var,
- base_url=base_url,
- models=models,
- default_model=default_model,
- fetch_models=fetch_models,
- rate_limits=rate_limits,
- )
-
-
-def get_providers_config(config_path: str | Path) -> dict[str, ProviderConfig]:
- """
- Load and parse the providers configuration.
-
-
- Args:
- config_path: Path to the TOML configuration file
-
- Returns:
- Dictionary mapping provider names to their configurations
-
- Example config:
- [openai]
- kind = "openai"
- environment = "cloud"
- env_var = "OPENAI_API_KEY"
- base_url = "https://api.openai.com/v1"
- models = ["gpt-4o", "gpt-4o-mini"]
- default_model = "gpt-4o-mini" # Optional, defaults to first model in list
-
- [ollama]
- kind = "ollama"
- environment = "local"
- env_var = "CUSTOM_KEY"
- base_url = "http://localhost:11434"
- fetch_models = true
- default_model = "llama3.2" # Optional, defaults to "default" if no models
- """
- config = _load_provider_config(config_path)
- providers = {}
-
- # Parse all top-level provider configs
- for name, provider_config in config.items():
- if isinstance(provider_config, dict): # Skip non-provider sections
- try:
- providers[name] = _parse_provider_config(provider_config)
- except KeyError as e:
- logger.warning(f"Skipping provider '{name}': Missing required field {e}")
- provider_errors = validate_config(providers)
- if provider_errors:
- logger.error(f"Provider configuration errors: {provider_errors}")
- raise ValueError(f"Provider configuration errors: {provider_errors}")
- logger.info(f"Loaded {len(providers)} providers")
- logger.debug(f"Providers: {providers}")
- return providers
-
-
-def validate_config(providers: dict[str, ProviderConfig]) -> list[str]:
- """Validate the provider configuration."""
- errors: list[str] = []
-
- for name, provider in providers.items():
- # Required fields
- if not provider.base_url:
- errors.append(f"{name}: Missing base URL")
- if not provider.kind:
- errors.append(f"{name}: Missing kind")
- if not provider.environment:
- errors.append(f"{name}: Missing environment")
- if not provider.default_model:
- errors.append(f"{name}: Missing default_model")
-
- # Environment validation
- if provider.environment not in ["cloud", "local"]:
- errors.append(f"{name}: Environment must be 'cloud' or 'local'")
-
- # URL format
- if provider.base_url and not (
- provider.base_url.startswith("http://") or provider.base_url.startswith("https://")
- ):
- errors.append(f"{name}: Base URL must start with http:// or https://")
-
- # Models list when fetch_models is False
- if not provider.fetch_models and not provider.models:
- errors.append(f"{name}: Must specify models list when fetch_models is False")
-
- return errors
diff --git a/servers/inference_server/graphcap/providers/provider_manager.py b/servers/inference_server/graphcap/providers/provider_manager.py
deleted file mode 100644
index 8db97294..00000000
--- a/servers/inference_server/graphcap/providers/provider_manager.py
+++ /dev/null
@@ -1,114 +0,0 @@
-"""
-# SPDX-License-Identifier: Apache-2.0
-Provider Manager Module
-
-This module handles provider lifecycle management and client initialization.
-
-Key features:
-- Provider configuration loading
-- Client initialization and caching
-- Environment validation
-- Rate limit management
-
-Classes:
- ProviderManager: Main provider management class
-"""
-
-from pathlib import Path
-from typing import Dict
-
-from loguru import logger
-
-from .clients import BaseClient, get_client
-from .provider_config import get_providers_config
-from .types import ProviderConfig
-
-
-class ProviderManager:
- """Manager class for handling provider lifecycle and client initialization"""
-
- def __init__(self, config_path: str | Path):
- """Initialize provider manager with configuration file"""
- logger.info(f"Initializing ProviderManager with config from: {config_path}")
- self.providers = get_providers_config(config_path)
- self._clients: Dict[str, BaseClient] = {}
- logger.info(f"Loaded {len(self.providers)} provider configurations")
- for name, config in self.providers.items():
- logger.info(f"Provider '{name}' configuration:")
- logger.info(f" - kind: {config.kind}")
- logger.info(f" - environment: {config.environment}")
- logger.info(f" - base_url: {config.base_url}")
- logger.info(f" - default_model: {config.default_model}")
- if config.rate_limits:
- logger.info(f" - rate_limits: {config.rate_limits}")
-
- def get_client(self, provider_name: str) -> BaseClient:
- """Get or create a client for the specified provider"""
- if provider_name not in self.providers:
- logger.error(f"Requested unknown provider: {provider_name}")
- logger.debug(f"Available providers: {', '.join(self.providers.keys())}")
- raise ValueError(f"Unknown provider: {provider_name}")
-
- # Return cached client if available
- if provider_name in self._clients:
- logger.debug(f"Using cached client for provider: {provider_name}")
- return self._clients[provider_name]
-
- # Create new client
- config = self.providers[provider_name]
- logger.info(f"Initializing new client for provider: {provider_name}")
- logger.info(f"Provider config details:")
- logger.info(f" - kind: {config.kind}")
- logger.info(f" - environment: {config.environment}")
- logger.info(f" - base_url: {config.base_url}")
- logger.info(f" - default_model: {config.default_model}")
-
- try:
- client = get_client(
- name=provider_name,
- kind=config.kind,
- environment=config.environment,
- env_var=config.env_var,
- base_url=config.base_url,
- default_model=config.default_model,
- )
-
- # Set rate limits if configured
- if config.rate_limits:
- logger.debug(
- f"Setting rate limits for {provider_name} - requests: {config.rate_limits.requests_per_minute}/min, tokens: {config.rate_limits.tokens_per_minute}/min"
- )
- client.requests_per_minute = config.rate_limits.requests_per_minute
- client.tokens_per_minute = config.rate_limits.tokens_per_minute
-
- self._clients[provider_name] = client
- logger.info(f"Successfully initialized client for provider: {provider_name}")
- return client
-
- except Exception as e:
- logger.error(f"Failed to initialize client for {provider_name}: {str(e)}")
- logger.error(f"Provider config details:")
- logger.error(f" - kind: {config.kind}")
- logger.error(f" - environment: {config.environment}")
- logger.error(f" - base_url: {config.base_url}")
- logger.error(f" - default_model: {config.default_model}")
- raise
-
- def clients(self) -> Dict[str, BaseClient]:
- """Get all initialized clients"""
- logger.debug(f"Returning {len(self._clients)} initialized clients")
- return self._clients.copy()
-
- def available_providers(self) -> list[str]:
- """Get list of available provider names"""
- providers = list(self.providers.keys())
- logger.debug(f"Available providers: {', '.join(providers)}")
- return providers
-
- def get_provider_config(self, provider_name: str) -> ProviderConfig:
- """Get configuration for a specific provider"""
- if provider_name not in self.providers:
- logger.error(f"Requested config for unknown provider: {provider_name}")
- raise ValueError(f"Unknown provider: {provider_name}")
- logger.debug(f"Returning config for provider: {provider_name}")
- return self.providers[provider_name]
diff --git a/servers/inference_server/pipelines/pipelines/providers/assets.py b/servers/inference_server/pipelines/pipelines/providers/assets.py
deleted file mode 100644
index c5952a59..00000000
--- a/servers/inference_server/pipelines/pipelines/providers/assets.py
+++ /dev/null
@@ -1,67 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-"""Assets for loading provider configurations."""
-
-import dagster as dg
-from graphcap.providers.provider_config import get_providers_config
-from graphcap.providers.types import ProviderConfig
-
-from ..common.resources import ProviderConfigFile
-
-
-@dg.asset(compute_kind="python", group_name="providers")
-def provider_list(
- context: dg.AssetExecutionContext, provider_config_file: ProviderConfigFile
-) -> dict[str, ProviderConfig]:
- """Loads the list of providers from the provider.config.toml file."""
- config_path = provider_config_file.provider_config
- try:
- providers = get_providers_config(config_path)
- context.log.info(f"Loaded providers from {config_path}")
- provider_info = [f"{name}: {provider.default_model}" for name, provider in providers.items()]
- context.add_output_metadata(
- {
- "num_providers": len(providers),
- "config_path": config_path,
- "providers": ", ".join(provider_info),
- }
- )
- return providers
- except FileNotFoundError:
- context.log.error(f"Provider config file not found: {config_path}")
- return {}
- except Exception as e:
- context.log.error(f"Error loading provider config: {e}")
- return {}
-
-
-# TODO: Remove this asset
-@dg.asset(compute_kind="python", group_name="providers")
-def default_provider(context: dg.AssetExecutionContext, provider_config_file: ProviderConfigFile) -> str | None:
- """Loads the default provider based on the selected_provider config."""
- config_path = provider_config_file.provider_config
- try:
- providers = get_providers_config(config_path)
- selected_provider_name = provider_config_file.default_provider
-
- if selected_provider_name not in providers:
- context.log.warning(f"Selected provider '{selected_provider_name}' not found in config.")
- return None
-
- selected_provider_config = providers[selected_provider_name]
-
- context.log.info(f"Loaded default provider: {selected_provider_name}")
- context.add_output_metadata(
- {
- "selected_provider": selected_provider_name,
- "provider_kind": selected_provider_config.kind,
- "provider_environment": selected_provider_config.environment,
- "provider_default_model": selected_provider_config.default_model,
- }
- )
- return selected_provider_name
- except FileNotFoundError:
- context.log.error(f"Provider config file not found: {config_path}")
- return None
- except Exception as e:
- context.log.error(f"Error loading provider config: {e}")
- return None
diff --git a/servers/inference_server/pipelines/pipelines/providers/util.py b/servers/inference_server/pipelines/pipelines/providers/util.py
deleted file mode 100644
index 5f058ba2..00000000
--- a/servers/inference_server/pipelines/pipelines/providers/util.py
+++ /dev/null
@@ -1,26 +0,0 @@
-from graphcap.providers.provider_config import get_providers_config
-from graphcap.providers.clients import get_client
-from ..perspectives.jobs.config import PerspectivePipelineConfig
-
-
-def get_provider(config_path: str, default_provider: str):
- """Instantiates the client based on the provider configuration.
-
- Args:
- config_path (str): Path to the provider configuration file.
- default_provider (str): The name of the default provider.
-
- Returns:
- The instantiated client.
- """
- providers = get_providers_config(config_path)
- selected_provider_config = providers[default_provider]
- client_args = {
- "name": default_provider,
- "environment": selected_provider_config.environment,
- "env_var": selected_provider_config.env_var,
- "base_url": selected_provider_config.base_url,
- "default_model": selected_provider_config.default_model,
- }
- client = get_client(selected_provider_config.kind, **client_args)
- return client
diff --git a/servers/inference_server/server/server/features/providers/router.py b/servers/inference_server/server/server/features/providers/router.py
deleted file mode 100644
index 4e71725a..00000000
--- a/servers/inference_server/server/server/features/providers/router.py
+++ /dev/null
@@ -1,88 +0,0 @@
-"""
-# SPDX-License-Identifier: Apache-2.0
-Providers Router
-
-Defines API routes for working with AI providers.
-
-This module provides the following endpoints:
-- GET /providers/list - List all available providers
-- GET /providers/check/{provider_name} - Check if a specific provider is available
-- GET /providers/{provider_name}/models - List available models for a specific provider
-"""
-
-from fastapi import APIRouter, HTTPException
-
-from .models import ProviderListResponse, ProviderModelsResponse
-from .service import get_available_providers, get_provider_manager, get_provider_models
-
-router = APIRouter(prefix="/providers", tags=["providers"])
-
-
-@router.get("/list", response_model=ProviderListResponse)
-async def list_providers() -> ProviderListResponse:
- """
- List all available providers.
-
- Returns:
- List of available providers
- """
- providers = get_available_providers()
- return ProviderListResponse(providers=providers)
-
-
-@router.get("/check/{provider_name}")
-async def check_provider(provider_name: str) -> dict:
- """
- Check if a specific provider is available.
-
- Args:
- provider_name: Name of the provider to check
-
- Returns:
- Status of the provider
-
- Raises:
- HTTPException: If the provider is not found
- """
- provider_manager = get_provider_manager()
- available_providers = provider_manager.available_providers()
-
- if provider_name not in available_providers:
- raise HTTPException(
- status_code=404,
- detail=f"Provider '{provider_name}' not found. Available providers: {', '.join(available_providers)}",
- )
-
- # Get the provider config
- provider_config = provider_manager.get_provider_config(provider_name)
-
- return {
- "status": "available",
- "provider": provider_name,
- "kind": provider_config.kind,
- "environment": provider_config.environment,
- "default_model": provider_config.default_model or "",
- }
-
-
-@router.get("/{provider_name}/models", response_model=ProviderModelsResponse)
-async def list_provider_models(provider_name: str) -> ProviderModelsResponse:
- """
- List available models for a specific provider.
-
- Args:
- provider_name: Name of the provider to get models for
-
- Returns:
- List of available models for the provider
-
- Raises:
- HTTPException: If the provider is not found
- """
- try:
- models = await get_provider_models(provider_name)
- return ProviderModelsResponse(provider=provider_name, models=models)
- except ValueError as e:
- raise HTTPException(status_code=404, detail=str(e))
- except Exception as e:
- raise HTTPException(status_code=500, detail=f"Error getting models: {str(e)}")
diff --git a/servers/inference_server/server/server/features/providers/service.py b/servers/inference_server/server/server/features/providers/service.py
deleted file mode 100644
index c4186d7a..00000000
--- a/servers/inference_server/server/server/features/providers/service.py
+++ /dev/null
@@ -1,185 +0,0 @@
-"""
-# SPDX-License-Identifier: Apache-2.0
-Providers Service
-
-Provides services for working with AI providers.
-"""
-
-import os
-from pathlib import Path
-from typing import Any, List, Optional
-
-from graphcap.providers.factory import initialize_provider_manager
-from graphcap.providers.provider_manager import ProviderManager
-from loguru import logger
-
-from ...config import settings
-from .models import ModelInfo, ProviderInfo
-
-# Global provider manager instance
-_provider_manager: Optional[ProviderManager] = None
-
-
-def get_provider_manager() -> ProviderManager:
- """
- Get or initialize the provider manager.
- Returns:
- ProviderManager: The initialized provider manager
- """
- global _provider_manager
-
- if _provider_manager is None:
- # Use the provider config path from server settings
- config_path = settings.PROVIDER_CONFIG_PATH
-
- # Verify the config path exists
- if config_path is None:
- logger.warning("Provider config path is None, using default locations")
- elif not os.path.exists(str(config_path)):
- logger.warning(f"Provider config path does not exist: {config_path}")
- # Check if the directory exists
- config_dir = Path(str(config_path)).parent
- if not os.path.exists(str(config_dir)):
- logger.warning(f"Config directory does not exist: {config_dir}")
- else:
- logger.info(f"Config directory exists: {config_dir}, but provider.config.toml is missing")
- # List files in the directory
- files = os.listdir(str(config_dir))
- logger.info(f"Files in config directory: {files}")
- else:
- logger.info(f"Provider config file exists: {config_path}")
-
- logger.info(f"Initializing provider manager with config path: {config_path}")
- _provider_manager = initialize_provider_manager(config_path)
-
- # Log the available providers
- provider_names = _provider_manager.available_providers()
- if provider_names:
- logger.info(f"Available providers: {', '.join(provider_names)}")
- else:
- logger.warning("No providers available. Check your provider.config.toml file.")
-
- return _provider_manager
-
-
-def get_available_providers() -> List[ProviderInfo]:
- """
- Get a list of available providers.
- Returns:
- List of provider information
- """
- # Get the provider manager
- provider_manager = get_provider_manager()
-
- # Get the list of available providers
- provider_names = provider_manager.available_providers()
- providers = []
-
- for name in provider_names:
- try:
- config = provider_manager.get_provider_config(name)
- providers.append(
- ProviderInfo(
- name=name,
- kind=config.kind,
- default_model=config.default_model or "",
- )
- )
- except Exception as e:
- logger.error(f"Error getting provider {name}: {str(e)}")
-
- return providers
-
-
-def _create_model_info(model_id: str, default_model: str) -> ModelInfo:
- """Create a ModelInfo instance with the given ID and default model."""
- return ModelInfo(id=model_id, name=model_id, is_default=(model_id == default_model))
-
-
-def _extract_model_id(model: Any) -> str:
- """Extract model ID from a model object."""
- if hasattr(model, "id"):
- return model.id
- return model.name if hasattr(model, "name") else str(model)
-
-
-async def _fetch_models_from_available_models(client: Any, default_model: str) -> List[ModelInfo]:
- """Fetch models using get_available_models method."""
- models = []
- provider_models = await client.get_available_models()
-
- if hasattr(provider_models, "data"):
- for model in provider_models.data:
- model_id = _extract_model_id(model)
- models.append(_create_model_info(model_id, default_model))
-
- return models
-
-
-async def _fetch_models_from_get_models(client: Any, default_model: str) -> List[ModelInfo]:
- """Fetch models using get_models method."""
- models = []
- provider_models = await client.get_models()
-
- if hasattr(provider_models, "models"):
- for model in provider_models.models:
- model_id = _extract_model_id(model)
- models.append(_create_model_info(model_id, default_model))
-
- return models
-
-
-def _get_configured_models(config: Any) -> List[ModelInfo]:
- """Get models from configuration."""
- return [_create_model_info(model_id, config.default_model) for model_id in config.models]
-
-
-async def _fetch_provider_models(client: Any, provider_name: str, config: Any) -> List[ModelInfo]:
- """Attempt to fetch models from the provider."""
- models = []
-
- try:
- logger.info(f"Fetching models from provider {provider_name}")
-
- if hasattr(client, "get_available_models"):
- models = await _fetch_models_from_available_models(client, config.default_model)
- elif hasattr(client, "get_models"):
- models = await _fetch_models_from_get_models(client, config.default_model)
-
- logger.info(f"Found {len(models)} models for provider {provider_name}")
- except Exception as e:
- logger.error(f"Error fetching models from provider {provider_name}: {str(e)}")
- logger.info(f"Falling back to configured models for provider {provider_name}")
-
- return models
-
-
-async def get_provider_models(provider_name: str) -> List[ModelInfo]:
- """
- Get a list of available models for a specific provider.
-
- Args:
- provider_name: Name of the provider to get models for
- Returns:
- List of model information
- Raises:
- ValueError: If the provider is not found
- """
- provider_manager = get_provider_manager()
- available_providers = provider_manager.available_providers()
-
- if provider_name not in available_providers:
- raise ValueError(f"Provider '{provider_name}' not found. Available providers: {', '.join(available_providers)}")
-
- config = provider_manager.get_provider_config(provider_name)
- client = provider_manager.get_client(provider_name)
- models = []
-
- if config.fetch_models:
- models = await _fetch_provider_models(client, provider_name, config)
-
- if not models:
- models = _get_configured_models(config)
- logger.info(f"Using {len(models)} configured models for provider {provider_name}")
-
- return models
diff --git a/servers/inference_server/server/server/models.py b/servers/inference_server/server/server/models.py
deleted file mode 100644
index 81b1b736..00000000
--- a/servers/inference_server/server/server/models.py
+++ /dev/null
@@ -1,12 +0,0 @@
-"""
-# SPDX-License-Identifier: Apache-2.0
-Model Index
-
-Provides centralized access to all database models.
-"""
-
-# from server.features.job.models import JobStatus, PipelineJob, PipelineNodeState
-# from server.features.workflows.models import Workflow
-
-# __all__ = ["Workflow", "JobStatus", "PipelineJob", "PipelineNodeState"]
-__all__ = []
diff --git a/servers/inference_server/server/server/utils/__init__.py b/servers/inference_server/server/server/utils/__init__.py
deleted file mode 100644
index 22acaa32..00000000
--- a/servers/inference_server/server/server/utils/__init__.py
+++ /dev/null
@@ -1,15 +0,0 @@
-"""
-# SPDX-License-Identifier: Apache-2.0
-Utility Module Collection
-
-Collection of utility functions and helpers used throughout the graphcap server.
-
-Key features:
-- Logging configuration
-- JSON formatting
-- Error handling
-- Common utilities
-
-Components:
- logger: Configured loguru logger with JSON formatting
-"""
diff --git a/test/inference_tests/ollama_graphcap_REST.py b/test/inference_tests/ollama_graphcap_REST.py
index 5937f9cb..520113e5 100644
--- a/test/inference_tests/ollama_graphcap_REST.py
+++ b/test/inference_tests/ollama_graphcap_REST.py
@@ -1,6 +1,6 @@
"""
curl -X 'POST' \
- 'http://localhost:32100/api/v1/perspectives/caption-from-path' \
+ 'http://localhost:32100//v1/perspectives/caption-from-path' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
diff --git a/test/library_tests/provider_tests/test_provider_factory.py b/test/library_tests/provider_tests/test_provider_factory.py
new file mode 100644
index 00000000..57bf8a00
--- /dev/null
+++ b/test/library_tests/provider_tests/test_provider_factory.py
@@ -0,0 +1,226 @@
+"""
+# SPDX-License-Identifier: Apache-2.0
+graphcap.tests.lib.providers.test_provider_factory
+
+Tests for provider factory functionality.
+
+Key features:
+- Provider client creation and caching
+- Environment validation
+- Client-specific configurations
+"""
+
+import pytest
+from unittest.mock import patch, MagicMock
+
+from graphcap.providers.factory import (
+ ProviderFactory,
+ create_provider_client,
+ get_provider_factory,
+ clear_provider_cache
+)
+from graphcap.providers.types import ProviderConfig, RateLimits
+
+
+def test_provider_factory_initialization():
+ """
+ GIVEN a provider factory
+ WHEN initializing a new instance
+ THEN should create an empty client cache
+ """
+ factory = ProviderFactory()
+ assert hasattr(factory, '_client_cache')
+ assert factory._client_cache == {}
+
+
+@pytest.mark.parametrize(
+ "provider_config",
+ [
+ {
+ "name": "test-openai",
+ "kind": "openai",
+ "environment": "cloud",
+ "base_url": "https://api.openai.com/v1",
+ "api_key": "test-key",
+ "default_model": "gpt-4o-mini",
+ },
+ {
+ "name": "test-gemini",
+ "kind": "gemini",
+ "environment": "cloud",
+ "base_url": "https://generativelanguage.googleapis.com/v1beta",
+ "api_key": "test-key",
+ "default_model": "gemini-2.0-flash-exp",
+ },
+ ],
+)
+@patch("graphcap.providers.factory.get_client")
+def test_create_client(mock_get_client, provider_config):
+ """
+ GIVEN valid provider configurations
+ WHEN creating a client
+ THEN should call get_client with correct parameters
+ AND should return the expected client instance
+ """
+ # Setup mock
+ mock_client = MagicMock()
+ mock_get_client.return_value = mock_client
+
+ # Create factory and client
+ factory = ProviderFactory()
+ client = factory.create_client(**provider_config)
+
+ # Verify
+ mock_get_client.assert_called_once_with(
+ name=provider_config["name"],
+ kind=provider_config["kind"],
+ environment=provider_config["environment"],
+ api_key=provider_config["api_key"],
+ base_url=provider_config["base_url"],
+ default_model=provider_config["default_model"],
+ )
+ assert client == mock_client
+
+
+@patch("graphcap.providers.factory.get_client")
+def test_client_caching(mock_get_client):
+ """
+ GIVEN a client that has been created
+ WHEN creating the same client again
+ THEN should return the cached client
+ AND should not call get_client again
+ """
+ # Setup mock
+ mock_client = MagicMock()
+ mock_get_client.return_value = mock_client
+
+ # Create configuration
+ config = {
+ "name": "test-openai",
+ "kind": "openai",
+ "environment": "cloud",
+ "base_url": "https://test.com",
+ "api_key": "test-key",
+ "default_model": "test-model",
+ }
+
+ # Create factory and client
+ factory = ProviderFactory()
+
+ # First call should create the client
+ client1 = factory.create_client(**config, use_cache=True)
+ assert mock_get_client.call_count == 1
+
+ # Second call should use cached client
+ client2 = factory.create_client(**config, use_cache=True)
+ assert mock_get_client.call_count == 1 # Count should still be 1
+ assert client1 is client2 # Should be the same instance
+
+ # Call with use_cache=False should create a new client
+ client3 = factory.create_client(**config, use_cache=False)
+ assert mock_get_client.call_count == 2 # Count should now be 2
+ assert client1 is not client3 # Should be different instances
+
+
+def test_clear_cache():
+ """
+ GIVEN a factory with cached clients
+ WHEN clearing the cache
+ THEN should remove all cached clients
+ """
+ with patch("graphcap.providers.factory.get_client") as mock_get_client:
+ # Setup mock
+ mock_client = MagicMock()
+ mock_get_client.return_value = mock_client
+
+ # Create factory and add some clients to cache
+ factory = ProviderFactory()
+ factory.create_client(
+ name="test1",
+ kind="openai",
+ environment="cloud",
+ base_url="https://test1.com",
+ api_key="key1",
+ default_model="model1",
+ )
+ factory.create_client(
+ name="test2",
+ kind="gemini",
+ environment="cloud",
+ base_url="https://test2.com",
+ api_key="key2",
+ default_model="model2",
+ )
+
+ # Verify cache has clients
+ assert len(factory._client_cache) == 2
+
+ # Clear cache
+ factory.clear_cache()
+
+ # Verify cache is empty
+ assert len(factory._client_cache) == 0
+
+
+@patch("graphcap.providers.factory._provider_factory", None)
+@patch("graphcap.providers.factory.ProviderFactory")
+def test_get_provider_factory(mock_factory_class):
+ """
+ GIVEN no existing provider factory
+ WHEN calling get_provider_factory
+ THEN should create a new factory instance
+ """
+ # Setup mock
+ mock_factory = MagicMock()
+ mock_factory_class.return_value = mock_factory
+
+ # Call function
+ factory = get_provider_factory()
+
+ # Verify
+ mock_factory_class.assert_called_once()
+ assert factory == mock_factory
+
+
+@patch("graphcap.providers.factory.get_provider_factory")
+def test_create_provider_client(mock_get_factory):
+ """
+ GIVEN valid provider configuration
+ WHEN calling create_provider_client
+ THEN should get factory and call create_client
+ """
+ # Setup mock
+ mock_factory = MagicMock()
+ mock_client = MagicMock()
+ mock_factory.create_client.return_value = mock_client
+ mock_get_factory.return_value = mock_factory
+
+ # Call function
+ config = {
+ "name": "test",
+ "kind": "openai",
+ "environment": "cloud",
+ "base_url": "https://test.com",
+ "api_key": "test-key",
+ "default_model": "test-model",
+ }
+ client = create_provider_client(**config)
+
+ # Verify
+ mock_get_factory.assert_called_once()
+ mock_factory.create_client.assert_called_once_with(**config)
+ assert client == mock_client
+
+
+@patch("graphcap.providers.factory._provider_factory")
+def test_clear_provider_cache(mock_factory):
+ """
+ GIVEN an existing provider factory
+ WHEN calling clear_provider_cache
+ THEN should call clear_cache on the factory
+ """
+ # Call function
+ clear_provider_cache()
+
+ # Verify
+ mock_factory.clear_cache.assert_called_once()
\ No newline at end of file
diff --git a/workspace/config/.env.template b/workspace/config/.env.template
index 70779218..2baa814e 100644
--- a/workspace/config/.env.template
+++ b/workspace/config/.env.template
@@ -17,8 +17,6 @@ POSTGRES_HOST=graphcap_postgres
POSTGRES_PORT=5432
POSTGRES_DB=graphcap
-# Configuration Paths
-DEFAULT_PROVIDER_CONFIG="./provider.config.toml"
GRAPHCAP_SERVER=http://localhost:32100
diff --git a/workspace/config/provider.example.config.toml b/workspace/config/provider.example.config.toml
deleted file mode 100644
index 03417698..00000000
--- a/workspace/config/provider.example.config.toml
+++ /dev/null
@@ -1,74 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-#
-# This is a provider configuration file that allows you to customize provider configurations.
-# To use this file:
-#
-# 1. Uncomment the providers you want to enable
-# 2. Make your desired changes to those providers
-# 3. Save as 'provider.config.toml'
-# 4. Run 'docker compose build' and 'docker compose up -d' as normal
-
-[openai]
-kind = "openai"
-environment = "cloud"
-env_var = "OPENAI_API_KEY"
-base_url = "https://api.openai.com/v1"
-models = [
- "gpt-4o-mini",
- "gpt-4o",
-]
-
-[gemini]
-kind = "gemini"
-environment = "cloud"
-env_var = "GOOGLE_API_KEY"
-base_url = "https://generativelanguage.googleapis.com/v1beta"
-models = [
- "gemini-2.0-flash-exp",
-]
-# Rate limits configuration
-rate_limits.requests_per_minute = 10
-rate_limits.tokens_per_minute = 4000000
-
-# [openrouter]
-# kind = "openrouter"
-# environment = "cloud"
-# env_var = "OPENROUTER_API_KEY"
-# base_url = "https://openrouter.ai/api/v1"
-# models = [
-# "minimax/minimax-01",
-# "qwen/qvq-72b-preview",
-# "qwen/qvq-32b-preview",
-# "qwen/qvq-1.5b-preview",
-# "google/gemini-2.0-flash-exp:free",
-# "mistralai/pixtral-large-2411",
-# "meta-llama/llama-3.2-90b-vision-instruct:free",
-# "qwen/qwen-2-vl-72b-instruct"
-# ]
-
-# [custom]
-# # Custom provider configuration
-# # Each provider needs a unique name, env_var (or stub), and base_url
-
-# [ollama]
-# kind = "ollama"
-# environment = "local"
-# env_var = "CUSTOM_PROVIDER_1_KEY"
-# base_url = "http://localhost:11434"
-# fetch_models = true
-
-# [my_provider_2]
-# kind = "ollama"
-# environment = "local"
-# env_var = "CUSTOM_PROVIDER_2_KEY"
-# base_url = "http://localhost:11435"
-# fetch_models = true
-
-
-# # Add more custom providers as needed following the same pattern:
-# # [provider_name]
-# # environment = "cloud"
-# # kind = "vllm"
-# # env_var = "API_KEY"
-# # base_url = "BASE_URL"
-# # models = ["model1", "model2"]