Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions backend/omni/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Added

- `enabled` field on `ModelProvider`
- `GET /api/models` now skips providers where `enabled=False`, so only active providers contribute models to the aggregated list.

## [0.0.3] - 2026-04-28

### Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,7 @@ def _make_provider(
name=name,
base_url=base_url,
api_key=api_key,
enabled=True,
properties={},
created_at=None,
updated_at=None,
Expand All @@ -1240,6 +1241,7 @@ def _real_provider() -> ModelProviderResponse:
name="myopenai",
base_url=os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1"),
api_key=os.environ.get("UNIT_TEST_OPENAI_API_KEY", ""),
enabled=True,
properties={},
created_at=None,
updated_at=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def mock_provider_modules(self):
name="OpenAI Production",
base_url="https://api.openai.com/v1",
api_key="sk-...",
enabled=True,
properties={},
created_at="2024-01-15T10:30:00Z",
updated_at="2024-01-15T10:30:00Z",
Expand Down Expand Up @@ -106,6 +107,7 @@ def mock_provider_modules(self):
name="Local Ollama",
base_url="http://localhost:11434",
api_key="",
enabled=True,
properties={},
created_at="2024-01-16T14:20:00Z",
updated_at="2024-01-16T14:20:00Z",
Expand Down Expand Up @@ -203,6 +205,49 @@ def test_get_all_models_endpoint(self, test_client):
assert "created" in model
assert "owned_by" in model

def test_get_all_models_excludes_disabled_providers(self, mock_session_module):
"""Test GET /models does not return models from disabled providers."""
disabled_provider = ModelProviderResponse(
id="openai-disabled",
type="openai",
name="Disabled Provider",
base_url="https://api.disabled.com/v1",
api_key="sk-disabled",
enabled=False,
properties={},
created_at="2024-01-15T10:30:00Z",
updated_at="2024-01-15T10:30:00Z",
)
disabled_models = ModelResponse(
data=[
{
"id": "gpt-secret",
"object": "model",
"created": 1686935002,
"owned_by": "openai",
}
]
)
disabled_module = DummyModelProviderModule(
"openai", [disabled_provider], {"openai-disabled": disabled_models}
)

dependencies = ModuleDependencies(
{"disabled_provider": disabled_module, "session": mock_session_module}
)
router = CentralModelProviderRouter(dependencies, config={})
app = FastAPI()
app.include_router(router.router)
client = TestClient(app)

response = client.get("/api/models")

assert response.status_code == 200
data = response.json()
model_ids = [m["id"] for m in data["data"]]
assert not any("gpt-secret" in mid for mid in model_ids)
assert data["data"] == []

def test_get_all_providers_with_pagination(self, test_client):
"""Test GET /api/models/providers with pagination"""
response = test_client.get("/api/models/providers?limit=1&offset=0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def test_create_provider(
name="NewProvider",
url="https://api.new.com",
properties=expected_properties,
enabled=False,
)

