From 69ba18d3922d4ecae2ce01130ea49320065e8c40 Mon Sep 17 00:00:00 2001 From: Xyfacai Date: Fri, 24 Apr 2026 01:24:14 +0800 Subject: [PATCH 01/67] fix(image): only price image model use N ratio --- relay/image_handler.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/relay/image_handler.go b/relay/image_handler.go index a4fee7d9e0..e986dd897e 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -122,8 +122,10 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type // calculation (both price-based and ratio-based paths). // Adaptors may have already set a more accurate count from the // upstream response; only set the default when they haven't. - if _, hasN := info.PriceData.OtherRatios["n"]; !hasN { - info.PriceData.AddOtherRatio("n", float64(imageN)) + if info.PriceData.UsePrice { // only price model use N ratio + if _, hasN := info.PriceData.OtherRatios["n"]; !hasN { + info.PriceData.AddOtherRatio("n", float64(imageN)) + } } if usage.(*dto.Usage).TotalTokens == 0 { From df6d8628951894fd741a723fc6ae338ee034a7a4 Mon Sep 17 00:00:00 2001 From: yesone Date: Fri, 24 Apr 2026 09:00:21 +0800 Subject: [PATCH 02/67] fix: correct gpt-5.5 completion ratio --- setting/ratio_setting/model_ratio.go | 3 +++ setting/ratio_setting/model_ratio_test.go | 22 ++++++++++++++++++++++ 2 files changed, 25 insertions(+) create mode 100644 setting/ratio_setting/model_ratio_test.go diff --git a/setting/ratio_setting/model_ratio.go b/setting/ratio_setting/model_ratio.go index 7556fd9482..42040d97df 100644 --- a/setting/ratio_setting/model_ratio.go +++ b/setting/ratio_setting/model_ratio.go @@ -515,6 +515,9 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) { } // gpt-5 匹配 if strings.HasPrefix(name, "gpt-5") { + if strings.HasPrefix(name, "gpt-5.5") { + return 6, true + } if strings.HasPrefix(name, "gpt-5.4") { if strings.HasPrefix(name, "gpt-5.4-nano") { return 6.25, true diff --git a/setting/ratio_setting/model_ratio_test.go b/setting/ratio_setting/model_ratio_test.go new file mode 100644 index 0000000000..0b7912f893 --- /dev/null +++ b/setting/ratio_setting/model_ratio_test.go @@ -0,0 +1,22 @@ +package ratio_setting + +import "testing" + +func TestGetCompletionRatioInfoGPT55UsesOfficialOutputMultiplier(t *testing.T) { + info := GetCompletionRatioInfo("gpt-5.5") + + if info.Ratio != 6 { + t.Fatalf("gpt-5.5 completion ratio = %v, want 6", info.Ratio) + } + if !info.Locked { + t.Fatal("gpt-5.5 completion ratio should be locked to the official multiplier") + } +} + +func TestGetCompletionRatioGPT55DatedVariant(t *testing.T) { + got := GetCompletionRatio("gpt-5.5-2026-04-24") + + if got != 6 { + t.Fatalf("gpt-5.5 dated variant completion ratio = %v, want 6", got) + } +} From 63ce2db988887dc760edba50c236b70d34e998ca Mon Sep 17 00:00:00 2001 From: feitianbubu Date: Fri, 24 Apr 2026 13:47:52 +0800 Subject: [PATCH 03/67] fix: model pricing use correct display type --- .../modal/components/DynamicPricingBreakdown.jsx | 9 ++++----- web/src/helpers/utils.jsx | 16 +++++++++++++++- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/web/src/components/table/model-pricing/modal/components/DynamicPricingBreakdown.jsx b/web/src/components/table/model-pricing/modal/components/DynamicPricingBreakdown.jsx index fd2be3f35f..794627dd5f 100644 --- a/web/src/components/table/model-pricing/modal/components/DynamicPricingBreakdown.jsx +++ b/web/src/components/table/model-pricing/modal/components/DynamicPricingBreakdown.jsx @@ -20,7 +20,7 @@ For commercial licensing, please contact support@quantumnous.com import React from 'react'; import { Avatar, Tag, Table, Typography } from '@douyinfe/semi-ui'; import { IconPriceTag } from '@douyinfe/semi-icons'; -import { parseTiersFromExpr } from '../../../../../helpers'; +import { parseTiersFromExpr, getCurrencyConfig } from '../../../../../helpers'; import { BILLING_VARS } from '../../../../../constants'; import { splitBillingExprAndRequestRules, @@ -36,8 +36,6 @@ import { const { Text } = Typography; -const PRICE_SUFFIX = '$/1M tokens'; - const VAR_LABELS = { p: '输入', c: '输出' }; const OP_LABELS = { '<': '<', '<=': '≤', '>': '>', '>=': '≥' }; const TIME_FUNC_LABELS = { hour: '小时', minute: '分钟', weekday: '星期', month: '月份', day: '日期' }; @@ -89,6 +87,7 @@ function describeGroup(group, t) { } export default function DynamicPricingBreakdown({ billingExpr, t }) { + const { symbol, rate } = getCurrencyConfig(); const { billingExpr: baseExpr, requestRuleExpr: ruleExpr } = splitBillingExprAndRequestRules(billingExpr || ''); @@ -132,9 +131,9 @@ export default function DynamicPricingBreakdown({ billingExpr, t }) { ...priceFields .filter(([field]) => hasTiers && tiers.some((tier) => tier[field] > 0)) .map(([field, label]) => ({ - title: `${t(label)} (${PRICE_SUFFIX})`, + title: `${t(label)} (${symbol}/1M tokens)`, dataIndex: field, - render: (v) => v > 0 ? ${v.toFixed(4)} : '-', + render: (v) => v > 0 ? {`${symbol}${(v * rate).toFixed(4)}`} : '-', })), ]; diff --git a/web/src/helpers/utils.jsx b/web/src/helpers/utils.jsx index f73df714b0..a4af4f2431 100644 --- a/web/src/helpers/utils.jsx +++ b/web/src/helpers/utils.jsx @@ -900,6 +900,20 @@ export const getModelPriceItems = ( export const formatDynamicPriceSummary = (billingExpr, t, groupRatio = 1) => { if (!billingExpr) return {t('动态计费')}; + const quotaDisplayType = localStorage.getItem('quota_display_type') || 'USD'; + let symbol = '$'; + let rate = 1; + try { + const s = JSON.parse(localStorage.getItem('status') || '{}'); + if (quotaDisplayType === 'CNY') { + symbol = '¥'; + rate = s?.usd_exchange_rate || 7; + } else if (quotaDisplayType === 'CUSTOM') { + symbol = s?.custom_currency_symbol || '¤'; + rate = s?.custom_currency_exchange_rate || 1; + } + } catch (e) {} + const gr = groupRatio || 1; const exprBody = billingExpr.replace(/^v\d+:/, ''); const tierMatches = exprBody.match(/tier\(/g) || []; @@ -933,7 +947,7 @@ export const formatDynamicPriceSummary = (billingExpr, t, groupRatio = 1) => { {varLabels.map(([key, label]) => key in varCoeffs ? ( - {t(label)} ${(varCoeffs[key] * gr).toFixed(4)}{unitSuffix} + {`${t(label)} ${symbol}${(varCoeffs[key] * gr * rate).toFixed(4)}${unitSuffix}`} ) : null, )} From e3d64cb76defa42193a19edb9a4e726ba8b82462 Mon Sep 17 00:00:00 2001 From: yyhhyyyyyy Date: Fri, 24 Apr 2026 16:24:36 +0800 Subject: [PATCH 04/67] Merge pull request #4431 from yyhhyyyyyy/fix/tiered-billing-model-list fix: include tiered billing models in model listing --- controller/model.go | 7 +- controller/model_list_test.go | 242 ++++++++++++++++++++++++++++++++++ model/option.go | 3 + model/pricing.go | 25 +++- relay/helper/price.go | 14 +- 5 files changed, 272 insertions(+), 19 deletions(-) create mode 100644 controller/model_list_test.go diff --git a/controller/model.go b/controller/model.go index aa6c6e2b9d..f22379555a 100644 --- a/controller/model.go +++ b/controller/model.go @@ -17,7 +17,6 @@ import ( relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" "github.com/QuantumNous/new-api/setting/operation_setting" - "github.com/QuantumNous/new-api/setting/ratio_setting" "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" "github.com/samber/lo" @@ -134,8 +133,7 @@ func ListModels(c *gin.Context, modelType int) { } for allowModel, _ := range tokenModelLimit { if !acceptUnsetRatioModel { - _, _, exist := ratio_setting.GetModelRatioOrPrice(allowModel) - if !exist { + if !model.HasModelBillingConfig(allowModel) { continue } } @@ -182,8 +180,7 @@ func ListModels(c *gin.Context, modelType int) { } for _, modelName := range models { if !acceptUnsetRatioModel { - _, _, exist := ratio_setting.GetModelRatioOrPrice(modelName) - if !exist { + if !model.HasModelBillingConfig(modelName) { continue } } diff --git a/controller/model_list_test.go b/controller/model_list_test.go new file mode 100644 index 0000000000..97d27cae5c --- /dev/null +++ b/controller/model_list_test.go @@ -0,0 +1,242 @@ +package controller + +import ( + "fmt" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/constant" + "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/setting/config" + "github.com/QuantumNous/new-api/setting/operation_setting" + "github.com/gin-gonic/gin" + "github.com/glebarez/sqlite" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +type listModelsResponse struct { + Success bool `json:"success"` + Data []dto.OpenAIModels `json:"data"` + Object string `json:"object"` +} + +func setupModelListControllerTestDB(t *testing.T) *gorm.DB { + t.Helper() + + initModelListColumnNames(t) + + gin.SetMode(gin.TestMode) + common.UsingSQLite = true + common.UsingMySQL = false + common.UsingPostgreSQL = false + common.RedisEnabled = false + + dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_")) + db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{}) + require.NoError(t, err) + model.DB = db + model.LOG_DB = db + + require.NoError(t, db.AutoMigrate(&model.User{}, &model.Channel{}, &model.Ability{}, &model.Model{}, &model.Vendor{})) + + t.Cleanup(func() { + sqlDB, err := db.DB() + if err == nil { + _ = sqlDB.Close() + } + }) + + return db +} + +func initModelListColumnNames(t *testing.T) { + t.Helper() + + originalIsMasterNode := common.IsMasterNode + originalSQLitePath := common.SQLitePath + originalUsingSQLite := common.UsingSQLite + originalUsingMySQL := common.UsingMySQL + originalUsingPostgreSQL := common.UsingPostgreSQL + originalSQLDSN, hadSQLDSN := os.LookupEnv("SQL_DSN") + defer func() { + common.IsMasterNode = originalIsMasterNode + common.SQLitePath = originalSQLitePath + common.UsingSQLite = originalUsingSQLite + common.UsingMySQL = originalUsingMySQL + common.UsingPostgreSQL = originalUsingPostgreSQL + if hadSQLDSN { + require.NoError(t, os.Setenv("SQL_DSN", originalSQLDSN)) + } else { + require.NoError(t, os.Unsetenv("SQL_DSN")) + } + }() + + common.IsMasterNode = false + common.SQLitePath = fmt.Sprintf("file:%s_init?mode=memory&cache=shared", strings.ReplaceAll(t.Name(), "/", "_")) + common.UsingSQLite = false + common.UsingMySQL = false + common.UsingPostgreSQL = false + require.NoError(t, os.Setenv("SQL_DSN", "local")) + + require.NoError(t, model.InitDB()) + if model.DB != nil { + sqlDB, err := model.DB.DB() + if err == nil { + _ = sqlDB.Close() + } + } +} + +func withTieredBillingConfig(t *testing.T, modes map[string]string, exprs map[string]string) { + t.Helper() + + saved := map[string]string{} + require.NoError(t, config.GlobalConfig.SaveToDB(func(key, value string) error { + if strings.HasPrefix(key, "billing_setting.") { + saved[key] = value + } + return nil + })) + t.Cleanup(func() { + require.NoError(t, config.GlobalConfig.LoadFromDB(saved)) + model.InvalidatePricingCache() + }) + + modeBytes, err := common.Marshal(modes) + require.NoError(t, err) + exprBytes, err := common.Marshal(exprs) + require.NoError(t, err) + + require.NoError(t, config.GlobalConfig.LoadFromDB(map[string]string{ + "billing_setting.billing_mode": string(modeBytes), + "billing_setting.billing_expr": string(exprBytes), + })) + model.InvalidatePricingCache() +} + +func withSelfUseModeDisabled(t *testing.T) { + t.Helper() + + original := operation_setting.SelfUseModeEnabled + operation_setting.SelfUseModeEnabled = false + t.Cleanup(func() { + operation_setting.SelfUseModeEnabled = original + }) +} + +func decodeListModelsResponse(t *testing.T, recorder *httptest.ResponseRecorder) map[string]struct{} { + t.Helper() + + require.Equal(t, http.StatusOK, recorder.Code) + var payload listModelsResponse + require.NoError(t, common.Unmarshal(recorder.Body.Bytes(), &payload)) + require.True(t, payload.Success) + require.Equal(t, "list", payload.Object) + + ids := make(map[string]struct{}, len(payload.Data)) + for _, item := range payload.Data { + ids[item.Id] = struct{}{} + } + return ids +} + +func pricingByModelName(pricings []model.Pricing) map[string]model.Pricing { + byName := make(map[string]model.Pricing, len(pricings)) + for _, pricing := range pricings { + byName[pricing.ModelName] = pricing + } + return byName +} + +func TestListModelsIncludesTieredBillingModel(t *testing.T) { + withSelfUseModeDisabled(t) + withTieredBillingConfig(t, map[string]string{ + "zz-tiered-visible-model": "tiered_expr", + "zz-tiered-empty-expr-model": "tiered_expr", + "zz-tiered-missing-expr-model": "tiered_expr", + }, map[string]string{ + "zz-tiered-visible-model": `tier("base", p * 1 + c * 2)`, + "zz-tiered-empty-expr-model": " ", + }) + + db := setupModelListControllerTestDB(t) + require.NoError(t, db.Create(&model.User{ + Id: 1001, + Username: "model-list-user", + Password: "password", + Group: "default", + Status: common.UserStatusEnabled, + }).Error) + require.NoError(t, db.Create(&[]model.Ability{ + {Group: "default", Model: "zz-tiered-visible-model", ChannelId: 1, Enabled: true}, + {Group: "default", Model: "zz-tiered-empty-expr-model", ChannelId: 1, Enabled: true}, + {Group: "default", Model: "zz-tiered-missing-expr-model", ChannelId: 1, Enabled: true}, + {Group: "default", Model: "zz-unpriced-model", ChannelId: 1, Enabled: true}, + }).Error) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) + ctx.Set("id", 1001) + + ListModels(ctx, constant.ChannelTypeOpenAI) + + ids := decodeListModelsResponse(t, recorder) + require.Contains(t, ids, "zz-tiered-visible-model") + require.NotContains(t, ids, "zz-tiered-empty-expr-model") + require.NotContains(t, ids, "zz-tiered-missing-expr-model") + require.NotContains(t, ids, "zz-unpriced-model") + + pricingByName := pricingByModelName(model.GetPricing()) + visiblePricing, ok := pricingByName["zz-tiered-visible-model"] + require.True(t, ok) + require.Equal(t, "tiered_expr", visiblePricing.BillingMode) + require.NotEmpty(t, visiblePricing.BillingExpr) + + emptyExprPricing, ok := pricingByName["zz-tiered-empty-expr-model"] + require.True(t, ok) + require.Empty(t, emptyExprPricing.BillingMode) + require.Empty(t, emptyExprPricing.BillingExpr) + + missingExprPricing, ok := pricingByName["zz-tiered-missing-expr-model"] + require.True(t, ok) + require.Empty(t, missingExprPricing.BillingMode) + require.Empty(t, missingExprPricing.BillingExpr) +} + +func TestListModelsTokenLimitIncludesTieredBillingModel(t *testing.T) { + withSelfUseModeDisabled(t) + withTieredBillingConfig(t, map[string]string{ + "zz-token-tiered-visible-model": "tiered_expr", + "zz-token-tiered-empty-expr-model": "tiered_expr", + "zz-token-tiered-missing-expr-model": "tiered_expr", + }, map[string]string{ + "zz-token-tiered-visible-model": `tier("base", p * 1 + c * 2)`, + "zz-token-tiered-empty-expr-model": "", + }) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) + common.SetContextKey(ctx, constant.ContextKeyTokenModelLimitEnabled, true) + common.SetContextKey(ctx, constant.ContextKeyTokenModelLimit, map[string]bool{ + "zz-token-tiered-visible-model": true, + "zz-token-tiered-empty-expr-model": true, + "zz-token-tiered-missing-expr-model": true, + "zz-token-unpriced-model": true, + }) + + ListModels(ctx, constant.ChannelTypeOpenAI) + + ids := decodeListModelsResponse(t, recorder) + require.Contains(t, ids, "zz-token-tiered-visible-model") + require.NotContains(t, ids, "zz-token-tiered-empty-expr-model") + require.NotContains(t, ids, "zz-token-tiered-missing-expr-model") + require.NotContains(t, ids, "zz-token-unpriced-model") +} diff --git a/model/option.go b/model/option.go index ae4e5ca36a..871f73a26f 100644 --- a/model/option.go +++ b/model/option.go @@ -578,6 +578,9 @@ func handleConfigUpdate(key, value string) bool { performance_setting.UpdateAndSync() } else if configName == "tool_price_setting" { operation_setting.RebuildToolPriceIndex() + } else if configName == "billing_setting" { + InvalidatePricingCache() + ratio_setting.InvalidateExposedDataCache() } return true // 已处理 diff --git a/model/pricing.go b/model/pricing.go index 0fe235629f..fe92758519 100644 --- a/model/pricing.go +++ b/model/pricing.go @@ -77,6 +77,29 @@ func GetPricing() []Pricing { return pricingMap } +func InvalidatePricingCache() { + updatePricingLock.Lock() + defer updatePricingLock.Unlock() + + pricingMap = nil + vendorsList = nil + lastGetPricingTime = time.Time{} +} + +func HasModelBillingConfig(modelName string) bool { + if _, ok := ratio_setting.GetModelPrice(modelName, false); ok { + return true + } + if _, ok, _ := ratio_setting.GetModelRatio(modelName); ok { + return true + } + if billing_setting.GetBillingMode(modelName) != billing_setting.BillingModeTieredExpr { + return false + } + expr, ok := billing_setting.GetBillingExpr(modelName) + return ok && strings.TrimSpace(expr) != "" +} + // GetVendors 返回当前定价接口使用到的供应商信息 func GetVendors() []PricingVendor { if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { @@ -323,7 +346,7 @@ func updatePricing() { pricing.AudioCompletionRatio = &audioCompletionRatio } if billingMode := billing_setting.GetBillingMode(model); billingMode == "tiered_expr" { - if expr, ok := billing_setting.GetBillingExpr(model); ok && expr != "" { + if expr, ok := billing_setting.GetBillingExpr(model); ok && strings.TrimSpace(expr) != "" { pricing.BillingMode = billingMode pricing.BillingExpr = expr } diff --git a/relay/helper/price.go b/relay/helper/price.go index 52b971c2b3..b4f8e66271 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -224,19 +224,7 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) (types } func ContainPriceOrRatio(modelName string) bool { - _, ok := ratio_setting.GetModelPrice(modelName, false) - if ok { - return true - } - _, ok, _ = ratio_setting.GetModelRatio(modelName) - if ok { - return true - } - if billing_setting.GetBillingMode(modelName) == billing_setting.BillingModeTieredExpr { - _, ok = billing_setting.GetBillingExpr(modelName) - return ok - } - return false + return model.HasModelBillingConfig(modelName) } func modelPriceHelperTiered(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, meta *types.TokenCountMeta, groupRatioInfo types.GroupRatioInfo) (types.PriceData, error) { From 3a2138ba61ce733fc781cef4f982fcb8dd4f1089 Mon Sep 17 00:00:00 2001 From: CaIon Date: Fri, 24 Apr 2026 16:39:12 +0800 Subject: [PATCH 05/67] refactor: rename and relocate HasModelBillingConfig function for clarity --- controller/model.go | 5 +++-- model/pricing.go | 14 -------------- relay/gemini_handler.go | 2 +- relay/helper/price.go | 15 +++++++++++++-- setting/ratio_setting/model_ratio_test.go | 22 ---------------------- 5 files changed, 17 insertions(+), 41 deletions(-) delete mode 100644 setting/ratio_setting/model_ratio_test.go diff --git a/controller/model.go b/controller/model.go index f22379555a..4dbd45838d 100644 --- a/controller/model.go +++ b/controller/model.go @@ -15,6 +15,7 @@ import ( "github.com/QuantumNous/new-api/relay/channel/minimax" "github.com/QuantumNous/new-api/relay/channel/moonshot" relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/relay/helper" "github.com/QuantumNous/new-api/service" "github.com/QuantumNous/new-api/setting/operation_setting" "github.com/QuantumNous/new-api/types" @@ -133,7 +134,7 @@ func ListModels(c *gin.Context, modelType int) { } for allowModel, _ := range tokenModelLimit { if !acceptUnsetRatioModel { - if !model.HasModelBillingConfig(allowModel) { + if !helper.HasModelBillingConfig(allowModel) { continue } } @@ -180,7 +181,7 @@ func ListModels(c *gin.Context, modelType int) { } for _, modelName := range models { if !acceptUnsetRatioModel { - if !model.HasModelBillingConfig(modelName) { + if !helper.HasModelBillingConfig(modelName) { continue } } diff --git a/model/pricing.go b/model/pricing.go index fe92758519..b9574a3885 100644 --- a/model/pricing.go +++ b/model/pricing.go @@ -86,20 +86,6 @@ func InvalidatePricingCache() { lastGetPricingTime = time.Time{} } -func HasModelBillingConfig(modelName string) bool { - if _, ok := ratio_setting.GetModelPrice(modelName, false); ok { - return true - } - if _, ok, _ := ratio_setting.GetModelRatio(modelName); ok { - return true - } - if billing_setting.GetBillingMode(modelName) != billing_setting.BillingModeTieredExpr { - return false - } - expr, ok := billing_setting.GetBillingExpr(modelName) - return ok && strings.TrimSpace(expr) != "" -} - // GetVendors 返回当前定价接口使用到的供应商信息 func GetVendors() []PricingVendor { if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index e663a28baa..3b4bafe2a6 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -77,7 +77,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ if !strings.Contains(info.OriginModelName, "-nothinking") { // try to get no thinking model price noThinkingModelName := info.OriginModelName + "-nothinking" - containPrice := helper.ContainPriceOrRatio(noThinkingModelName) + containPrice := helper.HasModelBillingConfig(noThinkingModelName) if containPrice { info.OriginModelName = noThinkingModelName info.UpstreamModelName = noThinkingModelName diff --git a/relay/helper/price.go b/relay/helper/price.go index b4f8e66271..a078325261 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -2,6 +2,7 @@ package helper import ( "fmt" + "strings" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/logger" @@ -223,8 +224,18 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) (types return priceData, nil } -func ContainPriceOrRatio(modelName string) bool { - return model.HasModelBillingConfig(modelName) +func HasModelBillingConfig(modelName string) bool { + if _, ok := ratio_setting.GetModelPrice(modelName, false); ok { + return true + } + if _, ok, _ := ratio_setting.GetModelRatio(modelName); ok { + return true + } + if billing_setting.GetBillingMode(modelName) != billing_setting.BillingModeTieredExpr { + return false + } + expr, ok := billing_setting.GetBillingExpr(modelName) + return ok && strings.TrimSpace(expr) != "" } func modelPriceHelperTiered(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, meta *types.TokenCountMeta, groupRatioInfo types.GroupRatioInfo) (types.PriceData, error) { diff --git a/setting/ratio_setting/model_ratio_test.go b/setting/ratio_setting/model_ratio_test.go deleted file mode 100644 index 0b7912f893..0000000000 --- a/setting/ratio_setting/model_ratio_test.go +++ /dev/null @@ -1,22 +0,0 @@ -package ratio_setting - -import "testing" - -func TestGetCompletionRatioInfoGPT55UsesOfficialOutputMultiplier(t *testing.T) { - info := GetCompletionRatioInfo("gpt-5.5") - - if info.Ratio != 6 { - t.Fatalf("gpt-5.5 completion ratio = %v, want 6", info.Ratio) - } - if !info.Locked { - t.Fatal("gpt-5.5 completion ratio should be locked to the official multiplier") - } -} - -func TestGetCompletionRatioGPT55DatedVariant(t *testing.T) { - got := GetCompletionRatio("gpt-5.5-2026-04-24") - - if got != 6 { - t.Fatalf("gpt-5.5 dated variant completion ratio = %v, want 6", got) - } -} From 435d7ae0dd8cf4e9f00acc4a0920c354cd6a9934 Mon Sep 17 00:00:00 2001 From: HynoR <20227709+HynoR@users.noreply.github.com> Date: Fri, 24 Apr 2026 16:40:07 +0800 Subject: [PATCH 06/67] feat: support DeepSeek V4 reasoning suffix handling --- relay/channel/deepseek/adaptor.go | 77 ++++++++++++++++++++++++++++- relay/channel/deepseek/constants.go | 2 + relay/channel/openai/adaptor.go | 20 ++------ setting/reasoning/suffix.go | 33 ++++++++++++- 4 files changed, 113 insertions(+), 19 deletions(-) diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index 57fcf3d043..60eaf22be5 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -7,12 +7,14 @@ import ( "net/http" "strings" + "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" "github.com/QuantumNous/new-api/relay/channel/claude" "github.com/QuantumNous/new-api/relay/channel/openai" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/relay/constant" + "github.com/QuantumNous/new-api/setting/reasoning" "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" ) @@ -27,7 +29,18 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { adaptor := claude.Adaptor{} - return adaptor.ConvertClaudeRequest(c, info, req) + convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, req) + if err != nil { + return nil, err + } + claudeRequest, ok := convertedRequest.(*dto.ClaudeRequest) + if !ok { + return convertedRequest, nil + } + if err := applyDeepSeekV4ClaudeThinkingSuffix(info, claudeRequest); err != nil { + return nil, err + } + return claudeRequest, nil } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { @@ -71,9 +84,71 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } + if err := applyDeepSeekV4OpenAIThinkingSuffix(info, request); err != nil { + return nil, err + } + return request, nil } +func applyDeepSeekV4OpenAIThinkingSuffix(info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) error { + modelName := request.Model + if info != nil && info.ChannelMeta != nil && info.UpstreamModelName != "" { + modelName = info.UpstreamModelName + } + baseModel, thinkingType, effort, ok := reasoning.ParseDeepSeekV4ThinkingSuffix(modelName) + if !ok { + return nil + } + thinking, err := common.Marshal(map[string]string{ + "type": thinkingType, + }) + if err != nil { + return fmt.Errorf("error marshalling thinking: %w", err) + } + request.Model = baseModel + request.THINKING = thinking + request.ReasoningEffort = effort + if info != nil { + if info.ChannelMeta != nil { + info.UpstreamModelName = baseModel + } + info.ReasoningEffort = effort + } + return nil +} + +func applyDeepSeekV4ClaudeThinkingSuffix(info *relaycommon.RelayInfo, request *dto.ClaudeRequest) error { + modelName := request.Model + if info != nil && info.ChannelMeta != nil && info.UpstreamModelName != "" { + modelName = info.UpstreamModelName + } + baseModel, thinkingType, effort, ok := reasoning.ParseDeepSeekV4ThinkingSuffix(modelName) + if !ok { + return nil + } + request.Model = baseModel + request.Thinking = &dto.Thinking{Type: thinkingType} + if effort == "" { + request.OutputConfig = nil + } else { + outputConfig, err := common.Marshal(map[string]string{ + "effort": effort, + }) + if err != nil { + return fmt.Errorf("error marshalling output_config: %w", err) + } + request.OutputConfig = outputConfig + } + if info != nil { + if info.ChannelMeta != nil { + info.UpstreamModelName = baseModel + } + info.ReasoningEffort = effort + } + return nil +} + func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { return nil, nil } diff --git a/relay/channel/deepseek/constants.go b/relay/channel/deepseek/constants.go index 1d7b1e329e..83f013305e 100644 --- a/relay/channel/deepseek/constants.go +++ b/relay/channel/deepseek/constants.go @@ -2,6 +2,8 @@ package deepseek var ModelList = []string{ "deepseek-chat", "deepseek-reasoner", + "deepseek-v4-flash", "deepseek-v4-flash-none", "deepseek-v4-flash-max", + "deepseek-v4-pro", "deepseek-v4-pro-none", "deepseek-v4-pro-max", } var ChannelName = "deepseek" diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 56a58f2865..3fd8d76a49 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -28,6 +28,7 @@ import ( relayconstant "github.com/QuantumNous/new-api/relay/constant" "github.com/QuantumNous/new-api/service" "github.com/QuantumNous/new-api/setting/model_setting" + "github.com/QuantumNous/new-api/setting/reasoning" "github.com/QuantumNous/new-api/types" "github.com/samber/lo" @@ -39,21 +40,6 @@ type Adaptor struct { ResponseFormat string } -// parseReasoningEffortFromModelSuffix 从模型名称中解析推理级别 -// support OAI models: o1-mini/o3-mini/o4-mini/o1/o3 etc... -// minimal effort only available in gpt-5 -func parseReasoningEffortFromModelSuffix(model string) (string, string) { - effortSuffixes := []string{"-high", "-minimal", "-low", "-medium", "-none", "-xhigh"} - for _, suffix := range effortSuffixes { - if strings.HasSuffix(model, suffix) { - effort := strings.TrimPrefix(suffix, "-") - originModel := strings.TrimSuffix(model, suffix) - return effort, originModel - } - } - return "", model -} - func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) { // 使用 service.GeminiToOpenAIRequest 转换请求格式 openaiRequest, err := service.GeminiToOpenAIRequest(request, info) @@ -342,7 +328,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn } // 转换模型推理力度后缀 - effort, originModel := parseReasoningEffortFromModelSuffix(info.UpstreamModelName) + effort, originModel := reasoning.ParseOpenAIReasoningEffortFromModelSuffix(info.UpstreamModelName) if effort != "" { request.ReasoningEffort = effort info.UpstreamModelName = originModel @@ -587,7 +573,7 @@ func detectImageMimeType(filename string) string { func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { // 转换模型推理力度后缀 - effort, originModel := parseReasoningEffortFromModelSuffix(request.Model) + effort, originModel := reasoning.ParseOpenAIReasoningEffortFromModelSuffix(request.Model) if effort != "" { if request.Reasoning == nil { request.Reasoning = &dto.Reasoning{ diff --git a/setting/reasoning/suffix.go b/setting/reasoning/suffix.go index 2b95de6dde..59140a7c8d 100644 --- a/setting/reasoning/suffix.go +++ b/setting/reasoning/suffix.go @@ -8,9 +8,17 @@ import ( var EffortSuffixes = []string{"-max", "-xhigh", "-high", "-medium", "-low", "-minimal"} +var OpenAIEffortSuffixes = []string{"-high", "-minimal", "-low", "-medium", "-none", "-xhigh"} + +var DeepSeekV4EffortSuffixes = []string{"-none", "-max"} + // TrimEffortSuffix -> modelName level(low) exists func TrimEffortSuffix(modelName string) (string, string, bool) { - suffix, found := lo.Find(EffortSuffixes, func(s string) bool { + return TrimEffortSuffixWithSuffixes(modelName, EffortSuffixes) +} + +func TrimEffortSuffixWithSuffixes(modelName string, suffixes []string) (string, string, bool) { + suffix, found := lo.Find(suffixes, func(s string) bool { return strings.HasSuffix(modelName, s) }) if !found { @@ -18,3 +26,26 @@ func TrimEffortSuffix(modelName string) (string, string, bool) { } return strings.TrimSuffix(modelName, suffix), strings.TrimPrefix(suffix, "-"), true } + +func ParseOpenAIReasoningEffortFromModelSuffix(modelName string) (string, string) { + baseModel, effort, ok := TrimEffortSuffixWithSuffixes(modelName, OpenAIEffortSuffixes) + if !ok { + return "", modelName + } + return effort, baseModel +} + +func ParseDeepSeekV4ThinkingSuffix(modelName string) (baseModel string, thinkingType string, effort string, ok bool) { + baseModel, suffix, ok := TrimEffortSuffixWithSuffixes(modelName, DeepSeekV4EffortSuffixes) + if !ok || !strings.HasPrefix(baseModel, "deepseek-v4-") { + return modelName, "", "", false + } + switch suffix { + case "none": + return baseModel, "disabled", "", true + case "max": + return baseModel, "enabled", "max", true + default: + return modelName, "", "", false + } +} From 095e1920f1a07f14d1d54b41a93f7e245d9a44c3 Mon Sep 17 00:00:00 2001 From: Seefs Date: Fri, 24 Apr 2026 17:51:46 +0800 Subject: [PATCH 07/67] fix(channel): load model mapping during upstream model checks --- controller/channel_upstream_update.go | 24 ++++++++++++++++++++-- controller/channel_upstream_update_test.go | 4 ++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/controller/channel_upstream_update.go b/controller/channel_upstream_update.go index 1d851949fd..77a1e3c817 100644 --- a/controller/channel_upstream_update.go +++ b/controller/channel_upstream_update.go @@ -32,6 +32,26 @@ const ( channelUpstreamModelUpdateNotifyMaxFailedChannelIDs = 10 ) +var channelUpstreamModelUpdateSelectFields = []string{ + "id", + "name", + "type", + "key", + "status", + "base_url", + "models", + "model_mapping", + "settings", + "setting", + "other", + "group", + "priority", + "weight", + "tag", + "channel_info", + "header_override", +} + var ( channelUpstreamModelUpdateTaskOnce sync.Once channelUpstreamModelUpdateTaskRunning atomic.Bool @@ -521,7 +541,7 @@ func runChannelUpstreamModelUpdateTaskOnce() { for { var channels []*model.Channel query := model.DB. - Select("id", "name", "type", "key", "status", "base_url", "models", "settings", "setting", "other", "group", "priority", "weight", "tag", "channel_info", "header_override"). + Select(channelUpstreamModelUpdateSelectFields). Where("status = ?", common.ChannelStatusEnabled). Order("id asc"). Limit(channelUpstreamModelUpdateTaskBatchSize) @@ -814,7 +834,7 @@ func collectPendingApplyUpstreamModelChanges(settings dto.ChannelOtherSettings) func findEnabledChannelsAfterID(lastID int, batchSize int) ([]*model.Channel, error) { var channels []*model.Channel query := model.DB. - Select("id", "name", "type", "key", "status", "base_url", "models", "settings", "setting", "other", "group", "priority", "weight", "tag", "channel_info", "header_override"). + Select(channelUpstreamModelUpdateSelectFields). Where("status = ?", common.ChannelStatusEnabled). Order("id asc"). Limit(batchSize) diff --git a/controller/channel_upstream_update_test.go b/controller/channel_upstream_update_test.go index 52de830b9a..d9890d9120 100644 --- a/controller/channel_upstream_update_test.go +++ b/controller/channel_upstream_update_test.go @@ -81,6 +81,10 @@ func TestCollectPendingApplyUpstreamModelChanges(t *testing.T) { require.Equal(t, []string{"old-model"}, pendingRemoveModels) } +func TestChannelUpstreamModelUpdateSelectFieldsIncludeModelMapping(t *testing.T) { + require.Contains(t, channelUpstreamModelUpdateSelectFields, "model_mapping") +} + func TestNormalizeChannelModelMapping(t *testing.T) { modelMapping := `{ " alias-model ": " upstream-model ", From a7c38ec851a3fbd123dac5b3185f2fed6c179944 Mon Sep 17 00:00:00 2001 From: CaIon Date: Fri, 24 Apr 2026 22:16:16 +0800 Subject: [PATCH 08/67] fix: add PaymentProvider field to prevent cross-gateway callback attacks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit EPay allows users to switch payment methods (e.g. wxpay→alipay) during checkout, causing callback rejection. Replace fragile blocklist guard with a PaymentProvider field on TopUp and SubscriptionOrder that identifies which gateway created the order. --- controller/subscription_payment_creem.go | 15 ++-- controller/subscription_payment_epay.go | 21 +++--- controller/subscription_payment_stripe.go | 15 ++-- controller/topup.go | 38 ++++------ controller/topup_creem.go | 17 ++--- controller/topup_epay_guard_test.go | 31 --------- controller/topup_stripe.go | 25 +++---- controller/topup_waffo.go | 17 ++--- controller/topup_waffo_pancake.go | 15 ++-- model/payment_method_guard_test.go | 84 ++++++++++++----------- model/subscription.go | 24 ++++--- model/topup.go | 41 ++++++----- 12 files changed, 163 insertions(+), 180 deletions(-) delete mode 100644 controller/topup_epay_guard_test.go diff --git a/controller/subscription_payment_creem.go b/controller/subscription_payment_creem.go index 935429acfe..18e1a58487 100644 --- a/controller/subscription_payment_creem.go +++ b/controller/subscription_payment_creem.go @@ -83,13 +83,14 @@ func SubscriptionRequestCreemPay(c *gin.Context) { // create pending order first order := &model.SubscriptionOrder{ - UserId: userId, - PlanId: plan.Id, - Money: plan.PriceAmount, - TradeNo: referenceId, - PaymentMethod: model.PaymentMethodCreem, - CreateTime: time.Now().Unix(), - Status: common.TopUpStatusPending, + UserId: userId, + PlanId: plan.Id, + Money: plan.PriceAmount, + TradeNo: referenceId, + PaymentMethod: model.PaymentMethodCreem, + PaymentProvider: model.PaymentProviderCreem, + CreateTime: time.Now().Unix(), + Status: common.TopUpStatusPending, } if err := order.Insert(); err != nil { c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"}) diff --git a/controller/subscription_payment_epay.go b/controller/subscription_payment_epay.go index 8f7848d5cf..2567654ff4 100644 --- a/controller/subscription_payment_epay.go +++ b/controller/subscription_payment_epay.go @@ -82,13 +82,14 @@ func SubscriptionRequestEpay(c *gin.Context) { } order := &model.SubscriptionOrder{ - UserId: userId, - PlanId: plan.Id, - Money: plan.PriceAmount, - TradeNo: tradeNo, - PaymentMethod: req.PaymentMethod, - CreateTime: time.Now().Unix(), - Status: common.TopUpStatusPending, + UserId: userId, + PlanId: plan.Id, + Money: plan.PriceAmount, + TradeNo: tradeNo, + PaymentMethod: req.PaymentMethod, + PaymentProvider: model.PaymentProviderEpay, + CreateTime: time.Now().Unix(), + Status: common.TopUpStatusPending, } if err := order.Insert(); err != nil { common.ApiErrorMsg(c, "创建订单失败") @@ -104,7 +105,7 @@ func SubscriptionRequestEpay(c *gin.Context) { ReturnUrl: returnUrl, }) if err != nil { - _ = model.ExpireSubscriptionOrder(tradeNo, req.PaymentMethod) + _ = model.ExpireSubscriptionOrder(tradeNo, model.PaymentProviderEpay) common.ApiErrorMsg(c, "拉起支付失败") return } @@ -156,7 +157,7 @@ func SubscriptionEpayNotify(c *gin.Context) { LockOrder(verifyInfo.ServiceTradeNo) defer UnlockOrder(verifyInfo.ServiceTradeNo) - if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), verifyInfo.Type); err != nil { + if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), model.PaymentProviderEpay, verifyInfo.Type); err != nil { _, _ = c.Writer.Write([]byte("fail")) return } @@ -205,7 +206,7 @@ func SubscriptionEpayReturn(c *gin.Context) { if verifyInfo.TradeStatus == epay.StatusTradeSuccess { LockOrder(verifyInfo.ServiceTradeNo) defer UnlockOrder(verifyInfo.ServiceTradeNo) - if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), verifyInfo.Type); err != nil { + if err := model.CompleteSubscriptionOrder(verifyInfo.ServiceTradeNo, common.GetJsonString(verifyInfo), model.PaymentProviderEpay, verifyInfo.Type); err != nil { c.Redirect(http.StatusFound, system_setting.ServerAddress+"/console/topup?pay=fail") return } diff --git a/controller/subscription_payment_stripe.go b/controller/subscription_payment_stripe.go index 9824c90dc4..a5ce4685b4 100644 --- a/controller/subscription_payment_stripe.go +++ b/controller/subscription_payment_stripe.go @@ -84,13 +84,14 @@ func SubscriptionRequestStripePay(c *gin.Context) { } order := &model.SubscriptionOrder{ - UserId: userId, - PlanId: plan.Id, - Money: plan.PriceAmount, - TradeNo: referenceId, - PaymentMethod: model.PaymentMethodStripe, - CreateTime: time.Now().Unix(), - Status: common.TopUpStatusPending, + UserId: userId, + PlanId: plan.Id, + Money: plan.PriceAmount, + TradeNo: referenceId, + PaymentMethod: model.PaymentMethodStripe, + PaymentProvider: model.PaymentProviderStripe, + CreateTime: time.Now().Unix(), + Status: common.TopUpStatusPending, } if err := order.Insert(); err != nil { c.JSON(http.StatusOK, gin.H{"message": "error", "data": "创建订单失败"}) diff --git a/controller/topup.go b/controller/topup.go index 86d361a349..a6445b40d6 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -123,17 +123,6 @@ type AmountRequest struct { Amount int64 `json:"amount"` } -var nonEpayPaymentMethodsForCallback = []string{ - model.PaymentMethodStripe, - model.PaymentMethodCreem, - model.PaymentMethodWaffo, - model.PaymentMethodWaffoPancake, -} - -func isNonEpayPaymentMethodForEpayCallback(paymentMethod string) bool { - return lo.Contains(nonEpayPaymentMethodsForCallback, paymentMethod) -} - func GetEpayClient() *epay.Client { if operation_setting.PayAddress == "" || operation_setting.EpayId == "" || operation_setting.EpayKey == "" { return nil @@ -248,13 +237,14 @@ func RequestEpay(c *gin.Context) { amount = dAmount.Div(dQuotaPerUnit).IntPart() } topUp := &model.TopUp{ - UserId: id, - Amount: amount, - Money: payMoney, - TradeNo: tradeNo, - PaymentMethod: req.PaymentMethod, - CreateTime: time.Now().Unix(), - Status: common.TopUpStatusPending, + UserId: id, + Amount: amount, + Money: payMoney, + TradeNo: tradeNo, + PaymentMethod: req.PaymentMethod, + PaymentProvider: model.PaymentProviderEpay, + CreateTime: time.Now().Unix(), + Status: common.TopUpStatusPending, } err = topUp.Insert() if err != nil { @@ -379,15 +369,15 @@ func EpayNotify(c *gin.Context) { logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 回调订单不存在 trade_no=%s callback_type=%s client_ip=%s verify_info=%q", verifyInfo.ServiceTradeNo, verifyInfo.Type, c.ClientIP(), common.GetJsonString(verifyInfo))) return } - if isNonEpayPaymentMethodForEpayCallback(topUp.PaymentMethod) { - logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付方式不匹配 trade_no=%s order_payment_method=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP())) - return - } - if topUp.PaymentMethod != verifyInfo.Type { - logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付方式不匹配 trade_no=%s order_payment_method=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP())) + if topUp.PaymentProvider != model.PaymentProviderEpay { + logger.LogWarn(c.Request.Context(), fmt.Sprintf("易支付 订单支付网关不匹配 trade_no=%s order_provider=%s callback_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentProvider, verifyInfo.Type, c.ClientIP())) return } if topUp.Status == common.TopUpStatusPending { + if topUp.PaymentMethod != verifyInfo.Type { + logger.LogInfo(c.Request.Context(), fmt.Sprintf("易支付 实际支付方式与订单不同 trade_no=%s order_payment_method=%s actual_type=%s client_ip=%s", verifyInfo.ServiceTradeNo, topUp.PaymentMethod, verifyInfo.Type, c.ClientIP())) + topUp.PaymentMethod = verifyInfo.Type + } topUp.Status = common.TopUpStatusSuccess err := topUp.Update() if err != nil { diff --git a/controller/topup_creem.go b/controller/topup_creem.go index 139dd43fbe..7472690e22 100644 --- a/controller/topup_creem.go +++ b/controller/topup_creem.go @@ -106,13 +106,14 @@ func (*CreemAdaptor) RequestPay(c *gin.Context, req *CreemPayRequest) { // 先创建订单记录,使用产品配置的金额和充值额度 topUp := &model.TopUp{ - UserId: id, - Amount: selectedProduct.Quota, // 充值额度 - Money: selectedProduct.Price, // 支付金额 - TradeNo: referenceId, - PaymentMethod: model.PaymentMethodCreem, - CreateTime: time.Now().Unix(), - Status: common.TopUpStatusPending, + UserId: id, + Amount: selectedProduct.Quota, // 充值额度 + Money: selectedProduct.Price, // 支付金额 + TradeNo: referenceId, + PaymentMethod: model.PaymentMethodCreem, + PaymentProvider: model.PaymentProviderCreem, + CreateTime: time.Now().Unix(), + Status: common.TopUpStatusPending, } err = topUp.Insert() if err != nil { @@ -301,7 +302,7 @@ func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) { // Try complete subscription order first LockOrder(referenceId) defer UnlockOrder(referenceId) - if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(event), model.PaymentMethodCreem); err == nil { + if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(event), model.PaymentProviderCreem, ""); err == nil { logger.LogInfo(c.Request.Context(), fmt.Sprintf("Creem 订阅订单处理成功 trade_no=%s creem_order_id=%s", referenceId, event.Object.Order.Id)) c.Status(http.StatusOK) return diff --git a/controller/topup_epay_guard_test.go b/controller/topup_epay_guard_test.go deleted file mode 100644 index 3451266558..0000000000 --- a/controller/topup_epay_guard_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package controller - -import ( - "testing" - - "github.com/QuantumNous/new-api/model" -) - -func TestIsNonEpayPaymentMethodForEpayCallback(t *testing.T) { - testCases := []struct { - name string - paymentMethod string - expectedBlocked bool - }{ - {name: "stripe", paymentMethod: model.PaymentMethodStripe, expectedBlocked: true}, - {name: "creem", paymentMethod: model.PaymentMethodCreem, expectedBlocked: true}, - {name: "waffo", paymentMethod: model.PaymentMethodWaffo, expectedBlocked: true}, - {name: "waffo pancake", paymentMethod: model.PaymentMethodWaffoPancake, expectedBlocked: true}, - {name: "alipay", paymentMethod: "alipay", expectedBlocked: false}, - {name: "wxpay", paymentMethod: "wxpay", expectedBlocked: false}, - {name: "custom epay type", paymentMethod: "custom1", expectedBlocked: false}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - if actual := isNonEpayPaymentMethodForEpayCallback(tc.paymentMethod); actual != tc.expectedBlocked { - t.Fatalf("expected blocked=%v, got %v for payment method %q", tc.expectedBlocked, actual, tc.paymentMethod) - } - }) - } -} diff --git a/controller/topup_stripe.go b/controller/topup_stripe.go index 23ddb3b90e..ceee8ecdd6 100644 --- a/controller/topup_stripe.go +++ b/controller/topup_stripe.go @@ -101,13 +101,14 @@ func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) { } topUp := &model.TopUp{ - UserId: id, - Amount: req.Amount, - Money: chargedMoney, - TradeNo: referenceId, - PaymentMethod: model.PaymentMethodStripe, - CreateTime: time.Now().Unix(), - Status: common.TopUpStatusPending, + UserId: id, + Amount: req.Amount, + Money: chargedMoney, + TradeNo: referenceId, + PaymentMethod: model.PaymentMethodStripe, + PaymentProvider: model.PaymentProviderStripe, + CreateTime: time.Now().Unix(), + Status: common.TopUpStatusPending, } err = topUp.Insert() if err != nil { @@ -237,8 +238,8 @@ func sessionAsyncPaymentFailed(ctx context.Context, event stripe.Event, callerIp return } - if topUp.PaymentMethod != model.PaymentMethodStripe { - logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败但订单支付方式不匹配 trade_no=%s payment_method=%s client_ip=%s", referenceId, topUp.PaymentMethod, callerIp)) + if topUp.PaymentProvider != model.PaymentProviderStripe { + logger.LogWarn(ctx, fmt.Sprintf("Stripe 异步支付失败但订单支付网关不匹配 trade_no=%s payment_provider=%s client_ip=%s", referenceId, topUp.PaymentProvider, callerIp)) return } @@ -270,7 +271,7 @@ func fulfillOrder(ctx context.Context, event stripe.Event, referenceId string, c "currency": strings.ToUpper(event.GetObjectValue("currency")), "event_type": string(event.Type), } - if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(payload), model.PaymentMethodStripe); err == nil { + if err := model.CompleteSubscriptionOrder(referenceId, common.GetJsonString(payload), model.PaymentProviderStripe, ""); err == nil { logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单处理成功 trade_no=%s event_type=%s client_ip=%s", referenceId, string(event.Type), callerIp)) return } else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) { @@ -305,7 +306,7 @@ func sessionExpired(ctx context.Context, event stripe.Event) { // Subscription order expiration LockOrder(referenceId) defer UnlockOrder(referenceId) - if err := model.ExpireSubscriptionOrder(referenceId, model.PaymentMethodStripe); err == nil { + if err := model.ExpireSubscriptionOrder(referenceId, model.PaymentProviderStripe); err == nil { logger.LogInfo(ctx, fmt.Sprintf("Stripe 订阅订单已过期 trade_no=%s", referenceId)) return } else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) { @@ -313,7 +314,7 @@ func sessionExpired(ctx context.Context, event stripe.Event) { return } - err := model.UpdatePendingTopUpStatus(referenceId, model.PaymentMethodStripe, common.TopUpStatusExpired) + err := model.UpdatePendingTopUpStatus(referenceId, model.PaymentProviderStripe, common.TopUpStatusExpired) if errors.Is(err, model.ErrTopUpNotFound) { logger.LogWarn(ctx, fmt.Sprintf("Stripe 充值订单不存在,无法标记过期 trade_no=%s", referenceId)) return diff --git a/controller/topup_waffo.go b/controller/topup_waffo.go index c006806281..1885c1ded9 100644 --- a/controller/topup_waffo.go +++ b/controller/topup_waffo.go @@ -208,13 +208,14 @@ func RequestWaffoPay(c *gin.Context) { // 创建本地订单 topUp := &model.TopUp{ - UserId: id, - Amount: amount, - Money: payMoney, - TradeNo: merchantOrderId, - PaymentMethod: model.PaymentMethodWaffo, - CreateTime: time.Now().Unix(), - Status: common.TopUpStatusPending, + UserId: id, + Amount: amount, + Money: payMoney, + TradeNo: merchantOrderId, + PaymentMethod: model.PaymentMethodWaffo, + PaymentProvider: model.PaymentProviderWaffo, + CreateTime: time.Now().Unix(), + Status: common.TopUpStatusPending, } if err := topUp.Insert(); err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, merchantOrderId, req.Amount, err.Error())) @@ -379,7 +380,7 @@ func handleWaffoPayment(c *gin.Context, wh *core.WebhookHandler, result *core.Pa logger.LogInfo(c.Request.Context(), fmt.Sprintf("Waffo 订单状态非成功,忽略充值 trade_no=%s order_status=%s client_ip=%s", result.MerchantOrderID, result.OrderStatus, c.ClientIP())) // 终态失败订单标记为 failed,避免永远停在 pending if result.MerchantOrderID != "" { - if err := model.UpdatePendingTopUpStatus(result.MerchantOrderID, model.PaymentMethodWaffo, common.TopUpStatusFailed); err != nil && + if err := model.UpdatePendingTopUpStatus(result.MerchantOrderID, model.PaymentProviderWaffo, common.TopUpStatusFailed); err != nil && !errors.Is(err, model.ErrTopUpNotFound) && !errors.Is(err, model.ErrTopUpStatusInvalid) { logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo 标记失败订单状态失败 trade_no=%s error=%q", result.MerchantOrderID, err.Error())) diff --git a/controller/topup_waffo_pancake.go b/controller/topup_waffo_pancake.go index 81515a56ed..09f1516304 100644 --- a/controller/topup_waffo_pancake.go +++ b/controller/topup_waffo_pancake.go @@ -159,13 +159,14 @@ func RequestWaffoPancakePay(c *gin.Context) { tradeNo := fmt.Sprintf("WAFFO_PANCAKE-%d-%d-%s", id, time.Now().UnixMilli(), randstr.String(6)) topUp := &model.TopUp{ - UserId: id, - Amount: normalizeWaffoPancakeTopUpAmount(req.Amount), - Money: payMoney, - TradeNo: tradeNo, - PaymentMethod: model.PaymentMethodWaffoPancake, - CreateTime: time.Now().Unix(), - Status: common.TopUpStatusPending, + UserId: id, + Amount: normalizeWaffoPancakeTopUpAmount(req.Amount), + Money: payMoney, + TradeNo: tradeNo, + PaymentMethod: model.PaymentMethodWaffoPancake, + PaymentProvider: model.PaymentProviderWaffoPancake, + CreateTime: time.Now().Unix(), + Status: common.TopUpStatusPending, } if err := topUp.Insert(); err != nil { logger.LogError(c.Request.Context(), fmt.Sprintf("Waffo Pancake 创建充值订单失败 user_id=%d trade_no=%s amount=%d error=%q", id, tradeNo, req.Amount, err.Error())) diff --git a/model/payment_method_guard_test.go b/model/payment_method_guard_test.go index 9bc292444f..7f4f15cc34 100644 --- a/model/payment_method_guard_test.go +++ b/model/payment_method_guard_test.go @@ -36,30 +36,32 @@ func insertSubscriptionPlanForPaymentGuardTest(t *testing.T, id int) *Subscripti return plan } -func insertSubscriptionOrderForPaymentGuardTest(t *testing.T, tradeNo string, userID int, planID int, paymentMethod string) { +func insertSubscriptionOrderForPaymentGuardTest(t *testing.T, tradeNo string, userID int, planID int, paymentProvider string) { t.Helper() order := &SubscriptionOrder{ - UserId: userID, - PlanId: planID, - Money: 9.99, - TradeNo: tradeNo, - PaymentMethod: paymentMethod, - Status: common.TopUpStatusPending, - CreateTime: time.Now().Unix(), + UserId: userID, + PlanId: planID, + Money: 9.99, + TradeNo: tradeNo, + PaymentMethod: paymentProvider, + PaymentProvider: paymentProvider, + Status: common.TopUpStatusPending, + CreateTime: time.Now().Unix(), } require.NoError(t, order.Insert()) } -func insertTopUpForPaymentGuardTest(t *testing.T, tradeNo string, userID int, paymentMethod string) { +func insertTopUpForPaymentGuardTest(t *testing.T, tradeNo string, userID int, paymentProvider string) { t.Helper() topUp := &TopUp{ - UserId: userID, - Amount: 2, - Money: 9.99, - TradeNo: tradeNo, - PaymentMethod: paymentMethod, - Status: common.TopUpStatusPending, - CreateTime: time.Now().Unix(), + UserId: userID, + Amount: 2, + Money: 9.99, + TradeNo: tradeNo, + PaymentMethod: paymentProvider, + PaymentProvider: paymentProvider, + Status: common.TopUpStatusPending, + CreateTime: time.Now().Unix(), } require.NoError(t, topUp.Insert()) } @@ -89,7 +91,7 @@ func TestRechargeWaffoPancake_RejectsMismatchedPaymentMethod(t *testing.T) { truncateTables(t) insertUserForPaymentGuardTest(t, 101, 0) - insertTopUpForPaymentGuardTest(t, "waffo-pancake-guard", 101, PaymentMethodStripe) + insertTopUpForPaymentGuardTest(t, "waffo-pancake-guard", 101, PaymentProviderStripe) err := RechargeWaffoPancake("waffo-pancake-guard") require.Error(t, err) @@ -100,27 +102,27 @@ func TestRechargeWaffoPancake_RejectsMismatchedPaymentMethod(t *testing.T) { assert.Equal(t, 0, getUserQuotaForPaymentGuardTest(t, 101)) } -func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentMethod(t *testing.T) { +func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentProvider(t *testing.T) { testCases := []struct { - name string - tradeNo string - storedPaymentMethod string - expectedPaymentMethod string - targetStatus string + name string + tradeNo string + storedPaymentProvider string + expectedPaymentProvider string + targetStatus string }{ { - name: "stripe expire", - tradeNo: "stripe-expire-guard", - storedPaymentMethod: PaymentMethodCreem, - expectedPaymentMethod: PaymentMethodStripe, - targetStatus: common.TopUpStatusExpired, + name: "stripe expire", + tradeNo: "stripe-expire-guard", + storedPaymentProvider: PaymentProviderCreem, + expectedPaymentProvider: PaymentProviderStripe, + targetStatus: common.TopUpStatusExpired, }, { - name: "waffo failed", - tradeNo: "waffo-failed-guard", - storedPaymentMethod: PaymentMethodStripe, - expectedPaymentMethod: PaymentMethodWaffo, - targetStatus: common.TopUpStatusFailed, + name: "waffo failed", + tradeNo: "waffo-failed-guard", + storedPaymentProvider: PaymentProviderStripe, + expectedPaymentProvider: PaymentProviderWaffo, + targetStatus: common.TopUpStatusFailed, }, } @@ -128,23 +130,23 @@ func TestUpdatePendingTopUpStatus_RejectsMismatchedPaymentMethod(t *testing.T) { t.Run(tc.name, func(t *testing.T) { truncateTables(t) insertUserForPaymentGuardTest(t, 150, 0) - insertTopUpForPaymentGuardTest(t, tc.tradeNo, 150, tc.storedPaymentMethod) + insertTopUpForPaymentGuardTest(t, tc.tradeNo, 150, tc.storedPaymentProvider) - err := UpdatePendingTopUpStatus(tc.tradeNo, tc.expectedPaymentMethod, tc.targetStatus) + err := UpdatePendingTopUpStatus(tc.tradeNo, tc.expectedPaymentProvider, tc.targetStatus) require.ErrorIs(t, err, ErrPaymentMethodMismatch) assert.Equal(t, common.TopUpStatusPending, getTopUpStatusForPaymentGuardTest(t, tc.tradeNo)) }) } } -func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T) { +func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentProvider(t *testing.T) { truncateTables(t) insertUserForPaymentGuardTest(t, 202, 0) plan := insertSubscriptionPlanForPaymentGuardTest(t, 301) - insertSubscriptionOrderForPaymentGuardTest(t, "sub-guard-order", 202, plan.Id, PaymentMethodStripe) + insertSubscriptionOrderForPaymentGuardTest(t, "sub-guard-order", 202, plan.Id, PaymentProviderStripe) - err := CompleteSubscriptionOrder("sub-guard-order", `{"provider":"epay"}`, "alipay") + err := CompleteSubscriptionOrder("sub-guard-order", `{"provider":"epay"}`, PaymentProviderEpay, "alipay") require.ErrorIs(t, err, ErrPaymentMethodMismatch) order := GetSubscriptionOrderByTradeNo("sub-guard-order") @@ -156,14 +158,14 @@ func TestCompleteSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T) assert.Nil(t, topUp) } -func TestExpireSubscriptionOrder_RejectsMismatchedPaymentMethod(t *testing.T) { +func TestExpireSubscriptionOrder_RejectsMismatchedPaymentProvider(t *testing.T) { truncateTables(t) insertUserForPaymentGuardTest(t, 303, 0) plan := insertSubscriptionPlanForPaymentGuardTest(t, 401) - insertSubscriptionOrderForPaymentGuardTest(t, "sub-expire-guard", 303, plan.Id, PaymentMethodStripe) + insertSubscriptionOrderForPaymentGuardTest(t, "sub-expire-guard", 303, plan.Id, PaymentProviderStripe) - err := ExpireSubscriptionOrder("sub-expire-guard", PaymentMethodCreem) + err := ExpireSubscriptionOrder("sub-expire-guard", PaymentProviderCreem) require.ErrorIs(t, err, ErrPaymentMethodMismatch) order := GetSubscriptionOrderByTradeNo("sub-expire-guard") diff --git a/model/subscription.go b/model/subscription.go index 10e750c3f3..da8fdae941 100644 --- a/model/subscription.go +++ b/model/subscription.go @@ -198,11 +198,12 @@ type SubscriptionOrder struct { PlanId int `json:"plan_id" gorm:"index"` Money float64 `json:"money"` - TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"` - PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"` - Status string `json:"status"` - CreateTime int64 `json:"create_time"` - CompleteTime int64 `json:"complete_time"` + TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"` + PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"` + PaymentProvider string `json:"payment_provider" gorm:"type:varchar(50);default:''"` + Status string `json:"status"` + CreateTime int64 `json:"create_time"` + CompleteTime int64 `json:"complete_time"` ProviderPayload string `json:"provider_payload" gorm:"type:text"` } @@ -505,7 +506,9 @@ func CreateUserSubscriptionFromPlanTx(tx *gorm.DB, userId int, plan *Subscriptio } // Complete a subscription order (idempotent). Creates a UserSubscription snapshot from the plan. -func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedPaymentMethod string) error { +// expectedPaymentProvider guards against cross-gateway callback attacks (empty skips the check). +// actualPaymentMethod updates the order's PaymentMethod to reflect the real payment type used (empty skips update). +func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedPaymentProvider string, actualPaymentMethod string) error { if tradeNo == "" { return errors.New("tradeNo is empty") } @@ -523,7 +526,7 @@ func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedP if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil { return ErrSubscriptionOrderNotFound } - if expectedPaymentMethod != "" && order.PaymentMethod != expectedPaymentMethod { + if expectedPaymentProvider != "" && order.PaymentProvider != expectedPaymentProvider { return ErrPaymentMethodMismatch } if order.Status == common.TopUpStatusSuccess { @@ -552,6 +555,9 @@ func CompleteSubscriptionOrder(tradeNo string, providerPayload string, expectedP if providerPayload != "" { order.ProviderPayload = providerPayload } + if actualPaymentMethod != "" && order.PaymentMethod != actualPaymentMethod { + order.PaymentMethod = actualPaymentMethod + } if err := tx.Save(&order).Error; err != nil { return err } @@ -610,7 +616,7 @@ func upsertSubscriptionTopUpTx(tx *gorm.DB, order *SubscriptionOrder) error { return tx.Save(&topup).Error } -func ExpireSubscriptionOrder(tradeNo string, expectedPaymentMethod string) error { +func ExpireSubscriptionOrder(tradeNo string, expectedPaymentProvider string) error { if tradeNo == "" { return errors.New("tradeNo is empty") } @@ -623,7 +629,7 @@ func ExpireSubscriptionOrder(tradeNo string, expectedPaymentMethod string) error if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil { return ErrSubscriptionOrderNotFound } - if expectedPaymentMethod != "" && order.PaymentMethod != expectedPaymentMethod { + if expectedPaymentProvider != "" && order.PaymentProvider != expectedPaymentProvider { return ErrPaymentMethodMismatch } if order.Status != common.TopUpStatusPending { diff --git a/model/topup.go b/model/topup.go index c1ac663f75..c071b77b54 100644 --- a/model/topup.go +++ b/model/topup.go @@ -12,15 +12,16 @@ import ( ) type TopUp struct { - Id int `json:"id"` - UserId int `json:"user_id" gorm:"index"` - Amount int64 `json:"amount"` - Money float64 `json:"money"` - TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"` - PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"` - CreateTime int64 `json:"create_time"` - CompleteTime int64 `json:"complete_time"` - Status string `json:"status"` + Id int `json:"id"` + UserId int `json:"user_id" gorm:"index"` + Amount int64 `json:"amount"` + Money float64 `json:"money"` + TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"` + PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"` + PaymentProvider string `json:"payment_provider" gorm:"type:varchar(50);default:''"` + CreateTime int64 `json:"create_time"` + CompleteTime int64 `json:"complete_time"` + Status string `json:"status"` } const ( @@ -30,6 +31,14 @@ const ( PaymentMethodWaffoPancake = "waffo_pancake" ) +const ( + PaymentProviderEpay = "epay" + PaymentProviderStripe = "stripe" + PaymentProviderCreem = "creem" + PaymentProviderWaffo = "waffo" + PaymentProviderWaffoPancake = "waffo_pancake" +) + var ( ErrPaymentMethodMismatch = errors.New("payment method mismatch") ErrTopUpNotFound = errors.New("topup not found") @@ -68,7 +77,7 @@ func GetTopUpByTradeNo(tradeNo string) *TopUp { return topUp } -func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentMethod string, targetStatus string) error { +func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentProvider string, targetStatus string) error { if tradeNo == "" { return errors.New("未提供支付单号") } @@ -83,7 +92,7 @@ func UpdatePendingTopUpStatus(tradeNo string, expectedPaymentMethod string, targ if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(topUp).Error; err != nil { return ErrTopUpNotFound } - if expectedPaymentMethod != "" && topUp.PaymentMethod != expectedPaymentMethod { + if expectedPaymentProvider != "" && topUp.PaymentProvider != expectedPaymentProvider { return ErrPaymentMethodMismatch } if topUp.Status != common.TopUpStatusPending { @@ -114,7 +123,7 @@ func Recharge(referenceId string, customerId string, callerIp string) (err error return errors.New("充值订单不存在") } - if topUp.PaymentMethod != PaymentMethodStripe { + if topUp.PaymentProvider != PaymentProviderStripe { return ErrPaymentMethodMismatch } @@ -340,7 +349,7 @@ func ManualCompleteTopUp(tradeNo string, callerIp string) error { // 计算应充值额度: // - Stripe 订单:Money 代表经分组倍率换算后的美元数量,直接 * QuotaPerUnit // - 其他订单(如易支付):Amount 为美元数量,* QuotaPerUnit - if topUp.PaymentMethod == PaymentMethodStripe { + if topUp.PaymentProvider == PaymentProviderStripe { dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) quotaToAdd = int(decimal.NewFromFloat(topUp.Money).Mul(dQuotaPerUnit).IntPart()) } else { @@ -397,7 +406,7 @@ func RechargeCreem(referenceId string, customerEmail string, customerName string return errors.New("充值订单不存在") } - if topUp.PaymentMethod != PaymentMethodCreem { + if topUp.PaymentProvider != PaymentProviderCreem { return ErrPaymentMethodMismatch } @@ -472,7 +481,7 @@ func RechargeWaffo(tradeNo string, callerIp string) (err error) { return errors.New("充值订单不存在") } - if topUp.PaymentMethod != PaymentMethodWaffo { + if topUp.PaymentProvider != PaymentProviderWaffo { return ErrPaymentMethodMismatch } @@ -535,7 +544,7 @@ func RechargeWaffoPancake(tradeNo string) (err error) { return errors.New("充值订单不存在") } - if topUp.PaymentMethod != PaymentMethodWaffoPancake { + if topUp.PaymentProvider != PaymentProviderWaffoPancake { return ErrPaymentMethodMismatch } From 02aacb38a2523411df93077cd52c2ddccdb86560 Mon Sep 17 00:00:00 2001 From: feitianbubu Date: Sat, 25 Apr 2026 12:41:14 +0800 Subject: [PATCH 09/67] feat: add user created_at and last_login_at --- controller/user.go | 1 + model/user.go | 8 ++++++++ .../table/users/UsersColumnDefs.jsx | 19 ++++++++++++++++++- 3 files changed, 27 insertions(+), 1 deletion(-) diff --git a/controller/user.go b/controller/user.go index d6becdd8f0..b572266863 100644 --- a/controller/user.go +++ b/controller/user.go @@ -91,6 +91,7 @@ func Login(c *gin.Context) { // setup session & cookies and then return user info func setupLogin(user *model.User, c *gin.Context) { + model.UpdateUserLastLoginAt(user.Id) session := sessions.Default(c) session.Set("id", user.Id) session.Set("username", user.Username) diff --git a/model/user.go b/model/user.go index 79e63e8fd5..b632ef9afa 100644 --- a/model/user.go +++ b/model/user.go @@ -50,6 +50,8 @@ type User struct { Setting string `json:"setting" gorm:"type:text;column:setting"` Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"` StripeCustomer string `json:"stripe_customer" gorm:"type:varchar(64);column:stripe_customer;index"` + CreatedAt int64 `json:"created_at" gorm:"autoCreateTime;column:created_at"` + LastLoginAt int64 `json:"last_login_at" gorm:"default:0;column:last_login_at"` } func (user *User) ToBaseUser() *UserBase { @@ -951,6 +953,12 @@ func GetRootUser() (user *User) { return user } +func UpdateUserLastLoginAt(id int) { + if err := DB.Model(&User{}).Where("id = ?", id).Update("last_login_at", common.GetTimestamp()).Error; err != nil { + common.SysLog("failed to update user last_login_at: " + err.Error()) + } +} + func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { if common.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeUsedQuota, id, quota) diff --git a/web/src/components/table/users/UsersColumnDefs.jsx b/web/src/components/table/users/UsersColumnDefs.jsx index dc3e6f3413..2e0d171ae4 100644 --- a/web/src/components/table/users/UsersColumnDefs.jsx +++ b/web/src/components/table/users/UsersColumnDefs.jsx @@ -29,7 +29,14 @@ import { Dropdown, } from '@douyinfe/semi-ui'; import { IconMore } from '@douyinfe/semi-icons'; -import { renderGroup, renderNumber, renderQuota } from '../../../helpers'; +import { + renderGroup, + renderNumber, + renderQuota, + timestamp2string, +} from '../../../helpers'; + +const renderTimestamp = (text) => (text ? timestamp2string(text) : '-'); /** * Render user role @@ -350,6 +357,16 @@ export const getUsersColumns = ({ dataIndex: 'invite', render: (text, record, index) => renderInviteInfo(text, record, t), }, + { + title: t('创建时间'), + dataIndex: 'created_at', + render: renderTimestamp, + }, + { + title: t('最后登录'), + dataIndex: 'last_login_at', + render: renderTimestamp, + }, { title: '', dataIndex: 'operate', From f2f3410dcfe7b9974fc167ac6ba033f5e0196c33 Mon Sep 17 00:00:00 2001 From: CaIon Date: Sat, 25 Apr 2026 13:24:20 +0800 Subject: [PATCH 10/67] feat: add `len` variable for tier conditions and LLM prompt helper --- pkg/billingexpr/billingexpr_test.go | 75 ++++++- pkg/billingexpr/compile.go | 1 + pkg/billingexpr/expr.md | 19 +- pkg/billingexpr/run.go | 4 +- pkg/billingexpr/types.go | 5 +- relay/helper/price.go | 5 +- service/quota.go | 5 +- service/tiered_settle.go | 9 + service/tiered_settle_test.go | 91 +++++++++ setting/billing_setting/tiered_billing.go | 8 +- .../components/DynamicPricingBreakdown.jsx | 4 +- web/src/constants/billing.constants.js | 17 +- web/src/helpers/render.jsx | 10 +- web/src/helpers/utils.jsx | 4 +- .../Ratio/components/TieredPricingEditor.jsx | 183 ++++++++++++++++-- 15 files changed, 393 insertions(+), 47 deletions(-) diff --git a/pkg/billingexpr/billingexpr_test.go b/pkg/billingexpr/billingexpr_test.go index fd493232b7..5a5412f085 100644 --- a/pkg/billingexpr/billingexpr_test.go +++ b/pkg/billingexpr/billingexpr_test.go @@ -1000,11 +1000,82 @@ func TestImageAudioZero(t *testing.T) { } } +// --------------------------------------------------------------------------- +// len variable tests — tier conditions based on context length +// --------------------------------------------------------------------------- + +const lenTieredExpr = `len <= 200000 ? tier("standard", p * 3 + c * 15 + cr * 0.3) : tier("long_context", p * 6 + c * 22.5 + cr * 0.6)` + +func TestLen_StandardTier(t *testing.T) { + params := billingexpr.TokenParams{P: 80000, C: 5000, Len: 100000, CR: 20000} + cost, trace, err := billingexpr.RunExpr(lenTieredExpr, params) + if err != nil { + t.Fatal(err) + } + want := 80000*3 + 5000*15 + 20000*0.3 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } + if trace.MatchedTier != "standard" { + t.Errorf("tier = %q, want standard", trace.MatchedTier) + } +} + +func TestLen_LongContextTier(t *testing.T) { + // p is low (cache subtracted), but len is high (full context) + params := billingexpr.TokenParams{P: 50000, C: 5000, Len: 300000, CR: 250000} + cost, trace, err := billingexpr.RunExpr(lenTieredExpr, params) + if err != nil { + t.Fatal(err) + } + want := 50000*6 + 5000*22.5 + 250000*0.6 + if math.Abs(cost-want) > 1e-6 { + t.Errorf("cost = %f, want %f", cost, want) + } + if trace.MatchedTier != "long_context" { + t.Errorf("tier = %q, want long_context (len=300000 > 200000)", trace.MatchedTier) + } +} + +func TestLen_BoundaryExact(t *testing.T) { + params := billingexpr.TokenParams{P: 100000, C: 1000, Len: 200000, CR: 100000} + _, trace, err := billingexpr.RunExpr(lenTieredExpr, params) + if err != nil { + t.Fatal(err) + } + if trace.MatchedTier != "standard" { + t.Errorf("tier = %q, want standard (len=200000 <= 200000)", trace.MatchedTier) + } +} + +func TestLen_BoundaryPlusOne(t *testing.T) { + params := billingexpr.TokenParams{P: 100000, C: 1000, Len: 200001, CR: 100001} + _, trace, err := billingexpr.RunExpr(lenTieredExpr, params) + if err != nil { + t.Fatal(err) + } + if trace.MatchedTier != "long_context" { + t.Errorf("tier = %q, want long_context (len=200001 > 200000)", trace.MatchedTier) + } +} + +func TestLen_ZeroDefaultsToZero(t *testing.T) { + // len defaults to 0 when not set + params := billingexpr.TokenParams{P: 1000, C: 500} + _, trace, err := billingexpr.RunExpr(lenTieredExpr, params) + if err != nil { + t.Fatal(err) + } + if trace.MatchedTier != "standard" { + t.Errorf("tier = %q, want standard (len=0 <= 200000)", trace.MatchedTier) + } +} + // --------------------------------------------------------------------------- // Benchmarks: compile vs cached execution // --------------------------------------------------------------------------- -const benchComplexExpr = `p <= 200000 ? tier("standard", p * 3 + c * 15 + cr * 0.3 + cc * 3.75 + cc1h * 6 + img * 3 + img_o * 30 + ai * 10 + ao * 40) : tier("long_context", p * 6 + c * 22.5 + cr * 0.6 + cc * 7.5 + cc1h * 12 + img * 6 + img_o * 60 + ai * 20 + ao * 80)` +const benchComplexExpr = `len <= 200000 ? tier("standard", p * 3 + c * 15 + cr * 0.3 + cc * 3.75 + cc1h * 6 + img * 3 + img_o * 30 + ai * 10 + ao * 40) : tier("long_context", p * 6 + c * 22.5 + cr * 0.6 + cc * 7.5 + cc1h * 12 + img * 6 + img_o * 60 + ai * 20 + ao * 80)` func BenchmarkExprCompile(b *testing.B) { for i := 0; i < b.N; i++ { @@ -1015,7 +1086,7 @@ func BenchmarkExprCompile(b *testing.B) { func BenchmarkExprRunCached(b *testing.B) { billingexpr.CompileFromCache(benchComplexExpr) - params := billingexpr.TokenParams{P: 150000, C: 10000, CR: 30000, CC: 5000, Img: 2000, AI: 1000, AO: 500} + params := billingexpr.TokenParams{P: 150000, C: 10000, Len: 188000, CR: 30000, CC: 5000, Img: 2000, AI: 1000, AO: 500} b.ResetTimer() for i := 0; i < b.N; i++ { billingexpr.RunExpr(benchComplexExpr, params) diff --git a/pkg/billingexpr/compile.go b/pkg/billingexpr/compile.go index 089b75f67f..c41aed75c7 100644 --- a/pkg/billingexpr/compile.go +++ b/pkg/billingexpr/compile.go @@ -41,6 +41,7 @@ var ( var compileEnvPrototypeV1 = map[string]interface{}{ "p": float64(0), "c": float64(0), + "len": float64(0), "cr": float64(0), "cc": float64(0), "cc1h": float64(0), diff --git a/pkg/billingexpr/expr.md b/pkg/billingexpr/expr.md index ab3b716433..89894ab0b4 100644 --- a/pkg/billingexpr/expr.md +++ b/pkg/billingexpr/expr.md @@ -30,7 +30,8 @@ Powered by [expr-lang/expr](https://github.com/expr-lang/expr). Expressions are | 变量 | 含义 | |------|------| -| `p` | 输入 token 数。**自动排除**表达式中单独计价的子类别(见下方说明) | +| `p` | 输入 token 数(**计价用**)。**自动排除**表达式中单独计价的子类别(见下方说明) | +| `len` | 输入上下文总长度(**条件判断用**)。不受自动排除影响,始终反映完整输入长度。非 Claude:等于原始 `prompt_tokens`;Claude:等于文本输入 + 缓存读取 + 缓存创建 | | `cr` | 缓存命中(读取)token 数 | | `cc` | 缓存创建 token 数(Claude 5分钟 TTL / 通用) | | `cc1h` | 缓存创建 token 数 — 1小时 TTL(Claude 专用) | @@ -51,6 +52,8 @@ Powered by [expr-lang/expr](https://github.com/expr-lang/expr). Expressions are **规则:如果表达式使用了某个子类别变量,对应的 token 就从 `p` 或 `c` 中扣除;如果没使用,那些 token 就留在 `p` 或 `c` 里按基础价格计费。** +> **重要:`len` 不受自动排除影响。** `len` 始终代表完整的输入上下文长度,不管表达式是否单独对缓存/图片/音频定价。因此**阶梯条件应使用 `len` 而非 `p`**,以避免缓存命中导致 `p` 降低而误判档位。 + 举例说明(假设上游返回的原始数据:prompt_tokens=1000,其中包含 200 cache read、100 image): | 表达式 | `p` 的值 | 说明 | @@ -93,8 +96,8 @@ Powered by [expr-lang/expr](https://github.com/expr-lang/expr). Expressions are # Simple flat pricing tier("base", p * 2.5 + c * 15 + cr * 0.25) -# Multi-tier (Claude Sonnet style) -p <= 200000 +# Multi-tier (Claude Sonnet style) — use len for tier conditions +len <= 200000 ? tier("standard", p * 3 + c * 15 + cr * 0.3 + cc * 3.75 + cc1h * 6) : tier("long_context", p * 6 + c * 22.5 + cr * 0.6 + cc * 7.5 + cc1h * 12) @@ -199,6 +202,16 @@ Example: `p * 2.5 + c * 15 + cr * 0.25` - Expression uses `cr` → cache read tokens subtracted from `p` - Expression doesn't use `img` → image tokens stay in `p`, priced at $2.50 +### `len` — Context Length Variable + +`len` represents the total input context length, designed for **tier condition evaluation** (e.g. `len <= 200000 ? ...`). Unlike `p`, `len` is never reduced by sub-category exclusion. + +**Computation rules:** +- **Non-Claude (GPT/OpenAI format)**: `len = prompt_tokens` (the raw total from the upstream response) +- **Claude format**: `len = input_tokens + cache_read_tokens + cache_creation_tokens` (since Claude's `input_tokens` is text-only, cache must be added back to reflect full context length) + +This ensures that heavy cache usage doesn't cause the tier condition to incorrectly evaluate to a lower tier. For example, if a request has 300K total context but 250K is cached, `p` with cache subtracted would be only 50K (standard tier), while `len` correctly reports 300K (long-context tier). + ### Quota Conversion Expression coefficients are $/1M tokens. Conversion to internal quota: diff --git a/pkg/billingexpr/run.go b/pkg/billingexpr/run.go index 9df43b39fe..d477d44e91 100644 --- a/pkg/billingexpr/run.go +++ b/pkg/billingexpr/run.go @@ -13,7 +13,8 @@ import ( // RunExpr compiles (with cache) and executes an expression string. // The environment exposes: -// - p, c — prompt / completion tokens +// - p, c — prompt / completion tokens (auto-excluding separately-priced sub-categories) +// - len — total input context length for tier conditions (never reduced by sub-category exclusion) // - cr, cc, cc1h — cache read / creation / creation-1h tokens // - tier(name, value) — trace callback that records which tier matched // - max, min, abs, ceil, floor — standard math helpers @@ -54,6 +55,7 @@ func runProgram(prog *vm.Program, params TokenParams, request RequestInput) (flo env := map[string]interface{}{ "p": params.P, "c": params.C, + "len": params.Len, "cr": params.CR, "cc": params.CC, "cc1h": params.CC1h, diff --git a/pkg/billingexpr/types.go b/pkg/billingexpr/types.go index 5e43339419..12e0d3c685 100644 --- a/pkg/billingexpr/types.go +++ b/pkg/billingexpr/types.go @@ -14,8 +14,9 @@ type RequestInput struct { // Fields beyond P and C are optional — when absent they default to 0, // which means cache-unaware expressions keep working unchanged. type TokenParams struct { - P float64 // prompt tokens (text) - C float64 // completion tokens (text) + P float64 // prompt tokens (text) — auto-excludes sub-categories priced separately + C float64 // completion tokens (text) — auto-excludes sub-categories priced separately + Len float64 // total input context length for tier conditions (non-Claude: raw prompt_tokens; Claude: text + cache read + cache creation) CR float64 // cache read (hit) tokens CC float64 // cache creation tokens (5-min TTL for Claude, generic for others) CC1h float64 // cache creation tokens — 1-hour TTL (Claude only) diff --git a/relay/helper/price.go b/relay/helper/price.go index a078325261..0e68edba20 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -255,8 +255,9 @@ func modelPriceHelperTiered(c *gin.Context, info *relaycommon.RelayInfo, promptT } rawCost, trace, err := billingexpr.RunExprWithRequest(exprStr, billingexpr.TokenParams{ - P: float64(promptTokens), - C: float64(estimatedCompletionTokens), + P: float64(promptTokens), + C: float64(estimatedCompletionTokens), + Len: float64(promptTokens), }, requestInput) if err != nil { return types.PriceData{}, fmt.Errorf("model %s tiered expr run failed: %w", info.OriginModelName, err) diff --git a/service/quota.go b/service/quota.go index 1f1f76aefd..398bd1b792 100644 --- a/service/quota.go +++ b/service/quota.go @@ -160,8 +160,9 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod var tieredResult *billingexpr.TieredResult tieredOk, tieredQuota, tieredRes := TryTieredSettle(relayInfo, billingexpr.TokenParams{ - P: float64(usage.InputTokens), - C: float64(usage.OutputTokens), + P: float64(usage.InputTokens), + C: float64(usage.OutputTokens), + Len: float64(usage.InputTokens), }) if tieredOk { tieredResult = tieredRes diff --git a/service/tiered_settle.go b/service/tiered_settle.go index fd168ab2ab..a97ec088d0 100644 --- a/service/tiered_settle.go +++ b/service/tiered_settle.go @@ -35,6 +35,14 @@ func BuildTieredTokenParams(usage *dto.Usage, isClaudeUsageSemantic bool, usedVa imgO := float64(usage.CompletionTokenDetails.ImageTokens) ao := float64(usage.CompletionTokenDetails.AudioTokens) + // len = total input context length for tier condition evaluation. + // Non-Claude: prompt_tokens already includes everything. + // Claude: input_tokens is text-only, so add cache read + cache creation. + inputLen := p + if isClaudeUsageSemantic { + inputLen = p + cr + cc5m + cc1h + } + if !isClaudeUsageSemantic { if usedVars["cr"] { p -= cr @@ -69,6 +77,7 @@ func BuildTieredTokenParams(usage *dto.Usage, isClaudeUsageSemantic bool, usedVa return billingexpr.TokenParams{ P: p, C: c, + Len: inputLen, CR: cr, CC: cc5m, CC1h: cc1h, diff --git a/service/tiered_settle_test.go b/service/tiered_settle_test.go index b7ba9f287c..cb6676fcc4 100644 --- a/service/tiered_settle_test.go +++ b/service/tiered_settle_test.go @@ -604,6 +604,97 @@ func TestBuildTieredTokenParams_ParityWithRatio_Image(t *testing.T) { } } +// --------------------------------------------------------------------------- +// BuildTieredTokenParams: Len computation tests +// --------------------------------------------------------------------------- + +func TestBuildTieredTokenParams_Len_GPT(t *testing.T) { + usage := &dto.Usage{ + PromptTokens: 10000, + CompletionTokens: 2000, + PromptTokensDetails: dto.InputTokenDetails{ + CachedTokens: 3000, + TextTokens: 7000, + }, + } + expr := `tier("base", p * 2.5 + c * 15 + cr * 0.25)` + usedVars := billingexpr.UsedVars(expr) + params := BuildTieredTokenParams(usage, false, usedVars) + + // Non-Claude: Len = raw PromptTokens + if params.Len != 10000 { + t.Fatalf("Len = %f, want 10000 (raw PromptTokens)", params.Len) + } + // P should be reduced by cache + if params.P != 7000 { + t.Fatalf("P = %f, want 7000 (PromptTokens - CachedTokens)", params.P) + } +} + +func TestBuildTieredTokenParams_Len_Claude(t *testing.T) { + usage := &dto.Usage{ + PromptTokens: 5000, + CompletionTokens: 2000, + UsageSemantic: "anthropic", + PromptTokensDetails: dto.InputTokenDetails{ + CachedTokens: 3000, + TextTokens: 5000, + }, + ClaudeCacheCreation5mTokens: 1000, + ClaudeCacheCreation1hTokens: 500, + } + expr := `tier("base", p * 3 + c * 15 + cr * 0.3 + cc * 3.75 + cc1h * 6)` + usedVars := billingexpr.UsedVars(expr) + params := BuildTieredTokenParams(usage, true, usedVars) + + // Claude: Len = PromptTokens + CachedTokens + CacheCreation5m + CacheCreation1h + wantLen := float64(5000 + 3000 + 1000 + 500) + if params.Len != wantLen { + t.Fatalf("Len = %f, want %f (text + cache read + cache creation)", params.Len, wantLen) + } + // Claude: P is not reduced (isClaudeUsageSemantic = true) + if params.P != 5000 { + t.Fatalf("P = %f, want 5000 (no subtraction for Claude)", params.P) + } +} + +func TestBuildTieredTokenParams_Len_TierCondition(t *testing.T) { + // Test that len-based tier conditions work correctly when p is reduced by cache + usage := &dto.Usage{ + PromptTokens: 300000, + CompletionTokens: 5000, + PromptTokensDetails: dto.InputTokenDetails{ + CachedTokens: 250000, + TextTokens: 50000, + }, + } + expr := `len <= 200000 ? tier("standard", p * 3 + c * 15 + cr * 0.3) : tier("long_context", p * 6 + c * 22.5 + cr * 0.6)` + usedVars := billingexpr.UsedVars(expr) + params := BuildTieredTokenParams(usage, false, usedVars) + + // Len = 300000 (raw prompt), P = 50000 (300000 - 250000 cache) + if params.Len != 300000 { + t.Fatalf("Len = %f, want 300000", params.Len) + } + if params.P != 50000 { + t.Fatalf("P = %f, want 50000", params.P) + } + + // Run expression: len=300000 > 200000, so long_context tier + cost, trace, err := billingexpr.RunExpr(expr, params) + if err != nil { + t.Fatal(err) + } + if trace.MatchedTier != "long_context" { + t.Fatalf("tier = %s, want long_context (len=300000 but p=50000)", trace.MatchedTier) + } + // long_context: 50000*6 + 5000*22.5 + 250000*0.6 + wantCost := 50000.0*6 + 5000*22.5 + 250000*0.6 + if math.Abs(cost-wantCost) > 1e-6 { + t.Fatalf("cost = %f, want %f", cost, wantCost) + } +} + // --------------------------------------------------------------------------- // Stress test: 1000 concurrent goroutines, complex tiered expr vs ratio, // random token counts, verify correctness and measure performance diff --git a/setting/billing_setting/tiered_billing.go b/setting/billing_setting/tiered_billing.go index 65f0ef2da4..8d5b6f0f4a 100644 --- a/setting/billing_setting/tiered_billing.go +++ b/setting/billing_setting/tiered_billing.go @@ -54,10 +54,10 @@ func SmokeTestExpr(exprStr string) error { func smokeTestExpr(exprStr string) error { vectors := []billingexpr.TokenParams{ - {P: 0, C: 0}, - {P: 1000, C: 1000}, - {P: 100000, C: 100000}, - {P: 1000000, C: 1000000}, + {P: 0, C: 0, Len: 0}, + {P: 1000, C: 1000, Len: 1000}, + {P: 100000, C: 100000, Len: 100000}, + {P: 1000000, C: 1000000, Len: 1000000}, } requests := []billingexpr.RequestInput{ {}, diff --git a/web/src/components/table/model-pricing/modal/components/DynamicPricingBreakdown.jsx b/web/src/components/table/model-pricing/modal/components/DynamicPricingBreakdown.jsx index 794627dd5f..23d1712a1b 100644 --- a/web/src/components/table/model-pricing/modal/components/DynamicPricingBreakdown.jsx +++ b/web/src/components/table/model-pricing/modal/components/DynamicPricingBreakdown.jsx @@ -21,7 +21,7 @@ import React from 'react'; import { Avatar, Tag, Table, Typography } from '@douyinfe/semi-ui'; import { IconPriceTag } from '@douyinfe/semi-icons'; import { parseTiersFromExpr, getCurrencyConfig } from '../../../../../helpers'; -import { BILLING_VARS } from '../../../../../constants'; +import { BILLING_PRICING_VARS } from '../../../../../constants'; import { splitBillingExprAndRequestRules, tryParseRequestRuleExpr, @@ -113,7 +113,7 @@ export default function DynamicPricingBreakdown({ billingExpr, t }) { ); } - const priceFields = BILLING_VARS.map((v) => [v.field, v.shortLabel]); + const priceFields = BILLING_PRICING_VARS.map((v) => [v.field, v.shortLabel]); const tierColumns = [ { diff --git a/web/src/constants/billing.constants.js b/web/src/constants/billing.constants.js index 79ef32866d..28114808bd 100644 --- a/web/src/constants/billing.constants.js +++ b/web/src/constants/billing.constants.js @@ -13,6 +13,7 @@ export const BILLING_VARS = [ { key: 'p', field: 'inputPrice', tierField: 'input_unit_cost', label: '输入价格', shortLabel: '输入', side: 'input', isBase: true }, { key: 'c', field: 'outputPrice', tierField: 'output_unit_cost', label: '补全价格', shortLabel: '补全', side: 'output', isBase: true }, + { key: 'len', field: null, tierField: null, label: '输入长度', shortLabel: '长度', side: 'condition', isConditionOnly: true }, { key: 'cr', field: 'cacheReadPrice', tierField: 'cache_read_unit_cost', label: '缓存读取价格', shortLabel: '缓存读', side: 'input', group: 'cache' }, { key: 'cc', field: 'cacheCreatePrice', tierField: 'cache_create_unit_cost', label: '缓存创建价格', shortLabel: '缓存创建', side: 'input', group: 'cache' }, { key: 'cc1h', field: 'cacheCreate1hPrice', tierField: 'cache_create_1h_unit_cost', label: '1h缓存创建价格', shortLabel: '1h缓存创建', side: 'input', group: 'cache' }, @@ -24,18 +25,20 @@ export const BILLING_VARS = [ export const BILLING_VAR_KEYS = BILLING_VARS.map((v) => v.key); -export const BILLING_EXTRA_VARS = BILLING_VARS.filter((v) => !v.isBase); +export const BILLING_PRICING_VARS = BILLING_VARS.filter((v) => !v.isConditionOnly); + +export const BILLING_EXTRA_VARS = BILLING_VARS.filter((v) => !v.isBase && !v.isConditionOnly); export const BILLING_VAR_KEY_TO_FIELD = Object.fromEntries( - BILLING_VARS.map((v) => [v.key, v.field]), + BILLING_PRICING_VARS.map((v) => [v.key, v.field]), ); export const BILLING_VAR_FIELD_TO_LABEL = Object.fromEntries( - BILLING_VARS.map((v) => [v.field, v.label]), + BILLING_PRICING_VARS.map((v) => [v.field, v.label]), ); export const BILLING_VAR_FIELD_TO_SHORT_LABEL = Object.fromEntries( - BILLING_VARS.map((v) => [v.field, v.shortLabel]), + BILLING_PRICING_VARS.map((v) => [v.field, v.shortLabel]), ); export const BILLING_CACHE_VAR_MAP = BILLING_EXTRA_VARS.map((v) => ({ @@ -44,6 +47,10 @@ export const BILLING_CACHE_VAR_MAP = BILLING_EXTRA_VARS.map((v) => ({ })); export const BILLING_VAR_REGEX = new RegExp( - `\\b(${BILLING_VAR_KEYS.join('|')})\\s*\\*\\s*([\\d.eE+-]+)`, + `\\b(${BILLING_PRICING_VARS.map((v) => v.key).join('|')})\\s*\\*\\s*([\\d.eE+-]+)`, 'g', ); + +export const BILLING_CONDITION_VARS = BILLING_VARS.filter( + (v) => v.isBase || v.isConditionOnly, +).map((v) => v.key); diff --git a/web/src/helpers/render.jsx b/web/src/helpers/render.jsx index d7ba654648..f5ce1f0391 100644 --- a/web/src/helpers/render.jsx +++ b/web/src/helpers/render.jsx @@ -22,7 +22,7 @@ import { Modal, Tag, Typography, Avatar } from '@douyinfe/semi-ui'; import { copy, showSuccess } from './utils'; import { MOBILE_BREAKPOINT } from '../hooks/common/useIsMobile'; import { - BILLING_VARS, + BILLING_PRICING_VARS, BILLING_VAR_KEY_TO_FIELD, BILLING_VAR_REGEX, } from '../constants'; @@ -2246,7 +2246,7 @@ export function parseTiersFromExpr(exprStr) { if (!exprStr) return []; try { const { body } = stripExprVersion(exprStr); - const condGroup = `((?:(?:p|c)\\s*(?:<|<=|>|>=)\\s*[\\d.eE+]+)(?:\\s*&&\\s*(?:p|c)\\s*(?:<|<=|>|>=)\\s*[\\d.eE+]+)*)`; + const condGroup = `((?:(?:p|c|len)\\s*(?:<|<=|>|>=)\\s*[\\d.eE+]+)(?:\\s*&&\\s*(?:p|c|len)\\s*(?:<|<=|>|>=)\\s*[\\d.eE+]+)*)`; const tierRe = new RegExp(`(?:${condGroup}\\s*\\?\\s*)?tier\\("([^"]*)",\\s*([^)]+)\\)`, 'g'); const tiers = []; let m; @@ -2255,7 +2255,7 @@ export function parseTiersFromExpr(exprStr) { const conditions = []; if (condStr) { for (const cp of condStr.split(/\s*&&\s*/)) { - const cm = cp.trim().match(/^(p|c)\s*(<|<=|>|>=)\s*([\d.eE+]+)$/); + const cm = cp.trim().match(/^(p|c|len)\s*(<|<=|>|>=)\s*([\d.eE+]+)$/); if (cm) conditions.push({ var: cm[1], op: cm[2], value: Number(cm[3]) }); } } @@ -2293,7 +2293,7 @@ export function renderTieredModelPrice(opts) { const { symbol, rate } = getCurrencyConfig(); const gr = groupRatio || 1; - const priceLines = BILLING_VARS.map((v) => [v.field, v.label]); + const priceLines = BILLING_PRICING_VARS.map((v) => [v.field, v.label]); const lines = [ buildBillingText('命中档位:{{tier}}', { tier: matchedTier || tier.label }), @@ -2334,7 +2334,7 @@ export function renderTieredModelPriceSimple(opts) { ]; if (tier && isPriceDisplayMode(displayMode)) { - const priceSegments = BILLING_VARS.map((v) => [v.field, v.shortLabel]); + const priceSegments = BILLING_PRICING_VARS.map((v) => [v.field, v.shortLabel]); for (const [field, label] of priceSegments) { if (tier[field] > 0) { segments.push({ diff --git a/web/src/helpers/utils.jsx b/web/src/helpers/utils.jsx index a4af4f2431..7c7e63c737 100644 --- a/web/src/helpers/utils.jsx +++ b/web/src/helpers/utils.jsx @@ -18,7 +18,7 @@ For commercial licensing, please contact support@quantumnous.com */ import { Toast, Pagination } from '@douyinfe/semi-ui'; -import { toastConstants, BILLING_VARS, BILLING_VAR_REGEX } from '../constants'; +import { toastConstants, BILLING_PRICING_VARS, BILLING_VAR_REGEX } from '../constants'; import React from 'react'; import { toast } from 'react-toastify'; import { @@ -927,7 +927,7 @@ export const formatDynamicPriceSummary = (billingExpr, t, groupRatio = 1) => { } const hasCoeffs = 'p' in varCoeffs || 'c' in varCoeffs; - const varLabels = BILLING_VARS.map((v) => [v.key, v.label]); + const varLabels = BILLING_PRICING_VARS.map((v) => [v.key, v.label]); const hasTimeCondition = /\b(?:hour|minute|weekday|month|day)\(/.test(exprBody); const hasRequestCondition = /\b(?:param|header)\(/.test(exprBody); diff --git a/web/src/pages/Setting/Ratio/components/TieredPricingEditor.jsx b/web/src/pages/Setting/Ratio/components/TieredPricingEditor.jsx index ec06a3409c..4ad94f736c 100644 --- a/web/src/pages/Setting/Ratio/components/TieredPricingEditor.jsx +++ b/web/src/pages/Setting/Ratio/components/TieredPricingEditor.jsx @@ -31,9 +31,10 @@ import { TextArea, Typography, } from '@douyinfe/semi-ui'; -import { IconDelete, IconPlus } from '@douyinfe/semi-icons'; +import { IconCopy, IconDelete, IconPlus } from '@douyinfe/semi-icons'; import { renderQuota } from '../../../../helpers/render'; -import { BILLING_EXTRA_VARS, BILLING_CACHE_VAR_MAP } from '../../../../constants'; +import { copy, showSuccess } from '../../../../helpers'; +import { BILLING_EXTRA_VARS, BILLING_CACHE_VAR_MAP, BILLING_CONDITION_VARS } from '../../../../constants'; import { createEmptyCondition, createEmptyTimeCondition, @@ -70,6 +71,7 @@ function priceToUnitCost(price) { const OPS = ['<', '<=', '>', '>=']; const VAR_OPTIONS = [ + { value: 'len', label: 'len (长度)' }, { value: 'p', label: 'p (输入)' }, { value: 'c', label: 'c (输出)' }, ]; @@ -224,7 +226,7 @@ function tryParseVisualConfig(exprStr) { } // Multi-tier: cond1 ? tier(body) : cond2 ? tier(body) : tier(body) - const condGroup = `((?:(?:p|c)\\s*(?:<|<=|>|>=)\\s*[\\d.eE+]+)(?:\\s*&&\\s*(?:p|c)\\s*(?:<|<=|>|>=)\\s*[\\d.eE+]+)*)`; + const condGroup = `((?:(?:p|c|len)\\s*(?:<|<=|>|>=)\\s*[\\d.eE+]+)(?:\\s*&&\\s*(?:p|c|len)\\s*(?:<|<=|>|>=)\\s*[\\d.eE+]+)*)`; const tierRe = new RegExp( `(?:${condGroup}\\s*\\?\\s*)?tier\\("([^"]*)",\\s*${bodyPat}\\)`, 'g', @@ -237,7 +239,7 @@ function tryParseVisualConfig(exprStr) { if (condStr) { const condParts = condStr.split(/\s*&&\s*/); for (const cp of condParts) { - const cm = cp.trim().match(/^(p|c)\s*(<|<=|>|>=)\s*([\d.eE+]+)$/); + const cm = cp.trim().match(/^(p|c|len)\s*(<|<=|>|>=)\s*([\d.eE+]+)$/); if (cm) { conditions.push({ var: cm[1], op: cm[2], value: Number(cm[3]) }); } @@ -283,7 +285,7 @@ function ConditionRow({ cond, onChange, onRemove, t }) { }}>