diff --git a/.gitignore b/.gitignore index b086116945..32ebbf74dd 100644 --- a/.gitignore +++ b/.gitignore @@ -50,3 +50,5 @@ _bmad-output/* # macOS .DS_Store ._* +.gocache/ +.porting/ diff --git a/.goreleaser.yml b/.goreleaser.yml index f8bebfc1d9..c479255eaf 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -19,6 +19,8 @@ builds: archives: - id: "cli-proxy-api" format: tar.gz + name_template: >- + {{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{- if eq .Arch "arm64" -}}aarch64{{- else -}}{{ .Arch }}{{- end -}} format_overrides: - goos: windows format: zip diff --git a/Dockerfile b/Dockerfile index 3e10c4f9f8..b4caaee325 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,7 +14,7 @@ ARG BUILD_DATE=unknown RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w -X 'main.Version=${VERSION}' -X 'main.Commit=${COMMIT}' -X 'main.BuildDate=${BUILD_DATE}'" -o ./CLIProxyAPI ./cmd/server/ -FROM alpine:3.22.0 +FROM alpine:3.23 RUN apk add --no-cache tzdata @@ -32,4 +32,4 @@ ENV TZ=Asia/Shanghai RUN cp /usr/share/zoneinfo/${TZ} /etc/localtime && echo "${TZ}" > /etc/timezone -CMD ["./CLIProxyAPI"] \ No newline at end of file +CMD ["./CLIProxyAPI"] diff --git a/README.md b/README.md index 53acdd5178..8064db7d77 100644 --- a/README.md +++ b/README.md @@ -10,23 +10,19 @@ So you can use local or multi-account CLI access with OpenAI(include Responses)/ ## Sponsor -[](https://z.ai/subscribe?ic=8JVLJQFSKB) +[](https://www.packyapi.com/register?aff=cliproxyapi) -This project is sponsored by Z.ai, supporting us with their GLM CODING PLAN. +Thanks to PackyCode for sponsoring this project! -GLM CODING PLAN is a subscription service designed for AI coding, starting at just $10/month. It provides access to their flagship GLM-4.7 & (GLM-5 Only Available for Pro Users)model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences. +PackyCode is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more. -Get 10% OFF GLM CODING PLAN:https://z.ai/subscribe?ic=8JVLJQFSKB +PackyCode provides special discounts for our software users: register using this link and enter the "cliproxyapi" promo code during recharge to get 10% off. ---
![]() |
-Thanks to PackyCode for sponsoring this project! PackyCode is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more. PackyCode provides special discounts for our software users: register using this link and enter the "cliproxyapi" promo code during recharge to get 10% off. | -||
![]() |
Thanks to AICodeMirror for sponsoring this project! AICodeMirror provides official high-stability relay services for Claude Code / Codex / Gemini CLI, with enterprise-grade concurrency, fast invoicing, and 24/7 dedicated technical support. Claude Code / Codex / Gemini official channels at 38% / 2% / 9% of original price, with extra discounts on top-ups! AICodeMirror offers special benefits for CLIProxyAPI users: register via this link to enjoy 20% off your first top-up, and enterprise customers can get up to 25% off! | Huge thanks to BmoPlus for sponsoring this project! BmoPlus is a highly reliable AI account provider built strictly for heavy AI users and developers. They offer rock-solid, ready-to-use accounts and official top-up services for ChatGPT Plus / ChatGPT Pro (Full Warranty) / Claude Pro / Super Grok / Gemini Pro. By registering and ordering through BmoPlus - Premium AI Accounts & Top-ups, users can unlock the mind-blowing rate of 10% of the official GPT subscription price (90% OFF)! | |
![]() |
-Thanks to LingtrueAPI for its sponsorship of this project! LingtrueAPI is a global large - model API intermediary service platform that provides API calling services for various top - notch models such as Claude Code, Codex, and Gemini. It is committed to enabling users to connect to global AI capabilities at low cost and with high stability. LingtrueAPI offers special discounts to users of this software: register using this link, and enter the promo code "LingtrueAPI" when making the first recharge to enjoy a 10% discount. | -||
![]() |
-Thanks to Poixe AI for sponsoring this project! Poixe AI provides reliable LLM API services. You can leverage the platform's API endpoints to seamlessly build AI-powered products. Additionally, you can become a vendor by providing AI API resources to the platform and earn revenue. Register through the exclusive CLIProxyAPI referral link and receive a bonus of $5 USD on your first top-up. | +![]() |
+Thanks to VisionCoder for supporting this project. VisionCoder Developer Platform is a reliable and efficient API relay service provider, offering access to mainstream AI models such as Claude Code, Codex, and Gemini. It helps developers and teams integrate AI capabilities more easily and improve productivity. + +VisionCoder is also offering our users a limited-time Token Plan promotion: buy 1 month and get 1 month free. |
![]() |
-感谢 PackyCode 对本项目的赞助!PackyCode 是一家可靠高效的 API 中转服务商,提供 Claude Code、Codex、Gemini 等多种服务的中转。PackyCode 为本软件用户提供了特别优惠:使用此链接注册,并在充值时输入 "cliproxyapi" 优惠码即可享受九折优惠。 | -||
![]() |
-感谢 AICodeMirror 赞助了本项目!AICodeMirror 提供 Claude Code / Codex / Gemini CLI 官方高稳定中转服务,支持企业级高并发、极速开票、7×24 专属技术支持。 Claude Code / Codex / Gemini 官方渠道低至 3.8 / 0.2 / 0.9 折,充值更有折上折!AICodeMirror 为 CLIProxyAPI 的用户提供了特别福利,通过此链接注册的用户,可享受首充8折,企业客户最高可享 7.5 折! | +感谢 AICodeMirror 赞助了本项目!AICodeMirror 提供 Claude Code / Codex / Gemini CLI 官方高稳定中转服务,支持企业级高并发、极速开票、7×24 专属技术支持。 Claude Code / Codex / Gemini 官方渠道低至 3.8 / 0.2 / 0.9 折,充值更有折上折!AICodeMirror 为 CLIProxyAPI 的用户提供了特别福利,通过此链接注册的用户,可享受首充8折,企业客户最高可享 7.5 折! | |
![]() |
-感谢 BmoPlus 赞助了本项目!BmoPlus 是一家专为AI订阅重度用户打造的可靠 AI 账号代充服务商,提供稳定的 ChatGPT Plus / ChatGPT Pro(全程质保) / Claude Pro / Super Grok / Gemini Pro 的官方代充&成品账号。 通过BmoPlus AI成品号专卖/代充注册下单的用户,可享GPT 官网订阅一折 的震撼价格! | +感谢 BmoPlus 赞助了本项目!BmoPlus 是一家专为AI订阅重度用户打造的可靠 AI 账号代充服务商,提供稳定的 ChatGPT Plus / ChatGPT Pro(全程质保) / Claude Pro / Super Grok / Gemini Pro 的官方代充&成品账号。 通过BmoPlus AI成品号专卖/代充注册下单的用户,可享GPT 官网订阅一折 的震撼价格! | |
![]() |
-感谢 LingtrueAPI 对本项目的赞助!LingtrueAPI 是一家全球大模型API中转服务平台,提供Claude Code、Codex、Gemini 等多种顶级模型API调用服务,致力于让用户以低成本、高稳定性链接全球AI能力。LingtrueAPI为本软件用户提供了特别优惠:使用此链接注册,并在首次充值时输入 "LingtrueAPI" 优惠码即可享受9折优惠。 | -||
![]() |
-感谢 Poixe AI 对本项目的赞助!Poixe AI 提供可靠的 AI 模型接口服务,您可以使用平台提供的 LLM API 接口轻松构建 AI 产品,同时也可以成为供应商,为平台提供大模型资源以赚取收益。通过 CLIProxyAPI 专属链接注册,充值额外赠送 $5 美金 | +![]() |
+感谢 VisionCoder 对本项目的支持。VisionCoder 开发平台 是一个可靠高效的 API 中继服务提供商,提供 Claude Code、Codex、Gemini 等主流 AI 模型,帮助开发者和团队更轻松地集成 AI 功能,提升工作效率。 + +VisionCoder 还为我们的用户提供 Token Plan 限时活动:购买 1 个月,赠送 1 个月。 |
![]() |
-PackyCodeのスポンサーシップに感謝します!PackyCodeは信頼性が高く効率的なAPIリレーサービスプロバイダーで、Claude Code、Codex、Geminiなどのリレーサービスを提供しています。PackyCodeは当ソフトウェアのユーザーに特別割引を提供しています:こちらのリンクから登録し、チャージ時にプロモーションコード「cliproxyapi」を入力すると10%割引になります。 | -||
![]() |
AICodeMirrorのスポンサーシップに感謝します!AICodeMirrorはClaude Code / Codex / Gemini CLI向けの公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時接続、迅速な請求書発行、24時間365日の専任技術サポートを備えています。Claude Code / Codex / Geminiの公式チャネルが元の価格の38% / 2% / 9%で利用でき、チャージ時にはさらに割引があります!CLIProxyAPIユーザー向けの特別特典:こちらのリンクから登録すると、初回チャージが20%割引になり、エンタープライズのお客様は最大25%割引を受けられます! | 本プロジェクトにご支援いただいた BmoPlus に感謝いたします!BmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらのBmoPlus AIアカウント専門店/代行チャージ経由でご登録・ご注文いただいたユーザー様は、GPTを 公式サイト価格の約1割(90% OFF) という驚異的な価格でご利用いただけます! | |
![]() |
-LingtrueAPIのスポンサーシップに感謝します!LingtrueAPIはグローバルな大規模モデルAPIリレーサービスプラットフォームで、Claude Code、Codex、GeminiなどのトップモデルAPI呼び出しサービスを提供し、ユーザーが低コストかつ高い安定性で世界中のAI能力に接続できるよう支援しています。LingtrueAPIは本ソフトウェアのユーザーに特別割引を提供しています:こちらのリンクから登録し、初回チャージ時にプロモーションコード「LingtrueAPI」を入力すると10%割引になります。 | -||
![]() |
-Poixe AIのスポンサーシップに感謝します!Poixe AIは信頼できるAIモデルAPIサービスを提供しており、プラットフォームが提供するLLM APIを使って簡単にAI製品を構築できます。また、サプライヤーとしてプラットフォームに大規模モデルのリソースを提供し、収益を得ることも可能です。CLIProxyAPIの専用リンクから登録すると、チャージ時に追加で$5が付与されます。 | +![]() |
+VisionCoderのご支援に感謝します!VisionCoder 開発プラットフォーム は、信頼性が高く効率的なAPIリレーサービスプロバイダーで、Claude Code、Codex、Geminiなどの主要AIモデルを提供し、開発者やチームがより簡単にAI機能を統合して生産性を向上できるよう支援します。さらに、VisionCoderはユーザー向けに Token Plan の期間限定キャンペーン(1か月購入で1か月分プレゼント)も提供しています。 |
You can close this tab.
This tab will close automatically in 3 seconds.
Login with your HuaweiCloud account to use CodeArts IDE models through CLIProxyAPI.
+Start Login +Click the button below to open HuaweiCloud login page. After login, you will be redirected back here.
+Open HuaweiCloud Login +Credential captured, syncing. Please return to the command line.
Login with your JD account to use JoyCode models through CLIProxyAPI.
+Start Login +Click the button below to open JoyCode login page. After login, credentials will be captured automatically.
+Open JoyCode Login +%s
You can close this window.
`, html.EscapeString(errParam)) + resultChan <- AuthResult{Error: errParam} + return + } + + if state != expectedState { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, `Invalid state parameter
You can close this window.
`) + resultChan <- AuthResult{Error: "state mismatch"} + return + } + + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `You can close this window and return to the terminal.
`) + resultChan <- AuthResult{Code: code, State: state} + }) + + server.Handler = mux + + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + log.Debugf("callback server error: %v", err) + } + }() + + go func() { + select { + case <-ctx.Done(): + case <-time.After(authTimeout): + case <-resultChan: + } + _ = server.Shutdown(context.Background()) + }() + + return redirectURI, resultChan, nil +} + +// LoginWithBuilderID performs OAuth login with AWS Builder ID using device code flow. +func (o *KiroOAuth) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, error) { + ssoClient := NewSSOOIDCClient(o.cfg) + return ssoClient.LoginWithBuilderID(ctx) +} + +// LoginWithBuilderIDAuthCode performs OAuth login with AWS Builder ID using authorization code flow. +// This provides a better UX than device code flow as it uses automatic browser callback. +func (o *KiroOAuth) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTokenData, error) { + ssoClient := NewSSOOIDCClient(o.cfg) + return ssoClient.LoginWithBuilderIDAuthCode(ctx) +} + +// exchangeCodeForToken exchanges the authorization code for tokens. +func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier, redirectURI string) (*KiroTokenData, error) { + payload := map[string]string{ + "code": code, + "code_verifier": codeVerifier, + "redirect_uri": redirectURI, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + tokenURL := kiroAuthEndpoint + "/oauth/token" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(string(body))) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", o.kiroVersion, o.machineID)) + req.Header.Set("Accept", "application/json, text/plain, */*") + + resp, err := o.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("token request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("token exchange failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token exchange failed (status %d)", resp.StatusCode) + } + + var tokenResp KiroTokenResponse + if err := json.Unmarshal(respBody, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + + // Validate ExpiresIn - use default 1 hour if invalid + expiresIn := tokenResp.ExpiresIn + if expiresIn <= 0 { + expiresIn = 3600 + } + expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: tokenResp.ProfileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "social", + Provider: "", // Caller should preserve original provider + Region: "us-east-1", + }, nil +} + +// RefreshToken refreshes an expired access token. +// Uses KiroIDE-style User-Agent to match official Kiro IDE behavior. +func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) { + return o.RefreshTokenWithFingerprint(ctx, refreshToken, "") +} + +// RefreshTokenWithFingerprint refreshes an expired access token with a specific fingerprint. +// tokenKey is used to generate a consistent fingerprint for the token. +func (o *KiroOAuth) RefreshTokenWithFingerprint(ctx context.Context, refreshToken, tokenKey string) (*KiroTokenData, error) { + payload := map[string]string{ + "refreshToken": refreshToken, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + refreshURL := kiroAuthEndpoint + "/refreshToken" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, strings.NewReader(string(body))) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", o.kiroVersion, o.machineID)) + req.Header.Set("Accept", "application/json, text/plain, */*") + + resp, err := o.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("refresh request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + } + + var tokenResp KiroTokenResponse + if err := json.Unmarshal(respBody, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + + // Validate ExpiresIn - use default 1 hour if invalid + expiresIn := tokenResp.ExpiresIn + if expiresIn <= 0 { + expiresIn = 3600 + } + expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: tokenResp.ProfileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "social", + Provider: "", // Caller should preserve original provider + Region: "us-east-1", + }, nil +} + +// LoginWithGoogle performs OAuth login with Google using Kiro's social auth. +// This uses a custom protocol handler (kiro://) to receive the callback. +func (o *KiroOAuth) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) { + socialClient := NewSocialAuthClient(o.cfg) + return socialClient.LoginWithGoogle(ctx) +} + +// LoginWithGitHub performs OAuth login with GitHub using Kiro's social auth. +// This uses a custom protocol handler (kiro://) to receive the callback. +func (o *KiroOAuth) LoginWithGitHub(ctx context.Context) (*KiroTokenData, error) { + socialClient := NewSocialAuthClient(o.cfg) + return socialClient.LoginWithGitHub(ctx) +} diff --git a/internal/auth/kiro/oauth_web.go b/internal/auth/kiro/oauth_web.go new file mode 100644 index 0000000000..4618a838da --- /dev/null +++ b/internal/auth/kiro/oauth_web.go @@ -0,0 +1,975 @@ +// Package kiro provides OAuth Web authentication for Kiro. +package kiro + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/json" + "fmt" + "html/template" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + "github.com/router-for-me/CLIProxyAPI/v7/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + defaultSessionExpiry = 10 * time.Minute + pollIntervalSeconds = 5 +) + +type authSessionStatus string + +const ( + statusPending authSessionStatus = "pending" + statusSuccess authSessionStatus = "success" + statusFailed authSessionStatus = "failed" +) + +type webAuthSession struct { + stateID string + deviceCode string + userCode string + authURL string + verificationURI string + expiresIn int + interval int + status authSessionStatus + startedAt time.Time + completedAt time.Time + expiresAt time.Time + error string + tokenData *KiroTokenData + ssoClient *SSOOIDCClient + clientID string + clientSecret string + region string + cancelFunc context.CancelFunc + authMethod string // "google", "github", "builder-id", "idc" + startURL string // Used for IDC + codeVerifier string // Used for social auth PKCE + codeChallenge string // Used for social auth PKCE +} + +type OAuthWebHandler struct { + cfg *config.Config + sessions map[string]*webAuthSession + mu sync.RWMutex + onTokenObtained func(*KiroTokenData) +} + +func NewOAuthWebHandler(cfg *config.Config) *OAuthWebHandler { + return &OAuthWebHandler{ + cfg: cfg, + sessions: make(map[string]*webAuthSession), + } +} + +func (h *OAuthWebHandler) SetTokenCallback(callback func(*KiroTokenData)) { + h.onTokenObtained = callback +} + +func (h *OAuthWebHandler) RegisterRoutes(router gin.IRouter) { + oauth := router.Group("/v0/oauth/kiro") + { + oauth.GET("", h.handleSelect) + oauth.GET("/start", h.handleStart) + oauth.GET("/callback", h.handleCallback) + oauth.GET("/social/callback", h.handleSocialCallback) + oauth.GET("/status", h.handleStatus) + oauth.POST("/import", h.handleImportToken) + oauth.POST("/refresh", h.handleManualRefresh) + } +} + +func generateStateID() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +func (h *OAuthWebHandler) handleSelect(c *gin.Context) { + h.renderSelectPage(c) +} + +func (h *OAuthWebHandler) handleStart(c *gin.Context) { + method := c.Query("method") + + if method == "" { + c.Redirect(http.StatusFound, "/v0/oauth/kiro") + return + } + + switch method { + case "google", "github": + // Google/GitHub social login is not supported for third-party apps + // due to AWS Cognito redirect_uri restrictions + h.renderError(c, "Google/GitHub login is not available for third-party applications. Please use AWS Builder ID or import your token from Kiro IDE.") + case "builder-id": + h.startBuilderIDAuth(c) + case "idc": + h.startIDCAuth(c) + default: + h.renderError(c, fmt.Sprintf("Unknown authentication method: %s", method)) + } +} + +func (h *OAuthWebHandler) startSocialAuth(c *gin.Context, method string) { + stateID, err := generateStateID() + if err != nil { + h.renderError(c, "Failed to generate state parameter") + return + } + + codeVerifier, codeChallenge, err := generatePKCE() + if err != nil { + h.renderError(c, "Failed to generate PKCE parameters") + return + } + + socialClient := NewSocialAuthClient(h.cfg) + + var provider string + if method == "google" { + provider = string(ProviderGoogle) + } else { + provider = string(ProviderGitHub) + } + + redirectURI := h.getSocialCallbackURL(c) + authURL := socialClient.buildLoginURL(provider, redirectURI, codeChallenge, stateID) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + + session := &webAuthSession{ + stateID: stateID, + authMethod: method, + authURL: authURL, + status: statusPending, + startedAt: time.Now(), + expiresIn: 600, + codeVerifier: codeVerifier, + codeChallenge: codeChallenge, + region: "us-east-1", + cancelFunc: cancel, + } + + h.mu.Lock() + h.sessions[stateID] = session + h.mu.Unlock() + + go func() { + <-ctx.Done() + h.mu.Lock() + if session.status == statusPending { + session.status = statusFailed + session.error = "Authentication timed out" + } + h.mu.Unlock() + }() + + c.Redirect(http.StatusFound, authURL) +} + +func (h *OAuthWebHandler) getSocialCallbackURL(c *gin.Context) string { + scheme := "http" + if c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https" { + scheme = "https" + } + return fmt.Sprintf("%s://%s/v0/oauth/kiro/social/callback", scheme, c.Request.Host) +} + +func (h *OAuthWebHandler) startBuilderIDAuth(c *gin.Context) { + stateID, err := generateStateID() + if err != nil { + h.renderError(c, "Failed to generate state parameter") + return + } + + region := defaultIDCRegion + startURL := builderIDStartURL + + ssoClient := NewSSOOIDCClient(h.cfg) + + regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region) + if err != nil { + log.Errorf("OAuth Web: failed to register client: %v", err) + h.renderError(c, fmt.Sprintf("Failed to register client: %v", err)) + return + } + + authResp, err := ssoClient.StartDeviceAuthorizationWithIDC( + c.Request.Context(), + regResp.ClientID, + regResp.ClientSecret, + startURL, + region, + ) + if err != nil { + log.Errorf("OAuth Web: failed to start device authorization: %v", err) + h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err)) + return + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second) + + session := &webAuthSession{ + stateID: stateID, + deviceCode: authResp.DeviceCode, + userCode: authResp.UserCode, + authURL: authResp.VerificationURIComplete, + verificationURI: authResp.VerificationURI, + expiresIn: authResp.ExpiresIn, + interval: authResp.Interval, + status: statusPending, + startedAt: time.Now(), + ssoClient: ssoClient, + clientID: regResp.ClientID, + clientSecret: regResp.ClientSecret, + region: region, + authMethod: "builder-id", + startURL: startURL, + cancelFunc: cancel, + } + + h.mu.Lock() + h.sessions[stateID] = session + h.mu.Unlock() + + go h.pollForToken(ctx, session) + + h.renderStartPage(c, session) +} + +func (h *OAuthWebHandler) startIDCAuth(c *gin.Context) { + startURL := c.Query("startUrl") + region := c.Query("region") + + if startURL == "" { + h.renderError(c, "Missing startUrl parameter for IDC authentication") + return + } + if region == "" { + region = defaultIDCRegion + } + + stateID, err := generateStateID() + if err != nil { + h.renderError(c, "Failed to generate state parameter") + return + } + + ssoClient := NewSSOOIDCClient(h.cfg) + + regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region) + if err != nil { + log.Errorf("OAuth Web: failed to register client: %v", err) + h.renderError(c, fmt.Sprintf("Failed to register client: %v", err)) + return + } + + authResp, err := ssoClient.StartDeviceAuthorizationWithIDC( + c.Request.Context(), + regResp.ClientID, + regResp.ClientSecret, + startURL, + region, + ) + if err != nil { + log.Errorf("OAuth Web: failed to start device authorization: %v", err) + h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err)) + return + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second) + + session := &webAuthSession{ + stateID: stateID, + deviceCode: authResp.DeviceCode, + userCode: authResp.UserCode, + authURL: authResp.VerificationURIComplete, + verificationURI: authResp.VerificationURI, + expiresIn: authResp.ExpiresIn, + interval: authResp.Interval, + status: statusPending, + startedAt: time.Now(), + ssoClient: ssoClient, + clientID: regResp.ClientID, + clientSecret: regResp.ClientSecret, + region: region, + authMethod: "idc", + startURL: startURL, + cancelFunc: cancel, + } + + h.mu.Lock() + h.sessions[stateID] = session + h.mu.Unlock() + + go h.pollForToken(ctx, session) + + h.renderStartPage(c, session) +} + +func (h *OAuthWebHandler) pollForToken(ctx context.Context, session *webAuthSession) { + defer session.cancelFunc() + + interval := time.Duration(session.interval) * time.Second + if interval < time.Duration(pollIntervalSeconds)*time.Second { + interval = time.Duration(pollIntervalSeconds) * time.Second + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + h.mu.Lock() + if session.status == statusPending { + session.status = statusFailed + session.error = "Authentication timed out" + } + h.mu.Unlock() + return + case <-ticker.C: + tokenResp, err := h.ssoClient(session).CreateTokenWithRegion( + ctx, + session.clientID, + session.clientSecret, + session.deviceCode, + session.region, + ) + + if err != nil { + errStr := err.Error() + if errStr == ErrAuthorizationPending.Error() { + continue + } + if errStr == ErrSlowDown.Error() { + interval += 5 * time.Second + ticker.Reset(interval) + continue + } + + h.mu.Lock() + session.status = statusFailed + session.error = errStr + session.completedAt = time.Now() + h.mu.Unlock() + + log.Errorf("OAuth Web: token polling failed: %v", err) + return + } + + expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + + // Fetch profileArn for IDC + var profileArn string + if session.authMethod == "idc" { + profileArn = session.ssoClient.FetchProfileArn(ctx, tokenResp.AccessToken, session.clientID, tokenResp.RefreshToken) + } + + email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken, session.clientID, tokenResp.RefreshToken) + + tokenData := &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: profileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: session.authMethod, + Provider: "AWS", + ClientID: session.clientID, + ClientSecret: session.clientSecret, + Email: email, + Region: session.region, + StartURL: session.startURL, + } + + h.mu.Lock() + session.status = statusSuccess + session.completedAt = time.Now() + session.expiresAt = expiresAt + session.tokenData = tokenData + h.mu.Unlock() + + if h.onTokenObtained != nil { + h.onTokenObtained(tokenData) + } + + // Save token to file + h.saveTokenToFile(tokenData) + + log.Infof("OAuth Web: authentication successful for %s", email) + return + } + } +} + +// saveTokenToFile saves the token data to the auth directory +func (h *OAuthWebHandler) saveTokenToFile(tokenData *KiroTokenData) { + // Get auth directory from config or use default + authDir := "" + if h.cfg != nil && h.cfg.AuthDir != "" { + var err error + authDir, err = util.ResolveAuthDir(h.cfg.AuthDir) + if err != nil { + log.Errorf("OAuth Web: failed to resolve auth directory: %v", err) + } + } + + // Fall back to default location + if authDir == "" { + home, err := os.UserHomeDir() + if err != nil { + log.Errorf("OAuth Web: failed to get home directory: %v", err) + return + } + authDir = filepath.Join(home, ".cli-proxy-api") + } + + // Create directory if not exists + if err := os.MkdirAll(authDir, 0700); err != nil { + log.Errorf("OAuth Web: failed to create auth directory: %v", err) + return + } + + // Generate filename using the unified function + fileName := GenerateTokenFileName(tokenData) + + authFilePath := filepath.Join(authDir, fileName) + + // Convert to storage format and save + storage := &KiroTokenStorage{ + Type: "kiro", + AccessToken: tokenData.AccessToken, + RefreshToken: tokenData.RefreshToken, + ProfileArn: tokenData.ProfileArn, + ExpiresAt: tokenData.ExpiresAt, + AuthMethod: tokenData.AuthMethod, + Provider: tokenData.Provider, + LastRefresh: time.Now().Format(time.RFC3339), + ClientID: tokenData.ClientID, + ClientSecret: tokenData.ClientSecret, + Region: tokenData.Region, + StartURL: tokenData.StartURL, + Email: tokenData.Email, + } + + if err := storage.SaveTokenToFile(authFilePath); err != nil { + log.Errorf("OAuth Web: failed to save token to file: %v", err) + return + } + + log.Infof("OAuth Web: token saved to %s", authFilePath) +} + +func (h *OAuthWebHandler) ssoClient(session *webAuthSession) *SSOOIDCClient { + return session.ssoClient +} + +func (h *OAuthWebHandler) handleCallback(c *gin.Context) { + stateID := c.Query("state") + errParam := c.Query("error") + + if errParam != "" { + h.renderError(c, errParam) + return + } + + if stateID == "" { + h.renderError(c, "Missing state parameter") + return + } + + h.mu.RLock() + session, exists := h.sessions[stateID] + h.mu.RUnlock() + + if !exists { + h.renderError(c, "Invalid or expired session") + return + } + + if session.status == statusSuccess { + h.renderSuccess(c, session) + } else if session.status == statusFailed { + h.renderError(c, session.error) + } else { + c.Redirect(http.StatusFound, "/v0/oauth/kiro/start") + } +} + +func (h *OAuthWebHandler) handleSocialCallback(c *gin.Context) { + stateID := c.Query("state") + code := c.Query("code") + errParam := c.Query("error") + + if errParam != "" { + h.renderError(c, errParam) + return + } + + if stateID == "" { + h.renderError(c, "Missing state parameter") + return + } + + if code == "" { + h.renderError(c, "Missing authorization code") + return + } + + h.mu.RLock() + session, exists := h.sessions[stateID] + h.mu.RUnlock() + + if !exists { + h.renderError(c, "Invalid or expired session") + return + } + + if session.authMethod != "google" && session.authMethod != "github" { + h.renderError(c, "Invalid session type for social callback") + return + } + + socialClient := NewSocialAuthClient(h.cfg) + redirectURI := h.getSocialCallbackURL(c) + + tokenReq := &CreateTokenRequest{ + Code: code, + CodeVerifier: session.codeVerifier, + RedirectURI: redirectURI, + } + + tokenResp, err := socialClient.CreateToken(c.Request.Context(), tokenReq) + if err != nil { + log.Errorf("OAuth Web: social token exchange failed: %v", err) + h.mu.Lock() + session.status = statusFailed + session.error = fmt.Sprintf("Token exchange failed: %v", err) + session.completedAt = time.Now() + h.mu.Unlock() + h.renderError(c, session.error) + return + } + + expiresIn := tokenResp.ExpiresIn + if expiresIn <= 0 { + expiresIn = 3600 + } + expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second) + + email := ExtractEmailFromJWT(tokenResp.AccessToken) + + var provider string + if session.authMethod == "google" { + provider = string(ProviderGoogle) + } else { + provider = string(ProviderGitHub) + } + + tokenData := &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: tokenResp.ProfileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: session.authMethod, + Provider: provider, + Email: email, + Region: "us-east-1", + } + + h.mu.Lock() + session.status = statusSuccess + session.completedAt = time.Now() + session.expiresAt = expiresAt + session.tokenData = tokenData + h.mu.Unlock() + + if session.cancelFunc != nil { + session.cancelFunc() + } + + if h.onTokenObtained != nil { + h.onTokenObtained(tokenData) + } + + // Save token to file + h.saveTokenToFile(tokenData) + + log.Infof("OAuth Web: social authentication successful for %s via %s", email, provider) + h.renderSuccess(c, session) +} + +func (h *OAuthWebHandler) handleStatus(c *gin.Context) { + stateID := c.Query("state") + if stateID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "missing state parameter"}) + return + } + + h.mu.RLock() + session, exists := h.sessions[stateID] + h.mu.RUnlock() + + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "session not found"}) + return + } + + response := gin.H{ + "status": string(session.status), + } + + switch session.status { + case statusPending: + elapsed := time.Since(session.startedAt).Seconds() + remaining := float64(session.expiresIn) - elapsed + if remaining < 0 { + remaining = 0 + } + response["remaining_seconds"] = int(remaining) + case statusSuccess: + response["completed_at"] = session.completedAt.Format(time.RFC3339) + response["expires_at"] = session.expiresAt.Format(time.RFC3339) + case statusFailed: + response["error"] = session.error + response["failed_at"] = session.completedAt.Format(time.RFC3339) + } + + c.JSON(http.StatusOK, response) +} + +func (h *OAuthWebHandler) renderStartPage(c *gin.Context, session *webAuthSession) { + tmpl, err := template.New("start").Parse(oauthWebStartPageHTML) + if err != nil { + log.Errorf("OAuth Web: failed to parse template: %v", err) + c.String(http.StatusInternalServerError, "Template error") + return + } + + data := map[string]interface{}{ + "AuthURL": session.authURL, + "UserCode": session.userCode, + "ExpiresIn": session.expiresIn, + "StateID": session.stateID, + } + + c.Header("Content-Type", "text/html; charset=utf-8") + if err := tmpl.Execute(c.Writer, data); err != nil { + log.Errorf("OAuth Web: failed to render template: %v", err) + } +} + +func (h *OAuthWebHandler) renderSelectPage(c *gin.Context) { + tmpl, err := template.New("select").Parse(oauthWebSelectPageHTML) + if err != nil { + log.Errorf("OAuth Web: failed to parse select template: %v", err) + c.String(http.StatusInternalServerError, "Template error") + return + } + + c.Header("Content-Type", "text/html; charset=utf-8") + if err := tmpl.Execute(c.Writer, nil); err != nil { + log.Errorf("OAuth Web: failed to render select template: %v", err) + } +} + +func (h *OAuthWebHandler) renderError(c *gin.Context, errMsg string) { + tmpl, err := template.New("error").Parse(oauthWebErrorPageHTML) + if err != nil { + log.Errorf("OAuth Web: failed to parse error template: %v", err) + c.String(http.StatusInternalServerError, "Template error") + return + } + + data := map[string]interface{}{ + "Error": errMsg, + } + + c.Header("Content-Type", "text/html; charset=utf-8") + c.Status(http.StatusBadRequest) + if err := tmpl.Execute(c.Writer, data); err != nil { + log.Errorf("OAuth Web: failed to render error template: %v", err) + } +} + +func (h *OAuthWebHandler) renderSuccess(c *gin.Context, session *webAuthSession) { + tmpl, err := template.New("success").Parse(oauthWebSuccessPageHTML) + if err != nil { + log.Errorf("OAuth Web: failed to parse success template: %v", err) + c.String(http.StatusInternalServerError, "Template error") + return + } + + data := map[string]interface{}{ + "ExpiresAt": session.expiresAt.Format(time.RFC3339), + } + + c.Header("Content-Type", "text/html; charset=utf-8") + if err := tmpl.Execute(c.Writer, data); err != nil { + log.Errorf("OAuth Web: failed to render success template: %v", err) + } +} + +func (h *OAuthWebHandler) CleanupExpiredSessions() { + h.mu.Lock() + defer h.mu.Unlock() + + now := time.Now() + for id, session := range h.sessions { + if session.status != statusPending && now.Sub(session.completedAt) > 30*time.Minute { + delete(h.sessions, id) + } else if session.status == statusPending && now.Sub(session.startedAt) > defaultSessionExpiry { + session.cancelFunc() + delete(h.sessions, id) + } + } +} + +func (h *OAuthWebHandler) GetSession(stateID string) (*webAuthSession, bool) { + h.mu.RLock() + defer h.mu.RUnlock() + session, exists := h.sessions[stateID] + return session, exists +} + +// ImportTokenRequest represents the request body for token import +type ImportTokenRequest struct { + RefreshToken string `json:"refreshToken"` +} + +// handleImportToken handles manual refresh token import from Kiro IDE +func (h *OAuthWebHandler) handleImportToken(c *gin.Context) { + var req ImportTokenRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "error": "Invalid request body", + }) + return + } + + refreshToken := strings.TrimSpace(req.RefreshToken) + if refreshToken == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "error": "Refresh token is required", + }) + return + } + + // Validate token format + if !strings.HasPrefix(refreshToken, "aorAAAAAG") { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "error": "Invalid token format. Token should start with aorAAAAAG...", + }) + return + } + + // Create social auth client to refresh and validate the token + socialClient := NewSocialAuthClient(h.cfg) + + // Refresh the token to validate it and get access token + tokenData, err := socialClient.RefreshSocialToken(c.Request.Context(), refreshToken) + if err != nil { + log.Errorf("OAuth Web: token refresh failed during import: %v", err) + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "error": fmt.Sprintf("Token validation failed: %v", err), + }) + return + } + + // Set the original refresh token (the refreshed one might be empty) + if tokenData.RefreshToken == "" { + tokenData.RefreshToken = refreshToken + } + tokenData.AuthMethod = "social" + tokenData.Provider = "imported" + + // Notify callback if set + if h.onTokenObtained != nil { + h.onTokenObtained(tokenData) + } + + // Save token to file + h.saveTokenToFile(tokenData) + + // Generate filename for response using the unified function + fileName := GenerateTokenFileName(tokenData) + + log.Infof("OAuth Web: token imported successfully") + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "Token imported successfully", + "fileName": fileName, + }) +} + +// handleManualRefresh handles manual token refresh requests from the web UI. +// This allows users to trigger a token refresh when needed, without waiting +// for the automatic 30-second check and 20-minute-before-expiry refresh cycle. +// Uses the same refresh logic as kiro_executor.Refresh for consistency. +func (h *OAuthWebHandler) handleManualRefresh(c *gin.Context) { + authDir := "" + if h.cfg != nil && h.cfg.AuthDir != "" { + var err error + authDir, err = util.ResolveAuthDir(h.cfg.AuthDir) + if err != nil { + log.Errorf("OAuth Web: failed to resolve auth directory: %v", err) + } + } + + if authDir == "" { + home, err := os.UserHomeDir() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "success": false, + "error": "Failed to get home directory", + }) + return + } + authDir = filepath.Join(home, ".cli-proxy-api") + } + + // Find all kiro token files in the auth directory + files, err := os.ReadDir(authDir) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "success": false, + "error": fmt.Sprintf("Failed to read auth directory: %v", err), + }) + return + } + + var refreshedCount int + var errors []string + + for _, file := range files { + if file.IsDir() { + continue + } + name := file.Name() + if !strings.HasPrefix(name, "kiro-") || !strings.HasSuffix(name, ".json") { + continue + } + + filePath := filepath.Join(authDir, name) + data, err := os.ReadFile(filePath) + if err != nil { + errors = append(errors, fmt.Sprintf("%s: read error - %v", name, err)) + continue + } + + var storage KiroTokenStorage + if err := json.Unmarshal(data, &storage); err != nil { + errors = append(errors, fmt.Sprintf("%s: parse error - %v", name, err)) + continue + } + + if storage.RefreshToken == "" { + errors = append(errors, fmt.Sprintf("%s: no refresh token", name)) + continue + } + + // Refresh token using the same logic as kiro_executor.Refresh + tokenData, err := h.refreshTokenData(c.Request.Context(), &storage) + if err != nil { + errors = append(errors, fmt.Sprintf("%s: refresh failed - %v", name, err)) + continue + } + + // Update storage with new token data + storage.AccessToken = tokenData.AccessToken + if tokenData.RefreshToken != "" { + storage.RefreshToken = tokenData.RefreshToken + } + storage.ExpiresAt = tokenData.ExpiresAt + storage.LastRefresh = time.Now().Format(time.RFC3339) + if tokenData.ProfileArn != "" { + storage.ProfileArn = tokenData.ProfileArn + } + + // Write updated token back to file + updatedData, err := json.MarshalIndent(storage, "", " ") + if err != nil { + errors = append(errors, fmt.Sprintf("%s: marshal error - %v", name, err)) + continue + } + + tmpFile := filePath + ".tmp" + if err := os.WriteFile(tmpFile, updatedData, 0600); err != nil { + errors = append(errors, fmt.Sprintf("%s: write error - %v", name, err)) + continue + } + if err := os.Rename(tmpFile, filePath); err != nil { + errors = append(errors, fmt.Sprintf("%s: rename error - %v", name, err)) + continue + } + + log.Infof("OAuth Web: manually refreshed token in %s, expires at %s", name, tokenData.ExpiresAt) + refreshedCount++ + + // Notify callback if set + if h.onTokenObtained != nil { + h.onTokenObtained(tokenData) + } + } + + if refreshedCount == 0 && len(errors) > 0 { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "error": fmt.Sprintf("All refresh attempts failed: %v", errors), + }) + return + } + + response := gin.H{ + "success": true, + "message": fmt.Sprintf("Refreshed %d token(s)", refreshedCount), + "refreshedCount": refreshedCount, + } + if len(errors) > 0 { + response["warnings"] = errors + } + + c.JSON(http.StatusOK, response) +} + +// refreshTokenData refreshes a token using the appropriate method based on auth type. +// This mirrors the logic in kiro_executor.Refresh for consistency. +func (h *OAuthWebHandler) refreshTokenData(ctx context.Context, storage *KiroTokenStorage) (*KiroTokenData, error) { + ssoClient := NewSSOOIDCClient(h.cfg) + + switch { + case storage.ClientID != "" && storage.ClientSecret != "" && storage.AuthMethod == "idc" && storage.Region != "": + // IDC refresh with region-specific endpoint + log.Debugf("OAuth Web: using SSO OIDC refresh for IDC (region=%s)", storage.Region) + return ssoClient.RefreshTokenWithRegion(ctx, storage.ClientID, storage.ClientSecret, storage.RefreshToken, storage.Region, storage.StartURL) + + case storage.ClientID != "" && storage.ClientSecret != "" && storage.AuthMethod == "builder-id": + // Builder ID refresh with default endpoint + log.Debugf("OAuth Web: using SSO OIDC refresh for AWS Builder ID") + return ssoClient.RefreshToken(ctx, storage.ClientID, storage.ClientSecret, storage.RefreshToken) + + default: + // Fallback to Kiro's OAuth refresh endpoint (for social auth: Google/GitHub) + log.Debugf("OAuth Web: using Kiro OAuth refresh endpoint") + oauth := NewKiroOAuth(h.cfg) + return oauth.RefreshToken(ctx, storage.RefreshToken) + } +} diff --git a/internal/auth/kiro/oauth_web_templates.go b/internal/auth/kiro/oauth_web_templates.go new file mode 100644 index 0000000000..228677a511 --- /dev/null +++ b/internal/auth/kiro/oauth_web_templates.go @@ -0,0 +1,779 @@ +// Package kiro provides OAuth Web authentication templates. +package kiro + +const ( + oauthWebStartPageHTML = ` + + + + +Follow the steps below to complete authentication
+ ++ Use your AWS SSO account to login and authorize +
+Choose how you want to authenticate with Kiro
+ +~/.kiro/kiro-auth-token.jsonrefreshToken value and paste it above
+ Error: %s
+You can close this window.
+ +`, html.EscapeString(errParam)) + } else { + fmt.Fprint(w, ` + +You can close this window and return to the terminal.
+ + +`) + } +} + +// IsProtocolHandlerInstalled checks if the kiro:// protocol handler is installed. +func IsProtocolHandlerInstalled() bool { + switch runtime.GOOS { + case "linux": + return isLinuxHandlerInstalled() + case "windows": + return isWindowsHandlerInstalled() + case "darwin": + return isDarwinHandlerInstalled() + default: + return false + } +} + +// InstallProtocolHandler installs the kiro:// protocol handler for the current platform. +func InstallProtocolHandler(handlerPort int) error { + switch runtime.GOOS { + case "linux": + return installLinuxHandler(handlerPort) + case "windows": + return installWindowsHandler(handlerPort) + case "darwin": + return installDarwinHandler(handlerPort) + default: + return fmt.Errorf("unsupported platform: %s", runtime.GOOS) + } +} + +// UninstallProtocolHandler removes the kiro:// protocol handler. +func UninstallProtocolHandler() error { + switch runtime.GOOS { + case "linux": + return uninstallLinuxHandler() + case "windows": + return uninstallWindowsHandler() + case "darwin": + return uninstallDarwinHandler() + default: + return fmt.Errorf("unsupported platform: %s", runtime.GOOS) + } +} + +// --- Linux Implementation --- + +func getLinuxDesktopPath() string { + homeDir, _ := os.UserHomeDir() + return filepath.Join(homeDir, ".local", "share", "applications", "kiro-oauth-handler.desktop") +} + +func getLinuxHandlerScriptPath() string { + homeDir, _ := os.UserHomeDir() + return filepath.Join(homeDir, ".local", "bin", "kiro-oauth-handler") +} + +func isLinuxHandlerInstalled() bool { + desktopPath := getLinuxDesktopPath() + _, err := os.Stat(desktopPath) + return err == nil +} + +func installLinuxHandler(handlerPort int) error { + // Create directories + homeDir, err := os.UserHomeDir() + if err != nil { + return err + } + + binDir := filepath.Join(homeDir, ".local", "bin") + appDir := filepath.Join(homeDir, ".local", "share", "applications") + + if err := os.MkdirAll(binDir, 0755); err != nil { + return fmt.Errorf("failed to create bin directory: %w", err) + } + if err := os.MkdirAll(appDir, 0755); err != nil { + return fmt.Errorf("failed to create applications directory: %w", err) + } + + // Create handler script - tries multiple ports to handle dynamic port allocation + scriptPath := getLinuxHandlerScriptPath() + scriptContent := fmt.Sprintf(`#!/bin/bash +# Kiro OAuth Protocol Handler +# Handles kiro:// URIs - tries CLI first, then forwards to Kiro IDE + +URL="$1" + +# Check curl availability +if ! command -v curl &> /dev/null; then + echo "Error: curl is required for Kiro OAuth handler" >&2 + exit 1 +fi + +# Extract code and state from URL +[[ "$URL" =~ code=([^&]+) ]] && CODE="${BASH_REMATCH[1]}" +[[ "$URL" =~ state=([^&]+) ]] && STATE="${BASH_REMATCH[1]}" +[[ "$URL" =~ error=([^&]+) ]] && ERROR="${BASH_REMATCH[1]}" + +# Try CLI proxy on multiple possible ports (default + dynamic range) +CLI_OK=0 +for PORT in %d %d %d %d %d; do + if [ -n "$ERROR" ]; then + curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?error=$ERROR" && CLI_OK=1 && break + elif [ -n "$CODE" ] && [ -n "$STATE" ]; then + curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?code=$CODE&state=$STATE" && CLI_OK=1 && break + fi +done + +# If CLI not available, forward to Kiro IDE +if [ $CLI_OK -eq 0 ] && [ -x "/usr/share/kiro/kiro" ]; then + /usr/share/kiro/kiro --open-url "$URL" & +fi +`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4) + + if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil { + return fmt.Errorf("failed to write handler script: %w", err) + } + + // Create .desktop file + desktopPath := getLinuxDesktopPath() + desktopContent := fmt.Sprintf(`[Desktop Entry] +Name=Kiro OAuth Handler +Comment=Handle kiro:// protocol for CLI Proxy API authentication +Exec=%s %%u +Type=Application +Terminal=false +NoDisplay=true +MimeType=x-scheme-handler/kiro; +Categories=Utility; +`, scriptPath) + + if err := os.WriteFile(desktopPath, []byte(desktopContent), 0644); err != nil { + return fmt.Errorf("failed to write desktop file: %w", err) + } + + // Register handler with xdg-mime + cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro") + if err := cmd.Run(); err != nil { + log.Warnf("xdg-mime registration failed (may need manual setup): %v", err) + } + + // Update desktop database + cmd = exec.Command("update-desktop-database", appDir) + _ = cmd.Run() // Ignore errors, not critical + + log.Info("Kiro protocol handler installed for Linux") + return nil +} + +func uninstallLinuxHandler() error { + desktopPath := getLinuxDesktopPath() + scriptPath := getLinuxHandlerScriptPath() + + if err := os.Remove(desktopPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove desktop file: %w", err) + } + if err := os.Remove(scriptPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove handler script: %w", err) + } + + log.Info("Kiro protocol handler uninstalled") + return nil +} + +// --- Windows Implementation --- + +func isWindowsHandlerInstalled() bool { + // Check registry key existence + cmd := exec.Command("reg", "query", `HKCU\Software\Classes\kiro`, "/ve") + return cmd.Run() == nil +} + +func installWindowsHandler(handlerPort int) error { + homeDir, err := os.UserHomeDir() + if err != nil { + return err + } + + // Create handler script (PowerShell) + scriptDir := filepath.Join(homeDir, ".cliproxyapi") + if err := os.MkdirAll(scriptDir, 0755); err != nil { + return fmt.Errorf("failed to create script directory: %w", err) + } + + scriptPath := filepath.Join(scriptDir, "kiro-oauth-handler.ps1") + scriptContent := fmt.Sprintf(`# Kiro OAuth Protocol Handler for Windows +param([string]$url) + +# Load required assembly for HttpUtility +Add-Type -AssemblyName System.Web + +# Parse URL parameters +$uri = [System.Uri]$url +$query = [System.Web.HttpUtility]::ParseQueryString($uri.Query) +$code = $query["code"] +$state = $query["state"] +$errorParam = $query["error"] + +# Try multiple ports (default + dynamic range) +$ports = @(%d, %d, %d, %d, %d) +$success = $false + +foreach ($port in $ports) { + if ($success) { break } + $callbackUrl = "http://127.0.0.1:$port/oauth/callback" + try { + if ($errorParam) { + $fullUrl = $callbackUrl + "?error=" + $errorParam + Invoke-WebRequest -Uri $fullUrl -UseBasicParsing -TimeoutSec 1 -ErrorAction Stop | Out-Null + $success = $true + } elseif ($code -and $state) { + $fullUrl = $callbackUrl + "?code=" + $code + "&state=" + $state + Invoke-WebRequest -Uri $fullUrl -UseBasicParsing -TimeoutSec 1 -ErrorAction Stop | Out-Null + $success = $true + } + } catch { + # Try next port + } +} +`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4) + + if err := os.WriteFile(scriptPath, []byte(scriptContent), 0644); err != nil { + return fmt.Errorf("failed to write handler script: %w", err) + } + + // Create batch wrapper + batchPath := filepath.Join(scriptDir, "kiro-oauth-handler.bat") + batchContent := fmt.Sprintf("@echo off\npowershell -ExecutionPolicy Bypass -File \"%s\" %%1\n", scriptPath) + + if err := os.WriteFile(batchPath, []byte(batchContent), 0644); err != nil { + return fmt.Errorf("failed to write batch wrapper: %w", err) + } + + // Register in Windows registry + commands := [][]string{ + {"reg", "add", `HKCU\Software\Classes\kiro`, "/ve", "/d", "URL:Kiro Protocol", "/f"}, + {"reg", "add", `HKCU\Software\Classes\kiro`, "/v", "URL Protocol", "/d", "", "/f"}, + {"reg", "add", `HKCU\Software\Classes\kiro\shell`, "/f"}, + {"reg", "add", `HKCU\Software\Classes\kiro\shell\open`, "/f"}, + {"reg", "add", `HKCU\Software\Classes\kiro\shell\open\command`, "/ve", "/d", fmt.Sprintf("\"%s\" \"%%1\"", batchPath), "/f"}, + } + + for _, args := range commands { + cmd := exec.Command(args[0], args[1:]...) + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to run registry command: %w", err) + } + } + + log.Info("Kiro protocol handler installed for Windows") + return nil +} + +func uninstallWindowsHandler() error { + // Remove registry keys + cmd := exec.Command("reg", "delete", `HKCU\Software\Classes\kiro`, "/f") + if err := cmd.Run(); err != nil { + log.Warnf("failed to remove registry key: %v", err) + } + + // Remove scripts + homeDir, _ := os.UserHomeDir() + scriptDir := filepath.Join(homeDir, ".cliproxyapi") + _ = os.Remove(filepath.Join(scriptDir, "kiro-oauth-handler.ps1")) + _ = os.Remove(filepath.Join(scriptDir, "kiro-oauth-handler.bat")) + + log.Info("Kiro protocol handler uninstalled") + return nil +} + +// --- macOS Implementation --- + +func getDarwinAppPath() string { + homeDir, _ := os.UserHomeDir() + return filepath.Join(homeDir, "Applications", "KiroOAuthHandler.app") +} + +func isDarwinHandlerInstalled() bool { + appPath := getDarwinAppPath() + _, err := os.Stat(appPath) + return err == nil +} + +func installDarwinHandler(handlerPort int) error { + // Create app bundle structure + appPath := getDarwinAppPath() + contentsPath := filepath.Join(appPath, "Contents") + macOSPath := filepath.Join(contentsPath, "MacOS") + + if err := os.MkdirAll(macOSPath, 0755); err != nil { + return fmt.Errorf("failed to create app bundle: %w", err) + } + + // Create Info.plist + plistPath := filepath.Join(contentsPath, "Info.plist") + plistContent := ` + +%s
You can close this window.
`, html.EscapeString(errParam)) + resultChan <- WebCallbackResult{Error: errParam} + return + } + + if state != expectedState { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, ` +Invalid state parameter
You can close this window.
`) + resultChan <- WebCallbackResult{Error: "state mismatch"} + return + } + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + fmt.Fprint(w, ` +You can close this window and return to the terminal.
+`) + resultChan <- WebCallbackResult{Code: code, State: state} + }) + + server.Handler = mux + + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + log.Debugf("kiro social auth callback server error: %v", err) + } + }() + + go func() { + select { + case <-ctx.Done(): + case <-time.After(socialAuthTimeout): + case <-resultChan: + } + _ = server.Shutdown(context.Background()) + }() + + return redirectURI, resultChan, nil +} + +// generatePKCE generates PKCE code verifier and challenge. +func generatePKCE() (verifier, challenge string, err error) { + // Generate 32 bytes of random data for verifier + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", "", fmt.Errorf("failed to generate random bytes: %w", err) + } + verifier = base64.RawURLEncoding.EncodeToString(b) + + // Generate SHA256 hash of verifier for challenge + h := sha256.Sum256([]byte(verifier)) + challenge = base64.RawURLEncoding.EncodeToString(h[:]) + + return verifier, challenge, nil +} + +// generateState generates a random state parameter. +func generateStateParam() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// buildLoginURL constructs the Kiro OAuth login URL. +// The login endpoint expects a GET request with query parameters. +// Format: /login?idp=Google&redirect_uri=...&code_challenge=...&code_challenge_method=S256&state=...&prompt=select_account +// The prompt=select_account parameter forces the account selection screen even if already logged in. +func (c *SocialAuthClient) buildLoginURL(provider, redirectURI, codeChallenge, state string) string { + return fmt.Sprintf("%s/login?idp=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&state=%s&prompt=select_account", + kiroAuthServiceEndpoint, + provider, + url.QueryEscape(redirectURI), + codeChallenge, + state, + ) +} + +// CreateToken exchanges the authorization code for tokens. +func (c *SocialAuthClient) CreateToken(ctx context.Context, req *CreateTokenRequest) (*SocialTokenResponse, error) { + body, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("failed to marshal token request: %w", err) + } + + tokenURL := kiroAuthServiceEndpoint + "/oauth/token" + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(string(body))) + if err != nil { + return nil, fmt.Errorf("failed to create token request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", c.kiroVersion, c.machineID)) + httpReq.Header.Set("Accept", "application/json, text/plain, */*") + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("token request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read token response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("token exchange failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token exchange failed (status %d)", resp.StatusCode) + } + + var tokenResp SocialTokenResponse + if err := json.Unmarshal(respBody, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + + return &tokenResp, nil +} + +// RefreshSocialToken refreshes an expired social auth token. +func (c *SocialAuthClient) RefreshSocialToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) { + body, err := json.Marshal(&RefreshTokenRequest{RefreshToken: refreshToken}) + if err != nil { + return nil, fmt.Errorf("failed to marshal refresh request: %w", err) + } + + refreshURL := kiroAuthServiceEndpoint + "/refreshToken" + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, strings.NewReader(string(body))) + if err != nil { + return nil, fmt.Errorf("failed to create refresh request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", c.kiroVersion, c.machineID)) + httpReq.Header.Set("Accept", "application/json, text/plain, */*") + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("refresh request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read refresh response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode) + } + + var tokenResp SocialTokenResponse + if err := json.Unmarshal(respBody, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse refresh response: %w", err) + } + + // Validate ExpiresIn - use default 1 hour if invalid + expiresIn := tokenResp.ExpiresIn + if expiresIn <= 0 { + expiresIn = 3600 // Default 1 hour + } + expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: tokenResp.ProfileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "social", + Provider: "", // Caller should preserve original provider + Region: "us-east-1", + }, nil +} + +// LoginWithSocial performs OAuth login with Google or GitHub. +// Uses local HTTP callback server instead of custom protocol handler to avoid redirect_mismatch errors. +func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialProvider) (*KiroTokenData, error) { + providerName := string(provider) + + fmt.Println("\n╔══════════════════════════════════════════════════════════╗") + fmt.Printf("║ Kiro Authentication (%s) ║\n", providerName) + fmt.Println("╚══════════════════════════════════════════════════════════╝") + + // Step 1: Start local HTTP callback server (instead of kiro:// protocol handler) + // This avoids redirect_mismatch errors with AWS Cognito + fmt.Println("\nSetting up authentication...") + + // Step 2: Generate PKCE codes + codeVerifier, codeChallenge, err := generatePKCE() + if err != nil { + return nil, fmt.Errorf("failed to generate PKCE: %w", err) + } + + // Step 3: Generate state + state, err := generateStateParam() + if err != nil { + return nil, fmt.Errorf("failed to generate state: %w", err) + } + + // Step 4: Start local HTTP callback server + redirectURI, resultChan, err := c.startWebCallbackServer(ctx, state) + if err != nil { + return nil, fmt.Errorf("failed to start callback server: %w", err) + } + log.Debugf("kiro social auth: callback server started at %s", redirectURI) + + // Step 5: Build the login URL using HTTP redirect URI + authURL := c.buildLoginURL(providerName, redirectURI, codeChallenge, state) + + // Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito) + // Incognito mode enables multi-account support by bypassing cached sessions + if c.cfg != nil { + browser.SetIncognitoMode(c.cfg.IncognitoBrowser) + if !c.cfg.IncognitoBrowser { + log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.") + } else { + log.Debug("kiro: using incognito mode for multi-account support") + } + } else { + browser.SetIncognitoMode(true) // Default to incognito if no config + log.Debug("kiro: using incognito mode for multi-account support (default)") + } + + // Step 6: Open browser for user authentication + fmt.Println("\n════════════════════════════════════════════════════════════") + fmt.Printf(" Opening browser for %s authentication...\n", providerName) + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf("\n URL: %s\n\n", authURL) + + if err := browser.OpenURL(authURL); err != nil { + log.Warnf("Could not open browser automatically: %v", err) + fmt.Println(" ⚠ Could not open browser automatically.") + fmt.Println(" Please open the URL above in your browser manually.") + } else { + fmt.Println(" (Browser opened automatically)") + } + + fmt.Println("\n Waiting for authentication callback...") + + // Step 7: Wait for callback from HTTP server + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(socialAuthTimeout): + return nil, fmt.Errorf("authentication timed out") + case callback := <-resultChan: + if callback.Error != "" { + return nil, fmt.Errorf("authentication error: %s", callback.Error) + } + + // State is already validated by the callback server + if callback.Code == "" { + return nil, fmt.Errorf("no authorization code received") + } + + fmt.Println("\n✓ Authorization received!") + + // Step 8: Exchange code for tokens + fmt.Println("Exchanging code for tokens...") + + tokenReq := &CreateTokenRequest{ + Code: callback.Code, + CodeVerifier: codeVerifier, + RedirectURI: redirectURI, // Use HTTP redirect URI, not kiro:// protocol + } + + tokenResp, err := c.CreateToken(ctx, tokenReq) + if err != nil { + return nil, fmt.Errorf("failed to exchange code for tokens: %w", err) + } + + fmt.Println("\n✓ Authentication successful!") + + // Close the browser window + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser: %v", err) + } + + // Validate ExpiresIn - use default 1 hour if invalid + expiresIn := tokenResp.ExpiresIn + if expiresIn <= 0 { + expiresIn = 3600 + } + expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second) + + // Try to extract email from JWT access token first + email := ExtractEmailFromJWT(tokenResp.AccessToken) + + // If no email in JWT, ask user for account label (only in interactive mode) + if email == "" && isInteractiveTerminal() { + fmt.Print("\n Enter account label for file naming (optional, press Enter to skip): ") + reader := bufio.NewReader(os.Stdin) + var err error + email, err = reader.ReadString('\n') + if err != nil { + log.Debugf("Failed to read account label: %v", err) + } + email = strings.TrimSpace(email) + } + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: tokenResp.ProfileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "social", + Provider: providerName, + Email: email, // JWT email or user-provided label + Region: "us-east-1", + }, nil + } +} + +// LoginWithGoogle performs OAuth login with Google. +func (c *SocialAuthClient) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) { + return c.LoginWithSocial(ctx, ProviderGoogle) +} + +// LoginWithGitHub performs OAuth login with GitHub. +func (c *SocialAuthClient) LoginWithGitHub(ctx context.Context) (*KiroTokenData, error) { + return c.LoginWithSocial(ctx, ProviderGitHub) +} + +// forceDefaultProtocolHandler sets our protocol handler as the default for kiro:// URLs. +// This prevents the "Open with" dialog from appearing on Linux. +// On non-Linux platforms, this is a no-op as they use different mechanisms. +func forceDefaultProtocolHandler() { + if runtime.GOOS != "linux" { + return // Non-Linux platforms use different handler mechanisms + } + + // Set our handler as default using xdg-mime + cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro") + if err := cmd.Run(); err != nil { + log.Warnf("Failed to set default protocol handler: %v. You may see a handler selection dialog.", err) + } +} + +// isInteractiveTerminal checks if stdin is connected to an interactive terminal. +// Returns false in CI/automated environments or when stdin is piped. +func isInteractiveTerminal() bool { + return term.IsTerminal(int(os.Stdin.Fd())) +} diff --git a/internal/auth/kiro/sso_oidc.go b/internal/auth/kiro/sso_oidc.go new file mode 100644 index 0000000000..22d8648e4d --- /dev/null +++ b/internal/auth/kiro/sso_oidc.go @@ -0,0 +1,1603 @@ +// Package kiro provides AWS SSO OIDC authentication for Kiro. +package kiro + +import ( + "bufio" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "html" + "io" + "net" + "net/http" + "net/url" + "os" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/browser" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + "github.com/router-for-me/CLIProxyAPI/v7/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + // AWS SSO OIDC endpoints + ssoOIDCEndpoint = "https://oidc.us-east-1.amazonaws.com" + + // Kiro's start URL for Builder ID + builderIDStartURL = "https://view.awsapps.com/start" + + // Default region for IDC + defaultIDCRegion = "us-east-1" + + // Polling interval + pollInterval = 5 * time.Second + + // Authorization code flow callback + authCodeCallbackPath = "/oauth/callback" + authCodeCallbackPort = 19877 +) + +var ( + ErrAuthorizationPending = errors.New("authorization_pending") + ErrSlowDown = errors.New("slow_down") +) + +type SSOOIDCClient struct { + httpClient *http.Client + cfg *config.Config +} + +// NewSSOOIDCClient creates a new SSO OIDC client. +func NewSSOOIDCClient(cfg *config.Config) *SSOOIDCClient { + client := &http.Client{Timeout: 30 * time.Second} + if cfg != nil { + client = util.SetProxy(&cfg.SDKConfig, client) + } + return &SSOOIDCClient{ + httpClient: client, + cfg: cfg, + } +} + +// RegisterClientResponse from AWS SSO OIDC. +type RegisterClientResponse struct { + ClientID string `json:"clientId"` + ClientSecret string `json:"clientSecret"` + ClientIDIssuedAt int64 `json:"clientIdIssuedAt"` + ClientSecretExpiresAt int64 `json:"clientSecretExpiresAt"` +} + +// StartDeviceAuthResponse from AWS SSO OIDC. +type StartDeviceAuthResponse struct { + DeviceCode string `json:"deviceCode"` + UserCode string `json:"userCode"` + VerificationURI string `json:"verificationUri"` + VerificationURIComplete string `json:"verificationUriComplete"` + ExpiresIn int `json:"expiresIn"` + Interval int `json:"interval"` +} + +// CreateTokenResponse from AWS SSO OIDC. +type CreateTokenResponse struct { + AccessToken string `json:"accessToken"` + TokenType string `json:"tokenType"` + ExpiresIn int `json:"expiresIn"` + RefreshToken string `json:"refreshToken"` +} + +// getOIDCEndpoint returns the OIDC endpoint for the given region. +func getOIDCEndpoint(region string) string { + if region == "" { + region = defaultIDCRegion + } + return fmt.Sprintf("https://oidc.%s.amazonaws.com", region) +} + +// promptInput prompts the user for input with an optional default value. +func promptInput(prompt, defaultValue string) string { + reader := bufio.NewReader(os.Stdin) + if defaultValue != "" { + fmt.Printf("%s [%s]: ", prompt, defaultValue) + } else { + fmt.Printf("%s: ", prompt) + } + input, err := reader.ReadString('\n') + if err != nil { + log.Warnf("Error reading input: %v", err) + return defaultValue + } + input = strings.TrimSpace(input) + if input == "" { + return defaultValue + } + return input +} + +// promptSelect prompts the user to select from options using number input. +func promptSelect(prompt string, options []string) int { + reader := bufio.NewReader(os.Stdin) + + for { + fmt.Println(prompt) + for i, opt := range options { + fmt.Printf(" %d) %s\n", i+1, opt) + } + fmt.Printf("Enter selection (1-%d): ", len(options)) + + input, err := reader.ReadString('\n') + if err != nil { + log.Warnf("Error reading input: %v", err) + return 0 // Default to first option on error + } + input = strings.TrimSpace(input) + + // Parse the selection + var selection int + if _, err := fmt.Sscanf(input, "%d", &selection); err != nil || selection < 1 || selection > len(options) { + fmt.Printf("Invalid selection '%s'. Please enter a number between 1 and %d.\n\n", input, len(options)) + continue + } + return selection - 1 + } +} + +// RegisterClientWithRegion registers a new OIDC client with AWS using a specific region. +func (c *SSOOIDCClient) RegisterClientWithRegion(ctx context.Context, region string) (*RegisterClientResponse, error) { + endpoint := getOIDCEndpoint(region) + + payload := map[string]interface{}{ + "clientName": "Kiro IDE", + "clientType": "public", + "scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"}, + "grantTypes": []string{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"}, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/client/register", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + SetOIDCHeaders(req) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("register client failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode) + } + + var result RegisterClientResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// StartDeviceAuthorizationWithIDC starts the device authorization flow for IDC. +func (c *SSOOIDCClient) StartDeviceAuthorizationWithIDC(ctx context.Context, clientID, clientSecret, startURL, region string) (*StartDeviceAuthResponse, error) { + endpoint := getOIDCEndpoint(region) + + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "startUrl": startURL, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/device_authorization", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + SetOIDCHeaders(req) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("start device auth failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("start device auth failed (status %d)", resp.StatusCode) + } + + var result StartDeviceAuthResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// CreateTokenWithRegion polls for the access token after user authorization using a specific region. +func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, clientSecret, deviceCode, region string) (*CreateTokenResponse, error) { + endpoint := getOIDCEndpoint(region) + + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "deviceCode": deviceCode, + "grantType": "urn:ietf:params:oauth:grant-type:device_code", + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + SetOIDCHeaders(req) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + // Check for pending authorization + if resp.StatusCode == http.StatusBadRequest { + var errResp struct { + Error string `json:"error"` + } + if json.Unmarshal(respBody, &errResp) == nil { + if errResp.Error == "authorization_pending" { + return nil, ErrAuthorizationPending + } + if errResp.Error == "slow_down" { + return nil, ErrSlowDown + } + } + log.Debugf("create token failed: %s", string(respBody)) + return nil, fmt.Errorf("create token failed") + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("create token failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode) + } + + var result CreateTokenResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// RefreshTokenWithRegion refreshes an access token using the refresh token with a specific OIDC region. +func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, clientSecret, refreshToken, region, startURL string) (*KiroTokenData, error) { + if region == "" { + region = defaultIDCRegion + } + endpoint := getOIDCEndpoint(region) + + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "refreshToken": refreshToken, + "grantType": "refresh_token", + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + SetOIDCHeaders(req) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Warnf("IDC token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode) + } + + var result CreateTokenResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + expiresAt := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: result.AccessToken, + RefreshToken: result.RefreshToken, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "idc", + Provider: "AWS", + ClientID: clientID, + ClientSecret: clientSecret, + StartURL: startURL, + Region: region, + }, nil +} + +// LoginWithIDC performs the full device code flow for AWS Identity Center (IDC). +func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region string) (*KiroTokenData, error) { + fmt.Println("\n╔══════════════════════════════════════════════════════════╗") + fmt.Println("║ Kiro Authentication (AWS Identity Center) ║") + fmt.Println("╚══════════════════════════════════════════════════════════╝") + + // Step 1: Register client with the specified region + fmt.Println("\nRegistering client...") + regResp, err := c.RegisterClientWithRegion(ctx, region) + if err != nil { + return nil, fmt.Errorf("failed to register client: %w", err) + } + log.Debugf("Client registered: %s", regResp.ClientID) + + // Step 2: Start device authorization with IDC start URL + fmt.Println("Starting device authorization...") + authResp, err := c.StartDeviceAuthorizationWithIDC(ctx, regResp.ClientID, regResp.ClientSecret, startURL, region) + if err != nil { + return nil, fmt.Errorf("failed to start device auth: %w", err) + } + + // Step 3: Show user the verification URL + fmt.Printf("\n") + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf(" Confirm the following code in the browser:\n") + fmt.Printf(" Code: %s\n", authResp.UserCode) + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf("\n Open this URL: %s\n\n", authResp.VerificationURIComplete) + + // Set incognito mode based on config + if c.cfg != nil { + browser.SetIncognitoMode(c.cfg.IncognitoBrowser) + if !c.cfg.IncognitoBrowser { + log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.") + } else { + log.Debug("kiro: using incognito mode for multi-account support") + } + } else { + browser.SetIncognitoMode(true) + log.Debug("kiro: using incognito mode for multi-account support (default)") + } + + // Open browser + if err := browser.OpenURL(authResp.VerificationURIComplete); err != nil { + log.Warnf("Could not open browser automatically: %v", err) + fmt.Println(" Please open the URL manually in your browser.") + } else { + fmt.Println(" (Browser opened automatically)") + } + + // Step 4: Poll for token + fmt.Println("Waiting for authorization...") + + interval := pollInterval + if authResp.Interval > 0 { + interval = time.Duration(authResp.Interval) * time.Second + } + + deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second) + + for time.Now().Before(deadline) { + select { + case <-ctx.Done(): + browser.CloseBrowser() + return nil, ctx.Err() + case <-time.After(interval): + tokenResp, err := c.CreateTokenWithRegion(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode, region) + if err != nil { + if errors.Is(err, ErrAuthorizationPending) { + fmt.Print(".") + continue + } + if errors.Is(err, ErrSlowDown) { + interval += 5 * time.Second + continue + } + browser.CloseBrowser() + return nil, fmt.Errorf("token creation failed: %w", err) + } + + fmt.Println("\n\n✓ Authorization successful!") + + // Close the browser window + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser: %v", err) + } + + // Step 5: Get profile ARN from CodeWhisperer API + fmt.Println("Fetching profile information...") + profileArn := c.FetchProfileArn(ctx, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken) + + // Fetch user email + email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken) + if email != "" { + fmt.Printf(" Logged in as: %s\n", email) + } + + expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: profileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "idc", + Provider: "AWS", + ClientID: regResp.ClientID, + ClientSecret: regResp.ClientSecret, + Email: email, + StartURL: startURL, + Region: region, + }, nil + } + } + + // Close browser on timeout + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser on timeout: %v", err) + } + return nil, fmt.Errorf("authorization timed out") +} + +// IDCLoginOptions holds optional parameters for IDC login. +type IDCLoginOptions struct { + StartURL string // Pre-configured start URL (skips prompt if set) + Region string // OIDC region for login and token refresh (defaults to us-east-1) + UseDeviceCode bool // Use Device Code flow instead of Auth Code flow +} + +// LoginWithMethodSelection prompts the user to select between Builder ID and IDC, then performs the login. +// Options can be provided to pre-configure IDC parameters (startURL, region). +// If StartURL is provided in opts, IDC flow is used directly without prompting. +func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context, opts *IDCLoginOptions) (*KiroTokenData, error) { + fmt.Println("\n╔══════════════════════════════════════════════════════════╗") + fmt.Println("║ Kiro Authentication (AWS) ║") + fmt.Println("╚══════════════════════════════════════════════════════════╝") + + // If IDC options with StartURL are provided, skip method selection and use IDC directly + if opts != nil && opts.StartURL != "" { + region := opts.Region + if region == "" { + region = defaultIDCRegion + } + fmt.Printf("\n Using IDC with Start URL: %s\n", opts.StartURL) + fmt.Printf(" Region: %s\n", region) + + if opts.UseDeviceCode { + return c.LoginWithIDCAndOptions(ctx, opts.StartURL, region) + } + return c.LoginWithIDCAuthCode(ctx, opts.StartURL, region) + } + + // Prompt for login method + options := []string{ + "Use with Builder ID (personal AWS account)", + "Use with IDC Account (organization SSO)", + } + selection := promptSelect("\n? Select login method:", options) + + if selection == 0 { + // Builder ID flow - use existing implementation + return c.LoginWithBuilderID(ctx) + } + + // IDC flow - use pre-configured values or prompt + var startURL, region string + + if opts != nil { + startURL = opts.StartURL + region = opts.Region + } + + fmt.Println() + + // Use pre-configured startURL or prompt + if startURL == "" { + startURL = promptInput("? Enter Start URL", "") + if startURL == "" { + return nil, fmt.Errorf("start URL is required for IDC login") + } + } else { + fmt.Printf(" Using pre-configured Start URL: %s\n", startURL) + } + + // Use pre-configured region or prompt + if region == "" { + region = promptInput("? Enter Region", defaultIDCRegion) + } else { + fmt.Printf(" Using pre-configured Region: %s\n", region) + } + + if opts != nil && opts.UseDeviceCode { + return c.LoginWithIDCAndOptions(ctx, startURL, region) + } + return c.LoginWithIDCAuthCode(ctx, startURL, region) +} + +// LoginWithIDCAndOptions performs IDC login with the specified region. +func (c *SSOOIDCClient) LoginWithIDCAndOptions(ctx context.Context, startURL, region string) (*KiroTokenData, error) { + return c.LoginWithIDC(ctx, startURL, region) +} + +// RegisterClient registers a new OIDC client with AWS. +func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResponse, error) { + payload := map[string]interface{}{ + "clientName": "Kiro IDE", + "clientType": "public", + "scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"}, + "grantTypes": []string{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"}, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/client/register", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + SetOIDCHeaders(req) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("register client failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode) + } + + var result RegisterClientResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// StartDeviceAuthorization starts the device authorization flow. +func (c *SSOOIDCClient) StartDeviceAuthorization(ctx context.Context, clientID, clientSecret string) (*StartDeviceAuthResponse, error) { + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "startUrl": builderIDStartURL, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/device_authorization", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + SetOIDCHeaders(req) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("start device auth failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("start device auth failed (status %d)", resp.StatusCode) + } + + var result StartDeviceAuthResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// CreateToken polls for the access token after user authorization. +func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret, deviceCode string) (*CreateTokenResponse, error) { + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "deviceCode": deviceCode, + "grantType": "urn:ietf:params:oauth:grant-type:device_code", + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + SetOIDCHeaders(req) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + // Check for pending authorization + if resp.StatusCode == http.StatusBadRequest { + var errResp struct { + Error string `json:"error"` + } + if json.Unmarshal(respBody, &errResp) == nil { + if errResp.Error == "authorization_pending" { + return nil, ErrAuthorizationPending + } + if errResp.Error == "slow_down" { + return nil, ErrSlowDown + } + } + log.Debugf("create token failed: %s", string(respBody)) + return nil, fmt.Errorf("create token failed") + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("create token failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode) + } + + var result CreateTokenResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// RefreshToken refreshes an access token using the refresh token. +// Includes retry logic and improved error handling for better reliability. +func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret, refreshToken string) (*KiroTokenData, error) { + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "refreshToken": refreshToken, + "grantType": "refresh_token", + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + SetOIDCHeaders(req) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Warnf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + } + + var result CreateTokenResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + expiresAt := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: result.AccessToken, + RefreshToken: result.RefreshToken, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "builder-id", + Provider: "AWS", + ClientID: clientID, + ClientSecret: clientSecret, + Region: defaultIDCRegion, + }, nil +} + +// LoginWithBuilderID performs the full device code flow for AWS Builder ID. +func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, error) { + fmt.Println("\n╔══════════════════════════════════════════════════════════╗") + fmt.Println("║ Kiro Authentication (AWS Builder ID) ║") + fmt.Println("╚══════════════════════════════════════════════════════════╝") + + // Step 1: Register client + fmt.Println("\nRegistering client...") + regResp, err := c.RegisterClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to register client: %w", err) + } + log.Debugf("Client registered: %s", regResp.ClientID) + + // Step 2: Start device authorization + fmt.Println("Starting device authorization...") + authResp, err := c.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret) + if err != nil { + return nil, fmt.Errorf("failed to start device auth: %w", err) + } + + // Step 3: Show user the verification URL + fmt.Printf("\n") + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf(" Open this URL in your browser:\n") + fmt.Printf(" %s\n", authResp.VerificationURIComplete) + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf("\n Or go to: %s\n", authResp.VerificationURI) + fmt.Printf(" And enter code: %s\n\n", authResp.UserCode) + + // Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito) + // Incognito mode enables multi-account support by bypassing cached sessions + if c.cfg != nil { + browser.SetIncognitoMode(c.cfg.IncognitoBrowser) + if !c.cfg.IncognitoBrowser { + log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.") + } else { + log.Debug("kiro: using incognito mode for multi-account support") + } + } else { + browser.SetIncognitoMode(true) // Default to incognito if no config + log.Debug("kiro: using incognito mode for multi-account support (default)") + } + + // Open browser using cross-platform browser package + if err := browser.OpenURL(authResp.VerificationURIComplete); err != nil { + log.Warnf("Could not open browser automatically: %v", err) + fmt.Println(" Please open the URL manually in your browser.") + } else { + fmt.Println(" (Browser opened automatically)") + } + + // Step 4: Poll for token + fmt.Println("Waiting for authorization...") + + interval := pollInterval + if authResp.Interval > 0 { + interval = time.Duration(authResp.Interval) * time.Second + } + + deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second) + + for time.Now().Before(deadline) { + select { + case <-ctx.Done(): + browser.CloseBrowser() // Cleanup on cancel + return nil, ctx.Err() + case <-time.After(interval): + tokenResp, err := c.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode) + if err != nil { + if errors.Is(err, ErrAuthorizationPending) { + fmt.Print(".") + continue + } + if errors.Is(err, ErrSlowDown) { + interval += 5 * time.Second + continue + } + // Close browser on error before returning + browser.CloseBrowser() + return nil, fmt.Errorf("token creation failed: %w", err) + } + + fmt.Println("\n\n✓ Authorization successful!") + + // Close the browser window + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser: %v", err) + } + + // Fetch user email (tries CodeWhisperer API first, then userinfo endpoint, then JWT parsing) + email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken) + if email != "" { + fmt.Printf(" Logged in as: %s\n", email) + } + + expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: "", // Builder ID has no profile + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "builder-id", + Provider: "AWS", + ClientID: regResp.ClientID, + ClientSecret: regResp.ClientSecret, + Email: email, + Region: defaultIDCRegion, + }, nil + } + } + + // Close browser on timeout for better UX + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser on timeout: %v", err) + } + return nil, fmt.Errorf("authorization timed out") +} + +// FetchUserEmail retrieves the user's email from AWS SSO OIDC userinfo endpoint. +// Falls back to JWT parsing if userinfo fails. +func (c *SSOOIDCClient) FetchUserEmail(ctx context.Context, accessToken string) string { + // Method 1: Try userinfo endpoint (standard OIDC) + email := c.tryUserInfoEndpoint(ctx, accessToken) + if email != "" { + return email + } + + // Method 2: Fallback to JWT parsing + return ExtractEmailFromJWT(accessToken) +} + +// tryUserInfoEndpoint attempts to get user info from AWS SSO OIDC userinfo endpoint. +func (c *SSOOIDCClient) tryUserInfoEndpoint(ctx context.Context, accessToken string) string { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, ssoOIDCEndpoint+"/userinfo", nil) + if err != nil { + return "" + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + log.Debugf("userinfo request failed: %v", err) + return "" + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + log.Debugf("userinfo endpoint returned status %d: %s", resp.StatusCode, string(respBody)) + return "" + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return "" + } + + log.Debugf("userinfo response: %s", string(respBody)) + + var userInfo struct { + Email string `json:"email"` + Sub string `json:"sub"` + PreferredUsername string `json:"preferred_username"` + Name string `json:"name"` + } + + if err := json.Unmarshal(respBody, &userInfo); err != nil { + return "" + } + + if userInfo.Email != "" { + return userInfo.Email + } + if userInfo.PreferredUsername != "" && strings.Contains(userInfo.PreferredUsername, "@") { + return userInfo.PreferredUsername + } + return "" +} + +// FetchProfileArn fetches the profile ARN from ListAvailableProfiles API. +// This is used to get profileArn for imported accounts that may not have it. +func (c *SSOOIDCClient) FetchProfileArn(ctx context.Context, accessToken, clientID, refreshToken string) string { + profileArn := c.tryListAvailableProfiles(ctx, accessToken, clientID, refreshToken) + if profileArn != "" { + return profileArn + } + return c.tryListProfilesLegacy(ctx, accessToken) +} + +func (c *SSOOIDCClient) tryListAvailableProfiles(ctx context.Context, accessToken, clientID, refreshToken string) string { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, GetKiroAPIEndpoint("")+"/ListAvailableProfiles", strings.NewReader("{}")) + if err != nil { + return "" + } + + req.Header.Set("Content-Type", "application/json") + accountKey := GetAccountKey(clientID, refreshToken) + setRuntimeHeaders(req, accessToken, accountKey) + + resp, err := c.httpClient.Do(req) + if err != nil { + log.Debugf("ListAvailableProfiles request failed: %v", err) + return "" + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + log.Debugf("ListAvailableProfiles failed (status %d): %s", resp.StatusCode, string(respBody)) + return "" + } + + log.Debugf("ListAvailableProfiles response: %s", string(respBody)) + + var result struct { + Profiles []struct { + Arn string `json:"arn"` + ProfileName string `json:"profileName"` + } `json:"profiles"` + NextToken *string `json:"nextToken"` + } + + if err := json.Unmarshal(respBody, &result); err != nil { + log.Debugf("ListAvailableProfiles parse error: %v", err) + return "" + } + + if len(result.Profiles) > 0 { + log.Debugf("Found profile: %s (%s)", result.Profiles[0].ProfileName, result.Profiles[0].Arn) + return result.Profiles[0].Arn + } + + return "" +} + +func (c *SSOOIDCClient) tryListProfilesLegacy(ctx context.Context, accessToken string) string { + payload := map[string]interface{}{ + "origin": "AI_EDITOR", + } + + body, err := json.Marshal(payload) + if err != nil { + return "" + } + + // Use the legacy CodeWhisperer endpoint for JSON-RPC style requests. + // The Q endpoint (q.{region}.amazonaws.com) does NOT support x-amz-target headers. + req, err := http.NewRequestWithContext(ctx, http.MethodPost, GetCodeWhispererLegacyEndpoint(""), strings.NewReader(string(body))) + if err != nil { + return "" + } + + req.Header.Set("Content-Type", "application/x-amz-json-1.0") + req.Header.Set("x-amz-target", "AmazonCodeWhispererService.ListProfiles") + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return "" + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + log.Debugf("ListProfiles (legacy) failed (status %d): %s", resp.StatusCode, string(respBody)) + return "" + } + + log.Debugf("ListProfiles (legacy) response: %s", string(respBody)) + + var result struct { + Profiles []struct { + Arn string `json:"arn"` + } `json:"profiles"` + ProfileArn string `json:"profileArn"` + } + + if err := json.Unmarshal(respBody, &result); err != nil { + return "" + } + + if result.ProfileArn != "" { + return result.ProfileArn + } + + if len(result.Profiles) > 0 { + return result.Profiles[0].Arn + } + + return "" +} + +// RegisterClientForAuthCode registers a new OIDC client for authorization code flow. +func (c *SSOOIDCClient) RegisterClientForAuthCode(ctx context.Context, redirectURI string) (*RegisterClientResponse, error) { + payload := map[string]interface{}{ + "clientName": "Kiro IDE", + "clientType": "public", + "scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"}, + "grantTypes": []string{"authorization_code", "refresh_token"}, + "redirectUris": []string{redirectURI}, + "issuerUrl": builderIDStartURL, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/client/register", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + SetOIDCHeaders(req) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("register client for auth code failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode) + } + + var result RegisterClientResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +func (c *SSOOIDCClient) RegisterClientForAuthCodeWithIDC(ctx context.Context, redirectURI, issuerUrl, region string) (*RegisterClientResponse, error) { + endpoint := getOIDCEndpoint(region) + + payload := map[string]interface{}{ + "clientName": "Kiro IDE", + "clientType": "public", + "scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"}, + "grantTypes": []string{"authorization_code", "refresh_token"}, + "redirectUris": []string{redirectURI}, + "issuerUrl": issuerUrl, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/client/register", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + SetOIDCHeaders(req) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("register client for auth code with IDC failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode) + } + + var result RegisterClientResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// AuthCodeCallbackResult contains the result from authorization code callback. +type AuthCodeCallbackResult struct { + Code string + State string + Error string +} + +// startAuthCodeCallbackServer starts a local HTTP server to receive the authorization code callback. +func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expectedState string) (string, <-chan AuthCodeCallbackResult, error) { + // Try to find an available port + listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", authCodeCallbackPort)) + if err != nil { + // Try with dynamic port + log.Warnf("sso oidc: default port %d is busy, falling back to dynamic port", authCodeCallbackPort) + listener, err = net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return "", nil, fmt.Errorf("failed to start callback server: %w", err) + } + } + + port := listener.Addr().(*net.TCPAddr).Port + redirectURI := fmt.Sprintf("http://127.0.0.1:%d%s", port, authCodeCallbackPath) + resultChan := make(chan AuthCodeCallbackResult, 1) + doneChan := make(chan struct{}) + + server := &http.Server{ + ReadHeaderTimeout: 10 * time.Second, + } + + mux := http.NewServeMux() + mux.HandleFunc(authCodeCallbackPath, func(w http.ResponseWriter, r *http.Request) { + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") + errParam := r.URL.Query().Get("error") + + // Send response to browser + w.Header().Set("Content-Type", "text/html; charset=utf-8") + if errParam != "" { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, ` +Error: %s
You can close this window.
`, html.EscapeString(errParam)) + resultChan <- AuthCodeCallbackResult{Error: errParam} + close(doneChan) + return + } + + if state != expectedState { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, ` +Invalid state parameter
You can close this window.
`) + resultChan <- AuthCodeCallbackResult{Error: "state mismatch"} + close(doneChan) + return + } + + fmt.Fprint(w, ` +You can close this window and return to the terminal.
+`) + resultChan <- AuthCodeCallbackResult{Code: code, State: state} + close(doneChan) + }) + + server.Handler = mux + + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + log.Debugf("auth code callback server error: %v", err) + } + }() + + go func() { + select { + case <-ctx.Done(): + case <-time.After(10 * time.Minute): + case <-doneChan: + } + _ = server.Shutdown(context.Background()) + }() + + return redirectURI, resultChan, nil +} + +// generatePKCEForAuthCode generates PKCE code verifier and challenge for authorization code flow. +func generatePKCEForAuthCode() (verifier, challenge string, err error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", "", fmt.Errorf("failed to generate random bytes: %w", err) + } + verifier = base64.RawURLEncoding.EncodeToString(b) + h := sha256.Sum256([]byte(verifier)) + challenge = base64.RawURLEncoding.EncodeToString(h[:]) + return verifier, challenge, nil +} + +// generateStateForAuthCode generates a random state parameter. +func generateStateForAuthCode() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// CreateTokenWithAuthCode exchanges authorization code for tokens. +func (c *SSOOIDCClient) CreateTokenWithAuthCode(ctx context.Context, clientID, clientSecret, code, codeVerifier, redirectURI string) (*CreateTokenResponse, error) { + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "code": code, + "codeVerifier": codeVerifier, + "redirectUri": redirectURI, + "grantType": "authorization_code", + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + SetOIDCHeaders(req) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("create token with auth code failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode) + } + + var result CreateTokenResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +func (c *SSOOIDCClient) CreateTokenWithAuthCodeAndRegion(ctx context.Context, clientID, clientSecret, code, codeVerifier, redirectURI, region string) (*CreateTokenResponse, error) { + endpoint := getOIDCEndpoint(region) + + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "code": code, + "codeVerifier": codeVerifier, + "redirectUri": redirectURI, + "grantType": "authorization_code", + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + SetOIDCHeaders(req) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("create token with auth code failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode) + } + + var result CreateTokenResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// LoginWithBuilderIDAuthCode performs the authorization code flow for AWS Builder ID. +// This provides a better UX than device code flow as it uses automatic browser callback. +func (c *SSOOIDCClient) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTokenData, error) { + fmt.Println("\n╔══════════════════════════════════════════════════════════╗") + fmt.Println("║ Kiro Authentication (AWS Builder ID - Auth Code) ║") + fmt.Println("╚══════════════════════════════════════════════════════════╝") + + // Step 1: Generate PKCE and state + codeVerifier, codeChallenge, err := generatePKCEForAuthCode() + if err != nil { + return nil, fmt.Errorf("failed to generate PKCE: %w", err) + } + + state, err := generateStateForAuthCode() + if err != nil { + return nil, fmt.Errorf("failed to generate state: %w", err) + } + + // Step 2: Start callback server + fmt.Println("\nStarting callback server...") + redirectURI, resultChan, err := c.startAuthCodeCallbackServer(ctx, state) + if err != nil { + return nil, fmt.Errorf("failed to start callback server: %w", err) + } + log.Debugf("Callback server started, redirect URI: %s", redirectURI) + + // Step 3: Register client with auth code grant type + fmt.Println("Registering client...") + regResp, err := c.RegisterClientForAuthCode(ctx, redirectURI) + if err != nil { + return nil, fmt.Errorf("failed to register client: %w", err) + } + log.Debugf("Client registered: %s", regResp.ClientID) + + // Step 4: Build authorization URL + scopes := "codewhisperer:completions,codewhisperer:analysis,codewhisperer:conversations" + authURL := fmt.Sprintf("%s/authorize?response_type=code&client_id=%s&redirect_uri=%s&scopes=%s&state=%s&code_challenge=%s&code_challenge_method=S256", + ssoOIDCEndpoint, + regResp.ClientID, + redirectURI, + scopes, + state, + codeChallenge, + ) + + // Step 5: Open browser + fmt.Println("\n════════════════════════════════════════════════════════════") + fmt.Println(" Opening browser for authentication...") + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf("\n URL: %s\n\n", authURL) + + // Set incognito mode + if c.cfg != nil { + browser.SetIncognitoMode(c.cfg.IncognitoBrowser) + } else { + browser.SetIncognitoMode(true) + } + + if err := browser.OpenURL(authURL); err != nil { + log.Warnf("Could not open browser automatically: %v", err) + fmt.Println(" ⚠ Could not open browser automatically.") + fmt.Println(" Please open the URL above in your browser manually.") + } else { + fmt.Println(" (Browser opened automatically)") + } + + fmt.Println("\n Waiting for authorization callback...") + + // Step 6: Wait for callback + select { + case <-ctx.Done(): + browser.CloseBrowser() + return nil, ctx.Err() + case <-time.After(10 * time.Minute): + browser.CloseBrowser() + return nil, fmt.Errorf("authorization timed out") + case result := <-resultChan: + if result.Error != "" { + browser.CloseBrowser() + return nil, fmt.Errorf("authorization failed: %s", result.Error) + } + + fmt.Println("\n✓ Authorization received!") + + // Close browser + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser: %v", err) + } + + // Step 7: Exchange code for tokens + fmt.Println("Exchanging code for tokens...") + tokenResp, err := c.CreateTokenWithAuthCode(ctx, regResp.ClientID, regResp.ClientSecret, result.Code, codeVerifier, redirectURI) + if err != nil { + return nil, fmt.Errorf("failed to exchange code for tokens: %w", err) + } + + fmt.Println("\n✓ Authentication successful!") + + // Fetch user email (tries CodeWhisperer API first, then userinfo endpoint, then JWT parsing) + email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken) + if email != "" { + fmt.Printf(" Logged in as: %s\n", email) + } + + expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: "", // Builder ID has no profile + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "builder-id", + Provider: "AWS", + ClientID: regResp.ClientID, + ClientSecret: regResp.ClientSecret, + Email: email, + Region: defaultIDCRegion, + }, nil + } +} + +func (c *SSOOIDCClient) LoginWithIDCAuthCode(ctx context.Context, startURL, region string) (*KiroTokenData, error) { + fmt.Println("\n╔══════════════════════════════════════════════════════════╗") + fmt.Println("║ Kiro Authentication (AWS IDC - Auth Code) ║") + fmt.Println("╚══════════════════════════════════════════════════════════╝") + + if region == "" { + region = defaultIDCRegion + } + + codeVerifier, codeChallenge, err := generatePKCEForAuthCode() + if err != nil { + return nil, fmt.Errorf("failed to generate PKCE: %w", err) + } + + state, err := generateStateForAuthCode() + if err != nil { + return nil, fmt.Errorf("failed to generate state: %w", err) + } + + fmt.Println("\nStarting callback server...") + redirectURI, resultChan, err := c.startAuthCodeCallbackServer(ctx, state) + if err != nil { + return nil, fmt.Errorf("failed to start callback server: %w", err) + } + log.Debugf("Callback server started, redirect URI: %s", redirectURI) + + fmt.Println("Registering client...") + regResp, err := c.RegisterClientForAuthCodeWithIDC(ctx, redirectURI, startURL, region) + if err != nil { + return nil, fmt.Errorf("failed to register client: %w", err) + } + log.Debugf("Client registered: %s", regResp.ClientID) + + endpoint := getOIDCEndpoint(region) + scopes := "codewhisperer:completions,codewhisperer:analysis,codewhisperer:conversations,codewhisperer:transformations,codewhisperer:taskassist" + authURL := buildAuthorizationURL(endpoint, regResp.ClientID, redirectURI, scopes, state, codeChallenge) + + fmt.Println("\n════════════════════════════════════════════════════════════") + fmt.Println(" Opening browser for authentication...") + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf("\n URL: %s\n\n", authURL) + + if c.cfg != nil { + browser.SetIncognitoMode(c.cfg.IncognitoBrowser) + } else { + browser.SetIncognitoMode(true) + } + + if err := browser.OpenURL(authURL); err != nil { + log.Warnf("Could not open browser automatically: %v", err) + fmt.Println(" ⚠ Could not open browser automatically.") + fmt.Println(" Please open the URL above in your browser manually.") + } else { + fmt.Println(" (Browser opened automatically)") + } + + fmt.Println("\n Waiting for authorization callback...") + + select { + case <-ctx.Done(): + browser.CloseBrowser() + return nil, ctx.Err() + case <-time.After(10 * time.Minute): + browser.CloseBrowser() + return nil, fmt.Errorf("authorization timed out") + case result := <-resultChan: + if result.Error != "" { + browser.CloseBrowser() + return nil, fmt.Errorf("authorization failed: %s", result.Error) + } + + fmt.Println("\n✓ Authorization received!") + + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser: %v", err) + } + + fmt.Println("Exchanging code for tokens...") + tokenResp, err := c.CreateTokenWithAuthCodeAndRegion(ctx, regResp.ClientID, regResp.ClientSecret, result.Code, codeVerifier, redirectURI, region) + if err != nil { + return nil, fmt.Errorf("failed to exchange code for tokens: %w", err) + } + + fmt.Println("\n✓ Authentication successful!") + + fmt.Println("Fetching profile information...") + profileArn := c.FetchProfileArn(ctx, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken) + + email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken) + if email != "" { + fmt.Printf(" Logged in as: %s\n", email) + } + + expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: profileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "idc", + Provider: "AWS", + ClientID: regResp.ClientID, + ClientSecret: regResp.ClientSecret, + Email: email, + StartURL: startURL, + Region: region, + }, nil + } +} + +func buildAuthorizationURL(endpoint, clientID, redirectURI, scopes, state, codeChallenge string) string { + params := url.Values{} + params.Set("response_type", "code") + params.Set("client_id", clientID) + params.Set("redirect_uri", redirectURI) + params.Set("scopes", scopes) + params.Set("state", state) + params.Set("code_challenge", codeChallenge) + params.Set("code_challenge_method", "S256") + return fmt.Sprintf("%s/authorize?%s", endpoint, params.Encode()) +} diff --git a/internal/auth/kiro/sso_oidc_test.go b/internal/auth/kiro/sso_oidc_test.go new file mode 100644 index 0000000000..760a6033ad --- /dev/null +++ b/internal/auth/kiro/sso_oidc_test.go @@ -0,0 +1,261 @@ +package kiro + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +type recordingRoundTripper struct { + lastReq *http.Request +} + +func (rt *recordingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + rt.lastReq = req + body := `{"nextToken":null,"profiles":[{"arn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC","profileName":"test"}]}` + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: make(http.Header), + }, nil +} + +func TestTryListAvailableProfiles_UsesClientIDForAccountKey(t *testing.T) { + rt := &recordingRoundTripper{} + client := &SSOOIDCClient{ + httpClient: &http.Client{Transport: rt}, + } + + profileArn := client.tryListAvailableProfiles(context.Background(), "access-token", "client-id-123", "refresh-token-456") + if profileArn == "" { + t.Fatal("expected profileArn, got empty result") + } + + accountKey := GetAccountKey("client-id-123", "refresh-token-456") + fp := GlobalFingerprintManager().GetFingerprint(accountKey) + expected := fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s-%s", fp.RuntimeSDKVersion, fp.KiroVersion, fp.KiroHash) + got := rt.lastReq.Header.Get("X-Amz-User-Agent") + if got != expected { + t.Errorf("X-Amz-User-Agent = %q, want %q", got, expected) + } +} + +func TestTryListAvailableProfiles_UsesRefreshTokenWhenClientIDMissing(t *testing.T) { + rt := &recordingRoundTripper{} + client := &SSOOIDCClient{ + httpClient: &http.Client{Transport: rt}, + } + + profileArn := client.tryListAvailableProfiles(context.Background(), "access-token", "", "refresh-token-789") + if profileArn == "" { + t.Fatal("expected profileArn, got empty result") + } + + accountKey := GetAccountKey("", "refresh-token-789") + fp := GlobalFingerprintManager().GetFingerprint(accountKey) + expected := fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s-%s", fp.RuntimeSDKVersion, fp.KiroVersion, fp.KiroHash) + got := rt.lastReq.Header.Get("X-Amz-User-Agent") + if got != expected { + t.Errorf("X-Amz-User-Agent = %q, want %q", got, expected) + } +} + +func TestRegisterClientForAuthCodeWithIDC(t *testing.T) { + var capturedReq struct { + Method string + Path string + Headers http.Header + Body map[string]interface{} + } + + mockResp := RegisterClientResponse{ + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + ClientIDIssuedAt: 1700000000, + ClientSecretExpiresAt: 1700086400, + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedReq.Method = r.Method + capturedReq.Path = r.URL.Path + capturedReq.Headers = r.Header.Clone() + + bodyBytes, _ := io.ReadAll(r.Body) + json.Unmarshal(bodyBytes, &capturedReq.Body) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(mockResp) + })) + defer ts.Close() + + // Extract host to build a region that resolves to our test server. + // Override getOIDCEndpoint by passing region="" and patching the endpoint. + // Since getOIDCEndpoint builds "https://oidc.{region}.amazonaws.com", we + // instead inject the test server URL directly via a custom HTTP client transport. + client := &SSOOIDCClient{ + httpClient: ts.Client(), + } + + // We need to route the request to our test server. Use a transport that rewrites the URL. + client.httpClient.Transport = &rewriteTransport{ + base: ts.Client().Transport, + targetURL: ts.URL, + } + + resp, err := client.RegisterClientForAuthCodeWithIDC( + context.Background(), + "http://127.0.0.1:19877/oauth/callback", + "https://my-idc-instance.awsapps.com/start", + "us-east-1", + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify request method and path + if capturedReq.Method != http.MethodPost { + t.Errorf("method = %q, want POST", capturedReq.Method) + } + if capturedReq.Path != "/client/register" { + t.Errorf("path = %q, want /client/register", capturedReq.Path) + } + + // Verify headers + if ct := capturedReq.Headers.Get("Content-Type"); ct != "application/json" { + t.Errorf("Content-Type = %q, want application/json", ct) + } + ua := capturedReq.Headers.Get("User-Agent") + if !strings.Contains(ua, "KiroIDE") { + t.Errorf("User-Agent %q does not contain KiroIDE", ua) + } + if !strings.Contains(ua, "sso-oidc") { + t.Errorf("User-Agent %q does not contain sso-oidc", ua) + } + xua := capturedReq.Headers.Get("X-Amz-User-Agent") + if !strings.Contains(xua, "KiroIDE") { + t.Errorf("x-amz-user-agent %q does not contain KiroIDE", xua) + } + + // Verify body fields + if v, _ := capturedReq.Body["clientName"].(string); v != "Kiro IDE" { + t.Errorf("clientName = %q, want %q", v, "Kiro IDE") + } + if v, _ := capturedReq.Body["clientType"].(string); v != "public" { + t.Errorf("clientType = %q, want %q", v, "public") + } + if v, _ := capturedReq.Body["issuerUrl"].(string); v != "https://my-idc-instance.awsapps.com/start" { + t.Errorf("issuerUrl = %q, want %q", v, "https://my-idc-instance.awsapps.com/start") + } + + // Verify scopes array + scopesRaw, ok := capturedReq.Body["scopes"].([]interface{}) + if !ok || len(scopesRaw) != 5 { + t.Fatalf("scopes: got %v, want 5-element array", capturedReq.Body["scopes"]) + } + expectedScopes := []string{ + "codewhisperer:completions", "codewhisperer:analysis", + "codewhisperer:conversations", "codewhisperer:transformations", + "codewhisperer:taskassist", + } + for i, s := range expectedScopes { + if scopesRaw[i].(string) != s { + t.Errorf("scopes[%d] = %q, want %q", i, scopesRaw[i], s) + } + } + + // Verify grantTypes + grantTypesRaw, ok := capturedReq.Body["grantTypes"].([]interface{}) + if !ok || len(grantTypesRaw) != 2 { + t.Fatalf("grantTypes: got %v, want 2-element array", capturedReq.Body["grantTypes"]) + } + if grantTypesRaw[0].(string) != "authorization_code" || grantTypesRaw[1].(string) != "refresh_token" { + t.Errorf("grantTypes = %v, want [authorization_code, refresh_token]", grantTypesRaw) + } + + // Verify redirectUris + redirectRaw, ok := capturedReq.Body["redirectUris"].([]interface{}) + if !ok || len(redirectRaw) != 1 { + t.Fatalf("redirectUris: got %v, want 1-element array", capturedReq.Body["redirectUris"]) + } + if redirectRaw[0].(string) != "http://127.0.0.1:19877/oauth/callback" { + t.Errorf("redirectUris[0] = %q, want %q", redirectRaw[0], "http://127.0.0.1:19877/oauth/callback") + } + + // Verify response parsing + if resp.ClientID != "test-client-id" { + t.Errorf("ClientID = %q, want %q", resp.ClientID, "test-client-id") + } + if resp.ClientSecret != "test-client-secret" { + t.Errorf("ClientSecret = %q, want %q", resp.ClientSecret, "test-client-secret") + } + if resp.ClientIDIssuedAt != 1700000000 { + t.Errorf("ClientIDIssuedAt = %d, want %d", resp.ClientIDIssuedAt, 1700000000) + } + if resp.ClientSecretExpiresAt != 1700086400 { + t.Errorf("ClientSecretExpiresAt = %d, want %d", resp.ClientSecretExpiresAt, 1700086400) + } +} + +// rewriteTransport redirects all requests to the test server URL. +type rewriteTransport struct { + base http.RoundTripper + targetURL string +} + +func (t *rewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) { + target, _ := url.Parse(t.targetURL) + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + if t.base != nil { + return t.base.RoundTrip(req) + } + return http.DefaultTransport.RoundTrip(req) +} + +func TestBuildAuthorizationURL(t *testing.T) { + scopes := "codewhisperer:completions,codewhisperer:analysis,codewhisperer:conversations,codewhisperer:transformations,codewhisperer:taskassist" + endpoint := "https://oidc.us-east-1.amazonaws.com" + redirectURI := "http://127.0.0.1:19877/oauth/callback" + + authURL := buildAuthorizationURL(endpoint, "test-client-id", redirectURI, scopes, "random-state", "test-challenge") + + // Verify colons and commas in scopes are percent-encoded + if !strings.Contains(authURL, "codewhisperer%3Acompletions") { + t.Errorf("expected colons in scopes to be percent-encoded, got: %s", authURL) + } + if !strings.Contains(authURL, "completions%2Ccodewhisperer") { + t.Errorf("expected commas in scopes to be percent-encoded, got: %s", authURL) + } + + // Parse back and verify all parameters round-trip correctly + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("failed to parse auth URL: %v", err) + } + + if !strings.HasPrefix(authURL, endpoint+"/authorize?") { + t.Errorf("expected URL to start with %s/authorize?, got: %s", endpoint, authURL) + } + + q := parsed.Query() + checks := map[string]string{ + "response_type": "code", + "client_id": "test-client-id", + "redirect_uri": redirectURI, + "scopes": scopes, + "state": "random-state", + "code_challenge": "test-challenge", + "code_challenge_method": "S256", + } + for key, want := range checks { + if got := q.Get(key); got != want { + t.Errorf("%s = %q, want %q", key, got, want) + } + } +} diff --git a/internal/auth/kiro/token.go b/internal/auth/kiro/token.go new file mode 100644 index 0000000000..91a4995b33 --- /dev/null +++ b/internal/auth/kiro/token.go @@ -0,0 +1,89 @@ +package kiro + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" +) + +// KiroTokenStorage holds the persistent token data for Kiro authentication. +type KiroTokenStorage struct { + // Type is the provider type for management UI recognition (must be "kiro") + Type string `json:"type"` + // AccessToken is the OAuth2 access token for API access + AccessToken string `json:"access_token"` + // RefreshToken is used to obtain new access tokens + RefreshToken string `json:"refresh_token"` + // ProfileArn is the AWS CodeWhisperer profile ARN + ProfileArn string `json:"profile_arn"` + // ExpiresAt is the timestamp when the token expires + ExpiresAt string `json:"expires_at"` + // AuthMethod indicates the authentication method used + AuthMethod string `json:"auth_method"` + // Provider indicates the OAuth provider + Provider string `json:"provider"` + // LastRefresh is the timestamp of the last token refresh + LastRefresh string `json:"last_refresh"` + // ClientID is the OAuth client ID (required for token refresh) + ClientID string `json:"client_id,omitempty"` + // ClientSecret is the OAuth client secret (required for token refresh) + ClientSecret string `json:"client_secret,omitempty"` + // Region is the OIDC region for IDC login and token refresh + Region string `json:"region,omitempty"` + // StartURL is the AWS Identity Center start URL (for IDC auth) + StartURL string `json:"start_url,omitempty"` + // Email is the user's email address + Email string `json:"email,omitempty"` +} + +// SaveTokenToFile persists the token storage to the specified file path. +func (s *KiroTokenStorage) SaveTokenToFile(authFilePath string) error { + dir := filepath.Dir(authFilePath) + if err := os.MkdirAll(dir, 0700); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + + data, err := json.MarshalIndent(s, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal token storage: %w", err) + } + + if err := os.WriteFile(authFilePath, data, 0600); err != nil { + return fmt.Errorf("failed to write token file: %w", err) + } + + return nil +} + +// LoadFromFile loads token storage from the specified file path. +func LoadFromFile(authFilePath string) (*KiroTokenStorage, error) { + data, err := os.ReadFile(authFilePath) + if err != nil { + return nil, fmt.Errorf("failed to read token file: %w", err) + } + + var storage KiroTokenStorage + if err := json.Unmarshal(data, &storage); err != nil { + return nil, fmt.Errorf("failed to parse token file: %w", err) + } + + return &storage, nil +} + +// ToTokenData converts storage to KiroTokenData for API use. +func (s *KiroTokenStorage) ToTokenData() *KiroTokenData { + return &KiroTokenData{ + AccessToken: s.AccessToken, + RefreshToken: s.RefreshToken, + ProfileArn: s.ProfileArn, + ExpiresAt: s.ExpiresAt, + AuthMethod: s.AuthMethod, + Provider: s.Provider, + ClientID: s.ClientID, + ClientSecret: s.ClientSecret, + Region: s.Region, + StartURL: s.StartURL, + Email: s.Email, + } +} diff --git a/internal/auth/kiro/token_repository.go b/internal/auth/kiro/token_repository.go new file mode 100644 index 0000000000..3ddf620e8f --- /dev/null +++ b/internal/auth/kiro/token_repository.go @@ -0,0 +1,260 @@ +package kiro + +import ( + "context" + "encoding/json" + "fmt" + "io/fs" + "os" + "path/filepath" + "sort" + "strings" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +// FileTokenRepository 实现 TokenRepository 接口,基于文件系统存储 +type FileTokenRepository struct { + mu sync.RWMutex + baseDir string +} + +// NewFileTokenRepository 创建一个新的文件 token 存储库 +func NewFileTokenRepository(baseDir string) *FileTokenRepository { + return &FileTokenRepository{ + baseDir: baseDir, + } +} + +// SetBaseDir 设置基础目录 +func (r *FileTokenRepository) SetBaseDir(dir string) { + r.mu.Lock() + r.baseDir = strings.TrimSpace(dir) + r.mu.Unlock() +} + +// FindOldestUnverified 查找需要刷新的 token(按最后验证时间排序) +func (r *FileTokenRepository) FindOldestUnverified(limit int) []*Token { + r.mu.RLock() + baseDir := r.baseDir + r.mu.RUnlock() + + if baseDir == "" { + log.Debug("token repository: base directory not configured") + return nil + } + + var tokens []*Token + + err := filepath.WalkDir(baseDir, func(path string, d fs.DirEntry, walkErr error) error { + if walkErr != nil { + return nil // 忽略错误,继续遍历 + } + if d.IsDir() { + return nil + } + if !strings.HasSuffix(strings.ToLower(d.Name()), ".json") { + return nil + } + + // 只处理 kiro 相关的 token 文件 + if !strings.HasPrefix(d.Name(), "kiro-") { + return nil + } + + token, err := r.readTokenFile(path) + if err != nil { + log.Debugf("token repository: failed to read token file %s: %v", path, err) + return nil + } + + if token != nil && token.RefreshToken != "" { + // 检查 token 是否需要刷新(过期前 5 分钟) + if token.ExpiresAt.IsZero() || time.Until(token.ExpiresAt) < 5*time.Minute { + tokens = append(tokens, token) + } + } + + return nil + }) + + if err != nil { + log.Warnf("token repository: error walking directory: %v", err) + } + + // 按最后验证时间排序(最旧的优先) + sort.Slice(tokens, func(i, j int) bool { + return tokens[i].LastVerified.Before(tokens[j].LastVerified) + }) + + // 限制返回数量 + if limit > 0 && len(tokens) > limit { + tokens = tokens[:limit] + } + + return tokens +} + +// UpdateToken 更新 token 并持久化到文件 +func (r *FileTokenRepository) UpdateToken(token *Token) error { + if token == nil { + return fmt.Errorf("token repository: token is nil") + } + + r.mu.RLock() + baseDir := r.baseDir + r.mu.RUnlock() + + if baseDir == "" { + return fmt.Errorf("token repository: base directory not configured") + } + + // 构建文件路径 + filePath := filepath.Join(baseDir, token.ID) + if !strings.HasSuffix(filePath, ".json") { + filePath += ".json" + } + + // 读取现有文件内容 + existingData := make(map[string]any) + if data, err := os.ReadFile(filePath); err == nil { + _ = json.Unmarshal(data, &existingData) + } + + // 更新字段 + existingData["access_token"] = token.AccessToken + existingData["refresh_token"] = token.RefreshToken + existingData["last_refresh"] = time.Now().Format(time.RFC3339) + + if !token.ExpiresAt.IsZero() { + existingData["expires_at"] = token.ExpiresAt.Format(time.RFC3339) + } + + // 保持原有的关键字段 + if token.ClientID != "" { + existingData["client_id"] = token.ClientID + } + if token.ClientSecret != "" { + existingData["client_secret"] = token.ClientSecret + } + if token.AuthMethod != "" { + existingData["auth_method"] = token.AuthMethod + } + if token.Region != "" { + existingData["region"] = token.Region + } + if token.StartURL != "" { + existingData["start_url"] = token.StartURL + } + + // 序列化并写入文件 + raw, err := json.MarshalIndent(existingData, "", " ") + if err != nil { + return fmt.Errorf("token repository: marshal failed: %w", err) + } + + // 原子写入:先写入临时文件,再重命名 + tmpPath := filePath + ".tmp" + if err := os.WriteFile(tmpPath, raw, 0o600); err != nil { + return fmt.Errorf("token repository: write temp file failed: %w", err) + } + if err := os.Rename(tmpPath, filePath); err != nil { + _ = os.Remove(tmpPath) + return fmt.Errorf("token repository: rename failed: %w", err) + } + + log.Debugf("token repository: updated token %s", token.ID) + return nil +} + +// readTokenFile 从文件读取 token +func (r *FileTokenRepository) readTokenFile(path string) (*Token, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + var metadata map[string]any + if err := json.Unmarshal(data, &metadata); err != nil { + return nil, err + } + + // 检查是否是 kiro token + tokenType, _ := metadata["type"].(string) + if tokenType != "kiro" { + return nil, nil + } + + // 检查 auth_method (case-insensitive comparison to handle "IdC", "IDC", "idc", etc.) + authMethod, _ := metadata["auth_method"].(string) + authMethod = strings.ToLower(authMethod) + if authMethod != "idc" && authMethod != "builder-id" { + return nil, nil // 只处理 IDC 和 Builder ID token + } + + token := &Token{ + ID: filepath.Base(path), + AuthMethod: authMethod, + } + + // 解析各字段 + token.AccessToken, _ = metadata["access_token"].(string) + token.RefreshToken, _ = metadata["refresh_token"].(string) + token.ClientID, _ = metadata["client_id"].(string) + token.ClientSecret, _ = metadata["client_secret"].(string) + token.Region, _ = metadata["region"].(string) + token.StartURL, _ = metadata["start_url"].(string) + token.Provider, _ = metadata["provider"].(string) + + // 解析时间字段 + if expiresAtStr, ok := metadata["expires_at"].(string); ok && expiresAtStr != "" { + if t, err := time.Parse(time.RFC3339, expiresAtStr); err == nil { + token.ExpiresAt = t + } + } + if lastRefreshStr, ok := metadata["last_refresh"].(string); ok && lastRefreshStr != "" { + if t, err := time.Parse(time.RFC3339, lastRefreshStr); err == nil { + token.LastVerified = t + } + } + + return token, nil +} + +// ListKiroTokens 列出所有 Kiro token(用于调试) +func (r *FileTokenRepository) ListKiroTokens(ctx context.Context) ([]*Token, error) { + r.mu.RLock() + baseDir := r.baseDir + r.mu.RUnlock() + + if baseDir == "" { + return nil, fmt.Errorf("token repository: base directory not configured") + } + + var tokens []*Token + + err := filepath.WalkDir(baseDir, func(path string, d fs.DirEntry, walkErr error) error { + if walkErr != nil { + return nil + } + if d.IsDir() { + return nil + } + if !strings.HasPrefix(d.Name(), "kiro-") || !strings.HasSuffix(d.Name(), ".json") { + return nil + } + + token, err := r.readTokenFile(path) + if err != nil { + return nil + } + if token != nil { + tokens = append(tokens, token) + } + return nil + }) + + return tokens, err +} diff --git a/internal/auth/kiro/usage_checker.go b/internal/auth/kiro/usage_checker.go new file mode 100644 index 0000000000..30ec1c5dc1 --- /dev/null +++ b/internal/auth/kiro/usage_checker.go @@ -0,0 +1,236 @@ +// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API. +// This file implements usage quota checking and monitoring. +package kiro + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + "github.com/router-for-me/CLIProxyAPI/v7/internal/util" +) + +// UsageQuotaResponse represents the API response structure for usage quota checking. +type UsageQuotaResponse struct { + UsageBreakdownList []UsageBreakdownExtended `json:"usageBreakdownList"` + SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"` + NextDateReset float64 `json:"nextDateReset,omitempty"` +} + +// UsageBreakdownExtended represents detailed usage information for quota checking. +// Note: UsageBreakdown is already defined in codewhisperer_client.go +type UsageBreakdownExtended struct { + ResourceType string `json:"resourceType"` + UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"` + CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"` + FreeTrialInfo *FreeTrialInfoExtended `json:"freeTrialInfo,omitempty"` +} + +// FreeTrialInfoExtended represents free trial usage information. +type FreeTrialInfoExtended struct { + FreeTrialStatus string `json:"freeTrialStatus"` + UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"` + CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"` +} + +// QuotaStatus represents the quota status for a token. +type QuotaStatus struct { + TotalLimit float64 + CurrentUsage float64 + RemainingQuota float64 + IsExhausted bool + ResourceType string + NextReset time.Time +} + +// UsageChecker provides methods for checking token quota usage. +type UsageChecker struct { + httpClient *http.Client +} + +// NewUsageChecker creates a new UsageChecker instance. +func NewUsageChecker(cfg *config.Config) *UsageChecker { + return &UsageChecker{ + httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}), + } +} + +// NewUsageCheckerWithClient creates a UsageChecker with a custom HTTP client. +func NewUsageCheckerWithClient(client *http.Client) *UsageChecker { + return &UsageChecker{ + httpClient: client, + } +} + +// CheckUsage retrieves usage limits for the given token. +func (c *UsageChecker) CheckUsage(ctx context.Context, tokenData *KiroTokenData) (*UsageQuotaResponse, error) { + if tokenData == nil { + return nil, fmt.Errorf("token data is nil") + } + + if tokenData.AccessToken == "" { + return nil, fmt.Errorf("access token is empty") + } + + queryParams := map[string]string{ + "origin": "AI_EDITOR", + "profileArn": tokenData.ProfileArn, + "resourceType": "AGENTIC_REQUEST", + } + + // Use endpoint from profileArn if available + endpoint := GetKiroAPIEndpointFromProfileArn(tokenData.ProfileArn) + url := buildURL(endpoint, pathGetUsageLimits, queryParams) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + accountKey := GetAccountKey(tokenData.ClientID, tokenData.RefreshToken) + setRuntimeHeaders(req, tokenData.AccessToken, accountKey) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) + } + + var result UsageQuotaResponse + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse usage response: %w", err) + } + + return &result, nil +} + +// CheckUsageByAccessToken retrieves usage limits using an access token and profile ARN directly. +func (c *UsageChecker) CheckUsageByAccessToken(ctx context.Context, accessToken, profileArn string) (*UsageQuotaResponse, error) { + tokenData := &KiroTokenData{ + AccessToken: accessToken, + ProfileArn: profileArn, + } + return c.CheckUsage(ctx, tokenData) +} + +// GetRemainingQuota calculates the remaining quota from usage limits. +func GetRemainingQuota(usage *UsageQuotaResponse) float64 { + if usage == nil || len(usage.UsageBreakdownList) == 0 { + return 0 + } + + var totalRemaining float64 + for _, breakdown := range usage.UsageBreakdownList { + remaining := breakdown.UsageLimitWithPrecision - breakdown.CurrentUsageWithPrecision + if remaining > 0 { + totalRemaining += remaining + } + + if breakdown.FreeTrialInfo != nil { + freeRemaining := breakdown.FreeTrialInfo.UsageLimitWithPrecision - breakdown.FreeTrialInfo.CurrentUsageWithPrecision + if freeRemaining > 0 { + totalRemaining += freeRemaining + } + } + } + + return totalRemaining +} + +// IsQuotaExhausted checks if the quota is exhausted based on usage limits. +func IsQuotaExhausted(usage *UsageQuotaResponse) bool { + if usage == nil || len(usage.UsageBreakdownList) == 0 { + return true + } + + for _, breakdown := range usage.UsageBreakdownList { + if breakdown.CurrentUsageWithPrecision < breakdown.UsageLimitWithPrecision { + return false + } + + if breakdown.FreeTrialInfo != nil { + if breakdown.FreeTrialInfo.CurrentUsageWithPrecision < breakdown.FreeTrialInfo.UsageLimitWithPrecision { + return false + } + } + } + + return true +} + +// GetQuotaStatus retrieves a comprehensive quota status for a token. +func (c *UsageChecker) GetQuotaStatus(ctx context.Context, tokenData *KiroTokenData) (*QuotaStatus, error) { + usage, err := c.CheckUsage(ctx, tokenData) + if err != nil { + return nil, err + } + + status := &QuotaStatus{ + IsExhausted: IsQuotaExhausted(usage), + } + + if len(usage.UsageBreakdownList) > 0 { + breakdown := usage.UsageBreakdownList[0] + status.TotalLimit = breakdown.UsageLimitWithPrecision + status.CurrentUsage = breakdown.CurrentUsageWithPrecision + status.RemainingQuota = breakdown.UsageLimitWithPrecision - breakdown.CurrentUsageWithPrecision + status.ResourceType = breakdown.ResourceType + + if breakdown.FreeTrialInfo != nil { + status.TotalLimit += breakdown.FreeTrialInfo.UsageLimitWithPrecision + status.CurrentUsage += breakdown.FreeTrialInfo.CurrentUsageWithPrecision + freeRemaining := breakdown.FreeTrialInfo.UsageLimitWithPrecision - breakdown.FreeTrialInfo.CurrentUsageWithPrecision + if freeRemaining > 0 { + status.RemainingQuota += freeRemaining + } + } + } + + if usage.NextDateReset > 0 { + status.NextReset = time.Unix(int64(usage.NextDateReset/1000), 0) + } + + return status, nil +} + +// CalculateAvailableCount calculates the available request count based on usage limits. +func CalculateAvailableCount(usage *UsageQuotaResponse) float64 { + return GetRemainingQuota(usage) +} + +// GetUsagePercentage calculates the usage percentage. +func GetUsagePercentage(usage *UsageQuotaResponse) float64 { + if usage == nil || len(usage.UsageBreakdownList) == 0 { + return 100.0 + } + + var totalLimit, totalUsage float64 + for _, breakdown := range usage.UsageBreakdownList { + totalLimit += breakdown.UsageLimitWithPrecision + totalUsage += breakdown.CurrentUsageWithPrecision + + if breakdown.FreeTrialInfo != nil { + totalLimit += breakdown.FreeTrialInfo.UsageLimitWithPrecision + totalUsage += breakdown.FreeTrialInfo.CurrentUsageWithPrecision + } + } + + if totalLimit == 0 { + return 100.0 + } + + return (totalUsage / totalLimit) * 100 +} diff --git a/internal/auth/qoder/auth.go b/internal/auth/qoder/auth.go new file mode 100644 index 0000000000..cc2f800770 --- /dev/null +++ b/internal/auth/qoder/auth.go @@ -0,0 +1,197 @@ +// Package qoder provides OAuth2 authentication functionality for the Qoder provider. +package qoder + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + + log "github.com/sirupsen/logrus" +) + +// UserStatusResponse represents the response from the user status endpoint. +type UserStatusResponse struct { + ID string `json:"id"` + Name string `json:"name"` + Email string `json:"email"` +} + +// QoderAuth handles Qoder PKCE + URI-scheme authentication. +type QoderAuth struct { + httpClient *http.Client +} + +// NewQoderAuth creates a new Qoder auth service. +func NewQoderAuth(httpClient *http.Client) *QoderAuth { + if httpClient == nil { + httpClient = &http.Client{} + } + return &QoderAuth{httpClient: httpClient} +} + +// GeneratePKCE generates a PKCE verifier/challenge pair and a nonce. +func GeneratePKCE() (nonce, challenge, verifier string, err error) { + // Generate 32-byte random verifier + verifierBytes := make([]byte, 32) + if _, err = rand.Read(verifierBytes); err != nil { + return "", "", "", fmt.Errorf("qoder: generate verifier: %w", err) + } + verifier = base64.RawURLEncoding.EncodeToString(verifierBytes) + + // SHA-256 challenge + challengeHash := sha256.Sum256([]byte(verifier)) + challenge = base64.RawURLEncoding.EncodeToString(challengeHash[:]) + + // Nonce + nonceBytes := make([]byte, 16) + if _, err = rand.Read(nonceBytes); err != nil { + return "", "", "", fmt.Errorf("qoder: generate nonce: %w", err) + } + nonce = fmt.Sprintf("%x", nonceBytes) + + return nonce, challenge, verifier, nil +} + +// BuildAuthURL constructs the Qoder login URL for browser-based authentication. +func BuildAuthURL(nonce, challenge, machineID string) string { + params := url.Values{} + params.Set("nonce", nonce) + params.Set("challenge", challenge) + params.Set("challenge_method", "S256") + params.Set("redirect_uri", RedirectURI) + params.Set("machine_id", machineID) + return AuthBase + SelectAccountsPath + "?" + params.Encode() +} + +// FetchUserStatus retrieves user info using a device token. +func (o *QoderAuth) FetchUserStatus(deviceToken string) (*UserStatusResponse, error) { + deviceToken = strings.TrimSpace(deviceToken) + if deviceToken == "" { + return nil, fmt.Errorf("qoder user status: missing device token") + } + reqURL := OpenAPIBase + UserStatusPath + req, err := http.NewRequest(http.MethodGet, reqURL, nil) + if err != nil { + return nil, fmt.Errorf("qoder user status: create request: %w", err) + } + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+deviceToken) + req.Header.Set("Cosy-Version", IDEVersion) + req.Header.Set("Cosy-Clienttype", "0") + + resp, errDo := o.httpClient.Do(req) + if errDo != nil { + return nil, fmt.Errorf("qoder user status: execute request: %w", errDo) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("qoder user status: close body error: %v", errClose) + } + }() + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + bodyBytes, errRead := io.ReadAll(io.LimitReader(resp.Body, 8<<10)) + if errRead != nil { + return nil, fmt.Errorf("qoder user status: read response: %w", errRead) + } + body := strings.TrimSpace(string(bodyBytes)) + if body == "" { + return nil, fmt.Errorf("qoder user status: request failed: status %d", resp.StatusCode) + } + return nil, fmt.Errorf("qoder user status: request failed: status %d: %s", resp.StatusCode, body) + } + + var user UserStatusResponse + if errDecode := json.NewDecoder(resp.Body).Decode(&user); errDecode != nil { + return nil, fmt.Errorf("qoder user status: decode response: %w", errDecode) + } + return &user, nil +} + +// DecodeAuthField decodes the obfuscated auth callback field to extract user info. +func DecodeAuthField(authStr string) (map[string]any, error) { + if strings.TrimSpace(authStr) == "" { + return nil, fmt.Errorf("qoder: empty auth field") + } + + // Reverse custom alphabet to standard base64 + var b64 strings.Builder + for _, c := range authStr { + ch := string(c) + if ch == CustomPad { + b64.WriteByte('=') + } else { + idx := strings.Index(CustomAlphabet, ch) + if idx >= 0 { + b64.WriteByte(StdAlphabet[idx]) + } else { + b64.WriteString(ch) + } + } + } + + decoded := b64.String() + + // Find the base64-encoded JSON payload starting with "eyJ" + eqPos := strings.Index(decoded, "=") + var head, tail string + if eqPos < 0 { + head = decoded + tail = "" + } else { + tail = decoded[:eqPos] + head = decoded[eqPos+1:] + } + + eyjPos := strings.Index(head, "eyJ") + var reconstructed string + if eyjPos < 0 { + eyjFull := strings.Index(decoded, "eyJ") + if eyjFull < 0 { + return nil, fmt.Errorf("qoder: no JWT payload found in auth field") + } + reconstructed = decoded[eyjFull:] + } else { + reconstructed = head[eyjPos:] + head[:eyjPos] + tail + "=" + } + + // Try decoding with different padding + for _, pad := range []string{"", "=", "==", "==="} { + raw, errDec := base64.StdEncoding.DecodeString(reconstructed + pad) + if errDec != nil { + raw, errDec = base64.RawStdEncoding.DecodeString(reconstructed + pad) + if errDec != nil { + continue + } + } + var result map[string]any + if errJSON := json.Unmarshal(raw, &result); errJSON != nil { + continue + } + return result, nil + } + + return nil, fmt.Errorf("qoder: failed to decode auth field") +} + +// GenerateMachineID creates a deterministic machine identifier. +func GenerateMachineID(hostname, macAddr, system, machine string) string { + raw := fmt.Sprintf("%s-%s-%s-%s", hostname, macAddr, system, machine) + digest := sha256.Sum256([]byte(raw)) + encoded := base64.RawURLEncoding.EncodeToString(digest[:]) + var parts []string + for i := 0; i < len(encoded); i += 22 { + end := i + 22 + if end > len(encoded) { + end = len(encoded) + } + parts = append(parts, encoded[i:end]) + } + return strings.Join(parts, "-") +} diff --git a/internal/auth/qoder/constants.go b/internal/auth/qoder/constants.go new file mode 100644 index 0000000000..65d1a34bc4 --- /dev/null +++ b/internal/auth/qoder/constants.go @@ -0,0 +1,41 @@ +// Package qoder provides OAuth2 authentication functionality for the Qoder provider. +package qoder + +// Qoder login configuration +const ( + CallbackPort = 51122 + AuthBase = "https://qoder.com" + CenterBase = "https://center.qoder.sh" + ChatBase = "https://api3.qoder.sh" + OpenAPIBase = "https://openapi.qoder.sh" + IDEVersion = "0.14.2" + CosyVersion = "1.0.0" + RedirectURI = "qoder://aicoding.aicoding-agent/login-success" +) + +// SelectAccountsPath is the browser login page path. +const SelectAccountsPath = "/device/selectAccounts" + +// ServerPublicKeyPEM is the RSA public key for COSY authentication. +const ServerPublicKeyPEM = `-----BEGIN PUBLIC KEY----- +MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDA8iMH5c02LilrsERw9t6Pv5Nc +4k6Pz1EaDicBMpdpxKduSZu5OANqUq8er4GM95omAGIOPOh+Nx0spthYA2BqGz+l +6HRkPJ7S236FZz73In/KVuLnwI8JJ2CbuJap8kvheCCZpmAWpb/cPx/3Vr/J6I17 +XcW+ML9FoCI6AOvOzwIDAQAB +-----END PUBLIC KEY-----` + +// Custom base64 encoding alphabet used by Qoder body encoding. +const ( + CustomPad = "$" + StdAlphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" + CustomAlphabet = "_doRTgHZBKcGVjlvpC,@aFSx#DPuNJme&i*MzLOEn)sUrthbf%Y^w.(kIQyXqWA!" +) + +// Chat endpoint path +const ( + ChatPath = "/algo/api/v2/service/pro/sse/agent_chat_generation" + ChatQueryExtra = "FetchKeys=llm_model_result&AgentId=agent_common" + ModelListPath = "/algo/api/v2/model/list" + UserPlanPath = "/algo/api/v2/user/plan" + UserStatusPath = "/api/v3/user/status" +) diff --git a/internal/auth/qoder/filename.go b/internal/auth/qoder/filename.go new file mode 100644 index 0000000000..1fab171a91 --- /dev/null +++ b/internal/auth/qoder/filename.go @@ -0,0 +1,16 @@ +package qoder + +import ( + "fmt" + "strings" +) + +// CredentialFileName returns the filename used to persist Qoder credentials. +// It uses the uid as a suffix to disambiguate accounts. +func CredentialFileName(uid string) string { + uid = strings.TrimSpace(uid) + if uid == "" { + return "qoder.json" + } + return fmt.Sprintf("qoder-%s.json", uid) +} diff --git a/internal/auth/qoder/uri_handler_other.go b/internal/auth/qoder/uri_handler_other.go new file mode 100644 index 0000000000..e5fff6b05e --- /dev/null +++ b/internal/auth/qoder/uri_handler_other.go @@ -0,0 +1,13 @@ +//go:build !windows + +package qoder + +// RegisterURIHandler is a no-op on non-Windows platforms. +// On Linux/macOS, the qoder:// protocol would need xdg-open or other platform-specific handling. +// For now, users on non-Windows platforms should paste the callback URL manually. +func RegisterURIHandler(callbackPort int) func() { + return func() {} +} + +// UnregisterURIHandler is a no-op on non-Windows platforms. +func UnregisterURIHandler() {} diff --git a/internal/auth/qoder/uri_handler_windows.go b/internal/auth/qoder/uri_handler_windows.go new file mode 100644 index 0000000000..79fb4aad77 --- /dev/null +++ b/internal/auth/qoder/uri_handler_windows.go @@ -0,0 +1,89 @@ +//go:build windows + +package qoder + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + + log "github.com/sirupsen/logrus" +) + +const vbsFileName = "qoder_login_handler.vbs" + +// RegisterURIHandler registers the qoder:// URI protocol handler on Windows. +// It creates a VBS script that forwards the qoder:// callback URL to the local +// HTTP callback server, then registers the protocol in the Windows registry. +// Returns a cleanup function that should be deferred. +func RegisterURIHandler(callbackPort int) func() { + vbsPath := filepath.Join(os.TempDir(), vbsFileName) + + vbsContent := fmt.Sprintf(`Set objHTTP = CreateObject("MSXML2.XMLHTTP") +On Error Resume Next +url = "http://127.0.0.1:%d/forward?url=" +url = url & UrlEncode(WScript.Arguments(0)) +objHTTP.Open "GET", url, False +objHTTP.send + +Function UrlEncode(str) + Dim result, i, c + result = "" + For i = 1 To Len(str) + c = Mid(str, i, 1) + If c = " " Then + result = result & "+" + ElseIf InStr("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_.~", c) > 0 Then + result = result & c + Else + result = result & "%%" & Right("0" & Hex(Asc(c)), 2) + End If + Next + UrlEncode = result +End Function +`, callbackPort) + + if err := os.WriteFile(vbsPath, []byte(vbsContent), 0o644); err != nil { + log.Errorf("qoder: failed to write VBS handler script: %v", err) + return func() {} + } + + regCmds := [][]string{ + {"reg", "add", `HKCU\Software\Classes\qoder`, "/ve", "/t", "REG_SZ", "/d", "URL:QoderLogin", "/f"}, + {"reg", "add", `HKCU\Software\Classes\qoder`, "/v", "URL Protocol", "/t", "REG_SZ", "/d", "", "/f"}, + {"reg", "add", `HKCU\Software\Classes\qoder\shell`, "/f"}, + {"reg", "add", `HKCU\Software\Classes\qoder\shell\open`, "/f"}, + {"reg", "add", `HKCU\Software\Classes\qoder\shell\open\command`, + "/ve", "/t", "REG_SZ", "/d", fmt.Sprintf(`wscript.exe "%s" %%1`, vbsPath), "/f"}, + } + + for _, args := range regCmds { + cmd := exec.Command(args[0], args[1:]...) + cmd.Stdout = nil + cmd.Stderr = nil + _ = cmd.Run() + } + + log.Infof("qoder: registered qoder:// URI handler (VBS: %s)", vbsPath) + + return func() { + UnregisterURIHandler() + } +} + +// UnregisterURIHandler removes the qoder:// URI protocol handler from Windows registry +// and cleans up the temporary VBS script. +func UnregisterURIHandler() { + cmd := exec.Command("reg", "delete", `HKCU\Software\Classes\qoder`, "/f") + cmd.Stdout = nil + cmd.Stderr = nil + _ = cmd.Run() + + vbsPath := filepath.Join(os.TempDir(), vbsFileName) + if _, err := os.Stat(vbsPath); err == nil { + _ = os.Remove(vbsPath) + } + + log.Info("qoder: unregistered qoder:// URI handler") +} diff --git a/internal/auth/vertex/vertex_credentials.go b/internal/auth/vertex/vertex_credentials.go index 9f830994ed..db214bd6e2 100644 --- a/internal/auth/vertex/vertex_credentials.go +++ b/internal/auth/vertex/vertex_credentials.go @@ -8,7 +8,7 @@ import ( "os" "path/filepath" - "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" + "github.com/router-for-me/CLIProxyAPI/v7/internal/misc" log "github.com/sirupsen/logrus" ) diff --git a/internal/browser/browser.go b/internal/browser/browser.go index b24dc5e112..d56f877537 100644 --- a/internal/browser/browser.go +++ b/internal/browser/browser.go @@ -144,3 +144,29 @@ func GetPlatformInfo() map[string]interface{} { return info } + +// incognitoMode controls whether to open URLs in incognito/private mode. +var incognitoMode bool + +// lastBrowserCmd stores the last opened browser process for cleanup. +var lastBrowserCmd *exec.Cmd + +// SetIncognitoMode enables or disables incognito/private browsing mode. +func SetIncognitoMode(enabled bool) { + incognitoMode = enabled +} + +// IsIncognitoMode returns whether incognito mode is enabled. +func IsIncognitoMode() bool { + return incognitoMode +} + +// CloseBrowser closes the last opened browser process. +func CloseBrowser() error { + if lastBrowserCmd == nil || lastBrowserCmd.Process == nil { + return nil + } + err := lastBrowserCmd.Process.Kill() + lastBrowserCmd = nil + return err +} diff --git a/internal/cmd/anthropic_login.go b/internal/cmd/anthropic_login.go index f7381461a6..cc1bfc8e7c 100644 --- a/internal/cmd/anthropic_login.go +++ b/internal/cmd/anthropic_login.go @@ -6,9 +6,9 @@ import ( "fmt" "os" - "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" + "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/claude" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" log "github.com/sirupsen/logrus" ) diff --git a/internal/cmd/antigravity_login.go b/internal/cmd/antigravity_login.go index 2efbaeee01..f2bd5505a2 100644 --- a/internal/cmd/antigravity_login.go +++ b/internal/cmd/antigravity_login.go @@ -4,8 +4,8 @@ import ( "context" "fmt" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" log "github.com/sirupsen/logrus" ) diff --git a/internal/cmd/auth_manager.go b/internal/cmd/auth_manager.go index 2654717901..7896a7023a 100644 --- a/internal/cmd/auth_manager.go +++ b/internal/cmd/auth_manager.go @@ -1,7 +1,7 @@ package cmd import ( - sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" ) // newAuthManager creates a new authentication manager instance with all supported diff --git a/internal/cmd/bt_login.go b/internal/cmd/bt_login.go new file mode 100644 index 0000000000..54cb17acff --- /dev/null +++ b/internal/cmd/bt_login.go @@ -0,0 +1,80 @@ +package cmd + +import ( + "context" + "fmt" + "os" + + btauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/bt" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" + log "github.com/sirupsen/logrus" +) + +func DoBTLogin(cfg *config.Config) { + if cfg == nil || len(cfg.BTKey) == 0 { + log.Error("No BT credentials configured. Add bt entries to config.yaml.") + return + } + + store := sdkAuth.GetTokenStore() + successCount := 0 + + for i, entry := range cfg.BTKey { + phone := entry.Phone + password := entry.Password + if phone == "" || password == "" { + log.Warnf("bt[%d]: missing phone or password, skipping", i) + continue + } + + fmt.Printf("Logging in BT account: %s ...\n", phone) + + token, err := btauth.Login(phone, password) + if err != nil { + log.Errorf("bt[%d]: login failed for phone %s: %v", i, phone, err) + continue + } + + authID := fmt.Sprintf("bt-%s.json", phone) + auth := &cliproxyauth.Auth{ + ID: authID, + FileName: authID, + Provider: "bt", + Label: phone, + Storage: token, + Metadata: map[string]any{ + "type": "bt", + "bt_phone": phone, + "bt_password": password, + "uid": token.UID, + "access_key": token.AccessKey, + "serverid": token.ServerID, + }, + } + + if cfg != nil { + if dirSetter, ok := store.(interface{ SetBaseDir(string) }); ok { + dirSetter.SetBaseDir(cfg.AuthDir) + } + } + + savedPath, err := store.Save(context.Background(), auth) + if err != nil { + log.Errorf("bt[%d]: failed to save credentials: %v", i, err) + continue + } + + fmt.Printf("BT authentication saved to %s\n", savedPath) + fmt.Printf("Authenticated as %s (uid: %s)\n", phone, token.UID) + successCount++ + } + + if successCount > 0 { + fmt.Printf("\nBT authentication successful! (%d account(s))\n", successCount) + } else { + fmt.Println("\nNo BT accounts were authenticated successfully.") + os.Exit(1) + } +} diff --git a/internal/cmd/codearts_login.go b/internal/cmd/codearts_login.go new file mode 100644 index 0000000000..be5f73f525 --- /dev/null +++ b/internal/cmd/codearts_login.go @@ -0,0 +1,37 @@ +package cmd + +import ( + "context" + "fmt" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" + log "github.com/sirupsen/logrus" +) + +func DoCodeArtsLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + manager := newAuthManager() + authOpts := &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + CallbackPort: options.CallbackPort, + Metadata: map[string]string{}, + } + + record, savedPath, err := manager.Login(context.Background(), "codearts", cfg, authOpts) + if err != nil { + log.Errorf("CodeArts authentication failed: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Authenticated as %s\n", record.Label) + } + fmt.Println("CodeArts authentication successful!") +} diff --git a/internal/cmd/codebuddy_ai_login.go b/internal/cmd/codebuddy_ai_login.go new file mode 100644 index 0000000000..33a4056be0 --- /dev/null +++ b/internal/cmd/codebuddy_ai_login.go @@ -0,0 +1,36 @@ +package cmd + +import ( + "context" + "fmt" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" + log "github.com/sirupsen/logrus" +) + +func DoCodeBuddyAILogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + manager := newAuthManager() + authOpts := &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + Metadata: map[string]string{}, + } + + record, savedPath, err := manager.Login(context.Background(), "codebuddy-ai", cfg, authOpts) + if err != nil { + log.Errorf("CodeBuddy AI authentication failed: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Authenticated as %s\n", record.Label) + } + fmt.Println("CodeBuddy AI authentication successful!") +} diff --git a/internal/cmd/codebuddy_login.go b/internal/cmd/codebuddy_login.go new file mode 100644 index 0000000000..dd718851c6 --- /dev/null +++ b/internal/cmd/codebuddy_login.go @@ -0,0 +1,43 @@ +package cmd + +import ( + "context" + "fmt" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" + log "github.com/sirupsen/logrus" +) + +// DoCodeBuddyLogin triggers the browser OAuth polling flow for CodeBuddy and saves tokens. +// It initiates the OAuth authentication, displays the user code for the user to enter +// at the CodeBuddy verification URL, and waits for authorization before saving the tokens. +// +// Parameters: +// - cfg: The application configuration containing proxy and auth directory settings +// - options: Login options including browser behavior settings +func DoCodeBuddyLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + manager := newAuthManager() + authOpts := &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + Metadata: map[string]string{}, + } + + record, savedPath, err := manager.Login(context.Background(), "codebuddy", cfg, authOpts) + if err != nil { + log.Errorf("CodeBuddy authentication failed: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Authenticated as %s\n", record.Label) + } + fmt.Println("CodeBuddy authentication successful!") +} diff --git a/internal/cmd/cursor_login.go b/internal/cmd/cursor_login.go new file mode 100644 index 0000000000..7317044767 --- /dev/null +++ b/internal/cmd/cursor_login.go @@ -0,0 +1,37 @@ +package cmd + +import ( + "context" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" + log "github.com/sirupsen/logrus" +) + +// DoCursorLogin triggers the OAuth PKCE flow for Cursor and saves tokens. +func DoCursorLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + manager := newAuthManager() + authOpts := &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + Metadata: map[string]string{}, + Prompt: options.Prompt, + } + + record, savedPath, err := manager.Login(context.Background(), "cursor", cfg, authOpts) + if err != nil { + log.Errorf("Cursor authentication failed: %v", err) + return + } + + if savedPath != "" { + log.Infof("Authentication saved to %s", savedPath) + } + if record != nil && record.Label != "" { + log.Infof("Authenticated as %s", record.Label) + } + log.Info("Cursor authentication successful!") +} diff --git a/internal/cmd/github_copilot_login.go b/internal/cmd/github_copilot_login.go new file mode 100644 index 0000000000..f605e87d7e --- /dev/null +++ b/internal/cmd/github_copilot_login.go @@ -0,0 +1,44 @@ +package cmd + +import ( + "context" + "fmt" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" + log "github.com/sirupsen/logrus" +) + +// DoGitHubCopilotLogin triggers the OAuth device flow for GitHub Copilot and saves tokens. +// It initiates the device flow authentication, displays the user code for the user to enter +// at GitHub's verification URL, and waits for authorization before saving the tokens. +// +// Parameters: +// - cfg: The application configuration containing proxy and auth directory settings +// - options: Login options including browser behavior settings +func DoGitHubCopilotLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + manager := newAuthManager() + authOpts := &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + Metadata: map[string]string{}, + Prompt: options.Prompt, + } + + record, savedPath, err := manager.Login(context.Background(), "github-copilot", cfg, authOpts) + if err != nil { + log.Errorf("GitHub Copilot authentication failed: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Authenticated as %s\n", record.Label) + } + fmt.Println("GitHub Copilot authentication successful!") +} diff --git a/internal/cmd/gitlab_login.go b/internal/cmd/gitlab_login.go new file mode 100644 index 0000000000..df8e56f54a --- /dev/null +++ b/internal/cmd/gitlab_login.go @@ -0,0 +1,69 @@ +package cmd + +import ( + "context" + "fmt" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" +) + +func DoGitLabLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + promptFn := options.Prompt + if promptFn == nil { + promptFn = defaultProjectPrompt() + } + + manager := newAuthManager() + authOpts := &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + CallbackPort: options.CallbackPort, + Metadata: map[string]string{ + "login_mode": "oauth", + }, + Prompt: promptFn, + } + + _, savedPath, err := manager.Login(context.Background(), "gitlab", cfg, authOpts) + if err != nil { + fmt.Printf("GitLab Duo authentication failed: %v\n", err) + return + } + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + fmt.Println("GitLab Duo authentication successful!") +} + +func DoGitLabTokenLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + promptFn := options.Prompt + if promptFn == nil { + promptFn = defaultProjectPrompt() + } + + manager := newAuthManager() + authOpts := &sdkAuth.LoginOptions{ + Metadata: map[string]string{ + "login_mode": "pat", + }, + Prompt: promptFn, + } + + _, savedPath, err := manager.Login(context.Background(), "gitlab", cfg, authOpts) + if err != nil { + fmt.Printf("GitLab Duo PAT authentication failed: %v\n", err) + return + } + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + fmt.Println("GitLab Duo PAT authentication successful!") +} diff --git a/internal/cmd/iflow_cookie.go b/internal/cmd/iflow_cookie.go new file mode 100644 index 0000000000..61408deac8 --- /dev/null +++ b/internal/cmd/iflow_cookie.go @@ -0,0 +1,98 @@ +package cmd + +import ( + "bufio" + "context" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/iflow" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" +) + +// DoIFlowCookieAuth performs the iFlow cookie-based authentication. +func DoIFlowCookieAuth(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + promptFn := options.Prompt + if promptFn == nil { + reader := bufio.NewReader(os.Stdin) + promptFn = func(prompt string) (string, error) { + fmt.Print(prompt) + value, err := reader.ReadString('\n') + if err != nil { + return "", err + } + return strings.TrimSpace(value), nil + } + } + + // Prompt user for cookie + cookie, err := promptForCookie(promptFn) + if err != nil { + fmt.Printf("Failed to get cookie: %v\n", err) + return + } + + // Check for duplicate BXAuth before authentication + bxAuth := iflow.ExtractBXAuth(cookie) + if existingFile, err := iflow.CheckDuplicateBXAuth(cfg.AuthDir, bxAuth); err != nil { + fmt.Printf("Failed to check duplicate: %v\n", err) + return + } else if existingFile != "" { + fmt.Printf("Duplicate BXAuth found, authentication already exists: %s\n", filepath.Base(existingFile)) + return + } + + // Authenticate with cookie + auth := iflow.NewIFlowAuth(cfg) + ctx := context.Background() + + tokenData, err := auth.AuthenticateWithCookie(ctx, cookie) + if err != nil { + fmt.Printf("iFlow cookie authentication failed: %v\n", err) + return + } + + // Create token storage + tokenStorage := auth.CreateCookieTokenStorage(tokenData) + + // Get auth file path using email in filename + authFilePath := getAuthFilePath(cfg, "iflow", tokenData.Email) + + // Save token to file + if err := tokenStorage.SaveTokenToFile(authFilePath); err != nil { + fmt.Printf("Failed to save authentication: %v\n", err) + return + } + + fmt.Printf("Authentication successful! API key: %s\n", tokenData.APIKey) + fmt.Printf("Expires at: %s\n", tokenData.Expire) + fmt.Printf("Authentication saved to: %s\n", authFilePath) +} + +// promptForCookie prompts the user to enter their iFlow cookie +func promptForCookie(promptFn func(string) (string, error)) (string, error) { + line, err := promptFn("Enter iFlow Cookie (from browser cookies): ") + if err != nil { + return "", fmt.Errorf("failed to read cookie: %w", err) + } + + cookie, err := iflow.NormalizeCookie(line) + if err != nil { + return "", err + } + + return cookie, nil +} + +// getAuthFilePath returns the auth file path for the given provider and email +func getAuthFilePath(cfg *config.Config, provider, email string) string { + fileName := iflow.SanitizeIFlowFileName(email) + return fmt.Sprintf("%s/%s-%s-%d.json", cfg.AuthDir, provider, fileName, time.Now().Unix()) +} diff --git a/internal/cmd/iflow_login.go b/internal/cmd/iflow_login.go new file mode 100644 index 0000000000..9015f5be67 --- /dev/null +++ b/internal/cmd/iflow_login.go @@ -0,0 +1,48 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" + log "github.com/sirupsen/logrus" +) + +// DoIFlowLogin performs the iFlow OAuth login via the shared authentication manager. +func DoIFlowLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + manager := newAuthManager() + + promptFn := options.Prompt + if promptFn == nil { + promptFn = defaultProjectPrompt() + } + + authOpts := &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + CallbackPort: options.CallbackPort, + Metadata: map[string]string{}, + Prompt: promptFn, + } + + _, savedPath, err := manager.Login(context.Background(), "iflow", cfg, authOpts) + if err != nil { + if emailErr, ok := errors.AsType[*sdkAuth.EmailRequiredError](err); ok { + log.Error(emailErr.Error()) + return + } + fmt.Printf("iFlow authentication failed: %v\n", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + + fmt.Println("iFlow authentication successful!") +} diff --git a/internal/cmd/joycode_login.go b/internal/cmd/joycode_login.go new file mode 100644 index 0000000000..f1c7a6e9eb --- /dev/null +++ b/internal/cmd/joycode_login.go @@ -0,0 +1,37 @@ +package cmd + +import ( + "context" + "fmt" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" + log "github.com/sirupsen/logrus" +) + +func DoJoyCodeLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + manager := newAuthManager() + authOpts := &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + CallbackPort: options.CallbackPort, + Metadata: map[string]string{}, + } + + record, savedPath, err := manager.Login(context.Background(), "joycode", cfg, authOpts) + if err != nil { + log.Errorf("JoyCode authentication failed: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Authenticated as %s\n", record.Label) + } + fmt.Println("JoyCode authentication successful!") +} diff --git a/internal/cmd/kilo_login.go b/internal/cmd/kilo_login.go new file mode 100644 index 0000000000..4f9b37f25b --- /dev/null +++ b/internal/cmd/kilo_login.go @@ -0,0 +1,54 @@ +package cmd + +import ( + "context" + "fmt" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" +) + +// DoKiloLogin handles the Kilo device flow using the shared authentication manager. +// It initiates the device-based authentication process for Kilo AI services and saves +// the authentication tokens to the configured auth directory. +// +// Parameters: +// - cfg: The application configuration +// - options: Login options including browser behavior and prompts +func DoKiloLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + manager := newAuthManager() + + promptFn := options.Prompt + if promptFn == nil { + promptFn = func(prompt string) (string, error) { + fmt.Print(prompt) + var value string + fmt.Scanln(&value) + return strings.TrimSpace(value), nil + } + } + + authOpts := &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + CallbackPort: options.CallbackPort, + Metadata: map[string]string{}, + Prompt: promptFn, + } + + _, savedPath, err := manager.Login(context.Background(), "kilo", cfg, authOpts) + if err != nil { + fmt.Printf("Kilo authentication failed: %v\n", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + + fmt.Println("Kilo authentication successful!") +} diff --git a/internal/cmd/kimi_login.go b/internal/cmd/kimi_login.go index eb5f11fb37..ffc470fda0 100644 --- a/internal/cmd/kimi_login.go +++ b/internal/cmd/kimi_login.go @@ -4,8 +4,8 @@ import ( "context" "fmt" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" log "github.com/sirupsen/logrus" ) diff --git a/internal/cmd/kiro_login.go b/internal/cmd/kiro_login.go new file mode 100644 index 0000000000..b2fd6f59e5 --- /dev/null +++ b/internal/cmd/kiro_login.go @@ -0,0 +1,257 @@ +package cmd + +import ( + "context" + "fmt" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" + log "github.com/sirupsen/logrus" +) + +// DoKiroLogin triggers the Kiro authentication flow with Google OAuth. +// This is the default login method (same as --kiro-google-login). +// +// Parameters: +// - cfg: The application configuration +// - options: Login options including Prompt field +func DoKiroLogin(cfg *config.Config, options *LoginOptions) { + // Use Google login as default + DoKiroGoogleLogin(cfg, options) +} + +// DoKiroGoogleLogin triggers Kiro authentication with Google OAuth. +// This uses a custom protocol handler (kiro://) to receive the callback. +// +// Parameters: +// - cfg: The application configuration +// - options: Login options including prompts +func DoKiroGoogleLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + // Note: Kiro defaults to incognito mode for multi-account support. + // Users can override with --no-incognito if they want to use existing browser sessions. + + manager := newAuthManager() + + // Use KiroAuthenticator with Google login + authenticator := sdkAuth.NewKiroAuthenticator() + record, err := authenticator.LoginWithGoogle(context.Background(), cfg, &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + Metadata: map[string]string{}, + Prompt: options.Prompt, + }) + if err != nil { + log.Errorf("Kiro Google authentication failed: %v", err) + fmt.Println("\nTroubleshooting:") + fmt.Println("1. Make sure the protocol handler is installed") + fmt.Println("2. Complete the Google login in the browser") + fmt.Println("3. If callback fails, try: --kiro-import (after logging in via Kiro IDE)") + return + } + + // Save the auth record + savedPath, err := manager.SaveAuth(record, cfg) + if err != nil { + log.Errorf("Failed to save auth: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Authenticated as %s\n", record.Label) + } + fmt.Println("Kiro Google authentication successful!") +} + +// DoKiroAWSLogin triggers Kiro authentication with AWS Builder ID. +// This uses the device code flow for AWS SSO OIDC authentication. +// +// Parameters: +// - cfg: The application configuration +// - options: Login options including prompts +func DoKiroAWSLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + // Note: Kiro defaults to incognito mode for multi-account support. + // Users can override with --no-incognito if they want to use existing browser sessions. + + manager := newAuthManager() + + // Use KiroAuthenticator with AWS Builder ID login (device code flow) + authenticator := sdkAuth.NewKiroAuthenticator() + record, err := authenticator.Login(context.Background(), cfg, &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + Metadata: map[string]string{}, + Prompt: options.Prompt, + }) + if err != nil { + log.Errorf("Kiro AWS authentication failed: %v", err) + fmt.Println("\nTroubleshooting:") + fmt.Println("1. Make sure you have an AWS Builder ID") + fmt.Println("2. Complete the authorization in the browser") + fmt.Println("3. If callback fails, try: --kiro-import (after logging in via Kiro IDE)") + return + } + + // Save the auth record + savedPath, err := manager.SaveAuth(record, cfg) + if err != nil { + log.Errorf("Failed to save auth: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Authenticated as %s\n", record.Label) + } + fmt.Println("Kiro AWS authentication successful!") +} + +// DoKiroAWSAuthCodeLogin triggers Kiro authentication with AWS Builder ID using authorization code flow. +// This provides a better UX than device code flow as it uses automatic browser callback. +// +// Parameters: +// - cfg: The application configuration +// - options: Login options including prompts +func DoKiroAWSAuthCodeLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + // Note: Kiro defaults to incognito mode for multi-account support. + // Users can override with --no-incognito if they want to use existing browser sessions. + + manager := newAuthManager() + + // Use KiroAuthenticator with AWS Builder ID login (authorization code flow) + authenticator := sdkAuth.NewKiroAuthenticator() + record, err := authenticator.LoginWithAuthCode(context.Background(), cfg, &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + Metadata: map[string]string{}, + Prompt: options.Prompt, + }) + if err != nil { + log.Errorf("Kiro AWS authentication (auth code) failed: %v", err) + fmt.Println("\nTroubleshooting:") + fmt.Println("1. Make sure you have an AWS Builder ID") + fmt.Println("2. Complete the authorization in the browser") + fmt.Println("3. If callback fails, try: --kiro-aws-login (device code flow)") + return + } + + // Save the auth record + savedPath, err := manager.SaveAuth(record, cfg) + if err != nil { + log.Errorf("Failed to save auth: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Authenticated as %s\n", record.Label) + } + fmt.Println("Kiro AWS authentication successful!") +} + +// DoKiroImport imports Kiro token from Kiro IDE's token file. +// This is useful for users who have already logged in via Kiro IDE +// and want to use the same credentials in CLI Proxy API. +// +// Parameters: +// - cfg: The application configuration +// - options: Login options (currently unused for import) +func DoKiroImport(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + manager := newAuthManager() + + // Use ImportFromKiroIDE instead of Login + authenticator := sdkAuth.NewKiroAuthenticator() + record, err := authenticator.ImportFromKiroIDE(context.Background(), cfg) + if err != nil { + log.Errorf("Kiro token import failed: %v", err) + fmt.Println("\nMake sure you have logged in to Kiro IDE first:") + fmt.Println("1. Open Kiro IDE") + fmt.Println("2. Click 'Sign in with Google' (or GitHub)") + fmt.Println("3. Complete the login process") + fmt.Println("4. Run this command again") + return + } + + // Save the imported auth record + savedPath, err := manager.SaveAuth(record, cfg) + if err != nil { + log.Errorf("Failed to save auth: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Imported as %s\n", record.Label) + } + fmt.Println("Kiro token import successful!") +} + +func DoKiroIDCLogin(cfg *config.Config, options *LoginOptions, startURL, region, flow string) { + if options == nil { + options = &LoginOptions{} + } + + if startURL == "" { + log.Errorf("Kiro IDC login requires --kiro-idc-start-url") + fmt.Println("\nUsage: --kiro-idc-login --kiro-idc-start-url https://d-xxx.awsapps.com/start") + return + } + + manager := newAuthManager() + + authenticator := sdkAuth.NewKiroAuthenticator() + metadata := map[string]string{ + "start-url": startURL, + "region": region, + "flow": flow, + } + + record, err := authenticator.Login(context.Background(), cfg, &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + Metadata: metadata, + Prompt: options.Prompt, + }) + if err != nil { + log.Errorf("Kiro IDC authentication failed: %v", err) + fmt.Println("\nTroubleshooting:") + fmt.Println("1. Make sure your IDC Start URL is correct") + fmt.Println("2. Complete the authorization in the browser") + fmt.Println("3. If auth code flow fails, try: --kiro-idc-flow device") + return + } + + savedPath, err := manager.SaveAuth(record, cfg) + if err != nil { + log.Errorf("Failed to save auth: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Authenticated as %s\n", record.Label) + } + fmt.Println("Kiro IDC authentication successful!") +} diff --git a/internal/cmd/login.go b/internal/cmd/login.go index 16af718ebb..a71bb28263 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -17,12 +17,12 @@ import ( "strings" "time" - "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" - "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" - sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" - cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/gemini" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + "github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces" + "github.com/router-for-me/CLIProxyAPI/v7/internal/misc" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" ) @@ -333,42 +333,10 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage finalProjectID := projectID if responseProjectID != "" { if explicitProject && !strings.EqualFold(responseProjectID, projectID) { - // Check if this is a free user (gen-lang-client projects or free/legacy tier) - isFreeUser := strings.HasPrefix(projectID, "gen-lang-client-") || - strings.EqualFold(tierID, "FREE") || - strings.EqualFold(tierID, "LEGACY") - - if isFreeUser { - // Interactive prompt for free users - fmt.Printf("\nGoogle returned a different project ID:\n") - fmt.Printf(" Requested (frontend): %s\n", projectID) - fmt.Printf(" Returned (backend): %s\n\n", responseProjectID) - fmt.Printf(" Backend project IDs have access to preview models (gemini-3-*).\n") - fmt.Printf(" This is normal for free tier users.\n\n") - fmt.Printf("Which project ID would you like to use?\n") - fmt.Printf(" [1] Backend (recommended): %s\n", responseProjectID) - fmt.Printf(" [2] Frontend: %s\n\n", projectID) - fmt.Printf("Enter choice [1]: ") - - reader := bufio.NewReader(os.Stdin) - choice, _ := reader.ReadString('\n') - choice = strings.TrimSpace(choice) - - if choice == "2" { - log.Infof("Using frontend project ID: %s", projectID) - fmt.Println(". Warning: Frontend project IDs may not have access to preview models.") - finalProjectID = projectID - } else { - log.Infof("Using backend project ID: %s (recommended)", responseProjectID) - finalProjectID = responseProjectID - } - } else { - // Pro users: keep requested project ID (original behavior) - log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID) - } - } else { - finalProjectID = responseProjectID + log.Infof("Gemini onboarding: requested project %s maps to backend project %s", projectID, responseProjectID) + log.Infof("Using backend project ID: %s", responseProjectID) } + finalProjectID = responseProjectID } storage.ProjectID = strings.TrimSpace(finalProjectID) diff --git a/internal/cmd/openai_device_login.go b/internal/cmd/openai_device_login.go index 1b7351e63a..3fa9307b9c 100644 --- a/internal/cmd/openai_device_login.go +++ b/internal/cmd/openai_device_login.go @@ -6,9 +6,9 @@ import ( "fmt" "os" - "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" + "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/codex" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" log "github.com/sirupsen/logrus" ) diff --git a/internal/cmd/openai_login.go b/internal/cmd/openai_login.go index 783a948400..ee8a025067 100644 --- a/internal/cmd/openai_login.go +++ b/internal/cmd/openai_login.go @@ -6,9 +6,9 @@ import ( "fmt" "os" - "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" + "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/codex" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" log "github.com/sirupsen/logrus" ) diff --git a/internal/cmd/qoder_login.go b/internal/cmd/qoder_login.go new file mode 100644 index 0000000000..a27bf64d8c --- /dev/null +++ b/internal/cmd/qoder_login.go @@ -0,0 +1,44 @@ +package cmd + +import ( + "context" + "fmt" + + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" + log "github.com/sirupsen/logrus" +) + +// DoQoderLogin triggers the PKCE browser flow for the Qoder provider and saves tokens. +func DoQoderLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + promptFn := options.Prompt + if promptFn == nil { + promptFn = defaultProjectPrompt() + } + + manager := newAuthManager() + authOpts := &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + CallbackPort: options.CallbackPort, + Metadata: map[string]string{}, + Prompt: promptFn, + } + + record, savedPath, err := manager.Login(context.Background(), "qoder", cfg, authOpts) + if err != nil { + log.Errorf("Qoder authentication failed: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Authenticated as %s\n", record.Label) + } + fmt.Println("Qoder authentication successful!") +} diff --git a/internal/cmd/run.go b/internal/cmd/run.go index d8c4f01938..38f189b4a9 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -10,9 +10,9 @@ import ( "syscall" "time" - "github.com/router-for-me/CLIProxyAPI/v6/internal/api" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy" + "github.com/router-for-me/CLIProxyAPI/v7/internal/api" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy" log "github.com/sirupsen/logrus" ) diff --git a/internal/cmd/vertex_import.go b/internal/cmd/vertex_import.go index 4aa0d74b59..ffb6200b1a 100644 --- a/internal/cmd/vertex_import.go +++ b/internal/cmd/vertex_import.go @@ -9,11 +9,11 @@ import ( "os" "strings" - "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/util" - sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" - coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/vertex" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + "github.com/router-for-me/CLIProxyAPI/v7/internal/util" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" + coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" log "github.com/sirupsen/logrus" ) diff --git a/internal/config/config.go b/internal/config/config.go index 760d43ec4a..0760f987fb 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -13,7 +13,7 @@ import ( "strings" "syscall" - "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + "github.com/router-for-me/CLIProxyAPI/v7/internal/registry" log "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt" "gopkg.in/yaml.v3" @@ -22,6 +22,7 @@ import ( const ( DefaultPanelGitHubRepository = "https://github.com/router-for-me/Cli-Proxy-API-Management-Center" DefaultPprofAddr = "127.0.0.1:8316" + DefaultAuthDir = "~/.cli-proxy-api" ) // Config represents the application's configuration, loaded from a YAML file. @@ -36,6 +37,9 @@ type Config struct { // TLS config controls HTTPS server settings. TLS TLSConfig `yaml:"tls" json:"tls"` + // Home config enables the Redis-based control plane integration. + Home HomeConfig `yaml:"home" json:"-"` + // RemoteManagement nests management-related options under 'remote-management'. RemoteManagement RemoteManagement `yaml:"remote-management" json:"-"` @@ -65,6 +69,11 @@ type Config struct { // UsageStatisticsEnabled toggles in-memory usage aggregation; when false, usage data is discarded. UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled" json:"usage-statistics-enabled"` + // RedisUsageQueueRetentionSeconds controls how long (in seconds) usage queue items + // are retained in memory for the Redis RESP interface (LPOP/RPOP). + // Default: 60. Max: 3600. + RedisUsageQueueRetentionSeconds int `yaml:"redis-usage-queue-retention-seconds" json:"redis-usage-queue-retention-seconds"` + // DisableCooling disables quota cooldown scheduling when true. DisableCooling bool `yaml:"disable-cooling" json:"disable-cooling"` @@ -120,6 +129,20 @@ type Config struct { // Used for services that use Vertex AI-style paths but with simple API key authentication. VertexCompatAPIKey []VertexCompatKey `yaml:"vertex-api-key" json:"vertex-api-key"` + // KiroKey defines a list of Kiro (AWS CodeWhisperer) configurations. + KiroKey []KiroKey `yaml:"kiro" json:"kiro"` + + // BTKey defines a list of BaoTa (BT Panel) AI configurations. + BTKey []BTKey `yaml:"bt" json:"bt"` + + // KiroFingerprint defines a global fingerprint configuration for all Kiro requests. + // When set, all Kiro requests will use this fixed fingerprint instead of random generation. + KiroFingerprint *KiroFingerprintConfig `yaml:"kiro-fingerprint,omitempty" json:"kiro-fingerprint,omitempty"` + + // KiroPreferredEndpoint sets the global default preferred endpoint for all Kiro providers. + // Values: "ide" (default, CodeWhisperer) or "cli" (Amazon Q). + KiroPreferredEndpoint string `yaml:"kiro-preferred-endpoint" json:"kiro-preferred-endpoint"` + // AmpCode contains Amp CLI upstream configuration, management restrictions, and model mappings. AmpCode AmpCode `yaml:"ampcode" json:"ampcode"` @@ -134,6 +157,11 @@ type Config struct { // gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode. OAuthModelAlias map[string][]OAuthModelAlias `yaml:"oauth-model-alias,omitempty" json:"oauth-model-alias,omitempty"` + // IncognitoBrowser enables opening OAuth URLs in incognito/private browsing mode. + // This is useful when you want to login with a different account without logging out + // from your current session. Default: false. + IncognitoBrowser bool `yaml:"incognito-browser" json:"incognito-browser"` + // Payload defines default and override rules for provider payload parameters. Payload PayloadConfig `yaml:"payload" json:"payload"` @@ -206,8 +234,9 @@ type QuotaExceeded struct { // SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded. SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"` - // AntigravityCredits indicates whether to retry Antigravity quota_exhausted 429s once - // on the same credential with enabledCreditTypes=["GOOGLE_ONE_AI"]. + // AntigravityCredits enables credits-based last-resort fallback for Claude models. + // When all free-tier auths are exhausted (429/503), the conductor retries with + // an auth that has available Google One AI credits. AntigravityCredits bool `yaml:"antigravity-credits" json:"antigravity-credits"` } @@ -217,15 +246,11 @@ type RoutingConfig struct { // Supported values: "round-robin" (default), "fill-first". Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"` - // ClaudeCodeSessionAffinity enables session-sticky routing for Claude Code clients. - // When enabled, requests with the same session ID (extracted from metadata.user_id) - // are routed to the same auth credential when available. - // Deprecated: Use SessionAffinity instead for universal session support. - ClaudeCodeSessionAffinity bool `yaml:"claude-code-session-affinity,omitempty" json:"claude-code-session-affinity,omitempty"` - // SessionAffinity enables universal session-sticky routing for all clients. // Session IDs are extracted from multiple sources: - // X-Session-ID header, Idempotency-Key, metadata.user_id, conversation_id, or message hash. + // metadata.user_id (Claude Code session format), X-Session-ID, Session_id (Codex), + // X-Amp-Thread-Id (Amp CLI thread), X-Client-Request-Id (PI), metadata.user_id, + // conversation_id, or message hash. // Automatic failover is always enabled when bound auth becomes unavailable. SessionAffinity bool `yaml:"session-affinity,omitempty" json:"session-affinity,omitempty"` @@ -392,6 +417,9 @@ type ClaudeKey struct { // ExcludedModels lists model IDs that should be excluded for this provider. ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"` + // DisableCooling disables auth/model cooldown scheduling for this credential when true. + DisableCooling bool `yaml:"disable-cooling,omitempty" json:"disable-cooling,omitempty"` + // Cloak configures request cloaking for non-Claude-Code clients. Cloak *CloakConfig `yaml:"cloak,omitempty" json:"cloak,omitempty"` @@ -447,6 +475,9 @@ type CodexKey struct { // ExcludedModels lists model IDs that should be excluded for this provider. ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"` + + // DisableCooling disables auth/model cooldown scheduling for this credential when true. + DisableCooling bool `yaml:"disable-cooling,omitempty" json:"disable-cooling,omitempty"` } func (k CodexKey) GetAPIKey() string { return k.APIKey } @@ -491,6 +522,9 @@ type GeminiKey struct { // ExcludedModels lists model IDs that should be excluded for this provider. ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"` + + // DisableCooling disables auth/model cooldown scheduling for this credential when true. + DisableCooling bool `yaml:"disable-cooling,omitempty" json:"disable-cooling,omitempty"` } func (k GeminiKey) GetAPIKey() string { return k.APIKey } @@ -518,6 +552,9 @@ type OpenAICompatibility struct { // Higher values are preferred; defaults to 0. Priority int `yaml:"priority,omitempty" json:"priority,omitempty"` + // Disabled prevents this provider from being used for routing. + Disabled bool `yaml:"disabled,omitempty" json:"disabled,omitempty"` + // Prefix optionally namespaces model aliases for this provider (e.g., "teamA/kimi-k2"). Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"` @@ -532,6 +569,9 @@ type OpenAICompatibility struct { // Headers optionally adds extra HTTP headers for requests sent to this provider. Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` + + // DisableCooling disables auth/model cooldown scheduling for this provider when true. + DisableCooling bool `yaml:"disable-cooling,omitempty" json:"disable-cooling,omitempty"` } // OpenAICompatibilityAPIKey represents an API key configuration with optional proxy setting. @@ -603,7 +643,10 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { cfg.LogsMaxTotalSizeMB = 0 cfg.ErrorLogsMaxFiles = 10 cfg.UsageStatisticsEnabled = false + cfg.RedisUsageQueueRetentionSeconds = 60 + cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force) cfg.DisableCooling = false + cfg.DisableImageGeneration = DisableImageGenerationOff cfg.Pprof.Enable = false cfg.Pprof.Addr = DefaultPprofAddr cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient @@ -664,6 +707,13 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { cfg.ErrorLogsMaxFiles = 10 } + if cfg.RedisUsageQueueRetentionSeconds <= 0 { + cfg.RedisUsageQueueRetentionSeconds = 60 + } else if cfg.RedisUsageQueueRetentionSeconds > 3600 { + log.WithField("value", cfg.RedisUsageQueueRetentionSeconds).Warn("redis-usage-queue-retention-seconds too large; clamping to 3600") + cfg.RedisUsageQueueRetentionSeconds = 3600 + } + if cfg.MaxRetryCredentials < 0 { cfg.MaxRetryCredentials = 0 } @@ -695,6 +745,12 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { // Normalize global OAuth model name aliases. cfg.SanitizeOAuthModelAlias() + // Sanitize Kiro keys: trim whitespace from credential fields + cfg.SanitizeKiroKeys() + + // Sanitize BT keys: trim whitespace from credential fields + cfg.SanitizeBTKeys() + // Validate raw payload rules and drop invalid entries. cfg.SanitizePayloadRules() @@ -1921,3 +1977,116 @@ func removeLegacyAuthBlock(root *yaml.Node) { } removeMapKey(root, "auth") } + +// KiroKey represents the configuration for Kiro (AWS CodeWhisperer) authentication. +type KiroKey struct { + // TokenFile is the path to the Kiro token file (default: ~/.aws/sso/cache/kiro-auth-token.json) + TokenFile string `yaml:"token-file,omitempty" json:"token-file,omitempty"` + + // AccessToken is the OAuth access token for direct configuration. + AccessToken string `yaml:"access-token,omitempty" json:"access-token,omitempty"` + + // RefreshToken is the OAuth refresh token for token renewal. + RefreshToken string `yaml:"refresh-token,omitempty" json:"refresh-token,omitempty"` + + // ProfileArn is the AWS CodeWhisperer profile ARN. + ProfileArn string `yaml:"profile-arn,omitempty" json:"profile-arn,omitempty"` + + // Region is the AWS region (default: us-east-1). + Region string `yaml:"region,omitempty" json:"region,omitempty"` + + // StartURL is the IAM Identity Center (IDC) start URL for SSO login. + StartURL string `yaml:"start-url,omitempty" json:"start-url,omitempty"` + + // ProxyURL optionally overrides the global proxy for this configuration. + ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"` + + // AgentTaskType sets the Kiro API task type. Known values: "vibe", "dev", "chat". + // Leave empty to let API use defaults. Different values may inject different system prompts. + AgentTaskType string `yaml:"agent-task-type,omitempty" json:"agent-task-type,omitempty"` + + // PreferredEndpoint sets the preferred Kiro API endpoint/quota. + // Values: "codewhisperer" (default, IDE quota) or "amazonq" (CLI quota). + PreferredEndpoint string `yaml:"preferred-endpoint,omitempty" json:"preferred-endpoint,omitempty"` +} + +// BTKey represents the configuration for BaoTa (BT Panel) AI authentication. +// Phone is stored in plaintext; password is stored as base64-encoded. +type BTKey struct { + // Phone is the BaoTa account phone number (plaintext). + Phone string `yaml:"phone" json:"phone"` + + // Password is the BaoTa account password (base64-encoded). + Password string `yaml:"password" json:"password"` + + // Prefix optionally namespaces model aliases for this provider. + Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"` + + // ProxyURL optionally overrides the global proxy for this configuration. + ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"` + + // Models defines optional model aliases for this credential. + Models []OpenAICompatibilityModel `yaml:"models,omitempty" json:"models,omitempty"` + + // ExcludedModels defines models to exclude from listing. + ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"` + + // Priority controls selection preference. Higher values are preferred; defaults to 0. + Priority int `yaml:"priority,omitempty" json:"priority,omitempty"` + + // Headers optionally adds extra HTTP headers for requests sent to this provider. + Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` +} + +// KiroFingerprintConfig defines a global fingerprint configuration for Kiro requests. +// When configured, all Kiro requests will use this fixed fingerprint instead of random generation. +// Empty fields will fall back to random selection from built-in pools. +type KiroFingerprintConfig struct { + OIDCSDKVersion string `yaml:"oidc-sdk-version,omitempty" json:"oidc-sdk-version,omitempty"` + RuntimeSDKVersion string `yaml:"runtime-sdk-version,omitempty" json:"runtime-sdk-version,omitempty"` + StreamingSDKVersion string `yaml:"streaming-sdk-version,omitempty" json:"streaming-sdk-version,omitempty"` + OSType string `yaml:"os-type,omitempty" json:"os-type,omitempty"` + OSVersion string `yaml:"os-version,omitempty" json:"os-version,omitempty"` + NodeVersion string `yaml:"node-version,omitempty" json:"node-version,omitempty"` + KiroVersion string `yaml:"kiro-version,omitempty" json:"kiro-version,omitempty"` + KiroHash string `yaml:"kiro-hash,omitempty" json:"kiro-hash,omitempty"` +} + +// SanitizeKiroKeys trims whitespace from Kiro credential fields. +func (cfg *Config) SanitizeKiroKeys() { + if cfg == nil || len(cfg.KiroKey) == 0 { + return + } + for i := range cfg.KiroKey { + entry := &cfg.KiroKey[i] + entry.TokenFile = strings.TrimSpace(entry.TokenFile) + entry.AccessToken = strings.TrimSpace(entry.AccessToken) + entry.RefreshToken = strings.TrimSpace(entry.RefreshToken) + entry.ProfileArn = strings.TrimSpace(entry.ProfileArn) + entry.Region = strings.TrimSpace(entry.Region) + entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) + entry.PreferredEndpoint = strings.TrimSpace(entry.PreferredEndpoint) + } +} + +// SanitizeBTKeys trims whitespace and validates BT credential fields. +func (cfg *Config) SanitizeBTKeys() { + if cfg == nil || len(cfg.BTKey) == 0 { + return + } + out := make([]BTKey, 0, len(cfg.BTKey)) + for i := range cfg.BTKey { + entry := cfg.BTKey[i] + entry.Phone = strings.TrimSpace(entry.Phone) + entry.Password = strings.TrimSpace(entry.Password) + entry.Prefix = normalizeModelPrefix(entry.Prefix) + entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) + entry.Headers = NormalizeHeaders(entry.Headers) + entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels) + if entry.Phone == "" || entry.Password == "" { + continue + } + out = append(out, entry) + } + cfg.BTKey = out +} diff --git a/internal/config/disable_image_generation_mode.go b/internal/config/disable_image_generation_mode.go new file mode 100644 index 0000000000..1712638b86 --- /dev/null +++ b/internal/config/disable_image_generation_mode.go @@ -0,0 +1,136 @@ +package config + +import ( + "bytes" + "encoding/json" + "fmt" + "strings" + + "gopkg.in/yaml.v3" +) + +// DisableImageGenerationMode is a tri-state config value for disable-image-generation. +// +// It supports: +// - false: enabled +// - true: disabled everywhere (including /v1/images/* endpoints) +// - "chat": disabled for all non-images endpoints, but enabled for /v1/images/generations and /v1/images/edits +type DisableImageGenerationMode int + +const ( + DisableImageGenerationOff DisableImageGenerationMode = iota + DisableImageGenerationAll + DisableImageGenerationChat +) + +func (m DisableImageGenerationMode) String() string { + switch m { + case DisableImageGenerationOff: + return "false" + case DisableImageGenerationAll: + return "true" + case DisableImageGenerationChat: + return "chat" + default: + return "false" + } +} + +func (m DisableImageGenerationMode) MarshalYAML() (any, error) { + switch m { + case DisableImageGenerationAll: + return true, nil + case DisableImageGenerationChat: + return "chat", nil + default: + return false, nil + } +} + +func (m *DisableImageGenerationMode) UnmarshalYAML(value *yaml.Node) error { + mode, err := parseDisableImageGenerationNode(value) + if err != nil { + return err + } + *m = mode + return nil +} + +func (m DisableImageGenerationMode) MarshalJSON() ([]byte, error) { + switch m { + case DisableImageGenerationAll: + return []byte("true"), nil + case DisableImageGenerationChat: + return json.Marshal("chat") + default: + return []byte("false"), nil + } +} + +func (m *DisableImageGenerationMode) UnmarshalJSON(data []byte) error { + mode, err := parseDisableImageGenerationJSON(data) + if err != nil { + return err + } + *m = mode + return nil +} + +func parseDisableImageGenerationNode(value *yaml.Node) (DisableImageGenerationMode, error) { + if value == nil { + return DisableImageGenerationOff, nil + } + + // First try a typed bool decode (covers unquoted true/false and YAML 1.1 bools). + var b bool + if err := value.Decode(&b); err == nil && value.Kind == yaml.ScalarNode && value.ShortTag() == "!!bool" { + if b { + return DisableImageGenerationAll, nil + } + return DisableImageGenerationOff, nil + } + + // Fall back to string decoding (covers quoted "true"/"false" and "chat"). + var s string + if err := value.Decode(&s); err != nil { + return DisableImageGenerationOff, fmt.Errorf("invalid disable-image-generation value") + } + return parseDisableImageGenerationString(s) +} + +func parseDisableImageGenerationJSON(data []byte) (DisableImageGenerationMode, error) { + trimmed := bytes.TrimSpace(data) + if len(trimmed) == 0 || bytes.Equal(trimmed, []byte("null")) { + return DisableImageGenerationOff, nil + } + + // bool + var b bool + if err := json.Unmarshal(trimmed, &b); err == nil { + if b { + return DisableImageGenerationAll, nil + } + return DisableImageGenerationOff, nil + } + + // string + var s string + if err := json.Unmarshal(trimmed, &s); err != nil { + return DisableImageGenerationOff, fmt.Errorf("invalid disable-image-generation value") + } + return parseDisableImageGenerationString(s) +} + +func parseDisableImageGenerationString(s string) (DisableImageGenerationMode, error) { + s = strings.TrimSpace(strings.ToLower(s)) + switch s { + case "", "false", "0", "off", "no": + return DisableImageGenerationOff, nil + case "true", "1", "on", "yes": + return DisableImageGenerationAll, nil + case "chat": + return DisableImageGenerationChat, nil + default: + return DisableImageGenerationOff, fmt.Errorf("invalid disable-image-generation value %q (allowed: true, false, chat)", s) + } +} diff --git a/internal/config/disable_image_generation_mode_test.go b/internal/config/disable_image_generation_mode_test.go new file mode 100644 index 0000000000..433a5cbf96 --- /dev/null +++ b/internal/config/disable_image_generation_mode_test.go @@ -0,0 +1,76 @@ +package config + +import ( + "encoding/json" + "testing" + + "gopkg.in/yaml.v3" +) + +func TestDisableImageGenerationMode_UnmarshalYAML(t *testing.T) { + type wrapper struct { + V DisableImageGenerationMode `yaml:"disable-image-generation"` + } + + { + var w wrapper + if err := yaml.Unmarshal([]byte("disable-image-generation: false\n"), &w); err != nil { + t.Fatalf("unmarshal false: %v", err) + } + if w.V != DisableImageGenerationOff { + t.Fatalf("false => %v, want %v", w.V, DisableImageGenerationOff) + } + } + + { + var w wrapper + if err := yaml.Unmarshal([]byte("disable-image-generation: true\n"), &w); err != nil { + t.Fatalf("unmarshal true: %v", err) + } + if w.V != DisableImageGenerationAll { + t.Fatalf("true => %v, want %v", w.V, DisableImageGenerationAll) + } + } + + { + var w wrapper + if err := yaml.Unmarshal([]byte("disable-image-generation: chat\n"), &w); err != nil { + t.Fatalf("unmarshal chat: %v", err) + } + if w.V != DisableImageGenerationChat { + t.Fatalf("chat => %v, want %v", w.V, DisableImageGenerationChat) + } + } +} + +func TestDisableImageGenerationMode_UnmarshalJSON(t *testing.T) { + { + var v DisableImageGenerationMode + if err := json.Unmarshal([]byte("false"), &v); err != nil { + t.Fatalf("unmarshal false: %v", err) + } + if v != DisableImageGenerationOff { + t.Fatalf("false => %v, want %v", v, DisableImageGenerationOff) + } + } + + { + var v DisableImageGenerationMode + if err := json.Unmarshal([]byte("true"), &v); err != nil { + t.Fatalf("unmarshal true: %v", err) + } + if v != DisableImageGenerationAll { + t.Fatalf("true => %v, want %v", v, DisableImageGenerationAll) + } + } + + { + var v DisableImageGenerationMode + if err := json.Unmarshal([]byte(`"chat"`), &v); err != nil { + t.Fatalf("unmarshal chat: %v", err) + } + if v != DisableImageGenerationChat { + t.Fatalf("chat => %v, want %v", v, DisableImageGenerationChat) + } + } +} diff --git a/internal/config/home.go b/internal/config/home.go new file mode 100644 index 0000000000..03c9173239 --- /dev/null +++ b/internal/config/home.go @@ -0,0 +1,9 @@ +package config + +// HomeConfig configures the optional "home" control plane integration over Redis protocol. +type HomeConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` + Host string `yaml:"host" json:"-"` + Port int `yaml:"port" json:"-"` + Password string `yaml:"password" json:"-"` +} diff --git a/internal/config/oauth_model_alias_defaults.go b/internal/config/oauth_model_alias_defaults.go new file mode 100644 index 0000000000..5eda56abe5 --- /dev/null +++ b/internal/config/oauth_model_alias_defaults.go @@ -0,0 +1,61 @@ +package config + +import "strings" + +// defaultKiroAliases returns default oauth-model-alias entries for Kiro. +// These aliases expose standard Claude IDs for Kiro-prefixed upstream models. +func defaultKiroAliases() []OAuthModelAlias { + return []OAuthModelAlias{ + // Sonnet 4.6 + {Name: "kiro-claude-sonnet-4-6", Alias: "claude-sonnet-4-6", Fork: true}, + // Sonnet 4.5 + {Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5-20250929", Fork: true}, + {Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5", Fork: true}, + // Sonnet 4 + {Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4-20250514", Fork: true}, + {Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4", Fork: true}, + // Opus 4.6 + {Name: "kiro-claude-opus-4-6", Alias: "claude-opus-4-6", Fork: true}, + // Opus 4.5 + {Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5-20251101", Fork: true}, + {Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5", Fork: true}, + // Haiku 4.5 + {Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5-20251001", Fork: true}, + {Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5", Fork: true}, + } +} + +// defaultGitHubCopilotAliases returns default oauth-model-alias entries for +// GitHub Copilot Claude models. It exposes hyphen-style IDs used by clients. +func defaultGitHubCopilotAliases() []OAuthModelAlias { + return []OAuthModelAlias{ + {Name: "claude-haiku-4.5", Alias: "claude-haiku-4-5", Fork: true}, + {Name: "claude-opus-4.1", Alias: "claude-opus-4-1", Fork: true}, + {Name: "claude-opus-4.5", Alias: "claude-opus-4-5", Fork: true}, + {Name: "claude-opus-4.6", Alias: "claude-opus-4-6", Fork: true}, + {Name: "claude-sonnet-4.5", Alias: "claude-sonnet-4-5", Fork: true}, + {Name: "claude-sonnet-4.6", Alias: "claude-sonnet-4-6", Fork: true}, + } +} + +// GitHubCopilotAliasesFromModels generates oauth-model-alias entries from a dynamic +// list of model IDs fetched from the Copilot API. It auto-creates aliases for +// models whose ID contains a dot (e.g. "claude-opus-4.6" → "claude-opus-4-6"), +// which is the pattern used by Claude models on Copilot. +func GitHubCopilotAliasesFromModels(modelIDs []string) []OAuthModelAlias { + var aliases []OAuthModelAlias + seen := make(map[string]struct{}) + for _, id := range modelIDs { + if !strings.Contains(id, ".") { + continue + } + hyphenID := strings.ReplaceAll(id, ".", "-") + key := id + "→" + hyphenID + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + aliases = append(aliases, OAuthModelAlias{Name: id, Alias: hyphenID, Fork: true}) + } + return aliases +} diff --git a/internal/config/parse.go b/internal/config/parse.go new file mode 100644 index 0000000000..283740e5f0 --- /dev/null +++ b/internal/config/parse.go @@ -0,0 +1,89 @@ +package config + +import ( + "fmt" + "strings" + + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/bcrypt" + "gopkg.in/yaml.v3" +) + +// ParseConfigBytes parses a YAML configuration payload into Config and applies the same +// in-memory normalizations as LoadConfigOptional, without persisting any changes to disk. +func ParseConfigBytes(data []byte) (*Config, error) { + if len(data) == 0 { + return nil, fmt.Errorf("config payload is empty") + } + + var cfg Config + // Keep defaults aligned with LoadConfigOptional. + cfg.Host = "" // Default empty: binds to all interfaces (IPv4 + IPv6) + cfg.LoggingToFile = false + cfg.LogsMaxTotalSizeMB = 0 + cfg.ErrorLogsMaxFiles = 10 + cfg.UsageStatisticsEnabled = false + cfg.RedisUsageQueueRetentionSeconds = 60 + cfg.DisableCooling = false + cfg.DisableImageGeneration = DisableImageGenerationOff + cfg.Pprof.Enable = false + cfg.Pprof.Addr = DefaultPprofAddr + cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient + cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository + + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("parse config payload: %w", err) + } + + // Hash remote management key if plaintext is detected (nested), but do NOT persist. + if cfg.RemoteManagement.SecretKey != "" && !looksLikeBcrypt(cfg.RemoteManagement.SecretKey) { + hashed, errHash := bcrypt.GenerateFromPassword([]byte(cfg.RemoteManagement.SecretKey), bcrypt.DefaultCost) + if errHash != nil { + return nil, fmt.Errorf("hash remote management key: %w", errHash) + } + cfg.RemoteManagement.SecretKey = string(hashed) + } + + cfg.RemoteManagement.PanelGitHubRepository = strings.TrimSpace(cfg.RemoteManagement.PanelGitHubRepository) + if cfg.RemoteManagement.PanelGitHubRepository == "" { + cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository + } + + cfg.Pprof.Addr = strings.TrimSpace(cfg.Pprof.Addr) + if cfg.Pprof.Addr == "" { + cfg.Pprof.Addr = DefaultPprofAddr + } + + if cfg.LogsMaxTotalSizeMB < 0 { + cfg.LogsMaxTotalSizeMB = 0 + } + + if cfg.ErrorLogsMaxFiles < 0 { + cfg.ErrorLogsMaxFiles = 10 + } + + if cfg.RedisUsageQueueRetentionSeconds <= 0 { + cfg.RedisUsageQueueRetentionSeconds = 60 + } else if cfg.RedisUsageQueueRetentionSeconds > 3600 { + log.WithField("value", cfg.RedisUsageQueueRetentionSeconds).Warn("redis-usage-queue-retention-seconds too large; clamping to 3600") + cfg.RedisUsageQueueRetentionSeconds = 3600 + } + + if cfg.MaxRetryCredentials < 0 { + cfg.MaxRetryCredentials = 0 + } + + // Apply the same sanitization pipeline. + cfg.SanitizeGeminiKeys() + cfg.SanitizeVertexCompatKeys() + cfg.SanitizeCodexKeys() + cfg.SanitizeCodexHeaderDefaults() + cfg.SanitizeClaudeHeaderDefaults() + cfg.SanitizeClaudeKeys() + cfg.SanitizeOpenAICompatibility() + cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels) + cfg.SanitizeOAuthModelAlias() + cfg.SanitizePayloadRules() + + return &cfg, nil +} diff --git a/internal/config/sdk_config.go b/internal/config/sdk_config.go index aa27526d1e..48c0fe5f17 100644 --- a/internal/config/sdk_config.go +++ b/internal/config/sdk_config.go @@ -9,6 +9,16 @@ type SDKConfig struct { // ProxyURL is the URL of an optional proxy server to use for outbound requests. ProxyURL string `yaml:"proxy-url" json:"proxy-url"` + // DisableImageGeneration controls whether the built-in image_generation tool is injected/allowed. + // + // Supported values: + // - false (default): image_generation is enabled everywhere (normal behavior). + // - true: image_generation is disabled everywhere. The server stops injecting it, removes it from request payloads, + // and returns 404 for /v1/images/generations and /v1/images/edits. + // - "chat": disable image_generation injection for all non-images endpoints (e.g. /v1/responses, /v1/chat/completions), + // while keeping /v1/images/generations and /v1/images/edits enabled and preserving image_generation there. + DisableImageGeneration DisableImageGenerationMode `yaml:"disable-image-generation" json:"disable-image-generation"` + // EnableGeminiCLIEndpoint controls whether Gemini CLI internal endpoints (/v1internal:*) are enabled. // Default is false for safety; when false, /v1internal:* requests are rejected. EnableGeminiCLIEndpoint bool `yaml:"enable-gemini-cli-endpoint" json:"enable-gemini-cli-endpoint"` diff --git a/internal/constant/constant.go b/internal/constant/constant.go index 58b388a138..7977f3896c 100644 --- a/internal/constant/constant.go +++ b/internal/constant/constant.go @@ -24,4 +24,16 @@ const ( // Antigravity represents the Antigravity response format identifier. Antigravity = "antigravity" + + // Kiro represents the AWS CodeWhisperer (Kiro) provider identifier. + Kiro = "kiro" + + // Kilo represents the Kilo AI provider identifier. + Kilo = "kilo" + + // CodeArts represents the HuaweiCloud CodeArts IDE provider identifier. + CodeArts = "codearts" + + // JoyCode represents the JD JoyCode provider identifier. + JoyCode = "joycode" ) diff --git a/internal/home/client.go b/internal/home/client.go new file mode 100644 index 0000000000..23082cc69c --- /dev/null +++ b/internal/home/client.go @@ -0,0 +1,393 @@ +package home + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + "sync/atomic" + "time" + + "github.com/redis/go-redis/v9" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + log "github.com/sirupsen/logrus" +) + +const ( + redisKeyConfig = "config" + redisChannelConfig = "config" + redisKeyModels = "models" + redisKeyUsage = "usage" + redisKeyRequestLog = "request-log" + + homeReconnectInterval = time.Second +) + +var ( + ErrDisabled = errors.New("home client disabled") + ErrNotConnected = errors.New("home not connected") + ErrEmptyResponse = errors.New("home returned empty response") + ErrAuthNotFound = errors.New("home auth not found") + ErrConfigNotFound = errors.New("home config not found") + ErrModelsNotFound = errors.New("home models not found") +) + +type Client struct { + homeCfg config.HomeConfig + + cmd *redis.Client + sub *redis.Client + + heartbeatOK atomic.Bool +} + +func New(homeCfg config.HomeConfig) *Client { + return &Client{homeCfg: homeCfg} +} + +func (c *Client) Enabled() bool { + if c == nil { + return false + } + return c.homeCfg.Enabled +} + +func (c *Client) HeartbeatOK() bool { + if c == nil { + return false + } + if !c.Enabled() { + return false + } + return c.heartbeatOK.Load() +} + +func (c *Client) Close() { + if c == nil { + return + } + c.heartbeatOK.Store(false) + if c.cmd != nil { + _ = c.cmd.Close() + } + if c.sub != nil { + _ = c.sub.Close() + } + c.cmd = nil + c.sub = nil +} + +func (c *Client) addr() (string, bool) { + if c == nil { + return "", false + } + host := strings.TrimSpace(c.homeCfg.Host) + if host == "" { + return "", false + } + if c.homeCfg.Port <= 0 { + return "", false + } + return fmt.Sprintf("%s:%d", host, c.homeCfg.Port), true +} + +func (c *Client) ensureClients() error { + if c == nil { + return ErrDisabled + } + if !c.Enabled() { + return ErrDisabled + } + addr, ok := c.addr() + if !ok { + return fmt.Errorf("home: invalid address (host=%q port=%d)", c.homeCfg.Host, c.homeCfg.Port) + } + + if c.cmd == nil { + c.cmd = redis.NewClient(&redis.Options{ + Addr: addr, + Password: c.homeCfg.Password, + }) + } + if c.sub == nil { + c.sub = redis.NewClient(&redis.Options{ + Addr: addr, + Password: c.homeCfg.Password, + }) + } + return nil +} + +func (c *Client) Ping(ctx context.Context) error { + if err := c.ensureClients(); err != nil { + return err + } + if c.cmd == nil { + return ErrNotConnected + } + return c.cmd.Ping(ctx).Err() +} + +func (c *Client) GetConfig(ctx context.Context) ([]byte, error) { + if err := c.ensureClients(); err != nil { + return nil, err + } + raw, err := c.cmd.Get(ctx, redisKeyConfig).Bytes() + if errors.Is(err, redis.Nil) { + return nil, ErrConfigNotFound + } + if err != nil { + return nil, err + } + if len(raw) == 0 { + return nil, ErrEmptyResponse + } + return raw, nil +} + +func (c *Client) GetModels(ctx context.Context) ([]byte, error) { + if err := c.ensureClients(); err != nil { + return nil, err + } + raw, err := c.cmd.Get(ctx, redisKeyModels).Bytes() + if errors.Is(err, redis.Nil) { + return nil, ErrModelsNotFound + } + if err != nil { + return nil, err + } + if len(raw) == 0 { + return nil, ErrEmptyResponse + } + return raw, nil +} + +func headersToLowerMap(headers http.Header) map[string]string { + if len(headers) == 0 { + return nil + } + out := make(map[string]string, len(headers)) + for key, values := range headers { + k := strings.ToLower(strings.TrimSpace(key)) + if k == "" { + continue + } + if len(values) == 0 { + out[k] = "" + continue + } + trimmed := make([]string, 0, len(values)) + for _, v := range values { + trimmed = append(trimmed, strings.TrimSpace(v)) + } + out[k] = strings.Join(trimmed, ", ") + } + if len(out) == 0 { + return nil + } + return out +} + +func newAuthDispatchRequest(requestedModel string, sessionID string, headers http.Header, count int) authDispatchRequest { + if count <= 0 { + count = 1 + } + return authDispatchRequest{ + Type: "auth", + Model: requestedModel, + Count: count, + SessionID: strings.TrimSpace(sessionID), + Headers: headersToLowerMap(headers), + } +} + +func (c *Client) RPopAuth(ctx context.Context, requestedModel string, sessionID string, headers http.Header, count int) ([]byte, error) { + if err := c.ensureClients(); err != nil { + return nil, err + } + requestedModel = strings.TrimSpace(requestedModel) + if requestedModel == "" { + return nil, fmt.Errorf("home: requested model is empty") + } + req := newAuthDispatchRequest(requestedModel, sessionID, headers, count) + keyBytes, err := json.Marshal(&req) + if err != nil { + return nil, err + } + + raw, err := c.cmd.RPop(ctx, string(keyBytes)).Bytes() + if errors.Is(err, redis.Nil) { + return nil, ErrAuthNotFound + } + if err != nil { + return nil, err + } + if len(raw) == 0 { + return nil, ErrEmptyResponse + } + return raw, nil +} + +func (c *Client) GetRefreshAuth(ctx context.Context, authIndex string) ([]byte, error) { + if err := c.ensureClients(); err != nil { + return nil, err + } + authIndex = strings.TrimSpace(authIndex) + if authIndex == "" { + return nil, fmt.Errorf("home: auth_index is empty") + } + req := refreshRequest{ + Type: "refresh", + AuthIndex: authIndex, + } + keyBytes, err := json.Marshal(&req) + if err != nil { + return nil, err + } + + raw, err := c.cmd.Get(ctx, string(keyBytes)).Bytes() + if errors.Is(err, redis.Nil) { + return nil, ErrAuthNotFound + } + if err != nil { + return nil, err + } + if len(raw) == 0 { + return nil, ErrEmptyResponse + } + return raw, nil +} + +func (c *Client) LPushUsage(ctx context.Context, payload []byte) error { + if err := c.ensureClients(); err != nil { + return err + } + if len(payload) == 0 { + return nil + } + return c.cmd.LPush(ctx, redisKeyUsage, payload).Err() +} + +func (c *Client) RPushRequestLog(ctx context.Context, payload []byte) error { + if err := c.ensureClients(); err != nil { + return err + } + if len(payload) == 0 { + return nil + } + return c.cmd.RPush(ctx, redisKeyRequestLog, payload).Err() +} + +// StartConfigSubscriber connects to home, fetches config once via GET config, then subscribes to +// the "config" channel to receive runtime config updates. +// +// The subscription connection is treated as the home heartbeat. HeartbeatOK is set to true only +// after the initial GET config succeeds and the SUBSCRIBE connection is established. When the +// subscription ends unexpectedly, HeartbeatOK becomes false and the loop reconnects. +func (c *Client) StartConfigSubscriber(ctx context.Context, onConfig func([]byte) error) { + if c == nil { + return + } + if !c.Enabled() { + return + } + if onConfig == nil { + return + } + + for { + if ctx != nil { + select { + case <-ctx.Done(): + c.heartbeatOK.Store(false) + return + default: + } + } + + c.heartbeatOK.Store(false) + c.Close() + + if errEnsure := c.ensureClients(); errEnsure != nil { + log.Warn("unable to connect to home control center, retrying in 1 second") + sleepWithContext(ctx, homeReconnectInterval) + continue + } + + if errPing := c.Ping(ctx); errPing != nil { + log.Warn("unable to connect to home control center, retrying in 1 second") + sleepWithContext(ctx, homeReconnectInterval) + continue + } + + raw, errGet := c.GetConfig(ctx) + if errGet != nil { + log.Warn("unable to fetch config from home control center, retrying in 1 second") + sleepWithContext(ctx, homeReconnectInterval) + continue + } + if errApply := onConfig(raw); errApply != nil { + log.Warn("unable to apply config from home control center, retrying in 1 second") + sleepWithContext(ctx, homeReconnectInterval) + continue + } + + if c.sub == nil { + sleepWithContext(ctx, homeReconnectInterval) + continue + } + + pubsub := c.sub.Subscribe(ctx, redisChannelConfig) + if pubsub == nil { + sleepWithContext(ctx, homeReconnectInterval) + continue + } + + // Ensure the subscription is established before marking heartbeat OK. + if _, errReceive := pubsub.Receive(ctx); errReceive != nil { + _ = pubsub.Close() + sleepWithContext(ctx, homeReconnectInterval) + continue + } + + c.heartbeatOK.Store(true) + + for { + msg, errMsg := pubsub.ReceiveMessage(ctx) + if errMsg != nil { + _ = pubsub.Close() + c.heartbeatOK.Store(false) + sleepWithContext(ctx, homeReconnectInterval) + break + } + if msg == nil { + continue + } + if payload := strings.TrimSpace(msg.Payload); payload != "" { + if errApply := onConfig([]byte(payload)); errApply != nil { + log.Warn("failed to apply config update from home control center, ignoring") + } + } + } + } +} + +func sleepWithContext(ctx context.Context, d time.Duration) { + if d <= 0 { + return + } + timer := time.NewTimer(d) + defer timer.Stop() + if ctx == nil { + <-timer.C + return + } + select { + case <-ctx.Done(): + return + case <-timer.C: + return + } +} diff --git a/internal/home/client_test.go b/internal/home/client_test.go new file mode 100644 index 0000000000..625e77bcac --- /dev/null +++ b/internal/home/client_test.go @@ -0,0 +1,32 @@ +package home + +import ( + "encoding/json" + "net/http" + "testing" +) + +func TestAuthDispatchRequestIncludesCount(t *testing.T) { + req := newAuthDispatchRequest("gpt-5.4", "session-1", http.Header{"Authorization": {"Bearer test"}}, 2) + + raw, err := json.Marshal(&req) + if err != nil { + t.Fatalf("marshal auth dispatch request: %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(raw, &payload); err != nil { + t.Fatalf("unmarshal auth dispatch request: %v", err) + } + if got := int(payload["count"].(float64)); got != 2 { + t.Fatalf("count = %d, want 2", got) + } +} + +func TestAuthDispatchRequestDefaultsCountToOne(t *testing.T) { + req := newAuthDispatchRequest("gpt-5.4", "", nil, 0) + + if req.Count != 1 { + t.Fatalf("count = %d, want 1", req.Count) + } +} diff --git a/internal/home/global.go b/internal/home/global.go new file mode 100644 index 0000000000..a79121a487 --- /dev/null +++ b/internal/home/global.go @@ -0,0 +1,25 @@ +package home + +import "sync/atomic" + +var currentClient atomic.Value // *Client + +// SetCurrent sets the active home client used by runtime integrations. +func SetCurrent(client *Client) { + currentClient.Store(client) +} + +// Current returns the active home client instance, if any. +func Current() *Client { + if v := currentClient.Load(); v != nil { + if client, ok := v.(*Client); ok { + return client + } + } + return nil +} + +// ClearCurrent removes the active home client. +func ClearCurrent() { + currentClient.Store((*Client)(nil)) +} diff --git a/internal/home/requests.go b/internal/home/requests.go new file mode 100644 index 0000000000..0757766468 --- /dev/null +++ b/internal/home/requests.go @@ -0,0 +1,14 @@ +package home + +type authDispatchRequest struct { + Type string `json:"type"` + Model string `json:"model"` + Count int `json:"count"` + SessionID string `json:"session_id,omitempty"` + Headers map[string]string `json:"headers,omitempty"` +} + +type refreshRequest struct { + Type string `json:"type"` + AuthIndex string `json:"auth_index"` +} diff --git a/internal/interfaces/types.go b/internal/interfaces/types.go index 9fb1e7f3b8..dfdfc02a84 100644 --- a/internal/interfaces/types.go +++ b/internal/interfaces/types.go @@ -3,7 +3,7 @@ // transformation operations, maintaining compatibility with the SDK translator package. package interfaces -import sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" +import sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator" // Backwards compatible aliases for translator function types. type TranslateRequestFunc = sdktranslator.RequestTransform diff --git a/internal/logging/gin_logger.go b/internal/logging/gin_logger.go index b94d7afe6d..6e3559b8c3 100644 --- a/internal/logging/gin_logger.go +++ b/internal/logging/gin_logger.go @@ -12,7 +12,7 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + "github.com/router-for-me/CLIProxyAPI/v7/internal/util" log "github.com/sirupsen/logrus" ) @@ -20,13 +20,17 @@ import ( var aiAPIPrefixes = []string{ "/v1/chat/completions", "/v1/completions", + "/v1/images", "/v1/messages", "/v1/responses", "/v1beta/models/", "/api/provider/", } -const skipGinLogKey = "__gin_skip_request_logging__" +const ( + skipGinLogKey = "__gin_skip_request_logging__" + creditsUsedKey = "__antigravity_credits_used__" +) // GinLogrusLogger returns a Gin middleware handler that logs HTTP requests and responses // using logrus. It captures request details including method, path, status code, latency, @@ -78,6 +82,9 @@ func GinLogrusLogger() gin.HandlerFunc { requestID = "--------" } logLine := fmt.Sprintf("%3d | %13v | %15s | %-7s \"%s\"", statusCode, latency, clientIP, method, path) + if creditsUsed(c) { + logLine += " [credits]" + } if errorMessage != "" { logLine = logLine + " | " + errorMessage } @@ -148,3 +155,15 @@ func shouldSkipGinRequestLogging(c *gin.Context) bool { flag, ok := val.(bool) return ok && flag } + +func creditsUsed(c *gin.Context) bool { + if c == nil { + return false + } + val, exists := c.Get(creditsUsedKey) + if !exists { + return false + } + flag, ok := val.(bool) + return ok && flag +} diff --git a/internal/logging/gin_logger_test.go b/internal/logging/gin_logger_test.go index 7de1833865..9bd3ddfba6 100644 --- a/internal/logging/gin_logger_test.go +++ b/internal/logging/gin_logger_test.go @@ -58,3 +58,12 @@ func TestGinLogrusRecoveryHandlesRegularPanic(t *testing.T) { t.Fatalf("expected 500, got %d", recorder.Code) } } + +func TestIsAIAPIPathIncludesImages(t *testing.T) { + if !isAIAPIPath("/v1/images/generations") { + t.Fatalf("expected /v1/images/generations to be treated as AI API path") + } + if !isAIAPIPath("/v1/images/edits") { + t.Fatalf("expected /v1/images/edits to be treated as AI API path") + } +} diff --git a/internal/logging/global_logger.go b/internal/logging/global_logger.go index 372222a545..4b4ef62c85 100644 --- a/internal/logging/global_logger.go +++ b/internal/logging/global_logger.go @@ -10,8 +10,8 @@ import ( "sync" "github.com/gin-gonic/gin" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + "github.com/router-for-me/CLIProxyAPI/v7/internal/util" log "github.com/sirupsen/logrus" "gopkg.in/natefinch/lumberjack.v2" ) diff --git a/internal/logging/request_logger.go b/internal/logging/request_logger.go index 2db2a504d3..44b2c95264 100644 --- a/internal/logging/request_logger.go +++ b/internal/logging/request_logger.go @@ -8,6 +8,8 @@ import ( "bytes" "compress/flate" "compress/gzip" + "context" + "encoding/json" "fmt" "io" "os" @@ -22,13 +24,23 @@ import ( "github.com/klauspost/compress/zstd" log "github.com/sirupsen/logrus" - "github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo" - "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" - "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + "github.com/router-for-me/CLIProxyAPI/v7/internal/buildinfo" + "github.com/router-for-me/CLIProxyAPI/v7/internal/home" + "github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces" + "github.com/router-for-me/CLIProxyAPI/v7/internal/util" ) var requestLogID atomic.Uint64 +type homeRequestLogClient interface { + HeartbeatOK() bool + RPushRequestLog(ctx context.Context, payload []byte) error +} + +var currentHomeRequestLogClient = func() homeRequestLogClient { + return home.Current() +} + // RequestLogger defines the interface for logging HTTP requests and responses. // It provides methods for logging both regular and streaming HTTP request/response cycles. type RequestLogger interface { @@ -148,6 +160,58 @@ type FileRequestLogger struct { // errorLogsMaxFiles limits the number of error log files retained. errorLogsMaxFiles int + + homeEnabled bool +} + +type homeRequestLogPayload struct { + Headers map[string][]string `json:"headers,omitempty"` + RequestLog string `json:"request_log,omitempty"` +} + +func cloneHeaders(headers map[string][]string) map[string][]string { + if len(headers) == 0 { + return nil + } + out := make(map[string][]string, len(headers)) + for key, values := range headers { + if strings.TrimSpace(key) == "" { + continue + } + if values == nil { + out[key] = nil + continue + } + copied := make([]string, len(values)) + copy(copied, values) + out[key] = copied + } + if len(out) == 0 { + return nil + } + return out +} + +func (l *FileRequestLogger) forwardRequestLogToHome(ctx context.Context, headers map[string][]string, logText string) error { + if l == nil || !l.homeEnabled { + return nil + } + client := currentHomeRequestLogClient() + if client == nil || !client.HeartbeatOK() { + return nil + } + payload := homeRequestLogPayload{ + Headers: cloneHeaders(headers), + RequestLog: logText, + } + raw, errMarshal := json.Marshal(&payload) + if errMarshal != nil { + return errMarshal + } + if ctx == nil { + ctx = context.Background() + } + return client.RPushRequestLog(ctx, raw) } // NewFileRequestLogger creates a new file-based request logger. @@ -173,7 +237,17 @@ func NewFileRequestLogger(enabled bool, logsDir string, configDir string, errorL enabled: enabled, logsDir: logsDir, errorLogsMaxFiles: errorLogsMaxFiles, + homeEnabled: false, + } +} + +// SetHomeEnabled toggles home request-log forwarding. +// When enabled, request logs are not written to disk and are instead forwarded to home via Redis RESP. +func (l *FileRequestLogger) SetHomeEnabled(enabled bool) { + if l == nil { + return } + l.homeEnabled = enabled } // IsEnabled returns whether request logging is currently enabled. @@ -231,6 +305,38 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st return nil } + if l.homeEnabled && l.enabled { + responseToWrite, decompressErr := l.decompressResponse(responseHeaders, response) + if decompressErr != nil { + responseToWrite = response + } + + var buf bytes.Buffer + writeErr := l.writeNonStreamingLog( + &buf, + url, + method, + requestHeaders, + body, + "", + websocketTimeline, + apiRequest, + apiResponse, + apiWebsocketTimeline, + apiResponseErrors, + statusCode, + responseHeaders, + responseToWrite, + decompressErr, + requestTimestamp, + apiResponseTimestamp, + ) + if writeErr != nil { + return fmt.Errorf("failed to build request log content: %w", writeErr) + } + return l.forwardRequestLogToHome(context.Background(), requestHeaders, buf.String()) + } + // Ensure logs directory exists if errEnsure := l.ensureLogsDir(); errEnsure != nil { return fmt.Errorf("failed to create logs directory: %w", errEnsure) @@ -321,6 +427,14 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[ return &NoOpStreamingLogWriter{}, nil } + if l.homeEnabled { + client := home.Current() + if client == nil || !client.HeartbeatOK() { + return &NoOpStreamingLogWriter{}, nil + } + return newHomeStreamingLogWriter(url, method, headers, body, requestID), nil + } + // Ensure logs directory exists if err := l.ensureLogsDir(); err != nil { return nil, fmt.Errorf("failed to create logs directory: %w", err) @@ -1498,3 +1612,165 @@ func (w *NoOpStreamingLogWriter) SetFirstChunkTimestamp(_ time.Time) {} // Returns: // - error: Always returns nil func (w *NoOpStreamingLogWriter) Close() error { return nil } + +type homeStreamingLogWriter struct { + url string + method string + timestamp time.Time + + requestHeaders map[string][]string + requestBody []byte + + chunkChan chan []byte + doneChan chan struct{} + + responseStatus int + statusWritten bool + responseHeaders map[string][]string + responseBody bytes.Buffer + apiRequest []byte + apiResponse []byte + apiWebsocketTime []byte + apiResponseTS time.Time + firstChunkTS time.Time +} + +func newHomeStreamingLogWriter(url, method string, headers map[string][]string, body []byte, _ string) *homeStreamingLogWriter { + requestHeaders := make(map[string][]string, len(headers)) + for key, values := range headers { + headerValues := make([]string, len(values)) + copy(headerValues, values) + requestHeaders[key] = headerValues + } + + writer := &homeStreamingLogWriter{ + url: url, + method: method, + timestamp: time.Now(), + requestHeaders: requestHeaders, + requestBody: append([]byte(nil), body...), + chunkChan: make(chan []byte, 100), + doneChan: make(chan struct{}), + } + + go writer.asyncWriter() + return writer +} + +func (w *homeStreamingLogWriter) asyncWriter() { + defer close(w.doneChan) + for chunk := range w.chunkChan { + if len(chunk) == 0 { + continue + } + _, _ = w.responseBody.Write(chunk) + } +} + +func (w *homeStreamingLogWriter) WriteChunkAsync(chunk []byte) { + if w == nil || w.chunkChan == nil || len(chunk) == 0 { + return + } + select { + case w.chunkChan <- append([]byte(nil), chunk...): + default: + } +} + +func (w *homeStreamingLogWriter) WriteStatus(status int, headers map[string][]string) error { + if w == nil || status == 0 { + return nil + } + w.responseStatus = status + w.statusWritten = true + if headers != nil { + w.responseHeaders = make(map[string][]string, len(headers)) + for key, values := range headers { + copied := make([]string, len(values)) + copy(copied, values) + w.responseHeaders[key] = copied + } + } + return nil +} + +func (w *homeStreamingLogWriter) WriteAPIRequest(apiRequest []byte) error { + if w == nil || len(apiRequest) == 0 { + return nil + } + w.apiRequest = bytes.Clone(apiRequest) + return nil +} + +func (w *homeStreamingLogWriter) WriteAPIResponse(apiResponse []byte) error { + if w == nil || len(apiResponse) == 0 { + return nil + } + w.apiResponse = bytes.Clone(apiResponse) + return nil +} + +func (w *homeStreamingLogWriter) WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error { + if w == nil || len(apiWebsocketTimeline) == 0 { + return nil + } + w.apiWebsocketTime = bytes.Clone(apiWebsocketTimeline) + return nil +} + +func (w *homeStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) { + if w == nil { + return + } + if !timestamp.IsZero() { + w.firstChunkTS = timestamp + w.apiResponseTS = timestamp + } +} + +func (w *homeStreamingLogWriter) Close() error { + if w == nil { + return nil + } + + client := currentHomeRequestLogClient() + if client == nil || !client.HeartbeatOK() { + return nil + } + + if w.chunkChan != nil { + close(w.chunkChan) + <-w.doneChan + w.chunkChan = nil + } + + responsePayload := w.responseBody.Bytes() + + var buf bytes.Buffer + upstreamTransport := inferUpstreamTransport(w.apiRequest, w.apiResponse, w.apiWebsocketTime, nil) + if errWrite := writeRequestInfoWithBody(&buf, w.url, w.method, w.requestHeaders, w.requestBody, "", w.timestamp, "http", upstreamTransport, true); errWrite != nil { + return errWrite + } + if errWrite := writeAPISection(&buf, "=== API WEBSOCKET TIMELINE ===\n", "=== API WEBSOCKET TIMELINE", w.apiWebsocketTime, time.Time{}); errWrite != nil { + return errWrite + } + if errWrite := writeAPISection(&buf, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest, time.Time{}); errWrite != nil { + return errWrite + } + if errWrite := writeAPISection(&buf, "=== API RESPONSE ===\n", "=== API RESPONSE", w.apiResponse, w.apiResponseTS); errWrite != nil { + return errWrite + } + if errWrite := writeResponseSection(&buf, w.responseStatus, w.statusWritten, w.responseHeaders, bytes.NewReader(responsePayload), nil, false); errWrite != nil { + return errWrite + } + + payload := homeRequestLogPayload{ + Headers: cloneHeaders(w.requestHeaders), + RequestLog: buf.String(), + } + raw, errMarshal := json.Marshal(&payload) + if errMarshal != nil { + return errMarshal + } + return client.RPushRequestLog(context.Background(), raw) +} diff --git a/internal/logging/request_logger_home_test.go b/internal/logging/request_logger_home_test.go new file mode 100644 index 0000000000..f8cdf1e453 --- /dev/null +++ b/internal/logging/request_logger_home_test.go @@ -0,0 +1,154 @@ +package logging + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "os" + "testing" + "time" +) + +type stubHomeRequestLogClient struct { + heartbeatOK bool + pushed [][]byte +} + +func (c *stubHomeRequestLogClient) HeartbeatOK() bool { return c.heartbeatOK } + +func (c *stubHomeRequestLogClient) RPushRequestLog(_ context.Context, payload []byte) error { + c.pushed = append(c.pushed, bytes.Clone(payload)) + return nil +} + +func TestFileRequestLogger_HomeEnabled_ForwardsWhenRequestLogEnabled(t *testing.T) { + original := currentHomeRequestLogClient + defer func() { + currentHomeRequestLogClient = original + }() + + stub := &stubHomeRequestLogClient{heartbeatOK: true} + currentHomeRequestLogClient = func() homeRequestLogClient { + return stub + } + + logsDir := t.TempDir() + logger := NewFileRequestLogger(true, logsDir, "", 0) + logger.SetHomeEnabled(true) + + requestHeaders := map[string][]string{ + "Content-Type": {"application/json"}, + "Authorization": {"Bearer secret"}, + } + + errLog := logger.LogRequest( + "/v1/chat/completions", + http.MethodPost, + requestHeaders, + []byte(`{"input":"hello"}`), + http.StatusOK, + map[string][]string{"Content-Type": {"application/json"}}, + []byte(`{"ok":true}`), + nil, + nil, + nil, + nil, + nil, + "req-1", + time.Now(), + time.Now(), + ) + if errLog != nil { + t.Fatalf("LogRequest error: %v", errLog) + } + + entries, errRead := os.ReadDir(logsDir) + if errRead != nil { + t.Fatalf("failed to read logs dir: %v", errRead) + } + if len(entries) != 0 { + t.Fatalf("expected no local request log files, got entries: %+v", entries) + } + + if len(stub.pushed) != 1 { + t.Fatalf("home pushed records = %d, want 1", len(stub.pushed)) + } + + var got struct { + Headers map[string][]string `json:"headers"` + RequestLog string `json:"request_log"` + } + if errUnmarshal := json.Unmarshal(stub.pushed[0], &got); errUnmarshal != nil { + t.Fatalf("unmarshal payload: %v payload=%s", errUnmarshal, string(stub.pushed[0])) + } + if got.Headers == nil || got.Headers["Content-Type"][0] != "application/json" { + t.Fatalf("headers.content-type = %+v, want application/json", got.Headers["Content-Type"]) + } + if got.Headers == nil || got.Headers["Authorization"][0] != "Bearer secret" { + t.Fatalf("headers.authorization = %+v, want Bearer secret", got.Headers["Authorization"]) + } + if got.RequestLog == "" { + t.Fatalf("request_log empty, want non-empty") + } +} + +func TestFileRequestLogger_HomeEnabled_DoesNotForwardForcedErrorLogsWhenRequestLogDisabled(t *testing.T) { + original := currentHomeRequestLogClient + defer func() { + currentHomeRequestLogClient = original + }() + + stub := &stubHomeRequestLogClient{heartbeatOK: true} + currentHomeRequestLogClient = func() homeRequestLogClient { + return stub + } + + logsDir := t.TempDir() + logger := NewFileRequestLogger(false, logsDir, "", 0) + logger.SetHomeEnabled(true) + + errLog := logger.LogRequestWithOptions( + "/v1/chat/completions", + http.MethodPost, + map[string][]string{"Content-Type": {"application/json"}}, + []byte(`{"input":"hello"}`), + http.StatusBadGateway, + map[string][]string{"Content-Type": {"application/json"}}, + []byte(`{"error":"upstream failure"}`), + nil, + nil, + nil, + nil, + nil, + true, + "req-2", + time.Now(), + time.Now(), + ) + if errLog != nil { + t.Fatalf("LogRequestWithOptions error: %v", errLog) + } + + if len(stub.pushed) != 0 { + t.Fatalf("home pushed records = %d, want 0", len(stub.pushed)) + } + + entries, errRead := os.ReadDir(logsDir) + if errRead != nil { + t.Fatalf("failed to read logs dir: %v", errRead) + } + found := false + for _, entry := range entries { + if entry.IsDir() { + continue + } + if entry.Name() != "" { + found = true + break + } + } + if !found { + t.Fatalf("expected local forced error log file when request-log disabled") + } +} diff --git a/internal/logging/requestmeta.go b/internal/logging/requestmeta.go new file mode 100644 index 0000000000..a28d7c6287 --- /dev/null +++ b/internal/logging/requestmeta.go @@ -0,0 +1,62 @@ +package logging + +import ( + "context" + "sync/atomic" +) + +type endpointKey struct{} +type responseStatusKey struct{} + +type responseStatusHolder struct { + status atomic.Int32 +} + +func WithEndpoint(ctx context.Context, endpoint string) context.Context { + if ctx == nil { + ctx = context.Background() + } + return context.WithValue(ctx, endpointKey{}, endpoint) +} + +func GetEndpoint(ctx context.Context) string { + if ctx == nil { + return "" + } + if endpoint, ok := ctx.Value(endpointKey{}).(string); ok { + return endpoint + } + return "" +} + +func WithResponseStatusHolder(ctx context.Context) context.Context { + if ctx == nil { + ctx = context.Background() + } + if holder, ok := ctx.Value(responseStatusKey{}).(*responseStatusHolder); ok && holder != nil { + return ctx + } + return context.WithValue(ctx, responseStatusKey{}, &responseStatusHolder{}) +} + +func SetResponseStatus(ctx context.Context, status int) { + if ctx == nil || status <= 0 { + return + } + holder, ok := ctx.Value(responseStatusKey{}).(*responseStatusHolder) + if !ok || holder == nil { + return + } + holder.status.Store(int32(status)) +} + +func GetResponseStatus(ctx context.Context) int { + if ctx == nil { + return 0 + } + holder, ok := ctx.Value(responseStatusKey{}).(*responseStatusHolder) + if !ok || holder == nil { + return 0 + } + return int(holder.status.Load()) +} diff --git a/internal/managementasset/embed.go b/internal/managementasset/embed.go new file mode 100644 index 0000000000..b4be5eafb7 --- /dev/null +++ b/internal/managementasset/embed.go @@ -0,0 +1,41 @@ +package managementasset + +import ( + "embed" + "io/fs" + "net/http" +) + +//go:embed all:web_static +var embeddedWebFS embed.FS + +func GetEmbeddedFileSystem() http.FileSystem { + sub, err := fs.Sub(embeddedWebFS, "web_static") + if err != nil { + return nil + } + return http.FS(sub) +} + +func GetEmbeddedFS() fs.FS { + sub, err := fs.Sub(embeddedWebFS, "web_static") + if err != nil { + return nil + } + return sub +} + +func HasEmbeddedAssets() bool { + fsys := GetEmbeddedFileSystem() + if fsys == nil { + return false + } + f, err := fsys.Open("index.html") + if err != nil { + return false + } + defer func() { + _ = f.Close() + }() + return true +} diff --git a/internal/managementasset/updater.go b/internal/managementasset/updater.go index ae2bc81956..ea7ca3f502 100644 --- a/internal/managementasset/updater.go +++ b/internal/managementasset/updater.go @@ -17,9 +17,9 @@ import ( "sync/atomic" "time" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/util" - sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + "github.com/router-for-me/CLIProxyAPI/v7/internal/util" + sdkconfig "github.com/router-for-me/CLIProxyAPI/v7/sdk/config" log "github.com/sirupsen/logrus" "golang.org/x/sync/singleflight" ) diff --git a/internal/managementasset/web_static/README.md b/internal/managementasset/web_static/README.md new file mode 100644 index 0000000000..551dc8a2fc --- /dev/null +++ b/internal/managementasset/web_static/README.md @@ -0,0 +1,3 @@ +This directory contains embedded web static assets generated from `web/out`. + +To update: `cp -r web/out internal/managementasset/web_static` diff --git a/internal/misc/antigravity_version.go b/internal/misc/antigravity_version.go index 595cfefd96..0d187c254f 100644 --- a/internal/misc/antigravity_version.go +++ b/internal/misc/antigravity_version.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "net/http" + "strings" "sync" "time" @@ -18,6 +19,8 @@ const ( antigravityFallbackVersion = "1.21.9" antigravityVersionCacheTTL = 6 * time.Hour antigravityFetchTimeout = 10 * time.Second + AntigravityNodeAPIClientUA = "google-api-nodejs-client/10.3.0" + AntigravityGoogAPIClientUA = "gl-node/22.21.1" ) type antigravityRelease struct { @@ -107,6 +110,65 @@ func AntigravityUserAgent() string { return fmt.Sprintf("antigravity/%s darwin/arm64", AntigravityLatestVersion()) } +func antigravityBaseUserAgent(userAgent string) string { + userAgent = strings.TrimSpace(userAgent) + if userAgent == "" { + return AntigravityUserAgent() + } + lower := strings.ToLower(userAgent) + if strings.HasPrefix(lower, "antigravity/") { + if idx := strings.Index(lower, " google-api-nodejs-client/"); idx >= 0 { + trimmed := strings.TrimSpace(userAgent[:idx]) + if trimmed != "" { + return trimmed + } + } + } + return userAgent +} + +// AntigravityRequestUserAgent returns the short Antigravity runtime UA used by +// generate/stream/model-list requests. +func AntigravityRequestUserAgent(userAgent string) string { + return antigravityBaseUserAgent(userAgent) +} + +// AntigravityLoadCodeAssistUserAgent returns the long Antigravity control-plane +// UA used by loadCodeAssist requests. +func AntigravityLoadCodeAssistUserAgent(userAgent string) string { + userAgent = strings.TrimSpace(userAgent) + if userAgent == "" { + return AntigravityUserAgent() + " " + AntigravityNodeAPIClientUA + } + lower := strings.ToLower(userAgent) + if !strings.HasPrefix(lower, "antigravity/") { + return userAgent + } + if strings.Contains(lower, "google-api-nodejs-client/") { + return userAgent + } + return antigravityBaseUserAgent(userAgent) + " " + AntigravityNodeAPIClientUA +} + +// AntigravityVersionFromUserAgent extracts the Antigravity version prefix from +// either the short or long Antigravity UA forms. +func AntigravityVersionFromUserAgent(userAgent string) string { + base := antigravityBaseUserAgent(userAgent) + lower := strings.ToLower(base) + if !strings.HasPrefix(lower, "antigravity/") { + return AntigravityLatestVersion() + } + rest := base[len("antigravity/"):] + if idx := strings.IndexAny(rest, " \t"); idx >= 0 { + rest = rest[:idx] + } + rest = strings.TrimSpace(rest) + if rest == "" { + return AntigravityLatestVersion() + } + return rest +} + func fetchAntigravityLatestVersion(ctx context.Context) (string, error) { if ctx == nil { ctx = context.Background() diff --git a/internal/misc/header_utils.go b/internal/misc/header_utils.go index 5752a26956..ac022a9627 100644 --- a/internal/misc/header_utils.go +++ b/internal/misc/header_utils.go @@ -12,7 +12,7 @@ import ( const ( // GeminiCLIVersion is the version string reported in the User-Agent for upstream requests. - GeminiCLIVersion = "0.31.0" + GeminiCLIVersion = "0.34.0" // GeminiCLIApiClientHeader is the value for the X-Goog-Api-Client header sent to the Gemini CLI upstream. GeminiCLIApiClientHeader = "google-genai-sdk/1.41.0 gl-node/v22.19.0" @@ -46,7 +46,7 @@ func GeminiCLIUserAgent(model string) string { if model == "" { model = "unknown" } - return fmt.Sprintf("GeminiCLI/%s/%s (%s; %s)", GeminiCLIVersion, model, geminiCLIOS(), geminiCLIArch()) + return fmt.Sprintf("GeminiCLI/%s/%s (%s; %s; terminal)", GeminiCLIVersion, model, geminiCLIOS(), geminiCLIArch()) } // ScrubProxyAndFingerprintHeaders removes all headers that could reveal diff --git a/internal/redisqueue/plugin.go b/internal/redisqueue/plugin.go new file mode 100644 index 0000000000..e5b74cb24b --- /dev/null +++ b/internal/redisqueue/plugin.go @@ -0,0 +1,160 @@ +package redisqueue + +import ( + "context" + "encoding/json" + "strings" + "time" + + internallogging "github.com/router-for-me/CLIProxyAPI/v7/internal/logging" + coreusage "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/usage" +) + +func init() { + coreusage.RegisterPlugin(&usageQueuePlugin{}) +} + +type usageQueuePlugin struct{} + +func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Record) { + if p == nil { + return + } + if !Enabled() || !UsageStatisticsEnabled() { + return + } + + timestamp := record.RequestedAt + if timestamp.IsZero() { + timestamp = time.Now() + } + + modelName := strings.TrimSpace(record.Model) + if modelName == "" { + modelName = "unknown" + } + aliasName := strings.TrimSpace(record.Alias) + if aliasName == "" { + aliasName = modelName + } + provider := strings.TrimSpace(record.Provider) + if provider == "" { + provider = "unknown" + } + authType := strings.TrimSpace(record.AuthType) + if authType == "" { + authType = "unknown" + } + apiKey := strings.TrimSpace(record.APIKey) + requestID := strings.TrimSpace(internallogging.GetRequestID(ctx)) + + tokens := tokenStats{ + InputTokens: record.Detail.InputTokens, + OutputTokens: record.Detail.OutputTokens, + ReasoningTokens: record.Detail.ReasoningTokens, + CachedTokens: record.Detail.CachedTokens, + TotalTokens: record.Detail.TotalTokens, + } + if tokens.TotalTokens == 0 { + tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens + } + if tokens.TotalTokens == 0 { + tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens + tokens.CachedTokens + } + + failed := record.Failed + if !failed { + failed = !resolveSuccess(ctx) + } + fail := resolveFail(ctx, record, failed) + + detail := requestDetail{ + Timestamp: timestamp, + LatencyMs: record.Latency.Milliseconds(), + Source: record.Source, + AuthIndex: record.AuthIndex, + Tokens: tokens, + Failed: failed, + Fail: fail, + } + + payload, err := json.Marshal(queuedUsageDetail{ + requestDetail: detail, + Provider: provider, + Model: modelName, + Alias: aliasName, + Endpoint: resolveEndpoint(ctx), + AuthType: authType, + APIKey: apiKey, + RequestID: requestID, + }) + if err != nil { + return + } + Enqueue(payload) +} + +type queuedUsageDetail struct { + requestDetail + Provider string `json:"provider"` + Model string `json:"model"` + Alias string `json:"alias"` + Endpoint string `json:"endpoint"` + AuthType string `json:"auth_type"` + APIKey string `json:"api_key"` + RequestID string `json:"request_id"` +} + +type requestDetail struct { + Timestamp time.Time `json:"timestamp"` + LatencyMs int64 `json:"latency_ms"` + Source string `json:"source"` + AuthIndex string `json:"auth_index"` + Tokens tokenStats `json:"tokens"` + Failed bool `json:"failed"` + Fail failDetail `json:"fail"` +} + +type tokenStats struct { + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + ReasoningTokens int64 `json:"reasoning_tokens"` + CachedTokens int64 `json:"cached_tokens"` + TotalTokens int64 `json:"total_tokens"` +} + +type failDetail struct { + StatusCode int `json:"status_code"` + Body string `json:"body"` +} + +func resolveFail(ctx context.Context, record coreusage.Record, failed bool) failDetail { + fail := failDetail{ + StatusCode: record.Fail.StatusCode, + Body: strings.TrimSpace(record.Fail.Body), + } + if !failed { + return failDetail{StatusCode: 200} + } + if fail.StatusCode <= 0 { + fail.StatusCode = internallogging.GetResponseStatus(ctx) + } + if fail.StatusCode <= 0 { + fail.StatusCode = 500 + } + return fail +} + +func resolveSuccess(ctx context.Context) bool { + status := internallogging.GetResponseStatus(ctx) + if status == 0 { + return true + } + return status < httpStatusBadRequest +} + +func resolveEndpoint(ctx context.Context) string { + return strings.TrimSpace(internallogging.GetEndpoint(ctx)) +} + +const httpStatusBadRequest = 400 diff --git a/internal/redisqueue/plugin_test.go b/internal/redisqueue/plugin_test.go new file mode 100644 index 0000000000..e2af6af709 --- /dev/null +++ b/internal/redisqueue/plugin_test.go @@ -0,0 +1,278 @@ +package redisqueue + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + internallogging "github.com/router-for-me/CLIProxyAPI/v7/internal/logging" + coreusage "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/usage" +) + +func TestUsageQueuePluginPayloadIncludesStableFieldsAndSuccess(t *testing.T) { + withEnabledQueue(t, func() { + ctx := internallogging.WithRequestID(context.Background(), "ctx-request-id") + ctx = internallogging.WithEndpoint(ctx, "POST /v1/chat/completions") + ctx = internallogging.WithResponseStatusHolder(ctx) + internallogging.SetResponseStatus(ctx, http.StatusOK) + + plugin := &usageQueuePlugin{} + plugin.HandleUsage(ctx, coreusage.Record{ + Provider: "openai", + Model: "gpt-5.4", + Alias: "client-gpt", + APIKey: "test-key", + AuthIndex: "0", + AuthType: "apikey", + Source: "user@example.com", + RequestedAt: time.Date(2026, 4, 25, 0, 0, 0, 0, time.UTC), + Latency: 1500 * time.Millisecond, + Detail: coreusage.Detail{ + InputTokens: 10, + OutputTokens: 20, + TotalTokens: 30, + }, + }) + + payload := popSinglePayload(t) + requireStringField(t, payload, "provider", "openai") + requireStringField(t, payload, "model", "gpt-5.4") + requireStringField(t, payload, "alias", "client-gpt") + requireStringField(t, payload, "endpoint", "POST /v1/chat/completions") + requireStringField(t, payload, "auth_type", "apikey") + requireMissingField(t, payload, "user_api_key") + requireStringField(t, payload, "request_id", "ctx-request-id") + requireBoolField(t, payload, "failed", false) + requireFailField(t, payload, http.StatusOK, "") + }) +} + +func TestUsageQueuePluginPayloadIncludesStableFieldsAndFailureAndGinRequestID(t *testing.T) { + withEnabledQueue(t, func() { + ctx := internallogging.WithRequestID(context.Background(), "gin-request-id") + ctx = internallogging.WithEndpoint(ctx, "GET /v1/responses") + ctx = internallogging.WithResponseStatusHolder(ctx) + internallogging.SetResponseStatus(ctx, http.StatusInternalServerError) + + plugin := &usageQueuePlugin{} + plugin.HandleUsage(ctx, coreusage.Record{ + Provider: "openai", + Model: "gpt-5.4-mini", + Alias: "client-mini", + APIKey: "test-key", + AuthIndex: "0", + AuthType: "apikey", + Source: "user@example.com", + RequestedAt: time.Date(2026, 4, 25, 0, 0, 0, 0, time.UTC), + Latency: 2500 * time.Millisecond, + Fail: coreusage.Failure{ + StatusCode: http.StatusInternalServerError, + Body: "upstream failed", + }, + Detail: coreusage.Detail{ + InputTokens: 10, + OutputTokens: 20, + TotalTokens: 30, + }, + }) + + payload := popSinglePayload(t) + requireStringField(t, payload, "provider", "openai") + requireStringField(t, payload, "model", "gpt-5.4-mini") + requireStringField(t, payload, "alias", "client-mini") + requireStringField(t, payload, "endpoint", "GET /v1/responses") + requireStringField(t, payload, "auth_type", "apikey") + requireMissingField(t, payload, "user_api_key") + requireStringField(t, payload, "request_id", "gin-request-id") + requireBoolField(t, payload, "failed", true) + requireFailField(t, payload, http.StatusInternalServerError, "upstream failed") + }) +} + +func TestUsageQueuePluginAsyncIgnoresRecycledGinContext(t *testing.T) { + withEnabledQueue(t, func() { + ginCtx := newTestGinContext(t, http.MethodPost, "/v1/chat/completions", http.StatusOK) + ctx := context.WithValue(context.Background(), "gin", ginCtx) + ctx = internallogging.WithRequestID(ctx, "ctx-request-id") + ctx = internallogging.WithEndpoint(ctx, "POST /v1/chat/completions") + ctx = internallogging.WithResponseStatusHolder(ctx) + internallogging.SetResponseStatus(ctx, http.StatusInternalServerError) + + mgr := coreusage.NewManager(16) + defer mgr.Stop() + + mgr.Register(pluginFunc(func(_ context.Context, _ coreusage.Record) { + ginCtx.Request = httptest.NewRequest(http.MethodGet, "http://example.com/v1/responses", nil) + ginCtx.Status(http.StatusOK) + })) + mgr.Register(&usageQueuePlugin{}) + + mgr.Publish(ctx, coreusage.Record{ + Provider: "openai", + Model: "gpt-5.4", + Alias: "client-gpt", + APIKey: "test-key", + AuthIndex: "0", + AuthType: "apikey", + Source: "user@example.com", + RequestedAt: time.Date(2026, 4, 25, 0, 0, 0, 0, time.UTC), + Latency: 1500 * time.Millisecond, + Fail: coreusage.Failure{ + StatusCode: http.StatusBadGateway, + Body: "bad gateway", + }, + Detail: coreusage.Detail{ + InputTokens: 10, + OutputTokens: 20, + TotalTokens: 30, + }, + }) + + payload := waitForSinglePayload(t, 2*time.Second) + requireStringField(t, payload, "endpoint", "POST /v1/chat/completions") + requireStringField(t, payload, "alias", "client-gpt") + requireMissingField(t, payload, "user_api_key") + requireStringField(t, payload, "request_id", "ctx-request-id") + requireBoolField(t, payload, "failed", true) + requireFailField(t, payload, http.StatusBadGateway, "bad gateway") + }) +} + +func withEnabledQueue(t *testing.T, fn func()) { + t.Helper() + + prevQueueEnabled := Enabled() + prevUsageEnabled := UsageStatisticsEnabled() + + SetEnabled(false) + SetEnabled(true) + SetUsageStatisticsEnabled(true) + + defer func() { + SetEnabled(false) + SetEnabled(prevQueueEnabled) + SetUsageStatisticsEnabled(prevUsageEnabled) + }() + + fn() +} + +func newTestGinContext(t *testing.T, method, path string, status int) *gin.Context { + t.Helper() + + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + ginCtx.Request = httptest.NewRequest(method, "http://example.com"+path, nil) + if status != 0 { + ginCtx.Status(status) + } + return ginCtx +} + +func popSinglePayload(t *testing.T) map[string]json.RawMessage { + t.Helper() + + items := PopOldest(10) + if len(items) != 1 { + t.Fatalf("PopOldest() items = %d, want 1", len(items)) + } + + var payload map[string]json.RawMessage + if err := json.Unmarshal(items[0], &payload); err != nil { + t.Fatalf("unmarshal payload: %v", err) + } + return payload +} + +func waitForSinglePayload(t *testing.T, timeout time.Duration) map[string]json.RawMessage { + t.Helper() + + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + items := PopOldest(10) + if len(items) == 0 { + time.Sleep(10 * time.Millisecond) + continue + } + if len(items) != 1 { + t.Fatalf("PopOldest() items = %d, want 1", len(items)) + } + var payload map[string]json.RawMessage + if err := json.Unmarshal(items[0], &payload); err != nil { + t.Fatalf("unmarshal payload: %v", err) + } + return payload + } + t.Fatalf("timeout waiting for queued payload") + return nil +} + +func requireStringField(t *testing.T, payload map[string]json.RawMessage, key, want string) { + t.Helper() + + raw, ok := payload[key] + if !ok { + t.Fatalf("payload missing %q", key) + } + var got string + if err := json.Unmarshal(raw, &got); err != nil { + t.Fatalf("unmarshal %q: %v", key, err) + } + if got != want { + t.Fatalf("%s = %q, want %q", key, got, want) + } +} + +func requireMissingField(t *testing.T, payload map[string]json.RawMessage, key string) { + t.Helper() + + if _, ok := payload[key]; ok { + t.Fatalf("payload unexpectedly contains %q", key) + } +} + +type pluginFunc func(context.Context, coreusage.Record) + +func (fn pluginFunc) HandleUsage(ctx context.Context, record coreusage.Record) { + fn(ctx, record) +} + +func requireBoolField(t *testing.T, payload map[string]json.RawMessage, key string, want bool) { + t.Helper() + + raw, ok := payload[key] + if !ok { + t.Fatalf("payload missing %q", key) + } + var got bool + if err := json.Unmarshal(raw, &got); err != nil { + t.Fatalf("unmarshal %q: %v", key, err) + } + if got != want { + t.Fatalf("%s = %t, want %t", key, got, want) + } +} + +func requireFailField(t *testing.T, payload map[string]json.RawMessage, wantStatus int, wantBody string) { + t.Helper() + + raw, ok := payload["fail"] + if !ok { + t.Fatalf("payload missing %q", "fail") + } + var got struct { + StatusCode int `json:"status_code"` + Body string `json:"body"` + } + if err := json.Unmarshal(raw, &got); err != nil { + t.Fatalf("unmarshal fail: %v", err) + } + if got.StatusCode != wantStatus || got.Body != wantBody { + t.Fatalf("fail = {status_code:%d body:%q}, want {status_code:%d body:%q}", got.StatusCode, got.Body, wantStatus, wantBody) + } +} diff --git a/internal/redisqueue/queue.go b/internal/redisqueue/queue.go new file mode 100644 index 0000000000..2fea58391a --- /dev/null +++ b/internal/redisqueue/queue.go @@ -0,0 +1,155 @@ +package redisqueue + +import ( + "sync" + "sync/atomic" + "time" +) + +const ( + defaultRetentionSeconds int64 = 60 + maxRetentionSeconds int64 = 3600 +) + +type queueItem struct { + enqueuedAt time.Time + payload []byte +} + +type queue struct { + mu sync.Mutex + items []queueItem + head int +} + +var ( + enabled atomic.Bool + retentionSeconds atomic.Int64 + global queue +) + +func init() { + retentionSeconds.Store(defaultRetentionSeconds) +} + +func SetEnabled(value bool) { + enabled.Store(value) + if !value { + global.clear() + } +} + +func Enabled() bool { + return enabled.Load() +} + +func SetRetentionSeconds(value int) { + normalized := int64(value) + if normalized <= 0 { + normalized = defaultRetentionSeconds + } else if normalized > maxRetentionSeconds { + normalized = maxRetentionSeconds + } + retentionSeconds.Store(normalized) +} + +func Enqueue(payload []byte) { + if !Enabled() { + return + } + if len(payload) == 0 { + return + } + global.enqueue(payload) +} + +func PopOldest(count int) [][]byte { + if !Enabled() { + return nil + } + if count <= 0 { + return nil + } + return global.popOldest(count) +} + +func (q *queue) clear() { + q.mu.Lock() + defer q.mu.Unlock() + q.items = nil + q.head = 0 +} + +func (q *queue) enqueue(payload []byte) { + now := time.Now() + + q.mu.Lock() + defer q.mu.Unlock() + + q.pruneLocked(now) + q.items = append(q.items, queueItem{ + enqueuedAt: now, + payload: append([]byte(nil), payload...), + }) + q.maybeCompactLocked() +} + +func (q *queue) popOldest(count int) [][]byte { + now := time.Now() + + q.mu.Lock() + defer q.mu.Unlock() + + q.pruneLocked(now) + available := len(q.items) - q.head + if available <= 0 { + q.items = nil + q.head = 0 + return nil + } + if count > available { + count = available + } + + out := make([][]byte, 0, count) + for i := 0; i < count; i++ { + item := q.items[q.head+i] + out = append(out, item.payload) + } + q.head += count + q.maybeCompactLocked() + return out +} + +func (q *queue) pruneLocked(now time.Time) { + if q.head >= len(q.items) { + q.items = nil + q.head = 0 + return + } + + windowSeconds := retentionSeconds.Load() + if windowSeconds <= 0 { + windowSeconds = defaultRetentionSeconds + } + cutoff := now.Add(-time.Duration(windowSeconds) * time.Second) + for q.head < len(q.items) && q.items[q.head].enqueuedAt.Before(cutoff) { + q.head++ + } +} + +func (q *queue) maybeCompactLocked() { + if q.head == 0 { + return + } + if q.head >= len(q.items) { + q.items = nil + q.head = 0 + return + } + if q.head < 1024 && q.head*2 < len(q.items) { + return + } + q.items = append([]queueItem(nil), q.items[q.head:]...) + q.head = 0 +} diff --git a/internal/redisqueue/usage_toggle.go b/internal/redisqueue/usage_toggle.go new file mode 100644 index 0000000000..dddbeca692 --- /dev/null +++ b/internal/redisqueue/usage_toggle.go @@ -0,0 +1,16 @@ +package redisqueue + +import "sync/atomic" + +var usageStatisticsEnabled atomic.Bool + +func init() { + usageStatisticsEnabled.Store(true) +} + +// SetUsageStatisticsEnabled toggles whether usage records are enqueued into the redisqueue payload buffer. +// This is controlled by the config field `usage-statistics-enabled` and the corresponding management API. +func SetUsageStatisticsEnabled(enabled bool) { usageStatisticsEnabled.Store(enabled) } + +// UsageStatisticsEnabled reports whether the usage queue plugin should publish records. +func UsageStatisticsEnabled() bool { return usageStatisticsEnabled.Load() } diff --git a/internal/registry/kilo_models.go b/internal/registry/kilo_models.go new file mode 100644 index 0000000000..ac9939dbb7 --- /dev/null +++ b/internal/registry/kilo_models.go @@ -0,0 +1,21 @@ +// Package registry provides model definitions for various AI service providers. +package registry + +// GetKiloModels returns the Kilo model definitions +func GetKiloModels() []*ModelInfo { + return []*ModelInfo{ + // --- Base Models --- + { + ID: "kilo/auto", + Object: "model", + Created: 1732752000, + OwnedBy: "kilo", + Type: "kilo", + DisplayName: "Kilo Auto", + Description: "Automatic model selection by Kilo", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + } +} diff --git a/internal/registry/kiro_model_converter.go b/internal/registry/kiro_model_converter.go new file mode 100644 index 0000000000..fe50a8f306 --- /dev/null +++ b/internal/registry/kiro_model_converter.go @@ -0,0 +1,303 @@ +// Package registry provides Kiro model conversion utilities. +// This file handles converting dynamic Kiro API model lists to the internal ModelInfo format, +// and merging with static metadata for thinking support and other capabilities. +package registry + +import ( + "strings" + "time" +) + +// KiroAPIModel represents a model from Kiro API response. +// This is a local copy to avoid import cycles with the kiro package. +// The structure mirrors kiro.KiroModel for easy data conversion. +type KiroAPIModel struct { + // ModelID is the unique identifier for the model (e.g., "claude-sonnet-4.5") + ModelID string + // ModelName is the human-readable name + ModelName string + // Description is the model description + Description string + // RateMultiplier is the credit multiplier for this model + RateMultiplier float64 + // RateUnit is the unit for rate calculation (e.g., "credit") + RateUnit string + // MaxInputTokens is the maximum input token limit + MaxInputTokens int +} + +// DefaultKiroThinkingSupport defines the default thinking configuration for Kiro models. +// All Kiro models support thinking with the following budget range. +var DefaultKiroThinkingSupport = &ThinkingSupport{ + Min: 1024, // Minimum thinking budget tokens + Max: 32000, // Maximum thinking budget tokens + ZeroAllowed: true, // Allow disabling thinking with 0 + DynamicAllowed: true, // Allow dynamic thinking budget (-1) +} + +// DefaultKiroContextLength is the default context window size for Kiro models. +const DefaultKiroContextLength = 200000 + +// DefaultKiroMaxCompletionTokens is the default max completion tokens for Kiro models. +const DefaultKiroMaxCompletionTokens = 64000 + +// ConvertKiroAPIModels converts Kiro API models to internal ModelInfo format. +// It performs the following transformations: +// - Normalizes model ID (e.g., claude-sonnet-4.5 → kiro-claude-sonnet-4-5) +// - Adds default thinking support metadata +// - Sets default context length and max completion tokens if not provided +// +// Parameters: +// - kiroModels: List of models from Kiro API response +// +// Returns: +// - []*ModelInfo: Converted model information list +func ConvertKiroAPIModels(kiroModels []*KiroAPIModel) []*ModelInfo { + if len(kiroModels) == 0 { + return nil + } + + now := time.Now().Unix() + result := make([]*ModelInfo, 0, len(kiroModels)) + + for _, km := range kiroModels { + // Skip nil models + if km == nil { + continue + } + + // Skip models without valid ID + if km.ModelID == "" { + continue + } + + // Normalize the model ID to kiro-* format + normalizedID := normalizeKiroModelID(km.ModelID) + + // Create ModelInfo with converted data + info := &ModelInfo{ + ID: normalizedID, + Object: "model", + Created: now, + OwnedBy: "aws", + Type: "kiro", + DisplayName: generateKiroDisplayName(km.ModelName, normalizedID), + Description: km.Description, + // Use MaxInputTokens from API if available, otherwise use default + ContextLength: getContextLength(km.MaxInputTokens), + MaxCompletionTokens: DefaultKiroMaxCompletionTokens, + // All Kiro models support thinking + Thinking: cloneThinkingSupport(DefaultKiroThinkingSupport), + } + + result = append(result, info) + } + + return result +} + +// GenerateAgenticVariants creates -agentic variants for each model. +// Agentic variants are optimized for coding agents with chunked writes. +// +// Parameters: +// - models: Base models to generate variants for +// +// Returns: +// - []*ModelInfo: Combined list of base models and their agentic variants +func GenerateAgenticVariants(models []*ModelInfo) []*ModelInfo { + if len(models) == 0 { + return nil + } + + // Pre-allocate result with capacity for both base models and variants + result := make([]*ModelInfo, 0, len(models)*2) + + for _, model := range models { + if model == nil { + continue + } + + // Add the base model first + result = append(result, model) + + // Skip if model already has -agentic suffix + if strings.HasSuffix(model.ID, "-agentic") { + continue + } + + // Skip special models that shouldn't have agentic variants + if model.ID == "kiro-auto" { + continue + } + + // Create agentic variant + agenticModel := &ModelInfo{ + ID: model.ID + "-agentic", + Object: model.Object, + Created: model.Created, + OwnedBy: model.OwnedBy, + Type: model.Type, + DisplayName: model.DisplayName + " (Agentic)", + Description: generateAgenticDescription(model.Description), + ContextLength: model.ContextLength, + MaxCompletionTokens: model.MaxCompletionTokens, + Thinking: cloneThinkingSupport(model.Thinking), + } + + result = append(result, agenticModel) + } + + return result +} + +// MergeWithStaticMetadata merges dynamic models with static metadata. +// Static metadata takes priority for any overlapping fields. +// This allows manual overrides for specific models while keeping dynamic discovery. +// +// Parameters: +// - dynamicModels: Models from Kiro API (converted to ModelInfo) +// - staticModels: Predefined model metadata (from GetKiroModels()) +// +// Returns: +// - []*ModelInfo: Merged model list with static metadata taking priority +func MergeWithStaticMetadata(dynamicModels, staticModels []*ModelInfo) []*ModelInfo { + if len(dynamicModels) == 0 && len(staticModels) == 0 { + return nil + } + + // Build a map of static models for quick lookup + staticMap := make(map[string]*ModelInfo, len(staticModels)) + for _, sm := range staticModels { + if sm != nil && sm.ID != "" { + staticMap[sm.ID] = sm + } + } + + // Build result, preferring static metadata where available + seenIDs := make(map[string]struct{}) + result := make([]*ModelInfo, 0, len(dynamicModels)+len(staticModels)) + + // First, process dynamic models and merge with static if available + for _, dm := range dynamicModels { + if dm == nil || dm.ID == "" { + continue + } + + // Skip duplicates + if _, seen := seenIDs[dm.ID]; seen { + continue + } + seenIDs[dm.ID] = struct{}{} + + // Check if static metadata exists for this model + if sm, exists := staticMap[dm.ID]; exists { + // Static metadata takes priority - use static model + result = append(result, sm) + } else { + // No static metadata - use dynamic model + result = append(result, dm) + } + } + + // Add any static models not in dynamic list + for _, sm := range staticModels { + if sm == nil || sm.ID == "" { + continue + } + if _, seen := seenIDs[sm.ID]; seen { + continue + } + seenIDs[sm.ID] = struct{}{} + result = append(result, sm) + } + + return result +} + +// normalizeKiroModelID converts Kiro API model IDs to internal format. +// Transformation rules: +// - Adds "kiro-" prefix if not present +// - Replaces dots with hyphens (e.g., 4.5 → 4-5) +// - Handles special cases like "auto" → "kiro-auto" +// +// Examples: +// - "claude-sonnet-4.5" → "kiro-claude-sonnet-4-5" +// - "claude-opus-4.5" → "kiro-claude-opus-4-5" +// - "auto" → "kiro-auto" +// - "kiro-claude-sonnet-4-5" → "kiro-claude-sonnet-4-5" (unchanged) +func normalizeKiroModelID(modelID string) string { + if modelID == "" { + return "" + } + + // Trim whitespace + modelID = strings.TrimSpace(modelID) + + // Replace dots with hyphens (e.g., 4.5 → 4-5) + normalized := strings.ReplaceAll(modelID, ".", "-") + + // Add kiro- prefix if not present + if !strings.HasPrefix(normalized, "kiro-") { + normalized = "kiro-" + normalized + } + + return normalized +} + +// generateKiroDisplayName creates a human-readable display name. +// Uses the API-provided model name if available, otherwise generates from ID. +func generateKiroDisplayName(modelName, normalizedID string) string { + if modelName != "" { + return "Kiro " + modelName + } + + // Generate from normalized ID by removing kiro- prefix and formatting + displayID := strings.TrimPrefix(normalizedID, "kiro-") + // Capitalize first letter of each word + words := strings.Split(displayID, "-") + for i, word := range words { + if len(word) > 0 { + words[i] = strings.ToUpper(word[:1]) + word[1:] + } + } + return "Kiro " + strings.Join(words, " ") +} + +// generateAgenticDescription creates description for agentic variants. +func generateAgenticDescription(baseDescription string) string { + if baseDescription == "" { + return "Optimized for coding agents with chunked writes" + } + return baseDescription + " (Agentic mode: chunked writes)" +} + +// getContextLength returns the context length, using default if not provided. +func getContextLength(maxInputTokens int) int { + if maxInputTokens > 0 { + return maxInputTokens + } + return DefaultKiroContextLength +} + +// cloneThinkingSupport creates a deep copy of ThinkingSupport. +// Returns nil if input is nil. +func cloneThinkingSupport(ts *ThinkingSupport) *ThinkingSupport { + if ts == nil { + return nil + } + + clone := &ThinkingSupport{ + Min: ts.Min, + Max: ts.Max, + ZeroAllowed: ts.ZeroAllowed, + DynamicAllowed: ts.DynamicAllowed, + } + + // Deep copy Levels slice if present + if len(ts.Levels) > 0 { + clone.Levels = make([]string, len(ts.Levels)) + copy(clone.Levels, ts.Levels) + } + + return clone +} diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index ab7258f845..e8a5f39e72 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -6,6 +6,15 @@ import ( "strings" ) +const codexBuiltinImageModelID = "gpt-image-2" + +// defaultCopilotClaudeContextLength is the conservative prompt token limit for +// Claude models accessed via the GitHub Copilot API. Individual accounts are +// capped at 128K; business accounts at 168K. When the dynamic /models API fetch +// succeeds, the real per-account limit overrides this value. This constant is +// only used as a safe fallback. +const defaultCopilotClaudeContextLength = 128000 + // staticModelsJSON mirrors the top-level structure of models.json. type staticModelsJSON struct { Claude []*ModelInfo `json:"claude"` @@ -48,22 +57,22 @@ func GetAIStudioModels() []*ModelInfo { // GetCodexFreeModels returns model definitions for the Codex free plan tier. func GetCodexFreeModels() []*ModelInfo { - return cloneModelInfos(getModels().CodexFree) + return WithCodexBuiltins(cloneModelInfos(getModels().CodexFree)) } // GetCodexTeamModels returns model definitions for the Codex team plan tier. func GetCodexTeamModels() []*ModelInfo { - return cloneModelInfos(getModels().CodexTeam) + return WithCodexBuiltins(cloneModelInfos(getModels().CodexTeam)) } // GetCodexPlusModels returns model definitions for the Codex plus plan tier. func GetCodexPlusModels() []*ModelInfo { - return cloneModelInfos(getModels().CodexPlus) + return WithCodexBuiltins(cloneModelInfos(getModels().CodexPlus)) } // GetCodexProModels returns model definitions for the Codex pro plan tier. func GetCodexProModels() []*ModelInfo { - return cloneModelInfos(getModels().CodexPro) + return WithCodexBuiltins(cloneModelInfos(getModels().CodexPro)) } // GetKimiModels returns the standard Kimi (Moonshot AI) model definitions. @@ -76,6 +85,71 @@ func GetAntigravityModels() []*ModelInfo { return cloneModelInfos(getModels().Antigravity) } +// WithCodexBuiltins injects hard-coded Codex-only model definitions that should +// not depend on remote models.json updates. Built-ins replace any matching IDs +// already present in the provided slice. +func WithCodexBuiltins(models []*ModelInfo) []*ModelInfo { + return upsertModelInfos(models, codexBuiltinImageModelInfo()) +} + +func codexBuiltinImageModelInfo() *ModelInfo { + return &ModelInfo{ + ID: codexBuiltinImageModelID, + Object: "model", + Created: 1704067200, // 2024-01-01 + OwnedBy: "openai", + Type: "openai", + DisplayName: "GPT Image 2", + Version: codexBuiltinImageModelID, + } +} + +func upsertModelInfos(models []*ModelInfo, extras ...*ModelInfo) []*ModelInfo { + if len(extras) == 0 { + return models + } + + extraIDs := make(map[string]struct{}, len(extras)) + extraList := make([]*ModelInfo, 0, len(extras)) + for _, extra := range extras { + if extra == nil { + continue + } + id := strings.TrimSpace(extra.ID) + if id == "" { + continue + } + key := strings.ToLower(id) + if _, exists := extraIDs[key]; exists { + continue + } + extraIDs[key] = struct{}{} + extraList = append(extraList, cloneModelInfo(extra)) + } + + if len(extraList) == 0 { + return models + } + + filtered := make([]*ModelInfo, 0, len(models)+len(extraList)) + for _, model := range models { + if model == nil { + continue + } + id := strings.TrimSpace(model.ID) + if id == "" { + continue + } + if _, exists := extraIDs[strings.ToLower(id)]; exists { + continue + } + filtered = append(filtered, model) + } + + filtered = append(filtered, extraList...) + return filtered +} + // cloneModelInfos returns a shallow copy of the slice with each element deep-cloned. func cloneModelInfos(models []*ModelInfo) []*ModelInfo { if len(models) == 0 { @@ -117,8 +191,26 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo { return GetCodexProModels() case "kimi": return GetKimiModels() + case "github-copilot": + return GetGitHubCopilotModels() + case "kiro": + return GetKiroModels() + case "kilo": + return GetKiloModels() + case "amazonq": + return GetAmazonQModels() case "antigravity": return GetAntigravityModels() + case "qoder": + return GetQoderModels() + case "bt": + return GetBTModels() + case "codebuddy": + return GetCodeBuddyModels() + case "codebuddy-ai": + return GetCodeBuddyAIModels() + case "cursor": + return GetCursorModels() default: return nil } @@ -141,6 +233,14 @@ func LookupStaticModelInfo(modelID string) *ModelInfo { data.CodexPro, data.Kimi, data.Antigravity, + GetQoderModels(), + GetGitHubCopilotModels(), + GetKiroModels(), + GetKiloModels(), + GetAmazonQModels(), + GetCodeBuddyModels(), + GetCodeBuddyAIModels(), + GetCursorModels(), } for _, models := range allModels { for _, m := range models { @@ -152,3 +252,1019 @@ func LookupStaticModelInfo(modelID string) *ModelInfo { return nil } + +func GetQoderModels() []*ModelInfo { + now := int64(1748044800) // 2025-05-24 + return []*ModelInfo{ + { + ID: "auto", + Object: "model", + Created: now, + OwnedBy: "qoder", + Type: "qoder", + DisplayName: "Auto", + Description: "Automatic model selection", + }, + { + ID: "ultimate", + Object: "model", + Created: now, + OwnedBy: "qoder", + Type: "qoder", + DisplayName: "Ultimate", + Description: "Qoder Ultimate tier model", + }, + { + ID: "performance", + Object: "model", + Created: now, + OwnedBy: "qoder", + Type: "qoder", + DisplayName: "Performance", + Description: "Qoder Performance tier model", + }, + { + ID: "efficient", + Object: "model", + Created: now, + OwnedBy: "qoder", + Type: "qoder", + DisplayName: "Efficient", + Description: "Qoder Efficient tier model", + }, + { + ID: "lite", + Object: "model", + Created: now, + OwnedBy: "qoder", + Type: "qoder", + DisplayName: "Lite", + Description: "Qoder Lite tier model", + }, + { + ID: "qmodel", + Object: "model", + Created: now, + OwnedBy: "qoder", + Type: "qoder", + DisplayName: "Qwen3.6-Plus", + Description: "Qwen 3.6 Plus via Qoder", + }, + { + ID: "dmodel", + Object: "model", + Created: now, + OwnedBy: "qoder", + Type: "qoder", + DisplayName: "DeepSeek-V4-Pro", + Description: "DeepSeek V4 Pro via Qoder", + }, + { + ID: "dfmodel", + Object: "model", + Created: now, + OwnedBy: "qoder", + Type: "qoder", + DisplayName: "DeepSeek-V4-Flash", + Description: "DeepSeek V4 Flash via Qoder", + }, + { + ID: "gm51model", + Object: "model", + Created: now, + OwnedBy: "qoder", + Type: "qoder", + DisplayName: "GLM-5.1", + Description: "GLM 5.1 via Qoder", + }, + { + ID: "kmodel", + Object: "model", + Created: now, + OwnedBy: "qoder", + Type: "qoder", + DisplayName: "Kimi-K2.6", + Description: "Kimi K2.6 via Qoder", + }, + { + ID: "mmodel", + Object: "model", + Created: now, + OwnedBy: "qoder", + Type: "qoder", + DisplayName: "MiniMax-M2.7", + Description: "MiniMax M2.7 via Qoder", + }, + } +} + +func GetCodeBuddyModels() []*ModelInfo { + now := int64(1748044800) + return []*ModelInfo{ + { + ID: "auto", Object: "model", Created: now, OwnedBy: "tencent", + Type: "codebuddy", DisplayName: "Auto", Description: "Automatic model selection via CodeBuddy", + ContextLength: 168000, MaxCompletionTokens: 32000, SupportedEndpoints: []string{"/chat/completions"}, + SupportedInputModalities: []string{"TEXT", "IMAGE"}, + }, + { + ID: "hy3-preview", Object: "model", Created: now, OwnedBy: "tencent", + Type: "codebuddy", DisplayName: "Hy3 Preview", Description: "Hunyuan thinking model with enhanced reasoning capabilities via CodeBuddy", + ContextLength: 192000, MaxCompletionTokens: 64000, SupportedEndpoints: []string{"/chat/completions"}, + Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}, + }, + { + ID: "glm-5v-turbo", Object: "model", Created: now, OwnedBy: "tencent", + Type: "codebuddy", DisplayName: "GLM-5v Turbo", Description: "Native multimodal model via CodeBuddy", + ContextLength: 200000, MaxCompletionTokens: 38000, SupportedEndpoints: []string{"/chat/completions"}, + Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}, + SupportedInputModalities: []string{"TEXT", "IMAGE"}, + }, + { + ID: "glm-5.1", Object: "model", Created: now, OwnedBy: "tencent", + Type: "codebuddy", DisplayName: "GLM-5.1", Description: "GLM-5.1 via CodeBuddy", + ContextLength: 200000, MaxCompletionTokens: 48000, SupportedEndpoints: []string{"/chat/completions"}, + Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}, + }, + { + ID: "glm-5.0-turbo", Object: "model", Created: now, OwnedBy: "tencent", + Type: "codebuddy", DisplayName: "GLM-5.0 Turbo", Description: "GLM-5.0 Turbo via CodeBuddy", + ContextLength: 200000, MaxCompletionTokens: 48000, SupportedEndpoints: []string{"/chat/completions"}, + Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}, + }, + { + ID: "kimi-k2.6", Object: "model", Created: now, OwnedBy: "tencent", + Type: "codebuddy", DisplayName: "Kimi K2.6", Description: "Kimi K2.6 via CodeBuddy", + ContextLength: 256000, MaxCompletionTokens: 32000, SupportedEndpoints: []string{"/chat/completions"}, + Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}, + SupportedInputModalities: []string{"TEXT", "IMAGE"}, + }, + { + ID: "kimi-k2.5", Object: "model", Created: now, OwnedBy: "tencent", + Type: "codebuddy", DisplayName: "Kimi K2.5", Description: "Kimi K2.5 via CodeBuddy", + ContextLength: 256000, MaxCompletionTokens: 32000, SupportedEndpoints: []string{"/chat/completions"}, + Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}, + SupportedInputModalities: []string{"TEXT", "IMAGE"}, + }, + { + ID: "minimax-m2.7", Object: "model", Created: now, OwnedBy: "tencent", + Type: "codebuddy", DisplayName: "MiniMax M2.7", Description: "MiniMax M2.7 via CodeBuddy", + ContextLength: 200000, MaxCompletionTokens: 48000, SupportedEndpoints: []string{"/chat/completions"}, + Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}, + SupportedInputModalities: []string{"TEXT", "IMAGE"}, + }, + { + ID: "deepseek-v4-flash", Object: "model", Created: now, OwnedBy: "tencent", + Type: "codebuddy", DisplayName: "DeepSeek V4 Flash", Description: "DeepSeek V4 Flash via CodeBuddy", + ContextLength: 1000000, MaxCompletionTokens: 50000, SupportedEndpoints: []string{"/chat/completions"}, + Thinking: &ThinkingSupport{Levels: []string{"high", "max"}}, + }, + { + ID: "deepseek-v3-2-volc", Object: "model", Created: now, OwnedBy: "tencent", + Type: "codebuddy", DisplayName: "DeepSeek V3.2", Description: "DeepSeek V3.2 via CodeBuddy", + ContextLength: 96000, MaxCompletionTokens: 32000, SupportedEndpoints: []string{"/chat/completions"}, + Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}, + SupportedInputModalities: []string{"TEXT", "IMAGE"}, + }, + { + ID: "deepseek-v3-1-volc", Object: "model", Created: now, OwnedBy: "tencent", + Type: "codebuddy", DisplayName: "DeepSeek V3.1 Terminus", Description: "DeepSeek V3.1 Terminus via CodeBuddy", + ContextLength: 128000, MaxCompletionTokens: 32000, SupportedEndpoints: []string{"/chat/completions"}, + }, + { + ID: "deepseek-r1-0528-lkeap", Object: "model", Created: now, OwnedBy: "tencent", + Type: "codebuddy", DisplayName: "DeepSeek R1 0528", Description: "DeepSeek R1 0528 via CodeBuddy", + ContextLength: 112000, MaxCompletionTokens: 16000, SupportedEndpoints: []string{"/chat/completions"}, + }, + { + ID: "hunyuan-chat", Object: "model", Created: now, OwnedBy: "tencent", + Type: "codebuddy", DisplayName: "Hunyuan Turbos", Description: "Tencent Hunyuan Turbos via CodeBuddy", + ContextLength: 128000, MaxCompletionTokens: 8192, SupportedEndpoints: []string{"/chat/completions"}, + }, + } +} + +func GetCodeBuddyAIModels() []*ModelInfo { + now := int64(1748044800) + return []*ModelInfo{ + { + ID: "default-model", Object: "model", Created: now, OwnedBy: "codebuddy-ai", + Type: "codebuddy-ai", DisplayName: "Default Model", Description: "Default model via CodeBuddy AI", + ContextLength: 128000, MaxCompletionTokens: 32768, SupportedEndpoints: []string{"/chat/completions"}, + }, + { + ID: "glm-5v-turbo", Object: "model", Created: now, OwnedBy: "codebuddy-ai", + Type: "codebuddy-ai", DisplayName: "GLM-5v Turbo", Description: "GLM-5v Turbo via CodeBuddy AI", + ContextLength: 200000, MaxCompletionTokens: 32768, SupportedEndpoints: []string{"/chat/completions"}, + }, + { + ID: "kimi-k2.5", Object: "model", Created: now, OwnedBy: "codebuddy-ai", + Type: "codebuddy-ai", DisplayName: "Kimi-K2.5", Description: "Kimi K2.5 via CodeBuddy AI", + ContextLength: 256000, MaxCompletionTokens: 32768, SupportedEndpoints: []string{"/chat/completions"}, + }, + { + ID: "gpt-5.4", Object: "model", Created: now, OwnedBy: "codebuddy-ai", + Type: "codebuddy-ai", DisplayName: "GPT-5.4", Description: "GPT-5.4 via CodeBuddy AI", + ContextLength: 200000, MaxCompletionTokens: 32768, SupportedEndpoints: []string{"/chat/completions"}, + }, + { + ID: "gpt-5.3-codex", Object: "model", Created: now, OwnedBy: "codebuddy-ai", + Type: "codebuddy-ai", DisplayName: "GPT-5.3-Codex", Description: "GPT-5.3 Codex via CodeBuddy AI", + ContextLength: 200000, MaxCompletionTokens: 32768, SupportedEndpoints: []string{"/chat/completions"}, + }, + { + ID: "gpt-5.2-codex", Object: "model", Created: now, OwnedBy: "codebuddy-ai", + Type: "codebuddy-ai", DisplayName: "GPT-5.2-Codex", Description: "GPT-5.2 Codex via CodeBuddy AI", + ContextLength: 200000, MaxCompletionTokens: 32768, SupportedEndpoints: []string{"/chat/completions"}, + }, + { + ID: "gpt-5.2", Object: "model", Created: now, OwnedBy: "codebuddy-ai", + Type: "codebuddy-ai", DisplayName: "GPT-5.2", Description: "GPT-5.2 via CodeBuddy AI", + ContextLength: 200000, MaxCompletionTokens: 32768, SupportedEndpoints: []string{"/chat/completions"}, + }, + { + ID: "gpt-5.1", Object: "model", Created: now, OwnedBy: "codebuddy-ai", + Type: "codebuddy-ai", DisplayName: "GPT-5.1", Description: "GPT-5.1 via CodeBuddy AI", + ContextLength: 200000, MaxCompletionTokens: 32768, SupportedEndpoints: []string{"/chat/completions"}, + }, + { + ID: "gpt-5.1-codex-max", Object: "model", Created: now, OwnedBy: "codebuddy-ai", + Type: "codebuddy-ai", DisplayName: "GPT-5.1-Codex-Max", Description: "GPT-5.1 Codex Max via CodeBuddy AI", + ContextLength: 200000, MaxCompletionTokens: 32768, SupportedEndpoints: []string{"/chat/completions"}, + }, + { + ID: "gemini-3.0-pro", Object: "model", Created: now, OwnedBy: "codebuddy-ai", + Type: "codebuddy-ai", DisplayName: "Gemini-3.0-Pro", Description: "Gemini 3.0 Pro via CodeBuddy AI", + ContextLength: 200000, MaxCompletionTokens: 32768, SupportedEndpoints: []string{"/chat/completions"}, + }, + { + ID: "gemini-3.0-flash", Object: "model", Created: now, OwnedBy: "codebuddy-ai", + Type: "codebuddy-ai", DisplayName: "Gemini-3.0-Flash", Description: "Gemini 3.0 Flash via CodeBuddy AI", + ContextLength: 200000, MaxCompletionTokens: 32768, SupportedEndpoints: []string{"/chat/completions"}, + }, + { + ID: "deepseek-v3.2", Object: "model", Created: now, OwnedBy: "codebuddy-ai", + Type: "codebuddy-ai", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 via CodeBuddy AI", + ContextLength: 128000, MaxCompletionTokens: 32768, SupportedEndpoints: []string{"/chat/completions"}, + }, + { + ID: "auto-chat", Object: "model", Created: now, OwnedBy: "codebuddy-ai", + Type: "codebuddy-ai", DisplayName: "Auto-Chat", Description: "Auto Chat via CodeBuddy AI", + ContextLength: 128000, MaxCompletionTokens: 32768, SupportedEndpoints: []string{"/chat/completions"}, + }, + } +} + +func GetBTModels() []*ModelInfo { + now := int64(1745548800) // 2025-04-25 + return []*ModelInfo{ + {ID: "ernie-4.5-21b-a3b-thinking", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "ERNIE 4.5 21B A3B Thinking", Description: "Baidu ERNIE 4.5 thinking model via BaoTa", ContextLength: 128000, MaxCompletionTokens: 8192, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}}, + {ID: "ernie-x1-turbo-32k-preview", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "ERNIE X1 Turbo 32K Preview", Description: "Baidu ERNIE X1 Turbo via BaoTa", ContextLength: 32768, MaxCompletionTokens: 8192}, + {ID: "ernie-x1.1", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "ERNIE X1.1", Description: "Baidu ERNIE X1.1 via BaoTa", ContextLength: 128000, MaxCompletionTokens: 8192}, + {ID: "ernie-5.0", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "ERNIE 5.0", Description: "Baidu ERNIE 5.0 via BaoTa", ContextLength: 128000, MaxCompletionTokens: 16384}, + {ID: "ernie-5.0-thinking-preview", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "ERNIE 5.0 Thinking Preview", Description: "Baidu ERNIE 5.0 thinking model via BaoTa", ContextLength: 128000, MaxCompletionTokens: 16384, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}}, + {ID: "hunyuan-2.0-instruct-20251111", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "Hunyuan 2.0 Instruct", Description: "Tencent Hunyuan 2.0 Instruct via BaoTa", ContextLength: 128000, MaxCompletionTokens: 8192}, + {ID: "hunyuan-2.0-thinking-20251109", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "Hunyuan 2.0 Thinking", Description: "Tencent Hunyuan 2.0 thinking model via BaoTa", ContextLength: 128000, MaxCompletionTokens: 8192, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}}, + {ID: "deepseek-r1-250528", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "DeepSeek R1 250528", Description: "DeepSeek R1 via BaoTa", ContextLength: 128000, MaxCompletionTokens: 16384, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}}, + {ID: "deepseek-v3-2-251201", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "DeepSeek V3.2 251201", Description: "DeepSeek V3.2 via BaoTa", ContextLength: 128000, MaxCompletionTokens: 8192}, + {ID: "glm-4-7-251222", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "GLM-4.7 251222", Description: "Zhipu GLM-4.7 via BaoTa", ContextLength: 128000, MaxCompletionTokens: 8192}, + {ID: "doubao-seed-2-0-code-preview-260215", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "Doubao Seed 2.0 Code Preview", Description: "ByteDance Doubao Seed 2.0 Code via BaoTa", ContextLength: 128000, MaxCompletionTokens: 16384}, + {ID: "doubao-seed-2-0-mini-260215", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "Doubao Seed 2.0 Mini", Description: "ByteDance Doubao Seed 2.0 Mini via BaoTa", ContextLength: 128000, MaxCompletionTokens: 8192}, + {ID: "doubao-seed-2-0-lite-260215", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "Doubao Seed 2.0 Lite", Description: "ByteDance Doubao Seed 2.0 Lite via BaoTa", ContextLength: 128000, MaxCompletionTokens: 8192}, + {ID: "doubao-seed-2-0-pro-260215", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "Doubao Seed 2.0 Pro", Description: "ByteDance Doubao Seed 2.0 Pro via BaoTa", ContextLength: 128000, MaxCompletionTokens: 16384}, + {ID: "text-embedding-v4", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "Text Embedding V4", Description: "Baidu Text Embedding V4 via BaoTa"}, + {ID: "kimi-k2.5", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "Kimi K2.5", Description: "Moonshot Kimi K2.5 via BaoTa", ContextLength: 128000, MaxCompletionTokens: 16384}, + {ID: "deepseek-v3.2", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "DeepSeek V3.2", Description: "DeepSeek V3.2 via BaoTa", ContextLength: 128000, MaxCompletionTokens: 8192}, + {ID: "qwen-max-2025-01-25", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "Qwen Max 2025-01-25", Description: "Alibaba Qwen Max via BaoTa", ContextLength: 128000, MaxCompletionTokens: 8192}, + {ID: "glm-5", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "GLM-5", Description: "Zhipu GLM-5 via BaoTa", ContextLength: 128000, MaxCompletionTokens: 16384}, + {ID: "qwen-flash", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "Qwen Flash", Description: "Alibaba Qwen Flash via BaoTa", ContextLength: 128000, MaxCompletionTokens: 8192}, + {ID: "qwen-plus", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "Qwen Plus", Description: "Alibaba Qwen Plus via BaoTa", ContextLength: 128000, MaxCompletionTokens: 8192}, + {ID: "doubao-seed-1-8-251228", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "Doubao Seed 1.8 251228", Description: "ByteDance Doubao Seed 1.8 via BaoTa", ContextLength: 128000, MaxCompletionTokens: 8192}, + {ID: "qwen-plus-2025-12-01", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "Qwen Plus 2025-12-01", Description: "Alibaba Qwen Plus via BaoTa", ContextLength: 128000, MaxCompletionTokens: 8192}, + {ID: "deepseek-v4-flash", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "DeepSeek V4 Flash", Description: "DeepSeek V4 Flash via BaoTa", ContextLength: 1000000, MaxCompletionTokens: 384000}, + {ID: "deepseek-v4-pro", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "DeepSeek V4 Pro", Description: "DeepSeek V4 Pro via BaoTa", ContextLength: 1000000, MaxCompletionTokens: 384000}, + {ID: "qwen3.5-plus", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "Qwen3.5 Plus", Description: "Alibaba Qwen3.5 Plus via BaoTa", ContextLength: 128000, MaxCompletionTokens: 8192}, + {ID: "qwen3.5-flash", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "Qwen3.5 Flash", Description: "Alibaba Qwen3.5 Flash via BaoTa", ContextLength: 128000, MaxCompletionTokens: 8192}, + {ID: "qwen3-coder-flash", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "Qwen3 Coder Flash", Description: "Alibaba Qwen3 Coder Flash via BaoTa", ContextLength: 128000, MaxCompletionTokens: 8192}, + {ID: "qwen3-coder-plus", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "Qwen3 Coder Plus", Description: "Alibaba Qwen3 Coder Plus via BaoTa", ContextLength: 128000, MaxCompletionTokens: 16384}, + {ID: "qwen3.6-plus", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "Qwen3.6 Plus", Description: "Alibaba Qwen3.6 Plus via BaoTa", ContextLength: 128000, MaxCompletionTokens: 8192}, + {ID: "qwen3-max", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "Qwen3 Max", Description: "Alibaba Qwen3 Max via BaoTa", ContextLength: 128000, MaxCompletionTokens: 8192}, + {ID: "qwen3-max-2026-01-23", Object: "model", Created: now, OwnedBy: "bt", Type: "bt", DisplayName: "Qwen3 Max 2026-01-23", Description: "Alibaba Qwen3 Max via BaoTa", ContextLength: 128000, MaxCompletionTokens: 8192}, + } +} + +func GetCursorModels() []*ModelInfo { + return []*ModelInfo{ + {ID: "composer-2", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Composer 2", ContextLength: 200000, MaxCompletionTokens: 64000, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}}, + {ID: "claude-4-sonnet", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Claude 4 Sonnet", ContextLength: 200000, MaxCompletionTokens: 64000, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}}, + {ID: "claude-3.5-sonnet", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Claude 3.5 Sonnet", ContextLength: 200000, MaxCompletionTokens: 8192}, + {ID: "gpt-4o", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "GPT-4o", ContextLength: 128000, MaxCompletionTokens: 16384}, + {ID: "cursor-small", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Cursor Small", ContextLength: 200000, MaxCompletionTokens: 64000}, + {ID: "gemini-2.5-pro", Object: "model", OwnedBy: "cursor", Type: "cursor", DisplayName: "Gemini 2.5 Pro", ContextLength: 1000000, MaxCompletionTokens: 65536, Thinking: &ThinkingSupport{Max: 50000, DynamicAllowed: true}}, + } +} + +func GetGitHubCopilotModels() []*ModelInfo { + now := int64(1732752000) // 2024-11-27 + copilotClaudeEndpoints := []string{"/chat/completions", "/messages"} + gpt4oEntries := []struct { + ID string + DisplayName string + Description string + }{ + {ID: "gpt-4o-2024-11-20", DisplayName: "GPT-4o (2024-11-20)", Description: "OpenAI GPT-4o 2024-11-20 via GitHub Copilot"}, + {ID: "gpt-4o-2024-08-06", DisplayName: "GPT-4o (2024-08-06)", Description: "OpenAI GPT-4o 2024-08-06 via GitHub Copilot"}, + {ID: "gpt-4o-2024-05-13", DisplayName: "GPT-4o (2024-05-13)", Description: "OpenAI GPT-4o 2024-05-13 via GitHub Copilot"}, + {ID: "gpt-4o", DisplayName: "GPT-4o", Description: "OpenAI GPT-4o via GitHub Copilot"}, + {ID: "gpt-4-o-preview", DisplayName: "GPT-4-o Preview", Description: "OpenAI GPT-4-o Preview via GitHub Copilot"}, + } + + models := []*ModelInfo{ + { + ID: "gpt-4.1", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-4.1", + Description: "OpenAI GPT-4.1 via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 16384, + SupportedEndpoints: []string{"/chat/completions", "/responses"}, + }, + } + + for _, entry := range gpt4oEntries { + models = append(models, &ModelInfo{ + ID: entry.ID, + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: entry.DisplayName, + Description: entry.Description, + ContextLength: 128000, + MaxCompletionTokens: 16384, + SupportedEndpoints: []string{"/chat/completions", "/responses"}, + }) + } + + return append(models, []*ModelInfo{ + { + ID: "gpt-5", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5", + Description: "OpenAI GPT-5 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32768, + SupportedEndpoints: []string{"/chat/completions", "/responses"}, + Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}, + }, + { + ID: "gpt-5-mini", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5 Mini", + Description: "OpenAI GPT-5 Mini via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 16384, + SupportedEndpoints: []string{"/chat/completions", "/responses"}, + Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}, + }, + { + ID: "gpt-5-codex", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5 Codex", + Description: "OpenAI GPT-5 Codex via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32768, + SupportedEndpoints: []string{"/responses"}, + Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}, + }, + { + ID: "gpt-5.1", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5.1", + Description: "OpenAI GPT-5.1 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32768, + SupportedEndpoints: []string{"/chat/completions", "/responses"}, + Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}}, + }, + { + ID: "gpt-5.1-codex", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5.1 Codex", + Description: "OpenAI GPT-5.1 Codex via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32768, + SupportedEndpoints: []string{"/responses"}, + Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}}, + }, + { + ID: "gpt-5.1-codex-mini", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5.1 Codex Mini", + Description: "OpenAI GPT-5.1 Codex Mini via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 16384, + SupportedEndpoints: []string{"/responses"}, + Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}}, + }, + { + ID: "gpt-5.1-codex-max", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5.1 Codex Max", + Description: "OpenAI GPT-5.1 Codex Max via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32768, + SupportedEndpoints: []string{"/responses"}, + Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}}, + }, + { + ID: "gpt-5.2", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5.2", + Description: "OpenAI GPT-5.2 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32768, + SupportedEndpoints: []string{"/chat/completions", "/responses"}, + Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}}, + }, + { + ID: "gpt-5.2-codex", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5.2 Codex", + Description: "OpenAI GPT-5.2 Codex via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32768, + SupportedEndpoints: []string{"/responses"}, + Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}}, + }, + { + ID: "gpt-5.3-codex", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5.3 Codex", + Description: "OpenAI GPT-5.3 Codex via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32768, + SupportedEndpoints: []string{"/responses"}, + Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}}, + }, + { + ID: "gpt-5.4", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5.4", + Description: "OpenAI GPT-5.4 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32768, + SupportedEndpoints: []string{"/responses"}, + Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}}, + }, + { + ID: "gpt-5.4-mini", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5.4 mini", + Description: "OpenAI GPT-5.4 mini via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32768, + SupportedEndpoints: []string{"/responses"}, + Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}}, + }, + { + ID: "claude-haiku-4.5", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Claude Haiku 4.5", + Description: "Anthropic Claude Haiku 4.5 via GitHub Copilot", + ContextLength: defaultCopilotClaudeContextLength, + MaxCompletionTokens: 64000, + SupportedEndpoints: copilotClaudeEndpoints, + }, + { + ID: "claude-opus-4.1", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Claude Opus 4.1", + Description: "Anthropic Claude Opus 4.1 via GitHub Copilot", + ContextLength: defaultCopilotClaudeContextLength, + MaxCompletionTokens: 32000, + SupportedEndpoints: copilotClaudeEndpoints, + }, + { + ID: "claude-opus-4.5", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Claude Opus 4.5", + Description: "Anthropic Claude Opus 4.5 via GitHub Copilot", + ContextLength: defaultCopilotClaudeContextLength, + MaxCompletionTokens: 64000, + SupportedEndpoints: copilotClaudeEndpoints, + Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}, + }, + { + ID: "claude-opus-4.6", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Claude Opus 4.6", + Description: "Anthropic Claude Opus 4.6 via GitHub Copilot", + ContextLength: defaultCopilotClaudeContextLength, + MaxCompletionTokens: 64000, + SupportedEndpoints: copilotClaudeEndpoints, + Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}, + }, + { + ID: "claude-sonnet-4", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Claude Sonnet 4", + Description: "Anthropic Claude Sonnet 4 via GitHub Copilot", + ContextLength: defaultCopilotClaudeContextLength, + MaxCompletionTokens: 64000, + SupportedEndpoints: copilotClaudeEndpoints, + Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}, + }, + { + ID: "claude-sonnet-4.5", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Claude Sonnet 4.5", + Description: "Anthropic Claude Sonnet 4.5 via GitHub Copilot", + ContextLength: defaultCopilotClaudeContextLength, + MaxCompletionTokens: 64000, + SupportedEndpoints: copilotClaudeEndpoints, + Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}, + }, + { + ID: "claude-sonnet-4.6", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Claude Sonnet 4.6", + Description: "Anthropic Claude Sonnet 4.6 via GitHub Copilot", + ContextLength: defaultCopilotClaudeContextLength, + MaxCompletionTokens: 64000, + SupportedEndpoints: copilotClaudeEndpoints, + Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}, + }, + { + ID: "gemini-2.5-pro", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Gemini 2.5 Pro", + Description: "Google Gemini 2.5 Pro via GitHub Copilot", + ContextLength: 1048576, + MaxCompletionTokens: 65536, + SupportedEndpoints: []string{"/chat/completions"}, + }, + { + ID: "gemini-3-pro-preview", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Gemini 3 Pro (Preview)", + Description: "Google Gemini 3 Pro Preview via GitHub Copilot", + ContextLength: 1048576, + MaxCompletionTokens: 65536, + SupportedEndpoints: []string{"/chat/completions"}, + }, + { + ID: "gemini-3.1-pro-preview", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Gemini 3.1 Pro (Preview)", + Description: "Google Gemini 3.1 Pro Preview via GitHub Copilot", + ContextLength: 173000, + MaxCompletionTokens: 65536, + SupportedEndpoints: []string{"/chat/completions"}, + }, + { + ID: "gemini-3-flash-preview", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Gemini 3 Flash (Preview)", + Description: "Google Gemini 3 Flash Preview via GitHub Copilot", + ContextLength: 173000, + MaxCompletionTokens: 65536, + SupportedEndpoints: []string{"/chat/completions"}, + }, + { + ID: "grok-code-fast-1", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Grok Code Fast 1", + Description: "xAI Grok Code Fast 1 via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 16384, + }, + { + ID: "oswe-vscode-prime", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Raptor mini (Preview)", + Description: "Raptor mini via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 16384, + SupportedEndpoints: []string{"/chat/completions", "/responses"}, + }, + }...) +} + +func GetKiroModels() []*ModelInfo { + return []*ModelInfo{ + // --- Base Models --- + { + ID: "kiro-auto", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Auto", + Description: "Automatic model selection by Kiro", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + { + ID: "kiro-claude-opus-4-6", + Object: "model", + Created: 1736899200, // 2025-01-15 + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Opus 4.6", + Description: "Claude Opus 4.6 via Kiro (2.2x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + { + ID: "kiro-claude-sonnet-4-6", + Object: "model", + Created: 1739836800, // 2025-02-18 + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Sonnet 4.6", + Description: "Claude Sonnet 4.6 via Kiro (1.3x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + { + ID: "kiro-claude-opus-4-5", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Opus 4.5", + Description: "Claude Opus 4.5 via Kiro (2.2x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + { + ID: "kiro-claude-sonnet-4-5", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Sonnet 4.5", + Description: "Claude Sonnet 4.5 via Kiro (1.3x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + { + ID: "kiro-claude-sonnet-4", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Sonnet 4", + Description: "Claude Sonnet 4 via Kiro (1.3x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + { + ID: "kiro-claude-haiku-4-5", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Haiku 4.5", + Description: "Claude Haiku 4.5 via Kiro (0.4x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + // --- 第三方模型 (通过 Kiro 接入) --- + { + ID: "kiro-deepseek-3-2", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro DeepSeek 3.2", + Description: "DeepSeek 3.2 via Kiro", + ContextLength: 128000, + MaxCompletionTokens: 32768, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + { + ID: "kiro-minimax-m2-1", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro MiniMax M2.1", + Description: "MiniMax M2.1 via Kiro", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + { + ID: "kiro-qwen3-coder-next", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Qwen3 Coder Next", + Description: "Qwen3 Coder Next via Kiro", + ContextLength: 128000, + MaxCompletionTokens: 32768, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + { + ID: "kiro-gpt-4o", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro GPT-4o", + Description: "OpenAI GPT-4o via Kiro", + ContextLength: 128000, + MaxCompletionTokens: 16384, + }, + { + ID: "kiro-gpt-4", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro GPT-4", + Description: "OpenAI GPT-4 via Kiro", + ContextLength: 128000, + MaxCompletionTokens: 8192, + }, + { + ID: "kiro-gpt-4-turbo", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro GPT-4 Turbo", + Description: "OpenAI GPT-4 Turbo via Kiro", + ContextLength: 128000, + MaxCompletionTokens: 16384, + }, + { + ID: "kiro-gpt-3-5-turbo", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro GPT-3.5 Turbo", + Description: "OpenAI GPT-3.5 Turbo via Kiro", + ContextLength: 16384, + MaxCompletionTokens: 4096, + }, + // --- Agentic Variants (Optimized for coding agents with chunked writes) --- + { + ID: "kiro-claude-opus-4-6-agentic", + Object: "model", + Created: 1736899200, // 2025-01-15 + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Opus 4.6 (Agentic)", + Description: "Claude Opus 4.6 optimized for coding agents (chunked writes)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + { + ID: "kiro-claude-sonnet-4-6-agentic", + Object: "model", + Created: 1739836800, // 2025-02-18 + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Sonnet 4.6 (Agentic)", + Description: "Claude Sonnet 4.6 optimized for coding agents (chunked writes)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + { + ID: "kiro-claude-opus-4-5-agentic", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Opus 4.5 (Agentic)", + Description: "Claude Opus 4.5 optimized for coding agents (chunked writes)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + { + ID: "kiro-claude-sonnet-4-5-agentic", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Sonnet 4.5 (Agentic)", + Description: "Claude Sonnet 4.5 optimized for coding agents (chunked writes)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + { + ID: "kiro-claude-sonnet-4-agentic", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Sonnet 4 (Agentic)", + Description: "Claude Sonnet 4 optimized for coding agents (chunked writes)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + { + ID: "kiro-claude-haiku-4-5-agentic", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Haiku 4.5 (Agentic)", + Description: "Claude Haiku 4.5 optimized for coding agents (chunked writes)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + { + ID: "kiro-deepseek-3-2-agentic", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro DeepSeek 3.2 (Agentic)", + Description: "DeepSeek 3.2 optimized for coding agents (chunked writes)", + ContextLength: 128000, + MaxCompletionTokens: 32768, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + { + ID: "kiro-minimax-m2-1-agentic", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro MiniMax M2.1 (Agentic)", + Description: "MiniMax M2.1 optimized for coding agents (chunked writes)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + { + ID: "kiro-qwen3-coder-next-agentic", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Qwen3 Coder Next (Agentic)", + Description: "Qwen3 Coder Next optimized for coding agents (chunked writes)", + ContextLength: 128000, + MaxCompletionTokens: 32768, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + } +} + +func GetAmazonQModels() []*ModelInfo { + return []*ModelInfo{ + { + ID: "amazonq-auto", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", // Uses Kiro executor - same API + DisplayName: "Amazon Q Auto", + Description: "Automatic model selection by Amazon Q", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "amazonq-claude-opus-4.5", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Amazon Q Claude Opus 4.5", + Description: "Claude Opus 4.5 via Amazon Q (2.2x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "amazonq-claude-sonnet-4.5", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Amazon Q Claude Sonnet 4.5", + Description: "Claude Sonnet 4.5 via Amazon Q (1.3x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "amazonq-claude-sonnet-4", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Amazon Q Claude Sonnet 4", + Description: "Claude Sonnet 4 via Amazon Q (1.3x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "amazonq-claude-haiku-4.5", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Amazon Q Claude Haiku 4.5", + Description: "Claude Haiku 4.5 via Amazon Q (0.4x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + } +} diff --git a/internal/registry/model_definitions_test.go b/internal/registry/model_definitions_test.go new file mode 100644 index 0000000000..bb2fc46046 --- /dev/null +++ b/internal/registry/model_definitions_test.go @@ -0,0 +1,94 @@ +package registry + +import "testing" + +func TestCodexFreeModelsExcludeGPT55(t *testing.T) { + model := findModelInfo(GetCodexFreeModels(), "gpt-5.5") + if model != nil { + t.Fatal("expected codex free tier to NOT include gpt-5.5") + } +} + +func TestCodexStaticModelsIncludeGPT55(t *testing.T) { + tierModels := map[string][]*ModelInfo{ + "team": GetCodexTeamModels(), + "plus": GetCodexPlusModels(), + "pro": GetCodexProModels(), + } + + for tier, models := range tierModels { + t.Run(tier, func(t *testing.T) { + model := findModelInfo(models, "gpt-5.5") + if model == nil { + t.Fatalf("expected codex %s tier to include gpt-5.5", tier) + } + assertGPT55ModelInfo(t, tier, model) + }) + } + + model := LookupStaticModelInfo("gpt-5.5") + if model == nil { + t.Fatal("expected LookupStaticModelInfo to find gpt-5.5") + } + assertGPT55ModelInfo(t, "lookup", model) +} + +func findModelInfo(models []*ModelInfo, id string) *ModelInfo { + for _, model := range models { + if model != nil && model.ID == id { + return model + } + } + return nil +} + +func assertGPT55ModelInfo(t *testing.T, source string, model *ModelInfo) { + t.Helper() + + if model.ID != "gpt-5.5" { + t.Fatalf("%s id mismatch: got %q", source, model.ID) + } + if model.Object != "model" { + t.Fatalf("%s object mismatch: got %q", source, model.Object) + } + if model.Created != 1776902400 { + t.Fatalf("%s created timestamp mismatch: got %d", source, model.Created) + } + if model.OwnedBy != "openai" { + t.Fatalf("%s owned_by mismatch: got %q", source, model.OwnedBy) + } + if model.Type != "openai" { + t.Fatalf("%s type mismatch: got %q", source, model.Type) + } + if model.DisplayName != "GPT 5.5" { + t.Fatalf("%s display name mismatch: got %q", source, model.DisplayName) + } + if model.Version != "gpt-5.5" { + t.Fatalf("%s version mismatch: got %q", source, model.Version) + } + if model.Description != "Frontier model for complex coding, research, and real-world work." { + t.Fatalf("%s description mismatch: got %q", source, model.Description) + } + if model.ContextLength != 272000 { + t.Fatalf("%s context length mismatch: got %d", source, model.ContextLength) + } + if model.MaxCompletionTokens != 128000 { + t.Fatalf("%s max completion tokens mismatch: got %d", source, model.MaxCompletionTokens) + } + if len(model.SupportedParameters) != 1 || model.SupportedParameters[0] != "tools" { + t.Fatalf("%s supported parameters mismatch: got %v", source, model.SupportedParameters) + } + if model.Thinking == nil { + t.Fatalf("%s missing thinking support", source) + } + + want := []string{"low", "medium", "high", "xhigh"} + if len(model.Thinking.Levels) != len(want) { + t.Fatalf("%s thinking level count mismatch: got %d, want %d", source, len(model.Thinking.Levels), len(want)) + } + for i, level := range want { + if model.Thinking.Levels[i] != level { + t.Fatalf("%s thinking level %d mismatch: got %q, want %q", source, i, model.Thinking.Levels[i], level) + } + } +} diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go index 3f3f530d27..560c8ccc0e 100644 --- a/internal/registry/model_registry.go +++ b/internal/registry/model_registry.go @@ -11,7 +11,7 @@ import ( "sync" "time" - misc "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" + misc "github.com/router-for-me/CLIProxyAPI/v7/internal/misc" log "github.com/sirupsen/logrus" ) @@ -56,6 +56,9 @@ type ModelInfo struct { // This is optional and currently used for Gemini thinking budget normalization. Thinking *ThinkingSupport `json:"thinking,omitempty"` + // SupportedEndpoints lists supported API endpoints (e.g., "/chat/completions", "/responses"). + SupportedEndpoints []string `json:"supported_endpoints,omitempty"` + // UserDefined indicates this model was defined through config file's models[] // array (e.g., openai-compatibility.*.models[], *-api-key.models[]). // UserDefined models have thinking configuration passed through without validation. @@ -1141,6 +1144,9 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string) if len(model.SupportedParameters) > 0 { result["supported_parameters"] = append([]string(nil), model.SupportedParameters...) } + if len(model.SupportedEndpoints) > 0 { + result["supported_endpoints"] = model.SupportedEndpoints + } return result case "claude": diff --git a/internal/registry/model_registry_safety_test.go b/internal/registry/model_registry_safety_test.go index 5f4f65d298..be5bf7908c 100644 --- a/internal/registry/model_registry_safety_test.go +++ b/internal/registry/model_registry_safety_test.go @@ -136,13 +136,13 @@ func TestGetAvailableModelsReturnsClonedSupportedParameters(t *testing.T) { } func TestLookupModelInfoReturnsCloneForStaticDefinitions(t *testing.T) { - first := LookupModelInfo("glm-4.6") + first := LookupModelInfo("claude-sonnet-4-6") if first == nil || first.Thinking == nil || len(first.Thinking.Levels) == 0 { t.Fatalf("expected static model with thinking levels, got %+v", first) } first.Thinking.Levels[0] = "mutated" - second := LookupModelInfo("glm-4.6") + second := LookupModelInfo("claude-sonnet-4-6") if second == nil || second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] == "mutated" { t.Fatalf("expected static lookup clone, got %+v", second) } diff --git a/internal/registry/models/models.json b/internal/registry/models/models.json index 65d8325169..fa56bb42a2 100644 --- a/internal/registry/models/models.json +++ b/internal/registry/models/models.json @@ -1292,6 +1292,52 @@ "xhigh" ] } + }, + { + "id": "gpt-5.5", + "object": "model", + "created": 1776902400, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.5", + "version": "gpt-5.5", + "description": "Frontier model for complex coding, research, and real-world work.", + "context_length": 272000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "codex-auto-review", + "object": "model", + "created": 1776902400, + "owned_by": "openai", + "type": "openai", + "display_name": "Codex Auto Review", + "version": "Codex Auto Review", + "description": "Automatic approval review model for Codex.", + "context_length": 272000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } } ], "codex-team": [ @@ -1387,6 +1433,52 @@ "xhigh" ] } + }, + { + "id": "gpt-5.5", + "object": "model", + "created": 1776902400, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.5", + "version": "gpt-5.5", + "description": "Frontier model for complex coding, research, and real-world work.", + "context_length": 272000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "codex-auto-review", + "object": "model", + "created": 1776902400, + "owned_by": "openai", + "type": "openai", + "display_name": "Codex Auto Review", + "version": "Codex Auto Review", + "description": "Automatic approval review model for Codex.", + "context_length": 272000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } } ], "codex-plus": [ @@ -1505,6 +1597,52 @@ "xhigh" ] } + }, + { + "id": "gpt-5.5", + "object": "model", + "created": 1776902400, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.5", + "version": "gpt-5.5", + "description": "Frontier model for complex coding, research, and real-world work.", + "context_length": 272000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "codex-auto-review", + "object": "model", + "created": 1776902400, + "owned_by": "openai", + "type": "openai", + "display_name": "Codex Auto Review", + "version": "Codex Auto Review", + "description": "Automatic approval review model for Codex.", + "context_length": 272000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } } ], "codex-pro": [ @@ -1623,6 +1761,52 @@ "xhigh" ] } + }, + { + "id": "gpt-5.5", + "object": "model", + "created": 1776902400, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.5", + "version": "gpt-5.5", + "description": "Frontier model for complex coding, research, and real-world work.", + "context_length": 272000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "codex-auto-review", + "object": "model", + "created": 1776902400, + "owned_by": "openai", + "type": "openai", + "display_name": "Codex Auto Review", + "version": "Codex Auto Review", + "description": "Automatic approval review model for Codex.", + "context_length": 272000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } } ], "kimi": [ @@ -1670,6 +1854,23 @@ "zero_allowed": true, "dynamic_allowed": true } + }, + { + "id": "kimi-k2.6", + "object": "model", + "created": 1776729600, + "owned_by": "moonshot", + "type": "kimi", + "display_name": "Kimi K2.6", + "description": "Kimi K2.6 - Latest Moonshot AI coding model with improved capabilities", + "context_length": 262144, + "max_completion_tokens": 65536, + "thinking": { + "min": 1024, + "max": 32000, + "zero_allowed": true, + "dynamic_allowed": true + } } ], "antigravity": [ diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go index f53e3e4d1d..41365b5f7a 100644 --- a/internal/runtime/executor/aistudio_executor.go +++ b/internal/runtime/executor/aistudio_executor.go @@ -13,14 +13,14 @@ import ( "net/url" "strings" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps" - "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" - "github.com/router-for-me/CLIProxyAPI/v6/internal/util" - "github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay" - cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" - cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" - sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + "github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor/helps" + "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking" + "github.com/router-for-me/CLIProxyAPI/v7/internal/util" + "github.com/router-for-me/CLIProxyAPI/v7/internal/wsrelay" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -284,8 +284,11 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth processEvent := func(event wsrelay.StreamEvent) bool { if event.Err != nil { helps.RecordAPIResponseError(ctx, e.cfg, event.Err) - reporter.PublishFailure(ctx) - out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)} + reporter.PublishFailure(ctx, event.Err) + select { + case out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}: + case <-ctx.Done(): + } return false } switch event.Type { @@ -303,7 +306,11 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth } lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, ¶m) for i := range lines { - out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])} + select { + case out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}: + case <-ctx.Done(): + return false + } } break } @@ -319,14 +326,21 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth } lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, ¶m) for i := range lines { - out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])} + select { + case out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}: + case <-ctx.Done(): + return false + } } reporter.Publish(ctx, helps.ParseGeminiUsage(event.Payload)) return false case wsrelay.MessageTypeError: helps.RecordAPIResponseError(ctx, e.cfg, event.Err) - reporter.PublishFailure(ctx) - out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)} + reporter.PublishFailure(ctx, event.Err) + select { + case out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}: + case <-ctx.Done(): + } return false } return true @@ -400,7 +414,10 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A } // Refresh refreshes the authentication credentials (no-op for AI Studio). -func (e *AIStudioExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { +func (e *AIStudioExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + if refreshed, handled, err := helps.RefreshAuthViaHome(ctx, e.cfg, auth); handled { + return refreshed, err + } return auth, nil } @@ -428,7 +445,8 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c } payload = fixGeminiImageAspectRatio(baseModel, payload) requestedModel := helps.PayloadRequestedModel(opts, req.Model) - payload = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated, requestedModel) + requestPath := helps.PayloadRequestPath(opts) + payload = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated, requestedModel, requestPath) payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens") payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType") payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema") diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 163b2d9279..2f8dff927c 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -23,18 +23,18 @@ import ( "time" "github.com/google/uuid" - "github.com/router-for-me/CLIProxyAPI/v6/internal/cache" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" - "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" - "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps" - "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" - antigravityclaude "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/claude" - "github.com/router-for-me/CLIProxyAPI/v6/internal/util" - sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" - cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" - cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" - sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + "github.com/router-for-me/CLIProxyAPI/v7/internal/cache" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + "github.com/router-for-me/CLIProxyAPI/v7/internal/misc" + "github.com/router-for-me/CLIProxyAPI/v7/internal/registry" + "github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor/helps" + "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking" + antigravityclaude "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/antigravity/claude" + "github.com/router-for-me/CLIProxyAPI/v7/internal/util" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -52,8 +52,8 @@ const ( defaultAntigravityAgent = "antigravity/1.21.9 darwin/arm64" // fallback only; overridden at runtime by misc.AntigravityUserAgent() antigravityAuthType = "antigravity" refreshSkew = 3000 * time.Second - antigravityCreditsRetryTTL = 5 * time.Hour - antigravityCreditsAutoDisableDuration = 5 * time.Hour + antigravityCreditsHintRefreshInterval = 10 * time.Minute + antigravityCreditsHintRefreshTimeout = 5 * time.Second antigravityShortQuotaCooldownThreshold = 5 * time.Minute antigravityInstantRetryThreshold = 3 * time.Second // systemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**" @@ -62,8 +62,6 @@ const ( type antigravity429Category string type antigravityCreditsFailureState struct { - Count int - DisabledUntil time.Time PermanentlyDisabled bool ExplicitBalanceExhausted bool } @@ -91,28 +89,85 @@ var ( randSource = rand.New(rand.NewSource(time.Now().UnixNano())) randSourceMutex sync.Mutex antigravityCreditsFailureByAuth sync.Map - antigravityPreferCreditsByModel sync.Map antigravityShortCooldownByAuth sync.Map + antigravityCreditsBalanceByAuth sync.Map // auth.ID → antigravityCreditsBalance + antigravityCreditsHintRefreshByID sync.Map // auth.ID → *antigravityCreditsHintRefreshState antigravityQuotaExhaustedKeywords = []string{ "quota_exhausted", "quota exhausted", } - antigravityCreditsExhaustedKeywords = []string{ - "google_one_ai", - "insufficient credit", - "insufficient credits", - "not enough credit", - "not enough credits", - "credit exhausted", - "credits exhausted", - "credit balance", - "minimumcreditamountforusage", - "minimum credit amount for usage", - "minimum credit", - "resource has been exhausted", - } ) +type antigravityCreditsBalance struct { + CreditAmount float64 + MinCreditAmount float64 + PaidTierID string + Known bool +} + +type antigravityCreditsHintRefreshState struct { + mu sync.Mutex + lastAttempt time.Time +} + +func antigravityAuthHasCredits(auth *cliproxyauth.Auth) bool { + if auth == nil || strings.TrimSpace(auth.ID) == "" { + return false + } + if hint, ok := cliproxyauth.GetAntigravityCreditsHint(auth.ID); ok && hint.Known { + return hint.Available + } + val, ok := antigravityCreditsBalanceByAuth.Load(strings.TrimSpace(auth.ID)) + if !ok { + return true // optimistic: assume credits available when balance unknown + } + bal, valid := val.(antigravityCreditsBalance) + if !valid { + antigravityCreditsBalanceByAuth.Delete(strings.TrimSpace(auth.ID)) + return false + } + if !bal.Known { + return false + } + available := bal.CreditAmount >= bal.MinCreditAmount + cliproxyauth.SetAntigravityCreditsHint(strings.TrimSpace(auth.ID), cliproxyauth.AntigravityCreditsHint{ + Known: true, + Available: available, + CreditAmount: bal.CreditAmount, + MinCreditAmount: bal.MinCreditAmount, + PaidTierID: bal.PaidTierID, + UpdatedAt: time.Now(), + }) + return available +} + +// parseMetaFloat extracts a float64 from auth.Metadata (handles string and numeric types). +func parseMetaFloat(metadata map[string]any, key string) (float64, bool) { + v, ok := metadata[key] + if !ok { + return 0, false + } + switch typed := v.(type) { + case float64: + return typed, true + case int: + return float64(typed), true + case int64: + return float64(typed), true + case uint64: + return float64(typed), true + case json.Number: + if f, err := typed.Float64(); err == nil { + return f, true + } + case string: + if f, err := strconv.ParseFloat(strings.TrimSpace(typed), 64); err == nil { + return f, true + } + } + return 0, false +} + // AntigravityExecutor proxies requests to the antigravity upstream. type AntigravityExecutor struct { cfg *config.Config @@ -189,7 +244,7 @@ func validateAntigravityRequestSignatures(from sdktranslator.Format, rawJSON []b if from.String() != "claude" { return rawJSON, nil } - // Always strip thinking blocks with empty signatures (proxy-generated). + // Always strip thinking blocks with invalid signatures (empty or non-Claude-format). rawJSON = antigravityclaude.StripEmptySignatureThinkingBlocks(rawJSON) if cache.SignatureCacheEnabled() { return rawJSON, nil @@ -298,49 +353,46 @@ func decideAntigravity429(body []byte) antigravity429Decision { decision.retryAfter = retryAfter } - lowerBody := strings.ToLower(string(body)) - for _, keyword := range antigravityQuotaExhaustedKeywords { - if strings.Contains(lowerBody, keyword) { - decision.kind = antigravity429DecisionFullQuotaExhausted - decision.reason = "quota_exhausted" - return decision - } - } - status := strings.TrimSpace(gjson.GetBytes(body, "error.status").String()) if !strings.EqualFold(status, "RESOURCE_EXHAUSTED") { return decision } details := gjson.GetBytes(body, "error.details") - if !details.Exists() || !details.IsArray() { - decision.kind = antigravity429DecisionSoftRetry - return decision - } - - for _, detail := range details.Array() { - if detail.Get("@type").String() != "type.googleapis.com/google.rpc.ErrorInfo" { - continue - } - reason := strings.TrimSpace(detail.Get("reason").String()) - decision.reason = reason - switch { - case strings.EqualFold(reason, "QUOTA_EXHAUSTED"): - decision.kind = antigravity429DecisionFullQuotaExhausted - return decision - case strings.EqualFold(reason, "RATE_LIMIT_EXCEEDED"): - if decision.retryAfter == nil { - decision.kind = antigravity429DecisionSoftRetry - return decision + if details.Exists() && details.IsArray() { + for _, detail := range details.Array() { + if detail.Get("@type").String() != "type.googleapis.com/google.rpc.ErrorInfo" { + continue } + reason := strings.TrimSpace(detail.Get("reason").String()) + decision.reason = reason switch { - case *decision.retryAfter < antigravityInstantRetryThreshold: - decision.kind = antigravity429DecisionInstantRetrySameAuth - case *decision.retryAfter < antigravityShortQuotaCooldownThreshold: - decision.kind = antigravity429DecisionShortCooldownSwitchAuth - default: + case strings.EqualFold(reason, "QUOTA_EXHAUSTED"): decision.kind = antigravity429DecisionFullQuotaExhausted + return decision + case strings.EqualFold(reason, "RATE_LIMIT_EXCEEDED"): + if decision.retryAfter == nil { + decision.kind = antigravity429DecisionSoftRetry + return decision + } + switch { + case *decision.retryAfter < antigravityInstantRetryThreshold: + decision.kind = antigravity429DecisionInstantRetrySameAuth + case *decision.retryAfter < antigravityShortQuotaCooldownThreshold: + decision.kind = antigravity429DecisionShortCooldownSwitchAuth + default: + decision.kind = antigravity429DecisionFullQuotaExhausted + } + return decision } + } + } + + lowerBody := strings.ToLower(string(body)) + for _, keyword := range antigravityQuotaExhaustedKeywords { + if strings.Contains(lowerBody, keyword) { + decision.kind = antigravity429DecisionFullQuotaExhausted + decision.reason = "quota_exhausted" return decision } } @@ -349,81 +401,10 @@ func decideAntigravity429(body []byte) antigravity429Decision { return decision } -func antigravityHasQuotaResetDelayOrModelInfo(body []byte) bool { - if len(body) == 0 { - return false - } - details := gjson.GetBytes(body, "error.details") - if !details.Exists() || !details.IsArray() { - return false - } - for _, detail := range details.Array() { - if detail.Get("@type").String() != "type.googleapis.com/google.rpc.ErrorInfo" { - continue - } - if strings.TrimSpace(detail.Get("metadata.quotaResetDelay").String()) != "" { - return true - } - if strings.TrimSpace(detail.Get("metadata.model").String()) != "" { - return true - } - } - return false -} - func antigravityCreditsRetryEnabled(cfg *config.Config) bool { return cfg != nil && cfg.QuotaExceeded.AntigravityCredits } -func antigravityCreditsFailureStateForAuth(auth *cliproxyauth.Auth) (string, antigravityCreditsFailureState, bool) { - if auth == nil || strings.TrimSpace(auth.ID) == "" { - return "", antigravityCreditsFailureState{}, false - } - authID := strings.TrimSpace(auth.ID) - value, ok := antigravityCreditsFailureByAuth.Load(authID) - if !ok { - return authID, antigravityCreditsFailureState{}, true - } - state, ok := value.(antigravityCreditsFailureState) - if !ok { - antigravityCreditsFailureByAuth.Delete(authID) - return authID, antigravityCreditsFailureState{}, true - } - return authID, state, true -} - -func antigravityCreditsDisabled(auth *cliproxyauth.Auth, now time.Time) bool { - authID, state, ok := antigravityCreditsFailureStateForAuth(auth) - if !ok { - return false - } - if state.PermanentlyDisabled { - return true - } - if state.DisabledUntil.IsZero() { - return false - } - if state.DisabledUntil.After(now) { - return true - } - antigravityCreditsFailureByAuth.Delete(authID) - return false -} - -func recordAntigravityCreditsFailure(auth *cliproxyauth.Auth, now time.Time) { - authID, state, ok := antigravityCreditsFailureStateForAuth(auth) - if !ok { - return - } - if state.PermanentlyDisabled { - antigravityCreditsFailureByAuth.Store(authID, state) - return - } - state.Count++ - state.DisabledUntil = now.Add(antigravityCreditsAutoDisableDuration) - antigravityCreditsFailureByAuth.Store(authID, state) -} - func clearAntigravityCreditsFailureState(auth *cliproxyauth.Auth) { if auth == nil || strings.TrimSpace(auth.ID) == "" { return @@ -440,6 +421,25 @@ func markAntigravityCreditsPermanentlyDisabled(auth *cliproxyauth.Auth) { ExplicitBalanceExhausted: true, } antigravityCreditsFailureByAuth.Store(authID, state) + antigravityCreditsBalanceByAuth.Store(authID, antigravityCreditsBalance{ + CreditAmount: 0, + MinCreditAmount: 1, + Known: true, + }) + cliproxyauth.SetAntigravityCreditsHint(authID, cliproxyauth.AntigravityCreditsHint{ + Known: true, + Available: false, + CreditAmount: 0, + MinCreditAmount: 1, + UpdatedAt: time.Now(), + }) +} + +func clearAntigravityCreditsPermanentlyDisabled(auth *cliproxyauth.Auth) { + if auth == nil || strings.TrimSpace(auth.ID) == "" { + return + } + antigravityCreditsFailureByAuth.Delete(strings.TrimSpace(auth.ID)) } func antigravityHasExplicitCreditsBalanceExhaustedReason(body []byte) bool { @@ -462,81 +462,6 @@ func antigravityHasExplicitCreditsBalanceExhaustedReason(body []byte) bool { return false } -func antigravityPreferCreditsKey(auth *cliproxyauth.Auth, modelName string) string { - if auth == nil { - return "" - } - authID := strings.TrimSpace(auth.ID) - modelName = strings.TrimSpace(modelName) - if authID == "" || modelName == "" { - return "" - } - return authID + "|" + modelName -} - -func antigravityShouldPreferCredits(auth *cliproxyauth.Auth, modelName string, now time.Time) bool { - key := antigravityPreferCreditsKey(auth, modelName) - if key == "" { - return false - } - value, ok := antigravityPreferCreditsByModel.Load(key) - if !ok { - return false - } - until, ok := value.(time.Time) - if !ok || until.IsZero() { - antigravityPreferCreditsByModel.Delete(key) - return false - } - if !until.After(now) { - antigravityPreferCreditsByModel.Delete(key) - return false - } - return true -} - -func markAntigravityPreferCredits(auth *cliproxyauth.Auth, modelName string, now time.Time, retryAfter *time.Duration) { - key := antigravityPreferCreditsKey(auth, modelName) - if key == "" { - return - } - until := now.Add(antigravityCreditsRetryTTL) - if retryAfter != nil && *retryAfter > 0 { - until = now.Add(*retryAfter) - } - antigravityPreferCreditsByModel.Store(key, until) -} - -func clearAntigravityPreferCredits(auth *cliproxyauth.Auth, modelName string) { - key := antigravityPreferCreditsKey(auth, modelName) - if key == "" { - return - } - antigravityPreferCreditsByModel.Delete(key) -} - -func shouldMarkAntigravityCreditsExhausted(statusCode int, body []byte, reqErr error) bool { - if reqErr != nil || statusCode == 0 { - return false - } - if statusCode >= http.StatusInternalServerError || statusCode == http.StatusRequestTimeout { - return false - } - lowerBody := strings.ToLower(string(body)) - for _, keyword := range antigravityCreditsExhaustedKeywords { - if strings.Contains(lowerBody, keyword) { - if keyword == "resource has been exhausted" && - statusCode == http.StatusTooManyRequests && - decideAntigravity429(body).kind == antigravity429DecisionSoftRetry && - !antigravityHasQuotaResetDelayOrModelInfo(body) { - return false - } - return true - } - } - return false -} - func newAntigravityStatusErr(statusCode int, body []byte) statusErr { err := statusErr{code: statusCode, msg: string(body)} if statusCode == http.StatusTooManyRequests { @@ -547,136 +472,13 @@ func newAntigravityStatusErr(statusCode int, body []byte) statusErr { return err } -func (e *AntigravityExecutor) attemptCreditsFallback( - ctx context.Context, - auth *cliproxyauth.Auth, - httpClient *http.Client, - token string, - modelName string, - payload []byte, - stream bool, - alt string, - baseURL string, - originalBody []byte, -) (*http.Response, bool) { - if !antigravityCreditsRetryEnabled(e.cfg) { - return nil, false - } - if decideAntigravity429(originalBody).kind != antigravity429DecisionFullQuotaExhausted { - return nil, false - } - now := time.Now() - if shouldForcePermanentDisableCredits(originalBody) { - clearAntigravityPreferCredits(auth, modelName) - markAntigravityCreditsPermanentlyDisabled(auth) - return nil, false - } - - if antigravityHasExplicitCreditsBalanceExhaustedReason(originalBody) { - clearAntigravityPreferCredits(auth, modelName) - markAntigravityCreditsPermanentlyDisabled(auth) - return nil, false - } - - if antigravityCreditsDisabled(auth, now) { - return nil, false - } - creditsPayload := injectEnabledCreditTypes(payload) - if len(creditsPayload) == 0 { - return nil, false - } - - httpReq, errReq := e.buildRequest(ctx, auth, token, modelName, creditsPayload, stream, alt, baseURL) - if errReq != nil { - helps.RecordAPIResponseError(ctx, e.cfg, errReq) - clearAntigravityPreferCredits(auth, modelName) - recordAntigravityCreditsFailure(auth, now) - return nil, true - } - httpResp, errDo := httpClient.Do(httpReq) - if errDo != nil { - helps.RecordAPIResponseError(ctx, e.cfg, errDo) - clearAntigravityPreferCredits(auth, modelName) - recordAntigravityCreditsFailure(auth, now) - return nil, true - } - if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices { - retryAfter, _ := parseRetryDelay(originalBody) - markAntigravityPreferCredits(auth, modelName, now, retryAfter) - clearAntigravityCreditsFailureState(auth) - return httpResp, true - } - - helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - bodyBytes, errRead := io.ReadAll(httpResp.Body) - if errClose := httpResp.Body.Close(); errClose != nil { - log.Errorf("antigravity executor: close credits fallback response body error: %v", errClose) - } - if errRead != nil { - helps.RecordAPIResponseError(ctx, e.cfg, errRead) - clearAntigravityPreferCredits(auth, modelName) - recordAntigravityCreditsFailure(auth, now) - return nil, true - } - helps.AppendAPIResponseChunk(ctx, e.cfg, bodyBytes) - if shouldForcePermanentDisableCredits(bodyBytes) { - clearAntigravityPreferCredits(auth, modelName) - markAntigravityCreditsPermanentlyDisabled(auth) - return nil, true - } - - if antigravityHasExplicitCreditsBalanceExhaustedReason(bodyBytes) { - clearAntigravityPreferCredits(auth, modelName) - markAntigravityCreditsPermanentlyDisabled(auth) - return nil, true - } - - clearAntigravityPreferCredits(auth, modelName) - recordAntigravityCreditsFailure(auth, now) - return nil, true -} - -func (e *AntigravityExecutor) handleDirectCreditsFailure(ctx context.Context, auth *cliproxyauth.Auth, modelName string, reqErr error) { - if reqErr != nil { - if shouldForcePermanentDisableCredits(reqErrBody(reqErr)) { - clearAntigravityPreferCredits(auth, modelName) - markAntigravityCreditsPermanentlyDisabled(auth) - return - } - - if antigravityHasExplicitCreditsBalanceExhaustedReason(reqErrBody(reqErr)) { - clearAntigravityPreferCredits(auth, modelName) - markAntigravityCreditsPermanentlyDisabled(auth) - return - } - - helps.RecordAPIResponseError(ctx, e.cfg, reqErr) - } - clearAntigravityPreferCredits(auth, modelName) - recordAntigravityCreditsFailure(auth, time.Now()) -} -func reqErrBody(reqErr error) []byte { - if reqErr == nil { - return nil - } - msg := reqErr.Error() - if strings.TrimSpace(msg) == "" { - return nil - } - return []byte(msg) -} - -func shouldForcePermanentDisableCredits(body []byte) bool { - return antigravityHasExplicitCreditsBalanceExhaustedReason(body) -} - // Execute performs a non-streaming request to the Antigravity API. func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { if opts.Alt == "responses/compact" { return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} } baseModel := thinking.ParseSuffix(req.Model).ModelName - if inCooldown, remaining := antigravityIsInShortCooldown(auth, baseModel, time.Now()); inCooldown { + if inCooldown, remaining := antigravityIsInShortCooldown(auth, baseModel, time.Now()); inCooldown && !antigravityShouldBypassShortCooldown(ctx, e.cfg) { log.Debugf("antigravity executor: auth %s in short cooldown for model %s (%s remaining), returning 429 to switch auth", auth.ID, baseModel, remaining) d := remaining return resp, statusErr{code: http.StatusTooManyRequests, msg: fmt.Sprintf("auth in short cooldown, %s remaining", remaining), retryAfter: &d} @@ -719,7 +521,10 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au } requestedModel := helps.PayloadRequestedModel(opts, req.Model) - translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel) + requestPath := helps.PayloadRequestPath(opts) + translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel, requestPath) + + useCredits := cliproxyauth.AntigravityCreditsRequested(ctx) && antigravityCreditsRetryEnabled(e.cfg) baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0) @@ -733,11 +538,10 @@ attemptLoop: for idx, baseURL := range baseURLs { requestPayload := translated - usedCreditsDirect := false - if antigravityCreditsRetryEnabled(e.cfg) && antigravityShouldPreferCredits(auth, baseModel, time.Now()) { - if creditsPayload := injectEnabledCreditTypes(translated); len(creditsPayload) > 0 { - requestPayload = creditsPayload - usedCreditsDirect = true + if useCredits { + if cp := injectEnabledCreditTypes(translated); len(cp) > 0 { + requestPayload = cp + helps.MarkCreditsUsed(ctx) } } @@ -785,7 +589,6 @@ attemptLoop: wait := antigravityInstantRetryDelay(*decision.retryAfter) log.Debugf("antigravity executor: instant retry for model %s, waiting %s", baseModel, wait) if errWait := antigravityWait(ctx, wait); errWait != nil { - return resp, errWait } } @@ -794,34 +597,13 @@ attemptLoop: case antigravity429DecisionShortCooldownSwitchAuth: if decision.retryAfter != nil && *decision.retryAfter > 0 { markAntigravityShortCooldown(auth, baseModel, time.Now(), *decision.retryAfter) - log.Debugf("antigravity executor: short quota cooldown (%s) for model %s, recorded cooldown and skipping credits fallback", *decision.retryAfter, baseModel) + log.Debugf("antigravity executor: short quota cooldown (%s) for model %s, recorded cooldown", *decision.retryAfter, baseModel) } case antigravity429DecisionFullQuotaExhausted: - if usedCreditsDirect { - clearAntigravityPreferCredits(auth, baseModel) - recordAntigravityCreditsFailure(auth, time.Now()) - } else { - creditsResp, _ := e.attemptCreditsFallback(ctx, auth, httpClient, token, baseModel, translated, false, opts.Alt, baseURL, bodyBytes) - if creditsResp != nil { - helps.RecordAPIResponseMetadata(ctx, e.cfg, creditsResp.StatusCode, creditsResp.Header.Clone()) - creditsBody, errCreditsRead := io.ReadAll(creditsResp.Body) - if errClose := creditsResp.Body.Close(); errClose != nil { - log.Errorf("antigravity executor: close credits success response body error: %v", errClose) - } - if errCreditsRead != nil { - helps.RecordAPIResponseError(ctx, e.cfg, errCreditsRead) - err = errCreditsRead - return resp, err - } - helps.AppendAPIResponseChunk(ctx, e.cfg, creditsBody) - reporter.Publish(ctx, helps.ParseAntigravityUsage(creditsBody)) - var param any - converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, creditsBody, ¶m) - resp = cliproxyexecutor.Response{Payload: converted, Headers: creditsResp.Header.Clone()} - reporter.EnsurePublished(ctx) - return resp, nil - } + if useCredits && antigravityHasExplicitCreditsBalanceExhaustedReason(bodyBytes) { + markAntigravityCreditsPermanentlyDisabled(auth) } + // No credits logic - just fall through to error return below } } @@ -870,6 +652,10 @@ attemptLoop: return resp, err } + // Success + if useCredits { + clearAntigravityCreditsFailureState(auth) + } reporter.Publish(ctx, helps.ParseAntigravityUsage(bodyBytes)) var param any converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m) @@ -895,7 +681,7 @@ attemptLoop: // executeClaudeNonStream performs a claude non-streaming request to the Antigravity API. func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { baseModel := thinking.ParseSuffix(req.Model).ModelName - if inCooldown, remaining := antigravityIsInShortCooldown(auth, baseModel, time.Now()); inCooldown { + if inCooldown, remaining := antigravityIsInShortCooldown(auth, baseModel, time.Now()); inCooldown && !antigravityShouldBypassShortCooldown(ctx, e.cfg) { log.Debugf("antigravity executor: auth %s in short cooldown for model %s (%s remaining), returning 429 to switch auth", auth.ID, baseModel, remaining) d := remaining return resp, statusErr{code: http.StatusTooManyRequests, msg: fmt.Sprintf("auth in short cooldown, %s remaining", remaining), retryAfter: &d} @@ -933,7 +719,10 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth * } requestedModel := helps.PayloadRequestedModel(opts, req.Model) - translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel) + requestPath := helps.PayloadRequestPath(opts) + translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel, requestPath) + + useCredits := cliproxyauth.AntigravityCreditsRequested(ctx) && antigravityCreditsRetryEnabled(e.cfg) baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0) @@ -948,11 +737,10 @@ attemptLoop: for idx, baseURL := range baseURLs { requestPayload := translated - usedCreditsDirect := false - if antigravityCreditsRetryEnabled(e.cfg) && antigravityShouldPreferCredits(auth, baseModel, time.Now()) { - if creditsPayload := injectEnabledCreditTypes(translated); len(creditsPayload) > 0 { - requestPayload = creditsPayload - usedCreditsDirect = true + if useCredits { + if cp := injectEnabledCreditTypes(translated); len(cp) > 0 { + requestPayload = cp + helps.MarkCreditsUsed(ctx) } } httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, requestPayload, true, opts.Alt, baseURL) @@ -1014,7 +802,6 @@ attemptLoop: wait := antigravityInstantRetryDelay(*decision.retryAfter) log.Debugf("antigravity executor: instant retry for model %s, waiting %s", baseModel, wait) if errWait := antigravityWait(ctx, wait); errWait != nil { - return resp, errWait } } @@ -1023,25 +810,16 @@ attemptLoop: case antigravity429DecisionShortCooldownSwitchAuth: if decision.retryAfter != nil && *decision.retryAfter > 0 { markAntigravityShortCooldown(auth, baseModel, time.Now(), *decision.retryAfter) - log.Debugf("antigravity executor: short quota cooldown (%s) for model %s, recorded cooldown and skipping credits fallback", *decision.retryAfter, baseModel) + log.Debugf("antigravity executor: short quota cooldown (%s) for model %s, recorded cooldown", *decision.retryAfter, baseModel) } case antigravity429DecisionFullQuotaExhausted: - if usedCreditsDirect { - clearAntigravityPreferCredits(auth, baseModel) - recordAntigravityCreditsFailure(auth, time.Now()) - } else { - creditsResp, _ := e.attemptCreditsFallback(ctx, auth, httpClient, token, baseModel, translated, true, opts.Alt, baseURL, bodyBytes) - if creditsResp != nil { - httpResp = creditsResp - helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - } + if useCredits && antigravityHasExplicitCreditsBalanceExhaustedReason(bodyBytes) { + markAntigravityCreditsPermanentlyDisabled(auth) } + // No credits logic - just fall through to error return below } } - if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices { - goto streamSuccessClaudeNonStream - } lastStatus = httpResp.StatusCode lastBody = append([]byte(nil), bodyBytes...) lastErr = nil @@ -1085,7 +863,10 @@ attemptLoop: return resp, err } - streamSuccessClaudeNonStream: + // Stream success + if useCredits { + clearAntigravityCreditsFailureState(auth) + } out := make(chan cliproxyexecutor.StreamChunk) go func(resp *http.Response) { defer close(out) @@ -1117,7 +898,7 @@ attemptLoop: } if errScan := scanner.Err(); errScan != nil { helps.RecordAPIResponseError(ctx, e.cfg, errScan) - reporter.PublishFailure(ctx) + reporter.PublishFailure(ctx, errScan) out <- cliproxyexecutor.StreamChunk{Err: errScan} } else { reporter.EnsurePublished(ctx) @@ -1360,7 +1141,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya baseModel := thinking.ParseSuffix(req.Model).ModelName ctx = context.WithValue(ctx, "alt", "") - if inCooldown, remaining := antigravityIsInShortCooldown(auth, baseModel, time.Now()); inCooldown { + if inCooldown, remaining := antigravityIsInShortCooldown(auth, baseModel, time.Now()); inCooldown && !antigravityShouldBypassShortCooldown(ctx, e.cfg) { log.Debugf("antigravity executor: auth %s in short cooldown for model %s (%s remaining), returning 429 to switch auth", auth.ID, baseModel, remaining) d := remaining return nil, statusErr{code: http.StatusTooManyRequests, msg: fmt.Sprintf("auth in short cooldown, %s remaining", remaining), retryAfter: &d} @@ -1389,6 +1170,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya if updatedAuth != nil { auth = updatedAuth } + originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true) @@ -1398,7 +1180,10 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya } requestedModel := helps.PayloadRequestedModel(opts, req.Model) - translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel) + requestPath := helps.PayloadRequestPath(opts) + translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel, requestPath) + + useCredits := cliproxyauth.AntigravityCreditsRequested(ctx) && antigravityCreditsRetryEnabled(e.cfg) baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0) @@ -1413,11 +1198,10 @@ attemptLoop: for idx, baseURL := range baseURLs { requestPayload := translated - usedCreditsDirect := false - if antigravityCreditsRetryEnabled(e.cfg) && antigravityShouldPreferCredits(auth, baseModel, time.Now()) { - if creditsPayload := injectEnabledCreditTypes(translated); len(creditsPayload) > 0 { - requestPayload = creditsPayload - usedCreditsDirect = true + if useCredits { + if cp := injectEnabledCreditTypes(translated); len(cp) > 0 { + requestPayload = cp + helps.MarkCreditsUsed(ctx) } } httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, requestPayload, true, opts.Alt, baseURL) @@ -1478,7 +1262,6 @@ attemptLoop: wait := antigravityInstantRetryDelay(*decision.retryAfter) log.Debugf("antigravity executor: instant retry for model %s, waiting %s", baseModel, wait) if errWait := antigravityWait(ctx, wait); errWait != nil { - return nil, errWait } } @@ -1487,25 +1270,16 @@ attemptLoop: case antigravity429DecisionShortCooldownSwitchAuth: if decision.retryAfter != nil && *decision.retryAfter > 0 { markAntigravityShortCooldown(auth, baseModel, time.Now(), *decision.retryAfter) - log.Debugf("antigravity executor: short quota cooldown (%s) for model %s, recorded cooldown and skipping credits fallback", *decision.retryAfter, baseModel) + log.Debugf("antigravity executor: short quota cooldown (%s) for model %s recorded", *decision.retryAfter, baseModel) } case antigravity429DecisionFullQuotaExhausted: - if usedCreditsDirect { - clearAntigravityPreferCredits(auth, baseModel) - recordAntigravityCreditsFailure(auth, time.Now()) - } else { - creditsResp, _ := e.attemptCreditsFallback(ctx, auth, httpClient, token, baseModel, translated, true, opts.Alt, baseURL, bodyBytes) - if creditsResp != nil { - httpResp = creditsResp - helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) - } + if useCredits && antigravityHasExplicitCreditsBalanceExhaustedReason(bodyBytes) { + markAntigravityCreditsPermanentlyDisabled(auth) } + // No credits logic - just fall through to error return below } } - if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices { - goto streamSuccessExecuteStream - } lastStatus = httpResp.StatusCode lastBody = append([]byte(nil), bodyBytes...) lastErr = nil @@ -1549,7 +1323,10 @@ attemptLoop: return nil, err } - streamSuccessExecuteStream: + // Stream success + if useCredits { + clearAntigravityCreditsFailureState(auth) + } out := make(chan cliproxyexecutor.StreamChunk) go func(resp *http.Response) { defer close(out) @@ -1580,17 +1357,28 @@ attemptLoop: chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), ¶m) for i := range chunks { - out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]} + select { + case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}: + case <-ctx.Done(): + return + } } } tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("[DONE]"), ¶m) for i := range tail { - out <- cliproxyexecutor.StreamChunk{Payload: tail[i]} + select { + case out <- cliproxyexecutor.StreamChunk{Payload: tail[i]}: + case <-ctx.Done(): + return + } } if errScan := scanner.Err(); errScan != nil { helps.RecordAPIResponseError(ctx, e.cfg, errScan) - reporter.PublishFailure(ctx) - out <- cliproxyexecutor.StreamChunk{Err: errScan} + reporter.PublishFailure(ctx, errScan) + select { + case out <- cliproxyexecutor.StreamChunk{Err: errScan}: + case <-ctx.Done(): + } } else { reporter.EnsurePublished(ctx) } @@ -1614,6 +1402,9 @@ attemptLoop: // Refresh refreshes the authentication credentials using the refresh token. func (e *AntigravityExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + if refreshed, handled, err := helps.RefreshAuthViaHome(ctx, e.cfg, auth); handled { + return refreshed, err + } if auth == nil { return auth, nil } @@ -1792,6 +1583,7 @@ func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *clipr accessToken := metaStringValue(auth.Metadata, "access_token") expiry := tokenExpiry(auth.Metadata) if accessToken != "" && expiry.After(time.Now().Add(refreshSkew)) { + e.maybeRefreshAntigravityCreditsHint(ctx, auth, accessToken) return accessToken, nil, nil } refreshCtx := context.Background() @@ -1800,6 +1592,18 @@ func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *clipr refreshCtx = context.WithValue(refreshCtx, "cliproxy.roundtripper", rt) } } + if refreshed, handled, err := helps.RefreshAuthViaHome(refreshCtx, e.cfg, auth); handled { + if err != nil { + return "", nil, err + } + token := metaStringValue(refreshed.Metadata, "access_token") + if strings.TrimSpace(token) == "" { + return "", nil, statusErr{code: http.StatusUnauthorized, msg: "missing access token"} + } + e.maybeRefreshAntigravityCreditsHint(ctx, refreshed, token) + return token, refreshed, nil + } + updated, errRefresh := e.refreshToken(refreshCtx, auth.Clone()) if errRefresh != nil { return "", nil, errRefresh @@ -1807,6 +1611,63 @@ func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *clipr return metaStringValue(updated.Metadata, "access_token"), updated, nil } +func (e *AntigravityExecutor) maybeRefreshAntigravityCreditsHint(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) { + if e == nil || auth == nil || !antigravityCreditsRetryEnabled(e.cfg) { + return + } + if ctx != nil && ctx.Err() != nil { + return + } + authID := strings.TrimSpace(auth.ID) + if authID == "" { + return + } + if hint, ok := cliproxyauth.GetAntigravityCreditsHint(authID); ok && hint.Known { + return + } + if strings.TrimSpace(accessToken) == "" { + accessToken = metaStringValue(auth.Metadata, "access_token") + } + if strings.TrimSpace(accessToken) == "" { + return + } + + state := &antigravityCreditsHintRefreshState{} + if existing, loaded := antigravityCreditsHintRefreshByID.LoadOrStore(authID, state); loaded { + if cast, ok := existing.(*antigravityCreditsHintRefreshState); ok && cast != nil { + state = cast + } else { + antigravityCreditsHintRefreshByID.Delete(authID) + antigravityCreditsHintRefreshByID.Store(authID, state) + } + } + + now := time.Now() + if !state.mu.TryLock() { + return + } + if !state.lastAttempt.IsZero() && now.Sub(state.lastAttempt) < antigravityCreditsHintRefreshInterval { + state.mu.Unlock() + return + } + state.lastAttempt = now + + refreshCtx := context.Background() + if ctx != nil { + if rt, ok := ctx.Value("cliproxy.roundtripper").(http.RoundTripper); ok && rt != nil { + refreshCtx = context.WithValue(refreshCtx, "cliproxy.roundtripper", rt) + } + } + refreshCtx, cancel := context.WithTimeout(refreshCtx, antigravityCreditsHintRefreshTimeout) + authCopy := auth.Clone() + + go func(state *antigravityCreditsHintRefreshState, auth *cliproxyauth.Auth, token string) { + defer cancel() + defer state.mu.Unlock() + e.updateAntigravityCreditsBalance(refreshCtx, auth, token) + }(state, authCopy, accessToken) +} + func (e *AntigravityExecutor) refreshToken(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { if auth == nil { return nil, statusErr{code: http.StatusUnauthorized, msg: "missing auth"} @@ -1882,6 +1743,7 @@ func (e *AntigravityExecutor) refreshToken(ctx context.Context, auth *cliproxyau if errProject := e.ensureAntigravityProjectID(ctx, auth, tokenResp.AccessToken); errProject != nil { log.Warnf("antigravity executor: ensure project id failed: %v", errProject) } + e.updateAntigravityCreditsBalance(ctx, auth, tokenResp.AccessToken) return auth, nil } @@ -1918,6 +1780,107 @@ func (e *AntigravityExecutor) ensureAntigravityProjectID(ctx context.Context, au return nil } +func (e *AntigravityExecutor) updateAntigravityCreditsBalance(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) { + if auth == nil || strings.TrimSpace(auth.ID) == "" { + return + } + token := strings.TrimSpace(accessToken) + if token == "" { + token = metaStringValue(auth.Metadata, "access_token") + } + if token == "" { + return + } + + userAgent := resolveLoadCodeAssistUserAgent(auth) + loadReqBody, errMarshal := json.Marshal(map[string]any{ + "metadata": map[string]string{ + "ide_type": "ANTIGRAVITY", + "ide_version": misc.AntigravityVersionFromUserAgent(userAgent), + "ide_name": "antigravity", + }, + }) + if errMarshal != nil { + log.Debugf("antigravity executor: marshal loadCodeAssist request error: %v", errMarshal) + return + } + baseURL := buildBaseURL(auth) + endpointURL := strings.TrimSuffix(baseURL, "/") + "/v1internal:loadCodeAssist" + httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, bytes.NewReader(loadReqBody)) + if errReq != nil { + log.Debugf("antigravity executor: create loadCodeAssist request error: %v", errReq) + return + } + httpReq.Header.Set("Authorization", "Bearer "+token) + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("User-Agent", userAgent) + httpReq.Header.Set("X-Goog-Api-Client", misc.AntigravityGoogAPIClientUA) + + httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0) + httpResp, errDo := httpClient.Do(httpReq) + if errDo != nil { + log.Debugf("antigravity executor: loadCodeAssist request error: %v", errDo) + return + } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("antigravity executor: close loadCodeAssist response body error: %v", errClose) + } + }() + + bodyBytes, errRead := io.ReadAll(httpResp.Body) + if errRead != nil || httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices { + log.Debugf("antigravity executor: loadCodeAssist returned status %d, err=%v", httpResp.StatusCode, errRead) + return + } + + authID := strings.TrimSpace(auth.ID) + paidTierID := strings.TrimSpace(gjson.GetBytes(bodyBytes, "paidTier.id").String()) + + credits := gjson.GetBytes(bodyBytes, "paidTier.availableCredits") + if !credits.IsArray() { + cliproxyauth.SetAntigravityCreditsHint(authID, cliproxyauth.AntigravityCreditsHint{ + Known: true, + Available: false, + PaidTierID: paidTierID, + UpdatedAt: time.Now(), + }) + return + } + for _, credit := range credits.Array() { + if !strings.EqualFold(credit.Get("creditType").String(), "GOOGLE_ONE_AI") { + continue + } + creditAmount, errCA := strconv.ParseFloat(strings.TrimSpace(credit.Get("creditAmount").String()), 64) + if errCA != nil { + continue + } + minAmount, errMA := strconv.ParseFloat(strings.TrimSpace(credit.Get("minimumCreditAmountForUsage").String()), 64) + if errMA != nil { + continue + } + bal := antigravityCreditsBalance{ + CreditAmount: creditAmount, + MinCreditAmount: minAmount, + PaidTierID: paidTierID, + Known: true, + } + antigravityCreditsBalanceByAuth.Store(authID, bal) + cliproxyauth.SetAntigravityCreditsHint(authID, cliproxyauth.AntigravityCreditsHint{ + Known: true, + Available: creditAmount >= minAmount, + CreditAmount: creditAmount, + MinCreditAmount: minAmount, + PaidTierID: paidTierID, + UpdatedAt: time.Now(), + }) + if creditAmount >= minAmount { + clearAntigravityCreditsPermanentlyDisabled(auth) + } + return + } +} + func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyauth.Auth, token, modelName string, payload []byte, stream bool, alt, baseURL string) (*http.Request, error) { if token == "" { return nil, statusErr{code: http.StatusUnauthorized, msg: "missing access token"} @@ -2149,19 +2112,28 @@ func resolveHost(base string) string { } func resolveUserAgent(auth *cliproxyauth.Auth) string { + return misc.AntigravityRequestUserAgent(antigravityConfiguredUserAgent(auth)) +} + +func resolveLoadCodeAssistUserAgent(auth *cliproxyauth.Auth) string { + return misc.AntigravityLoadCodeAssistUserAgent(antigravityConfiguredUserAgent(auth)) +} + +func antigravityConfiguredUserAgent(auth *cliproxyauth.Auth) string { + raw := "" if auth != nil { if auth.Attributes != nil { if ua := strings.TrimSpace(auth.Attributes["user_agent"]); ua != "" { - return ua + raw = ua } } - if auth.Metadata != nil { + if raw == "" && auth.Metadata != nil { if ua, ok := auth.Metadata["user_agent"].(string); ok && strings.TrimSpace(ua) != "" { - return strings.TrimSpace(ua) + raw = strings.TrimSpace(ua) } } } - return misc.AntigravityUserAgent() + return raw } func antigravityRetryAttempts(auth *cliproxyauth.Auth, cfg *config.Config) int { @@ -2220,6 +2192,10 @@ func antigravityShouldRetrySoftRateLimit(statusCode int, body []byte) bool { return decideAntigravity429(body).kind == antigravity429DecisionSoftRetry } +func antigravityShouldBypassShortCooldown(ctx context.Context, cfg *config.Config) bool { + return cliproxyauth.AntigravityCreditsRequested(ctx) && antigravityCreditsRetryEnabled(cfg) +} + func antigravitySoftRateLimitDelay(attempt int) time.Duration { if attempt < 0 { attempt = 0 @@ -2321,9 +2297,9 @@ var antigravityBaseURLFallbackOrder = func(auth *cliproxyauth.Auth) []string { return []string{base} } return []string{ - antigravityBaseURLProd, antigravityBaseURLDaily, - antigravitySandboxBaseURLDaily, + antigravityBaseURLProd, + // antigravitySandboxBaseURLDaily, } } diff --git a/internal/runtime/executor/antigravity_executor_buildrequest_test.go b/internal/runtime/executor/antigravity_executor_buildrequest_test.go index ed2d79e632..f0711752e4 100644 --- a/internal/runtime/executor/antigravity_executor_buildrequest_test.go +++ b/internal/runtime/executor/antigravity_executor_buildrequest_test.go @@ -6,7 +6,7 @@ import ( "io" "testing" - cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" ) func TestAntigravityBuildRequest_SanitizesGeminiToolSchema(t *testing.T) { diff --git a/internal/runtime/executor/antigravity_executor_credits_test.go b/internal/runtime/executor/antigravity_executor_credits_test.go index cf968ac794..e16e64434f 100644 --- a/internal/runtime/executor/antigravity_executor_credits_test.go +++ b/internal/runtime/executor/antigravity_executor_credits_test.go @@ -10,16 +10,17 @@ import ( "testing" "time" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" - cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" - sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator" ) func resetAntigravityCreditsRetryState() { antigravityCreditsFailureByAuth = sync.Map{} - antigravityPreferCreditsByModel = sync.Map{} antigravityShortCooldownByAuth = sync.Map{} + antigravityCreditsBalanceByAuth = sync.Map{} + antigravityCreditsHintRefreshByID = sync.Map{} } func TestClassifyAntigravity429(t *testing.T) { @@ -30,6 +31,43 @@ func TestClassifyAntigravity429(t *testing.T) { } }) + t.Run("standard antigravity rate limit with ui message stays rate limited", func(t *testing.T) { + body := []byte(`{ + "error": { + "code": 429, + "message": "You have exhausted your capacity on this model. Your quota will reset after 0s.", + "status": "RESOURCE_EXHAUSTED", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "RATE_LIMIT_EXCEEDED", + "domain": "cloudcode-pa.googleapis.com", + "metadata": { + "model": "claude-opus-4-6-thinking", + "quotaResetDelay": "479.417207ms", + "quotaResetTimeStamp": "2026-04-20T09:19:49Z", + "uiMessage": "true" + } + }, + { + "@type": "type.googleapis.com/google.rpc.RetryInfo", + "retryDelay": "0.479417207s" + } + ] + } + }`) + if got := classifyAntigravity429(body); got != antigravity429RateLimited { + t.Fatalf("classifyAntigravity429() = %q, want %q", got, antigravity429RateLimited) + } + decision := decideAntigravity429(body) + if decision.kind != antigravity429DecisionInstantRetrySameAuth { + t.Fatalf("decideAntigravity429().kind = %q, want %q", decision.kind, antigravity429DecisionInstantRetrySameAuth) + } + if decision.retryAfter == nil { + t.Fatal("decideAntigravity429().retryAfter = nil") + } + }) + t.Run("structured rate limit", func(t *testing.T) { body := []byte(`{ "error": { @@ -67,8 +105,31 @@ func TestClassifyAntigravity429(t *testing.T) { }) } +func TestAntigravityShouldRetryNoCapacity_Standard503(t *testing.T) { + body := []byte(`{ + "error": { + "code": 503, + "message": "No capacity available for model gemini-3.1-flash-image on the server", + "status": "UNAVAILABLE", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "MODEL_CAPACITY_EXHAUSTED", + "domain": "cloudcode-pa.googleapis.com", + "metadata": { + "model": "gemini-3.1-flash-image" + } + } + ] + } + }`) + if !antigravityShouldRetryNoCapacity(http.StatusServiceUnavailable, body) { + t.Fatal("antigravityShouldRetryNoCapacity() = false, want true") + } +} + func TestInjectEnabledCreditTypes(t *testing.T) { - body := []byte(`{"model":"gemini-2.5-flash","request":{}}`) + body := []byte(`{"model":"claude-sonnet-4-6","request":{}}`) got := injectEnabledCreditTypes(body) if got == nil { t.Fatal("injectEnabledCreditTypes() returned nil") @@ -82,34 +143,18 @@ func TestInjectEnabledCreditTypes(t *testing.T) { } } -func TestShouldMarkAntigravityCreditsExhausted(t *testing.T) { - t.Run("credit errors are marked", func(t *testing.T) { - for _, body := range [][]byte{ - []byte(`{"error":{"message":"Insufficient GOOGLE_ONE_AI credits"}}`), - []byte(`{"error":{"message":"minimumCreditAmountForUsage requirement not met"}}`), - } { - if !shouldMarkAntigravityCreditsExhausted(http.StatusForbidden, body, nil) { - t.Fatalf("shouldMarkAntigravityCreditsExhausted(%s) = false, want true", string(body)) - } - } - }) - - t.Run("transient 429 resource exhausted is not marked", func(t *testing.T) { - body := []byte(`{"error":{"code":429,"message":"Resource has been exhausted (e.g. check quota).","status":"RESOURCE_EXHAUSTED"}}`) - if shouldMarkAntigravityCreditsExhausted(http.StatusTooManyRequests, body, nil) { - t.Fatalf("shouldMarkAntigravityCreditsExhausted(%s) = true, want false", string(body)) - } - }) - - t.Run("resource exhausted with quota metadata is still marked", func(t *testing.T) { - body := []byte(`{"error":{"code":429,"message":"Resource has been exhausted","status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","metadata":{"quotaResetDelay":"1h","model":"claude-sonnet-4-6"}}]}}`) - if !shouldMarkAntigravityCreditsExhausted(http.StatusTooManyRequests, body, nil) { - t.Fatalf("shouldMarkAntigravityCreditsExhausted(%s) = false, want true", string(body)) - } - }) - - if shouldMarkAntigravityCreditsExhausted(http.StatusServiceUnavailable, []byte(`{"error":{"message":"credits exhausted"}}`), nil) { - t.Fatal("shouldMarkAntigravityCreditsExhausted() = true for 5xx, want false") +func TestParseRetryDelay_HumanReadableDuration(t *testing.T) { + body := []byte(`{"error":{"message":"You have exhausted your capacity on this model. Your quota will reset after 1h43m56s."}}`) + retryAfter, err := parseRetryDelay(body) + if err != nil { + t.Fatalf("parseRetryDelay() error = %v", err) + } + if retryAfter == nil { + t.Fatal("parseRetryDelay() returned nil") + } + want := time.Hour + 43*time.Minute + 56*time.Second + if *retryAfter != want { + t.Fatalf("parseRetryDelay() = %v, want %v", *retryAfter, want) } } @@ -147,7 +192,7 @@ func TestAntigravityExecute_RetriesTransient429ResourceExhausted(t *testing.T) { } resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{ - Model: "gemini-2.5-flash", + Model: "claude-sonnet-4-6", Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`), }, cliproxyexecutor.Options{ SourceFormat: sdktranslator.FormatAntigravity, @@ -163,32 +208,23 @@ func TestAntigravityExecute_RetriesTransient429ResourceExhausted(t *testing.T) { } } -func TestAntigravityExecute_RetriesQuotaExhaustedWithCredits(t *testing.T) { +func TestAntigravityExecute_CreditsInjectedWhenConductorRequests(t *testing.T) { resetAntigravityCreditsRetryState() t.Cleanup(resetAntigravityCreditsRetryState) - var ( - mu sync.Mutex - requestBodies []string - ) - + var requestBodies []string server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { body, _ := io.ReadAll(r.Body) _ = r.Body.Close() - - mu.Lock() - requestBodies = append(requestBodies, string(body)) - reqNum := len(requestBodies) - mu.Unlock() - - if reqNum == 1 { - w.WriteHeader(http.StatusTooManyRequests) - _, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)) + if r.URL.Path == "/v1internal:loadCodeAssist" { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"paidTier":{"id":"tier-1","availableCredits":[{"creditType":"GOOGLE_ONE_AI","creditAmount":"25000","minimumCreditAmountForUsage":"50"}]}}`)) return } + requestBodies = append(requestBodies, string(body)) if !strings.Contains(string(body), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) { - t.Fatalf("second request body missing enabledCreditTypes: %s", string(body)) + t.Fatalf("request body missing enabledCreditTypes: %s", string(body)) } w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"ok"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}}`)) @@ -199,7 +235,7 @@ func TestAntigravityExecute_RetriesQuotaExhaustedWithCredits(t *testing.T) { QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true}, }) auth := &cliproxyauth.Auth{ - ID: "auth-credits-ok", + ID: "auth-credits-conductor", Attributes: map[string]string{ "base_url": server.URL, }, @@ -210,8 +246,11 @@ func TestAntigravityExecute_RetriesQuotaExhaustedWithCredits(t *testing.T) { }, } - resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{ - Model: "gemini-2.5-flash", + // Simulate conductor setting credits requested flag in context + ctx := cliproxyauth.WithAntigravityCredits(context.Background()) + + resp, err := exec.Execute(ctx, auth, cliproxyexecutor.Request{ + Model: "claude-sonnet-4-6", Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`), }, cliproxyexecutor.Options{ SourceFormat: sdktranslator.FormatAntigravity, @@ -222,21 +261,25 @@ func TestAntigravityExecute_RetriesQuotaExhaustedWithCredits(t *testing.T) { if len(resp.Payload) == 0 { t.Fatal("Execute() returned empty payload") } - - mu.Lock() - defer mu.Unlock() - if len(requestBodies) != 2 { - t.Fatalf("request count = %d, want 2", len(requestBodies)) + if len(requestBodies) != 1 { + t.Fatalf("request count = %d, want 1", len(requestBodies)) } } -func TestAntigravityExecute_SkipsCreditsRetryWhenAlreadyExhausted(t *testing.T) { +func TestAntigravityExecute_NoCreditsWithoutConductorFlag(t *testing.T) { resetAntigravityCreditsRetryState() t.Cleanup(resetAntigravityCreditsRetryState) - var requestCount int + var requestBodies []string server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestCount++ + body, _ := io.ReadAll(r.Body) + _ = r.Body.Close() + if r.URL.Path == "/v1internal:loadCodeAssist" { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"paidTier":{"id":"tier-1","availableCredits":[{"creditType":"GOOGLE_ONE_AI","creditAmount":"25000","minimumCreditAmountForUsage":"50"}]}}`)) + return + } + requestBodies = append(requestBodies, string(body)) w.WriteHeader(http.StatusTooManyRequests) _, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)) })) @@ -246,7 +289,7 @@ func TestAntigravityExecute_SkipsCreditsRetryWhenAlreadyExhausted(t *testing.T) QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true}, }) auth := &cliproxyauth.Auth{ - ID: "auth-credits-exhausted", + ID: "auth-no-conductor-flag", Attributes: map[string]string{ "base_url": server.URL, }, @@ -256,10 +299,10 @@ func TestAntigravityExecute_SkipsCreditsRetryWhenAlreadyExhausted(t *testing.T) "expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339), }, } - recordAntigravityCreditsFailure(auth, time.Now()) + // No conductor credits flag set in context _, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{ - Model: "gemini-2.5-flash", + Model: "claude-sonnet-4-6", Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`), }, cliproxyexecutor.Options{ SourceFormat: sdktranslator.FormatAntigravity, @@ -267,224 +310,194 @@ func TestAntigravityExecute_SkipsCreditsRetryWhenAlreadyExhausted(t *testing.T) if err == nil { t.Fatal("Execute() error = nil, want 429") } - sErr, ok := err.(statusErr) - if !ok { - t.Fatalf("Execute() error type = %T, want statusErr", err) - } - if got := sErr.StatusCode(); got != http.StatusTooManyRequests { - t.Fatalf("Execute() status code = %d, want %d", got, http.StatusTooManyRequests) + if len(requestBodies) != 1 { + t.Fatalf("request count = %d, want 1", len(requestBodies)) } - if requestCount != 1 { - t.Fatalf("request count = %d, want 1", requestCount) + // Should NOT contain credits since conductor didn't request them + if strings.Contains(requestBodies[0], `"enabledCreditTypes"`) { + t.Fatalf("request should not contain enabledCreditTypes without conductor flag: %s", requestBodies[0]) } } -func TestAntigravityExecute_PrefersCreditsAfterSuccessfulFallback(t *testing.T) { - resetAntigravityCreditsRetryState() - t.Cleanup(resetAntigravityCreditsRetryState) - - var ( - mu sync.Mutex - requestBodies []string - ) +func TestAntigravityAuthHasCredits(t *testing.T) { + t.Run("sufficient balance", func(t *testing.T) { + resetAntigravityCreditsRetryState() + auth := &cliproxyauth.Auth{ID: "test-sufficient"} + antigravityCreditsBalanceByAuth.Store("test-sufficient", antigravityCreditsBalance{ + CreditAmount: 25000, + MinCreditAmount: 50, + Known: true, + }) + if !antigravityAuthHasCredits(auth) { + t.Fatal("antigravityAuthHasCredits() = false, want true") + } + }) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - body, _ := io.ReadAll(r.Body) - _ = r.Body.Close() + t.Run("insufficient balance", func(t *testing.T) { + resetAntigravityCreditsRetryState() + auth := &cliproxyauth.Auth{ID: "test-insufficient"} + antigravityCreditsBalanceByAuth.Store("test-insufficient", antigravityCreditsBalance{ + CreditAmount: 30, + MinCreditAmount: 50, + Known: true, + }) + if antigravityAuthHasCredits(auth) { + t.Fatal("antigravityAuthHasCredits() = true, want false") + } + }) - mu.Lock() - requestBodies = append(requestBodies, string(body)) - reqNum := len(requestBodies) - mu.Unlock() + t.Run("no balance stored returns true (optimistic)", func(t *testing.T) { + resetAntigravityCreditsRetryState() + auth := &cliproxyauth.Auth{ID: "test-no-balance"} + if !antigravityAuthHasCredits(auth) { + t.Fatal("antigravityAuthHasCredits() = false with no balance stored, want true (optimistic default)") + } + }) - switch reqNum { - case 1: - w.WriteHeader(http.StatusTooManyRequests) - _, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"QUOTA_EXHAUSTED"},{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"10s"}]}}`)) - case 2, 3: - if !strings.Contains(string(body), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) { - t.Fatalf("request %d body missing enabledCreditTypes: %s", reqNum, string(body)) - } - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"OK"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}}`)) - default: - t.Fatalf("unexpected request count %d", reqNum) + t.Run("nil auth returns false", func(t *testing.T) { + if antigravityAuthHasCredits(nil) { + t.Fatal("antigravityAuthHasCredits(nil) = true, want false") } - })) - defer server.Close() + }) - exec := NewAntigravityExecutor(&config.Config{ - QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true}, + t.Run("empty ID returns false", func(t *testing.T) { + auth := &cliproxyauth.Auth{} + if antigravityAuthHasCredits(auth) { + t.Fatal("antigravityAuthHasCredits(empty ID) = true, want false") + } }) - auth := &cliproxyauth.Auth{ - ID: "auth-prefer-credits", - Attributes: map[string]string{ - "base_url": server.URL, - }, - Metadata: map[string]any{ - "access_token": "token", - "project_id": "project-1", - "expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339), - }, - } - request := cliproxyexecutor.Request{ - Model: "gemini-2.5-flash", - Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`), - } - opts := cliproxyexecutor.Options{SourceFormat: sdktranslator.FormatAntigravity} + t.Run("unknown balance returns false", func(t *testing.T) { + resetAntigravityCreditsRetryState() + auth := &cliproxyauth.Auth{ID: "test-unknown"} + antigravityCreditsBalanceByAuth.Store("test-unknown", antigravityCreditsBalance{ + Known: false, + }) + if antigravityAuthHasCredits(auth) { + t.Fatal("antigravityAuthHasCredits() = true for unknown balance, want false") + } + }) +} - if _, err := exec.Execute(context.Background(), auth, request, opts); err != nil { - t.Fatalf("first Execute() error = %v", err) - } - if _, err := exec.Execute(context.Background(), auth, request, opts); err != nil { - t.Fatalf("second Execute() error = %v", err) - } +type roundTripperFunc func(*http.Request) (*http.Response, error) - mu.Lock() - defer mu.Unlock() - if len(requestBodies) != 3 { - t.Fatalf("request count = %d, want 3", len(requestBodies)) - } - if strings.Contains(requestBodies[0], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) { - t.Fatalf("first request unexpectedly used credits: %s", requestBodies[0]) - } - if !strings.Contains(requestBodies[1], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) { - t.Fatalf("fallback request missing credits: %s", requestBodies[1]) - } - if !strings.Contains(requestBodies[2], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) { - t.Fatalf("preferred request missing credits: %s", requestBodies[2]) - } +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) } -func TestAntigravityExecute_PreservesBaseURLFallbackAfterCreditsRetryFailure(t *testing.T) { +func TestEnsureAccessToken_WarmTokenLoadsCreditsHint(t *testing.T) { resetAntigravityCreditsRetryState() t.Cleanup(resetAntigravityCreditsRetryState) - var ( - mu sync.Mutex - firstCount int - secondCount int - ) - - firstServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - body, _ := io.ReadAll(r.Body) - _ = r.Body.Close() - - mu.Lock() - firstCount++ - reqNum := firstCount - mu.Unlock() - - switch reqNum { - case 1: - w.WriteHeader(http.StatusTooManyRequests) - _, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"QUOTA_EXHAUSTED"}]}}`)) - case 2: - if !strings.Contains(string(body), `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) { - t.Fatalf("credits retry missing enabledCreditTypes: %s", string(body)) - } - w.WriteHeader(http.StatusForbidden) - _, _ = w.Write([]byte(`{"error":{"message":"permission denied"}}`)) - default: - t.Fatalf("unexpected first server request count %d", reqNum) - } - })) - defer firstServer.Close() - - secondServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - mu.Lock() - secondCount++ - mu.Unlock() - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"ok"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}}`)) - })) - defer secondServer.Close() - exec := NewAntigravityExecutor(&config.Config{ QuotaExceeded: config.QuotaExceeded{AntigravityCredits: true}, }) auth := &cliproxyauth.Auth{ - ID: "auth-baseurl-fallback", - Attributes: map[string]string{ - "base_url": firstServer.URL, - }, + ID: "auth-warm-token-credits", Metadata: map[string]any{ "access_token": "token", - "project_id": "project-1", "expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339), }, } + ctx := context.WithValue(context.Background(), "cliproxy.roundtripper", roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.URL.String() != "https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist" { + t.Fatalf("unexpected request url %s", req.URL.String()) + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(`{"paidTier":{"id":"tier-1","availableCredits":[{"creditType":"GOOGLE_ONE_AI","creditAmount":"25000","minimumCreditAmountForUsage":"50"}]}}`)), + }, nil + })) - originalOrder := antigravityBaseURLFallbackOrder - defer func() { antigravityBaseURLFallbackOrder = originalOrder }() - antigravityBaseURLFallbackOrder = func(auth *cliproxyauth.Auth) []string { - return []string{firstServer.URL, secondServer.URL} - } - - resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{ - Model: "gemini-2.5-flash", - Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`), - }, cliproxyexecutor.Options{ - SourceFormat: sdktranslator.FormatAntigravity, - }) + token, updatedAuth, err := exec.ensureAccessToken(ctx, auth) if err != nil { - t.Fatalf("Execute() error = %v", err) + t.Fatalf("ensureAccessToken() error = %v", err) } - if len(resp.Payload) == 0 { - t.Fatal("Execute() returned empty payload") + if token != "token" { + t.Fatalf("ensureAccessToken() token = %q, want %q", token, "token") + } + if updatedAuth != nil { + t.Fatalf("ensureAccessToken() updatedAuth = %v, want nil", updatedAuth) } - if firstCount != 2 { - t.Fatalf("first server request count = %d, want 2", firstCount) + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) && !cliproxyauth.HasKnownAntigravityCreditsHint(auth.ID) { + time.Sleep(10 * time.Millisecond) } - if secondCount != 1 { - t.Fatalf("second server request count = %d, want 1", secondCount) + if !cliproxyauth.HasKnownAntigravityCreditsHint(auth.ID) { + t.Fatal("expected credits hint to be populated for warm token auth") + } + hint, ok := cliproxyauth.GetAntigravityCreditsHint(auth.ID) + if !ok { + t.Fatal("expected credits hint lookup to succeed") + } + if !hint.Available { + t.Fatalf("hint.Available = %v, want true", hint.Available) + } + if hint.CreditAmount != 25000 || hint.MinCreditAmount != 50 { + t.Fatalf("hint amounts = (%v, %v), want (25000, 50)", hint.CreditAmount, hint.MinCreditAmount) } } -func TestAntigravityExecute_DoesNotDirectInjectCreditsWhenFlagDisabled(t *testing.T) { +func TestUpdateAntigravityCreditsBalance_LoadCodeAssistUserAgent(t *testing.T) { resetAntigravityCreditsRetryState() t.Cleanup(resetAntigravityCreditsRetryState) - var requestBodies []string - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - body, _ := io.ReadAll(r.Body) - _ = r.Body.Close() - requestBodies = append(requestBodies, string(body)) - w.WriteHeader(http.StatusTooManyRequests) - _, _ = w.Write([]byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)) - })) - defer server.Close() - - exec := NewAntigravityExecutor(&config.Config{ - QuotaExceeded: config.QuotaExceeded{AntigravityCredits: false}, - }) + exec := NewAntigravityExecutor(&config.Config{}) + const userAgent = "antigravity/1.23.2 windows/amd64 google-api-nodejs-client/10.3.0" auth := &cliproxyauth.Auth{ - ID: "auth-flag-disabled", - Attributes: map[string]string{ - "base_url": server.URL, - }, - Metadata: map[string]any{ - "access_token": "token", - "project_id": "project-1", - "expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339), - }, + ID: "auth-load-code-assist-ua", + Attributes: map[string]string{"user_agent": userAgent}, } - markAntigravityPreferCredits(auth, "gemini-2.5-flash", time.Now(), nil) + ctx := context.WithValue(context.Background(), "cliproxy.roundtripper", roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.URL.String() != "https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist" { + t.Fatalf("unexpected request url %s", req.URL.String()) + } + if got := req.Header.Get("User-Agent"); got != userAgent { + t.Fatalf("User-Agent = %q, want %q", got, userAgent) + } + if got := req.Header.Get("X-Goog-Api-Client"); got != "gl-node/22.21.1" { + t.Fatalf("X-Goog-Api-Client = %q, want %q", got, "gl-node/22.21.1") + } + body, _ := io.ReadAll(req.Body) + _ = req.Body.Close() + if string(body) != `{"metadata":{"ide_name":"antigravity","ide_type":"ANTIGRAVITY","ide_version":"1.23.2"}}` { + t.Fatalf("loadCodeAssist body = %s", string(body)) + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(`{"paidTier":{"id":"tier-1","availableCredits":[{"creditType":"GOOGLE_ONE_AI","creditAmount":"25000","minimumCreditAmountForUsage":"50"}]}}`)), + }, nil + })) - _, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{ - Model: "gemini-2.5-flash", - Payload: []byte(`{"request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`), - }, cliproxyexecutor.Options{ - SourceFormat: sdktranslator.FormatAntigravity, - }) - if err == nil { - t.Fatal("Execute() error = nil, want 429") - } - if len(requestBodies) != 1 { - t.Fatalf("request count = %d, want 1", len(requestBodies)) - } - if strings.Contains(requestBodies[0], `"enabledCreditTypes":["GOOGLE_ONE_AI"]`) { - t.Fatalf("request unexpectedly used enabledCreditTypes with flag disabled: %s", requestBodies[0]) + exec.updateAntigravityCreditsBalance(ctx, auth, "token") +} + +func TestParseMetaFloat(t *testing.T) { + tests := []struct { + name string + value any + wantVal float64 + wantOK bool + }{ + {"string", "25000", 25000, true}, + {"float64", float64(100), 100, true}, + {"int", int(50), 50, true}, + {"int64", int64(75), 75, true}, + {"empty string", "", 0, false}, + {"invalid string", "abc", 0, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + meta := map[string]any{"key": tt.value} + got, ok := parseMetaFloat(meta, "key") + if ok != tt.wantOK { + t.Fatalf("parseMetaFloat() ok = %v, want %v", ok, tt.wantOK) + } + if ok && got != tt.wantVal { + t.Fatalf("parseMetaFloat() = %f, want %f", got, tt.wantVal) + } + }) } } diff --git a/internal/runtime/executor/antigravity_executor_signature_test.go b/internal/runtime/executor/antigravity_executor_signature_test.go index 226daf5c67..7d84bfe890 100644 --- a/internal/runtime/executor/antigravity_executor_signature_test.go +++ b/internal/runtime/executor/antigravity_executor_signature_test.go @@ -10,10 +10,10 @@ import ( "testing" "time" - "github.com/router-for-me/CLIProxyAPI/v6/internal/cache" - cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" - cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" - sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + "github.com/router-for-me/CLIProxyAPI/v7/internal/cache" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator" ) func testGeminiSignaturePayload() string { diff --git a/internal/runtime/executor/bt_executor.go b/internal/runtime/executor/bt_executor.go new file mode 100644 index 0000000000..737cd0f267 --- /dev/null +++ b/internal/runtime/executor/bt_executor.go @@ -0,0 +1,429 @@ +package executor + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + btauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/bt" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + "github.com/router-for-me/CLIProxyAPI/v7/internal/registry" + "github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor/helps" + "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking" + "github.com/router-for-me/CLIProxyAPI/v7/internal/util" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator" + log "github.com/sirupsen/logrus" + "github.com/tidwall/sjson" +) + +type BTExecutor struct { + cfg *config.Config +} + +func NewBTExecutor(cfg *config.Config) *BTExecutor { + return &BTExecutor{cfg: cfg} +} + +func (e *BTExecutor) Identifier() string { return "bt" } + +func btCredentials(auth *cliproxyauth.Auth) (uid, accessKey, serverID string) { + uid = metaStringValue(auth.Metadata, "uid") + accessKey = metaStringValue(auth.Metadata, "access_key") + serverID = metaStringValue(auth.Metadata, "serverid") + return +} + +func (e *BTExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error { + if req == nil { + return nil + } + uid, accessKey, serverID := btCredentials(auth) + if uid == "" || accessKey == "" { + return fmt.Errorf("bt executor: missing credentials") + } + req.Header.Set("uid", uid) + req.Header.Set("access-key", accessKey) + req.Header.Set("serverid", serverID) + req.Header.Set("appid", btauth.AppID) + req.Header.Set("Content-Type", "application/json") + var attrs map[string]string + if auth != nil { + attrs = auth.Attributes + } + util.ApplyCustomHeadersFromAttrs(req, attrs) + return nil +} + +func (e *BTExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) { + if req == nil { + return nil, fmt.Errorf("bt executor: request is nil") + } + if ctx == nil { + ctx = req.Context() + } + httpReq := req.WithContext(ctx) + if err := e.PrepareRequest(httpReq, auth); err != nil { + return nil, err + } + httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + return httpClient.Do(httpReq) +} + +func (e *BTExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { + baseModel := thinking.ParseSuffix(req.Model).ModelName + + reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth) + defer reporter.TrackFailure(ctx, &err) + + from := opts.SourceFormat + to := sdktranslator.FromString("openai") + endpoint := "/chat/completions" + if opts.Alt == "responses/compact" { + to = sdktranslator.FromString("openai-response") + endpoint = "/responses/compact" + } + originalPayloadSource := req.Payload + if len(opts.OriginalRequest) > 0 { + originalPayloadSource = opts.OriginalRequest + } + originalPayload := originalPayloadSource + originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream) + translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream) + requestedModel := helps.PayloadRequestedModel(opts, req.Model) + translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel, "") + if opts.Alt == "responses/compact" { + if updated, errDelete := sjson.DeleteBytes(translated, "stream"); errDelete == nil { + translated = updated + } + } + + translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier()) + if err != nil { + return resp, err + } + + uid, accessKey, serverID := btCredentials(auth) + if uid == "" || accessKey == "" { + err = statusErr{code: http.StatusUnauthorized, msg: "bt: missing credentials in auth metadata"} + return resp, err + } + + baseURL := btauth.CloudURL + "/plugin_api/chat/openai/v1" + upstreamURL := strings.TrimSuffix(baseURL, "/") + endpoint + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(translated)) + if err != nil { + return resp, err + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("uid", uid) + httpReq.Header.Set("access-key", accessKey) + httpReq.Header.Set("serverid", serverID) + httpReq.Header.Set("appid", btauth.AppID) + var attrs map[string]string + if auth != nil { + attrs = auth.Attributes + } + util.ApplyCustomHeadersFromAttrs(httpReq, attrs) + httpReq.Header.Set("User-Agent", "cli-proxy-bt") + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{ + URL: upstreamURL, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: translated, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + helps.RecordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("bt executor: close response body error: %v", errClose) + } + }() + helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + helps.AppendAPIResponseChunk(ctx, e.cfg, b) + helps.LogWithRequestID(ctx).Debugf("bt executor: request error, status: %d, message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + err = statusErr{code: httpResp.StatusCode, msg: string(b)} + return resp, err + } + body, err := io.ReadAll(httpResp.Body) + if err != nil { + helps.RecordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + helps.AppendAPIResponseChunk(ctx, e.cfg, body) + reporter.Publish(ctx, helps.ParseOpenAIUsage(body)) + reporter.EnsurePublished(ctx) + var param any + out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m) + resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()} + return resp, nil +} + +func (e *BTExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { + baseModel := thinking.ParseSuffix(req.Model).ModelName + + reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth) + defer reporter.TrackFailure(ctx, &err) + + from := opts.SourceFormat + to := sdktranslator.FromString("openai") + originalPayloadSource := req.Payload + if len(opts.OriginalRequest) > 0 { + originalPayloadSource = opts.OriginalRequest + } + originalPayload := originalPayloadSource + originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) + translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true) + requestedModel := helps.PayloadRequestedModel(opts, req.Model) + translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel, "") + + translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier()) + if err != nil { + return nil, err + } + + translated, _ = sjson.SetBytes(translated, "stream_options.include_usage", true) + + uid, accessKey, serverID := btCredentials(auth) + if uid == "" || accessKey == "" { + return nil, fmt.Errorf("bt: missing credentials in auth metadata") + } + + baseURL := btauth.CloudURL + "/plugin_api/chat/openai/v1" + upstreamURL := strings.TrimSuffix(baseURL, "/") + "/chat/completions" + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(translated)) + if err != nil { + return nil, err + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("uid", uid) + httpReq.Header.Set("access-key", accessKey) + httpReq.Header.Set("serverid", serverID) + httpReq.Header.Set("appid", btauth.AppID) + var attrs map[string]string + if auth != nil { + attrs = auth.Attributes + } + util.ApplyCustomHeadersFromAttrs(httpReq, attrs) + httpReq.Header.Set("User-Agent", "cli-proxy-bt") + httpReq.Header.Set("Accept", "text/event-stream") + httpReq.Header.Set("Cache-Control", "no-cache") + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{ + URL: upstreamURL, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: translated, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + helps.RecordAPIResponseError(ctx, e.cfg, err) + return nil, err + } + helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + helps.AppendAPIResponseChunk(ctx, e.cfg, b) + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("bt executor: close response body error: %v", errClose) + } + err = statusErr{code: httpResp.StatusCode, msg: string(b)} + return nil, err + } + out := make(chan cliproxyexecutor.StreamChunk) + go func() { + defer close(out) + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("bt executor: close response body error: %v", errClose) + } + }() + scanner := bufio.NewScanner(httpResp.Body) + scanner.Buffer(nil, 52_428_800) + var param any + for scanner.Scan() { + line := scanner.Bytes() + helps.AppendAPIResponseChunk(ctx, e.cfg, line) + if detail, ok := helps.ParseOpenAIStreamUsage(line); ok { + reporter.Publish(ctx, detail) + } + if len(line) == 0 { + continue + } + if !bytes.HasPrefix(line, []byte("data:")) { + continue + } + chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), ¶m) + for i := range chunks { + out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]} + } + } + if errScan := scanner.Err(); errScan != nil { + helps.RecordAPIResponseError(ctx, e.cfg, errScan) + reporter.PublishFailure(ctx) + out <- cliproxyexecutor.StreamChunk{Err: errScan} + } else { + chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("data: [DONE]"), ¶m) + for i := range chunks { + out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]} + } + } + reporter.EnsurePublished(ctx) + }() + return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil +} + +func (e *BTExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + baseModel := thinking.ParseSuffix(req.Model).ModelName + from := opts.SourceFormat + to := sdktranslator.FromString("openai") + translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) + + translated, err := thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier()) + if err != nil { + return cliproxyexecutor.Response{}, err + } + + enc, err := helps.TokenizerForModel(baseModel) + if err != nil { + return cliproxyexecutor.Response{}, fmt.Errorf("bt executor: tokenizer init failed: %w", err) + } + + count, err := helps.CountOpenAIChatTokens(enc, translated) + if err != nil { + return cliproxyexecutor.Response{}, fmt.Errorf("bt executor: token counting failed: %w", err) + } + + usageJSON := helps.BuildOpenAIUsageJSON(count) + translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON) + return cliproxyexecutor.Response{Payload: translatedUsage}, nil +} + +func (e *BTExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + log.Debugf("bt executor: refresh called") + _ = ctx + return auth, nil +} + +func FetchBTModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo { + if auth == nil { + return nil + } + + uid, accessKey, serverID := btCredentials(auth) + if uid == "" || accessKey == "" { + log.Debug("bt: missing credentials, skipping dynamic model fetch") + return nil + } + + httpClient := helps.NewProxyAwareHTTPClient(ctx, cfg, auth, 15*time.Second) + modelsURL := btauth.CloudURL + "/plugin_api/chat/openai/v1/models" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, modelsURL, nil) + if err != nil { + log.Warnf("bt: failed to create model fetch request: %v", err) + return nil + } + req.Header.Set("uid", uid) + req.Header.Set("access-key", accessKey) + req.Header.Set("serverid", serverID) + req.Header.Set("appid", btauth.AppID) + req.Header.Set("User-Agent", "cli-proxy-bt") + util.ApplyCustomHeadersFromAttrs(req, auth.Attributes) + + resp, err := httpClient.Do(req) + if err != nil { + log.Warnf("bt: fetch models failed: %v", err) + return nil + } + defer func() { + if err := resp.Body.Close(); err != nil { + log.Debugf("bt: close model fetch response error: %v", err) + } + }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + log.Warnf("bt: failed to read models response: %v", err) + return nil + } + + if resp.StatusCode != http.StatusOK { + log.Warnf("bt: fetch models failed: status %d, body: %s", resp.StatusCode, string(body)) + return nil + } + + var result struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + } + if err := json.Unmarshal(body, &result); err != nil { + log.Warnf("bt: failed to parse models response: %v", err) + return nil + } + + now := time.Now().Unix() + models := make([]*registry.ModelInfo, 0, len(result.Data)) + seen := make(map[string]struct{}) + for _, m := range result.Data { + id := strings.TrimSpace(m.ID) + if id == "" { + continue + } + key := strings.ToLower(id) + if _, exists := seen[key]; exists { + continue + } + seen[key] = struct{}{} + models = append(models, ®istry.ModelInfo{ + ID: id, + Object: "model", + Created: now, + OwnedBy: "bt", + Type: "bt", + DisplayName: id, + UserDefined: false, + Thinking: ®istry.ThinkingSupport{Levels: []string{"low", "medium", "high"}}, + }) + } + return models +} diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 0311827bae..eb17864d6e 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -11,23 +11,22 @@ import ( "fmt" "io" "net/http" - "net/textproto" "strings" "time" "github.com/andybalholm/brotli" "github.com/google/uuid" "github.com/klauspost/compress/zstd" - claudeauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" - "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" - "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps" - "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" - "github.com/router-for-me/CLIProxyAPI/v6/internal/util" - cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" - cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" - sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + claudeauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/claude" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + "github.com/router-for-me/CLIProxyAPI/v7/internal/misc" + "github.com/router-for-me/CLIProxyAPI/v7/internal/registry" + "github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor/helps" + "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking" + "github.com/router-for-me/CLIProxyAPI/v7/internal/util" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -66,14 +65,13 @@ var oauthToolRenameMap = map[string]string{ "notebookedit": "NotebookEdit", } -// oauthToolRenameReverseMap is the inverse of oauthToolRenameMap for response decoding. -var oauthToolRenameReverseMap = func() map[string]string { - m := make(map[string]string, len(oauthToolRenameMap)) - for k, v := range oauthToolRenameMap { - m[v] = k - } - return m -}() +// The reverse map is now computed per-request in remapOAuthToolNames so that +// only names the client actually caused us to rewrite are restored on the +// response. A global reverse map — as used previously — corrupted responses +// for clients that sent mixed casing (e.g. Amp CLI sends `Bash` TitleCase +// alongside `glob` lowercase; the request flagged renames via `glob→Glob`, +// then the global reverse map incorrectly rewrote every `Bash` in the +// response to `bash`, causing Amp to reject the tool_use as unknown). // oauthToolsToRemove lists tool names that must be stripped from OAuth requests // even after remapping. Currently empty — all tools are mapped instead of removed. @@ -165,7 +163,8 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey) requestedModel := helps.PayloadRequestedModel(opts, req.Model) - body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) + requestPath := helps.PayloadRequestPath(opts) + body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath) body = ensureModelMaxTokens(body, baseModel) // Disable thinking if tool_choice forces tool use (Anthropic API constraint) @@ -192,15 +191,9 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r bodyForTranslation := body bodyForUpstream := body oauthToken := isClaudeOAuthToken(apiKey) - oauthToolNamesRemapped := false - if oauthToken && !auth.ToolPrefixDisabled() { - bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix) - } - // Remap third-party tool names to Claude Code equivalents and remove - // tools without official counterparts. This prevents Anthropic from - // fingerprinting the request as third-party via tool naming patterns. + var oauthToolNamesReverseMap map[string]string if oauthToken { - bodyForUpstream, oauthToolNamesRemapped = remapOAuthToolNames(bodyForUpstream) + bodyForUpstream, oauthToolNamesReverseMap = prepareClaudeOAuthToolNamesForUpstream(bodyForUpstream, claudeToolPrefix, auth.ToolPrefixDisabled()) } // Enable cch signing by default for OAuth tokens (not just experimental flag). // Claude Code always computes cch; missing or invalid cch is a detectable fingerprint. @@ -285,6 +278,10 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r } helps.AppendAPIResponseChunk(ctx, e.cfg, data) if stream { + if errValidate := validateClaudeStreamingResponse(data); errValidate != nil { + helps.RecordAPIResponseError(ctx, e.cfg, errValidate) + return resp, errValidate + } lines := bytes.Split(data, []byte("\n")) for _, line := range lines { if detail, ok := helps.ParseClaudeStreamUsage(line); ok { @@ -294,13 +291,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r } else { reporter.Publish(ctx, helps.ParseClaudeUsage(data)) } - if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() { - data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix) - } - // Reverse the OAuth tool name remap so the downstream client sees original names. - if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped { - data = reverseRemapOAuthToolNames(data) - } + data = restoreClaudeOAuthToolNamesFromResponse(data, claudeToolPrefix, auth.ToolPrefixDisabled(), oauthToolNamesReverseMap) var param any out := sdktranslator.TranslateNonStream( ctx, @@ -350,7 +341,8 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey) requestedModel := helps.PayloadRequestedModel(opts, req.Model) - body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) + requestPath := helps.PayloadRequestPath(opts) + body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath) body = ensureModelMaxTokens(body, baseModel) // Disable thinking if tool_choice forces tool use (Anthropic API constraint) @@ -374,15 +366,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A bodyForTranslation := body bodyForUpstream := body oauthToken := isClaudeOAuthToken(apiKey) - oauthToolNamesRemapped := false - if oauthToken && !auth.ToolPrefixDisabled() { - bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix) - } - // Remap third-party tool names to Claude Code equivalents and remove - // tools without official counterparts. This prevents Anthropic from - // fingerprinting the request as third-party via tool naming patterns. + var oauthToolNamesReverseMap map[string]string if oauthToken { - bodyForUpstream, oauthToolNamesRemapped = remapOAuthToolNames(bodyForUpstream) + bodyForUpstream, oauthToolNamesReverseMap = prepareClaudeOAuthToolNamesForUpstream(bodyForUpstream, claudeToolPrefix, auth.ToolPrefixDisabled()) } // Enable cch signing by default for OAuth tokens (not just experimental flag). if oauthToken || experimentalCCHSigningEnabled(e.cfg, auth) { @@ -473,22 +459,24 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A if detail, ok := helps.ParseClaudeStreamUsage(line); ok { reporter.Publish(ctx, detail) } - if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() { - line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix) - } - if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped { - line = reverseRemapOAuthToolNamesFromStreamLine(line) - } + line = restoreClaudeOAuthToolNamesFromStreamLine(line, claudeToolPrefix, auth.ToolPrefixDisabled(), oauthToolNamesReverseMap) // Forward the line as-is to preserve SSE format cloned := make([]byte, len(line)+1) copy(cloned, line) cloned[len(line)] = '\n' - out <- cliproxyexecutor.StreamChunk{Payload: cloned} + select { + case out <- cliproxyexecutor.StreamChunk{Payload: cloned}: + case <-ctx.Done(): + return + } } if errScan := scanner.Err(); errScan != nil { helps.RecordAPIResponseError(ctx, e.cfg, errScan) - reporter.PublishFailure(ctx) - out <- cliproxyexecutor.StreamChunk{Err: errScan} + reporter.PublishFailure(ctx, errScan) + select { + case out <- cliproxyexecutor.StreamChunk{Err: errScan}: + case <-ctx.Done(): + } } return } @@ -503,12 +491,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A if detail, ok := helps.ParseClaudeStreamUsage(line); ok { reporter.Publish(ctx, detail) } - if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() { - line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix) - } - if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped { - line = reverseRemapOAuthToolNamesFromStreamLine(line) - } + line = restoreClaudeOAuthToolNamesFromStreamLine(line, claudeToolPrefix, auth.ToolPrefixDisabled(), oauthToolNamesReverseMap) chunks := sdktranslator.TranslateStream( ctx, to, @@ -520,18 +503,83 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A ¶m, ) for i := range chunks { - out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]} + select { + case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}: + case <-ctx.Done(): + return + } } } if errScan := scanner.Err(); errScan != nil { helps.RecordAPIResponseError(ctx, e.cfg, errScan) - reporter.PublishFailure(ctx) - out <- cliproxyexecutor.StreamChunk{Err: errScan} + reporter.PublishFailure(ctx, errScan) + select { + case out <- cliproxyexecutor.StreamChunk{Err: errScan}: + case <-ctx.Done(): + } } }() return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil } +func validateClaudeStreamingResponse(data []byte) error { + scanner := bufio.NewScanner(bytes.NewReader(data)) + scanner.Buffer(nil, 52_428_800) + + hasData := false + hasMessageStart := false + hasMessageDelta := false + + for scanner.Scan() { + line := bytes.TrimSpace(scanner.Bytes()) + if len(line) == 0 || !bytes.HasPrefix(line, []byte("data:")) { + continue + } + payload := bytes.TrimSpace(line[len("data:"):]) + if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) { + continue + } + hasData = true + if !gjson.ValidBytes(payload) { + return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream returned malformed stream data"} + } + + root := gjson.ParseBytes(payload) + switch root.Get("type").String() { + case "error": + message := strings.TrimSpace(root.Get("error.message").String()) + if message == "" { + message = strings.TrimSpace(root.Get("error.type").String()) + } + if message == "" { + message = "unknown upstream error" + } + return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream returned error event: " + message} + case "message_start": + message := root.Get("message") + if strings.TrimSpace(message.Get("id").String()) == "" || strings.TrimSpace(message.Get("model").String()) == "" { + return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream stream message_start is missing id or model"} + } + hasMessageStart = true + case "message_delta": + hasMessageDelta = true + } + } + if errScan := scanner.Err(); errScan != nil { + return errScan + } + if !hasData { + return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream returned empty stream response"} + } + if !hasMessageStart { + return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream stream response is missing message_start"} + } + if !hasMessageDelta { + return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream stream response ended before message completion"} + } + return nil +} + func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { baseModel := thinking.ParseSuffix(req.Model).ModelName @@ -558,12 +606,8 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut // Extract betas from body and convert to header (for count_tokens too) var extraBetas []string extraBetas, body = extractAndRemoveBetas(body) - if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() { - body = applyClaudeToolPrefix(body, claudeToolPrefix) - } - // Remap tool names for OAuth token requests to avoid third-party fingerprinting. if isClaudeOAuthToken(apiKey) { - body, _ = remapOAuthToolNames(body) + body, _ = prepareClaudeOAuthToolNamesForUpstream(body, claudeToolPrefix, auth.ToolPrefixDisabled()) } url := fmt.Sprintf("%s/v1/messages/count_tokens?beta=true", baseURL) @@ -647,6 +691,9 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { log.Debugf("claude executor: refresh called") + if refreshed, handled, err := helps.RefreshAuthViaHome(ctx, e.cfg, auth); handled { + return refreshed, err + } if auth == nil { return nil, fmt.Errorf("claude executor: auth is nil") } @@ -660,7 +707,7 @@ func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) ( return auth, nil } svc := claudeauth.NewClaudeAuthWithProxyURL(e.cfg, auth.ProxyURL) - td, err := svc.RefreshTokens(ctx, refreshToken) + td, err := svc.RefreshTokensWithRetry(ctx, refreshToken, 3) if err != nil { return nil, err } @@ -911,15 +958,8 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, baseBetas += ",interleaved-thinking-2025-05-14" } - hasClaude1MHeader := false - if ginHeaders != nil { - if _, ok := ginHeaders[textproto.CanonicalMIMEHeaderKey("X-CPA-CLAUDE-1M")]; ok { - hasClaude1MHeader = true - } - } - // Merge extra betas from request body and request flags. - if len(extraBetas) > 0 || hasClaude1MHeader { + if len(extraBetas) > 0 { existingSet := make(map[string]bool) for _, b := range strings.Split(baseBetas, ",") { betaName := strings.TrimSpace(b) @@ -934,9 +974,6 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, existingSet[beta] = true } } - if hasClaude1MHeader && !existingSet["context-1m-2025-08-07"] { - baseBetas += ",context-1m-2025-08-07" - } } r.Header.Set("Anthropic-Beta", baseBetas) @@ -1013,6 +1050,36 @@ func isClaudeOAuthToken(apiKey string) bool { return strings.Contains(apiKey, "sk-ant-oat") } +// prepareClaudeOAuthToolNamesForUpstream applies the Claude OAuth tool-name +// transforms in the same order across request paths. Remap runs before prefixing +// so any future non-empty prefix still composes correctly with the per-request +// reverse map. +func prepareClaudeOAuthToolNamesForUpstream(body []byte, prefix string, prefixDisabled bool) ([]byte, map[string]string) { + body, reverseMap := remapOAuthToolNames(body) + if !prefixDisabled { + body = applyClaudeToolPrefix(body, prefix) + } + return body, reverseMap +} + +// restoreClaudeOAuthToolNamesFromResponse undoes the Claude OAuth tool-name +// transforms for non-stream responses in reverse order. +func restoreClaudeOAuthToolNamesFromResponse(body []byte, prefix string, prefixDisabled bool, reverseMap map[string]string) []byte { + if !prefixDisabled { + body = stripClaudeToolPrefixFromResponse(body, prefix) + } + return reverseRemapOAuthToolNames(body, reverseMap) +} + +// restoreClaudeOAuthToolNamesFromStreamLine undoes the Claude OAuth tool-name +// transforms for SSE lines in reverse order. +func restoreClaudeOAuthToolNamesFromStreamLine(line []byte, prefix string, prefixDisabled bool, reverseMap map[string]string) []byte { + if !prefixDisabled { + line = stripClaudeToolPrefixFromStreamLine(line, prefix) + } + return reverseRemapOAuthToolNamesFromStreamLine(line, reverseMap) +} + // remapOAuthToolNames renames third-party tool names to Claude Code equivalents // and removes tools without an official counterpart. This prevents Anthropic from // fingerprinting the request as a third-party client via tool naming patterns. @@ -1020,8 +1087,25 @@ func isClaudeOAuthToken(apiKey string) bool { // It operates on: tools[].name, tool_choice.name, and all tool_use/tool_reference // references in messages. Removed tools' corresponding tool_result blocks are preserved // (they just become orphaned, which is safe for Claude). -func remapOAuthToolNames(body []byte) ([]byte, bool) { - renamed := false +// +// The returned map is keyed on the upstream (TitleCase) name and maps to the +// client-supplied original name. Callers MUST pass this map to the reverse +// functions so only names the client actually caused us to rewrite are restored +// on the response. A global reverse map (the previous implementation) incorrectly +// rewrote names the client originally sent in TitleCase (e.g. Amp CLI's `Bash`) +// when any OTHER tool in the same request triggered a forward rename (e.g. +// Amp's `glob`→`Glob`), because the global reverse map contained `Bash`→`bash` +// regardless of what the client originally sent. +func remapOAuthToolNames(body []byte) ([]byte, map[string]string) { + reverseMap := make(map[string]string, len(oauthToolRenameMap)) + recordRename := func(original, renamed string) { + // Preserve the first-seen original name if the same upstream name is + // produced from multiple call sites; they all map back identically. + if _, exists := reverseMap[renamed]; !exists { + reverseMap[renamed] = original + } + } + // 1. Rewrite tools array in a single pass (if present). // IMPORTANT: do not mutate names first and then rebuild from an older gjson // snapshot. gjson results are snapshots of the original bytes; rebuilding from a @@ -1054,7 +1138,7 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) { updatedTool, err := sjson.Set(toolJSON, "name", newName) if err == nil { toolJSON = updatedTool - renamed = true + recordRename(name, newName) } } @@ -1079,7 +1163,7 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) { body, _ = sjson.DeleteBytes(body, "tool_choice") } else if newName, ok := oauthToolRenameMap[tcName]; ok && newName != tcName { body, _ = sjson.SetBytes(body, "tool_choice.name", newName) - renamed = true + recordRename(tcName, newName) } } @@ -1099,14 +1183,14 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) { if newName, ok := oauthToolRenameMap[name]; ok && newName != name { path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int()) body, _ = sjson.SetBytes(body, path, newName) - renamed = true + recordRename(name, newName) } case "tool_reference": toolName := part.Get("tool_name").String() if newName, ok := oauthToolRenameMap[toolName]; ok && newName != toolName { path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int()) body, _ = sjson.SetBytes(body, path, newName) - renamed = true + recordRename(toolName, newName) } case "tool_result": // Handle nested tool_reference blocks inside tool_result.content[] @@ -1120,7 +1204,7 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) { if newName, ok := oauthToolRenameMap[nestedToolName]; ok && newName != nestedToolName { nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int()) body, _ = sjson.SetBytes(body, nestedPath, newName) - renamed = true + recordRename(nestedToolName, newName) } } return true @@ -1133,13 +1217,16 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) { }) } - return body, renamed + return body, reverseMap } -// reverseRemapOAuthToolNames reverses the tool name mapping for non-stream responses. -// It maps Claude Code TitleCase names back to the original lowercase names so the -// downstream client receives tool names it recognizes. -func reverseRemapOAuthToolNames(body []byte) []byte { +// reverseRemapOAuthToolNames reverses the tool name mapping for non-stream responses +// using the per-request map produced by remapOAuthToolNames. Names the client sent +// that were NOT forward-renamed are passed through unchanged. +func reverseRemapOAuthToolNames(body []byte, reverseMap map[string]string) []byte { + if len(reverseMap) == 0 { + return body + } content := gjson.GetBytes(body, "content") if !content.Exists() || !content.IsArray() { return body @@ -1149,13 +1236,13 @@ func reverseRemapOAuthToolNames(body []byte) []byte { switch partType { case "tool_use": name := part.Get("name").String() - if origName, ok := oauthToolRenameReverseMap[name]; ok { + if origName, ok := reverseMap[name]; ok { path := fmt.Sprintf("content.%d.name", index.Int()) body, _ = sjson.SetBytes(body, path, origName) } case "tool_reference": toolName := part.Get("tool_name").String() - if origName, ok := oauthToolRenameReverseMap[toolName]; ok { + if origName, ok := reverseMap[toolName]; ok { path := fmt.Sprintf("content.%d.tool_name", index.Int()) body, _ = sjson.SetBytes(body, path, origName) } @@ -1165,8 +1252,12 @@ func reverseRemapOAuthToolNames(body []byte) []byte { return body } -// reverseRemapOAuthToolNamesFromStreamLine reverses the tool name mapping for SSE stream lines. -func reverseRemapOAuthToolNamesFromStreamLine(line []byte) []byte { +// reverseRemapOAuthToolNamesFromStreamLine reverses the tool name mapping for SSE +// stream lines, using the per-request reverseMap produced by remapOAuthToolNames. +func reverseRemapOAuthToolNamesFromStreamLine(line []byte, reverseMap map[string]string) []byte { + if len(reverseMap) == 0 { + return line + } payload := helps.JSONPayload(line) if len(payload) == 0 || !gjson.ValidBytes(payload) { return line @@ -1184,7 +1275,7 @@ func reverseRemapOAuthToolNamesFromStreamLine(line []byte) []byte { switch blockType { case "tool_use": name := contentBlock.Get("name").String() - if origName, ok := oauthToolRenameReverseMap[name]; ok { + if origName, ok := reverseMap[name]; ok { updated, err = sjson.SetBytes(payload, "content_block.name", origName) if err != nil { return line @@ -1194,7 +1285,7 @@ func reverseRemapOAuthToolNamesFromStreamLine(line []byte) []byte { } case "tool_reference": toolName := contentBlock.Get("tool_name").String() - if origName, ok := oauthToolRenameReverseMap[toolName]; ok { + if origName, ok := reverseMap[toolName]; ok { updated, err = sjson.SetBytes(payload, "content_block.tool_name", origName) if err != nil { return line diff --git a/internal/runtime/executor/claude_executor_test.go b/internal/runtime/executor/claude_executor_test.go index f456064dc6..f5bca55ab7 100644 --- a/internal/runtime/executor/claude_executor_test.go +++ b/internal/runtime/executor/claude_executor_test.go @@ -17,12 +17,12 @@ import ( "github.com/gin-gonic/gin" "github.com/klauspost/compress/zstd" xxHash64 "github.com/pierrec/xxHash/xxHash64" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" - "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps" - cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" - cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" - sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + "github.com/router-for-me/CLIProxyAPI/v7/internal/registry" + "github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor/helps" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -936,6 +936,113 @@ func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) { } } +func TestClaudeExecutor_ExecuteOpenAINonStreamRejectsEmptyClaudeStream(t *testing.T) { + _, err := executeOpenAIChatCompletionThroughClaude(t, "") + if err == nil { + t.Fatal("Execute error = nil, want empty stream error") + } + assertStatusErr(t, err, http.StatusBadGateway) + if !strings.Contains(err.Error(), "empty stream response") { + t.Fatalf("Execute error = %q, want empty stream response", err.Error()) + } +} + +func TestClaudeExecutor_ExecuteOpenAINonStreamRejectsClaudeErrorEvent(t *testing.T) { + body := `data: {"type":"error","error":{"type":"overloaded_error","message":"upstream overloaded"}}` + "\n" + _, err := executeOpenAIChatCompletionThroughClaude(t, body) + if err == nil { + t.Fatal("Execute error = nil, want upstream error event") + } + assertStatusErr(t, err, http.StatusBadGateway) + if !strings.Contains(err.Error(), "upstream overloaded") { + t.Fatalf("Execute error = %q, want upstream overloaded", err.Error()) + } +} + +func TestClaudeExecutor_ExecuteOpenAINonStreamRejectsIncompleteClaudeStream(t *testing.T) { + body := strings.Join([]string{ + `data: {"type":"message_start","message":{"id":"msg_123","model":"claude-3-5-sonnet-20241022"}}`, + `data: {"type":"message_stop"}`, + ``, + }, "\n") + + _, err := executeOpenAIChatCompletionThroughClaude(t, body) + if err == nil { + t.Fatal("Execute error = nil, want incomplete stream error") + } + assertStatusErr(t, err, http.StatusBadGateway) + if !strings.Contains(err.Error(), "ended before message completion") { + t.Fatalf("Execute error = %q, want incomplete stream error", err.Error()) + } +} + +func TestClaudeExecutor_ExecuteOpenAINonStreamConvertsValidClaudeStream(t *testing.T) { + body := strings.Join([]string{ + `event: message_start`, + `data: {"type":"message_start","message":{"id":"msg_123","model":"claude-3-5-sonnet-20241022"}}`, + `event: content_block_delta`, + `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"ok"}}`, + `event: message_delta`, + `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"input_tokens":2,"output_tokens":1}}`, + `event: message_stop`, + `data: {"type":"message_stop"}`, + ``, + }, "\n") + + resp, err := executeOpenAIChatCompletionThroughClaude(t, body) + if err != nil { + t.Fatalf("Execute error: %v", err) + } + if got := gjson.GetBytes(resp.Payload, "id").String(); got != "msg_123" { + t.Fatalf("response id = %q, want msg_123; payload=%s", got, string(resp.Payload)) + } + if got := gjson.GetBytes(resp.Payload, "model").String(); got != "claude-3-5-sonnet-20241022" { + t.Fatalf("response model = %q, want claude-3-5-sonnet-20241022", got) + } + if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "ok" { + t.Fatalf("response content = %q, want ok", got) + } + if got := gjson.GetBytes(resp.Payload, "usage.total_tokens").Int(); got != 3 { + t.Fatalf("usage.total_tokens = %d, want 3", got) + } +} + +func executeOpenAIChatCompletionThroughClaude(t *testing.T, upstreamBody string) (cliproxyexecutor.Response, error) { + t.Helper() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte(upstreamBody)) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + payload := []byte(`{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":"hi"}]}`) + + return executor.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("openai"), + }) +} + +func assertStatusErr(t *testing.T, err error, want int) { + t.Helper() + + status, ok := err.(interface{ StatusCode() int }) + if !ok { + t.Fatalf("error %T does not expose StatusCode", err) + } + if got := status.StatusCode(); got != want { + t.Fatalf("StatusCode() = %d, want %d", got, want) + } +} + func TestStripClaudeToolPrefixFromResponse_NestedToolReference(t *testing.T) { input := []byte(`{"content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"proxy_mcp__nia__manage_resource"}]}]}`) out := stripClaudeToolPrefixFromResponse(input, "proxy_") @@ -1714,7 +1821,27 @@ func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity } } -// Test case 1: String system prompt is preserved and converted to a content block +func expectedClaudeCodeStaticPrompt() string { + return strings.Join([]string{ + helps.ClaudeCodeIntro, + helps.ClaudeCodeSystem, + helps.ClaudeCodeDoingTasks, + helps.ClaudeCodeToneAndStyle, + helps.ClaudeCodeOutputEfficiency, + }, "\n\n") +} + +func expectedForwardedSystemReminder(text string) string { + return fmt.Sprintf(`You can close this window and return to the terminal.
` + + `Credential captured, syncing. Please return to the command line.
` + + `You can close this window and return to the terminal.
` + + `
+ {revealed ? value : maskKey(value)}
+
+
+
+ {revealed ? value : maskKey(value)}
+
+
+ {message}
++ Simple API key strings for authentication +
+ ++ Gemini API key entries with optional prefix and proxy +
+ ++ Claude API key entries with model aliases +
+ ++ Codex API key entries with WebSocket support +
+ ++ Vertex AI API key entries with service account import +
++ OpenAI-compatible provider entries +
+ ++ No model mappings configured +
+ ) : ( ++ No upstream API keys configured +
+ ) : ( ++ No auth files found. Upload an auth file to get started. +
++ Request logging is active. Individual request logs can be + downloaded by their request ID. +
+
+ {JSON.stringify(requestLog, null, 2)}
+
+ + Model aliases map model names to alternative names for routing. +
+ ++ No model aliases configured. +
++ Manage models excluded from provider routing. +
+ ++ No excluded models configured. +
++ No OAuth providers available. +
+