diff --git a/agent/context_manager/manager.py b/agent/context_manager/manager.py index 85e96af0..9401fec8 100644 --- a/agent/context_manager/manager.py +++ b/agent/context_manager/manager.py @@ -143,6 +143,29 @@ async def summarize_messages( prompt_messages = list(messages) + [Message(role="user", content=prompt)] llm_params = _resolve_llm_params(model_name, hf_token, reasoning_effort="high") + if llm_params.get("_ml_intern_provider") == "openai-codex": + if session is None: + raise RuntimeError("OpenAI Codex summarization requires a session") + from agent.core.codex_responses import ( + call_codex_responses, + resolve_codex_llm_params, + ) + + codex_params = await resolve_codex_llm_params( + model_name, + user_id=getattr(session, "user_id", None), + reasoning_effort="high", + ) + result = await call_codex_responses( + session, + prompt_messages, + tool_specs, + codex_params, + emit_events=False, + ) + completion_tokens = int(result.usage.get("completion_tokens") or 0) + return result.content or "", completion_tokens + prompt_messages, tool_specs = with_prompt_caching( prompt_messages, tool_specs, llm_params.get("model") ) diff --git a/agent/core/agent_loop.py b/agent/core/agent_loop.py index e32e4e42..4d3cec0c 100644 --- a/agent/core/agent_loop.py +++ b/agent/core/agent_loop.py @@ -24,6 +24,13 @@ normalize_tool_operation, ) from agent.core.cost_estimation import CostEstimate, estimate_tool_cost +from agent.core.codex_responses import ( + CodexAPIError, + CodexAuthRequiredError, + call_codex_responses, + is_codex_llm_params, + resolve_codex_llm_params, +) from agent.messaging.gateway import NotificationGateway from agent.core import telemetry from agent.core.doom_loop import check_for_doom_loop @@ -496,15 +503,27 @@ async def _heal_effort_and_rebuild_params( session.model_effective_effort[model] = None logger.info("healed: %s probe inconclusive — stripped", model) - return _resolve_llm_params( + params = _resolve_llm_params( model, session.hf_token, reasoning_effort=session.effective_effort_for(model), ) + if is_codex_llm_params(params): + return await resolve_codex_llm_params( + model, + user_id=session.user_id, + reasoning_effort=session.effective_effort_for(model), + ) + return params def _friendly_error_message(error: Exception) -> str | None: """Return a user-friendly message for known error types, or None to fall back to traceback.""" + if isinstance(error, CodexAuthRequiredError): + return str(error) + if isinstance(error, CodexAPIError): + return f"OpenAI Codex request failed: {error}" + err_str = str(error).lower() if ( @@ -811,6 +830,22 @@ async def _call_llm_streaming( session: Session, messages, tools, llm_params ) -> LLMResult: """Call the LLM with streaming, emitting assistant_chunk events.""" + if is_codex_llm_params(llm_params): + result = await call_codex_responses( + session, + messages, + tools, + llm_params, + emit_events=True, + ) + return LLMResult( + content=result.content, + tool_calls_acc=result.tool_calls_acc, + token_count=result.token_count, + finish_reason=result.finish_reason, + usage=result.usage, + ) + response = None _healed_effort = False # one-shot safety net per call _healed_thinking_signature = False @@ -969,6 +1004,26 @@ async def _call_llm_non_streaming( session: Session, messages, tools, llm_params ) -> LLMResult: """Call the LLM without streaming, emit assistant_message at the end.""" + if is_codex_llm_params(llm_params): + result = await call_codex_responses( + session, + messages, + tools, + llm_params, + emit_events=False, + ) + if result.content: + await session.send_event( + Event(event_type="assistant_message", data={"content": result.content}) + ) + return LLMResult( + content=result.content, + tool_calls_acc=result.tool_calls_acc, + token_count=result.token_count, + finish_reason=result.finish_reason, + usage=result.usage, + ) + response = None _healed_effort = False _healed_thinking_signature = False @@ -1215,6 +1270,14 @@ async def run_agent( session.config.model_name ), ) + if is_codex_llm_params(llm_params): + llm_params = await resolve_codex_llm_params( + session.config.model_name, + user_id=session.user_id, + reasoning_effort=session.effective_effort_for( + session.config.model_name + ), + ) if session.stream: llm_result = await _call_llm_streaming( session, messages, tools, llm_params diff --git a/agent/core/codex_oauth.py b/agent/core/codex_oauth.py new file mode 100644 index 00000000..954a568e --- /dev/null +++ b/agent/core/codex_oauth.py @@ -0,0 +1,565 @@ +"""OpenAI Codex OAuth helpers. + +This mirrors Codex/Pi's ``openai-codex`` provider shape: ChatGPT OAuth +credentials are refreshable bearer tokens for ChatGPT's Codex backend, not +normal OpenAI API keys. +""" + +from __future__ import annotations + +import base64 +import asyncio +import hashlib +import json +import os +import secrets +import time +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import Any +from urllib.parse import parse_qs, urlencode, urlparse + +import httpx + +try: + import fcntl +except ImportError: # pragma: no cover - this app is deployed on Linux. + fcntl = None # type: ignore[assignment] + +try: + import msvcrt +except ImportError: # pragma: no cover - Windows-only fallback. + msvcrt = None # type: ignore[assignment] + + +CODEX_PROVIDER_ID = "openai-codex" +CODEX_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" +CODEX_AUTHORIZE_URL = "https://auth.openai.com/oauth/authorize" +CODEX_TOKEN_URL = "https://auth.openai.com/oauth/token" +CODEX_REDIRECT_URI = "http://localhost:1455/auth/callback" +CODEX_SCOPE = "openid profile email offline_access" +CODEX_JWT_CLAIM_PATH = "https://api.openai.com/auth" +CODEX_ORIGINATOR = os.environ.get("ML_INTERN_CODEX_ORIGINATOR", "pi") + +_REFRESH_SKEW_MS = 60_000 +_REFRESH_LOCK = asyncio.Lock() + + +@dataclass(frozen=True) +class CodexCredentials: + access: str + refresh: str + expires: int + account_id: str + source: str = "ml-intern" + path: Path | None = None + + def to_json(self) -> dict[str, Any]: + return { + "type": "oauth", + "access": self.access, + "refresh": self.refresh, + "expires": self.expires, + "accountId": self.account_id, + } + + +@dataclass(frozen=True) +class CodexAuthorizationFlow: + verifier: str + state: str + url: str + + +def _auth_path() -> Path: + configured = os.environ.get("ML_INTERN_CODEX_AUTH_PATH") + if configured: + return Path(configured).expanduser() + return Path.home() / ".cache" / "ml-intern" / "codex-auth.json" + + +def _codex_auth_path() -> Path: + configured = os.environ.get("CODEX_AUTH_PATH") + if configured: + return Path(configured).expanduser() + return Path.home() / ".codex" / "auth.json" + + +def _pi_auth_path() -> Path: + configured = os.environ.get("PI_CODEX_AUTH_PATH") + if configured: + return Path(configured).expanduser() + return Path.home() / ".pi" / "agent" / "auth.json" + + +def _user_key(user_id: str | None) -> str: + normalized = (user_id or "dev").strip() or "dev" + return hashlib.sha256(normalized.encode("utf-8")).hexdigest() + + +def _b64url(data: bytes) -> str: + return base64.urlsafe_b64encode(data).decode("ascii").rstrip("=") + + +def generate_pkce() -> tuple[str, str]: + verifier = _b64url(secrets.token_bytes(32)) + challenge = _b64url(hashlib.sha256(verifier.encode("ascii")).digest()) + return verifier, challenge + + +def create_codex_authorization_flow() -> CodexAuthorizationFlow: + verifier, challenge = generate_pkce() + state = secrets.token_urlsafe(24) + params = { + "response_type": "code", + "client_id": CODEX_CLIENT_ID, + "redirect_uri": CODEX_REDIRECT_URI, + "scope": CODEX_SCOPE, + "code_challenge": challenge, + "code_challenge_method": "S256", + "state": state, + "id_token_add_organizations": "true", + "codex_cli_simplified_flow": "true", + "originator": CODEX_ORIGINATOR, + } + return CodexAuthorizationFlow( + verifier=verifier, + state=state, + url=f"{CODEX_AUTHORIZE_URL}?{urlencode(params)}", + ) + + +def parse_authorization_response(value: str) -> tuple[str | None, str | None]: + raw = (value or "").strip() + if not raw: + return None, None + + try: + parsed = urlparse(raw) + if parsed.scheme and parsed.netloc: + params = parse_qs(parsed.query) + return _first(params.get("code")), _first(params.get("state")) + except Exception: + pass + + if "#" in raw: + code, state = raw.split("#", 1) + return code or None, state or None + + if "code=" in raw or "state=" in raw: + params = parse_qs(raw) + return _first(params.get("code")), _first(params.get("state")) + + return raw, None + + +def _first(values: list[str] | None) -> str | None: + if not values: + return None + value = values[0].strip() + return value or None + + +def _decode_jwt_payload(token: str) -> dict[str, Any] | None: + try: + parts = token.split(".") + if len(parts) != 3: + return None + payload = parts[1] + payload += "=" * (-len(payload) % 4) + decoded = base64.urlsafe_b64decode(payload.encode("ascii")) + parsed = json.loads(decoded) + return parsed if isinstance(parsed, dict) else None + except Exception: + return None + + +def extract_account_id(access_token: str) -> str | None: + payload = _decode_jwt_payload(access_token) + auth = payload.get(CODEX_JWT_CLAIM_PATH) if payload else None + if not isinstance(auth, dict): + return None + account_id = auth.get("chatgpt_account_id") + return account_id if isinstance(account_id, str) and account_id else None + + +def _credentials_from_raw(raw: Any, *, source: str) -> CodexCredentials | None: + if not isinstance(raw, dict): + return None + access = raw.get("access") + refresh = raw.get("refresh") + expires = raw.get("expires") + account_id = raw.get("accountId") or raw.get("account_id") + if not isinstance(access, str) or not isinstance(refresh, str): + return None + if not isinstance(expires, (int, float)): + return None + if not isinstance(account_id, str) or not account_id: + account_id = extract_account_id(access) or "" + if not account_id: + return None + return CodexCredentials( + access=access, + refresh=refresh, + expires=int(expires), + account_id=account_id, + source=source, + ) + + +def _credentials_from_codex_data( + data: dict[str, Any], path: Path +) -> CodexCredentials | None: + tokens = data.get("tokens") + if not isinstance(tokens, dict): + return None + access = tokens.get("access_token") + refresh = tokens.get("refresh_token") + account_id = tokens.get("account_id") + if not isinstance(access, str) or not isinstance(refresh, str): + return None + expires = jwt_expires_at_ms(access) + if expires is None: + return None + if not isinstance(account_id, str) or not account_id: + account_id = extract_account_id(access) or "" + if not account_id: + return None + return CodexCredentials( + access=access, + refresh=refresh, + expires=expires, + account_id=account_id, + source="codex", + path=path, + ) + + +def _credentials_from_codex_file(path: Path) -> CodexCredentials | None: + return _credentials_from_codex_data(_read_json(path), path) + + +def _credentials_from_stored_data( + data: dict[str, Any], + user_id: str | None, +) -> CodexCredentials | None: + users = data.get("users") + if not isinstance(users, dict): + return None + return _credentials_from_raw(users.get(_user_key(user_id)), source="ml-intern") + + +def _read_json(path: Path) -> dict[str, Any]: + try: + if not path.exists(): + return {} + with path.open("r", encoding="utf-8") as f: + data = json.load(f) + return data if isinstance(data, dict) else {} + except Exception: + return {} + + +@contextmanager +def _file_lock(path: Path): + path.parent.mkdir(parents=True, exist_ok=True, mode=0o700) + fd = os.open(path, os.O_RDWR | os.O_CREAT, 0o600) + with os.fdopen(fd, "r+b") as f: + if fcntl is not None: + fcntl.flock(f.fileno(), fcntl.LOCK_EX) + elif msvcrt is not None: # pragma: no cover - Windows-only fallback. + msvcrt.locking(f.fileno(), msvcrt.LK_LOCK, 1) + try: + yield + finally: + if fcntl is not None: + fcntl.flock(f.fileno(), fcntl.LOCK_UN) + elif msvcrt is not None: # pragma: no cover - Windows-only fallback. + f.seek(0) + msvcrt.locking(f.fileno(), msvcrt.LK_UNLCK, 1) + + +@contextmanager +def _locked_json(path: Path): + path.parent.mkdir(parents=True, exist_ok=True, mode=0o700) + lock_path = path.with_name(f"{path.name}.lock") + with _file_lock(lock_path): + data = _read_json(path) + yield data + + tmp_path = path.with_name( + f".{path.name}.{os.getpid()}.{secrets.token_hex(4)}.tmp" + ) + try: + with tmp_path.open("w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + f.write("\n") + f.flush() + os.fsync(f.fileno()) + os.chmod(tmp_path, 0o600) + os.replace(tmp_path, path) + try: + dir_fd = os.open(path.parent, os.O_RDONLY) + except OSError: + dir_fd = None + if dir_fd is not None: + try: + os.fsync(dir_fd) + finally: + os.close(dir_fd) + finally: + try: + tmp_path.unlink() + except FileNotFoundError: + pass + + +def store_codex_credentials(user_id: str | None, credentials: CodexCredentials) -> None: + with _locked_json(_auth_path()) as data: + users = data.setdefault("users", {}) + users[_user_key(user_id)] = credentials.to_json() + + +def delete_codex_credentials(user_id: str | None) -> None: + with _locked_json(_auth_path()) as data: + users = data.setdefault("users", {}) + users.pop(_user_key(user_id), None) + + +def load_stored_codex_credentials(user_id: str | None) -> CodexCredentials | None: + return _credentials_from_stored_data(_read_json(_auth_path()), user_id) + + +def load_pi_codex_credentials() -> CodexCredentials | None: + if os.environ.get("ML_INTERN_CODEX_USE_PI_AUTH", "1") == "0": + return None + data = _read_json(_pi_auth_path()) + return _credentials_from_raw(data.get(CODEX_PROVIDER_ID), source="pi") + + +def load_codex_cli_credentials() -> CodexCredentials | None: + if os.environ.get("ML_INTERN_CODEX_USE_CODEX_AUTH", "1") == "0": + return None + return _credentials_from_codex_file(_codex_auth_path()) + + +def load_codex_credentials_for_user(user_id: str | None) -> CodexCredentials | None: + return ( + load_codex_cli_credentials() + or load_stored_codex_credentials(user_id) + or load_pi_codex_credentials() + ) + + +def has_codex_credentials(user_id: str | None) -> bool: + return load_codex_credentials_for_user(user_id) is not None + + +def _expires_soon(credentials: CodexCredentials) -> bool: + return int(time.time() * 1000) + _REFRESH_SKEW_MS >= credentials.expires + + +async def exchange_codex_authorization_code( + code: str, + verifier: str, +) -> CodexCredentials: + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.post( + CODEX_TOKEN_URL, + data={ + "grant_type": "authorization_code", + "client_id": CODEX_CLIENT_ID, + "code": code, + "code_verifier": verifier, + "redirect_uri": CODEX_REDIRECT_URI, + }, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + return _credentials_from_token_response(response, action="exchange") + + +async def refresh_codex_credentials(credentials: CodexCredentials) -> CodexCredentials: + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.post( + CODEX_TOKEN_URL, + data={ + "grant_type": "refresh_token", + "refresh_token": credentials.refresh, + "client_id": CODEX_CLIENT_ID, + }, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + return _credentials_from_token_response( + response, + action="refresh", + fallback_refresh=credentials.refresh, + ) + + +def jwt_expires_at_ms(access_token: str) -> int | None: + payload = _decode_jwt_payload(access_token) + exp = payload.get("exp") if payload else None + if not isinstance(exp, (int, float)): + return None + return int(float(exp) * 1000) + + +def _credentials_from_token_response( + response: httpx.Response, + *, + action: str, + fallback_refresh: str | None = None, +) -> CodexCredentials: + if not response.is_success: + text = response.text[:500] if response.text else response.reason_phrase + raise ValueError( + f"OpenAI Codex token {action} failed ({response.status_code}): {text}" + ) + payload = response.json() + access = payload.get("access_token") + refresh = payload.get("refresh_token") + expires_in = payload.get("expires_in") + if not isinstance(access, str) or not access: + raise ValueError(f"OpenAI Codex token {action} response missing access token") + if not isinstance(refresh, str) or not refresh: + refresh = fallback_refresh + if not isinstance(refresh, str) or not refresh: + raise ValueError(f"OpenAI Codex token {action} response missing refresh token") + if not isinstance(expires_in, (int, float)): + raise ValueError(f"OpenAI Codex token {action} response missing expires_in") + account_id = extract_account_id(access) + if not account_id: + raise ValueError("OpenAI Codex token did not include a ChatGPT account id") + jwt_expires = jwt_expires_at_ms(access) + return CodexCredentials( + access=access, + refresh=refresh, + expires=jwt_expires or int(time.time() * 1000 + float(expires_in) * 1000), + account_id=account_id, + ) + + +def _write_codex_cli_credentials( + data: dict[str, Any], + credentials: CodexCredentials, +) -> None: + tokens = data.setdefault("tokens", {}) + tokens["access_token"] = credentials.access + tokens["refresh_token"] = credentials.refresh + tokens["account_id"] = credentials.account_id + # Codex stores id_token too, but ChatGPT Codex backend calls only need + # access_token + refresh_token + account_id. Leave an existing id_token + # untouched because refresh_token responses for this client do not need + # to return a new one. + data["last_refresh"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + + +def _store_refreshed_codex_cli_credentials(credentials: CodexCredentials) -> bool: + if credentials.path is None or credentials.source != "codex": + return False + with _locked_json(credentials.path) as data: + _write_codex_cli_credentials(data, credentials) + return True + + +async def _refresh_codex_cli_credentials_from_file( + path: Path, +) -> CodexCredentials | None: + try: + with _locked_json(path) as data: + current = _credentials_from_codex_data(data, path) + if current is None: + return None + if not _expires_soon(current): + return current + + refreshed = await refresh_codex_credentials(current) + refreshed = CodexCredentials( + access=refreshed.access, + refresh=refreshed.refresh, + expires=refreshed.expires, + account_id=refreshed.account_id, + source=current.source, + path=current.path, + ) + _write_codex_cli_credentials(data, refreshed) + return refreshed + except Exception: + latest = _credentials_from_codex_file(path) + if latest is not None and not _expires_soon(latest): + return latest + raise + + +async def _refresh_stored_codex_credentials( + user_id: str | None, +) -> CodexCredentials | None: + try: + with _locked_json(_auth_path()) as data: + current = _credentials_from_stored_data(data, user_id) + if current is None: + return None + if not _expires_soon(current): + return current + + refreshed = await refresh_codex_credentials(current) + users = data.setdefault("users", {}) + users[_user_key(user_id)] = refreshed.to_json() + return refreshed + except Exception: + latest = load_stored_codex_credentials(user_id) + if latest is not None and not _expires_soon(latest): + return latest + raise + + +async def get_codex_credentials_for_user( + user_id: str | None, +) -> CodexCredentials | None: + credentials = load_codex_credentials_for_user(user_id) + if credentials is None: + return None + if not _expires_soon(credentials): + return credentials + + async with _REFRESH_LOCK: + current = load_codex_credentials_for_user(user_id) + if current is None: + return None + if not _expires_soon(current): + return current + + if current.source == "codex" and current.path is not None: + return await _refresh_codex_cli_credentials_from_file(current.path) + if current.source == "ml-intern": + return await _refresh_stored_codex_credentials(user_id) + if current.source == "pi": + # Pi owns ~/.pi/agent/auth.json and uses its own lock format. Do not + # rotate Pi's refresh token from here; otherwise Pi can be left with + # a stale refresh token. Valid Pi access tokens are still usable. + return None + + refreshed = await refresh_codex_credentials(current) + store_codex_credentials(user_id, refreshed) + return refreshed + + +def codex_request_headers( + credentials: CodexCredentials, + *, + session_id: str | None = None, +) -> dict[str, str]: + headers = { + "Authorization": f"Bearer {credentials.access}", + "chatgpt-account-id": credentials.account_id, + "originator": CODEX_ORIGINATOR, + "User-Agent": "pi (ml-intern)", + "OpenAI-Beta": "responses=experimental", + "accept": "text/event-stream", + "content-type": "application/json", + } + if session_id: + headers["session_id"] = session_id + headers["x-client-request-id"] = session_id + return headers diff --git a/agent/core/codex_responses.py b/agent/core/codex_responses.py new file mode 100644 index 00000000..c4814cf5 --- /dev/null +++ b/agent/core/codex_responses.py @@ -0,0 +1,593 @@ +"""ChatGPT Codex Responses adapter. + +The normal OpenAI/GPT path goes through LiteLLM and the public API. Codex +OAuth does not: Pi sends OAuth bearer tokens to ChatGPT's +``/backend-api/codex/responses`` endpoint. This module keeps that provider +path explicit so the rest of the agent loop can keep its LiteLLM-shaped +message/tool handling. +""" + +from __future__ import annotations + +import json +import time +from dataclasses import dataclass, field +from typing import Any, AsyncIterator + +import httpx + +from agent.core import telemetry +from agent.core.codex_oauth import ( + codex_request_headers, + get_codex_credentials_for_user, +) +from agent.core.llm_params import UnsupportedEffortError +from agent.core.session import Event + +CODEX_MODEL_PREFIX = "openai-codex/" +CODEX_BASE_URL = "https://chatgpt.com/backend-api" +_CODEX_PROVIDER_MARKER = "_ml_intern_provider" +_CODEX_EFFORTS = {"none", "low", "medium", "high", "xhigh"} +_MAX_RETRIES = 3 +_BASE_RETRY_DELAY_SECONDS = 1 + + +class CodexAuthRequiredError(RuntimeError): + """Raised when a session selects openai-codex without OAuth credentials.""" + + +class CodexAPIError(RuntimeError): + """Raised for ChatGPT Codex response errors.""" + + +@dataclass +class CodexCompletionResult: + content: str | None + tool_calls_acc: dict[int, dict] + token_count: int + finish_reason: str | None + usage: dict[str, Any] = field(default_factory=dict) + + +def is_openai_codex_model(model_name: str | None) -> bool: + return bool(model_name and model_name.startswith(CODEX_MODEL_PREFIX)) + + +def codex_model_id(model_name: str) -> str: + if not is_openai_codex_model(model_name): + raise ValueError(f"Not an OpenAI Codex OAuth model: {model_name}") + model_id = model_name.removeprefix(CODEX_MODEL_PREFIX).strip() + if not model_id: + raise ValueError("OpenAI Codex model id is empty") + return model_id + + +def is_codex_llm_params(params: dict[str, Any]) -> bool: + return params.get(_CODEX_PROVIDER_MARKER) == "openai-codex" + + +def _normalize_reasoning_effort( + reasoning_effort: str | None, + *, + strict: bool, +) -> str | None: + if not reasoning_effort: + return None + if reasoning_effort == "minimal": + return "low" + if reasoning_effort == "max": + if strict: + raise UnsupportedEffortError("OpenAI Codex doesn't accept effort='max'") + return "xhigh" + if reasoning_effort in _CODEX_EFFORTS: + return None if reasoning_effort == "none" else reasoning_effort + if strict: + raise UnsupportedEffortError( + f"OpenAI Codex doesn't accept effort={reasoning_effort!r}" + ) + return None + + +async def resolve_codex_llm_params( + model_name: str, + *, + user_id: str | None, + reasoning_effort: str | None = None, + strict: bool = False, +) -> dict[str, Any]: + credentials = await get_codex_credentials_for_user(user_id) + if credentials is None: + raise CodexAuthRequiredError( + "Connect a ChatGPT Plus/Pro account before using the Codex subscription model." + ) + return { + _CODEX_PROVIDER_MARKER: "openai-codex", + "model": model_name, + "codex_model": codex_model_id(model_name), + "credentials": credentials, + "reasoning_effort": _normalize_reasoning_effort( + reasoning_effort, + strict=strict, + ), + } + + +def _message_to_dict(message: Any) -> dict[str, Any]: + if isinstance(message, dict): + return message + if hasattr(message, "model_dump"): + try: + return message.model_dump(mode="json") + except TypeError: + return message.model_dump() + data: dict[str, Any] = {} + for key in ("role", "content", "tool_calls", "tool_call_id", "name"): + value = getattr(message, key, None) + if value is not None: + data[key] = value + return data + + +def _flatten_content(content: Any) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, dict): + text = item.get("text") or item.get("content") + if isinstance(text, str): + parts.append(text) + return "\n".join(part for part in parts if part) + return str(content) + + +def _content_parts_for_user(content: Any) -> list[dict[str, str]]: + text = _flatten_content(content) + if not text: + return [] + return [{"type": "input_text", "text": text}] + + +def _assistant_text_item(text: str, index: int) -> dict[str, Any]: + return { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": text, "annotations": []}], + "status": "completed", + "id": f"msg_{index}", + } + + +def _tool_call_dict(tool_call: Any) -> dict[str, Any] | None: + if hasattr(tool_call, "model_dump"): + try: + data = tool_call.model_dump(mode="json") + except TypeError: + data = tool_call.model_dump() + elif isinstance(tool_call, dict): + data = tool_call + else: + data = { + "id": getattr(tool_call, "id", None), + "function": getattr(tool_call, "function", None), + } + function = data.get("function") or {} + if not isinstance(function, dict): + function = { + "name": getattr(function, "name", None), + "arguments": getattr(function, "arguments", None), + } + call_id = str(data.get("id") or "").strip() + name = str(function.get("name") or "").strip() + arguments = function.get("arguments") + if not call_id or not name: + return None + if not isinstance(arguments, str): + arguments = json.dumps(arguments or {}) + normalized_id = call_id.split("|", 1) + item_id = normalized_id[1] if len(normalized_id) == 2 else None + item: dict[str, Any] = { + "type": "function_call", + "call_id": normalized_id[0], + "name": name, + "arguments": arguments, + } + if item_id: + item["id"] = item_id + return item + + +def convert_messages_for_codex(messages: list[Any]) -> tuple[str, list[dict[str, Any]]]: + instructions: list[str] = [] + input_items: list[dict[str, Any]] = [] + msg_index = 0 + + for raw_message in messages: + message = _message_to_dict(raw_message) + role = message.get("role") + if role in {"system", "developer"}: + text = _flatten_content(message.get("content")).strip() + if text: + instructions.append(text) + continue + + if role == "user": + content = _content_parts_for_user(message.get("content")) + if content: + input_items.append({"role": "user", "content": content}) + elif role == "assistant": + text = _flatten_content(message.get("content")).strip() + if text: + input_items.append(_assistant_text_item(text, msg_index)) + tool_calls = message.get("tool_calls") or [] + for tool_call in tool_calls: + item = _tool_call_dict(tool_call) + if item: + input_items.append(item) + elif role == "tool": + call_id = str(message.get("tool_call_id") or "").split("|", 1)[0] + if call_id: + input_items.append( + { + "type": "function_call_output", + "call_id": call_id, + "output": _flatten_content(message.get("content")), + } + ) + msg_index += 1 + + return "\n\n".join(instructions) or "You are a helpful assistant.", input_items + + +def convert_tools_for_codex(tools: list[dict[str, Any]] | None) -> list[dict[str, Any]]: + converted: list[dict[str, Any]] = [] + for tool in tools or []: + function = tool.get("function") if isinstance(tool, dict) else None + if not isinstance(function, dict): + function = tool if isinstance(tool, dict) else {} + name = function.get("name") + if not isinstance(name, str) or not name: + continue + converted.append( + { + "type": "function", + "name": name, + "description": function.get("description") or "", + "parameters": function.get("parameters") or {"type": "object"}, + "strict": None, + } + ) + return converted + + +def _codex_url(base_url: str = CODEX_BASE_URL) -> str: + normalized = base_url.rstrip("/") + if normalized.endswith("/codex/responses"): + return normalized + if normalized.endswith("/codex"): + return f"{normalized}/responses" + return f"{normalized}/codex/responses" + + +def _request_body( + *, + session_id: str | None, + messages: list[Any], + tools: list[dict[str, Any]] | None, + llm_params: dict[str, Any], +) -> dict[str, Any]: + instructions, input_items = convert_messages_for_codex(messages) + body: dict[str, Any] = { + "model": llm_params["codex_model"], + "store": False, + "stream": True, + "instructions": instructions, + "input": input_items, + "text": {"verbosity": "low"}, + "include": ["reasoning.encrypted_content"], + "prompt_cache_key": session_id, + "tool_choice": "auto", + "parallel_tool_calls": True, + } + converted_tools = convert_tools_for_codex(tools) + if converted_tools: + body["tools"] = converted_tools + if llm_params.get("reasoning_effort"): + body["reasoning"] = { + "effort": llm_params["reasoning_effort"], + "summary": "auto", + } + return body + + +async def _iter_sse_events(response: httpx.Response) -> AsyncIterator[dict[str, Any]]: + data_lines: list[str] = [] + async for raw_line in response.aiter_lines(): + line = raw_line.rstrip("\r") + if line == "": + if data_lines: + data = "\n".join(data_lines).strip() + data_lines = [] + if data and data != "[DONE]": + try: + parsed = json.loads(data) + except json.JSONDecodeError as exc: + raise CodexAPIError(f"Invalid Codex SSE JSON: {exc}") from exc + if isinstance(parsed, dict): + yield parsed + continue + if line.startswith("data:"): + data_lines.append(line[5:].strip()) + if data_lines: + data = "\n".join(data_lines).strip() + if data and data != "[DONE]": + parsed = json.loads(data) + if isinstance(parsed, dict): + yield parsed + + +def _usage_from_codex(response: dict[str, Any] | None) -> dict[str, Any]: + usage = response.get("usage") if isinstance(response, dict) else None + if not isinstance(usage, dict): + return {} + input_tokens = int(usage.get("input_tokens") or 0) + output_tokens = int(usage.get("output_tokens") or 0) + total_tokens = int(usage.get("total_tokens") or input_tokens + output_tokens) + details = usage.get("input_tokens_details") + cached_tokens = ( + int(details.get("cached_tokens") or 0) if isinstance(details, dict) else 0 + ) + return { + "prompt_tokens": input_tokens, + "completion_tokens": output_tokens, + "total_tokens": total_tokens, + "prompt_tokens_details": {"cached_tokens": cached_tokens}, + } + + +def _finish_reason( + response: dict[str, Any] | None, tool_calls_acc: dict[int, dict] +) -> str: + status = response.get("status") if isinstance(response, dict) else None + if status == "incomplete": + return "length" + if status in {"failed", "cancelled"}: + return "error" + if tool_calls_acc: + return "tool_calls" + return "stop" + + +def _response_error_message(status: int, text: str) -> str: + try: + parsed = json.loads(text) + err = parsed.get("error") if isinstance(parsed, dict) else None + if isinstance(err, dict): + code = str(err.get("code") or err.get("type") or "") + message = str(err.get("message") or "") + if status == 429 or "usage_limit" in code or "rate_limit" in code: + return message or "You have hit your ChatGPT usage limit." + if message: + return message + except Exception: + pass + return text[:500] or f"Codex request failed with HTTP {status}" + + +def _is_retryable(status: int, text: str) -> bool: + if status in {429, 500, 502, 503, 504}: + return True + lowered = text.lower() + return any( + phrase in lowered + for phrase in ( + "rate limit", + "overloaded", + "service unavailable", + "connection refused", + ) + ) + + +async def _post_with_retries( + client: httpx.AsyncClient, + *, + headers: dict[str, str], + body: dict[str, Any], +) -> httpx.Response: + last_error: Exception | None = None + for attempt in range(_MAX_RETRIES + 1): + try: + request = client.build_request( + "POST", + _codex_url(), + headers=headers, + json=body, + ) + response = await client.send(request, stream=True) + if response.is_success: + return response + text = await response.aread() + error_text = text.decode("utf-8", errors="replace") + await response.aclose() + if attempt < _MAX_RETRIES and _is_retryable( + response.status_code, error_text + ): + await _sleep_retry(attempt) + continue + raise CodexAPIError( + _response_error_message(response.status_code, error_text) + ) + except (httpx.TimeoutException, httpx.NetworkError) as exc: + last_error = exc + if attempt < _MAX_RETRIES: + await _sleep_retry(attempt) + continue + raise + if last_error: + raise last_error + raise CodexAPIError("Codex request failed after retries") + + +async def _sleep_retry(attempt: int) -> None: + import asyncio + + await asyncio.sleep(_BASE_RETRY_DELAY_SECONDS * (2**attempt)) + + +def _ensure_tool_slot( + tool_calls_acc: dict[int, dict], + index: int, +) -> dict[str, Any]: + if index not in tool_calls_acc: + tool_calls_acc[index] = { + "id": "", + "type": "function", + "function": {"name": "", "arguments": ""}, + } + return tool_calls_acc[index] + + +async def call_codex_responses( + session: Any, + messages: list[Any], + tools: list[dict[str, Any]] | None, + llm_params: dict[str, Any], + *, + emit_events: bool, +) -> CodexCompletionResult: + credentials = llm_params["credentials"] + body = _request_body( + session_id=getattr(session, "session_id", None), + messages=messages, + tools=tools, + llm_params=llm_params, + ) + headers = codex_request_headers( + credentials, + session_id=getattr(session, "session_id", None), + ) + + content_parts: list[str] = [] + tool_calls_acc: dict[int, dict] = {} + final_response: dict[str, Any] | None = None + t_start = time.monotonic() + + async with httpx.AsyncClient(timeout=httpx.Timeout(600.0, connect=30.0)) as client: + response = await _post_with_retries(client, headers=headers, body=body) + try: + async for event in _iter_sse_events(response): + event_type = event.get("type") + if event_type == "error": + message = event.get("message") or event.get("code") or "Codex error" + raise CodexAPIError(str(message)) + if event_type == "response.failed": + failed_response = event.get("response") + error = ( + failed_response.get("error") + if isinstance(failed_response, dict) + else None + ) + if isinstance(error, dict): + raise CodexAPIError(str(error.get("message") or error)) + raise CodexAPIError("Codex response failed") + if event_type in { + "response.completed", + "response.done", + "response.incomplete", + }: + response_payload = event.get("response") + final_response = ( + response_payload if isinstance(response_payload, dict) else {} + ) + if event_type == "response.incomplete": + final_response["status"] = "incomplete" + break + if event_type == "response.output_text.delta": + delta = event.get("delta") + if isinstance(delta, str) and delta: + content_parts.append(delta) + if emit_events: + await session.send_event( + Event( + event_type="assistant_chunk", + data={"content": delta}, + ) + ) + elif event_type == "response.output_item.added": + item = event.get("item") + if ( + not isinstance(item, dict) + or item.get("type") != "function_call" + ): + continue + index = int(event.get("output_index") or len(tool_calls_acc)) + slot = _ensure_tool_slot(tool_calls_acc, index) + call_id = str(item.get("call_id") or "") + item_id = str(item.get("id") or "") + slot["id"] = f"{call_id}|{item_id}" if item_id else call_id + slot["function"]["name"] = str(item.get("name") or "") + slot["function"]["arguments"] = str(item.get("arguments") or "") + elif event_type == "response.function_call_arguments.delta": + index = int(event.get("output_index") or len(tool_calls_acc) - 1) + if index < 0: + continue + delta = event.get("delta") + if isinstance(delta, str): + slot = _ensure_tool_slot(tool_calls_acc, index) + slot["function"]["arguments"] += delta + elif event_type == "response.function_call_arguments.done": + index = int(event.get("output_index") or len(tool_calls_acc) - 1) + if index < 0: + continue + arguments = event.get("arguments") + if isinstance(arguments, str): + slot = _ensure_tool_slot(tool_calls_acc, index) + slot["function"]["arguments"] = arguments + elif event_type == "response.output_item.done": + item = event.get("item") + if not isinstance(item, dict): + continue + if item.get("type") == "function_call": + index = int(event.get("output_index") or len(tool_calls_acc)) + slot = _ensure_tool_slot(tool_calls_acc, index) + call_id = str(item.get("call_id") or "") + item_id = str(item.get("id") or "") + slot["id"] = f"{call_id}|{item_id}" if item_id else call_id + slot["function"]["name"] = str(item.get("name") or "") + if isinstance(item.get("arguments"), str): + slot["function"]["arguments"] = str(item["arguments"]) + elif item.get("type") == "message" and not content_parts: + for part in item.get("content") or []: + if ( + isinstance(part, dict) + and part.get("type") == "output_text" + ): + text = part.get("text") + if isinstance(text, str): + content_parts.append(text) + finally: + await response.aclose() + + usage = _usage_from_codex(final_response) + finish_reason = _finish_reason(final_response, tool_calls_acc) + await telemetry.record_llm_call( + session, + model=str(llm_params.get("model") or llm_params.get("codex_model")), + response={"usage": usage}, + latency_ms=int((time.monotonic() - t_start) * 1000), + finish_reason=finish_reason, + ) + + return CodexCompletionResult( + content="".join(content_parts) or None, + tool_calls_acc=tool_calls_acc, + token_count=int(usage.get("total_tokens") or 0), + finish_reason=finish_reason, + usage=usage, + ) diff --git a/agent/core/llm_params.py b/agent/core/llm_params.py index f95695fb..200e361b 100644 --- a/agent/core/llm_params.py +++ b/agent/core/llm_params.py @@ -89,12 +89,14 @@ def _widened(model: str) -> bool: # Effort levels accepted on the wire. # Anthropic (4.6+): low | medium | high | xhigh | max (output_config.effort) # OpenAI direct: minimal | low | medium | high | xhigh (reasoning_effort top-level) +# OpenAI Codex OAuth: low | medium | high | xhigh (ChatGPT Codex backend) # HF router: low | medium | high (extra_body.reasoning_effort) # # We validate *shape* here and let the probe cascade walk down on rejection; # we deliberately do NOT maintain a per-model capability table. _ANTHROPIC_EFFORTS = {"low", "medium", "high", "xhigh", "max"} _OPENAI_EFFORTS = {"minimal", "low", "medium", "high", "xhigh"} +_OPENAI_CODEX_EFFORTS = {"minimal", "low", "medium", "high", "xhigh"} _HF_EFFORTS = {"low", "medium", "high"} @@ -171,6 +173,11 @@ def _resolve_llm_params( • ``openai/`` — ``reasoning_effort`` forwarded as a top-level kwarg (GPT-5 / o-series). LiteLLM uses the user's ``OPENAI_API_KEY``. + • ``openai-codex/`` — ChatGPT Codex OAuth. This is not a LiteLLM + provider and not the public OpenAI API; the agent loop recognizes the + returned marker and sends the request to ChatGPT's Codex backend with + credentials from ``~/.codex/auth.json``. + • ``ollama/``, ``vllm/``, ``lm_studio/``, and ``llamacpp/`` — local OpenAI-compatible endpoints. The id prefix selects a configurable localhost base URL, and the model suffix is sent @@ -243,6 +250,31 @@ def _resolve_llm_params( params["reasoning_effort"] = reasoning_effort return params + if model_name.startswith("openai-codex/"): + codex_model = model_name.removeprefix("openai-codex/").strip() + if not codex_model: + raise ValueError(f"Unsupported OpenAI Codex model id: {model_name}") + params = { + "_ml_intern_provider": "openai-codex", + "model": model_name, + "codex_model": codex_model, + } + if reasoning_effort: + level = "low" if reasoning_effort == "minimal" else reasoning_effort + if level == "max": + if strict: + raise UnsupportedEffortError( + "OpenAI Codex doesn't accept effort='max'" + ) + elif level not in _OPENAI_CODEX_EFFORTS: + if strict: + raise UnsupportedEffortError( + f"OpenAI Codex doesn't accept effort={level!r}" + ) + else: + params["reasoning_effort"] = level + return params + if is_reserved_local_model_id(model_name): raise ValueError(f"Unsupported local model id: {model_name}") diff --git a/agent/core/model_switcher.py b/agent/core/model_switcher.py index 34eaccdd..ecdb8fcf 100644 --- a/agent/core/model_switcher.py +++ b/agent/core/model_switcher.py @@ -19,6 +19,10 @@ from litellm import acompletion +from agent.core.codex_responses import ( + CODEX_MODEL_PREFIX, + resolve_codex_llm_params, +) from agent.core.effort_probe import ProbeInconclusive, probe_effort from agent.core.llm_params import _resolve_llm_params from agent.core.local_models import ( @@ -34,6 +38,7 @@ # ":cheapest" / ":preferred" / ":" to override the default # routing policy (auto = fastest with failover). SUGGESTED_MODELS = [ + {"id": "openai-codex/gpt-5.5", "label": "GPT-5.5 Codex"}, {"id": "openai/gpt-5.5", "label": "GPT-5.5"}, {"id": "openai/gpt-5.4", "label": "GPT-5.4"}, {"id": "anthropic/claude-opus-4-7", "label": "Claude Opus 4.7"}, @@ -50,7 +55,7 @@ _ROUTING_POLICIES = {"fastest", "cheapest", "preferred"} -_DIRECT_PREFIXES = ("anthropic/", "openai/", *LOCAL_MODEL_PREFIXES) +_DIRECT_PREFIXES = ("anthropic/", "openai/", CODEX_MODEL_PREFIX, *LOCAL_MODEL_PREFIXES) _LOCAL_PROBE_TIMEOUT = 15.0 @@ -163,6 +168,7 @@ def print_model_listing(config, console) -> None: "\n[dim]Paste any HF model id (e.g. 'MiniMaxAI/MiniMax-M2.7').\n" "Add ':fastest', ':cheapest', ':preferred', or ':' to override routing.\n" "Use 'anthropic/' or 'openai/' for direct API access.\n" + "Use 'openai-codex/' for ChatGPT Codex subscription access.\n" "Use 'ollama/', 'vllm/', 'lm_studio/', or " "'llamacpp/' for local OpenAI-compatible endpoints.[/dim]" ) @@ -175,6 +181,7 @@ def print_invalid_id(arg: str, console) -> None: " • /[:tag] (HF router — paste from huggingface.co)\n" " • anthropic/\n" " • openai/\n" + " • openai-codex/\n" " • ollama/ | vllm/ | lm_studio/ | llamacpp/[/dim]" ) @@ -215,6 +222,7 @@ async def probe_and_switch_model( persistent. Local models reject every probe error, including timeouts, and keep the current model. """ + preference = config.reasoning_effort if is_local_model_id(model_id): console.print(f"[dim]checking local model {model_id}...[/dim]") try: @@ -230,7 +238,25 @@ async def probe_and_switch_model( ) return - preference = config.reasoning_effort + if model_id.startswith(CODEX_MODEL_PREFIX): + try: + await resolve_codex_llm_params( + model_id, + user_id=getattr(session, "user_id", None), + reasoning_effort=preference, + ) + except Exception as e: + console.print(f"[bold red]Switch failed:[/bold red] {e}") + console.print(f"[dim]Keeping current model: {config.model_name}[/dim]") + return + + _commit_switch(model_id, config, session, effective=None, cache=False) + console.print( + f"[green]Model switched to {model_id}[/green] " + "[dim](ChatGPT Codex OAuth)[/dim]" + ) + return + if not _print_hf_routing_info(model_id, console): return diff --git a/backend/routes/agent.py b/backend/routes/agent.py index 0a742b7c..e97c8c53 100644 --- a/backend/routes/agent.py +++ b/backend/routes/agent.py @@ -54,6 +54,7 @@ DEFAULT_CLAUDE_MODEL_ID = "bedrock/us.anthropic.claude-opus-4-6-v1" DEFAULT_FREE_MODEL_ID = "moonshotai/Kimi-K2.6" +DEFAULT_CODEX_MODEL_ID = "openai-codex/gpt-5.5" PREMIUM_MODEL_IDS = { DEFAULT_CLAUDE_MODEL_ID, "openai/gpt-5.5", @@ -94,6 +95,12 @@ def _available_models() -> list[dict[str, Any]]: "provider": "openai", "tier": "pro", }, + { + "id": DEFAULT_CODEX_MODEL_ID, + "label": "GPT-5.5 Codex", + "provider": "openai-codex", + "tier": "subscription", + }, { "id": "MiniMaxAI/MiniMax-M2.7", "label": "MiniMax M2.7", diff --git a/frontend/src/components/Chat/ChatInput.tsx b/frontend/src/components/Chat/ChatInput.tsx index 8a8810ea..27ade828 100644 --- a/frontend/src/components/Chat/ChatInput.tsx +++ b/frontend/src/components/Chat/ChatInput.tsx @@ -24,6 +24,7 @@ import { useAgentStore } from '@/store/agentStore'; import { useSessionStore } from '@/store/sessionStore'; import { CLAUDE_MODEL_PATH, + CODEX_GPT_55_MODEL_PATH, FIRST_FREE_MODEL_PATH, GPT_55_MODEL_PATH, isClaudePath, @@ -69,6 +70,13 @@ const DEFAULT_MODEL_OPTIONS: ModelOption[] = [ modelPath: GPT_55_MODEL_PATH, avatarUrl: 'https://huggingface.co/api/avatars/openai', }, + { + id: 'codex-gpt-5.5', + name: 'GPT-5.5 Codex', + description: 'ChatGPT subscription', + modelPath: CODEX_GPT_55_MODEL_PATH, + avatarUrl: 'https://huggingface.co/api/avatars/openai', + }, { id: 'minimax-m2.7', name: 'MiniMax M2.7', diff --git a/frontend/src/utils/model.ts b/frontend/src/utils/model.ts index 84754f99..1269760a 100644 --- a/frontend/src/utils/model.ts +++ b/frontend/src/utils/model.ts @@ -8,6 +8,7 @@ export const CLAUDE_MODEL_PATH = 'bedrock/us.anthropic.claude-opus-4-6-v1'; export const GPT_55_MODEL_PATH = 'openai/gpt-5.5'; +export const CODEX_GPT_55_MODEL_PATH = 'openai-codex/gpt-5.5'; export const FIRST_FREE_MODEL_PATH = 'moonshotai/Kimi-K2.6'; export function isClaudePath(modelPath: string | undefined): boolean { diff --git a/tests/unit/test_agent_model_gating.py b/tests/unit/test_agent_model_gating.py index 8a3f88d8..bf7104b8 100644 --- a/tests/unit/test_agent_model_gating.py +++ b/tests/unit/test_agent_model_gating.py @@ -26,6 +26,7 @@ def test_premium_model_predicate_includes_bedrock_claude_and_gpt55_only(): assert agent._is_premium_model("bedrock/us.anthropic.claude-opus-4-6-v1") assert agent._is_premium_model("openai/gpt-5.5") assert not agent._is_premium_model("anthropic/claude-opus-4-6") + assert not agent._is_premium_model("openai-codex/gpt-5.5") assert not agent._is_premium_model("moonshotai/Kimi-K2.6") diff --git a/tests/unit/test_cli_local_models.py b/tests/unit/test_cli_local_models.py index 836fb3fd..6ba8158e 100644 --- a/tests/unit/test_cli_local_models.py +++ b/tests/unit/test_cli_local_models.py @@ -16,6 +16,7 @@ def test_model_switcher_accepts_supported_local_prefixes(): assert model_switcher.is_valid_model_id("vllm/meta-llama/Llama-3.1-8B") assert model_switcher.is_valid_model_id("lm_studio/google/gemma-3-4b") assert model_switcher.is_valid_model_id("llamacpp/llama-3.1-8b") + assert model_switcher.is_valid_model_id("openai-codex/gpt-5.5") def test_model_switcher_rejects_empty_or_whitespace_local_ids(): @@ -41,6 +42,17 @@ def print(self, *args, **kwargs): ) +def test_codex_models_skip_hf_router_catalog_output(): + class NoPrintConsole: + def print(self, *args, **kwargs): + raise AssertionError("Codex models should not print HF catalog info") + + assert model_switcher._print_hf_routing_info( + "openai-codex/gpt-5.5", + NoPrintConsole(), + ) + + @pytest.mark.asyncio async def test_probe_and_switch_local_model_uses_no_effort(monkeypatch): calls = [] @@ -119,3 +131,51 @@ def print(self, *args, **kwargs): assert config.model_name == "openai/gpt-5.5" assert session.model_id is None assert "ollama/llama3.1:8b" not in session.model_effective_effort + + +@pytest.mark.asyncio +async def test_probe_and_switch_codex_model_validates_oauth_without_litellm( + monkeypatch, +): + calls = [] + + async def fake_resolve(model_id, *, user_id, reasoning_effort): + calls.append((model_id, user_id, reasoning_effort)) + return {"_ml_intern_provider": "openai-codex"} + + async def fail_acompletion(**kwargs): + raise AssertionError("Codex switch should not use LiteLLM") + + monkeypatch.setattr(model_switcher, "resolve_codex_llm_params", fake_resolve) + monkeypatch.setattr(model_switcher, "acompletion", fail_acompletion) + + class Config: + model_name = "openai/gpt-5.5" + reasoning_effort = "high" + + class Session: + user_id = "user-1" + + def __init__(self): + self.model_id = None + self.model_effective_effort = {} + + def update_model(self, model_id): + self.model_id = model_id + + class Console: + def print(self, *args, **kwargs): + pass + + session = Session() + await model_switcher.probe_and_switch_model( + "openai-codex/gpt-5.5", + Config(), + session, + Console(), + hf_token=None, + ) + + assert calls == [("openai-codex/gpt-5.5", "user-1", "high")] + assert session.model_id == "openai-codex/gpt-5.5" + assert "openai-codex/gpt-5.5" not in session.model_effective_effort diff --git a/tests/unit/test_codex_oauth.py b/tests/unit/test_codex_oauth.py new file mode 100644 index 00000000..fa6ffbde --- /dev/null +++ b/tests/unit/test_codex_oauth.py @@ -0,0 +1,290 @@ +import base64 +import json +import time + +import httpx +import pytest + +from agent.core import codex_oauth +from agent.core.codex_oauth import ( + CodexCredentials, + _credentials_from_token_response, + _store_refreshed_codex_cli_credentials, + get_codex_credentials_for_user, + load_codex_cli_credentials, + store_codex_credentials, +) + + +def _jwt(payload: dict) -> str: + def encode(part: dict) -> str: + raw = json.dumps(part, separators=(",", ":")).encode("utf-8") + return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=") + + return f"{encode({'alg': 'none'})}.{encode(payload)}.signature" + + +def test_load_codex_cli_credentials_reads_auth_json(monkeypatch, tmp_path): + auth_path = tmp_path / "auth.json" + access_token = _jwt( + { + "exp": int(time.time()) + 3600, + "https://api.openai.com/auth": { + "chatgpt_account_id": "acct_test", + }, + } + ) + auth_path.write_text( + json.dumps( + { + "auth_mode": "chatgpt", + "tokens": { + "access_token": access_token, + "refresh_token": "refresh_test", + }, + } + ), + encoding="utf-8", + ) + monkeypatch.setenv("CODEX_AUTH_PATH", str(auth_path)) + + credentials = load_codex_cli_credentials() + + assert credentials is not None + assert credentials.source == "codex" + assert credentials.path == auth_path + assert credentials.access == access_token + assert credentials.refresh == "refresh_test" + assert credentials.account_id == "acct_test" + + +def test_refreshed_codex_cli_credentials_preserve_codex_file_shape(tmp_path): + auth_path = tmp_path / "auth.json" + auth_path.write_text( + json.dumps( + { + "auth_mode": "chatgpt", + "tokens": { + "access_token": "old_access", + "refresh_token": "old_refresh", + "id_token": "keep_id_token", + "account_id": "old_account", + }, + } + ), + encoding="utf-8", + ) + refreshed = CodexCredentials( + access="new_access", + refresh="new_refresh", + expires=123, + account_id="new_account", + source="codex", + path=auth_path, + ) + + assert _store_refreshed_codex_cli_credentials(refreshed) + + data = json.loads(auth_path.read_text(encoding="utf-8")) + assert "users" not in data + assert data["tokens"]["access_token"] == "new_access" + assert data["tokens"]["refresh_token"] == "new_refresh" + assert data["tokens"]["account_id"] == "new_account" + assert data["tokens"]["id_token"] == "keep_id_token" + assert data["last_refresh"] + + +def test_refresh_token_response_can_preserve_existing_refresh_token(): + access_token = _jwt( + { + "exp": int(time.time()) + 3600, + "https://api.openai.com/auth": { + "chatgpt_account_id": "acct_test", + }, + } + ) + response = httpx.Response( + 200, + json={ + "access_token": access_token, + "expires_in": 3600, + }, + ) + + credentials = _credentials_from_token_response( + response, + action="refresh", + fallback_refresh="old_refresh", + ) + + assert credentials.access == access_token + assert credentials.refresh == "old_refresh" + assert credentials.account_id == "acct_test" + + +@pytest.mark.asyncio +async def test_get_codex_credentials_refreshes_codex_file_under_lock( + monkeypatch, + tmp_path, +): + auth_path = tmp_path / "auth.json" + old_access = _jwt( + { + "exp": int(time.time()) - 10, + "https://api.openai.com/auth": { + "chatgpt_account_id": "old_account", + }, + } + ) + new_access = _jwt( + { + "exp": int(time.time()) + 3600, + "https://api.openai.com/auth": { + "chatgpt_account_id": "new_account", + }, + } + ) + auth_path.write_text( + json.dumps( + { + "auth_mode": "chatgpt", + "tokens": { + "access_token": old_access, + "refresh_token": "old_refresh", + "id_token": "keep_id_token", + "account_id": "old_account", + }, + } + ), + encoding="utf-8", + ) + monkeypatch.setenv("CODEX_AUTH_PATH", str(auth_path)) + + async def fake_refresh(credentials): + assert credentials.refresh == "old_refresh" + return CodexCredentials( + access=new_access, + refresh="new_refresh", + expires=int(time.time() * 1000) + 3_600_000, + account_id="new_account", + ) + + monkeypatch.setattr(codex_oauth, "refresh_codex_credentials", fake_refresh) + + credentials = await get_codex_credentials_for_user("user-1") + + assert credentials is not None + assert credentials.source == "codex" + assert credentials.access == new_access + assert credentials.refresh == "new_refresh" + + data = json.loads(auth_path.read_text(encoding="utf-8")) + assert data["tokens"]["access_token"] == new_access + assert data["tokens"]["refresh_token"] == "new_refresh" + assert data["tokens"]["account_id"] == "new_account" + assert data["tokens"]["id_token"] == "keep_id_token" + + +@pytest.mark.asyncio +async def test_get_codex_credentials_refreshes_internal_store_under_lock( + monkeypatch, + tmp_path, +): + auth_path = tmp_path / "ml-intern-auth.json" + old_access = _jwt( + { + "exp": int(time.time()) - 10, + "https://api.openai.com/auth": { + "chatgpt_account_id": "old_account", + }, + } + ) + new_access = _jwt( + { + "exp": int(time.time()) + 3600, + "https://api.openai.com/auth": { + "chatgpt_account_id": "new_account", + }, + } + ) + monkeypatch.setenv("ML_INTERN_CODEX_AUTH_PATH", str(auth_path)) + monkeypatch.setenv("ML_INTERN_CODEX_USE_CODEX_AUTH", "0") + monkeypatch.setenv("ML_INTERN_CODEX_USE_PI_AUTH", "0") + store_codex_credentials( + "user-1", + CodexCredentials( + access=old_access, + refresh="old_refresh", + expires=int(time.time() * 1000) - 10_000, + account_id="old_account", + ), + ) + + async def fake_refresh(credentials): + assert credentials.refresh == "old_refresh" + return CodexCredentials( + access=new_access, + refresh="new_refresh", + expires=int(time.time() * 1000) + 3_600_000, + account_id="new_account", + ) + + monkeypatch.setattr(codex_oauth, "refresh_codex_credentials", fake_refresh) + + credentials = await get_codex_credentials_for_user("user-1") + + assert credentials is not None + assert credentials.access == new_access + assert credentials.refresh == "new_refresh" + + data = json.loads(auth_path.read_text(encoding="utf-8")) + stored = data["users"][codex_oauth._user_key("user-1")] + assert stored["access"] == new_access + assert stored["refresh"] == "new_refresh" + assert stored["accountId"] == "new_account" + + +@pytest.mark.asyncio +async def test_expired_pi_credentials_are_not_refreshed_or_copied( + monkeypatch, + tmp_path, +): + pi_auth_path = tmp_path / "pi-auth.json" + ml_intern_auth_path = tmp_path / "ml-intern-auth.json" + expired_access = _jwt( + { + "exp": int(time.time()) - 10, + "https://api.openai.com/auth": { + "chatgpt_account_id": "pi_account", + }, + } + ) + pi_auth_path.write_text( + json.dumps( + { + "openai-codex": { + "type": "oauth", + "access": expired_access, + "refresh": "pi_refresh", + "expires": int(time.time() * 1000) - 10_000, + "accountId": "pi_account", + } + } + ), + encoding="utf-8", + ) + monkeypatch.setenv("ML_INTERN_CODEX_USE_CODEX_AUTH", "0") + monkeypatch.setenv("ML_INTERN_CODEX_AUTH_PATH", str(ml_intern_auth_path)) + monkeypatch.setenv("PI_CODEX_AUTH_PATH", str(pi_auth_path)) + + async def fail_refresh(credentials): + raise AssertionError("Pi-owned refresh tokens must not be rotated by ml-intern") + + monkeypatch.setattr(codex_oauth, "refresh_codex_credentials", fail_refresh) + + credentials = await get_codex_credentials_for_user("user-1") + + assert credentials is None + assert not ml_intern_auth_path.exists() + data = json.loads(pi_auth_path.read_text(encoding="utf-8")) + assert data["openai-codex"]["refresh"] == "pi_refresh" diff --git a/tests/unit/test_llm_params.py b/tests/unit/test_llm_params.py index a7c7b4cd..72f8f8f1 100644 --- a/tests/unit/test_llm_params.py +++ b/tests/unit/test_llm_params.py @@ -32,6 +32,28 @@ def test_openai_max_effort_is_still_rejected(): raise AssertionError("Expected UnsupportedEffortError for max effort") +def test_openai_codex_model_uses_chatgpt_backend_marker(): + params = _resolve_llm_params( + "openai-codex/gpt-5.5", + reasoning_effort="xhigh", + strict=True, + ) + + assert params["_ml_intern_provider"] == "openai-codex" + assert params["model"] == "openai-codex/gpt-5.5" + assert params["codex_model"] == "gpt-5.5" + assert params["reasoning_effort"] == "xhigh" + + +def test_openai_codex_max_effort_is_rejected_in_strict_mode(): + with pytest.raises(UnsupportedEffortError, match="OpenAI Codex"): + _resolve_llm_params( + "openai-codex/gpt-5.5", + reasoning_effort="max", + strict=True, + ) + + def test_resolve_ollama_params_adds_v1_and_uses_default_key(monkeypatch): monkeypatch.delenv("OLLAMA_API_KEY", raising=False) monkeypatch.setenv("OLLAMA_BASE_URL", "http://localhost:11434")