def test_create_provider_validation_error(
Expand Down Expand Up @@ -295,6 +296,7 @@ def test_update_provider(
name="UpdatedProvider",
url="https://api.updated.com",
properties=expected_properties,
enabled=None,
)

def test_update_provider_not_found(
Expand Down Expand Up @@ -379,6 +381,7 @@ def test_complex_properties_handling(
name="ComplexProvider",
url="https://api.complex.com",
properties=expected_properties,
enabled=False,
)

@pytest.mark.skipif(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,10 @@ async def get_all_models(self, request: Request) -> ModelsListResponse:
request, limit=None, offset=None
)

# For each provider, get its models
# For each provider, get its models (only enabled ones)
for provider in providers_response.providers:
if not provider.enabled:
continue
try:
models_response = await provider_module.get_models(
request, provider.id
Expand Down
2 changes: 2 additions & 0 deletions backend/omni/src/modai/modules/model_provider/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class ModelProviderResponse(BaseModel):
name: str
base_url: str
api_key: str
enabled: bool
properties: dict[str, Any]
created_at: str | None
updated_at: str | None
Expand All @@ -30,6 +31,7 @@ class ModelProviderCreateRequest(BaseModel):
name: str
base_url: str
api_key: str
enabled: bool | None = None
properties: dict[str, Any] = {}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,12 @@ async def create_provider(
properties = (provider_data.properties or {}).copy()
properties["api_key"] = provider_data.api_key

# Create new provider
# Create new provider (disabled by default)
provider = await self.provider_store.add_provider(
name=provider_data.name,
url=provider_data.base_url,
properties=properties,
enabled=False,
)

return self._create_provider_response(provider)
Expand Down Expand Up @@ -129,6 +130,7 @@ async def update_provider(
name=provider_data.name,
url=provider_data.base_url,
properties=properties,
enabled=provider_data.enabled,
)
if not provider:
raise HTTPException(
Expand Down Expand Up @@ -208,6 +210,7 @@ def _create_provider_response(
name=provider.name,
base_url=provider.url,
api_key=api_key,
enabled=provider.enabled,
properties=properties,
created_at=provider.created_at.isoformat() if provider.created_at else None,
updated_at=provider.updated_at.isoformat() if provider.updated_at else None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class ModelProvider:
name: str
url: str
properties: dict[str, Any]
enabled: bool = False
created_at: datetime | None = None
updated_at: datetime | None = None

Expand Down Expand Up @@ -67,14 +68,15 @@ async def get_provider(self, provider_id: str) -> ModelProvider | None:

@abstractmethod
async def add_provider(
self, name: str, url: str, properties: dict[str, Any]
self, name: str, url: str, properties: dict[str, Any], enabled: bool = False
) -> ModelProvider:
"""
Adds a new model provider configuration.
Args:
name: Human-readable name for the provider
url: API endpoint URL for the provider
properties: Configuration properties specific to the provider
enabled: Whether the provider is active (default: False)

Returns:
Created ModelProvider object
Expand All @@ -91,6 +93,7 @@ async def update_provider(
name: str,
url: str,
properties: dict[str, Any],
enabled: bool | None = None,
) -> ModelProvider | None:
"""
Updates an existing model provider configuration.
Expand All @@ -100,6 +103,7 @@ async def update_provider(
name: New name for the provider (optional)
url: New URL for the provider (optional)
properties: New properties for the provider (optional)
enabled: Whether the provider is active (None = keep current)

Returns:
Updated ModelProvider object if found, None otherwise
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import json

from sqlalchemy import (
Boolean,
create_engine,
MetaData,
Table,
Expand Down Expand Up @@ -58,6 +59,9 @@ def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]):
Column("name", String(128), unique=True, index=True),
Column("url", String(1000)),
Column("properties", JSON),
Column(
"enabled", Boolean, default=False, nullable=False, server_default="0"
),
Column("created_at", DateTime, default=datetime.now),
Column("updated_at", DateTime, default=datetime.now),
)
Expand Down Expand Up @@ -85,6 +89,7 @@ def _row_to_provider(self, row) -> ModelProvider:
name=row.name,
url=row.url,
properties=properties,
enabled=bool(row.enabled) if row.enabled is not None else False,
created_at=row.created_at,
updated_at=row.updated_at,
)
Expand Down Expand Up @@ -131,7 +136,7 @@ async def get_provider(self, provider_id: str) -> ModelProvider | None:
return None

async def add_provider(
self, name: str, url: str, properties: dict[str, Any]
self, name: str, url: str, properties: dict[str, Any], enabled: bool = False
) -> ModelProvider:
provider_id = self._generate_provider_id()
with self._get_session() as session:
Expand All @@ -146,6 +151,7 @@ async def add_provider(
name=name.strip(),
url=url.strip(),
properties=properties,
enabled=enabled,
created_at=now,
updated_at=now,
)
Expand All @@ -159,6 +165,7 @@ async def add_provider(
name=name.strip(),
url=url.strip(),
properties=properties,
enabled=enabled,
created_at=now,
updated_at=now,
)
Expand All @@ -169,6 +176,7 @@ async def update_provider(
name: str,
url: str,
properties: dict[str, Any],
enabled: bool | None = None,
) -> ModelProvider | None:
with self._get_session() as session:
# Check if provider exists first
Expand All @@ -185,6 +193,7 @@ async def update_provider(
properties = {}

now = datetime.now()
new_enabled = enabled if enabled is not None else bool(existing_row.enabled)

# Update the provider
update_stmt = (
Expand All @@ -194,6 +203,7 @@ async def update_provider(
name=name.strip(),
url=url.strip(),
properties=properties,
enabled=new_enabled,
updated_at=now,
)
)
Expand All @@ -207,6 +217,7 @@ async def update_provider(
name=name.strip(),
url=url.strip(),
properties=properties,
enabled=new_enabled,
created_at=existing_row.created_at,
updated_at=now,
)
Expand Down
2 changes: 1 addition & 1 deletion docs/biome.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"$schema": "https://biomejs.dev/schemas/2.4.13/schema.json",
"$schema": "https://biomejs.dev/schemas/2.4.15/schema.json",
"root": true,
"vcs": {
"enabled": true,
Expand Down
2 changes: 1 addition & 1 deletion docs/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"mermaid": "^11.14.0"
},
"devDependencies": {
"@biomejs/biome": "^2.4.13",
"@biomejs/biome": "^2.4.15",
"@rspress/core": "^2.0.9",
"typescript": "^6.0.3"
}
Expand Down
Loading