From 4e709d0f0a4b483ff10a956651a5b8ea1edf7ed8 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Thu, 12 Mar 2026 11:17:34 -0600 Subject: [PATCH] feat: add tenant-scoped Tinker model aliases Return the Tinker API key to ART clients and expose /v1/models alias management so tenant-scoped routing works through the OpenAI-compatible Tinker server. Made-with: Cursor --- src/art/tinker/backend.py | 15 ++ src/art/tinker/server.py | 418 +++++++++++++++++++++++--------------- src/art/tinker/service.py | 2 +- 3 files changed, 274 insertions(+), 161 deletions(-) diff --git a/src/art/tinker/backend.py b/src/art/tinker/backend.py index ff663b921..0da2dbb30 100644 --- a/src/art/tinker/backend.py +++ b/src/art/tinker/backend.py @@ -3,6 +3,8 @@ from mp_actors import move_to_child_process +from .. import dev +from ..backend import AnyTrainableModel from ..local.backend import LocalBackend from ..local.service import ModelService from ..model import TrainableModel @@ -26,6 +28,19 @@ def __init__( os.environ["TINKER_API_KEY"] = tinker_api_key super().__init__(in_process=in_process, path=path) + async def _prepare_backend_for_training( + self, + model: AnyTrainableModel, + config: dev.OpenAIServerConfig | None = None, + ) -> tuple[str, str]: + api_key = os.environ["TINKER_API_KEY"] + config_dict: dict = dict(config or {}) + server_args = dict(config_dict.get("server_args", {})) + server_args["api_key"] = api_key + config_dict["server_args"] = server_args + base_url, _ = await super()._prepare_backend_for_training(model, config) + return base_url, api_key + async def _get_service(self, model: TrainableModel) -> ModelService: from ..dev.get_model_config import get_model_config from ..dev.model import TinkerArgs, TinkerTrainingClientArgs diff --git a/src/art/tinker/server.py b/src/art/tinker/server.py index 32f41ca16..590f4983e 100644 --- a/src/art/tinker/server.py +++ b/src/art/tinker/server.py @@ -4,11 +4,12 @@ import os import socket import time -from typing import Annotated, cast +from typing import Annotated, Literal import uuid from fastapi import FastAPI, HTTPException, Request from openai import AsyncOpenAI +from openai.types import Model, ModelDeleted from openai.types.chat.chat_completion import ChatCompletion, Choice, ChoiceLogprobs from openai.types.chat.chat_completion_message import ChatCompletionMessage from openai.types.chat.chat_completion_message_function_tool_call import ( @@ -22,9 +23,8 @@ ) from openai.types.chat.completion_create_params import CompletionCreateParams from openai.types.completion_usage import CompletionUsage -from pydantic import SkipValidation +from pydantic import BaseModel, SkipValidation import tinker -from tinker.lib.public_interfaces.rest_client import RestClient as TinkerRestClient from transformers.tokenization_utils_base import BatchEncoding import uvicorn @@ -35,140 +35,45 @@ from mp_actors import close_proxy, move_to_child_process +class ModelList(BaseModel): + object: Literal["list"] = "list" + data: list[Model] + + +class ModelUpsert(BaseModel): + target: str + + @dataclass class OpenAICompatibleTinkerServer: host: str | None = None port: int | None = None num_workers: int | None = None - models: dict[str, str] = field(default_factory=dict) _prefix_cache: LRUTrieCache = field(default_factory=LRUTrieCache) - _workers: list["Worker"] = field(default_factory=list) _task: asyncio.Task[None] | None = None - _tenant_clients: dict[str, tuple[tinker.ServiceClient, TinkerRestClient]] = field( + _tenants: dict[str, "OpenAICompatibleTinkerServerTenant"] = field( default_factory=dict ) - _sampling_client_and_base_model_tasks: dict[ - tuple[str, str], asyncio.Task[tuple[tinker.SamplingClient, str]] - ] = field(default_factory=dict) - - @dataclass - class Worker: - _renderers: dict[str, renderers.Renderer] = field(default_factory=dict) - - async def prompt_tokens( - self, - base_model: str, - messages: list[ChatCompletionMessageParam], - tools: list[ChatCompletionToolUnionParam] | None, - ) -> list[int]: - encoding = self._get_renderer(base_model).tokenizer.apply_chat_template( - messages, # type: ignore - tools=tools, # type: ignore - add_generation_prompt=True, - ) - if isinstance(encoding, BatchEncoding): - return encoding.input_ids - else: - return encoding # type: ignore - - async def chat_completion_and_token_discrepancies( - self, - base_model: str, - sample_response: tinker.SampleResponse, - model_name: str, - prompt_tokens: int, - ) -> tuple[ChatCompletion, list[tuple[list[int], list[int]]]]: - renderer = self._get_renderer(base_model) - choices: list[Choice] = [] - token_discrepancies: list[tuple[list[int], list[int]]] = [] - for i, sequence in enumerate(sample_response.sequences): - assert sequence.logprobs is not None, "Logprobs are required" - assert len(sequence.tokens) == len(sequence.logprobs), ( - "Tokens and logprobs must have the same length" - ) - rendered_response_tokens = renderer.tokenizer.encode( - renderer.tokenizer.decode(sequence.tokens) - ) - if rendered_response_tokens != sequence.tokens: - token_discrepancies.append( - (rendered_response_tokens, sequence.tokens) - ) - message, _ = renderer.parse_response(sequence.tokens) - openai_message = renderer.to_openai_message(message) - tool_calls = ( - [ - ChatCompletionMessageFunctionToolCall( - type="function", - id=tool_call.get("id") or "", - function=Function( - name=tool_call["function"]["name"], - arguments=tool_call["function"]["arguments"], - ), - ) - for tool_call in openai_message.get("tool_calls", []) - ] - if openai_message.get("tool_calls") - else None - ) - choices.append( - Choice( - finish_reason=sequence.stop_reason, - index=i, - message=ChatCompletionMessage( - content=openai_message.get("content") or None, - role="assistant", - tool_calls=tool_calls, # type: ignore - ), - logprobs=ChoiceLogprobs( - content=[ - ChatCompletionTokenLogprob( - token=f"token_id:{token}", - bytes=list( - renderer.tokenizer.decode(token).encode() - ), - logprob=logprob, - top_logprobs=[], - ) - for token, logprob in zip( - sequence.tokens, sequence.logprobs - ) - ] - ), - ) - ) - completion_tokens = sum( - len(sequence.tokens) for sequence in sample_response.sequences - ) - return ( - ChatCompletion( - id=str(uuid.uuid4()), - choices=choices, - created=int(time.time()), - model=model_name, - object="chat.completion", - usage=CompletionUsage( - completion_tokens=completion_tokens, - prompt_tokens=prompt_tokens, - total_tokens=completion_tokens + prompt_tokens, - ), - ), - token_discrepancies, - ) + _workers: list["OpenAICompatibleTinkerServerWorker"] = field(default_factory=list) - def _get_renderer(self, base_model: str) -> renderers.Renderer: - if not base_model in self._renderers: - self._renderers[base_model] = renderers.get_renderer( - name=get_renderer_name(base_model), - tokenizer=get_tokenizer(base_model), - ) - return self._renderers[base_model] + @property + def models(self) -> dict[str, str]: + if "TINKER_API_KEY" not in os.environ: + raise ValueError("TINKER_API_KEY is not set") + return self._get_tenant(os.environ["TINKER_API_KEY"])._models + + @models.setter + def models(self, models: dict[str, str]) -> None: + if "TINKER_API_KEY" not in os.environ: + raise ValueError("TINKER_API_KEY is not set") + self._get_tenant(os.environ["TINKER_API_KEY"])._models = models async def start(self) -> tuple[str, int]: host = self.host or "0.0.0.0" port = self.port or get_free_port(host) self._workers = [ move_to_child_process( - OpenAICompatibleTinkerServer.Worker(), + OpenAICompatibleTinkerServerWorker(), process_name=f"openai-compatible-tinker-server-worker-{i}", ) for i in range(self.num_workers or self._default_num_workers()) @@ -185,7 +90,7 @@ async def start(self) -> tuple[str, int]: try: await client.completions.create(model="", prompt="") break # Server is ready - except: + except Exception: await asyncio.sleep(0.1) return host, port @@ -197,6 +102,20 @@ async def stop(self) -> None: for worker in self._workers: close_proxy(worker) + def _get_request_tenant( + self, request: Request + ) -> "OpenAICompatibleTinkerServerTenant": + auth = request.headers.get("authorization", "") + scheme, _, api_key = auth.partition(" ") + api_key = api_key.strip() + if scheme.lower() != "bearer" or not api_key: + raise HTTPException( + status_code=401, + detail="Missing or invalid Authorization header", + headers={"WWW-Authenticate": "Bearer"}, + ) + return self._get_tenant(api_key) + async def _run(self, host: str, port: int) -> None: workers = cycle(self._workers) app = FastAPI() @@ -211,28 +130,78 @@ async def completions() -> dict: # Minimal completions endpoint for health checks return {"choices": [{"text": ""}]} + @app.get("/v1/models") + async def list_models(request: Request) -> ModelList: + tenant = self._get_request_tenant(request) + return ModelList( + object="list", + data=[ + Model( + id=model, + created=tenant._model_timestamps.get(model, 0), + object="model", + owned_by="tinker", + ) + for model in tenant._models + ], + ) + + @app.get("/v1/models/{model}") + async def get_model(request: Request, model: str) -> Model: + tenant = self._get_request_tenant(request) + if model not in tenant._models: + raise HTTPException( + status_code=404, + detail=f"Model not found: {model}", + ) + return Model( + id=model, + created=tenant._model_timestamps.get(model, 0), + object="model", + owned_by="tinker", + ) + + @app.put("/v1/models/{model}") + async def put_model( + request: Request, + model: str, + body: ModelUpsert, + ) -> Model: + tenant = self._get_request_tenant(request) + tenant._models[model] = body.target + tenant._model_timestamps.setdefault(model, int(time.time())) + return Model( + id=model, + created=tenant._model_timestamps[model], + object="model", + owned_by="tinker", + ) + + @app.delete("/v1/models/{model}") + async def delete_model(request: Request, model: str) -> ModelDeleted: + tenant = self._get_request_tenant(request) + if model not in tenant._models: + raise HTTPException( + status_code=404, + detail=f"Model not found: {model}", + ) + tenant._models.pop(model) + tenant._model_timestamps.pop(model, None) + return ModelDeleted( + id=model, + deleted=True, + object="model", + ) + @app.post("/v1/chat/completions") async def chat_completions( request: Request, body: Annotated[CompletionCreateParams, SkipValidation] ) -> ChatCompletion: worker = next(workers) - auth = request.headers.get("authorization", "") - scheme, _, api_key = auth.partition(" ") - api_key = api_key.strip() - if scheme.lower() != "bearer" or not api_key: - raise HTTPException( - status_code=401, - detail="Missing or invalid Authorization header", - headers={"WWW-Authenticate": "Bearer"}, - ) - ( - sampling_client, - base_model, - ) = await self._get_sampling_client_and_base_model( - api_key, self.models.get(body["model"], body["model"]) - ) + tenant = self._get_request_tenant(request) + samplable_model = await tenant._get_samplable_model(body["model"]) rendered_prompt_tokens = await worker.prompt_tokens( - base_model=base_model, + base_model=samplable_model.base_model, messages=list(body["messages"]), tools=list(body.get("tools", [])) if "tools" in body else None, ) @@ -246,7 +215,7 @@ async def chat_completions( + rendered_prompt_tokens[prefix_entry.rendered_len :] ) try: - sample_response = await sampling_client.sample_async( + sample_response = await samplable_model.sampling_client.sample_async( prompt=tinker.ModelInput.from_ints(tokens=prompt_tokens), num_samples=body.get("n") or 1, sampling_params=tinker.SamplingParams( @@ -273,7 +242,7 @@ async def chat_completions( chat_completion, token_discrepancies, ) = await worker.chat_completion_and_token_discrepancies( - base_model=base_model, + base_model=samplable_model.base_model, sample_response=sample_response, model_name=body["model"], prompt_tokens=len(prompt_tokens), @@ -300,9 +269,32 @@ def _default_num_workers(self) -> int: except (AttributeError, OSError): return os.cpu_count() or 1 - async def _get_sampling_client_and_base_model( - self, api_key: str, model_path_or_base_model: str - ) -> tuple[tinker.SamplingClient, str]: + def _get_tenant(self, api_key: str) -> "OpenAICompatibleTinkerServerTenant": + if api_key not in self._tenants: + self._tenants[api_key] = OpenAICompatibleTinkerServerTenant(api_key) + return self._tenants[api_key] + + +@dataclass +class OpenAICompatibleTinkerServerSamplableModel: + sampling_client: tinker.SamplingClient + base_model: str + + +class OpenAICompatibleTinkerServerTenant: + def __init__(self, api_key: str) -> None: + self._models: dict[str, str] = {} + self._model_timestamps: dict[str, int] = {} + self._service_client = tinker.ServiceClient(api_key=api_key) + self._rest_client = self._service_client.create_rest_client() + self._samplable_models: dict[ + str, asyncio.Task[OpenAICompatibleTinkerServerSamplableModel] + ] = dict() + + async def _get_samplable_model( + self, model: str + ) -> OpenAICompatibleTinkerServerSamplableModel: + model_path_or_base_model = self._models.get(model, model) if not model_path_or_base_model.startswith("tinker://"): try: get_renderer_name(model_path_or_base_model) @@ -314,37 +306,143 @@ async def _get_sampling_client_and_base_model( "A model must be either a valid `tinker://...` path, supported base model, or registered model alias." ), ) - args = (api_key, model_path_or_base_model) - if (task := self._sampling_client_and_base_model_tasks.get(args)) and ( + if (task := self._samplable_models.get(model_path_or_base_model)) and ( not task.done() or task.exception() is None ): return await task - self._sampling_client_and_base_model_tasks[args] = asyncio.create_task( - self._load_sampling_client_and_base_model(*args) + self._samplable_models[model_path_or_base_model] = asyncio.create_task( + self._load_samplable_model(model_path_or_base_model) ) - return await self._sampling_client_and_base_model_tasks[args] - - async def _load_sampling_client_and_base_model( - self, api_key: str, model_path_or_base_model: str - ) -> tuple[tinker.SamplingClient, str]: - if api_key not in self._tenant_clients: - service_client = tinker.ServiceClient() - rest_client = service_client.create_rest_client() - self._tenant_clients[api_key] = (service_client, rest_client) - service_client, rest_client = self._tenant_clients[api_key] + return await self._samplable_models[model_path_or_base_model] + + async def _load_samplable_model( + self, model_path_or_base_model: str + ) -> OpenAICompatibleTinkerServerSamplableModel: is_model_path = model_path_or_base_model.startswith("tinker://") - sampling_client = await service_client.create_sampling_client_async( + sampling_client = await self._service_client.create_sampling_client_async( model_path=model_path_or_base_model if is_model_path else None, base_model=model_path_or_base_model if not is_model_path else None, ) if is_model_path: - sampler_response = await rest_client.get_sampler_async( + sampler_response = await self._rest_client.get_sampler_async( sampling_client._sampling_session_id ) base_model = sampler_response.base_model else: base_model = model_path_or_base_model - return sampling_client, base_model + return OpenAICompatibleTinkerServerSamplableModel( + sampling_client=sampling_client, + base_model=base_model, + ) + + +@dataclass +class OpenAICompatibleTinkerServerWorker: + _renderers: dict[str, renderers.Renderer] = field(default_factory=dict) + + async def prompt_tokens( + self, + base_model: str, + messages: list[ChatCompletionMessageParam], + tools: list[ChatCompletionToolUnionParam] | None, + ) -> list[int]: + encoding = self._get_renderer(base_model).tokenizer.apply_chat_template( + messages, # type: ignore + tools=tools, # type: ignore + add_generation_prompt=True, + ) + if isinstance(encoding, BatchEncoding): + return encoding.input_ids + else: + return encoding # type: ignore + + async def chat_completion_and_token_discrepancies( + self, + base_model: str, + sample_response: tinker.SampleResponse, + model_name: str, + prompt_tokens: int, + ) -> tuple[ChatCompletion, list[tuple[list[int], list[int]]]]: + renderer = self._get_renderer(base_model) + choices: list[Choice] = [] + token_discrepancies: list[tuple[list[int], list[int]]] = [] + for i, sequence in enumerate(sample_response.sequences): + assert sequence.logprobs is not None, "Logprobs are required" + assert len(sequence.tokens) == len(sequence.logprobs), ( + "Tokens and logprobs must have the same length" + ) + rendered_response_tokens = renderer.tokenizer.encode( + renderer.tokenizer.decode(sequence.tokens) + ) + if rendered_response_tokens != sequence.tokens: + token_discrepancies.append((rendered_response_tokens, sequence.tokens)) + message, _ = renderer.parse_response(sequence.tokens) + openai_message = renderer.to_openai_message(message) + tool_calls = ( + [ + ChatCompletionMessageFunctionToolCall( + type="function", + id=tool_call.get("id") or "", + function=Function( + name=tool_call["function"]["name"], + arguments=tool_call["function"]["arguments"], + ), + ) + for tool_call in openai_message.get("tool_calls", []) + ] + if openai_message.get("tool_calls") + else None + ) + choices.append( + Choice( + finish_reason=sequence.stop_reason, + index=i, + message=ChatCompletionMessage( + content=openai_message.get("content") or None, + role="assistant", + tool_calls=tool_calls, # type: ignore + ), + logprobs=ChoiceLogprobs( + content=[ + ChatCompletionTokenLogprob( + token=f"token_id:{token}", + bytes=list(renderer.tokenizer.decode(token).encode()), + logprob=logprob, + top_logprobs=[], + ) + for token, logprob in zip( + sequence.tokens, sequence.logprobs + ) + ] + ), + ) + ) + completion_tokens = sum( + len(sequence.tokens) for sequence in sample_response.sequences + ) + return ( + ChatCompletion( + id=str(uuid.uuid4()), + choices=choices, + created=int(time.time()), + model=model_name, + object="chat.completion", + usage=CompletionUsage( + completion_tokens=completion_tokens, + prompt_tokens=prompt_tokens, + total_tokens=completion_tokens + prompt_tokens, + ), + ), + token_discrepancies, + ) + + def _get_renderer(self, base_model: str) -> renderers.Renderer: + if base_model not in self._renderers: + self._renderers[base_model] = renderers.get_renderer( + name=get_renderer_name(base_model), + tokenizer=get_tokenizer(base_model), + ) + return self._renderers[base_model] def get_free_port(host: str | None = None) -> int: diff --git a/src/art/tinker/service.py b/src/art/tinker/service.py index 1f5970aca..2eebafc45 100644 --- a/src/art/tinker/service.py +++ b/src/art/tinker/service.py @@ -47,8 +47,8 @@ async def start_openai_server( self._server = OpenAICompatibleTinkerServer( host=config.get("host") if config else None, port=config.get("port") if config else None, - models=state.models, ) + self._server.models = state.models with log_timing("Starting OpenAI-compatible Tinker server"): return await self._server.start()