From 730d996276221b04b471c8bef1196f3a2059394e Mon Sep 17 00:00:00 2001 From: lrt4836 <296659110@qq.com> Date: Fri, 22 May 2026 21:54:42 +0800 Subject: [PATCH 1/8] feat(channels/yuanbao): add Yuanbao channel provider Adds a new Yuanbao channel provider so OpenHuman can talk to the Yuanbao bot service over signed WebSocket + HTTPS. - Backend: full provider under `src/openhuman/channels/providers/yuanbao/` (signed WS connection, inbound/outbound pipelines, COS media upload, protobuf wire codec, ID shortening for conversation-store filenames), wired through channel config schema, registry, runtime startup, controllers, and CLI. - Frontend: dedicated `YuanbaoConfig` form + icon, hooked into the channel selector, setup modal, and connection slice using the existing per-channel definition pattern. - Endpoint config: prod / pre-release defaults selectable via `env` field; no test credentials in code. --- Cargo.lock | 1 + Cargo.toml | 3 + app/src-tauri/Cargo.lock | 1 + .../components/channels/ChannelSelector.tsx | 16 +- .../components/channels/ChannelSetupModal.tsx | 7 +- app/src/components/channels/YuanbaoConfig.tsx | 296 ++++++++ app/src/components/channels/YuanbaoIcon.tsx | 56 ++ app/src/components/skills/skillIcons.tsx | 10 + app/src/store/channelConnectionsSlice.ts | 17 + app/src/types/channels.ts | 2 +- src/openhuman/channels/commands.rs | 8 + .../channels/controllers/definitions.rs | 39 + src/openhuman/channels/controllers/ops.rs | 98 +++ .../channels/controllers/ops_tests.rs | 121 ++++ src/openhuman/channels/mod.rs | 2 + src/openhuman/channels/providers/mod.rs | 1 + .../channels/providers/yuanbao/channel.rs | 531 ++++++++++++++ .../channels/providers/yuanbao/config.rs | 250 +++++++ .../channels/providers/yuanbao/connection.rs | 571 +++++++++++++++ .../channels/providers/yuanbao/cos.rs | 392 ++++++++++ .../channels/providers/yuanbao/errors.rs | 61 ++ .../channels/providers/yuanbao/ids.rs | 115 +++ .../channels/providers/yuanbao/inbound.rs | 623 ++++++++++++++++ .../channels/providers/yuanbao/media.rs | 339 +++++++++ .../channels/providers/yuanbao/mod.rs | 29 + .../channels/providers/yuanbao/outbound.rs | 378 ++++++++++ .../channels/providers/yuanbao/proto.rs | 676 ++++++++++++++++++ .../channels/providers/yuanbao/proto_biz.rs | 417 +++++++++++ .../providers/yuanbao/proto_constants.rs | 90 +++ .../channels/providers/yuanbao/sign.rs | 446 ++++++++++++ .../channels/providers/yuanbao/splitter.rs | 209 ++++++ .../channels/providers/yuanbao/types.rs | 265 +++++++ .../channels/providers/yuanbao/wire.rs | 232 ++++++ src/openhuman/channels/runtime/startup.rs | 8 + src/openhuman/config/schema/channels.rs | 3 + 35 files changed, 6308 insertions(+), 5 deletions(-) create mode 100644 app/src/components/channels/YuanbaoConfig.tsx create mode 100644 app/src/components/channels/YuanbaoIcon.tsx create mode 100644 src/openhuman/channels/providers/yuanbao/channel.rs create mode 100644 src/openhuman/channels/providers/yuanbao/config.rs create mode 100644 src/openhuman/channels/providers/yuanbao/connection.rs create mode 100644 src/openhuman/channels/providers/yuanbao/cos.rs create mode 100644 src/openhuman/channels/providers/yuanbao/errors.rs create mode 100644 src/openhuman/channels/providers/yuanbao/ids.rs create mode 100644 src/openhuman/channels/providers/yuanbao/inbound.rs create mode 100644 src/openhuman/channels/providers/yuanbao/media.rs create mode 100644 src/openhuman/channels/providers/yuanbao/mod.rs create mode 100644 src/openhuman/channels/providers/yuanbao/outbound.rs create mode 100644 src/openhuman/channels/providers/yuanbao/proto.rs create mode 100644 src/openhuman/channels/providers/yuanbao/proto_biz.rs create mode 100644 src/openhuman/channels/providers/yuanbao/proto_constants.rs create mode 100644 src/openhuman/channels/providers/yuanbao/sign.rs create mode 100644 src/openhuman/channels/providers/yuanbao/splitter.rs create mode 100644 src/openhuman/channels/providers/yuanbao/types.rs create mode 100644 src/openhuman/channels/providers/yuanbao/wire.rs diff --git a/Cargo.lock b/Cargo.lock index 8a062d75a5..1f23a1de55 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5141,6 +5141,7 @@ dependencies = [ "serde-big-array", "serde_json", "serde_yaml", + "sha1", "sha2 0.10.9", "shellexpand", "socketioxide", diff --git a/Cargo.toml b/Cargo.toml index 06409c20fb..6ad486f4e1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,6 +60,9 @@ argon2 = "0.5" rand = "0.10" dirs = "5" sha2 = "0.10" +# Legacy SHA-1 only used for Tencent COS HMAC-SHA1 signing (yuanbao +# channel media upload). Not used for any new security-sensitive work. +sha1 = "0.10" hmac = "0.12" # Archive extraction for the Node.js runtime bootstrap. Unix Node # distributions ship as .tar.xz, Windows as .zip. `xz2` with `static` diff --git a/app/src-tauri/Cargo.lock b/app/src-tauri/Cargo.lock index 74ff66efb9..18ec6d2660 100644 --- a/app/src-tauri/Cargo.lock +++ b/app/src-tauri/Cargo.lock @@ -5286,6 +5286,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", + "sha1", "sha2 0.10.9", "shellexpand", "socketioxide", diff --git a/app/src/components/channels/ChannelSelector.tsx b/app/src/components/channels/ChannelSelector.tsx index 6e41f598a0..770587146b 100644 --- a/app/src/components/channels/ChannelSelector.tsx +++ b/app/src/components/channels/ChannelSelector.tsx @@ -5,6 +5,7 @@ import { useT } from '../../lib/i18n/I18nContext'; import { useAppSelector } from '../../store/hooks'; import type { ChannelConnectionStatus, ChannelDefinition, ChannelType } from '../../types/channels'; import ChannelStatusBadge from './ChannelStatusBadge'; +import YuanbaoIcon from './YuanbaoIcon'; interface ChannelSelectorProps { definitions: ChannelDefinition[]; @@ -12,17 +13,28 @@ interface ChannelSelectorProps { onSelectChannel: (channel: ChannelType) => void; } +// Emoji icons for channels rendered as plain text. `yuanbao` is handled +// separately with a branded SVG (see `YuanbaoIcon`). const CHANNEL_ICONS: Record = { telegram: '✈️', discord: '🎮', web: '🌐', + yuanbao: '🟡', mcp: '🔌', }; +const renderChannelIcon = (icon: string) => + icon === 'yuanbao' ? ( + + ) : ( + {CHANNEL_ICONS[icon] ?? ''} + ); + /** Virtual (static) tabs that are not backed by a ChannelDefinition from the core. */ const VIRTUAL_TABS: { id: ChannelType; display_name: string }[] = [ { id: 'mcp', display_name: 'MCP Servers' }, ]; + const CHANNEL_STATUS_PRIORITY: ChannelConnectionStatus[] = [ 'connected', 'connecting', @@ -84,7 +96,7 @@ const ChannelSelector = ({ : 'border-stone-200 dark:border-neutral-800 bg-stone-50 dark:bg-neutral-800/60 text-stone-600 dark:text-neutral-300 hover:border-stone-300 dark:hover:border-neutral-700' }`}> - {CHANNEL_ICONS[def.icon] ?? ''} + {renderChannelIcon(def.icon)} {def.display_name} @@ -105,7 +117,7 @@ const ChannelSelector = ({ ? 'border-primary-500/60 bg-primary-50 dark:bg-primary-500/15 text-primary-600 dark:text-primary-300' : 'border-stone-200 dark:border-neutral-800 bg-stone-50 dark:bg-neutral-800/60 text-stone-600 dark:text-neutral-300 hover:border-stone-300 dark:hover:border-neutral-700' }`}> - {CHANNEL_ICONS[tab.id] ?? ''} + {renderChannelIcon(tab.id)} {tab.display_name} ); diff --git a/app/src/components/channels/ChannelSetupModal.tsx b/app/src/components/channels/ChannelSetupModal.tsx index cf81b7c7e2..c3b2502294 100644 --- a/app/src/components/channels/ChannelSetupModal.tsx +++ b/app/src/components/channels/ChannelSetupModal.tsx @@ -9,6 +9,7 @@ import { useT } from '../../lib/i18n/I18nContext'; import type { ChannelDefinition, ChannelType } from '../../types/channels'; import DiscordConfig from './DiscordConfig'; import TelegramConfig from './TelegramConfig'; +import YuanbaoConfig from './YuanbaoConfig'; const CHANNEL_ICONS: Record = { telegram: '\u2708\uFE0F', @@ -29,6 +30,8 @@ function ChannelConfigContent({ definition }: { definition: ChannelDefinition }) return ; case 'discord': return ; + case 'yuanbao': + return ; default: return (

@@ -62,7 +65,7 @@ export default function ChannelSetupModal({ definition, onClose }: ChannelSetupM if (e.target === e.currentTarget) onClose(); }; - const icon = CHANNEL_ICONS[definition.icon] ?? ''; + const emojiIcon = CHANNEL_ICONS[definition.icon] ?? ''; const modalContent = (

- {icon && {icon}} + {emojiIcon && {emojiIcon}}

diff --git a/app/src/components/channels/YuanbaoConfig.tsx b/app/src/components/channels/YuanbaoConfig.tsx new file mode 100644 index 0000000000..33b78f8986 --- /dev/null +++ b/app/src/components/channels/YuanbaoConfig.tsx @@ -0,0 +1,296 @@ +import debug from 'debug'; +import { useCallback, useEffect, useState } from 'react'; + +import { AUTH_MODE_LABELS } from '../../lib/channels/definitions'; +import { useT } from '../../lib/i18n/I18nContext'; +import { channelConnectionsApi } from '../../services/api/channelConnectionsApi'; +import { + disconnectChannelConnection, + setChannelConnectionStatus, + upsertChannelConnection, +} from '../../store/channelConnectionsSlice'; +import { useAppDispatch, useAppSelector } from '../../store/hooks'; +import type { ChannelConnectionStatus, ChannelDefinition } from '../../types/channels'; +import { restartCoreProcess } from '../../utils/tauriCommands/core'; +import ChannelFieldInput from './ChannelFieldInput'; +import ChannelStatusBadge from './ChannelStatusBadge'; + +const log = debug('channels:yuanbao'); + +interface YuanbaoConfigProps { + definition: ChannelDefinition; +} + +const YuanbaoConfig = ({ definition }: YuanbaoConfigProps) => { + const { t } = useT(); + const dispatch = useAppDispatch(); + const channelConnections = useAppSelector(state => state.channelConnections); + + const [busy, setBusy] = useState(false); + const [fieldValues, setFieldValues] = useState>({ + app_key: '', + app_secret: '', + }); + // Per-field inline validation errors, keyed by field.key. + const [fieldErrors, setFieldErrors] = useState>({}); + + const updateField = useCallback((fieldKey: string, value: string) => { + setFieldValues(prev => ({ ...prev, [fieldKey]: value })); + // Clear the error for this field as the user types. + setFieldErrors(prev => { + if (!prev[fieldKey]) return prev; + const next = { ...prev }; + delete next[fieldKey]; + return next; + }); + }, []); + + const spec = definition.auth_modes[0]; + + // On mount, reset any stale 'connecting' state persisted from a previous session. + useEffect(() => { + if (!spec) return; + const conn = channelConnections.connections.yuanbao?.[spec.mode]; + if (conn?.status === 'connecting') { + dispatch( + setChannelConnectionStatus({ + channel: 'yuanbao', + authMode: spec.mode, + status: 'disconnected', + }) + ); + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); + + // All useCallback hooks must be called unconditionally. + const handleConnect = useCallback(() => { + console.log('[YuanbaoConfig] handleConnect: 1.entry, spec=', spec); + if (!spec) { + console.warn('[YuanbaoConfig] handleConnect: aborted — spec is null'); + return; + } + + const errors: Record = {}; + for (const field of spec.fields) { + const empty = !fieldValues[field.key]?.trim(); + if (field.required && empty) { + errors[field.key] = `${field.label} 不能为空`; + } + } + if (Object.keys(errors).length > 0) { + console.warn('[YuanbaoConfig] handleConnect: 2.validation FAILED', errors); + setFieldErrors(errors); + return; + } + console.log('[YuanbaoConfig] handleConnect: 2.validation passed'); + + setFieldErrors({}); + setBusy(true); + + dispatch( + setChannelConnectionStatus({ channel: 'yuanbao', authMode: spec.mode, status: 'connecting' }) + ); + + const credentials: Record = {}; + for (const field of spec.fields) { + const val = fieldValues[field.key]?.trim() ?? ''; + if (val) credentials[field.key] = val; + } + console.log( + '[YuanbaoConfig] handleConnect: 3.dispatched connecting, credential keys=', + Object.keys(credentials) + ); + + void (async () => { + try { + console.log( + '[YuanbaoConfig] handleConnect: 4.before channels_connect RPC, authMode=', + spec.mode + ); + log('connecting yuanbao via %s', spec.mode); + const result = await channelConnectionsApi.connectChannel('yuanbao', { + authMode: spec.mode, + credentials, + }); + console.log('[YuanbaoConfig] handleConnect: 5.RPC returned', result); + log('connect result: %o', result); + + if (result.restart_required) { + console.log( + '[YuanbaoConfig] handleConnect: 6.restart_required=true, calling restartCoreProcess' + ); + log('restart required after connect — restarting core process'); + try { + await restartCoreProcess(); + console.log( + '[YuanbaoConfig] handleConnect: 7.restartCoreProcess resolved, dispatching connected' + ); + dispatch( + upsertChannelConnection({ + channel: 'yuanbao', + authMode: spec.mode, + patch: { + status: 'connected', + lastError: undefined, + capabilities: ['read', 'write'], + }, + }) + ); + } catch (restartErr) { + const msg = restartErr instanceof Error ? restartErr.message : String(restartErr); + console.error('[YuanbaoConfig] handleConnect: 7.restartCoreProcess FAILED', restartErr); + log('core restart failed: %s', msg); + dispatch( + setChannelConnectionStatus({ + channel: 'yuanbao', + authMode: spec.mode, + status: 'error', + lastError: t('channels.telegram.savedRestartRequired'), + }) + ); + } + } else { + console.log( + '[YuanbaoConfig] handleConnect: 6.restart_required=false, dispatching connected' + ); + dispatch( + upsertChannelConnection({ + channel: 'yuanbao', + authMode: spec.mode, + patch: { status: 'connected', lastError: undefined, capabilities: ['read', 'write'] }, + }) + ); + } + } catch (e) { + const msg = e instanceof Error ? e.message : String(e); + console.error('[YuanbaoConfig] handleConnect: X.caught error', e); + dispatch( + setChannelConnectionStatus({ + channel: 'yuanbao', + authMode: spec.mode, + status: 'error', + lastError: msg, + }) + ); + } finally { + console.log('[YuanbaoConfig] handleConnect: 8.finally, setBusy(false)'); + setBusy(false); + } + })(); + }, [dispatch, fieldValues, spec, t]); + + const handleDisconnect = useCallback(() => { + if (!spec) return; + setBusy(true); + void (async () => { + try { + log('disconnecting yuanbao'); + await channelConnectionsApi.disconnectChannel('yuanbao', spec.mode); + dispatch(disconnectChannelConnection({ channel: 'yuanbao', authMode: spec.mode })); + } catch (e) { + const msg = e instanceof Error ? e.message : String(e); + dispatch( + setChannelConnectionStatus({ + channel: 'yuanbao', + authMode: spec.mode, + status: 'error', + lastError: msg, + }) + ); + } finally { + setBusy(false); + } + })(); + }, [dispatch, spec]); + + if (!spec) return null; + + const connection = channelConnections.connections.yuanbao?.[spec.mode]; + const status: ChannelConnectionStatus = connection?.status ?? 'disconnected'; + + return ( +
+
+
+
+

+ {AUTH_MODE_LABELS[spec.mode] ?? spec.mode} +

+

{spec.description}

+ {connection?.lastError && ( +

{connection.lastError}

+ )} +
+ +
+ + {spec.fields.length > 0 && ( +
+ {spec.fields.map(field => { + return ( +
+ updateField(field.key, val)} + disabled={busy} + /> + {fieldErrors[field.key] && ( +

+ {fieldErrors[field.key]} +

+ )} +
+ ); + })} +
+ )} + +
+ + +
+
+
+ ); +}; + +export default YuanbaoConfig; diff --git a/app/src/components/channels/YuanbaoIcon.tsx b/app/src/components/channels/YuanbaoIcon.tsx new file mode 100644 index 0000000000..5a1261b9b9 --- /dev/null +++ b/app/src/components/channels/YuanbaoIcon.tsx @@ -0,0 +1,56 @@ +import { useId } from 'react'; + +interface YuanbaoIconProps { + /** + * Tailwind size + color overrides. Defaults to a 20px box, matching + * the visual weight of the channel-row emojis it sits next to. + */ + className?: string; +} + +/** + * Brand mark for the Yuanbao channel. Inlined as an SVG component so it + * can be tinted / sized via Tailwind without round-tripping through an + * `` element. `clipPath` ids are generated with `useId` so multiple + * instances on the same page (channel selector + setup modal) don't + * collide in the DOM. + */ +const YuanbaoIcon = ({ className = 'w-5 h-5' }: YuanbaoIconProps) => { + const clipId = useId(); + return ( + + ); +}; + +export default YuanbaoIcon; diff --git a/app/src/components/skills/skillIcons.tsx b/app/src/components/skills/skillIcons.tsx index f7d8038cad..8791fe7e96 100644 --- a/app/src/components/skills/skillIcons.tsx +++ b/app/src/components/skills/skillIcons.tsx @@ -2,6 +2,8 @@ import type { ReactNode } from 'react'; import type { IconType } from 'react-icons'; import { FaDiscord, FaGlobe, FaTelegramPlane } from 'react-icons/fa'; import { IoChatbubble } from 'react-icons/io5'; + +import YuanbaoIcon from '../channels/YuanbaoIcon'; import { LuBlocks, LuBot, @@ -84,6 +86,14 @@ export function getChannelIcons( iconClassName="text-[#34C759]" /> ), + yuanbao: ( + + + + ), }; } diff --git a/app/src/store/channelConnectionsSlice.ts b/app/src/store/channelConnectionsSlice.ts index faee8db9c5..b785894590 100644 --- a/app/src/store/channelConnectionsSlice.ts +++ b/app/src/store/channelConnectionsSlice.ts @@ -34,9 +34,21 @@ const initialState: ChannelConnectionsState = { // MCP Servers tab is a virtual channel — no auth-mode connections, // but must be present to satisfy Record. mcp: makeEmptyChannelModes(), + yuanbao: makeEmptyChannelModes(), }, }; +// Lazy-init a channel's mode bucket on first write. Protects writes against +// rehydrated state from older app versions (or any channel added after a user +// first persisted state) where `state.connections[channel]` is `undefined` +// because redux-persist's default `autoMergeLevel1` reconciler does not deep- +// merge into `connections`. +function ensureChannelModes(state: ChannelConnectionsState, channel: ChannelType): void { + if (!state.connections[channel]) { + state.connections[channel] = makeEmptyChannelModes(); + } +} + function touchConnection( existing: ChannelConnection | undefined, patch: Partial & { channel: ChannelType; authMode: ChannelAuthMode } @@ -68,11 +80,13 @@ const channelConnectionsSlice = createSlice({ // explicit initialisation here, the first `upsertChannelConnection` // for either channel would crash on `state.connections[channel]` // being undefined. Pin them by default so the migration is total. + state.connections.yuanbao = makeEmptyChannelModes(); state.connections.lark = makeEmptyChannelModes(); state.connections.dingtalk = makeEmptyChannelModes(); // MCP virtual channel must be present in persisted states migrated from // before PR #2276 or the Record shape is incomplete. state.connections.mcp = makeEmptyChannelModes(); + state.connections.yuanbao = makeEmptyChannelModes(); state.defaultMessagingChannel = 'telegram'; state.migrationCompleted = true; state.schemaVersion = SCHEMA_VERSION; @@ -91,6 +105,7 @@ const channelConnectionsSlice = createSlice({ }> ) { const { channel, authMode, patch } = action.payload; + ensureChannelModes(state, channel); const existing = state.connections[channel][authMode]; state.connections[channel][authMode] = touchConnection(existing, { channel, @@ -109,6 +124,7 @@ const channelConnectionsSlice = createSlice({ }> ) { const { channel, authMode, status, lastError } = action.payload; + ensureChannelModes(state, channel); const existing = state.connections[channel][authMode]; state.connections[channel][authMode] = touchConnection(existing, { channel, @@ -123,6 +139,7 @@ const channelConnectionsSlice = createSlice({ action: PayloadAction<{ channel: ChannelType; authMode: ChannelAuthMode }> ) { const { channel, authMode } = action.payload; + ensureChannelModes(state, channel); state.connections[channel][authMode] = touchConnection(state.connections[channel][authMode], { channel, authMode, diff --git a/app/src/types/channels.ts b/app/src/types/channels.ts index ad3be4cf16..05ca949fa3 100644 --- a/app/src/types/channels.ts +++ b/app/src/types/channels.ts @@ -1,4 +1,4 @@ -export type ChannelType = 'telegram' | 'discord' | 'web' | 'lark' | 'dingtalk' | 'mcp'; +export type ChannelType = 'telegram' | 'discord' | 'web' | 'lark' | 'dingtalk' | 'mcp' | 'yuanbao'; export type ChannelAuthMode = 'managed_dm' | 'oauth' | 'bot_token' | 'api_key'; diff --git a/src/openhuman/channels/commands.rs b/src/openhuman/channels/commands.rs index bd12e19343..105f40e119 100644 --- a/src/openhuman/channels/commands.rs +++ b/src/openhuman/channels/commands.rs @@ -17,6 +17,7 @@ use super::telegram::TelegramChannel; use super::whatsapp::WhatsAppChannel; #[cfg(feature = "whatsapp-web")] use super::whatsapp_web::WhatsAppWebChannel; +use super::yuanbao::YuanbaoChannel; use super::Channel; use crate::openhuman::config::Config; use anyhow::Result; @@ -235,6 +236,13 @@ pub async fn doctor_channels(config: Config) -> Result<()> { )); } + if let Some(ref yb) = config.channels_config.yuanbao { + match YuanbaoChannel::new(yb.clone()) { + Ok(ch) => channels.push(("Yuanbao", Arc::new(ch))), + Err(e) => tracing::warn!("Yuanbao config invalid, skipping: {}", e), + } + } + if channels.is_empty() { println!("No real-time channels configured. Configure channels in the web UI."); return Ok(()); diff --git a/src/openhuman/channels/controllers/definitions.rs b/src/openhuman/channels/controllers/definitions.rs index 445437e193..ccb7baad9f 100644 --- a/src/openhuman/channels/controllers/definitions.rs +++ b/src/openhuman/channels/controllers/definitions.rs @@ -160,6 +160,7 @@ pub fn all_channel_definitions() -> Vec { imessage_definition(), lark_definition(), dingtalk_definition(), + yuanbao_definition(), ] } @@ -444,6 +445,44 @@ fn dingtalk_definition() -> ChannelDefinition { } } +fn yuanbao_definition() -> ChannelDefinition { + // Endpoint URLs (api_domain / ws_domain) are not user-facing — the + // channel derives them from the `env` field of `YuanbaoConfig` + // (default: production). Advanced users can override via TOML. + ChannelDefinition { + id: "yuanbao", + display_name: "元宝", + description: "通过元宝(Yuanbao)机器人收发消息。", + icon: "yuanbao", + auth_modes: vec![AuthModeSpec { + mode: ChannelAuthMode::ApiKey, + description: "提供元宝开放平台的 AppID 和 AppSecret。", + fields: vec![ + FieldRequirement { + key: "app_key", + label: "AppID", + field_type: "string", + required: true, + placeholder: "元宝开放平台 AppID", + }, + FieldRequirement { + key: "app_secret", + label: "AppSecret", + field_type: "secret", + required: true, + placeholder: "元宝开放平台 AppSecret", + }, + ], + auth_action: None, + }], + capabilities: vec![ + ChannelCapability::SendText, + ChannelCapability::ReceiveText, + ChannelCapability::Typing, + ], + } +} + #[cfg(test)] #[path = "definitions_tests.rs"] mod tests; diff --git a/src/openhuman/channels/controllers/ops.rs b/src/openhuman/channels/controllers/ops.rs index cf5a3a3296..805f79b5c0 100644 --- a/src/openhuman/channels/controllers/ops.rs +++ b/src/openhuman/channels/controllers/ops.rs @@ -6,6 +6,7 @@ use serde_json::{json, Value}; use crate::api::config::{app_env_from_env, effective_backend_api_url, is_staging_app_env}; use crate::api::jwt::get_session_token; use crate::api::rest::BackendOAuthClient; +use crate::openhuman::channels::providers::yuanbao::sign::SignManager; use crate::openhuman::config::{Config, DiscordConfig, IMessageConfig, TelegramConfig}; use crate::openhuman::credentials; use crate::rpc::RpcOutcome; @@ -108,6 +109,46 @@ fn parse_optional_bool(value: Option<&Value>) -> Option { } } +/// Verify Yuanbao credentials against the `sign-token` endpoint before any +/// persistence so invalid `app_key` / `app_secret` surface the upstream API +/// error to the user instead of silently succeeding. +/// +/// Honours an explicit `api_domain` already configured in TOML; otherwise +/// derives it from `env` (prod by default). +async fn verify_yuanbao_credentials( + config: &Config, + creds_map: &serde_json::Map, +) -> Result<(), String> { + let app_key = creds_map + .get("app_key") + .and_then(|v| v.as_str()) + .map(str::trim) + .filter(|s| !s.is_empty()) + .ok_or_else(|| "missing required app_key".to_string())?; + let app_secret = creds_map + .get("app_secret") + .and_then(|v| v.as_str()) + .map(str::trim) + .filter(|s| !s.is_empty()) + .ok_or_else(|| "missing required app_secret".to_string())?; + + let mut yb_config = config.channels_config.yuanbao.clone().unwrap_or_default(); + if yb_config.api_domain.is_empty() { + yb_config.apply_env_defaults(); + } + + SignManager::new(reqwest::Client::new()) + .get_token( + app_key, + app_secret, + &yb_config.api_domain, + &yb_config.route_env, + ) + .await + .map_err(|e| format!("yuanbao credential verification failed: {e}"))?; + Ok(()) +} + /// List all available channel definitions. pub async fn list_channels() -> Result>, String> { Ok(RpcOutcome::new(all_channel_definitions(), vec![])) @@ -160,6 +201,13 @@ pub async fn connect_channel( def.validate_credentials(auth_mode, creds_map)?; + // Yuanbao: verify credentials with the sign-token endpoint before any + // persistence so invalid creds surface the upstream API error to the + // user without leaving dangling credential entries or TOML state. + if channel_id == "yuanbao" && auth_mode == ChannelAuthMode::ApiKey { + verify_yuanbao_credentials(config, creds_map).await?; + } + // iMessage is local-only (no credentials): persist channels_config + return connected. if channel_id == "imessage" && auth_mode == ChannelAuthMode::ManagedDm { let allowed_contacts = parse_allowed_users(creds_map.get("allowed_contacts")); @@ -332,6 +380,41 @@ pub async fn connect_channel( mention_only, "[discord] connect_channel: wrote channels_config.discord; restart core for listener to load token" ); + } else if channel_id == "yuanbao" && auth_mode == ChannelAuthMode::ApiKey { + let app_key = creds_map + .get("app_key") + .and_then(|v| v.as_str()) + .map(str::trim) + .filter(|s| !s.is_empty()) + .ok_or_else(|| "missing required app_key".to_string())? + .to_string(); + let app_secret = creds_map + .get("app_secret") + .and_then(|v| v.as_str()) + .map(str::trim) + .filter(|s| !s.is_empty()) + .ok_or_else(|| "missing required app_secret".to_string())? + .to_string(); + + let mut persisted = config.clone(); + let mut yb_config = persisted + .channels_config + .yuanbao + .clone() + .unwrap_or_default(); + yb_config.app_key = app_key; + yb_config.app_secret = app_secret; + persisted.channels_config.yuanbao = Some(yb_config); + + persisted + .save() + .await + .map_err(|e| format!("failed to persist yuanbao config.toml: {e}"))?; + + tracing::info!( + target: "openhuman::channels", + "[yuanbao] connect_channel: wrote channels_config.yuanbao; restart core for WS listener" + ); } Ok(RpcOutcome::single_log( @@ -402,6 +485,18 @@ pub async fn disconnect_channel( "[imessage] disconnect_channel: cleared channels_config.imessage" ); } + } else if channel_id == "yuanbao" && auth_mode == ChannelAuthMode::ApiKey { + let mut persisted = config.clone(); + if persisted.channels_config.yuanbao.take().is_some() { + persisted + .save() + .await + .map_err(|e| format!("failed to clear yuanbao config.toml: {e}"))?; + tracing::info!( + target: "openhuman::channels", + "[yuanbao] disconnect_channel: cleared channels_config.yuanbao" + ); + } } Ok(RpcOutcome::single_log( @@ -507,6 +602,9 @@ pub async fn connected_channel_slugs(config: &Config) -> Result, Str if cc.imessage.is_some() { slugs.insert("imessage".to_string()); } + if cc.yuanbao.is_some() { + slugs.insert("yuanbao".to_string()); + } if cc.irc.is_some() { slugs.insert("irc".to_string()); } diff --git a/src/openhuman/channels/controllers/ops_tests.rs b/src/openhuman/channels/controllers/ops_tests.rs index aa19ab4d97..3d2d055c6c 100644 --- a/src/openhuman/channels/controllers/ops_tests.rs +++ b/src/openhuman/channels/controllers/ops_tests.rs @@ -1,4 +1,5 @@ use super::*; +use crate::openhuman::channels::providers::yuanbao::YuanbaoConfig; use tempfile::tempdir; fn isolated_test_config() -> (tempfile::TempDir, Config) { @@ -482,3 +483,123 @@ async fn connected_channel_slugs_empty_when_nothing_configured() { "fresh config should yield no channels: {slugs:?}" ); } + +// ── Yuanbao channel credential verification ──────────────────── +// Issue: connect_channel for yuanbao previously stored creds and returned +// "connected" without ever calling the upstream sign-token endpoint, so +// random input (e.g. app_key=12) showed as Connected in the UI. The fix +// calls `/api/v5/robotLogic/sign-token` and propagates the API error. + +/// Build a Config pre-pointed at a mock `api_domain` so the verification +/// step hits the wiremock server instead of the live prod URL. +fn yuanbao_test_config(mock_uri: &str) -> (tempfile::TempDir, Config) { + let (tmp, mut config) = isolated_test_config(); + config.channels_config.yuanbao = Some(YuanbaoConfig { + api_domain: mock_uri.to_string(), + ..Default::default() + }); + (tmp, config) +} + +#[tokio::test] +async fn connect_yuanbao_rejects_invalid_credentials() { + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/v5/robotLogic/sign-token")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 40001, + "msg": "invalid signature", + }))) + .mount(&server) + .await; + + let (_tmp, config) = yuanbao_test_config(&server.uri()); + let err = connect_channel( + &config, + "yuanbao", + ChannelAuthMode::ApiKey, + serde_json::json!({ "app_key": "12", "app_secret": "12" }), + ) + .await + .expect_err("invalid yuanbao credentials should fail"); + + assert!( + err.contains("yuanbao credential verification failed") && err.contains("invalid signature"), + "expected upstream API msg in error, got: {err}" + ); + + // Nothing should be persisted on failure: no TOML write, no credential row. + let raw = tokio::fs::read_to_string(&config.config_path).await.ok(); + if let Some(text) = raw { + let parsed: toml::Value = toml::from_str(&text).expect("config parses"); + // The mock api_domain we pre-loaded is allowed to be present, but + // app_key / app_secret must NOT have been written. + if let Some(yb) = parsed + .get("channels_config") + .and_then(|v| v.get("yuanbao")) + .and_then(toml::Value::as_table) + { + assert_ne!( + yb.get("app_key").and_then(toml::Value::as_str), + Some("12"), + "app_key must not be persisted when verification fails" + ); + } + } +} + +#[tokio::test] +async fn connect_yuanbao_persists_when_credentials_valid() { + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/v5/robotLogic/sign-token")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 0, + "data": { + "token": "tok-abc", + "bot_id": "bot-123", + "product": "yuanbao", + "source": "openhuman", + "duration": 3600, + } + }))) + .mount(&server) + .await; + + let (_tmp, config) = yuanbao_test_config(&server.uri()); + let result = connect_channel( + &config, + "yuanbao", + ChannelAuthMode::ApiKey, + serde_json::json!({ "app_key": "real-key", "app_secret": "real-secret" }), + ) + .await + .expect("valid yuanbao credentials should succeed"); + + assert_eq!(result.value.status, "connected"); + assert!(result.value.restart_required); + + let raw = tokio::fs::read_to_string(&config.config_path) + .await + .expect("config should be persisted"); + let parsed: toml::Value = toml::from_str(&raw).expect("config parses"); + let yb = parsed + .get("channels_config") + .and_then(|v| v.get("yuanbao")) + .and_then(toml::Value::as_table) + .expect("channels_config.yuanbao persisted"); + assert_eq!( + yb.get("app_key").and_then(toml::Value::as_str), + Some("real-key") + ); + assert_eq!( + yb.get("app_secret").and_then(toml::Value::as_str), + Some("real-secret") + ); +} diff --git a/src/openhuman/channels/mod.rs b/src/openhuman/channels/mod.rs index 2101542973..41f9679504 100644 --- a/src/openhuman/channels/mod.rs +++ b/src/openhuman/channels/mod.rs @@ -34,6 +34,7 @@ pub use providers::web; pub use providers::whatsapp; #[cfg(feature = "whatsapp-web")] pub use providers::whatsapp_web; +pub use providers::yuanbao; pub use cli::CliChannel; pub use dingtalk::DingTalkChannel; @@ -54,6 +55,7 @@ pub use traits::{Channel, SendMessage}; pub use whatsapp::WhatsAppChannel; #[cfg(feature = "whatsapp-web")] pub use whatsapp_web::WhatsAppWebChannel; +pub use yuanbao::YuanbaoChannel; pub use commands::doctor_channels; pub use controllers::{ChannelAuthMode, ChannelDefinition}; diff --git a/src/openhuman/channels/providers/mod.rs b/src/openhuman/channels/providers/mod.rs index 34bae4714a..d6844be06d 100644 --- a/src/openhuman/channels/providers/mod.rs +++ b/src/openhuman/channels/providers/mod.rs @@ -19,3 +19,4 @@ pub mod web; pub mod whatsapp; #[cfg(feature = "whatsapp-web")] pub mod whatsapp_web; +pub mod yuanbao; diff --git a/src/openhuman/channels/providers/yuanbao/channel.rs b/src/openhuman/channels/providers/yuanbao/channel.rs new file mode 100644 index 0000000000..40cf27a85f --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/channel.rs @@ -0,0 +1,531 @@ +//! Channel facade for the Yuanbao provider. +//! +//! This module owns the OpenHuman [`Channel`] implementation and keeps +//! provider wiring out of `mod.rs`. Protocol decoding, transport, inbound +//! filtering, and outbound sending remain delegated to sibling modules. + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; +use tokio::sync::{mpsc, watch, Mutex as TokioMutex}; +use tokio::task::JoinHandle; +use tracing::{info, warn}; + +use crate::openhuman::channels::traits::{Channel, ChannelMessage, SendMessage}; + +use super::config::YuanbaoConfig; +use super::connection::{InboundEvent, YuanbaoConnection}; +use super::ids::{shorten_account_id, shorten_reply_target}; +use super::inbound::{InboundPipeline, PipelineOutcome, PipelineState}; +use super::outbound::OutboundSender; +use super::proto::decode_push_msg; +use super::sign::SignManager; +use super::{splitter, types}; + +/// Reply Heartbeat keepalive interval. The yuanbao gateway expects the +/// bot to ping (`SendPrivateHeartbeat RUNNING`) at this cadence so the +/// "正在输入" indicator stays alive for long-running responses. +const REPLY_HEARTBEAT_INTERVAL_SECS: u64 = 2; + +/// Hard ceiling on the in-memory shortened-recipient → original-recipient +/// map. Each entry is two short strings (~80 B), so 4096 distinct senders +/// give ~320 KB — plenty for any realistic chat load and small enough +/// that we can blow the whole map away when we hit the cap instead of +/// dragging in an LRU dependency. See `register_recipient_alias`. +const RECIPIENT_ALIAS_CAP: usize = 4096; + +/// The yuanbao channel — owns one WebSocket and one inbound pipeline. +pub struct YuanbaoChannel { + config: YuanbaoConfig, + connection: Arc, + outbound: Arc, + pipeline: Arc, + shutdown_tx: watch::Sender, + /// Holds the inbound receiver between `new()` and the first `listen()` call. + /// + /// `Channel::listen` takes `&self`, so we can't move the receiver out of + /// a field. Use a `Mutex>` so the first listener takes ownership + /// and subsequent calls fail cleanly. + inbound_rx: parking_lot::Mutex>>, + /// Per-recipient Reply Heartbeat keepalive tasks (started on `start_typing`). + heartbeat_tasks: TokioMutex>>, + /// Reverse lookup table from shortened recipient ids (the ones we + /// emit on `ChannelMessage.sender` / `reply_target`) back to the + /// original server-recognized ids that outbound `send_c2c_message` + /// / `send_group_message` must use as `to_account` / `group_code`. + /// + /// Why this exists: yuanbao uids are ~64-char hashes, and + /// `super::ids::shorten_account_id` rewrites them as + /// `_` so the conversation store's per-thread + /// JSONL filenames stay under filesystem `NAME_MAX`. Without this + /// table the agent loop sends replies addressed to the shortened + /// hash, which the yuanbao gateway silently drops because no such + /// user exists. See `register_recipient_alias` / `resolve_recipient`. + recipient_aliases: TokioMutex>, +} + +impl YuanbaoChannel { + /// Build a channel from a validated config. Returns an error if the + /// config is missing required fields (so misconfiguration surfaces + /// at startup, not on the first inbound message). + pub fn new(mut config: YuanbaoConfig) -> anyhow::Result { + config.apply_env_defaults(); + config.validate()?; + let (shutdown_tx, _shutdown_rx) = watch::channel(false); + let (inbound_tx, inbound_rx) = mpsc::unbounded_channel::(); + + // SignManager is only useful when we have an app_secret — without + // it we'd never call the sign endpoint anyway. + let sign_manager: Option> = if !config.app_secret.is_empty() { + Some(SignManager::new(reqwest::Client::new())) + } else { + None + }; + + let connection = YuanbaoConnection::new(config.clone(), inbound_tx, sign_manager.clone()); + let outbound = Arc::new(OutboundSender::new( + Arc::clone(&connection), + sign_manager.clone(), + config.app_key.clone(), + config.bot_id.clone(), + )); + // PipelineState's `from_account` is used by the echo-guard stage to + // drop self-sent messages. We feed it the static config value here + // (which may be empty); the canonical server-issued bot_id only + // becomes known after sign-token, so this is a known minor gap — + // echo guard will simply not fire when bot_id isn't statically set. + let pipeline_state = PipelineState::new(&config, config.bot_id.clone()); + let pipeline = Arc::new(InboundPipeline::new(pipeline_state)); + + Ok(Self { + config, + connection, + outbound, + pipeline, + shutdown_tx, + inbound_rx: parking_lot::Mutex::new(Some(inbound_rx)), + heartbeat_tasks: TokioMutex::new(HashMap::new()), + recipient_aliases: TokioMutex::new(HashMap::new()), + }) + } + + /// Record a `shortened → original` recipient mapping so the outbound + /// side can recover the server-recognized id when the agent loop + /// addresses a reply with the shortened sender / reply_target it + /// received on `ChannelMessage`. + /// + /// No-op when the two are equal (uid is short enough to skip + /// shortening, or this is the `g:` group-target case where the + /// inner code is short). When the map crosses `RECIPIENT_ALIAS_CAP` + /// we clear it — the next inbound message from each active sender + /// re-populates the entry it needs, and stale entries from idle + /// conversations are fine to lose. + async fn register_recipient_alias(&self, shortened: &str, original: &str) { + if shortened == original { + return; + } + let mut m = self.recipient_aliases.lock().await; + if m.len() >= RECIPIENT_ALIAS_CAP { + warn!( + "[yuanbao] recipient alias map hit cap ({}), clearing", + RECIPIENT_ALIAS_CAP + ); + m.clear(); + } + m.insert(shortened.to_string(), original.to_string()); + } + + /// Look up the server-recognized recipient for a (possibly + /// shortened) inbound id. Falls back to the input unchanged when + /// nothing is registered — which keeps the previous behavior for + /// recipients that don't go through `shorten_account_id` (short + /// uids, group codes, `imessage`-style ids). + async fn resolve_recipient(&self, recipient: &str) -> String { + let m = self.recipient_aliases.lock().await; + m.get(recipient) + .cloned() + .unwrap_or_else(|| recipient.to_string()) + } + + fn split_message(&self, text: &str) -> Vec { + splitter::split_markdown(text, self.config.max_message_length) + } + + async fn start_heartbeat_task(&self, recipient: &str) { + let mut tasks = self.heartbeat_tasks.lock().await; + if tasks.contains_key(recipient) { + return; + } + let outbound = Arc::clone(&self.outbound); + let target = recipient.to_string(); + let handle = tokio::spawn(async move { + let mut interval = + tokio::time::interval(Duration::from_secs(REPLY_HEARTBEAT_INTERVAL_SECS)); + interval.tick().await; // skip first tick (start_typing already sent RUNNING) + loop { + interval.tick().await; + if let Err(e) = outbound.start_heartbeat(&target).await { + // Connection bouncing — bail out of this loop; the + // next start_typing call will spawn a new one. + warn!( + "[yuanbao] reply heartbeat send failed: {} — stopping loop", + e + ); + return; + } + } + }); + tasks.insert(recipient.to_string(), handle); + } + + async fn stop_heartbeat_task(&self, recipient: &str) { + let mut tasks = self.heartbeat_tasks.lock().await; + if let Some(handle) = tasks.remove(recipient) { + handle.abort(); + } + } +} + +#[async_trait] +impl Channel for YuanbaoChannel { + fn name(&self) -> &str { + "yuanbao" + } + + async fn send(&self, message: &SendMessage) -> anyhow::Result<()> { + let chunks = self.split_message(&message.content); + let ref_msg_id = message.thread_ts.as_deref(); + let recipient = self.resolve_recipient(&message.recipient).await; + for chunk in &chunks { + self.outbound + .send_text(&recipient, chunk, ref_msg_id) + .await?; + } + Ok(()) + } + + fn supports_draft_updates(&self) -> bool { + // Routes turns through the streaming code path even though Yuanbao + // itself has no edit-message capability. We accept the UX cost (no + // progressive rendering — the reply appears all at once in + // `finalize_draft`) in exchange for streaming's tolerance of + // malformed `usage` chunks; the non-streaming parser fails the + // whole turn when an upstream LLM returns string-typed token counts. + true + } + + async fn send_draft(&self, message: &SendMessage) -> anyhow::Result> { + // Marker id so dispatch spins up the progress consumer task; + // nothing is sent to the user here. Real content goes out in + // `finalize_draft`. See `supports_draft_updates` for rationale. + Ok(Some(format!("yb-draft:{}", message.recipient))) + } + + async fn update_draft( + &self, + _recipient: &str, + _message_id: &str, + _text: &str, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn finalize_draft( + &self, + recipient: &str, + _message_id: &str, + text: &str, + thread_ts: Option<&str>, + ) -> anyhow::Result<()> { + let chunks = self.split_message(text); + let recipient = self.resolve_recipient(recipient).await; + for chunk in &chunks { + self.outbound + .send_text(&recipient, chunk, thread_ts) + .await?; + } + Ok(()) + } + + async fn listen(&self, tx: mpsc::Sender) -> anyhow::Result<()> { + // Take the inbound receiver. A second listener would just exit early. + let mut inbound_rx = match self.inbound_rx.lock().take() { + Some(rx) => rx, + None => { + warn!("[yuanbao] listen() called twice — second call exits"); + return Ok(()); + } + }; + + let conn = Arc::clone(&self.connection); + let shutdown_rx = self.shutdown_tx.subscribe(); + let conn_task = tokio::spawn(async move { + conn.run(shutdown_rx).await; + }); + + info!("[yuanbao] channel listening — pipeline ready"); + let mut shutdown_rx2 = self.shutdown_tx.subscribe(); + loop { + tokio::select! { + _ = shutdown_rx2.changed() => { + info!("[yuanbao] listen loop received shutdown"); + break; + } + event = inbound_rx.recv() => { + match event { + Some(InboundEvent::Push(frame)) => { + self.dispatch_push(frame, &tx).await; + } + Some(InboundEvent::Kickout(reason)) => { + warn!("[yuanbao] kickout: {} — stopping listen loop", reason); + break; + } + None => { + warn!("[yuanbao] inbound channel closed"); + break; + } + } + } + } + } + + let _ = self.shutdown_tx.send(true); + conn_task.abort(); + Ok(()) + } + + async fn health_check(&self) -> bool { + self.connection.is_connected() + } + + async fn start_typing(&self, recipient: &str) -> anyhow::Result<()> { + // Send RUNNING immediately, then spawn a 2s keepalive so the + // indicator doesn't expire while we generate. + let recipient = self.resolve_recipient(recipient).await; + self.outbound.start_heartbeat(&recipient).await?; + self.start_heartbeat_task(&recipient).await; + Ok(()) + } + + async fn stop_typing(&self, recipient: &str) -> anyhow::Result<()> { + let recipient = self.resolve_recipient(recipient).await; + self.stop_heartbeat_task(&recipient).await; + self.outbound.stop_heartbeat(&recipient).await?; + Ok(()) + } + + fn supports_reactions(&self) -> bool { + false + } +} + +impl YuanbaoChannel { + async fn dispatch_push(&self, frame: types::ConnFrame, tx: &mpsc::Sender) { + // The Yuanbao gateway pushes inbound messages with `cmd_type=Push`; + // the actual `cmd` word is decided server-side and varies (mirrors + // hermes-agent `yuanbao.py::_handle_received_frame` which routes + // purely on cmd_type). The connection layer has already filtered + // out non-push frames before we get here, so every frame we see + // should be a candidate for the inbound pipeline. + if frame.data.is_empty() { + tracing::trace!("[yuanbao] empty push body cmd={} — skipping", frame.cmd); + return; + } + // Some push frames wrap the biz body in an extra + // `PushMsg { cmd, module, msg_id, data }` envelope; others (e.g. + // cmd="inbound_message", module="yuanbao_openclaw_proxy") put the + // InboundMessagePush bytes directly in `ConnMsg.data` with the + // ConnMsg.head already carrying cmd/module. Mirrors plugin + // client.ts::onPush (l. 813): try PushMsg first, but only accept + // it when it has a non-empty cmd or module; otherwise treat the + // raw frame.data as the biz body. + let unwrapped: Option> = match decode_push_msg(&frame.data) { + Ok(p) if (!p.cmd.is_empty() || !p.module.is_empty()) && !p.data.is_empty() => { + info!( + "[yuanbao] push envelope decoded: cmd={} module={} msg_id={} biz_len={}", + p.cmd, + p.module, + p.msg_id, + p.data.len() + ); + Some(p.data) + } + _ => { + info!( + "[yuanbao] push has no PushMsg envelope — treating ConnMsg.data as biz body (conn_cmd={} module={} len={})", + frame.cmd, + frame.module, + frame.data.len() + ); + None + } + }; + let biz_body: &[u8] = unwrapped.as_deref().unwrap_or(&frame.data); + let outcome = self.pipeline.process(biz_body).await; + match outcome { + PipelineOutcome::Dispatch(ctx) => { + // Shorten ids at the channel boundary so the composite thread_id + // derived downstream (channel:yuanbao__) + // stays under filesystem NAME_MAX once hex-encoded for the + // per-thread JSONL filename. Yuanbao internals (echo guard, + // access control, owner-command check) keep the original + // `from_account` — see `super::ids` for the format and rationale. + let original_from = ctx.msg.from_account.clone(); + let original_reply_target = ctx.source.reply_target(); + let short_sender = shorten_account_id(&original_from); + let short_reply_target = shorten_reply_target(&original_reply_target); + // Remember the original ids so the outbound side can + // recover them when the agent loop addresses a reply + // with the shortened values it sees here. + self.register_recipient_alias(&short_sender, &original_from) + .await; + self.register_recipient_alias(&short_reply_target, &original_reply_target) + .await; + let msg = ChannelMessage { + id: ctx.msg.msg_id.clone(), + sender: short_sender, + reply_target: short_reply_target, + content: if ctx.text.is_empty() && !ctx.image_urls.is_empty() { + // Surface image URLs as content so downstream tools have something to work with. + ctx.image_urls.join("\n") + } else { + ctx.text.clone() + }, + channel: "yuanbao".into(), + timestamp: ctx.msg.msg_time as u64, + thread_ts: None, + }; + if tx.send(msg).await.is_err() { + warn!("[yuanbao] dispatch receiver gone — dropping message"); + } + } + PipelineOutcome::Filtered(reason) => { + tracing::trace!("[yuanbao] filtered at {reason}"); + } + PipelineOutcome::Failed(err) => { + let preview_len = biz_body.len().min(256); + let hex: String = biz_body[..preview_len] + .iter() + .map(|b| format!("{b:02x}")) + .collect(); + warn!( + "[yuanbao] pipeline error: {err} | biz_len={} biz_hex_first_{preview_len}={}", + biz_body.len(), + hex + ); + } + } + } +} + +impl Drop for YuanbaoChannel { + fn drop(&mut self) { + let _ = self.shutdown_tx.send(true); + } +} + +#[cfg(test)] +mod tests { + use crate::openhuman::channels::traits::Channel; + + use super::*; + + fn good_cfg() -> YuanbaoConfig { + let mut c = YuanbaoConfig::default(); + c.app_key = "ak".into(); + c.ws_domain = "wss://example".into(); + c.token = "tok".into(); + c + } + + #[test] + fn channel_construction_validates() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + assert_eq!(ch.name(), "yuanbao"); + } + + #[test] + fn invalid_config_rejected() { + let mut c = YuanbaoConfig::default(); + c.app_key = "ak".into(); + // missing ws_domain + assert!(YuanbaoChannel::new(c).is_err()); + } + + #[test] + fn split_short_message_returns_one() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + let chunks = ch.split_message("hello"); + assert_eq!(chunks, vec!["hello"]); + } + + #[test] + fn split_respects_newlines() { + let mut c = good_cfg(); + c.max_message_length = 12; + let ch = YuanbaoChannel::new(c).unwrap(); + let chunks = ch.split_message("line one\nline two\nline three"); + assert!(chunks.len() >= 2); + // No chunk exceeds the limit. + for chunk in &chunks { + assert!(chunk.len() <= 12, "chunk too long: {chunk:?}"); + } + } + + #[tokio::test] + async fn resolve_recipient_returns_input_when_no_alias_registered() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + assert_eq!(ch.resolve_recipient("short_uid").await, "short_uid"); + } + + #[tokio::test] + async fn register_and_resolve_dm_alias_recovers_original_uid() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + let original = "x".repeat(64); + let shortened = shorten_account_id(&original); + assert_ne!(shortened, original, "test premise: should actually shorten"); + ch.register_recipient_alias(&shortened, &original).await; + assert_eq!(ch.resolve_recipient(&shortened).await, original); + } + + #[tokio::test] + async fn register_recipient_alias_is_noop_for_equal_pair() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + // Short uid that wouldn't be shortened — caller still hands us + // (s, s); we should silently skip and not eat a map slot. + ch.register_recipient_alias("short", "short").await; + let m = ch.recipient_aliases.lock().await; + assert!(m.is_empty()); + } + + #[tokio::test] + async fn resolve_recipient_preserves_group_prefix_via_alias() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + let long_group_code = "g".repeat(64); + let original = format!("g:{long_group_code}"); + let shortened = shorten_reply_target(&original); + assert_ne!(shortened, original); + assert!(shortened.starts_with("g:")); + ch.register_recipient_alias(&shortened, &original).await; + assert_eq!(ch.resolve_recipient(&shortened).await, original); + } + + #[tokio::test] + async fn alias_map_clears_when_cap_is_hit() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + // Pre-fill up to the cap with distinct entries. + for i in 0..RECIPIENT_ALIAS_CAP { + ch.register_recipient_alias(&format!("s{i}"), &format!("o{i}")) + .await; + } + assert_eq!(ch.recipient_aliases.lock().await.len(), RECIPIENT_ALIAS_CAP); + // One more entry must trigger a clear, then insert the new entry. + ch.register_recipient_alias("new_short", "new_original") + .await; + let m = ch.recipient_aliases.lock().await; + assert_eq!(m.len(), 1); + assert_eq!(m.get("new_short").map(String::as_str), Some("new_original")); + } +} diff --git a/src/openhuman/channels/providers/yuanbao/config.rs b/src/openhuman/channels/providers/yuanbao/config.rs new file mode 100644 index 0000000000..3f79888d67 --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/config.rs @@ -0,0 +1,250 @@ +//! Yuanbao channel configuration. +//! +//! Loaded from `ChannelsConfig.yuanbao` (TOML) and validated before the +//! channel is started. Mirrors the Python `YuanbaoAdapter` configuration +//! surface (hermes-agent `gateway/platforms/yuanbao.py`). + +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +use super::errors::YuanbaoError; + +/// Production environment endpoints (default). +const PROD_API_DOMAIN: &str = "https://bot.yuanbao.tencent.com"; +const PROD_WS_URL: &str = "wss://bot-wss.yuanbao.tencent.com/wss/connection"; +/// Pre-release environment endpoints. Opt in via `env = "pre"` in TOML. +const PRE_API_DOMAIN: &str = "https://bot-pre.yuanbao.tencent.com"; +const PRE_WS_URL: &str = "wss://bot-wss-pre.yuanbao.tencent.com/wss/connection"; + +/// User-facing config for the Yuanbao channel. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct YuanbaoConfig { + /// Application key (`X-ID` header / AuthBind biz_id). + pub app_key: String, + /// Application secret — used by the token-sign endpoint. + pub app_secret: String, + /// Bot account ID (uid for AuthBind). Optional — when empty, derived + /// from the AuthBindRsp payload after the first handshake. + #[serde(default)] + pub bot_id: String, + /// Environment selector for endpoint defaults: `"prod"` (default) or `"pre"`. + /// Only consulted when `api_domain` / `ws_domain` are empty. + #[serde(default = "default_env")] + pub env: String, + /// API base URL. Empty by default — derived from `env` at channel start. + /// Set explicitly in TOML to point at a custom deployment. + #[serde(default)] + pub api_domain: String, + /// WebSocket base URL. Empty by default — derived from `env` at channel + /// start. Set explicitly in TOML to point at a custom deployment. + #[serde(default)] + pub ws_domain: String, + /// Optional `route_env` header (canary routing). + #[serde(default)] + pub route_env: String, + /// Optional pre-provisioned token. When empty, the channel calls + /// `api_domain/api/token/sign` with `(app_key, app_secret)` to fetch one. + #[serde(default)] + pub token: String, + /// Plugin/bot version reported in `AuthBindReq.DeviceInfo.bot_version`. + #[serde(default = "default_bot_version")] + pub bot_version: String, + /// Optional bot display name — used by the `@bot` mention guard. + #[serde(default)] + pub bot_name: String, + + /// DM access policy: `open` / `allowlist` / `closed`. + #[serde(default = "default_dm_policy")] + pub dm_access: String, + /// Group access policy: `open` / `allowlist` / `closed`. + #[serde(default = "default_group_policy")] + pub group_access: String, + /// When `dm_access = "allowlist"`, only these UIDs may DM the bot. + #[serde(default)] + pub allowed_users: Vec, + /// When `group_access = "allowlist"`, only these group codes are allowed. + #[serde(default)] + pub allowed_groups: Vec, + /// Owner UID — receives elevated `/admin` commands. + #[serde(default)] + pub owner_id: String, + + /// Group messages must `@bot` to be processed (recommended). + #[serde(default = "default_true")] + pub group_at_required: bool, + + /// Maximum WS heartbeat interval override (seconds). 0 = use server-driven default. + #[serde(default)] + pub heartbeat_interval_secs: u64, + /// Reconnect retry budget — 0 means use the default cap (100). + #[serde(default)] + pub max_reconnect_attempts: u32, + + /// Per-message body length cap before splitting (UTF-8 bytes). + #[serde(default = "default_max_msg_len")] + pub max_message_length: usize, + /// Maximum inbound media file size in MiB. + #[serde(default = "default_max_media_mb")] + pub max_media_mb: u32, +} + +impl Default for YuanbaoConfig { + fn default() -> Self { + Self { + app_key: String::new(), + app_secret: String::new(), + bot_id: String::new(), + env: default_env(), + api_domain: String::new(), + ws_domain: String::new(), + route_env: String::new(), + token: String::new(), + bot_version: default_bot_version(), + bot_name: String::new(), + dm_access: default_dm_policy(), + group_access: default_group_policy(), + allowed_users: Vec::new(), + allowed_groups: Vec::new(), + owner_id: String::new(), + group_at_required: true, + heartbeat_interval_secs: 0, + max_reconnect_attempts: 0, + max_message_length: default_max_msg_len(), + max_media_mb: default_max_media_mb(), + } + } +} + +impl YuanbaoConfig { + /// Fill empty `api_domain` / `ws_domain` from the configured `env`. The + /// UI only collects `app_key` + `app_secret`; endpoints are derived + /// here so the renderer never needs to know about them. TOML values + /// take precedence (when non-empty), so existing deployments and + /// custom routes keep working. + pub fn apply_env_defaults(&mut self) { + let env = self.env.as_str(); + if self.api_domain.is_empty() { + self.api_domain = match env { + "pre" => PRE_API_DOMAIN.into(), + _ => PROD_API_DOMAIN.into(), + }; + } + if self.ws_domain.is_empty() { + self.ws_domain = match env { + "pre" => PRE_WS_URL.into(), + _ => PROD_WS_URL.into(), + }; + } + } + + /// Validate required fields. Called at channel construction time so + /// misconfiguration surfaces early in `start_channels`, not after a + /// failed WebSocket handshake. + pub fn validate(&self) -> Result<(), YuanbaoError> { + if self.app_key.is_empty() { + return Err(YuanbaoError::Config("`app_key` is required".into())); + } + if self.ws_domain.is_empty() { + return Err(YuanbaoError::Config("`ws_domain` is required".into())); + } + if self.token.is_empty() && self.app_secret.is_empty() { + return Err(YuanbaoError::Config( + "either `token` or `app_secret` must be set".into(), + )); + } + if self.api_domain.is_empty() && self.token.is_empty() { + return Err(YuanbaoError::Config( + "`api_domain` is required when `token` is not pre-provisioned".into(), + )); + } + Ok(()) + } +} + +fn default_bot_version() -> String { + "openhuman/0.1.0".into() +} + +fn default_env() -> String { + "prod".into() +} + +fn default_dm_policy() -> String { + "open".into() +} + +fn default_group_policy() -> String { + "allowlist".into() +} + +fn default_true() -> bool { + true +} + +fn default_max_msg_len() -> usize { + 4500 +} + +fn default_max_media_mb() -> u32 { + 50 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_config_is_invalid() { + let cfg = YuanbaoConfig::default(); + assert!(cfg.validate().is_err()); + } + + #[test] + fn validate_requires_app_key() { + let mut cfg = YuanbaoConfig::default(); + cfg.ws_domain = "wss://example".into(); + cfg.token = "tok".into(); + assert!(cfg.validate().is_err()); + cfg.app_key = "ak".into(); + assert!(cfg.validate().is_ok()); + } + + #[test] + fn validate_requires_token_or_secret() { + let mut cfg = YuanbaoConfig::default(); + cfg.app_key = "ak".into(); + cfg.ws_domain = "wss://example".into(); + cfg.api_domain = "https://api".into(); + assert!(cfg.validate().is_err()); + cfg.app_secret = "secret".into(); + assert!(cfg.validate().is_ok()); + } + + #[test] + fn apply_env_defaults_fills_prod_when_empty() { + let mut cfg = YuanbaoConfig::default(); + assert_eq!(cfg.env, "prod"); + cfg.apply_env_defaults(); + assert_eq!(cfg.api_domain, PROD_API_DOMAIN); + assert_eq!(cfg.ws_domain, PROD_WS_URL); + } + + #[test] + fn apply_env_defaults_respects_pre_env() { + let mut cfg = YuanbaoConfig::default(); + cfg.env = "pre".into(); + cfg.apply_env_defaults(); + assert_eq!(cfg.api_domain, PRE_API_DOMAIN); + assert_eq!(cfg.ws_domain, PRE_WS_URL); + } + + #[test] + fn apply_env_defaults_preserves_explicit_overrides() { + let mut cfg = YuanbaoConfig::default(); + cfg.api_domain = "https://custom.example".into(); + cfg.ws_domain = "wss://custom.example".into(); + cfg.apply_env_defaults(); + assert_eq!(cfg.api_domain, "https://custom.example"); + assert_eq!(cfg.ws_domain, "wss://custom.example"); + } +} diff --git a/src/openhuman/channels/providers/yuanbao/connection.rs b/src/openhuman/channels/providers/yuanbao/connection.rs new file mode 100644 index 0000000000..796c802405 --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/connection.rs @@ -0,0 +1,571 @@ +//! Yuanbao WebSocket connection manager. +//! +//! Owns one WebSocket to the gateway and runs: +//! 1. token sign-fetch (via [`SignManager`]) → `auth-bind` handshake +//! 2. periodic `ping` heartbeats +//! 3. inbound frame fan-out (decoded `ConnFrame` → mpsc) +//! 4. outbound request/response correlation via per-`msg_id` oneshot +//! 5. exponential-backoff reconnect with a no-retry close-code allowlist +//! +//! All public APIs are `&self` so the connection can be wrapped in +//! `Arc<…>` and shared between the listen loop and outbound senders. + +use std::collections::HashMap; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +use futures_util::{SinkExt, StreamExt}; +use parking_lot::Mutex as ParkingMutex; +use tokio::net::TcpStream; +use tokio::sync::{mpsc, oneshot, watch, Mutex}; +use tokio::time; +use tokio_tungstenite::{ + connect_async, tungstenite::protocol::Message, MaybeTlsStream, WebSocketStream, +}; +use tracing::{error, info, warn}; +use uuid::Uuid; + +use super::config::YuanbaoConfig; +use super::errors::{YuanbaoError, NO_RECONNECT_CLOSE_CODES}; +use super::proto::{ + decode_auth_bind_rsp, decode_conn_msg, encode_auth_bind, encode_ping, encode_push_ack, +}; +use super::proto_constants::*; +use super::sign::SignManager; +use super::types::{Account, ConnFrame, ConnectionState}; + +type WsSender = + futures_util::stream::SplitSink>, Message>; + +/// One inbound event delivered to the listen loop. +pub enum InboundEvent { + /// A regular biz push. + Push(ConnFrame), + /// Server told us we were kicked off. + Kickout(String), +} + +/// In-flight outbound request awaiting a matching `Response` frame. +type PendingMap = HashMap>; + +/// Long-lived connection manager. +pub struct YuanbaoConnection { + config: YuanbaoConfig, + state: ParkingMutex, + is_connected: AtomicBool, + msg_id_seq: AtomicU64, + sender: Mutex>, + inbound_tx: mpsc::UnboundedSender, + account: ParkingMutex, + sign_manager: Option>, + pending: ParkingMutex, +} + +impl YuanbaoConnection { + pub fn new( + config: YuanbaoConfig, + inbound_tx: mpsc::UnboundedSender, + sign_manager: Option>, + ) -> Arc { + let initial_account = Account { + uid: config.bot_id.clone(), + ..Default::default() + }; + Arc::new(Self { + config, + state: ParkingMutex::new(ConnectionState::Disconnected), + is_connected: AtomicBool::new(false), + msg_id_seq: AtomicU64::new(1), + sender: Mutex::new(None), + inbound_tx, + account: ParkingMutex::new(initial_account), + sign_manager, + pending: ParkingMutex::new(HashMap::new()), + }) + } + + pub fn is_connected(&self) -> bool { + self.is_connected.load(Ordering::Relaxed) + } + + pub fn state(&self) -> ConnectionState { + *self.state.lock() + } + + fn set_state(&self, new: ConnectionState) { + *self.state.lock() = new; + self.is_connected + .store(matches!(new, ConnectionState::Connected), Ordering::Relaxed); + } + + /// Current account info (best-effort — empty fields until auth-bind succeeds). + pub fn account(&self) -> Account { + self.account.lock().clone() + } + + fn update_account(&self, f: impl FnOnce(&mut Account)) { + let mut g = self.account.lock(); + f(&mut g); + } + + /// Per-process monotonic application msg_id. + pub fn next_msg_id(&self, prefix: &str) -> String { + let n = self.msg_id_seq.fetch_add(1, Ordering::Relaxed); + format!("{prefix}_{n}") + } + + /// Send a raw binary frame. Returns `NotConnected` if the connection + /// isn't currently up. + pub async fn send_frame(&self, data: Vec) -> Result<(), YuanbaoError> { + let mut guard = self.sender.lock().await; + match guard.as_mut() { + Some(s) => s + .send(Message::Binary(data)) + .await + .map_err(|e| YuanbaoError::WebSocket(e.to_string())), + None => Err(YuanbaoError::NotConnected), + } + } + + /// Send an already-encoded `ConnMsg` (alias of `send_frame`). + pub async fn send_conn_msg(&self, frame_bytes: Vec) -> Result<(), YuanbaoError> { + self.send_frame(frame_bytes).await + } + + /// Send a request and wait for the matching `Response` (correlated by + /// `msg_id`). Times out after `timeout` and removes the pending entry. + pub async fn send_and_wait( + &self, + msg_id: &str, + frame_bytes: Vec, + timeout: Duration, + ) -> Result { + let (tx, rx) = oneshot::channel(); + { + let mut p = self.pending.lock(); + p.insert(msg_id.to_string(), tx); + } + if let Err(e) = self.send_frame(frame_bytes).await { + self.pending.lock().remove(msg_id); + return Err(e); + } + match tokio::time::timeout(timeout, rx).await { + Ok(Ok(frame)) => Ok(frame), + Ok(Err(_)) => { + self.pending.lock().remove(msg_id); + Err(YuanbaoError::SendFailed(format!( + "correlator channel closed for msg_id={msg_id}" + ))) + } + Err(_) => { + self.pending.lock().remove(msg_id); + Err(YuanbaoError::Timeout(format!("msg_id={msg_id}"))) + } + } + } + + /// Trigger a graceful shutdown. + pub async fn shutdown(&self) { + let mut guard = self.sender.lock().await; + if let Some(mut s) = guard.take() { + let _ = s.send(Message::Close(None)).await; + let _ = s.close().await; + } + // Drop all pending waiters so callers stop hanging. + let mut pending = self.pending.lock(); + pending.clear(); + self.set_state(ConnectionState::Disconnected); + } + + /// Main reconnection loop. Returns when `shutdown` flips to `true`. + pub async fn run(self: Arc, mut shutdown: watch::Receiver) { + let max_attempts = if self.config.max_reconnect_attempts > 0 { + self.config.max_reconnect_attempts + } else { + MAX_RECONNECT_ATTEMPTS + }; + let mut attempt: u32 = 0; + + loop { + if *shutdown.borrow() { + info!("[yuanbao] shutdown signaled, stopping connection loop"); + self.shutdown().await; + return; + } + if attempt >= max_attempts { + error!("[yuanbao] giving up after {} reconnect attempts", attempt); + return; + } + + self.set_state(if attempt == 0 { + ConnectionState::Connecting + } else { + ConnectionState::Reconnecting + }); + + let outcome = self.connect_once(&mut shutdown).await; + match outcome { + Ok(Some(code)) if NO_RECONNECT_CLOSE_CODES.contains(&code) => { + error!("[yuanbao] no-reconnect close code {} — stopping", code); + return; + } + Ok(close_code) => info!("[yuanbao] connection closed (code={:?})", close_code), + Err(e) => warn!("[yuanbao] connection error: {}", e), + } + + self.set_state(ConnectionState::Disconnected); + *self.sender.lock().await = None; + self.pending.lock().clear(); + + attempt += 1; + let delay = backoff_seconds(attempt); + info!( + "[yuanbao] reconnecting in {}s (attempt {}/{})", + delay, attempt, max_attempts + ); + tokio::select! { + _ = time::sleep(Duration::from_secs(delay)) => {} + _ = shutdown.changed() => { + info!("[yuanbao] shutdown received during backoff"); + self.shutdown().await; + return; + } + } + } + } + + async fn connect_once( + &self, + shutdown: &mut watch::Receiver, + ) -> Result, YuanbaoError> { + // Resolve token (may hit the sign endpoint). + let (token, bot_id, source) = self.resolve_token().await?; + if !bot_id.is_empty() { + self.update_account(|a| { + if a.uid.is_empty() { + a.uid = bot_id.clone(); + } + }); + } + + let url = &self.config.ws_domain; + info!("[yuanbao] connecting to {}", url); + let (ws_stream, _resp) = connect_async(url) + .await + .map_err(|e| YuanbaoError::WebSocket(e.to_string()))?; + + let (sender, mut receiver) = ws_stream.split(); + *self.sender.lock().await = Some(sender); + info!("[yuanbao] WebSocket connected — sending auth-bind"); + + self.set_state(ConnectionState::Authenticating); + self.send_auth_bind(&token, &bot_id, &source).await?; + + // Wait for auth-bind response. + let auth_timeout = Duration::from_secs(AUTH_TIMEOUT_SECS); + let auth_msg = tokio::time::timeout(auth_timeout, receiver.next()) + .await + .map_err(|_| YuanbaoError::AuthTimeout)? + .ok_or_else(|| YuanbaoError::WebSocket("closed during auth-bind".into()))? + .map_err(|e| YuanbaoError::WebSocket(e.to_string()))?; + + self.handle_auth_response(&auth_msg)?; + self.set_state(ConnectionState::Connected); + info!("[yuanbao] auth-bind successful, entering read loop"); + + let ping_secs = if self.config.heartbeat_interval_secs > 0 { + self.config.heartbeat_interval_secs + } else { + PING_INTERVAL_SECS + }; + let mut ping_interval = time::interval(Duration::from_secs(ping_secs)); + ping_interval.tick().await; // skip first tick + + let mut close_code: Option = None; + let mut consecutive_ping_failures: u32 = 0; + + loop { + tokio::select! { + _ = shutdown.changed() => { + info!("[yuanbao] shutdown received in read loop"); + return Ok(None); + } + _ = ping_interval.tick() => { + let msg_id = self.next_msg_id("ping"); + let frame = encode_ping(&msg_id); + if let Err(e) = self.send_frame(frame).await { + warn!("[yuanbao] ping send failed: {}", e); + consecutive_ping_failures += 1; + if consecutive_ping_failures >= HEARTBEAT_TIMEOUT_THRESHOLD { + warn!( + "[yuanbao] {} consecutive ping failures — dropping", + consecutive_ping_failures + ); + break; + } + } else { + consecutive_ping_failures = 0; + } + } + msg = receiver.next() => { + match msg { + Some(Ok(Message::Binary(data))) => self.handle_binary(data).await, + Some(Ok(Message::Close(frame))) => { + close_code = frame.map(|f| u16::from(f.code)); + info!("[yuanbao] received close frame: {:?}", close_code); + break; + } + Some(Ok(Message::Ping(payload))) => { + let mut guard = self.sender.lock().await; + if let Some(s) = guard.as_mut() { + let _ = s.send(Message::Pong(payload)).await; + } + } + Some(Ok(_)) => {} + Some(Err(e)) => { + warn!("[yuanbao] websocket read error: {}", e); + break; + } + None => { + info!("[yuanbao] websocket stream ended"); + break; + } + } + } + } + } + + Ok(close_code) + } + + async fn resolve_token(&self) -> Result<(String, String, String), YuanbaoError> { + let cfg = &self.config; + if !cfg.token.is_empty() { + // Pre-signed token: no source returned by the sign endpoint. + // Mirrors yuanbao-openclaw-plugin's static-token branch, which + // returns source="bot". + return Ok((cfg.token.clone(), cfg.bot_id.clone(), String::new())); + } + let mgr = self + .sign_manager + .as_ref() + .ok_or_else(|| YuanbaoError::AuthFailed("no token and no SignManager".into()))?; + if cfg.app_secret.is_empty() { + return Err(YuanbaoError::AuthFailed( + "app_secret required to sign".into(), + )); + } + let entry = mgr + .get_token( + &cfg.app_key, + &cfg.app_secret, + &cfg.api_domain, + &cfg.route_env, + ) + .await?; + Ok((entry.token, entry.bot_id, entry.source)) + } + + async fn send_auth_bind( + &self, + token: &str, + bot_id: &str, + source: &str, + ) -> Result<(), YuanbaoError> { + let cfg = &self.config; + let uid = if bot_id.is_empty() { + self.account.lock().uid.clone() + } else { + bot_id.to_string() + }; + let msg_id = format!("auth_{}", Uuid::new_v4()); + // Auth-bind payload aligned with yuanbao-openclaw-plugin: + // biz_id = "ybBot" (server rejects raw app_key with 40011). + // source comes from the sign endpoint response; fall back to + // "bot" when missing (matches the plugin's static-token branch + // and `data.source || "bot"` resolution). + let resolved_source = if source.is_empty() { "bot" } else { source }; + let frame = encode_auth_bind( + "ybBot", + &uid, + resolved_source, + token, + &msg_id, + env!("CARGO_PKG_VERSION"), + std::env::consts::OS, + &cfg.bot_version, + &cfg.route_env, + ); + self.send_frame(frame).await + } + + fn handle_auth_response(&self, msg: &Message) -> Result<(), YuanbaoError> { + let data = match msg { + Message::Binary(b) => b, + _ => { + return Err(YuanbaoError::AuthFailed( + "expected binary auth-bind response".into(), + )) + } + }; + let frame = decode_conn_msg(data)?; + if frame.cmd != cmd::AUTH_BIND { + return Err(YuanbaoError::AuthFailed(format!( + "unexpected cmd in auth response: {:?}", + frame.cmd + ))); + } + if frame.status != 0 { + return Err(YuanbaoError::AuthFailed(format!( + "auth rejected: status={}", + frame.status + ))); + } + // Body carries code/message/connect_id — back-fill the account. + if !frame.data.is_empty() { + let rsp = decode_auth_bind_rsp(&frame.data)?; + if rsp.code != 0 { + return Err(YuanbaoError::AuthFailed(format!( + "auth-bind code={} message={}", + rsp.code, rsp.message + ))); + } + if !rsp.connect_id.is_empty() { + self.update_account(|a| a.connect_id = rsp.connect_id.clone()); + info!("[yuanbao] auth-bind connect_id={}", rsp.connect_id); + } + } + Ok(()) + } + + async fn handle_binary(&self, data: Vec) { + let frame = match decode_conn_msg(&data) { + Ok(f) => f, + Err(e) => { + warn!("[yuanbao] failed to decode binary frame: {}", e); + return; + } + }; + + info!( + "[yuanbao] rx cmd={} module={} cmd_type={} seq={} msg_id={} data_len={}", + frame.cmd, + frame.module, + frame.cmd_type, + frame.seq_no, + frame.msg_id, + frame.data.len() + ); + + // Responses → match against pending requests via msg_id. + if frame.cmd_type == cmd_type::RESPONSE { + if !frame.msg_id.is_empty() { + if let Some(tx) = self.pending.lock().remove(&frame.msg_id) { + let _ = tx.send(frame); + return; + } + } + info!( + "[yuanbao] response with no waiter cmd={} msg_id={}", + frame.cmd, frame.msg_id + ); + return; + } + + // For server-driven pushes, ACK first when the head asks for it. + if frame.cmd_type == cmd_type::PUSH && frame.need_ack { + let ack = encode_push_ack(&frame); + if let Err(e) = self.send_frame(ack).await { + warn!("[yuanbao] failed to send PushAck: {}", e); + } + } + + // Handle conn-level builtin pushes inline. + if frame.cmd == cmd::KICKOUT { + let reason = String::from_utf8_lossy(&frame.data).into_owned(); + warn!("[yuanbao] kickout received: {}", reason); + let _ = self.inbound_tx.send(InboundEvent::Kickout(reason)); + return; + } + if frame.cmd == cmd::UPDATE_META { + return; + } + + if frame.cmd_type != cmd_type::PUSH { + info!( + "[yuanbao] dropping non-push frame cmd_type={} cmd={}", + frame.cmd_type, frame.cmd + ); + return; + } + + info!( + "[yuanbao] push forwarded to listener cmd={} module={} seq={}", + frame.cmd, frame.module, frame.seq_no + ); + if self.inbound_tx.send(InboundEvent::Push(frame)).is_err() { + error!("[yuanbao] inbound channel closed — listener gone"); + } + } +} + +/// Backoff schedule used by `run()`. After the configured table is +/// exhausted we cap at the last entry forever (until the attempt budget +/// trips). Indexing is 1-based so attempt=1 → table[0]. +fn backoff_seconds(attempt: u32) -> u64 { + let idx = attempt.saturating_sub(1) as usize; + if idx < RECONNECT_DELAYS.len() { + RECONNECT_DELAYS[idx] + } else { + *RECONNECT_DELAYS.last().unwrap_or(&60) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn backoff_follows_schedule() { + assert_eq!(backoff_seconds(1), 1); + assert_eq!(backoff_seconds(2), 2); + assert_eq!(backoff_seconds(3), 5); + assert_eq!(backoff_seconds(6), 60); + assert_eq!(backoff_seconds(100), 60); + assert_eq!(backoff_seconds(0), 1); + } + + fn cfg() -> YuanbaoConfig { + let mut c = YuanbaoConfig::default(); + c.app_key = "ak".into(); + c.ws_domain = "wss://example".into(); + c.token = "tok".into(); + c.bot_id = "bot1".into(); + c + } + + #[tokio::test] + async fn pending_correlator_times_out() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + let err = conn + .send_and_wait("missing_id", vec![1, 2, 3], Duration::from_millis(20)) + .await + .unwrap_err(); + // Without a connected socket, send_frame fails first → SendFailed/NotConnected. + assert!(matches!( + err, + YuanbaoError::NotConnected | YuanbaoError::Timeout(_) | YuanbaoError::SendFailed(_) + )); + } + + #[tokio::test] + async fn account_back_fill_picks_up_uid() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + assert_eq!(conn.account().uid, "bot1"); + conn.update_account(|a| a.connect_id = "cid_xyz".into()); + assert_eq!(conn.account().connect_id, "cid_xyz"); + } +} diff --git a/src/openhuman/channels/providers/yuanbao/cos.rs b/src/openhuman/channels/providers/yuanbao/cos.rs new file mode 100644 index 0000000000..0f2d3b3e32 --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/cos.rs @@ -0,0 +1,392 @@ +//! Tencent COS upload — HMAC-SHA1 signing and `genUploadInfo` flow. +//! +//! Split out of `media.rs` to stay under the 500-line per-file ceiling. +//! Reference: . + +use std::time::{SystemTime, UNIX_EPOCH}; + +use hmac::{Hmac, Mac}; +use sha1::{Digest, Sha1}; +use tracing::{debug, info}; + +use super::errors::YuanbaoError; +use super::media::{guess_mime_type, is_image, parse_image_size, ImageDims}; + +const UPLOAD_INFO_PATH: &str = "/api/resource/genUploadInfo"; +const COS_USE_ACCELERATE: bool = true; + +type HmacSha1 = Hmac; + +fn hmac_sha1_hex(key: &[u8], msg: &[u8]) -> String { + let mut mac = HmacSha1::new_from_slice(key).expect("HMAC accepts any key length"); + mac.update(msg); + hex::encode(mac.finalize().into_bytes()) +} + +fn sha1_hex(msg: &[u8]) -> String { + let mut hasher = Sha1::new(); + hasher.update(msg); + hex::encode(hasher.finalize()) +} + +#[derive(Debug, Clone)] +pub struct CosSignInput<'a> { + pub method: &'a str, + /// URL-encoded path with leading `/`. + pub path: &'a str, + pub params: &'a [(&'a str, &'a str)], + pub headers: &'a [(&'a str, &'a str)], + pub secret_id: &'a str, + pub secret_key: &'a str, + pub start_time: u64, + pub expire_seconds: u64, +} + +/// Build the COS `Authorization` header value. +pub fn cos_sign(input: &CosSignInput<'_>) -> String { + let q_sign_time = format!( + "{};{}", + input.start_time, + input.start_time + input.expire_seconds + ); + + // Step 1 — SignKey = HMAC-SHA1(SecretKey, q-sign-time). + let sign_key = hmac_sha1_hex(input.secret_key.as_bytes(), q_sign_time.as_bytes()); + + // Step 2 — HttpString. Names lower-cased, values URL-encoded. + let mut params: Vec<(String, String)> = input + .params + .iter() + .map(|(k, v)| (k.to_ascii_lowercase(), url_encode(v))) + .collect(); + params.sort(); + let mut headers: Vec<(String, String)> = input + .headers + .iter() + .map(|(k, v)| (k.to_ascii_lowercase(), url_encode(v))) + .collect(); + headers.sort(); + + let url_param_list = params + .iter() + .map(|(k, _)| k.as_str()) + .collect::>() + .join(";"); + let url_params = params + .iter() + .map(|(k, v)| format!("{k}={v}")) + .collect::>() + .join("&"); + let header_list = headers + .iter() + .map(|(k, _)| k.as_str()) + .collect::>() + .join(";"); + let header_str = headers + .iter() + .map(|(k, v)| format!("{k}={v}")) + .collect::>() + .join("&"); + + let http_string = format!( + "{}\n{}\n{}\n{}\n", + input.method.to_ascii_lowercase(), + input.path, + url_params, + header_str + ); + + // Step 3 — StringToSign. + let sha1_of_http = sha1_hex(http_string.as_bytes()); + let string_to_sign = format!("sha1\n{q_sign_time}\n{sha1_of_http}\n"); + + // Step 4 — Signature. + let signature = hmac_sha1_hex(sign_key.as_bytes(), string_to_sign.as_bytes()); + + format!( + "q-sign-algorithm=sha1&q-ak={sid}&q-sign-time={t}&q-key-time={t}\ + &q-header-list={hl}&q-url-param-list={pl}&q-signature={sig}", + sid = input.secret_id, + t = q_sign_time, + hl = header_list, + pl = url_param_list, + sig = signature + ) +} + +fn url_encode(s: &str) -> String { + urlencoding::encode(s).into_owned() +} + +fn encode_cos_key(key: &str) -> String { + key.split('/') + .map(|seg| urlencoding::encode(seg).into_owned()) + .collect::>() + .join("/") +} + +#[derive(Debug, Clone, Default)] +pub struct CosCredentials { + pub bucket: String, + pub region: String, + pub location: String, + pub secret_id: String, + pub secret_key: String, + pub session_token: String, + pub start_time: u64, + pub expired_time: u64, + pub resource_url: String, +} + +#[derive(Debug, Clone)] +pub struct UploadResult { + pub url: String, + pub uuid: String, + pub size: u64, + pub width: u32, + pub height: u32, +} + +/// Fetch COS upload credentials from the yuanbao gateway. +pub async fn get_cos_credentials( + http: &reqwest::Client, + api_domain: &str, + app_key: &str, + bot_id: &str, + token: &str, + route_env: &str, + filename: &str, +) -> Result { + let upload_url = format!( + "{}/{}", + api_domain.trim_end_matches('/'), + UPLOAD_INFO_PATH.trim_start_matches('/') + ); + let body = serde_json::json!({ + "fileName": filename, + "fileId": uuid::Uuid::new_v4().simple().to_string(), + "docFrom": "localDoc", + "docOpenId": "", + }); + let mut req = http + .post(&upload_url) + .header("Content-Type", "application/json") + .header("X-Token", token) + .header("X-ID", if bot_id.is_empty() { app_key } else { bot_id }) + .header("X-Source", "web"); + if !route_env.is_empty() { + req = req.header("X-Route-Env", route_env); + } + let resp = req + .json(&body) + .send() + .await + .map_err(|e| YuanbaoError::Connection(format!("genUploadInfo: {e}")))?; + if !resp.status().is_success() { + return Err(YuanbaoError::Media(format!( + "genUploadInfo HTTP {}", + resp.status() + ))); + } + let payload: serde_json::Value = resp + .json() + .await + .map_err(|e| YuanbaoError::Media(format!("genUploadInfo body parse: {e}")))?; + if let Some(code) = payload.get("code").and_then(|c| c.as_i64()) { + if code != 0 { + return Err(YuanbaoError::Media(format!( + "genUploadInfo code={code}, msg={}", + payload.get("msg").and_then(|m| m.as_str()).unwrap_or("") + ))); + } + } + let data = payload.get("data").unwrap_or(&payload); + let get_str = |k: &str| -> String { + data.get(k) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string() + }; + let get_u64 = |k: &str| -> u64 { data.get(k).and_then(|v| v.as_u64()).unwrap_or(0) }; + + Ok(CosCredentials { + bucket: get_str("bucketName"), + region: get_str("region"), + location: get_str("location"), + secret_id: get_str("encryptTmpSecretId"), + secret_key: get_str("encryptTmpSecretKey"), + session_token: get_str("encryptToken"), + start_time: get_u64("startTime"), + expired_time: get_u64("expiredTime"), + resource_url: get_str("resourceUrl"), + }) +} + +/// PUT a file to COS using credentials returned by `get_cos_credentials`. +pub async fn upload_to_cos( + http: &reqwest::Client, + creds: &CosCredentials, + data: &[u8], + filename: &str, + mut content_type: String, +) -> Result { + if creds.secret_id.is_empty() || creds.secret_key.is_empty() || creds.location.is_empty() { + return Err(YuanbaoError::Media( + "COS credentials missing secret_id / secret_key / location".into(), + )); + } + if content_type.is_empty() || content_type == "application/octet-stream" { + content_type = if is_image(filename, "") { + guess_mime_type(filename).to_string() + } else { + "application/octet-stream".into() + }; + } + + let cos_host = if COS_USE_ACCELERATE { + format!("{}.cos.accelerate.myqcloud.com", creds.bucket) + } else { + format!("{}.cos.{}.myqcloud.com", creds.bucket, creds.region) + }; + let encoded_key = encode_cos_key(&creds.location); + let cos_url = format!("https://{cos_host}/{}", encoded_key.trim_start_matches('/')); + + let now = unix_now(); + let start = if creds.start_time != 0 { + creds.start_time + } else { + now + }; + let expire = if creds.expired_time > now { + creds.expired_time - now + } else { + 3600 + }; + + let headers_for_sign: Vec<(&str, &str)> = vec![ + ("host", cos_host.as_str()), + ("content-type", content_type.as_str()), + ("x-cos-security-token", creds.session_token.as_str()), + ]; + let path = format!("/{}", encoded_key.trim_start_matches('/')); + let sig = cos_sign(&CosSignInput { + method: "put", + path: &path, + params: &[], + headers: &headers_for_sign, + secret_id: &creds.secret_id, + secret_key: &creds.secret_key, + start_time: start, + expire_seconds: expire, + }); + + info!( + "[yuanbao] COS PUT bucket={} key={} size={}", + creds.bucket, + creds.location, + data.len() + ); + let resp = http + .put(&cos_url) + .header("Authorization", sig) + .header("Content-Type", content_type.as_str()) + .header("x-cos-security-token", &creds.session_token) + .body(data.to_vec()) + .send() + .await + .map_err(|e| YuanbaoError::Connection(format!("COS PUT: {e}")))?; + if !resp.status().is_success() { + return Err(YuanbaoError::Media(format!( + "COS PUT HTTP {}", + resp.status() + ))); + } + + let dims = if content_type.starts_with("image/") { + parse_image_size(data).unwrap_or(ImageDims { + width: 0, + height: 0, + }) + } else { + ImageDims { + width: 0, + height: 0, + } + }; + + let uuid = { + let mut h = Sha1::new(); + h.update(data); + hex::encode(h.finalize()) + }; + let url = if creds.resource_url.is_empty() { + cos_url + } else { + creds.resource_url.clone() + }; + debug!("[yuanbao] COS upload ok url={url}"); + Ok(UploadResult { + url, + uuid, + size: data.len() as u64, + width: dims.width, + height: dims.height, + }) +} + +fn unix_now() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn cos_sign_is_deterministic() { + let s = cos_sign(&CosSignInput { + method: "put", + path: "/test/file.bin", + params: &[], + headers: &[("host", "bucket.cos.example.com")], + secret_id: "AKID", + secret_key: "SK", + start_time: 1_700_000_000, + expire_seconds: 3600, + }); + let s2 = cos_sign(&CosSignInput { + method: "put", + path: "/test/file.bin", + params: &[], + headers: &[("host", "bucket.cos.example.com")], + secret_id: "AKID", + secret_key: "SK", + start_time: 1_700_000_000, + expire_seconds: 3600, + }); + assert_eq!(s, s2); + assert!(s.starts_with("q-sign-algorithm=sha1")); + assert!(s.contains("q-ak=AKID")); + assert!(s.contains("q-sign-time=1700000000;1700003600")); + } + + #[test] + fn cos_sign_changes_with_path() { + let base = CosSignInput { + method: "put", + path: "/a", + params: &[], + headers: &[("host", "h")], + secret_id: "AKID", + secret_key: "SK", + start_time: 1_700_000_000, + expire_seconds: 3600, + }; + let s1 = cos_sign(&base); + let s2 = cos_sign(&CosSignInput { path: "/b", ..base }); + assert_ne!(s1, s2); + } +} diff --git a/src/openhuman/channels/providers/yuanbao/errors.rs b/src/openhuman/channels/providers/yuanbao/errors.rs new file mode 100644 index 0000000000..de0d3d2335 --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/errors.rs @@ -0,0 +1,61 @@ +//! Yuanbao channel error types. + +use thiserror::Error; + +/// Close codes from the yuanbao gateway that indicate the connection +/// must **not** be retried (auth failure, kicked off, etc.). +/// +/// Mirrors `NO_RECONNECT_CLOSE_CODES` in hermes-agent `yuanbao.py`. +pub const NO_RECONNECT_CLOSE_CODES: &[u16] = &[4012, 4013, 4014, 4018, 4019, 4021]; + +/// Auth-related response codes that mean "credentials are bad" — surface +/// to the user, don't auto-retry. +pub const AUTH_FAILED_CODES: &[u32] = &[40001, 40002, 40003]; + +/// Auth-related codes that are transient — retry with backoff. +pub const AUTH_RETRYABLE_CODES: &[u32] = &[40010, 40011]; + +#[derive(Debug, Error)] +pub enum YuanbaoError { + #[error("protocol encode error: {0}")] + ProtoEncode(String), + + #[error("protocol decode error: {0}")] + ProtoDecode(String), + + #[error("not connected")] + NotConnected, + + #[error("connection closed: code={code}, reason={reason}")] + ConnectionClosed { code: u16, reason: String }, + + #[error("WebSocket error: {0}")] + WebSocket(String), + + #[error("HTTP/connection error: {0}")] + Connection(String), + + #[error("auth-bind failed: {0}")] + AuthFailed(String), + + #[error("auth-bind timeout")] + AuthTimeout, + + #[error("login timeout")] + LoginTimeout, + + #[error("request timeout: {0}")] + Timeout(String), + + #[error("send-message failed: {0}")] + SendFailed(String), + + #[error("media error: {0}")] + Media(String), + + #[error("invalid message: {0}")] + InvalidMessage(String), + + #[error("config error: {0}")] + Config(String), +} diff --git a/src/openhuman/channels/providers/yuanbao/ids.rs b/src/openhuman/channels/providers/yuanbao/ids.rs new file mode 100644 index 0000000000..4375e561b6 --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/ids.rs @@ -0,0 +1,115 @@ +//! Account-id shortening for yuanbao. +//! +//! Yuanbao uids (`from_account`) are 64-char hashes assigned by the platform. +//! The composite `ChannelMessage` thread_id that downstream consumers derive +//! from `sender` and `reply_target` (`channel:yuanbao__`) +//! becomes ~145 chars. After the conversation store hex-encodes that for the +//! per-thread JSONL filename it grows to ~296 chars, exceeding `NAME_MAX` +//! (255 bytes) on ext4/HFS+/APFS/NTFS — writes fail with `ENAMETOOLONG` and +//! channel history is lost. +//! +//! Rather than push the filesystem limit into shared `ConversationStore` code, +//! we shorten yuanbao-specific ids at the channel boundary. Internal yuanbao +//! state (echo guard, access control, owner-command check) keeps the original +//! `from_account` — only the value emitted on `ChannelMessage.sender` / +//! `ChannelMessage.reply_target` is shortened. +//! +//! Format: `_`. +//! The 8-char prefix keeps logs roughly groupable for the same user; the +//! sha256 suffix guarantees uniqueness across uids that share a prefix. + +use sha2::{Digest, Sha256}; + +/// Max raw account-id length before the shortening kicks in. +/// +/// Anything shorter is passed through unchanged so short upstream-style ids +/// (e.g. numeric ids, future protocol changes) keep their natural form. +const ACCOUNT_ID_PASSTHROUGH_MAX: usize = 24; + +/// Shorten a yuanbao account id for use in `ChannelMessage.sender` / +/// `ChannelMessage.reply_target`. See module docs for rationale. +pub(super) fn shorten_account_id(uid: &str) -> String { + if uid.len() <= ACCOUNT_ID_PASSTHROUGH_MAX { + return uid.to_string(); + } + let prefix: String = uid.chars().take(8).collect(); + let digest = Sha256::digest(uid.as_bytes()); + format!("{prefix}_{:.16x}", digest) +} + +/// Shorten a yuanbao `reply_target`, preserving the `g:` shape +/// used for group chats. The `g:` discriminator is required by outbound +/// routing (see [`super::types::InboundContext::reply_target`]). +pub(super) fn shorten_reply_target(reply_target: &str) -> String { + if let Some(group_code) = reply_target.strip_prefix("g:") { + format!("g:{}", shorten_account_id(group_code)) + } else { + shorten_account_id(reply_target) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn passes_short_ids_through_unchanged() { + assert_eq!(shorten_account_id("123456"), "123456"); + assert_eq!(shorten_account_id(""), ""); + let exactly_max = "a".repeat(ACCOUNT_ID_PASSTHROUGH_MAX); + assert_eq!(shorten_account_id(&exactly_max), exactly_max); + } + + #[test] + fn shortens_long_ids_to_prefix_plus_hash() { + let long_uid = "a".repeat(64); + let shortened = shorten_account_id(&long_uid); + assert_eq!(shortened.len(), 8 + 1 + 16, "8 prefix + '_' + 16 hex"); + assert!(shortened.starts_with("aaaaaaaa_")); + } + + #[test] + fn shortening_is_deterministic_and_collision_resistant() { + let a = "f".repeat(64); + let mut b = a.clone(); + b.replace_range(63..64, "e"); // differ in last char only + let sa = shorten_account_id(&a); + let sb = shorten_account_id(&b); + assert_eq!(sa, shorten_account_id(&a), "deterministic"); + assert_ne!(sa, sb, "different uids hash to different ids"); + } + + #[test] + fn group_reply_target_preserves_g_prefix() { + let short_group = shorten_reply_target("g:short_group"); + assert_eq!(short_group, "g:short_group"); + + let long_code = "a".repeat(64); + let long_group = format!("g:{long_code}"); + let shortened = shorten_reply_target(&long_group); + assert!(shortened.starts_with("g:aaaaaaaa_")); + assert_eq!(shortened.len(), 2 + 8 + 1 + 16); + } + + #[test] + fn dm_reply_target_shortens_like_account_id() { + let uid = "z".repeat(64); + assert_eq!(shorten_reply_target(&uid), shorten_account_id(&uid)); + } + + #[test] + fn shortened_thread_id_fits_under_name_max() { + // Simulate the worst case: long uid for sender + reply_target. + let uid = "f".repeat(64); + let sender = shorten_account_id(&uid); + let reply_target = shorten_account_id(&uid); + let thread_id = format!("channel:yuanbao_{sender}_{reply_target}"); + // hex-encoded filename used by ConversationStore (`.jsonl`). + let hex_name_len = thread_id.len() * 2 + ".jsonl".len(); + // NAME_MAX on common filesystems is 255 bytes. + assert!( + hex_name_len <= 255, + "shortened thread_id hex filename ({hex_name_len} bytes) must fit under NAME_MAX (255)" + ); + } +} diff --git a/src/openhuman/channels/providers/yuanbao/inbound.rs b/src/openhuman/channels/providers/yuanbao/inbound.rs new file mode 100644 index 0000000000..3491a24cc6 --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/inbound.rs @@ -0,0 +1,623 @@ +//! Inbound message pipeline (17 stages). +//! +//! Mirrors `InboundPipeline` in hermes-agent `gateway/platforms/yuanbao.py`. +//! Each stage runs in order; any of them can short-circuit by +//! returning `Skip(reason)`. `Abort(_)` propagates an error. + +use std::collections::VecDeque; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use async_trait::async_trait; +use tokio::sync::RwLock; +use tracing::{debug, trace}; + +use super::config::YuanbaoConfig; +use super::errors::YuanbaoError; +use super::proto::{decode_inbound_json, decode_inbound_push}; +use super::proto_constants::*; +use super::types::*; + +/// Shared per-channel state that survives across messages. +pub struct PipelineState { + pub bot_id: String, + pub bot_name: String, + pub owner_id: String, + pub dm_access: AccessPolicy, + pub group_access: AccessPolicy, + pub allowed_users: Vec, + pub allowed_groups: Vec, + pub group_at_required: bool, + pub home_chat: RwLock>, + pub dedup: RwLock, +} + +impl PipelineState { + pub fn new(cfg: &YuanbaoConfig, bot_id: String) -> Arc { + Arc::new(Self { + bot_id, + bot_name: cfg.bot_name.clone(), + owner_id: cfg.owner_id.clone(), + dm_access: AccessPolicy::parse(&cfg.dm_access), + group_access: AccessPolicy::parse(&cfg.group_access), + allowed_users: cfg.allowed_users.clone(), + allowed_groups: cfg.allowed_groups.clone(), + group_at_required: cfg.group_at_required, + home_chat: RwLock::new(None), + dedup: RwLock::new(DedupCache::new(DEDUP_CAPACITY, DEDUP_TTL_SECS)), + }) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AccessPolicy { + Open, + Allowlist, + Closed, +} + +impl AccessPolicy { + fn parse(s: &str) -> Self { + match s.to_ascii_lowercase().as_str() { + "open" => Self::Open, + "closed" | "disabled" | "none" => Self::Closed, + _ => Self::Allowlist, + } + } +} + +/// Mutable context passed through every inbound stage. +#[derive(Debug, Clone)] +pub struct PipelineCtx { + pub msg: InboundMessage, + pub source: Source, + pub text: String, + pub image_urls: Vec, + pub is_at_bot: bool, + pub is_owner_command: bool, + pub kind: MessageKind, +} + +/// Outcome of a single inbound stage invocation. +#[derive(Debug)] +pub enum MwResult { + Continue, + Skip(&'static str), + Abort(YuanbaoError), +} + +/// Final outcome of the whole pipeline. +#[derive(Debug)] +pub enum PipelineOutcome { + Dispatch(PipelineCtx), + Filtered(&'static str), + Failed(YuanbaoError), +} + +#[async_trait] +pub trait Middleware: Send + Sync { + fn name(&self) -> &'static str; + async fn process(&self, state: &PipelineState, ctx: &mut PipelineCtx) -> MwResult; +} + +/// LRU-like dedup cache with TTL. +pub struct DedupCache { + capacity: usize, + ttl: Duration, + order: VecDeque<(String, Instant)>, + index: std::collections::HashSet, +} + +impl DedupCache { + pub fn new(capacity: usize, ttl_secs: u64) -> Self { + Self { + capacity, + ttl: Duration::from_secs(ttl_secs), + order: VecDeque::with_capacity(capacity), + index: std::collections::HashSet::with_capacity(capacity), + } + } + + /// Returns `true` if `id` has been seen within the TTL window. Inserts it otherwise. + pub fn check_and_insert(&mut self, id: &str) -> bool { + self.evict_expired(); + if self.index.contains(id) { + return true; + } + if self.order.len() >= self.capacity { + if let Some((old, _)) = self.order.pop_front() { + self.index.remove(&old); + } + } + self.order.push_back((id.to_string(), Instant::now())); + self.index.insert(id.to_string()); + false + } + + fn evict_expired(&mut self) { + let now = Instant::now(); + while let Some((_, ts)) = self.order.front() { + if now.duration_since(*ts) > self.ttl { + if let Some((old, _)) = self.order.pop_front() { + self.index.remove(&old); + } + } else { + break; + } + } + } +} + +// ───── Individual inbound stages ──────────────────────────────────── + +struct DecodeMw; +struct ExtractFieldsMw; +struct RecallGuardMw; +struct DedupMw; +struct SkipSelfMw; +struct ChatRoutingMw; +struct AccessGuardMw; +struct AutoSetHomeMw; +struct ExtractContentMw; +struct PlaceholderFilterMw; +struct OwnerCommandMw; +struct BuildSourceMw; +struct GroupAtGuardMw; +struct GroupAttributionMw; +struct ClassifyMsgTypeMw; +struct QuoteContextMw; +struct MediaResolveMw; + +#[async_trait] +impl Middleware for DecodeMw { + fn name(&self) -> &'static str { + "decode" + } + async fn process(&self, _s: &PipelineState, _c: &mut PipelineCtx) -> MwResult { + // Decoding happens before we build a PipelineCtx — this MW is a placeholder + // so the stage list still has 17 entries (mirrors hermes-agent). + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for ExtractFieldsMw { + fn name(&self) -> &'static str { + "extract_fields" + } + async fn process(&self, _s: &PipelineState, _c: &mut PipelineCtx) -> MwResult { + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for RecallGuardMw { + fn name(&self) -> &'static str { + "recall_guard" + } + async fn process(&self, _s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + if c.msg.is_recall() { + c.kind = MessageKind::Recall; + return MwResult::Skip("recall_guard"); + } + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for DedupMw { + fn name(&self) -> &'static str { + "dedup" + } + async fn process(&self, s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + if c.msg.msg_id.is_empty() { + return MwResult::Continue; // nothing to dedup on + } + let mut cache = s.dedup.write().await; + if cache.check_and_insert(&c.msg.msg_id) { + return MwResult::Skip("dedup"); + } + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for SkipSelfMw { + fn name(&self) -> &'static str { + "skip_self" + } + async fn process(&self, s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + if !s.bot_id.is_empty() && c.msg.from_account == s.bot_id { + return MwResult::Skip("skip_self"); + } + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for ChatRoutingMw { + fn name(&self) -> &'static str { + "chat_routing" + } + async fn process(&self, _s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + c.source.is_group = c.msg.is_group(); + c.source.group_code = c.msg.group_code.clone(); + c.source.from_account = c.msg.from_account.clone(); + c.source.sender_nickname = c.msg.sender_nickname.clone(); + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for AccessGuardMw { + fn name(&self) -> &'static str { + "access_guard" + } + async fn process(&self, s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + let (policy, allow_list, key) = if c.source.is_group { + (s.group_access, &s.allowed_groups, &c.source.group_code) + } else { + (s.dm_access, &s.allowed_users, &c.source.from_account) + }; + let pass = match policy { + AccessPolicy::Open => true, + AccessPolicy::Closed => false, + AccessPolicy::Allowlist => allow_list.iter().any(|u| u == "*" || u == key), + }; + if !pass { + return MwResult::Skip("access_guard"); + } + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for AutoSetHomeMw { + fn name(&self) -> &'static str { + "auto_set_home" + } + async fn process(&self, s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + if !c.source.is_group { + let mut home = s.home_chat.write().await; + if home.is_none() { + *home = Some(c.source.reply_target()); + } + } + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for ExtractContentMw { + fn name(&self) -> &'static str { + "extract_content" + } + async fn process(&self, _s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + c.text = c.msg.extract_text(); + c.image_urls = c.msg.extract_image_urls(); + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for PlaceholderFilterMw { + fn name(&self) -> &'static str { + "placeholder_filter" + } + async fn process(&self, _s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + let trimmed = c.text.trim(); + let is_placeholder = trimmed == "[image]" || trimmed == "[file]" || trimmed == "[图片]"; + if (trimmed.is_empty() || is_placeholder) && c.image_urls.is_empty() { + return MwResult::Skip("placeholder_filter"); + } + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for OwnerCommandMw { + fn name(&self) -> &'static str { + "owner_command" + } + async fn process(&self, s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + if !s.owner_id.is_empty() + && c.msg.from_account == s.owner_id + && c.text.trim_start().starts_with('/') + { + c.is_owner_command = true; + } + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for BuildSourceMw { + fn name(&self) -> &'static str { + "build_source" + } + async fn process(&self, _s: &PipelineState, _c: &mut PipelineCtx) -> MwResult { + // Source already populated by ChatRoutingMw. + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for GroupAtGuardMw { + fn name(&self) -> &'static str { + "group_at_guard" + } + async fn process(&self, s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + if !c.source.is_group || !s.group_at_required { + return MwResult::Continue; + } + let by_name = !s.bot_name.is_empty() && c.text.contains(&format!("@{}", s.bot_name)); + let by_id = !s.bot_id.is_empty() && c.text.contains(&format!("@{}", s.bot_id)); + let by_mention = + !s.bot_id.is_empty() && c.text.contains(&format!("[at|userId:{}]", s.bot_id)); + c.is_at_bot = by_name || by_id || by_mention; + if !c.is_at_bot && !c.is_owner_command { + return MwResult::Skip("group_at_guard"); + } + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for GroupAttributionMw { + fn name(&self) -> &'static str { + "group_attribution" + } + async fn process(&self, s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + // Strip `@bot` from text and the TIM `[at|userId:…]` markup. + if c.source.is_group && c.is_at_bot { + if !s.bot_name.is_empty() { + c.text = c.text.replace(&format!("@{}", s.bot_name), ""); + } + if !s.bot_id.is_empty() { + c.text = c.text.replace(&format!("@{}", s.bot_id), ""); + c.text = c.text.replace(&format!("[at|userId:{}]", s.bot_id), ""); + } + c.text = c.text.trim().to_string(); + } + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for ClassifyMsgTypeMw { + fn name(&self) -> &'static str { + "classify_msg_type" + } + async fn process(&self, _s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + let has_text = !c.text.is_empty(); + let has_image = !c.image_urls.is_empty(); + let has_file = c.msg.msg_body.iter().any(|el| el.msg_type == tim::FILE); + let has_sound = c.msg.msg_body.iter().any(|el| el.msg_type == tim::SOUND); + c.kind = match (has_text, has_image, has_file, has_sound) { + (_, true, _, _) if has_text => MessageKind::Mixed, + (_, true, _, _) => MessageKind::Image, + (_, _, true, _) => MessageKind::File, + (_, _, _, true) => MessageKind::Voice, + _ => MessageKind::Text, + }; + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for QuoteContextMw { + fn name(&self) -> &'static str { + "quote_context" + } + async fn process(&self, _s: &PipelineState, c: &mut PipelineCtx) -> MwResult { + // The cloud_custom_data field carries a JSON quote envelope; for + // now we just leave the raw payload accessible via `msg.cloud_custom_data` + // for downstream tools. Full parsing is intentionally deferred — + // hermes-agent does it lazily too. + let _ = c; + MwResult::Continue + } +} + +#[async_trait] +impl Middleware for MediaResolveMw { + fn name(&self) -> &'static str { + "media_resolve" + } + async fn process(&self, _s: &PipelineState, _c: &mut PipelineCtx) -> MwResult { + // ybres:// resource URLs would be resolved here. Currently URLs + // arrive pre-resolved from the server; expand later if needed. + MwResult::Continue + } +} + +/// Composite pipeline = ordered Vec of inbound stages. +pub struct InboundPipeline { + state: Arc, + stages: Vec>, +} + +impl InboundPipeline { + pub fn new(state: Arc) -> Self { + let stages: Vec> = vec![ + Box::new(DecodeMw), + Box::new(ExtractFieldsMw), + Box::new(RecallGuardMw), + Box::new(DedupMw), + Box::new(SkipSelfMw), + Box::new(ChatRoutingMw), + Box::new(AccessGuardMw), + Box::new(AutoSetHomeMw), + Box::new(ExtractContentMw), + Box::new(PlaceholderFilterMw), + Box::new(OwnerCommandMw), + Box::new(BuildSourceMw), + Box::new(GroupAtGuardMw), + Box::new(GroupAttributionMw), + Box::new(ClassifyMsgTypeMw), + Box::new(QuoteContextMw), + Box::new(MediaResolveMw), + ]; + Self { state, stages } + } + + /// Decode a biz push body, run it through every stage, return the outcome. + /// + /// The yuanbao gateway may push the biz body as either protobuf + /// (`InboundMessagePush`) or a JSON string with the same field shape + /// (snake_case + `log_ext.trace_id`). We sniff the first non-whitespace + /// byte to pick the decoder — `{` means JSON, anything else is treated + /// as protobuf. Mirrors plugin gateway.ts::wsPushToInboundMessage + /// (l. 288), which tries protobuf first and falls back to JSON. + pub async fn process(&self, biz_body: &[u8]) -> PipelineOutcome { + let is_json = biz_body + .iter() + .find(|b| !b.is_ascii_whitespace()) + .map(|b| *b == b'{') + .unwrap_or(false); + + let msg = if is_json { + match decode_inbound_json(biz_body) { + Ok(m) => m, + Err(e) => return PipelineOutcome::Failed(e), + } + } else { + match decode_inbound_push(biz_body) { + Ok(m) => m, + Err(e) => return PipelineOutcome::Failed(e), + } + }; + let mut ctx = PipelineCtx { + msg, + source: Source::default(), + text: String::new(), + image_urls: Vec::new(), + is_at_bot: false, + is_owner_command: false, + kind: MessageKind::Text, + }; + for stage in &self.stages { + match stage.process(&self.state, &mut ctx).await { + MwResult::Continue => { + trace!("[yuanbao:inbound] {} pass", stage.name()); + } + MwResult::Skip(reason) => { + debug!("[yuanbao:inbound] {} filtered ({})", stage.name(), reason); + return PipelineOutcome::Filtered(reason); + } + MwResult::Abort(err) => return PipelineOutcome::Failed(err), + } + } + PipelineOutcome::Dispatch(ctx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn cfg(bot_id: &str) -> YuanbaoConfig { + let mut c = YuanbaoConfig::default(); + c.app_key = "ak".into(); + c.ws_domain = "wss://x".into(); + c.token = "tok".into(); + c.bot_id = bot_id.into(); + c.bot_name = "bot".into(); + c.dm_access = "open".into(); + c.group_access = "open".into(); + c + } + + fn ctx_with(msg: InboundMessage) -> PipelineCtx { + PipelineCtx { + msg, + source: Source::default(), + text: String::new(), + image_urls: Vec::new(), + is_at_bot: false, + is_owner_command: false, + kind: MessageKind::Text, + } + } + + #[tokio::test] + async fn dedup_skips_repeat() { + let state = PipelineState::new(&cfg("bot1"), "bot1".into()); + let mw = DedupMw; + let msg = InboundMessage { + msg_id: "m1".into(), + ..Default::default() + }; + let mut c1 = ctx_with(msg.clone()); + assert!(matches!( + mw.process(&state, &mut c1).await, + MwResult::Continue + )); + let mut c2 = ctx_with(msg); + assert!(matches!( + mw.process(&state, &mut c2).await, + MwResult::Skip(_) + )); + } + + #[tokio::test] + async fn access_guard_open() { + let state = PipelineState::new(&cfg("bot1"), "bot1".into()); + let mw = AccessGuardMw; + let mut c = ctx_with(InboundMessage { + from_account: "alice".into(), + ..Default::default() + }); + c.source.is_group = false; + c.source.from_account = "alice".into(); + assert!(matches!( + mw.process(&state, &mut c).await, + MwResult::Continue + )); + } + + #[tokio::test] + async fn full_dm_dispatch() { + let mut config = cfg("bot1"); + config.group_at_required = false; + let state = PipelineState::new(&config, "bot1".into()); + let pipeline = InboundPipeline::new(state); + let msg = InboundMessage { + from_account: "alice".into(), + to_account: "bot1".into(), + msg_id: "hi".into(), + msg_body: vec![MsgBodyElement { + msg_type: "TIMTextElem".into(), + msg_content: MsgContent { + text: Some("hello".into()), + ..Default::default() + }, + }], + ..Default::default() + }; + let body = crate::openhuman::channels::providers::yuanbao::proto::encode_msg_body_element( + &msg.msg_body[0], + ); + // Synthesize an InboundMessagePush from scratch: + use crate::openhuman::channels::providers::yuanbao::proto; + let mut buf = Vec::new(); + let mut put_str = |fnum: u32, s: &str, b: &mut Vec| { + proto::encode_varint(((fnum as u64) << 3) | 2, b); + proto::encode_varint(s.len() as u64, b); + b.extend_from_slice(s.as_bytes()); + }; + put_str(2, &msg.from_account, &mut buf); + put_str(3, &msg.to_account, &mut buf); + put_str(12, &msg.msg_id, &mut buf); + proto::encode_varint(((13u64) << 3) | 2, &mut buf); + proto::encode_varint(body.len() as u64, &mut buf); + buf.extend_from_slice(&body); + + let outcome = pipeline.process(&buf).await; + assert!( + matches!(outcome, PipelineOutcome::Dispatch(_)), + "got {:?}", + outcome + ); + } +} diff --git a/src/openhuman/channels/providers/yuanbao/media.rs b/src/openhuman/channels/providers/yuanbao/media.rs new file mode 100644 index 0000000000..0066e5a480 --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/media.rs @@ -0,0 +1,339 @@ +//! Media helpers — MIME mapping, byte-level image dimension parsing, +//! download with size cap, and TIM `msg_body` builders. +//! +//! Tencent COS upload lives in [`super::cos`] to keep both files under +//! the 500-line ceiling. + +use super::errors::YuanbaoError; +use super::types::MsgBodyElement; + +// ─── MIME / image-format mapping ─────────────────────────────────── + +pub fn guess_mime_type(filename: &str) -> &'static str { + let ext = filename + .rsplit_once('.') + .map(|(_, e)| e.to_ascii_lowercase()) + .unwrap_or_default(); + match ext.as_str() { + "jpg" | "jpeg" => "image/jpeg", + "png" => "image/png", + "gif" => "image/gif", + "webp" => "image/webp", + "bmp" => "image/bmp", + "heic" => "image/heic", + "tiff" => "image/tiff", + "ico" => "image/x-icon", + "pdf" => "application/pdf", + "doc" => "application/msword", + "docx" => "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "xls" => "application/vnd.ms-excel", + "xlsx" => "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "ppt" => "application/vnd.ms-powerpoint", + "pptx" => "application/vnd.openxmlformats-officedocument.presentationml.presentation", + "txt" => "text/plain", + "zip" => "application/zip", + "tar" => "application/x-tar", + "gz" => "application/gzip", + "mp3" => "audio/mpeg", + "mp4" => "video/mp4", + "wav" => "audio/wav", + "ogg" => "audio/ogg", + "webm" => "video/webm", + _ => "application/octet-stream", + } +} + +pub fn is_image(filename: &str, mime_type: &str) -> bool { + if mime_type.starts_with("image/") { + return true; + } + guess_mime_type(filename).starts_with("image/") +} + +/// Map a MIME type to the TIM `image_format` enum. +pub fn image_format_code(mime: &str) -> u32 { + match mime { + "image/jpeg" | "image/jpg" => 1, + "image/gif" => 2, + "image/png" => 3, + "image/bmp" => 4, + "image/webp" | "image/heic" | "image/tiff" => 255, + _ => 255, + } +} + +// ─── Pure-bytes image dimension parsing ───────────────────────────── + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ImageDims { + pub width: u32, + pub height: u32, +} + +pub fn parse_image_size(data: &[u8]) -> Option { + parse_png(data) + .or_else(|| parse_jpeg(data)) + .or_else(|| parse_gif(data)) + .or_else(|| parse_webp(data)) +} + +fn parse_png(buf: &[u8]) -> Option { + if buf.len() < 24 || &buf[..4] != b"\x89PNG" { + return None; + } + let w = u32::from_be_bytes(buf[16..20].try_into().ok()?); + let h = u32::from_be_bytes(buf[20..24].try_into().ok()?); + Some(ImageDims { + width: w, + height: h, + }) +} + +fn parse_jpeg(buf: &[u8]) -> Option { + if buf.len() < 4 || buf[0] != 0xFF || buf[1] != 0xD8 { + return None; + } + let mut i = 2usize; + while i + 9 < buf.len() { + if buf[i] != 0xFF { + i += 1; + continue; + } + let marker = buf[i + 1]; + if marker == 0xC0 || marker == 0xC2 { + let h = u16::from_be_bytes(buf[i + 5..i + 7].try_into().ok()?); + let w = u16::from_be_bytes(buf[i + 7..i + 9].try_into().ok()?); + return Some(ImageDims { + width: w as u32, + height: h as u32, + }); + } + if i + 3 >= buf.len() { + break; + } + let seg_len = u16::from_be_bytes(buf[i + 2..i + 4].try_into().ok()?) as usize; + i += 2 + seg_len; + } + None +} + +fn parse_gif(buf: &[u8]) -> Option { + if buf.len() < 10 { + return None; + } + let sig = &buf[..6]; + if sig != b"GIF87a" && sig != b"GIF89a" { + return None; + } + let w = u16::from_le_bytes(buf[6..8].try_into().ok()?); + let h = u16::from_le_bytes(buf[8..10].try_into().ok()?); + Some(ImageDims { + width: w as u32, + height: h as u32, + }) +} + +fn parse_webp(buf: &[u8]) -> Option { + if buf.len() < 16 || &buf[..4] != b"RIFF" || &buf[8..12] != b"WEBP" { + return None; + } + let chunk = &buf[12..16]; + if chunk == b"VP8 " { + if buf.len() >= 30 && buf[23] == 0x9D && buf[24] == 0x01 && buf[25] == 0x2A { + let w = u16::from_le_bytes(buf[26..28].try_into().ok()?) & 0x3FFF; + let h = u16::from_le_bytes(buf[28..30].try_into().ok()?) & 0x3FFF; + return Some(ImageDims { + width: w as u32, + height: h as u32, + }); + } + } else if chunk == b"VP8L" { + if buf.len() >= 25 && buf[20] == 0x2F { + let bits = u32::from_le_bytes(buf[21..25].try_into().ok()?); + let w = (bits & 0x3FFF) + 1; + let h = ((bits >> 14) & 0x3FFF) + 1; + return Some(ImageDims { + width: w, + height: h, + }); + } + } else if chunk == b"VP8X" && buf.len() >= 30 { + let w = (buf[24] as u32 | ((buf[25] as u32) << 8) | ((buf[26] as u32) << 16)) + 1; + let h = (buf[27] as u32 | ((buf[28] as u32) << 8) | ((buf[29] as u32) << 16)) + 1; + return Some(ImageDims { + width: w, + height: h, + }); + } + None +} + +// ─── HTTP download with size cap ──────────────────────────────────── + +pub async fn download_url( + http: &reqwest::Client, + url: &str, + max_size_mb: u64, +) -> Result<(Vec, String), YuanbaoError> { + let limit = max_size_mb.saturating_mul(1024 * 1024); + + if let Ok(head) = http.head(url).send().await { + if let Some(len) = head + .headers() + .get(reqwest::header::CONTENT_LENGTH) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()) + { + if len > limit { + return Err(YuanbaoError::Media(format!( + "remote file too large: {len} > limit {limit}" + ))); + } + } + } + + let resp = http + .get(url) + .send() + .await + .map_err(|e| YuanbaoError::Connection(format!("download {url}: {e}")))?; + if !resp.status().is_success() { + return Err(YuanbaoError::Media(format!( + "download HTTP {} for {url}", + resp.status() + ))); + } + let ct = resp + .headers() + .get(reqwest::header::CONTENT_TYPE) + .and_then(|v| v.to_str().ok()) + .unwrap_or("") + .split(';') + .next() + .unwrap_or("") + .trim() + .to_string(); + + let bytes = resp + .bytes() + .await + .map_err(|e| YuanbaoError::Media(format!("read body: {e}")))?; + if bytes.len() as u64 > limit { + return Err(YuanbaoError::Media(format!( + "downloaded file exceeds limit: {} > {}", + bytes.len(), + limit + ))); + } + Ok((bytes.to_vec(), ct)) +} + +// ─── TIM msg_body builders ────────────────────────────────────────── + +/// Build a TIM `TIMImageElem` `msg_body` ready to send. +pub fn build_image_msg_body( + url: &str, + uuid: Option<&str>, + filename: Option<&str>, + size: u32, + width: u32, + height: u32, + mime_type: &str, +) -> Vec { + use super::types::{ImageInfo, MsgContent}; + let uuid_str = uuid + .map(|s| s.to_string()) + .or_else(|| filename.map(|s| s.to_string())) + .unwrap_or_else(|| "image".to_string()); + let format = if mime_type.is_empty() { + 255 + } else { + image_format_code(mime_type) + }; + vec![MsgBodyElement { + msg_type: "TIMImageElem".into(), + msg_content: MsgContent { + uuid: Some(uuid_str), + image_format: Some(format), + image_info_array: vec![ImageInfo { + image_type: 1, + size, + width, + height, + url: url.to_string(), + }], + ..Default::default() + }, + }] +} + +/// Build a TIM `TIMFileElem` `msg_body` ready to send. +pub fn build_file_msg_body( + url: &str, + filename: &str, + uuid: Option<&str>, + size: u32, +) -> Vec { + use super::types::MsgContent; + let uuid_str = uuid + .map(|s| s.to_string()) + .unwrap_or_else(|| filename.to_string()); + vec![MsgBodyElement { + msg_type: "TIMFileElem".into(), + msg_content: MsgContent { + uuid: Some(uuid_str), + file_name: Some(filename.to_string()), + file_size: Some(size), + url: Some(url.to_string()), + ..Default::default() + }, + }] +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn png_dims_parse() { + let png = [ + 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, + 0x44, 0x52, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x06, + ]; + let d = parse_image_size(&png).expect("png parse"); + assert_eq!(d.width, 1); + assert_eq!(d.height, 1); + } + + #[test] + fn gif_dims_parse() { + let gif = b"GIF89a\x40\x01\xF0\x00rest"; + let d = parse_image_size(gif).expect("gif parse"); + assert_eq!(d.width, 320); + assert_eq!(d.height, 240); + } + + #[test] + fn guess_mime_basic() { + assert_eq!(guess_mime_type("foo.png"), "image/png"); + assert_eq!(guess_mime_type("doc.pdf"), "application/pdf"); + assert_eq!(guess_mime_type("blob"), "application/octet-stream"); + } + + #[test] + fn is_image_works() { + assert!(is_image("a.jpeg", "")); + assert!(is_image("noext", "image/png")); + assert!(!is_image("a.pdf", "")); + } + + #[test] + fn image_format_code_matrix() { + assert_eq!(image_format_code("image/png"), 3); + assert_eq!(image_format_code("image/jpeg"), 1); + assert_eq!(image_format_code("image/gif"), 2); + assert_eq!(image_format_code("image/bmp"), 4); + assert_eq!(image_format_code("image/webp"), 255); + assert_eq!(image_format_code("application/pdf"), 255); + } +} diff --git a/src/openhuman/channels/providers/yuanbao/mod.rs b/src/openhuman/channels/providers/yuanbao/mod.rs new file mode 100644 index 0000000000..23e8ad99d0 --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/mod.rs @@ -0,0 +1,29 @@ +//! Yuanbao (元宝) channel provider. +//! +//! This module is intentionally export-focused. Operational code lives in +//! sibling modules: +//! - [`channel`] wires the provider into the generic OpenHuman `Channel` trait. +//! - [`connection`] owns the WebSocket transport and request correlator. +//! - [`inbound`] owns inbound filtering/extraction. +//! - [`outbound`] owns Yuanbao send/query calls. +//! - [`proto`] / [`proto_biz`] / [`wire`] own hand-written protobuf codecs. + +pub mod channel; +pub mod config; +pub mod connection; +pub mod cos; +pub mod errors; +pub mod ids; +pub mod inbound; +pub mod media; +pub mod outbound; +pub mod proto; +pub mod proto_biz; +pub mod proto_constants; +pub mod sign; +pub mod splitter; +pub mod types; +pub mod wire; + +pub use channel::YuanbaoChannel; +pub use config::YuanbaoConfig; diff --git a/src/openhuman/channels/providers/yuanbao/outbound.rs b/src/openhuman/channels/providers/yuanbao/outbound.rs new file mode 100644 index 0000000000..4f9da0e50c --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/outbound.rs @@ -0,0 +1,378 @@ +//! Outbound message sender. +//! +//! Translates high-level `send_text` / `send_image` / heartbeat calls +//! into encoded ConnMsg frames and pushes them through the shared +//! `YuanbaoConnection`. The recipient string uses the convention +//! `g:` for groups, raw `` for DMs. +//! +//! For the few request kinds where we care about the response body +//! (notably `QueryGroupInfo`, `GetGroupMemberList`, and `SendXxxMessage`'s +//! `code/msg_id` echo) we use the connection-level pending-acks +//! correlator and return parsed results to the caller. Heartbeats are +//! fire-and-forget — the response is never inspected. + +use std::sync::Arc; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use tracing::debug; + +use super::connection::YuanbaoConnection; +use super::cos::{get_cos_credentials, upload_to_cos}; +use super::errors::YuanbaoError; +use super::media::{build_file_msg_body, build_image_msg_body, download_url, parse_image_size}; +use super::proto_biz::{ + decode_get_group_member_list_rsp, decode_query_group_info_rsp, encode_get_group_member_list, + encode_query_group_info, encode_send_c2c_message, encode_send_group_heartbeat, + encode_send_group_message, encode_send_private_heartbeat, +}; +use super::proto_constants::{ws_heartbeat, DEFAULT_SEND_TIMEOUT_SECS}; +use super::sign::SignManager; +use super::types::{GroupInfo, GroupMemberListPage, MsgBodyElement}; + +const GROUP_PREFIX: &str = "g:"; +/// Wait-for-response timeout on queries like `QueryGroupInfo`. +const QUERY_TIMEOUT_SECS: u64 = 10; + +/// Parsed addressing target. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Target<'a> { + Dm(&'a str), + Group(&'a str), +} + +impl<'a> Target<'a> { + pub fn parse(recipient: &'a str) -> Self { + if let Some(rest) = recipient.strip_prefix(GROUP_PREFIX) { + Self::Group(rest) + } else { + Self::Dm(recipient) + } + } +} + +pub struct OutboundSender { + conn: Arc, + /// Sign-token cache holding the server-issued `bot_id`. Populated as a + /// side effect of `connection`'s sign+auth-bind flow. The `bot_id` here + /// is what `yuanbao_openclaw_proxy` expects in the outbound + /// `from_account` field — config-only fallbacks like `app_key` get + /// silently accepted (status=0) but never routed to a real conv id. + sign_manager: Option>, + /// Lookup key for `sign_manager.cached(app_key)`. + app_key: String, + /// User-supplied bot id override; empty when not set. Only used when + /// the sign cache hasn't been primed yet (e.g. send-before-auth races). + config_bot_id: String, + http: reqwest::Client, +} + +impl OutboundSender { + pub fn new( + conn: Arc, + sign_manager: Option>, + app_key: String, + config_bot_id: String, + ) -> Self { + Self { + conn, + sign_manager, + app_key, + config_bot_id, + http: reqwest::Client::new(), + } + } + + /// Resolve the `from_account` to put on the next outbound frame. + /// Prefers the server-issued `bot_id` cached after sign-token / auth-bind + /// — matches hermes-agent `_bot_id = token_data["bot_id"]` (yuanbao.py:400). + async fn resolve_from_account(&self) -> String { + if let Some(sign) = &self.sign_manager { + if let Some(entry) = sign.cached(&self.app_key).await { + if !entry.bot_id.is_empty() { + return entry.bot_id; + } + } + } + self.config_bot_id.clone() + } + + /// Send a plain-text message. Returns the client-side `msg_id`. + pub async fn send_text( + &self, + recipient: &str, + text: &str, + ref_msg_id: Option<&str>, + ) -> Result { + let body = vec![MsgBodyElement { + msg_type: "TIMTextElem".into(), + msg_content: super::types::MsgContent { + text: Some(text.to_string()), + ..Default::default() + }, + }]; + self.send_body(recipient, body, ref_msg_id).await + } + + /// Send an image by an already-uploaded (COS or other) URL. + #[allow(clippy::too_many_arguments)] + pub async fn send_image_url( + &self, + recipient: &str, + url: &str, + size: u32, + width: u32, + height: u32, + mime_type: &str, + ) -> Result { + let body = build_image_msg_body(url, None, None, size, width, height, mime_type); + self.send_body(recipient, body, None).await + } + + /// End-to-end image send: download from URL → upload to COS → send + /// as a `TIMImageElem`. Returns the outbound `msg_id`. + /// + /// `app_key` / `bot_id` / `token` / `api_domain` / `route_env` come + /// from the channel config; pass them in rather than reaching back + /// through the conn to keep this fn easy to unit-test. + #[allow(clippy::too_many_arguments)] + pub async fn send_image_from_url( + &self, + recipient: &str, + source_url: &str, + app_key: &str, + bot_id: &str, + token: &str, + api_domain: &str, + route_env: &str, + max_size_mb: u64, + ) -> Result { + let (bytes, mime) = download_url(&self.http, source_url, max_size_mb).await?; + let dims = parse_image_size(&bytes); + let width = dims.as_ref().map(|d| d.width).unwrap_or(0); + let height = dims.as_ref().map(|d| d.height).unwrap_or(0); + + let filename = extract_filename(source_url); + let creds = get_cos_credentials( + &self.http, api_domain, app_key, bot_id, token, route_env, &filename, + ) + .await?; + let upload = upload_to_cos(&self.http, &creds, &bytes, &filename, mime.clone()).await?; + + let final_width = if upload.width > 0 { + upload.width + } else { + width + }; + let final_height = if upload.height > 0 { + upload.height + } else { + height + }; + let body = build_image_msg_body( + &upload.url, + Some(&upload.uuid), + Some(&filename), + upload.size as u32, + final_width, + final_height, + &mime, + ); + self.send_body(recipient, body, None).await + } + + /// Send a file by URL. + pub async fn send_file_url( + &self, + recipient: &str, + url: &str, + file_name: &str, + size: u32, + ) -> Result { + let body = build_file_msg_body(url, file_name, None, size); + self.send_body(recipient, body, None).await + } + + /// Send a pre-built `msg_body`. Waits up to `DEFAULT_SEND_TIMEOUT_SECS` + /// for the server response so the caller learns about delivery + /// failures (rate-limit, banned content, etc.) instead of getting a + /// silent drop. + pub async fn send_body( + &self, + recipient: &str, + msg_body: Vec, + ref_msg_id: Option<&str>, + ) -> Result { + let msg_id = self.next_msg_id(); + let target = Target::parse(recipient); + let from_account = self.resolve_from_account().await; + let frame = match target { + Target::Dm(uid) => encode_send_c2c_message( + uid, + &from_account, + &msg_body, + &msg_id, + random_u32(), + "", + "", + ), + Target::Group(group_code) => { + let random = format!("{}", random_u32()); + encode_send_group_message( + group_code, + &from_account, + &msg_body, + &msg_id, + "", + &random, + ref_msg_id.unwrap_or(""), + "", + ) + } + }; + + let timeout = Duration::from_secs(DEFAULT_SEND_TIMEOUT_SECS); + match self.conn.send_and_wait(&msg_id, frame, timeout).await { + Ok(resp) => { + if resp.status != 0 { + return Err(YuanbaoError::SendFailed(format!( + "server status={} cmd={}", + resp.status, resp.cmd + ))); + } + debug!("[outbound] ack msg_id={msg_id} target={:?}", target); + Ok(msg_id) + } + // If the correlator isn't usable yet (NotConnected etc.) bubble up. + Err(e) => Err(e), + } + } + + /// Send a "thinking" heartbeat (RUNNING) — fire-and-forget. + pub async fn start_heartbeat(&self, recipient: &str) -> Result<(), YuanbaoError> { + self.send_heartbeat(recipient, ws_heartbeat::RUNNING).await + } + + /// Send a "done" heartbeat (FINISH) — fire-and-forget. + pub async fn stop_heartbeat(&self, recipient: &str) -> Result<(), YuanbaoError> { + self.send_heartbeat(recipient, ws_heartbeat::FINISH).await + } + + async fn send_heartbeat(&self, recipient: &str, heartbeat: u32) -> Result<(), YuanbaoError> { + let req_id = self.conn.next_msg_id("hb"); + let from_account = self.resolve_from_account().await; + let frame = match Target::parse(recipient) { + Target::Dm(uid) => { + encode_send_private_heartbeat(&req_id, &from_account, uid, heartbeat) + } + Target::Group(group_code) => { + let now_ms = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_millis() as u64) + .unwrap_or(0); + encode_send_group_heartbeat(&req_id, &from_account, group_code, heartbeat, now_ms) + } + }; + // Fire-and-forget — we don't care about the heartbeat ack. + self.conn.send_conn_msg(frame).await + } + + /// Query group info and wait for the server's reply. + pub async fn query_group_info(&self, group_code: &str) -> Result { + let req_id = self.conn.next_msg_id("qgi"); + let frame = encode_query_group_info(&req_id, group_code); + let resp = self + .conn + .send_and_wait(&req_id, frame, Duration::from_secs(QUERY_TIMEOUT_SECS)) + .await?; + decode_query_group_info_rsp(&resp.data) + } + + /// Fetch one page of group members. Use `offset=0, limit=100` for the + /// first page; the response carries `next_offset` for pagination. + pub async fn query_group_members( + &self, + group_code: &str, + offset: u32, + limit: u32, + ) -> Result { + let req_id = self.conn.next_msg_id("qgm"); + let frame = encode_get_group_member_list(&req_id, group_code, offset, limit); + let resp = self + .conn + .send_and_wait(&req_id, frame, Duration::from_secs(QUERY_TIMEOUT_SECS)) + .await?; + decode_get_group_member_list_rsp(&resp.data) + } + + fn next_msg_id(&self) -> String { + // Use a stable prefix so logs can be grepped across send paths. + self.conn.next_msg_id("om") + } +} + +fn random_u32() -> u32 { + rand::random::() +} + +/// Best-effort file name extraction from a URL — uses the URL's path +/// component (so the host is never picked up as a filename) and falls +/// back to "file" if there's nothing usable. +fn extract_filename(url_str: &str) -> String { + if let Ok(parsed) = url::Url::parse(url_str) { + if let Some(segments) = parsed.path_segments() { + if let Some(last) = segments.filter(|s| !s.is_empty()).last() { + return last.to_string(); + } + } + return "file".to_string(); + } + // Non-URL input (relative path, raw filename, etc.) — fall back to + // last non-empty `/`-delimited segment. + let without_query = url_str.split('?').next().unwrap_or(url_str); + let last = without_query + .rsplit('/') + .find(|s| !s.is_empty()) + .unwrap_or(""); + if last.is_empty() { + "file".to_string() + } else { + last.to_string() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn target_parse_dm() { + match Target::parse("user_42") { + Target::Dm(uid) => assert_eq!(uid, "user_42"), + _ => panic!("should be DM"), + } + } + + #[test] + fn target_parse_group() { + match Target::parse("g:room_99") { + Target::Group(c) => assert_eq!(c, "room_99"), + _ => panic!("should be group"), + } + } + + #[test] + fn target_parse_empty_dm() { + assert!(matches!(Target::parse(""), Target::Dm(""))); + } + + #[test] + fn extract_filename_strips_query() { + assert_eq!(extract_filename("https://x.com/a/b/cat.png"), "cat.png"); + assert_eq!( + extract_filename("https://x.com/a/b/cat.png?sig=abc"), + "cat.png" + ); + assert_eq!(extract_filename("https://x.com/"), "file"); + assert_eq!(extract_filename(""), "file"); + } +} diff --git a/src/openhuman/channels/providers/yuanbao/proto.rs b/src/openhuman/channels/providers/yuanbao/proto.rs new file mode 100644 index 0000000000..41fad20170 --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/proto.rs @@ -0,0 +1,676 @@ +//! Yuanbao WebSocket ConnMsg envelope + built-in protocol commands +//! (auth-bind, ping, push-ack) + TIM `MsgBodyElement` codecs. +//! +//! Each WebSocket binary frame carries one full `ConnMsg` protobuf +//! message; **no extra length prefix is needed** (the WS frame boundary +//! delimits one message). Verified against the hermes-agent Python +//! reference (yuanbao_proto.py) and the TypeScript openclaw plugin. +//! +//! Business-layer codecs (send-message / heartbeat / group query) live +//! in [`super::proto_biz`]. Wire-format primitives (varint, FieldValue, +//! parse_fields) live in [`super::wire`]. + +use super::errors::YuanbaoError; +use super::proto_constants::*; +use super::types::*; +use super::wire::{ + encode_field_bytes, encode_field_string, encode_field_varint, get_bytes, get_repeated_bytes, + get_string, get_varint, next_seq_no, parse_fields, FieldValue, +}; + +// Re-export wire primitives for downstream callers (tests, tools). +pub use super::wire::{decode_varint, encode_varint}; + +// ─── ConnMsg envelope ────────────────────────────────────────────── +// +// message Head { +// uint32 cmd_type = 1; +// string cmd = 2; +// uint32 seq_no = 3; +// string msg_id = 4; +// string module = 5; +// bool need_ack = 6; +// int32 status = 10; +// } +// message ConnMsg { +// Head head = 1; +// bytes data = 2; +// } + +fn encode_head( + cmd_type: u32, + cmd: &str, + seq_no: u32, + msg_id: &str, + module: &str, + need_ack: bool, + status: u32, +) -> Vec { + let mut buf = Vec::with_capacity(64); + if cmd_type != 0 { + encode_field_varint(1, cmd_type as u64, &mut buf); + } + if !cmd.is_empty() { + encode_field_string(2, cmd, &mut buf); + } + if seq_no != 0 { + encode_field_varint(3, seq_no as u64, &mut buf); + } + if !msg_id.is_empty() { + encode_field_string(4, msg_id, &mut buf); + } + if !module.is_empty() { + encode_field_string(5, module, &mut buf); + } + if need_ack { + encode_field_varint(6, 1, &mut buf); + } + if status != 0 { + encode_field_varint(10, status as u64, &mut buf); + } + buf +} + +/// Encode a full `ConnMsg` frame (ready to send as a binary WS frame). +pub fn encode_conn_msg( + cmd_type: u32, + cmd: &str, + seq_no: u32, + msg_id: &str, + module: &str, + data: &[u8], +) -> Vec { + let head = encode_head(cmd_type, cmd, seq_no, msg_id, module, false, 0); + let mut buf = Vec::with_capacity(head.len() + data.len() + 16); + encode_field_bytes(1, &head, &mut buf); + if !data.is_empty() { + encode_field_bytes(2, data, &mut buf); + } + buf +} + +/// Decode a `ConnMsg` frame received from the gateway. +pub fn decode_conn_msg(data: &[u8]) -> Result { + let fields = parse_fields(data)?; + let head_bytes = get_bytes(&fields, 1); + let payload = get_bytes(&fields, 2); + let head_fields = if head_bytes.is_empty() { + Vec::new() + } else { + parse_fields(&head_bytes)? + }; + Ok(ConnFrame { + cmd_type: get_varint(&head_fields, 1) as u32, + cmd: get_string(&head_fields, 2), + seq_no: get_varint(&head_fields, 3) as u32, + msg_id: get_string(&head_fields, 4), + module: get_string(&head_fields, 5), + need_ack: get_varint(&head_fields, 6) != 0, + status: get_varint(&head_fields, 10) as u32, + data: payload, + }) +} + +// ─── Built-in protocol commands ──────────────────────────────────── + +/// `AuthBindReq` — first request after the WebSocket opens. +#[allow(clippy::too_many_arguments)] +pub fn encode_auth_bind( + biz_id: &str, + uid: &str, + source: &str, + token: &str, + msg_id: &str, + app_version: &str, + operation_system: &str, + bot_version: &str, + route_env: &str, +) -> Vec { + let mut auth_buf = Vec::with_capacity(uid.len() + source.len() + token.len() + 16); + encode_field_string(1, uid, &mut auth_buf); + encode_field_string(2, source, &mut auth_buf); + encode_field_string(3, token, &mut auth_buf); + + let mut dev_buf = Vec::with_capacity(64); + if !app_version.is_empty() { + encode_field_string(1, app_version, &mut dev_buf); + } + if !operation_system.is_empty() { + encode_field_string(2, operation_system, &mut dev_buf); + } + encode_field_string(10, OPENHUMAN_INSTANCE_ID, &mut dev_buf); + if !bot_version.is_empty() { + encode_field_string(24, bot_version, &mut dev_buf); + } + + let mut req_buf = Vec::with_capacity(auth_buf.len() + dev_buf.len() + biz_id.len() + 16); + encode_field_string(1, biz_id, &mut req_buf); + encode_field_bytes(2, &auth_buf, &mut req_buf); + encode_field_bytes(3, &dev_buf, &mut req_buf); + if !route_env.is_empty() { + encode_field_string(5, route_env, &mut req_buf); + } + + encode_conn_msg( + cmd_type::REQUEST, + cmd::AUTH_BIND, + next_seq_no(), + msg_id, + module::CONN_ACCESS, + &req_buf, + ) +} + +pub fn encode_ping(msg_id: &str) -> Vec { + encode_conn_msg( + cmd_type::REQUEST, + cmd::PING, + next_seq_no(), + msg_id, + module::CONN_ACCESS, + &[], + ) +} + +/// Decoded `AuthBindRsp` body. +/// +/// message AuthBindRsp { +/// int32 code = 1; +/// string message = 2; +/// string connect_id = 3; +/// } +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct AuthBindRsp { + pub code: i32, + pub message: String, + pub connect_id: String, +} + +/// Parse an `AuthBindRsp` from the biz payload (`ConnMsg.data`). +pub fn decode_auth_bind_rsp(data: &[u8]) -> Result { + let fields = parse_fields(data)?; + Ok(AuthBindRsp { + code: get_varint(&fields, 1) as i32, + message: get_string(&fields, 2), + connect_id: get_string(&fields, 3), + }) +} + +pub fn encode_push_ack(original: &ConnFrame) -> Vec { + encode_conn_msg( + cmd_type::PUSH_ACK, + &original.cmd, + next_seq_no(), + &original.msg_id, + &original.module, + &[], + ) +} + +// ─── MsgBodyElement (TIM) encoding ───────────────────────────────── + +pub fn encode_msg_content(c: &MsgContent) -> Vec { + let mut buf = Vec::with_capacity(64); + if let Some(ref v) = c.text { + if !v.is_empty() { + encode_field_string(1, v, &mut buf); + } + } + if let Some(ref v) = c.uuid { + if !v.is_empty() { + encode_field_string(2, v, &mut buf); + } + } + if let Some(v) = c.image_format { + if v != 0 { + encode_field_varint(3, v as u64, &mut buf); + } + } + if let Some(ref v) = c.data { + if !v.is_empty() { + encode_field_string(4, v, &mut buf); + } + } + if let Some(ref v) = c.desc { + if !v.is_empty() { + encode_field_string(5, v, &mut buf); + } + } + if let Some(ref v) = c.ext { + if !v.is_empty() { + encode_field_string(6, v, &mut buf); + } + } + if let Some(ref v) = c.sound { + if !v.is_empty() { + encode_field_string(7, v, &mut buf); + } + } + for img in &c.image_info_array { + let mut ib = Vec::with_capacity(48); + if img.image_type != 0 { + encode_field_varint(1, img.image_type as u64, &mut ib); + } + if img.size != 0 { + encode_field_varint(2, img.size as u64, &mut ib); + } + if img.width != 0 { + encode_field_varint(3, img.width as u64, &mut ib); + } + if img.height != 0 { + encode_field_varint(4, img.height as u64, &mut ib); + } + if !img.url.is_empty() { + encode_field_string(5, &img.url, &mut ib); + } + encode_field_bytes(8, &ib, &mut buf); + } + if let Some(v) = c.index { + if v != 0 { + encode_field_varint(9, v as u64, &mut buf); + } + } + if let Some(ref v) = c.url { + if !v.is_empty() { + encode_field_string(10, v, &mut buf); + } + } + if let Some(v) = c.file_size { + if v != 0 { + encode_field_varint(11, v as u64, &mut buf); + } + } + if let Some(ref v) = c.file_name { + if !v.is_empty() { + encode_field_string(12, v, &mut buf); + } + } + buf +} + +fn decode_msg_content(data: &[u8]) -> Result { + let fields = parse_fields(data)?; + let mut c = MsgContent::default(); + for (n, v) in &fields { + match (*n, v) { + (1, FieldValue::Bytes(b)) => c.text = Some(String::from_utf8_lossy(b).into_owned()), + (2, FieldValue::Bytes(b)) => c.uuid = Some(String::from_utf8_lossy(b).into_owned()), + (3, FieldValue::Varint(x)) => c.image_format = Some(*x as u32), + (4, FieldValue::Bytes(b)) => c.data = Some(String::from_utf8_lossy(b).into_owned()), + (5, FieldValue::Bytes(b)) => c.desc = Some(String::from_utf8_lossy(b).into_owned()), + (6, FieldValue::Bytes(b)) => c.ext = Some(String::from_utf8_lossy(b).into_owned()), + (7, FieldValue::Bytes(b)) => c.sound = Some(String::from_utf8_lossy(b).into_owned()), + (8, FieldValue::Bytes(b)) => { + let ifields = parse_fields(b)?; + let mut info = ImageInfo { + image_type: get_varint(&ifields, 1) as u32, + size: get_varint(&ifields, 2) as u32, + width: get_varint(&ifields, 3) as u32, + height: get_varint(&ifields, 4) as u32, + url: get_string(&ifields, 5), + }; + if info.image_type != 0 || !info.url.is_empty() { + if info.image_type == 0 { + info.image_type = 1; + } + c.image_info_array.push(info); + } + } + (9, FieldValue::Varint(x)) => c.index = Some(*x as u32), + (10, FieldValue::Bytes(b)) => c.url = Some(String::from_utf8_lossy(b).into_owned()), + (11, FieldValue::Varint(x)) => c.file_size = Some(*x as u32), + (12, FieldValue::Bytes(b)) => { + c.file_name = Some(String::from_utf8_lossy(b).into_owned()) + } + _ => {} + } + } + Ok(c) +} + +pub fn encode_msg_body_element(el: &MsgBodyElement) -> Vec { + let mut buf = Vec::with_capacity(64); + if !el.msg_type.is_empty() { + encode_field_string(1, &el.msg_type, &mut buf); + } + let content = encode_msg_content(&el.msg_content); + if !content.is_empty() { + encode_field_bytes(2, &content, &mut buf); + } + buf +} + +fn decode_msg_body_element(data: &[u8]) -> Result { + let fields = parse_fields(data)?; + let content_bytes = get_bytes(&fields, 2); + let content = if content_bytes.is_empty() { + MsgContent::default() + } else { + decode_msg_content(&content_bytes)? + }; + Ok(MsgBodyElement { + msg_type: get_string(&fields, 1), + msg_content: content, + }) +} + +// ─── PushMsg envelope (cmd_type=Push inner wrapper) ──────────────── +// +// message PushMsg { +// string cmd = 1; +// string module = 2; +// string msg_id = 3; +// bytes data = 4; // ← actual biz body (e.g. InboundMessagePush) +// } +// +// The yuanbao gateway wraps every downstream push in this envelope +// *inside* `ConnMsg.data`. Mirrors plugin client.ts::onPush which +// decodes PushMsg before handing `data` to the business decoder. + +#[derive(Debug, Default)] +pub struct PushMsg { + pub cmd: String, + pub module: String, + pub msg_id: String, + pub data: Vec, +} + +pub fn decode_push_msg(data: &[u8]) -> Result { + let fields = parse_fields(data)?; + Ok(PushMsg { + cmd: get_string(&fields, 1), + module: get_string(&fields, 2), + msg_id: get_string(&fields, 3), + data: get_bytes(&fields, 4), + }) +} + +// ─── InboundMessagePush decode ───────────────────────────────────── + +pub fn decode_inbound_push(data: &[u8]) -> Result { + let fields = parse_fields(data)?; + + let mut msg_body = Vec::new(); + for b in get_repeated_bytes(&fields, 13) { + msg_body.push(decode_msg_body_element(&b)?); + } + + let mut recalls = Vec::new(); + for b in get_repeated_bytes(&fields, 17) { + let f = parse_fields(&b)?; + recalls.push(ImMsgSeq { + msg_seq: get_varint(&f, 1) as u32, + msg_id: get_string(&f, 2), + }); + } + + let log_ext_bytes = get_bytes(&fields, 20); + let trace_id = if log_ext_bytes.is_empty() { + String::new() + } else { + get_string(&parse_fields(&log_ext_bytes)?, 1) + }; + + Ok(InboundMessage { + callback_command: get_string(&fields, 1), + from_account: get_string(&fields, 2), + to_account: get_string(&fields, 3), + sender_nickname: get_string(&fields, 4), + group_id: get_string(&fields, 5), + group_code: get_string(&fields, 6), + group_name: get_string(&fields, 7), + msg_seq: get_varint(&fields, 8) as u32, + msg_random: get_varint(&fields, 9) as u32, + msg_time: get_varint(&fields, 10) as u32, + msg_key: get_string(&fields, 11), + msg_id: get_string(&fields, 12), + msg_body, + cloud_custom_data: get_string(&fields, 14), + event_time: get_varint(&fields, 15) as u32, + bot_owner_id: get_string(&fields, 16), + recall_msg_seq_list: recalls, + claw_msg_type: get_varint(&fields, 18) as u32, + private_from_group_code: get_string(&fields, 19), + trace_id, + }) +} + +// ─── InboundMessagePush JSON decode ──────────────────────────────── +// +// The yuanbao gateway sometimes (depending on backend account config / +// source channel) pushes `inbound_message` as a JSON string instead of +// protobuf. The shape matches `InboundMessagePush` field-for-field +// (snake_case), with `log_ext.trace_id` nested. Mirrors plugin +// gateway.ts::decodeFromRawDataJson (l. 238). + +pub fn decode_inbound_json(data: &[u8]) -> Result { + let v: serde_json::Value = serde_json::from_slice(data) + .map_err(|e| YuanbaoError::ProtoDecode(format!("json parse failed: {e}")))?; + + let obj = v + .as_object() + .ok_or_else(|| YuanbaoError::ProtoDecode("json root is not an object".into()))?; + + let get_str = |k: &str| -> String { + obj.get(k) + .and_then(|x| x.as_str()) + .unwrap_or("") + .to_string() + }; + let get_u32 = |k: &str| -> u32 { obj.get(k).and_then(|x| x.as_u64()).unwrap_or(0) as u32 }; + + let msg_body = obj + .get("msg_body") + .and_then(|v| v.as_array()) + .map(|arr| arr.iter().map(decode_msg_body_element_json).collect()) + .unwrap_or_default(); + + let recall_msg_seq_list = obj + .get("recall_msg_seq_list") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .map(|e| ImMsgSeq { + msg_seq: e.get("msg_seq").and_then(|v| v.as_u64()).unwrap_or(0) as u32, + msg_id: e + .get("msg_id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + }) + .collect() + }) + .unwrap_or_default(); + + let trace_id = obj + .get("log_ext") + .and_then(|v| v.get("trace_id")) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + Ok(InboundMessage { + callback_command: get_str("callback_command"), + from_account: get_str("from_account"), + to_account: get_str("to_account"), + sender_nickname: get_str("sender_nickname"), + group_id: get_str("group_id"), + group_code: get_str("group_code"), + group_name: get_str("group_name"), + msg_seq: get_u32("msg_seq"), + msg_random: get_u32("msg_random"), + msg_time: get_u32("msg_time"), + msg_key: get_str("msg_key"), + msg_id: get_str("msg_id"), + msg_body, + cloud_custom_data: get_str("cloud_custom_data"), + event_time: get_u32("event_time"), + bot_owner_id: get_str("bot_owner_id"), + recall_msg_seq_list, + claw_msg_type: get_u32("claw_msg_type"), + private_from_group_code: get_str("private_from_group_code"), + trace_id, + }) +} + +fn decode_msg_body_element_json(v: &serde_json::Value) -> MsgBodyElement { + let msg_type = v + .get("msg_type") + .and_then(|x| x.as_str()) + .unwrap_or("") + .to_string(); + let mc = v.get("msg_content").and_then(|x| x.as_object()); + + let str_field = |k: &str| -> Option { + mc.and_then(|m| m.get(k)) + .and_then(|x| x.as_str()) + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()) + }; + let u32_field = |k: &str| -> Option { + mc.and_then(|m| m.get(k)) + .and_then(|x| x.as_u64()) + .map(|n| n as u32) + }; + + let image_info_array = mc + .and_then(|m| m.get("image_info_array")) + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .map(|e| ImageInfo { + image_type: e + .get("type") + .or_else(|| e.get("image_type")) + .and_then(|x| x.as_u64()) + .unwrap_or(0) as u32, + size: e.get("size").and_then(|x| x.as_u64()).unwrap_or(0) as u32, + width: e.get("width").and_then(|x| x.as_u64()).unwrap_or(0) as u32, + height: e.get("height").and_then(|x| x.as_u64()).unwrap_or(0) as u32, + url: e + .get("url") + .and_then(|x| x.as_str()) + .unwrap_or("") + .to_string(), + }) + .collect() + }) + .unwrap_or_default(); + + MsgBodyElement { + msg_type, + msg_content: MsgContent { + text: str_field("text"), + uuid: str_field("uuid"), + image_format: u32_field("image_format"), + data: str_field("data"), + desc: str_field("desc"), + ext: str_field("ext"), + sound: str_field("sound"), + image_info_array, + index: u32_field("index"), + url: str_field("url"), + file_size: u32_field("file_size"), + file_name: str_field("file_name"), + }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn conn_msg_roundtrip() { + let buf = encode_conn_msg( + cmd_type::REQUEST, + cmd::PING, + 42, + "mid-1", + module::CONN_ACCESS, + b"payload", + ); + let frame = decode_conn_msg(&buf).unwrap(); + assert_eq!(frame.cmd_type, cmd_type::REQUEST); + assert_eq!(frame.cmd, cmd::PING); + assert_eq!(frame.seq_no, 42); + assert_eq!(frame.msg_id, "mid-1"); + assert_eq!(frame.module, module::CONN_ACCESS); + assert_eq!(frame.data, b"payload"); + } + + #[test] + fn auth_bind_smoke() { + let buf = encode_auth_bind( + "biz", + "uid", + "openclaw", + "tok", + "mid", + "1.0", + "linux", + "openhuman/0.1.0", + "", + ); + let frame = decode_conn_msg(&buf).unwrap(); + assert_eq!(frame.cmd, cmd::AUTH_BIND); + assert_eq!(frame.module, module::CONN_ACCESS); + assert!(!frame.data.is_empty()); + } + + #[test] + fn push_ack_mirrors_original() { + let original = ConnFrame { + cmd_type: cmd_type::PUSH, + cmd: "some_push".into(), + module: "yuanbao_openclaw_proxy".into(), + seq_no: 99, + msg_id: "mid-abc".into(), + need_ack: true, + status: 0, + data: vec![], + }; + let buf = encode_push_ack(&original); + let frame = decode_conn_msg(&buf).unwrap(); + assert_eq!(frame.cmd_type, cmd_type::PUSH_ACK); + assert_eq!(frame.cmd, original.cmd); + assert_eq!(frame.module, original.module); + assert_eq!(frame.msg_id, original.msg_id); + } + + #[test] + fn msg_body_element_roundtrip() { + let el = MsgBodyElement { + msg_type: "TIMTextElem".into(), + msg_content: MsgContent { + text: Some("hello 元宝".into()), + ..Default::default() + }, + }; + let buf = encode_msg_body_element(&el); + let got = decode_msg_body_element(&buf).unwrap(); + assert_eq!(got, el); + } + + #[test] + fn image_element_roundtrip() { + let el = MsgBodyElement { + msg_type: "TIMImageElem".into(), + msg_content: MsgContent { + uuid: Some("abc123".into()), + image_format: Some(3), + image_info_array: vec![ImageInfo { + image_type: 1, + size: 1024, + width: 800, + height: 600, + url: "https://example/img.png".into(), + }], + ..Default::default() + }, + }; + let buf = encode_msg_body_element(&el); + let got = decode_msg_body_element(&buf).unwrap(); + assert_eq!(got, el); + } +} diff --git a/src/openhuman/channels/providers/yuanbao/proto_biz.rs b/src/openhuman/channels/providers/yuanbao/proto_biz.rs new file mode 100644 index 0000000000..4b802606bb --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/proto_biz.rs @@ -0,0 +1,417 @@ +//! Business-layer protobuf codecs (biz payloads inside `ConnMsg.data`). +//! +//! Kept separate from `proto.rs` to stay under the 500-line ceiling and +//! to isolate the "openclaw biz protocol" surface from the lower-level +//! ConnMsg envelope. + +use super::errors::YuanbaoError; +use super::proto::{decode_conn_msg, encode_conn_msg, encode_msg_body_element}; +use super::proto_constants::*; +use super::types::*; +use super::wire::{ + encode_field_bytes as put_bytes_field, encode_field_string as put_string_field, + encode_field_varint as put_varint_field, get_bytes, get_repeated_bytes, get_string, get_varint, + next_seq_no, parse_fields, +}; + +// ─── SendC2CMessageReq ──────────────────────────────────────────── +// +// 1: msg_id (string) 5: msg_body (repeated MsgBodyElement) +// 2: to_account 6: group_code (DM-from-group) +// 3: from_account 7: msg_seq +// 4: msg_random 8: log_ext + +#[allow(clippy::too_many_arguments)] +fn encode_send_c2c_req( + msg_id: &str, + to_account: &str, + from_account: &str, + msg_random: u32, + msg_body: &[MsgBodyElement], + group_code: &str, + msg_seq: Option, + trace_id: &str, +) -> Vec { + let mut buf = Vec::with_capacity(128); + if !msg_id.is_empty() { + put_string_field(1, msg_id, &mut buf); + } + put_string_field(2, to_account, &mut buf); + if !from_account.is_empty() { + put_string_field(3, from_account, &mut buf); + } + if msg_random != 0 { + put_varint_field(4, msg_random as u64, &mut buf); + } + for el in msg_body { + let el_bytes = encode_msg_body_element(el); + put_bytes_field(5, &el_bytes, &mut buf); + } + if !group_code.is_empty() { + put_string_field(6, group_code, &mut buf); + } + if let Some(seq) = msg_seq { + put_varint_field(7, seq, &mut buf); + } + if !trace_id.is_empty() { + // log_ext is field 8 with a nested {1: trace_id} + let mut log = Vec::new(); + put_string_field(1, trace_id, &mut log); + put_bytes_field(8, &log, &mut buf); + } + buf +} + +/// Encode a full C2C send request as a `ConnMsg` ready to send over WS. +pub fn encode_send_c2c_message( + to_account: &str, + from_account: &str, + msg_body: &[MsgBodyElement], + msg_id: &str, + msg_random: u32, + group_code: &str, + trace_id: &str, +) -> Vec { + let body = encode_send_c2c_req( + msg_id, + to_account, + from_account, + msg_random, + msg_body, + group_code, + None, + trace_id, + ); + let req_id = if msg_id.is_empty() { + format!("c2c_{}", next_seq_no()) + } else { + msg_id.to_string() + }; + encode_conn_msg( + cmd_type::REQUEST, + biz_cmd::SEND_C2C_MESSAGE, + next_seq_no(), + &req_id, + module::BIZ_PKG, + &body, + ) +} + +// ─── SendGroupMessageReq ─────────────────────────────────────────── +// +// 1: msg_id 5: random (string) +// 2: group_code 6: msg_body (repeated) +// 3: from_account 7: ref_msg_id +// 4: to_account 8: msg_seq +// 9: log_ext + +#[allow(clippy::too_many_arguments)] +fn encode_send_group_req( + msg_id: &str, + group_code: &str, + from_account: &str, + to_account: &str, + random: &str, + msg_body: &[MsgBodyElement], + ref_msg_id: &str, + trace_id: &str, +) -> Vec { + let mut buf = Vec::with_capacity(128); + if !msg_id.is_empty() { + put_string_field(1, msg_id, &mut buf); + } + put_string_field(2, group_code, &mut buf); + if !from_account.is_empty() { + put_string_field(3, from_account, &mut buf); + } + if !to_account.is_empty() { + put_string_field(4, to_account, &mut buf); + } + if !random.is_empty() { + put_string_field(5, random, &mut buf); + } + for el in msg_body { + let el_bytes = encode_msg_body_element(el); + put_bytes_field(6, &el_bytes, &mut buf); + } + if !ref_msg_id.is_empty() { + put_string_field(7, ref_msg_id, &mut buf); + } + if !trace_id.is_empty() { + let mut log = Vec::new(); + put_string_field(1, trace_id, &mut log); + put_bytes_field(9, &log, &mut buf); + } + buf +} + +#[allow(clippy::too_many_arguments)] +pub fn encode_send_group_message( + group_code: &str, + from_account: &str, + msg_body: &[MsgBodyElement], + msg_id: &str, + to_account: &str, + random: &str, + ref_msg_id: &str, + trace_id: &str, +) -> Vec { + let body = encode_send_group_req( + msg_id, + group_code, + from_account, + to_account, + random, + msg_body, + ref_msg_id, + trace_id, + ); + let req_id = if msg_id.is_empty() { + format!("grp_{}", next_seq_no()) + } else { + msg_id.to_string() + }; + encode_conn_msg( + cmd_type::REQUEST, + biz_cmd::SEND_GROUP_MESSAGE, + next_seq_no(), + &req_id, + module::BIZ_PKG, + &body, + ) +} + +// ─── Heartbeats ──────────────────────────────────────────────────── + +pub fn encode_send_private_heartbeat( + req_id: &str, + from_account: &str, + to_account: &str, + heartbeat: u32, +) -> Vec { + let mut body = Vec::with_capacity(48); + put_string_field(1, from_account, &mut body); + put_string_field(2, to_account, &mut body); + put_varint_field(3, heartbeat as u64, &mut body); + encode_conn_msg( + cmd_type::REQUEST, + biz_cmd::SEND_PRIVATE_HEARTBEAT, + next_seq_no(), + req_id, + module::BIZ_PKG, + &body, + ) +} + +pub fn encode_send_group_heartbeat( + req_id: &str, + from_account: &str, + group_code: &str, + heartbeat: u32, + send_time_ms: u64, +) -> Vec { + let mut body = Vec::with_capacity(64); + put_string_field(1, from_account, &mut body); + put_string_field(2, "", &mut body); // to_account empty for group + put_string_field(3, group_code, &mut body); + put_varint_field(4, send_time_ms, &mut body); + put_varint_field(5, heartbeat as u64, &mut body); + encode_conn_msg( + cmd_type::REQUEST, + biz_cmd::SEND_GROUP_HEARTBEAT, + next_seq_no(), + req_id, + module::BIZ_PKG, + &body, + ) +} + +// ─── QueryGroupInfo ──────────────────────────────────────────────── + +pub fn encode_query_group_info(req_id: &str, group_code: &str) -> Vec { + let mut body = Vec::with_capacity(16 + group_code.len()); + put_string_field(1, group_code, &mut body); + encode_conn_msg( + cmd_type::REQUEST, + biz_cmd::QUERY_GROUP_INFO, + next_seq_no(), + req_id, + module::BIZ_PKG, + &body, + ) +} + +pub fn decode_query_group_info_rsp(data: &[u8]) -> Result { + let fields = parse_fields(data)?; + let mut info = GroupInfo { + code: get_varint(&fields, 1) as i32, + message: get_string(&fields, 2), + ..Default::default() + }; + let gi_bytes = get_bytes(&fields, 3); + if !gi_bytes.is_empty() { + let gi = parse_fields(&gi_bytes)?; + info.group_name = get_string(&gi, 1); + info.owner_id = get_string(&gi, 2); + info.owner_nickname = get_string(&gi, 3); + info.member_count = get_varint(&gi, 4) as u32; + } + Ok(info) +} + +// ─── GetGroupMemberList ──────────────────────────────────────────── + +pub fn encode_get_group_member_list( + req_id: &str, + group_code: &str, + offset: u32, + limit: u32, +) -> Vec { + let mut body = Vec::with_capacity(32 + group_code.len()); + put_string_field(1, group_code, &mut body); + if offset != 0 { + put_varint_field(2, offset as u64, &mut body); + } + put_varint_field(3, limit as u64, &mut body); + encode_conn_msg( + cmd_type::REQUEST, + biz_cmd::GET_GROUP_MEMBER_LIST, + next_seq_no(), + req_id, + module::BIZ_PKG, + &body, + ) +} + +pub fn decode_get_group_member_list_rsp(data: &[u8]) -> Result { + let fields = parse_fields(data)?; + let mut members = Vec::new(); + for b in get_repeated_bytes(&fields, 3) { + let m = parse_fields(&b)?; + members.push(GroupMember { + user_id: get_string(&m, 1), + nickname: get_string(&m, 2), + role: get_varint(&m, 3) as u32, + join_time: get_varint(&m, 4) as u32, + name_card: get_string(&m, 5), + }); + } + Ok(GroupMemberListPage { + code: get_varint(&fields, 1) as i32, + message: get_string(&fields, 2), + members, + next_offset: get_varint(&fields, 4) as u32, + is_complete: get_varint(&fields, 5) != 0, + }) +} + +// ─── Generic biz response code helper ────────────────────────────── + +/// Decode the `code` and `message` from a biz response. +/// +/// All biz responses share the convention: field 1 = code, field 2 = message. +pub fn decode_biz_rsp_code(data: &[u8]) -> Result<(i32, String), YuanbaoError> { + let fields = parse_fields(data)?; + Ok((get_varint(&fields, 1) as i32, get_string(&fields, 2))) +} + +/// Decode a `ConnMsg` and return the typed biz response code + frame for +/// the request/response correlator. +pub fn decode_response_envelope(frame_bytes: &[u8]) -> Result { + decode_conn_msg(frame_bytes) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn text_body(s: &str) -> Vec { + vec![MsgBodyElement { + msg_type: "TIMTextElem".into(), + msg_content: MsgContent { + text: Some(s.into()), + ..Default::default() + }, + }] + } + + #[test] + fn c2c_encode_smoke() { + let buf = encode_send_c2c_message("uid_alice", "uid_bot", &text_body("hi"), "", 0, "", ""); + let frame = decode_conn_msg(&buf).unwrap(); + assert_eq!(frame.cmd, biz_cmd::SEND_C2C_MESSAGE); + assert_eq!(frame.module, module::BIZ_PKG); + assert!(!frame.data.is_empty()); + } + + #[test] + fn group_encode_smoke() { + let buf = encode_send_group_message( + "group_42", + "uid_bot", + &text_body("hello"), + "", + "", + "rand", + "", + "", + ); + let frame = decode_conn_msg(&buf).unwrap(); + assert_eq!(frame.cmd, biz_cmd::SEND_GROUP_MESSAGE); + } + + #[test] + fn private_heartbeat_smoke() { + let buf = + encode_send_private_heartbeat("hb_1", "uid_bot", "uid_user", ws_heartbeat::RUNNING); + let frame = decode_conn_msg(&buf).unwrap(); + assert_eq!(frame.cmd, biz_cmd::SEND_PRIVATE_HEARTBEAT); + assert_eq!(frame.msg_id, "hb_1"); + } + + #[test] + fn query_group_info_roundtrip() { + let buf = encode_query_group_info("qgi_1", "group_99"); + let frame = decode_conn_msg(&buf).unwrap(); + assert_eq!(frame.cmd, biz_cmd::QUERY_GROUP_INFO); + assert_eq!(frame.msg_id, "qgi_1"); + + // Simulate response payload: code=0, message="ok", group_name="g", owner=… + let mut gi = Vec::new(); + put_string_field(1, "TestGroup", &mut gi); + put_string_field(2, "owner_uid", &mut gi); + put_string_field(3, "OwnerNick", &mut gi); + put_varint_field(4, 42, &mut gi); + let mut rsp = Vec::new(); + put_varint_field(1, 0, &mut rsp); + put_string_field(2, "ok", &mut rsp); + put_bytes_field(3, &gi, &mut rsp); + + let parsed = decode_query_group_info_rsp(&rsp).unwrap(); + assert_eq!(parsed.code, 0); + assert_eq!(parsed.group_name, "TestGroup"); + assert_eq!(parsed.owner_id, "owner_uid"); + assert_eq!(parsed.member_count, 42); + } + + #[test] + fn group_member_list_decode() { + let mut m1 = Vec::new(); + put_string_field(1, "uid_a", &mut m1); + put_string_field(2, "Alice", &mut m1); + put_varint_field(3, 2, &mut m1); + let mut rsp = Vec::new(); + put_varint_field(1, 0, &mut rsp); + put_string_field(2, "ok", &mut rsp); + put_bytes_field(3, &m1, &mut rsp); + put_varint_field(4, 100, &mut rsp); + put_varint_field(5, 1, &mut rsp); + + let page = decode_get_group_member_list_rsp(&rsp).unwrap(); + assert_eq!(page.members.len(), 1); + assert_eq!(page.members[0].user_id, "uid_a"); + assert_eq!(page.members[0].role, 2); + assert_eq!(page.next_offset, 100); + assert!(page.is_complete); + } +} diff --git a/src/openhuman/channels/providers/yuanbao/proto_constants.rs b/src/openhuman/channels/providers/yuanbao/proto_constants.rs new file mode 100644 index 0000000000..bde9fdb58e --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/proto_constants.rs @@ -0,0 +1,90 @@ +//! Yuanbao WebSocket protocol constants. +//! +//! Values mirror `gateway/platforms/yuanbao_proto.py` in hermes-agent +//! (the authoritative reference implementation). + +/// `ConnMsg.Head.cmd_type` enum. +pub mod cmd_type { + /// Upstream request. + pub const REQUEST: u32 = 0; + /// Response to a previous upstream request. + pub const RESPONSE: u32 = 1; + /// Downstream push from server. + pub const PUSH: u32 = 2; + /// ACK reply to a downstream push. + pub const PUSH_ACK: u32 = 3; +} + +/// Built-in command words used in `ConnMsg.Head.cmd`. +pub mod cmd { + pub const AUTH_BIND: &str = "auth-bind"; + pub const PING: &str = "ping"; + pub const KICKOUT: &str = "kickout"; + pub const UPDATE_META: &str = "update-meta"; +} + +/// Module / service names used in `ConnMsg.Head.module`. +pub mod module { + pub const CONN_ACCESS: &str = "conn_access"; + /// Short name of the openclaw biz module (matches TS client). + pub const BIZ_PKG: &str = "yuanbao_openclaw_proxy"; +} + +/// Business command words (`ConnMsg.Head.cmd` when module=BIZ_PKG). +/// +/// Note: there is intentionally no constant for the inbound push cmd — +/// the yuanbao gateway uses several cmd words for inbound messages and +/// the routing is purely by `cmd_type=Push` (see `connection.rs` / +/// `mod.rs::dispatch_push`). +pub mod biz_cmd { + pub const SEND_C2C_MESSAGE: &str = "send_c2c_message"; + pub const SEND_GROUP_MESSAGE: &str = "send_group_message"; + pub const SEND_PRIVATE_HEARTBEAT: &str = "send_private_heartbeat"; + pub const SEND_GROUP_HEARTBEAT: &str = "send_group_heartbeat"; + pub const QUERY_GROUP_INFO: &str = "query_group_info"; + pub const GET_GROUP_MEMBER_LIST: &str = "get_group_member_list"; +} + +/// Reply Heartbeat status enum (`heartbeat` field of `Send*HeartbeatReq`). +pub mod ws_heartbeat { + /// Bot is currently producing output. + pub const RUNNING: u32 = 1; + /// Bot has finished its turn. + pub const FINISH: u32 = 2; +} + +/// TIM `msg_type` string constants for `MsgBodyElement.msg_type`. +pub mod tim { + pub const TEXT: &str = "TIMTextElem"; + pub const IMAGE: &str = "TIMImageElem"; + pub const FILE: &str = "TIMFileElem"; + pub const SOUND: &str = "TIMSoundElem"; + pub const VIDEO: &str = "TIMVideoFileElem"; + pub const FACE: &str = "TIMFaceElem"; + pub const CUSTOM: &str = "TIMCustomElem"; +} + +/// Fixed instance id reported in `AuthBindReq.DeviceInfo.instance_id` and +/// the `X-Instance-Id` HTTP header. Mirrors `OPENCLAW_ID = 16` used by +/// `yuanbao-openclaw-plugin` (`src/access/ws/conn-codec.ts`) — the server +/// keys some checks off this value, so it must match the value the sign +/// endpoint sees when the token is minted. +pub const OPENHUMAN_INSTANCE_ID: &str = "16"; + +/// Reconnect backoff schedule (seconds). Mirrors hermes-agent. +pub const RECONNECT_DELAYS: &[u64] = &[1, 2, 5, 10, 30, 60]; +pub const MAX_RECONNECT_ATTEMPTS: u32 = 100; + +/// Ping interval (seconds). Server-driven; this is the upper bound. +pub const PING_INTERVAL_SECS: u64 = 30; +/// Number of consecutive ping timeouts before the connection is dropped. +pub const HEARTBEAT_TIMEOUT_THRESHOLD: u32 = 2; +/// Per-request biz timeout (seconds). +pub const DEFAULT_SEND_TIMEOUT_SECS: u64 = 30; +/// Auth-bind handshake timeout (seconds). +pub const AUTH_TIMEOUT_SECS: u64 = 15; + +/// Inbound dedup TTL — drop a `msg_id` we've already seen within this window. +pub const DEDUP_TTL_SECS: u64 = 300; +/// LRU-style cap on the dedup table. +pub const DEDUP_CAPACITY: usize = 10_000; diff --git a/src/openhuman/channels/providers/yuanbao/sign.rs b/src/openhuman/channels/providers/yuanbao/sign.rs new file mode 100644 index 0000000000..e168b8f95d --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/sign.rs @@ -0,0 +1,446 @@ +//! Token sign manager — talks to `/api/v5/robotLogic/sign-token` to +//! exchange `(app_key, app_secret)` for a short-lived WS token + bot_id. +//! +//! Mirrors hermes-agent `SignManager` (yuanbao.py 641-881). Implements: +//! - per-app_key tokio `Mutex` to coalesce concurrent refresh attempts +//! - 60-second early-refresh margin to avoid using a token that's +//! about to expire mid-handshake +//! - retry on `code=10099` up to 3 times +//! +//! Signature scheme (TS plugin compatible): +//! plain = nonce + timestamp + app_key + app_secret +//! signature = HMAC-SHA256(key = app_secret, msg = plain) as lower-hex + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use chrono::FixedOffset; +use hmac::{Hmac, Mac}; +use sha2::Sha256; +use tokio::sync::Mutex; +use tracing::{info, warn}; + +use super::errors::YuanbaoError; + +const SIGN_PATH: &str = "/api/v5/robotLogic/sign-token"; +const RETRYABLE_CODE: i64 = 10099; +const MAX_RETRIES: usize = 3; +const RETRY_DELAY_MS: u64 = 1_000; +/// Treat as expiring this many seconds before actual expiry so a fresh +/// token is fetched before the running one dies mid-request. +const CACHE_REFRESH_MARGIN_SECS: u64 = 60; +const HTTP_TIMEOUT_SECS: u64 = 10; +const DEFAULT_DURATION_SECS: u64 = 3600; + +/// One cached token entry. +#[derive(Debug, Clone)] +pub struct TokenEntry { + pub token: String, + pub bot_id: String, + pub product: String, + pub source: String, + /// Seconds-since-epoch when this token expires server-side. + pub expire_ts: u64, +} + +impl TokenEntry { + pub fn is_valid(&self) -> bool { + let now = unix_now(); + self.expire_ts > now + CACHE_REFRESH_MARGIN_SECS + } + + pub fn seconds_remaining(&self) -> i64 { + self.expire_ts as i64 - unix_now() as i64 + } +} + +type HmacSha256 = Hmac; + +/// Compute the `signature` field for the sign-token API. +pub fn compute_signature(nonce: &str, timestamp: &str, app_key: &str, app_secret: &str) -> String { + let plain = format!("{nonce}{timestamp}{app_key}{app_secret}"); + let mut mac = + HmacSha256::new_from_slice(app_secret.as_bytes()).expect("HMAC accepts any key length"); + mac.update(plain.as_bytes()); + hex::encode(mac.finalize().into_bytes()) +} + +/// Build Beijing-time ISO-8601 timestamp without milliseconds. +/// Format: `2006-01-02T15:04:05+08:00`. +pub fn build_timestamp() -> String { + let bj_offset = FixedOffset::east_opt(8 * 3600).expect("valid offset"); + let now = chrono::Utc::now().with_timezone(&bj_offset); + now.format("%Y-%m-%dT%H:%M:%S+08:00").to_string() +} + +/// Generate a 32-char hex nonce. +pub fn generate_nonce() -> String { + let mut bytes = [0u8; 16]; + for b in &mut bytes { + *b = rand::random::(); + } + hex::encode(bytes) +} + +/// Process-wide token manager. One instance is built per `YuanbaoChannel` +/// and shared with the connection layer; the per-app_key Mutex makes it +/// safe to have multiple connections sharing this manager. +pub struct SignManager { + http: reqwest::Client, + /// Per-app_key refresh mutexes — coalesce concurrent refresh attempts. + locks: Mutex>>>, + /// Token cache keyed by app_key. + cache: Mutex>, +} + +impl SignManager { + pub fn new(http: reqwest::Client) -> Arc { + Arc::new(Self { + http, + locks: Mutex::new(HashMap::new()), + cache: Mutex::new(HashMap::new()), + }) + } + + /// Look up a cached token without touching the network. + pub async fn cached(&self, app_key: &str) -> Option { + let cache = self.cache.lock().await; + cache.get(app_key).cloned().filter(|e| e.is_valid()) + } + + /// Get a valid token, fetching one if the cache is empty or stale. + pub async fn get_token( + &self, + app_key: &str, + app_secret: &str, + api_domain: &str, + route_env: &str, + ) -> Result { + if let Some(entry) = self.cached(app_key).await { + info!( + "[yuanbao/sign] using cached token ({}s remaining)", + entry.seconds_remaining() + ); + return Ok(entry); + } + self.refresh(app_key, app_secret, api_domain, route_env) + .await + } + + /// Force-refresh: drop the cache entry and re-fetch. + pub async fn force_refresh( + &self, + app_key: &str, + app_secret: &str, + api_domain: &str, + route_env: &str, + ) -> Result { + { + let mut cache = self.cache.lock().await; + cache.remove(app_key); + } + warn!( + "[yuanbao/sign] force-refresh app_key=****{}", + suffix(app_key) + ); + self.refresh(app_key, app_secret, api_domain, route_env) + .await + } + + async fn refresh( + &self, + app_key: &str, + app_secret: &str, + api_domain: &str, + route_env: &str, + ) -> Result { + let lock = self.get_refresh_lock(app_key).await; + let _g = lock.lock().await; + + // Double-checked locking: another task may have refreshed while we waited. + if let Some(entry) = self.cached(app_key).await { + return Ok(entry); + } + + let entry = self + .fetch_with_retry(app_key, app_secret, api_domain, route_env) + .await?; + let mut cache = self.cache.lock().await; + cache.insert(app_key.to_string(), entry.clone()); + Ok(entry) + } + + async fn get_refresh_lock(&self, app_key: &str) -> Arc> { + let mut locks = self.locks.lock().await; + locks + .entry(app_key.to_string()) + .or_insert_with(|| Arc::new(Mutex::new(()))) + .clone() + } + + async fn fetch_with_retry( + &self, + app_key: &str, + app_secret: &str, + api_domain: &str, + route_env: &str, + ) -> Result { + let url = format!("{}{}", api_domain.trim_end_matches('/'), SIGN_PATH); + let mut last_err: Option = None; + + for attempt in 0..=MAX_RETRIES { + let nonce = generate_nonce(); + let timestamp = build_timestamp(); + let signature = compute_signature(&nonce, ×tamp, app_key, app_secret); + let payload = serde_json::json!({ + "app_key": app_key, + "nonce": nonce, + "signature": signature, + "timestamp": timestamp, + }); + + let mut req = self + .http + .post(&url) + .timeout(Duration::from_secs(HTTP_TIMEOUT_SECS)) + .header("Content-Type", "application/json") + .header("X-AppVersion", "openhuman/0.1.0") + .header("X-OperationSystem", "linux") + .header( + "X-Instance-Id", + super::proto_constants::OPENHUMAN_INSTANCE_ID, + ) + .header("X-Bot-Version", "openhuman/0.1.0"); + if !route_env.is_empty() { + req = req.header("X-Route-Env", route_env); + } + + info!( + "[yuanbao/sign] POST {}{}", + url, + if attempt > 0 { + format!(" (retry {attempt}/{MAX_RETRIES})") + } else { + String::new() + } + ); + + let resp = match req.json(&payload).send().await { + Ok(r) => r, + Err(e) => { + last_err = Some(YuanbaoError::Connection(format!("sign-token: {e}"))); + if attempt < MAX_RETRIES { + tokio::time::sleep(Duration::from_millis(RETRY_DELAY_MS)).await; + continue; + } + break; + } + }; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(YuanbaoError::AuthFailed(format!( + "sign-token HTTP {status}: {}", + &body.chars().take(200).collect::() + ))); + } + + let json: serde_json::Value = resp + .json() + .await + .map_err(|e| YuanbaoError::AuthFailed(format!("sign-token body: {e}")))?; + + let code = json.get("code").and_then(|c| c.as_i64()).unwrap_or(0); + if code == 0 { + let data = match json.get("data") { + Some(d) if d.is_object() => d, + _ => { + return Err(YuanbaoError::AuthFailed( + "sign-token response missing 'data'".into(), + )) + } + }; + let duration = data + .get("duration") + .and_then(|v| v.as_u64()) + .unwrap_or(DEFAULT_DURATION_SECS); + let entry = TokenEntry { + token: data + .get("token") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + bot_id: data + .get("bot_id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + product: data + .get("product") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + source: data + .get("source") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + expire_ts: unix_now() + duration, + }; + info!( + "[yuanbao/sign] success: bot_id={} duration={}s", + entry.bot_id, duration + ); + return Ok(entry); + } + + if code == RETRYABLE_CODE && attempt < MAX_RETRIES { + warn!( + "[yuanbao/sign] retryable code={code}, retrying in {RETRY_DELAY_MS}ms (attempt {})", + attempt + 1 + ); + tokio::time::sleep(Duration::from_millis(RETRY_DELAY_MS)).await; + continue; + } + + let msg = json + .get("msg") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + return Err(YuanbaoError::AuthFailed(format!( + "sign-token code={code} msg={msg}" + ))); + } + + Err(last_err.unwrap_or(YuanbaoError::AuthFailed( + "sign-token max retries exceeded".into(), + ))) + } + + /// Drop all per-app_key locks. Called on channel shutdown to avoid + /// leaking entries across reconnects within the same process. + pub async fn clear_locks(&self) { + let mut locks = self.locks.lock().await; + locks.clear(); + } +} + +fn unix_now() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0) +} + +fn suffix(s: &str) -> &str { + if s.len() <= 4 { + s + } else { + &s[s.len() - 4..] + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn signature_matches_python_reference() { + // Reproducible vector — hand-computed: + // plain = "n123" + "2026-05-19T22:00:00+08:00" + "app_k" + "secret" + // sig = HMAC-SHA256(key="secret", msg=plain) as lower hex + let sig = compute_signature("n123", "2026-05-19T22:00:00+08:00", "app_k", "secret"); + // We don't pin the exact bytes (would require running Python to confirm) — + // instead verify the contract: same inputs → same output, 64-char hex. + assert_eq!(sig.len(), 64); + assert!(sig.chars().all(|c| c.is_ascii_hexdigit())); + let sig2 = compute_signature("n123", "2026-05-19T22:00:00+08:00", "app_k", "secret"); + assert_eq!(sig, sig2); + } + + #[test] + fn signature_varies_with_inputs() { + let s1 = compute_signature("n1", "t", "ak", "sk"); + let s2 = compute_signature("n2", "t", "ak", "sk"); + let s3 = compute_signature("n1", "t2", "ak", "sk"); + let s4 = compute_signature("n1", "t", "ak2", "sk"); + let s5 = compute_signature("n1", "t", "ak", "sk2"); + let all = [&s1, &s2, &s3, &s4, &s5]; + for (i, a) in all.iter().enumerate() { + for (j, b) in all.iter().enumerate() { + if i != j { + assert_ne!(a, b, "inputs {i} vs {j} should differ"); + } + } + } + } + + #[test] + fn nonce_is_32_char_hex() { + let n = generate_nonce(); + assert_eq!(n.len(), 32); + assert!(n.chars().all(|c| c.is_ascii_hexdigit())); + } + + #[test] + fn timestamp_matches_beijing_format() { + let t = build_timestamp(); + // 2006-01-02T15:04:05+08:00 → length 25 + assert_eq!(t.len(), 25); + assert!(t.ends_with("+08:00")); + assert_eq!(&t[4..5], "-"); + assert_eq!(&t[7..8], "-"); + assert_eq!(&t[10..11], "T"); + assert_eq!(&t[13..14], ":"); + } + + #[test] + fn token_entry_is_valid_only_with_margin() { + let mut e = TokenEntry { + token: "t".into(), + bot_id: "b".into(), + product: String::new(), + source: String::new(), + expire_ts: unix_now() + 120, + }; + assert!(e.is_valid()); + e.expire_ts = unix_now() + 30; // less than 60s margin + assert!(!e.is_valid()); + e.expire_ts = unix_now().saturating_sub(10); + assert!(!e.is_valid()); + } + + #[tokio::test] + async fn cache_returns_entry_when_valid() { + let mgr = SignManager::new(reqwest::Client::new()); + let entry = TokenEntry { + token: "tok".into(), + bot_id: "bot".into(), + product: String::new(), + source: String::new(), + expire_ts: unix_now() + 600, + }; + mgr.cache.lock().await.insert("ak".into(), entry.clone()); + let got = mgr.cached("ak").await.expect("cache hit"); + assert_eq!(got.token, "tok"); + } + + #[tokio::test] + async fn cache_drops_expired_entry() { + let mgr = SignManager::new(reqwest::Client::new()); + mgr.cache.lock().await.insert( + "ak".into(), + TokenEntry { + token: "tok".into(), + bot_id: "bot".into(), + product: String::new(), + source: String::new(), + expire_ts: unix_now() + 10, // under margin + }, + ); + assert!(mgr.cached("ak").await.is_none()); + } +} diff --git a/src/openhuman/channels/providers/yuanbao/splitter.rs b/src/openhuman/channels/providers/yuanbao/splitter.rs new file mode 100644 index 0000000000..88784a7814 --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/splitter.rs @@ -0,0 +1,209 @@ +//! Fence-aware Markdown splitter. +//! +//! When a long AI response is split into N chunks for the Yuanbao +//! `max_message_length` cap, we must not break inside: +//! - a fenced code block (``` … ``` or ~~~ … ~~~) +//! - a Markdown table row (lines starting with `|`) +//! - a list-continuation block +//! +//! Strategy: walk the input by line, tracking fence/table state, and +//! emit a chunk every time adding the next line would push the buffer +//! past the cap **and** the cap boundary is safe (not inside a fence, +//! not in the middle of a table). If a single line is itself longer +//! than the cap, hard-split at a char boundary. + +/// Split `text` into chunks no larger than `cap_bytes` (utf-8 byte count), +/// preserving fenced code blocks and table rows where possible. +pub fn split_markdown(text: &str, cap_bytes: usize) -> Vec { + if text.len() <= cap_bytes { + return vec![text.to_string()]; + } + let cap = cap_bytes.max(1); + // Reserve a small headroom so the trailing newline / final char fits + // when we flush. For very small caps fall back to no margin so callers + // testing tight bounds (cap=20) still get chunks under the cap. + let safe_cap = if cap >= 32 { + cap.saturating_sub(16) + } else { + cap + }; + + let mut chunks: Vec = Vec::new(); + let mut buf = String::with_capacity(cap); + let mut in_fence = false; + let mut fence_marker: Option = None; + + for line in text.split_inclusive('\n') { + let trimmed = line.trim_end_matches('\n'); + let starts_fence = is_fence(trimmed); + + // If this single line is wider than the cap, we must hard-split it. + if line.len() > cap { + flush(&mut chunks, &mut buf); + for piece in hard_split(line, cap) { + chunks.push(piece); + } + continue; + } + + let candidate_len = buf.len() + line.len(); + if candidate_len > safe_cap && !buf.is_empty() && safe_to_break(in_fence) { + flush(&mut chunks, &mut buf); + } + buf.push_str(line); + + if let Some(marker) = starts_fence { + if let Some(open) = &fence_marker { + if marker == *open { + in_fence = false; + fence_marker = None; + } + } else { + in_fence = true; + fence_marker = Some(marker); + } + } + } + flush(&mut chunks, &mut buf); + + // Drop empty trailing chunks (can happen if input ends on newline). + chunks.retain(|c| !c.trim().is_empty()); + chunks +} + +fn flush(chunks: &mut Vec, buf: &mut String) { + if !buf.is_empty() { + chunks.push(buf.trim_end().to_string()); + buf.clear(); + } +} + +fn safe_to_break(in_fence: bool) -> bool { + !in_fence +} + +/// If `line` opens or closes a fenced code block, return the marker text +/// (e.g. "```" or "~~~"). A line that contains a fence in the middle is +/// NOT a fence marker; only lines that *start* with three or more +/// backticks/tildes count. +fn is_fence(line: &str) -> Option { + let trimmed = line.trim_start(); + if let Some(rest) = trimmed.strip_prefix("```") { + // Allow optional language tag after the fence. + let _ = rest; + return Some("```".into()); + } + if let Some(rest) = trimmed.strip_prefix("~~~") { + let _ = rest; + return Some("~~~".into()); + } + None +} + +/// Last-resort splitter for a line that's wider than the cap. +fn hard_split(line: &str, cap: usize) -> Vec { + let mut out = Vec::new(); + let mut remaining = line; + while !remaining.is_empty() { + if remaining.len() <= cap { + out.push(remaining.to_string()); + break; + } + let mut idx = cap; + while idx > 0 && !remaining.is_char_boundary(idx) { + idx -= 1; + } + if idx == 0 { + // pathological — emit one char at a time + let take = remaining + .char_indices() + .nth(1) + .map(|(i, _)| i) + .unwrap_or(remaining.len()); + let (chunk, rest) = remaining.split_at(take); + out.push(chunk.to_string()); + remaining = rest; + } else { + let (chunk, rest) = remaining.split_at(idx); + out.push(chunk.to_string()); + remaining = rest; + } + } + out +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn short_input_returns_one_chunk() { + let r = split_markdown("hello", 100); + assert_eq!(r, vec!["hello"]); + } + + #[test] + fn splits_on_newlines_respecting_cap() { + let input = "a\n".repeat(100); + let r = split_markdown(&input, 20); + assert!(r.len() > 1); + for c in &r { + assert!(c.len() <= 20, "chunk too long: {c:?}"); + } + } + + #[test] + fn preserves_fenced_code_block() { + let input = "intro line\n\ + ```rust\n\ + fn long_function_a() -> u32 { 42 }\n\ + fn long_function_b() -> u32 { 43 }\n\ + fn long_function_c() -> u32 { 44 }\n\ + ```\n\ + trailing text"; + let chunks = split_markdown(input, 80); + // Find the chunk(s) containing the fence — they must not split mid-fence. + let mut open = 0; + for c in &chunks { + for line in c.lines() { + if is_fence(line).is_some() { + open += 1; + } + } + } + // The fence must appear as balanced pairs. + assert_eq!(open % 2, 0, "unbalanced fences after split: {chunks:#?}"); + } + + #[test] + fn hard_split_very_long_line() { + let line = "x".repeat(500); + let r = split_markdown(&line, 100); + for c in &r { + assert!(c.len() <= 100, "chunk too long: {}", c.len()); + } + assert_eq!(r.join("").len(), 500); + } + + #[test] + fn unicode_safe_hard_split() { + let line = "中".repeat(200); // each char is 3 bytes → 600 total + let r = split_markdown(&line, 50); + for c in &r { + assert!(c.len() <= 50, "chunk too long: {}", c.len()); + // verify it's valid utf-8 by reading it + for ch in c.chars() { + assert!(ch == '中'); + } + } + } + + #[test] + fn is_fence_detects_backticks() { + assert_eq!(is_fence("```").as_deref(), Some("```")); + assert_eq!(is_fence("```rust").as_deref(), Some("```")); + assert_eq!(is_fence("~~~").as_deref(), Some("~~~")); + assert_eq!(is_fence("text").as_deref(), None); + assert_eq!(is_fence("``").as_deref(), None); + } +} diff --git a/src/openhuman/channels/providers/yuanbao/types.rs b/src/openhuman/channels/providers/yuanbao/types.rs new file mode 100644 index 0000000000..c661d078ae --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/types.rs @@ -0,0 +1,265 @@ +//! Shared domain types for the Yuanbao channel. +//! +//! Field naming follows the upstream Yuanbao protocol (`from_account`, +//! `group_code`, `msg_id`, etc.) so that the protobuf decoders, the +//! inbound pipeline, and the outbound encoders can all share the +//! same `InboundMessage` / `MsgBodyElement` shapes without re-mapping. +//! +//! Source of truth: `gateway/platforms/yuanbao_proto.py` in +//! hermes-agent (lines 415-705). + +use serde::{Deserialize, Serialize}; + +/// Decoded ConnMsg envelope (head + payload). +#[derive(Debug, Clone)] +pub struct ConnFrame { + /// CmdType (`CMD_TYPE`): Request=0, Response=1, Push=2, PushAck=3. + pub cmd_type: u32, + /// Command word, e.g. `"auth-bind"`, `"ping"`, `"send_c2c_message"`. + pub cmd: String, + /// Module / service name, e.g. `"conn_access"` or `"yuanbao_openclaw_proxy"`. + pub module: String, + /// Per-message sequence number. + pub seq_no: u32, + /// Application-level message id. + pub msg_id: String, + /// Whether the server expects an ACK. + pub need_ack: bool, + /// Status code (head.status, field 10). + pub status: u32, + /// Biz payload bytes (ConnMsg.data, field 2). + pub data: Vec, +} + +/// One element of the TIM-style `msg_body` array (e.g. text, image, file). +#[derive(Debug, Clone, Default, PartialEq)] +pub struct MsgBodyElement { + /// `"TIMTextElem"`, `"TIMImageElem"`, `"TIMFileElem"`, `"TIMSoundElem"`, … + pub msg_type: String, + pub msg_content: MsgContent, +} + +/// Generic union of all TIM `msg_content` shapes (text/image/file/sound). +/// +/// Only the fields relevant to the active `msg_type` are populated; the +/// rest stay at their `Default`. +#[derive(Debug, Clone, Default, PartialEq)] +pub struct MsgContent { + /// Field 1 — text payload. + pub text: Option, + /// Field 2 — file uuid (MD5 for images/files). + pub uuid: Option, + /// Field 3 — image format code (1=JPEG, 2=GIF, 3=PNG, 4=BMP, 255=WEBP). + pub image_format: Option, + /// Field 4 — raw inline data (rarely used; usually `url` is set instead). + pub data: Option, + /// Field 5 — element description. + pub desc: Option, + /// Field 6 — extension JSON / blob. + pub ext: Option, + /// Field 7 — voice payload identifier. + pub sound: Option, + /// Field 8 — repeated `ImageInfo` for the image element. + pub image_info_array: Vec, + /// Field 9 — element index within a multi-image message. + pub index: Option, + /// Field 10 — resource URL. + pub url: Option, + /// Field 11 — file size in bytes. + pub file_size: Option, + /// Field 12 — file name. + pub file_name: Option, +} + +/// Per-resolution image variant. `type` is 1=original, 2=large, 3=thumb. +#[derive(Debug, Clone, Default, PartialEq)] +pub struct ImageInfo { + pub image_type: u32, + pub size: u32, + pub width: u32, + pub height: u32, + pub url: String, +} + +/// A single recall entry in `recall_msg_seq_list` (InboundMessagePush field 17). +#[derive(Debug, Clone, Default, PartialEq)] +pub struct ImMsgSeq { + pub msg_seq: u32, + pub msg_id: String, +} + +/// A decoded `InboundMessagePush` biz payload — what the yuanbao gateway +/// pushes down to us for every incoming message. +#[derive(Debug, Clone, Default)] +pub struct InboundMessage { + pub callback_command: String, + pub from_account: String, + pub to_account: String, + pub sender_nickname: String, + /// Empty string for DMs, group ID for group messages. + pub group_id: String, + /// Empty string for DMs, group code (canonical group ref) for group messages. + pub group_code: String, + pub group_name: String, + pub msg_seq: u32, + pub msg_random: u32, + /// Server-side message timestamp (seconds since epoch). + pub msg_time: u32, + pub msg_key: String, + /// Stable application-level message ID. + pub msg_id: String, + pub msg_body: Vec, + pub cloud_custom_data: String, + pub event_time: u32, + pub bot_owner_id: String, + pub recall_msg_seq_list: Vec, + pub claw_msg_type: u32, + pub private_from_group_code: String, + pub trace_id: String, +} + +impl InboundMessage { + /// Whether this is a group message. + pub fn is_group(&self) -> bool { + !self.group_code.is_empty() + } + + /// Whether the message looks like a recall notification. + pub fn is_recall(&self) -> bool { + !self.recall_msg_seq_list.is_empty() + } + + /// Routing key — group_code for groups, sender uid for DMs. + pub fn chat_id(&self) -> &str { + if self.is_group() { + &self.group_code + } else { + &self.from_account + } + } + + /// Concatenated text content (joins all `TIMTextElem`s). + pub fn extract_text(&self) -> String { + let mut out = String::new(); + for el in &self.msg_body { + if el.msg_type == "TIMTextElem" { + if let Some(ref t) = el.msg_content.text { + if !out.is_empty() { + out.push('\n'); + } + out.push_str(t); + } + } + } + out + } + + /// All image URLs in the message (from `TIMImageElem` elements). + pub fn extract_image_urls(&self) -> Vec { + let mut urls = Vec::new(); + for el in &self.msg_body { + if el.msg_type == "TIMImageElem" { + for info in &el.msg_content.image_info_array { + if !info.url.is_empty() { + urls.push(info.url.clone()); + } + } + if let Some(ref url) = el.msg_content.url { + if !url.is_empty() && !urls.contains(url) { + urls.push(url.clone()); + } + } + } + } + urls + } +} + +/// High-level classification produced by the inbound pipeline. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum MessageKind { + #[default] + Text, + Image, + File, + Voice, + Mixed, + /// Recall notification — handled by `RecallGuard`, never dispatched. + Recall, +} + +/// Where the message came from — used by the outbound side to address replies. +#[derive(Debug, Clone, Default)] +pub struct Source { + pub from_account: String, + pub sender_nickname: String, + pub group_code: String, + /// `true` for group chats, `false` for DMs. + pub is_group: bool, +} + +impl Source { + /// Stable string for `ChannelMessage.sender` / `reply_target` — + /// `g:` for groups, raw uid for DMs. This format also + /// round-trips through `parse_recipient` in `outbound.rs`. + pub fn reply_target(&self) -> String { + if self.is_group { + format!("g:{}", self.group_code) + } else { + self.from_account.clone() + } + } +} + +/// Group metadata returned by `QueryGroupInfo`. +#[derive(Debug, Clone, Default, PartialEq)] +pub struct GroupInfo { + pub code: i32, + pub message: String, + pub group_name: String, + pub owner_id: String, + pub owner_nickname: String, + pub member_count: u32, +} + +/// One member returned by `GetGroupMemberList`. +#[derive(Debug, Clone, Default, PartialEq)] +pub struct GroupMember { + pub user_id: String, + pub nickname: String, + /// 0=member, 1=admin, 2=owner. + pub role: u32, + pub join_time: u32, + pub name_card: String, +} + +/// Paginated result of `GetGroupMemberList`. +#[derive(Debug, Clone, Default)] +pub struct GroupMemberListPage { + pub code: i32, + pub message: String, + pub members: Vec, + pub next_offset: u32, + pub is_complete: bool, +} + +/// Cached account info — populated after `auth-bind` succeeds. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct Account { + /// Bot user-id (used as `from_account` in outbound messages). + pub uid: String, + /// Display name (best-effort; may be empty until first inbound message). + pub nickname: String, + /// Server-assigned connection id (`AuthBindRsp.connect_id`, field 3). + pub connect_id: String, +} + +/// Connection state machine (matches task list spec). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ConnectionState { + Disconnected, + Connecting, + Authenticating, + Connected, + Reconnecting, +} diff --git a/src/openhuman/channels/providers/yuanbao/wire.rs b/src/openhuman/channels/providers/yuanbao/wire.rs new file mode 100644 index 0000000000..d7137c20cd --- /dev/null +++ b/src/openhuman/channels/providers/yuanbao/wire.rs @@ -0,0 +1,232 @@ +//! Hand-rolled protobuf wire-format primitives. +//! +//! Only varints, length-delimited bytes, and the two fixed-width forms +//! are supported — that's everything the yuanbao protocol uses. Kept +//! separate from `proto.rs` so the latter stays under 500 lines and +//! reads as a "schema" file. + +use std::sync::atomic::{AtomicU32, Ordering}; + +use super::errors::YuanbaoError; + +/// Global per-process sequence number for ConnMsg head.seq_no. +static SEQ: AtomicU32 = AtomicU32::new(1); + +pub fn next_seq_no() -> u32 { + SEQ.fetch_add(1, Ordering::Relaxed) +} + +pub const WT_VARINT: u8 = 0; +pub const WT_LEN: u8 = 2; + +// ─── Varint ───────────────────────────────────────────────────────── + +pub fn encode_varint(mut value: u64, buf: &mut Vec) { + loop { + let mut byte = (value & 0x7F) as u8; + value >>= 7; + if value != 0 { + byte |= 0x80; + } + buf.push(byte); + if value == 0 { + break; + } + } +} + +pub fn decode_varint(data: &[u8], pos: usize) -> Result<(u64, usize), YuanbaoError> { + let mut value: u64 = 0; + let mut shift: u32 = 0; + let mut i = pos; + loop { + if i >= data.len() { + return Err(YuanbaoError::ProtoDecode("truncated varint".into())); + } + let byte = data[i]; + value |= ((byte & 0x7F) as u64) << shift; + i += 1; + if byte & 0x80 == 0 { + return Ok((value, i - pos)); + } + shift += 7; + if shift >= 64 { + return Err(YuanbaoError::ProtoDecode("varint too long".into())); + } + } +} + +// ─── Field encoders ──────────────────────────────────────────────── + +pub fn encode_field_varint(field: u32, value: u64, buf: &mut Vec) { + encode_varint(((field as u64) << 3) | WT_VARINT as u64, buf); + encode_varint(value, buf); +} + +pub fn encode_field_bytes(field: u32, data: &[u8], buf: &mut Vec) { + encode_varint(((field as u64) << 3) | WT_LEN as u64, buf); + encode_varint(data.len() as u64, buf); + buf.extend_from_slice(data); +} + +pub fn encode_field_string(field: u32, s: &str, buf: &mut Vec) { + encode_field_bytes(field, s.as_bytes(), buf); +} + +// ─── Field parsing ────────────────────────────────────────────────── + +#[derive(Debug)] +pub enum FieldValue { + Varint(u64), + Bytes(Vec), + Fixed32(u32), + Fixed64(u64), +} + +pub fn parse_fields(data: &[u8]) -> Result, YuanbaoError> { + let mut out = Vec::new(); + let mut pos = 0; + while pos < data.len() { + let (tag, n) = decode_varint(data, pos)?; + pos += n; + let field = (tag >> 3) as u32; + let wire = (tag & 0x07) as u8; + match wire { + WT_VARINT => { + let (v, n) = decode_varint(data, pos)?; + pos += n; + out.push((field, FieldValue::Varint(v))); + } + WT_LEN => { + let (len, n) = decode_varint(data, pos)?; + pos += n; + let end = pos + len as usize; + if end > data.len() { + return Err(YuanbaoError::ProtoDecode(format!( + "truncated len field {field}: need {len} have {}", + data.len() - pos + ))); + } + out.push((field, FieldValue::Bytes(data[pos..end].to_vec()))); + pos = end; + } + 1 => { + if pos + 8 > data.len() { + return Err(YuanbaoError::ProtoDecode("truncated fixed64".into())); + } + let v = u64::from_le_bytes(data[pos..pos + 8].try_into().unwrap()); + pos += 8; + out.push((field, FieldValue::Fixed64(v))); + } + 5 => { + if pos + 4 > data.len() { + return Err(YuanbaoError::ProtoDecode("truncated fixed32".into())); + } + let v = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()); + pos += 4; + out.push((field, FieldValue::Fixed32(v))); + } + other => { + return Err(YuanbaoError::ProtoDecode(format!( + "unsupported wire type {other} at field {field}" + ))); + } + } + } + Ok(out) +} + +pub fn get_string(fields: &[(u32, FieldValue)], num: u32) -> String { + for (n, v) in fields { + if *n == num { + if let FieldValue::Bytes(b) = v { + return String::from_utf8_lossy(b).into_owned(); + } + } + } + String::new() +} + +pub fn get_varint(fields: &[(u32, FieldValue)], num: u32) -> u64 { + for (n, v) in fields { + if *n == num { + if let FieldValue::Varint(x) = v { + return *x; + } + } + } + 0 +} + +pub fn get_bytes(fields: &[(u32, FieldValue)], num: u32) -> Vec { + for (n, v) in fields { + if *n == num { + if let FieldValue::Bytes(b) = v { + return b.clone(); + } + } + } + Vec::new() +} + +pub fn get_repeated_bytes(fields: &[(u32, FieldValue)], num: u32) -> Vec> { + fields + .iter() + .filter_map(|(n, v)| match v { + FieldValue::Bytes(b) if *n == num => Some(b.clone()), + _ => None, + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn varint_roundtrip() { + for &v in &[0u64, 1, 127, 128, 300, 16384, u32::MAX as u64, u64::MAX] { + let mut buf = Vec::new(); + encode_varint(v, &mut buf); + let (got, n) = decode_varint(&buf, 0).unwrap(); + assert_eq!(got, v, "varint roundtrip failed for {v}"); + assert_eq!(n, buf.len()); + } + } + + #[test] + fn varint_truncated_errors() { + let buf = vec![0x80, 0x80]; // continuation bit set but no end + assert!(decode_varint(&buf, 0).is_err()); + } + + #[test] + fn field_roundtrip() { + let mut buf = Vec::new(); + encode_field_varint(1, 42, &mut buf); + encode_field_string(2, "hello", &mut buf); + encode_field_bytes(3, b"\x00\x01\x02", &mut buf); + + let fields = parse_fields(&buf).unwrap(); + assert_eq!(get_varint(&fields, 1), 42); + assert_eq!(get_string(&fields, 2), "hello"); + assert_eq!(get_bytes(&fields, 3), vec![0, 1, 2]); + } + + #[test] + fn unknown_field_skipped_gracefully() { + let mut buf = Vec::new(); + encode_field_varint(99, 123, &mut buf); + encode_field_string(1, "wanted", &mut buf); + let fields = parse_fields(&buf).unwrap(); + assert_eq!(get_string(&fields, 1), "wanted"); + assert_eq!(get_string(&fields, 2), ""); // missing field returns default + } + + #[test] + fn seq_numbers_are_monotonic() { + let a = next_seq_no(); + let b = next_seq_no(); + assert!(b > a); + } +} diff --git a/src/openhuman/channels/runtime/startup.rs b/src/openhuman/channels/runtime/startup.rs index 73ccb9f5dd..965e119db4 100644 --- a/src/openhuman/channels/runtime/startup.rs +++ b/src/openhuman/channels/runtime/startup.rs @@ -28,6 +28,7 @@ use crate::openhuman::channels::traits; use crate::openhuman::channels::whatsapp::WhatsAppChannel; #[cfg(feature = "whatsapp-web")] use crate::openhuman::channels::whatsapp_web::WhatsAppWebChannel; +use crate::openhuman::channels::yuanbao::YuanbaoChannel; use crate::openhuman::channels::Channel; use crate::openhuman::config::Config; use crate::openhuman::context::channels_prompt::build_system_prompt; @@ -500,6 +501,13 @@ pub async fn start_channels(config: Config) -> Result<()> { ))); } + if let Some(ref yb) = config.channels_config.yuanbao { + match YuanbaoChannel::new(yb.clone()) { + Ok(ch) => channels.push(Arc::new(ch)), + Err(e) => tracing::warn!("[channels] yuanbao config invalid: {e}"), + } + } + if channels.is_empty() { println!("No channels configured. Set up channels in the web UI."); return Ok(()); diff --git a/src/openhuman/config/schema/channels.rs b/src/openhuman/config/schema/channels.rs index 7489da8d7a..60d2c17318 100644 --- a/src/openhuman/config/schema/channels.rs +++ b/src/openhuman/config/schema/channels.rs @@ -23,6 +23,7 @@ pub struct ChannelsConfig { pub lark: Option, pub dingtalk: Option, pub qq: Option, + pub yuanbao: Option, #[serde(default = "default_channel_message_timeout_secs")] pub message_timeout_secs: u64, /// The user's preferred *external* channel for proactive messages @@ -61,6 +62,7 @@ impl ChannelsConfig { || self.lark.is_some() || self.dingtalk.is_some() || self.qq.is_some() + || self.yuanbao.is_some() || self.matrix.is_some() || self.whatsapp.is_some() } @@ -85,6 +87,7 @@ impl Default for ChannelsConfig { lark: None, dingtalk: None, qq: None, + yuanbao: None, message_timeout_secs: default_channel_message_timeout_secs(), active_channel: None, } From c557caa8393a3754bf639569bfd7f6ef8402db44 Mon Sep 17 00:00:00 2001 From: lrt4836 <296659110@qq.com> Date: Sat, 23 May 2026 21:55:10 +0800 Subject: [PATCH 2/8] feat: add unit test --- .../channels/providers/yuanbao/channel.rs | 180 +++++++++++ .../channels/providers/yuanbao/connection.rs | 151 ++++++++++ .../channels/providers/yuanbao/cos.rs | 166 +++++++++++ .../channels/providers/yuanbao/media.rs | 281 ++++++++++++++++++ .../channels/providers/yuanbao/outbound.rs | 71 +++++ .../channels/providers/yuanbao/proto.rs | 238 +++++++++++++++ .../channels/providers/yuanbao/proto_biz.rs | 173 +++++++++++ .../channels/providers/yuanbao/sign.rs | 183 ++++++++++++ .../channels/providers/yuanbao/types.rs | 172 +++++++++++ .../channels/providers/yuanbao/wire.rs | 123 ++++++++ 10 files changed, 1738 insertions(+) diff --git a/src/openhuman/channels/providers/yuanbao/channel.rs b/src/openhuman/channels/providers/yuanbao/channel.rs index 40cf27a85f..2baeaa8646 100644 --- a/src/openhuman/channels/providers/yuanbao/channel.rs +++ b/src/openhuman/channels/providers/yuanbao/channel.rs @@ -528,4 +528,184 @@ mod tests { assert_eq!(m.len(), 1); assert_eq!(m.get("new_short").map(String::as_str), Some("new_original")); } + + // ─── trivial trait methods ───────────────────────────────────── + + #[test] + fn supports_draft_updates_is_true() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + assert!(ch.supports_draft_updates()); + } + + #[test] + fn supports_reactions_is_false() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + assert!(!ch.supports_reactions()); + } + + #[tokio::test] + async fn send_draft_returns_marker_id() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + let msg = SendMessage::new("ignored", "user-42"); + let id = ch.send_draft(&msg).await.unwrap(); + assert_eq!(id.as_deref(), Some("yb-draft:user-42")); + } + + #[tokio::test] + async fn update_draft_is_a_noop_ok() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + assert!(ch.update_draft("user-42", "any-id", "text").await.is_ok()); + } + + #[tokio::test] + async fn health_check_is_false_when_socket_not_connected() { + // Real connect requires a WebSocket; we only verify the + // disconnected default here. The connected branch is exercised + // by `connection::tests::set_state_connected_flips_is_connected_flag`. + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + assert!(!ch.health_check().await); + } + + // ─── dispatch_push branches ──────────────────────────────────── + + fn make_push_frame(cmd: &str, data: Vec) -> types::ConnFrame { + types::ConnFrame { + cmd_type: super::super::proto_constants::cmd_type::PUSH, + cmd: cmd.into(), + module: "yuanbao_openclaw_proxy".into(), + seq_no: 0, + msg_id: String::new(), + need_ack: false, + status: 0, + data, + } + } + + #[tokio::test] + async fn dispatch_push_empty_body_is_skipped() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + let (tx, mut rx) = mpsc::channel::(4); + let frame = make_push_frame("noop", Vec::new()); + ch.dispatch_push(frame, &tx).await; + // No message should reach the sender. + assert!(rx.try_recv().is_err()); + } + + #[tokio::test] + async fn dispatch_push_garbage_body_does_not_dispatch() { + // Body is not a valid protobuf push *and* not valid JSON → Failed. + // dispatch_push should log + swallow, not propagate panic. + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + let (tx, mut rx) = mpsc::channel::(4); + let frame = make_push_frame("inbound_message", vec![0xFF, 0xFF, 0xFF, 0xFF]); + ch.dispatch_push(frame, &tx).await; + assert!(rx.try_recv().is_err()); + } + + #[tokio::test] + async fn dispatch_push_dm_text_reaches_listener() { + // Build a minimal `InboundMessagePush` directly in ConnFrame.data + // (no PushMsg envelope), with a single TIMTextElem so the pipeline + // dispatches. + use super::super::proto::{encode_msg_body_element, encode_varint}; + let elem = types::MsgBodyElement { + msg_type: "TIMTextElem".into(), + msg_content: types::MsgContent { + text: Some("hello".into()), + ..Default::default() + }, + }; + let elem_bytes = encode_msg_body_element(&elem); + + // Hand-roll an InboundMessagePush so we don't depend on a helper: + // field 2 = from_account, field 3 = to_account, field 12 = msg_id, + // field 13 = repeated MsgBodyElement. + let mut biz = Vec::new(); + let put_string = |fnum: u32, s: &str, b: &mut Vec| { + encode_varint(((fnum as u64) << 3) | 2, b); + encode_varint(s.len() as u64, b); + b.extend_from_slice(s.as_bytes()); + }; + put_string(2, "alice", &mut biz); + put_string(3, "bot1", &mut biz); + put_string(12, "mid-x", &mut biz); + encode_varint(((13u64) << 3) | 2, &mut biz); + encode_varint(elem_bytes.len() as u64, &mut biz); + biz.extend_from_slice(&elem_bytes); + + // Disable group_at_required and use open dm_access so the + // pipeline passes all stages for this DM. + let mut cfg = good_cfg(); + cfg.dm_access = "open".into(); + cfg.bot_id = "bot1".into(); + let ch = YuanbaoChannel::new(cfg).unwrap(); + + let frame = make_push_frame("inbound_message", biz); + let (tx, mut rx) = mpsc::channel::(4); + ch.dispatch_push(frame, &tx).await; + let msg = rx.try_recv().expect("dispatch should produce one message"); + assert_eq!(msg.id, "mid-x"); + assert_eq!(msg.content, "hello"); + assert_eq!(msg.channel, "yuanbao"); + } + + #[tokio::test] + async fn dispatch_push_filtered_by_dedup_does_not_double_dispatch() { + use super::super::proto::{encode_msg_body_element, encode_varint}; + let elem = types::MsgBodyElement { + msg_type: "TIMTextElem".into(), + msg_content: types::MsgContent { + text: Some("dup".into()), + ..Default::default() + }, + }; + let elem_bytes = encode_msg_body_element(&elem); + let mut biz = Vec::new(); + let put_string = |fnum: u32, s: &str, b: &mut Vec| { + encode_varint(((fnum as u64) << 3) | 2, b); + encode_varint(s.len() as u64, b); + b.extend_from_slice(s.as_bytes()); + }; + put_string(2, "alice", &mut biz); + put_string(3, "bot1", &mut biz); + put_string(12, "dup-id", &mut biz); + encode_varint(((13u64) << 3) | 2, &mut biz); + encode_varint(elem_bytes.len() as u64, &mut biz); + biz.extend_from_slice(&elem_bytes); + + let mut cfg = good_cfg(); + cfg.dm_access = "open".into(); + cfg.bot_id = "bot1".into(); + let ch = YuanbaoChannel::new(cfg).unwrap(); + let (tx, mut rx) = mpsc::channel::(4); + ch.dispatch_push(make_push_frame("inbound_message", biz.clone()), &tx) + .await; + assert!(rx.try_recv().is_ok(), "first should dispatch"); + ch.dispatch_push(make_push_frame("inbound_message", biz), &tx) + .await; + assert!(rx.try_recv().is_err(), "second (same id) should dedup"); + } + + // ─── heartbeat task lifecycle ────────────────────────────────── + + #[tokio::test] + async fn start_heartbeat_task_inserts_and_stop_removes() { + let ch = YuanbaoChannel::new(good_cfg()).unwrap(); + ch.start_heartbeat_task("recipient-1").await; + assert!( + ch.heartbeat_tasks + .lock() + .await + .contains_key("recipient-1"), + "should have spawned a task for recipient-1" + ); + // Second start for same recipient is a no-op (does not double-spawn). + ch.start_heartbeat_task("recipient-1").await; + assert_eq!(ch.heartbeat_tasks.lock().await.len(), 1); + + ch.stop_heartbeat_task("recipient-1").await; + assert!(ch.heartbeat_tasks.lock().await.is_empty()); + // Stopping a recipient with no task is also a no-op. + ch.stop_heartbeat_task("never-started").await; + } } diff --git a/src/openhuman/channels/providers/yuanbao/connection.rs b/src/openhuman/channels/providers/yuanbao/connection.rs index 796c802405..480581cbaa 100644 --- a/src/openhuman/channels/providers/yuanbao/connection.rs +++ b/src/openhuman/channels/providers/yuanbao/connection.rs @@ -568,4 +568,155 @@ mod tests { conn.update_account(|a| a.connect_id = "cid_xyz".into()); assert_eq!(conn.account().connect_id, "cid_xyz"); } + + #[tokio::test] + async fn next_msg_id_is_monotonic_and_prefixed() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + let a = conn.next_msg_id("pfx"); + let b = conn.next_msg_id("pfx"); + assert!(a.starts_with("pfx_")); + assert!(b.starts_with("pfx_")); + // Suffix is monotonically increasing. + let na: u64 = a.strip_prefix("pfx_").unwrap().parse().unwrap(); + let nb: u64 = b.strip_prefix("pfx_").unwrap().parse().unwrap(); + assert!(nb > na); + } + + #[tokio::test] + async fn initial_state_is_disconnected() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + assert_eq!(conn.state(), ConnectionState::Disconnected); + assert!(!conn.is_connected()); + } + + #[tokio::test] + async fn set_state_connected_flips_is_connected_flag() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + conn.set_state(ConnectionState::Connected); + assert_eq!(conn.state(), ConnectionState::Connected); + assert!(conn.is_connected()); + conn.set_state(ConnectionState::Reconnecting); + assert_eq!(conn.state(), ConnectionState::Reconnecting); + assert!(!conn.is_connected()); + } + + #[tokio::test] + async fn send_frame_without_socket_returns_not_connected() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + let err = conn.send_frame(vec![1, 2, 3]).await.unwrap_err(); + assert!(matches!(err, YuanbaoError::NotConnected)); + let err2 = conn.send_conn_msg(vec![4]).await.unwrap_err(); + assert!(matches!(err2, YuanbaoError::NotConnected)); + } + + #[tokio::test] + async fn shutdown_clears_pending_and_sets_disconnected() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + conn.set_state(ConnectionState::Connected); + // Drop a phantom pending entry then shutdown. + let (phantom_tx, _phantom_rx) = oneshot::channel(); + conn.pending.lock().insert("ghost".into(), phantom_tx); + conn.shutdown().await; + assert_eq!(conn.state(), ConnectionState::Disconnected); + assert!(!conn.is_connected()); + assert!(conn.pending.lock().is_empty()); + } + + #[test] + fn backoff_caps_at_last_entry_for_huge_attempts() { + let last = *RECONNECT_DELAYS.last().unwrap(); + assert_eq!(backoff_seconds(RECONNECT_DELAYS.len() as u32 + 5), last); + } + + #[tokio::test] + async fn resolve_token_uses_static_token_when_present() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + let (token, bot_id, source) = conn.resolve_token().await.unwrap(); + assert_eq!(token, "tok"); + assert_eq!(bot_id, "bot1"); + assert_eq!(source, ""); + } + + #[tokio::test] + async fn resolve_token_without_token_and_without_sign_manager_errors() { + let (tx, _rx) = mpsc::unbounded_channel(); + let mut c = cfg(); + c.token = String::new(); + let conn = YuanbaoConnection::new(c, tx, None); + match conn.resolve_token().await.unwrap_err() { + YuanbaoError::AuthFailed(m) => assert!(m.contains("no token"), "got {m}"), + other => panic!("unexpected {other:?}"), + } + } + + #[tokio::test] + async fn resolve_token_with_sign_manager_but_no_app_secret_errors() { + let (tx, _rx) = mpsc::unbounded_channel(); + let mut c = cfg(); + c.token = String::new(); + c.app_secret = String::new(); + let mgr = SignManager::new(reqwest::Client::new()); + let conn = YuanbaoConnection::new(c, tx, Some(mgr)); + match conn.resolve_token().await.unwrap_err() { + YuanbaoError::AuthFailed(m) => assert!(m.contains("app_secret"), "got {m}"), + other => panic!("unexpected {other:?}"), + } + } + + #[tokio::test] + async fn send_auth_bind_without_socket_returns_not_connected() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + let err = conn.send_auth_bind("tok", "bot1", "bot").await.unwrap_err(); + assert!(matches!(err, YuanbaoError::NotConnected)); + } + + #[tokio::test] + async fn send_auth_bind_falls_back_to_account_uid_when_bot_id_empty() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + // bot_id="" → reads from account.uid (which was seeded from cfg.bot_id="bot1") + let err = conn.send_auth_bind("tok", "", "").await.unwrap_err(); + assert!(matches!(err, YuanbaoError::NotConnected)); + // Account uid still in place. + assert_eq!(conn.account().uid, "bot1"); + } + + #[test] + fn handle_auth_response_rejects_non_binary_message() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + let msg = Message::Text("nope".into()); + match conn.handle_auth_response(&msg).unwrap_err() { + YuanbaoError::AuthFailed(m) => { + assert!(m.contains("binary"), "got {m}") + } + other => panic!("unexpected {other:?}"), + } + } + + #[test] + fn handle_auth_response_rejects_undecodable_binary() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + // Wholly invalid wire data — decode_conn_msg fails. + let msg = Message::Binary(vec![0xFF, 0xFF, 0xFF, 0xFF]); + let err = conn.handle_auth_response(&msg).unwrap_err(); + // Either Proto decode error or some other surface — must not be Ok. + assert!(!matches!(err, YuanbaoError::AuthFailed(_) if format!("{err:?}").contains("binary"))); + } + + #[tokio::test] + async fn handle_binary_with_garbage_does_not_panic() { + let (tx, _rx) = mpsc::unbounded_channel(); + let conn = YuanbaoConnection::new(cfg(), tx, None); + // Should silently log + return — no panic. + conn.handle_binary(vec![0xFF, 0xFF, 0xFF, 0xFF]).await; + } } diff --git a/src/openhuman/channels/providers/yuanbao/cos.rs b/src/openhuman/channels/providers/yuanbao/cos.rs index 0f2d3b3e32..3dc2a9bda1 100644 --- a/src/openhuman/channels/providers/yuanbao/cos.rs +++ b/src/openhuman/channels/providers/yuanbao/cos.rs @@ -389,4 +389,170 @@ mod tests { let s2 = cos_sign(&CosSignInput { path: "/b", ..base }); assert_ne!(s1, s2); } + + #[test] + fn cos_sign_lowercases_method_and_includes_url_params() { + let s = cos_sign(&CosSignInput { + method: "PUT", // mixed case → should be lowercased into sig + path: "/k", + params: &[("Foo", "Bar Baz")], // url-encoded value + headers: &[("Host", "h")], + secret_id: "AKID", + secret_key: "SK", + start_time: 1_700_000_000, + expire_seconds: 600, + }); + assert!(s.contains("q-url-param-list=foo")); + // header list also lowercased + assert!(s.contains("q-header-list=host")); + } + + fn ok_credentials_body(bucket: &str, location: &str) -> serde_json::Value { + serde_json::json!({ + "code": 0, + "data": { + "bucketName": bucket, + "region": "ap-shanghai", + "location": location, + "encryptTmpSecretId": "AKID", + "encryptTmpSecretKey": "SECRET", + "encryptToken": "session-token", + "startTime": 1_700_000_000u64, + "expiredTime": 1_700_003_600u64, + "resourceUrl": "https://cdn.example/r", + } + }) + } + + #[tokio::test] + async fn get_cos_credentials_parses_data_block() { + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .and(wiremock::matchers::path(UPLOAD_INFO_PATH)) + .and(wiremock::matchers::header("X-Token", "tok")) + .and(wiremock::matchers::header("X-Source", "web")) + .respond_with( + wiremock::ResponseTemplate::new(200) + .set_body_json(ok_credentials_body("bkt-1", "k/v/file.png")), + ) + .mount(&server) + .await; + let http = reqwest::Client::new(); + let creds = + get_cos_credentials(&http, &server.uri(), "appk", "bot", "tok", "", "file.png") + .await + .unwrap(); + assert_eq!(creds.bucket, "bkt-1"); + assert_eq!(creds.region, "ap-shanghai"); + assert_eq!(creds.location, "k/v/file.png"); + assert_eq!(creds.secret_id, "AKID"); + assert_eq!(creds.secret_key, "SECRET"); + assert_eq!(creds.session_token, "session-token"); + assert_eq!(creds.resource_url, "https://cdn.example/r"); + assert_eq!(creds.start_time, 1_700_000_000); + assert_eq!(creds.expired_time, 1_700_003_600); + } + + #[tokio::test] + async fn get_cos_credentials_falls_back_to_app_key_for_xid_when_bot_id_empty() { + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .and(wiremock::matchers::path(UPLOAD_INFO_PATH)) + .and(wiremock::matchers::header("X-ID", "appk")) + .respond_with( + wiremock::ResponseTemplate::new(200) + .set_body_json(ok_credentials_body("bkt", "loc")), + ) + .mount(&server) + .await; + let http = reqwest::Client::new(); + let creds = get_cos_credentials(&http, &server.uri(), "appk", "", "tok", "", "f") + .await + .unwrap(); + assert_eq!(creds.bucket, "bkt"); + } + + #[tokio::test] + async fn get_cos_credentials_sends_route_env_header_when_non_empty() { + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .and(wiremock::matchers::header("X-Route-Env", "canary")) + .respond_with( + wiremock::ResponseTemplate::new(200) + .set_body_json(ok_credentials_body("bkt", "loc")), + ) + .mount(&server) + .await; + let http = reqwest::Client::new(); + get_cos_credentials(&http, &server.uri(), "appk", "bot", "tok", "canary", "f") + .await + .expect("should send canary header"); + } + + #[tokio::test] + async fn get_cos_credentials_surfaces_http_error() { + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .respond_with(wiremock::ResponseTemplate::new(500).set_body_string("boom")) + .mount(&server) + .await; + let http = reqwest::Client::new(); + let err = get_cos_credentials(&http, &server.uri(), "appk", "bot", "tok", "", "f") + .await + .unwrap_err(); + match err { + YuanbaoError::Media(m) => assert!(m.contains("HTTP 500"), "got {m}"), + other => panic!("expected Media error, got {other:?}"), + } + } + + #[tokio::test] + async fn get_cos_credentials_surfaces_non_zero_business_code() { + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 4001, + "msg": "quota", + })), + ) + .mount(&server) + .await; + let http = reqwest::Client::new(); + let err = get_cos_credentials(&http, &server.uri(), "appk", "bot", "tok", "", "f") + .await + .unwrap_err(); + match err { + YuanbaoError::Media(m) => { + assert!(m.contains("code=4001"), "got {m}"); + assert!(m.contains("quota"), "got {m}"); + } + other => panic!("expected Media error, got {other:?}"), + } + } + + #[tokio::test] + async fn upload_to_cos_rejects_missing_credentials() { + let http = reqwest::Client::new(); + // empty credentials → fail without making any HTTP call + let bad = CosCredentials::default(); + let err = upload_to_cos(&http, &bad, b"data", "f.bin", "application/octet-stream".into()) + .await + .unwrap_err(); + match err { + YuanbaoError::Media(m) => assert!(m.contains("credentials missing"), "got {m}"), + other => panic!("expected Media error, got {other:?}"), + } + } + + // NOTE: upload_to_cos always targets `.cos.accelerate.myqcloud.com` + // which we cannot redirect at the reqwest layer without DNS hacks, so we + // only cover the guard branch (missing creds) above. The PUT body itself + // is exercised by integration tests, not unit tests. + + #[test] + fn encode_cos_key_keeps_slashes_but_escapes_segments() { + assert_eq!(encode_cos_key("plain/file.png"), "plain/file.png"); + assert_eq!(encode_cos_key("a b/c d.png"), "a%20b/c%20d.png"); + } } diff --git a/src/openhuman/channels/providers/yuanbao/media.rs b/src/openhuman/channels/providers/yuanbao/media.rs index 0066e5a480..266db09bb3 100644 --- a/src/openhuman/channels/providers/yuanbao/media.rs +++ b/src/openhuman/channels/providers/yuanbao/media.rs @@ -336,4 +336,285 @@ mod tests { assert_eq!(image_format_code("image/webp"), 255); assert_eq!(image_format_code("application/pdf"), 255); } + + // ─── extended MIME / image-format tests ───────────────────── + + #[test] + fn guess_mime_handles_uppercase_extension() { + assert_eq!(guess_mime_type("PHOTO.JPG"), "image/jpeg"); + assert_eq!(guess_mime_type("Doc.PDF"), "application/pdf"); + } + + #[test] + fn guess_mime_covers_office_audio_video_archive_types() { + assert_eq!(guess_mime_type("file.jpg"), "image/jpeg"); + assert_eq!(guess_mime_type("file.jpeg"), "image/jpeg"); + assert_eq!(guess_mime_type("file.gif"), "image/gif"); + assert_eq!(guess_mime_type("file.webp"), "image/webp"); + assert_eq!(guess_mime_type("file.bmp"), "image/bmp"); + assert_eq!(guess_mime_type("file.heic"), "image/heic"); + assert_eq!(guess_mime_type("file.tiff"), "image/tiff"); + assert_eq!(guess_mime_type("file.ico"), "image/x-icon"); + assert_eq!(guess_mime_type("file.doc"), "application/msword"); + assert_eq!( + guess_mime_type("file.docx"), + "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + ); + assert_eq!(guess_mime_type("file.xls"), "application/vnd.ms-excel"); + assert_eq!( + guess_mime_type("file.xlsx"), + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + ); + assert_eq!( + guess_mime_type("file.ppt"), + "application/vnd.ms-powerpoint" + ); + assert_eq!( + guess_mime_type("file.pptx"), + "application/vnd.openxmlformats-officedocument.presentationml.presentation" + ); + assert_eq!(guess_mime_type("file.txt"), "text/plain"); + assert_eq!(guess_mime_type("file.zip"), "application/zip"); + assert_eq!(guess_mime_type("file.tar"), "application/x-tar"); + assert_eq!(guess_mime_type("file.gz"), "application/gzip"); + assert_eq!(guess_mime_type("file.mp3"), "audio/mpeg"); + assert_eq!(guess_mime_type("file.mp4"), "video/mp4"); + assert_eq!(guess_mime_type("file.wav"), "audio/wav"); + assert_eq!(guess_mime_type("file.ogg"), "audio/ogg"); + assert_eq!(guess_mime_type("file.webm"), "video/webm"); + } + + #[test] + fn image_format_code_jpg_alias_and_heic_tiff() { + assert_eq!(image_format_code("image/jpg"), 1); + assert_eq!(image_format_code("image/heic"), 255); + assert_eq!(image_format_code("image/tiff"), 255); + assert_eq!(image_format_code(""), 255); + } + + // ─── parse_image_size — JPEG / WEBP / negative paths ──────── + + #[test] + fn jpeg_dims_from_sof0_marker() { + // SOI + filler APP0 segment + SOF0 marker carrying h=2, w=3. + // parse_jpeg's read uses buf[i+5..i+9] and gates on `i + 9 < buf.len()`, + // so trailing pad bytes are required (one tail byte makes 17 < 18 true). + let mut jpeg = vec![0xFF, 0xD8]; + // APP0 (0xFFE0) with len=4 → 2 bytes payload + jpeg.extend_from_slice(&[0xFF, 0xE0, 0x00, 0x04, 0x00, 0x00]); + // SOF0: 0xFF 0xC0 len(11) precision(8) height(2 BE) width(3 BE) + jpeg.extend_from_slice(&[0xFF, 0xC0, 0x00, 0x0B, 0x08, 0x00, 0x02, 0x00, 0x03]); + // trailing pad so the `i + 9 < buf.len()` loop guard accepts the SOF0 + // entry on the second iteration. + jpeg.push(0xFF); + let d = parse_image_size(&jpeg).expect("jpeg parse"); + assert_eq!(d.width, 3); + assert_eq!(d.height, 2); + } + + #[test] + fn jpeg_too_short_returns_none() { + assert!(parse_image_size(&[0xFF, 0xD8]).is_none()); + } + + #[test] + fn jpeg_wrong_magic_returns_none() { + let buf = [0xCA, 0xFE, 0xBA, 0xBE, 0, 0, 0, 0, 0, 0]; + assert!(parse_image_size(&buf).is_none()); + } + + #[test] + fn webp_vp8x_dims_parse() { + // RIFF size WEBP VP8X flags(4) padding(3) w-1(LE24) h-1(LE24) + // w=320 → 319 little-endian = [0x3F, 0x01, 0x00]; h=240 → 239 = [0xEF, 0x00, 0x00] + let mut buf = b"RIFF\x00\x00\x00\x00WEBPVP8X".to_vec(); + buf.extend_from_slice(&[0u8; 8]); // flags + reserved + buf.extend_from_slice(&[0x3F, 0x01, 0x00, 0xEF, 0x00, 0x00]); + let d = parse_image_size(&buf).expect("webp vp8x parse"); + assert_eq!(d.width, 320); + assert_eq!(d.height, 240); + } + + #[test] + fn webp_too_short_returns_none() { + let buf = b"RIFF\0\0\0\0WEBPVP8"; + assert!(parse_image_size(buf).is_none()); + } + + #[test] + fn webp_unsupported_chunk_returns_none() { + let mut buf = b"RIFF\0\0\0\0WEBPXXXX".to_vec(); + buf.extend_from_slice(&[0u8; 30]); + assert!(parse_image_size(&buf).is_none()); + } + + #[test] + fn png_short_or_wrong_magic_returns_none() { + assert!(parse_image_size(&[0x89, 0x50, 0x4E, 0x47]).is_none()); // too short + let mut buf = vec![0xFF; 24]; + buf[..4].copy_from_slice(&[0x89, 0x50, 0x4F, 0x47]); // wrong magic + assert!(parse_image_size(&buf).is_none()); + } + + #[test] + fn gif_too_short_or_wrong_sig_returns_none() { + assert!(parse_image_size(b"GIF87").is_none()); + assert!(parse_image_size(b"NOTGIFEXT").is_none()); + } + + #[test] + fn parse_image_size_empty_returns_none() { + assert!(parse_image_size(&[]).is_none()); + } + + // ─── msg_body builders ────────────────────────────────────── + + #[test] + fn build_image_msg_body_uses_uuid_when_present() { + let body = build_image_msg_body( + "https://x/cat.png", + Some("uuid-1"), + Some("cat.png"), + 1024, + 800, + 600, + "image/png", + ); + assert_eq!(body.len(), 1); + let el = &body[0]; + assert_eq!(el.msg_type, "TIMImageElem"); + assert_eq!(el.msg_content.uuid.as_deref(), Some("uuid-1")); + assert_eq!(el.msg_content.image_format, Some(3)); // png + assert_eq!(el.msg_content.image_info_array.len(), 1); + let info = &el.msg_content.image_info_array[0]; + assert_eq!(info.image_type, 1); + assert_eq!(info.size, 1024); + assert_eq!(info.width, 800); + assert_eq!(info.height, 600); + assert_eq!(info.url, "https://x/cat.png"); + } + + #[test] + fn build_image_msg_body_falls_back_to_filename_then_default_uuid() { + let with_filename = + build_image_msg_body("https://x/", None, Some("only-name.png"), 0, 0, 0, "image/png"); + assert_eq!( + with_filename[0].msg_content.uuid.as_deref(), + Some("only-name.png") + ); + + let default_id = build_image_msg_body("https://x/", None, None, 0, 0, 0, "image/png"); + assert_eq!(default_id[0].msg_content.uuid.as_deref(), Some("image")); + } + + #[test] + fn build_image_msg_body_treats_empty_mime_as_format_255() { + let body = build_image_msg_body("https://x/cat.jpg", None, None, 0, 0, 0, ""); + assert_eq!(body[0].msg_content.image_format, Some(255)); + } + + #[test] + fn build_file_msg_body_uses_filename_when_uuid_missing() { + let body = build_file_msg_body("https://x/doc.pdf", "doc.pdf", None, 2048); + assert_eq!(body.len(), 1); + let el = &body[0]; + assert_eq!(el.msg_type, "TIMFileElem"); + assert_eq!(el.msg_content.uuid.as_deref(), Some("doc.pdf")); + assert_eq!(el.msg_content.file_name.as_deref(), Some("doc.pdf")); + assert_eq!(el.msg_content.file_size, Some(2048)); + assert_eq!(el.msg_content.url.as_deref(), Some("https://x/doc.pdf")); + } + + #[test] + fn build_file_msg_body_prefers_explicit_uuid() { + let body = build_file_msg_body("https://x/y.pdf", "y.pdf", Some("uuid-y"), 0); + assert_eq!(body[0].msg_content.uuid.as_deref(), Some("uuid-y")); + } + + // ─── download_url (wiremock) ──────────────────────────────── + + #[tokio::test] + async fn download_url_returns_bytes_and_content_type() { + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("HEAD")) + .respond_with( + wiremock::ResponseTemplate::new(200).insert_header("Content-Length", "3"), + ) + .mount(&server) + .await; + wiremock::Mock::given(wiremock::matchers::method("GET")) + .respond_with( + wiremock::ResponseTemplate::new(200) + .insert_header("Content-Type", "image/png; charset=binary") + .set_body_bytes(vec![1u8, 2, 3]), + ) + .mount(&server) + .await; + let http = reqwest::Client::new(); + let (bytes, mime) = download_url(&http, &server.uri(), 10).await.unwrap(); + assert_eq!(bytes, vec![1, 2, 3]); + assert_eq!(mime, "image/png"); + } + + #[tokio::test] + async fn download_url_rejects_oversize_from_head_content_length() { + let server = wiremock::MockServer::start().await; + // HEAD reports a very large file → reject BEFORE GET. + wiremock::Mock::given(wiremock::matchers::method("HEAD")) + .respond_with( + wiremock::ResponseTemplate::new(200).insert_header( + "Content-Length", + (10u64 * 1024 * 1024 + 1).to_string().as_str(), + ), + ) + .mount(&server) + .await; + let http = reqwest::Client::new(); + let err = download_url(&http, &server.uri(), 10).await.unwrap_err(); + match err { + YuanbaoError::Media(m) => assert!(m.contains("too large"), "got {m}"), + other => panic!("expected Media error, got {other:?}"), + } + } + + #[tokio::test] + async fn download_url_rejects_when_body_exceeds_limit() { + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("HEAD")) + .respond_with(wiremock::ResponseTemplate::new(200)) + .mount(&server) + .await; + wiremock::Mock::given(wiremock::matchers::method("GET")) + .respond_with( + wiremock::ResponseTemplate::new(200) + .set_body_bytes(vec![0u8; 2 * 1024 * 1024]), // 2 MiB + ) + .mount(&server) + .await; + let http = reqwest::Client::new(); + let err = download_url(&http, &server.uri(), 1).await.unwrap_err(); + match err { + YuanbaoError::Media(m) => assert!(m.contains("exceeds limit"), "got {m}"), + other => panic!("expected Media error, got {other:?}"), + } + } + + #[tokio::test] + async fn download_url_surfaces_http_error_status() { + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("HEAD")) + .respond_with(wiremock::ResponseTemplate::new(404)) + .mount(&server) + .await; + wiremock::Mock::given(wiremock::matchers::method("GET")) + .respond_with(wiremock::ResponseTemplate::new(404)) + .mount(&server) + .await; + let http = reqwest::Client::new(); + let err = download_url(&http, &server.uri(), 10).await.unwrap_err(); + match err { + YuanbaoError::Media(m) => assert!(m.contains("HTTP 404"), "got {m}"), + other => panic!("expected Media error, got {other:?}"), + } + } } diff --git a/src/openhuman/channels/providers/yuanbao/outbound.rs b/src/openhuman/channels/providers/yuanbao/outbound.rs index 4f9da0e50c..6f7ec1cd85 100644 --- a/src/openhuman/channels/providers/yuanbao/outbound.rs +++ b/src/openhuman/channels/providers/yuanbao/outbound.rs @@ -375,4 +375,75 @@ mod tests { assert_eq!(extract_filename("https://x.com/"), "file"); assert_eq!(extract_filename(""), "file"); } + + #[test] + fn extract_filename_from_bare_path() { + // Not a valid URL → fall back to last non-empty `/`-segment. + assert_eq!(extract_filename("/var/log/foo.bin"), "foo.bin"); + // Trailing slash gets skipped; last non-empty segment wins. + assert_eq!(extract_filename("/var/log/"), "log"); + // Plain filename with no slashes. + assert_eq!(extract_filename("plain.txt"), "plain.txt"); + } + + fn make_conn(cfg: super::super::config::YuanbaoConfig) -> Arc { + let (tx, _rx) = tokio::sync::mpsc::unbounded_channel(); + YuanbaoConnection::new(cfg, tx, None) + } + + fn base_cfg() -> super::super::config::YuanbaoConfig { + let mut c = super::super::config::YuanbaoConfig::default(); + c.app_key = "ak".into(); + c.ws_domain = "wss://x".into(); + c.token = "tok".into(); + c.bot_id = "cfg-bot".into(); + c + } + + #[tokio::test] + async fn resolve_from_account_uses_config_bot_id_when_no_sign_manager() { + let conn = make_conn(base_cfg()); + let sender = OutboundSender::new(conn, None, "ak".into(), "cfg-bot".into()); + assert_eq!(sender.resolve_from_account().await, "cfg-bot"); + } + + #[tokio::test] + async fn resolve_from_account_uses_sign_cache_when_bot_id_present() { + let conn = make_conn(base_cfg()); + let mgr = super::super::sign::SignManager::new(reqwest::Client::new()); + // Seed the cache with a bot_id keyed on the same app_key. + mgr.set_cached_for_test( + "ak", + super::super::sign::TokenEntry { + token: "tok".into(), + bot_id: "server-bot".into(), + product: String::new(), + source: "bot".into(), + expire_ts: u64::MAX / 2, + }, + ) + .await; + let sender = OutboundSender::new(conn, Some(mgr), "ak".into(), "fallback-bot".into()); + // Sign cache hit → use server bot_id, not the fallback. + assert_eq!(sender.resolve_from_account().await, "server-bot"); + } + + #[tokio::test] + async fn resolve_from_account_falls_back_when_sign_cache_bot_id_empty() { + let conn = make_conn(base_cfg()); + let mgr = super::super::sign::SignManager::new(reqwest::Client::new()); + mgr.set_cached_for_test( + "ak", + super::super::sign::TokenEntry { + token: "tok".into(), + bot_id: String::new(), + product: String::new(), + source: String::new(), + expire_ts: u64::MAX / 2, + }, + ) + .await; + let sender = OutboundSender::new(conn, Some(mgr), "ak".into(), "fallback-bot".into()); + assert_eq!(sender.resolve_from_account().await, "fallback-bot"); + } } diff --git a/src/openhuman/channels/providers/yuanbao/proto.rs b/src/openhuman/channels/providers/yuanbao/proto.rs index 41fad20170..e678f74227 100644 --- a/src/openhuman/channels/providers/yuanbao/proto.rs +++ b/src/openhuman/channels/providers/yuanbao/proto.rs @@ -673,4 +673,242 @@ mod tests { let got = decode_msg_body_element(&buf).unwrap(); assert_eq!(got, el); } + + // ─── decode_auth_bind_rsp ───────────────────────────────────── + + fn build_auth_bind_rsp_bytes(code: u64, message: &str, connect_id: &str) -> Vec { + let mut buf = Vec::new(); + if code != 0 { + encode_field_varint(1, code, &mut buf); + } + if !message.is_empty() { + encode_field_string(2, message, &mut buf); + } + if !connect_id.is_empty() { + encode_field_string(3, connect_id, &mut buf); + } + buf + } + + #[test] + fn decode_auth_bind_rsp_happy_path() { + let body = build_auth_bind_rsp_bytes(0, "ok", "conn-42"); + let r = decode_auth_bind_rsp(&body).unwrap(); + assert_eq!(r.code, 0); + assert_eq!(r.message, "ok"); + assert_eq!(r.connect_id, "conn-42"); + } + + #[test] + fn decode_auth_bind_rsp_with_error_code() { + let body = build_auth_bind_rsp_bytes(40011, "rejected", ""); + let r = decode_auth_bind_rsp(&body).unwrap(); + assert_eq!(r.code, 40011); + assert_eq!(r.message, "rejected"); + assert!(r.connect_id.is_empty()); + } + + #[test] + fn decode_auth_bind_rsp_on_empty_returns_default() { + let r = decode_auth_bind_rsp(&[]).unwrap(); + assert_eq!(r, AuthBindRsp::default()); + } + + // ─── decode_push_msg ────────────────────────────────────────── + + #[test] + fn decode_push_msg_extracts_all_fields() { + let inner_payload = vec![0xCA, 0xFE, 0xBA, 0xBE]; + let mut buf = Vec::new(); + encode_field_string(1, "inbound_message", &mut buf); + encode_field_string(2, "yuanbao_openclaw_proxy", &mut buf); + encode_field_string(3, "pm-1", &mut buf); + encode_field_bytes(4, &inner_payload, &mut buf); + + let pm = decode_push_msg(&buf).unwrap(); + assert_eq!(pm.cmd, "inbound_message"); + assert_eq!(pm.module, "yuanbao_openclaw_proxy"); + assert_eq!(pm.msg_id, "pm-1"); + assert_eq!(pm.data, inner_payload); + } + + #[test] + fn decode_push_msg_on_empty_returns_defaults() { + let pm = decode_push_msg(&[]).unwrap(); + assert!(pm.cmd.is_empty()); + assert!(pm.module.is_empty()); + assert!(pm.msg_id.is_empty()); + assert!(pm.data.is_empty()); + } + + // ─── decode_inbound_push (protobuf) ─────────────────────────── + + #[test] + fn decode_inbound_push_dm_with_text_body() { + // Build a minimal DM push: from/to/sender_nickname + one TIMTextElem. + let text_elem = MsgBodyElement { + msg_type: "TIMTextElem".into(), + msg_content: MsgContent { + text: Some("hello".into()), + ..Default::default() + }, + }; + let elem_bytes = encode_msg_body_element(&text_elem); + + let mut log_ext = Vec::new(); + encode_field_string(1, "trace-123", &mut log_ext); + + let mut buf = Vec::new(); + encode_field_string(1, "C2CMsg", &mut buf); + encode_field_string(2, "user_42", &mut buf); + encode_field_string(3, "bot_1", &mut buf); + encode_field_string(4, "Alice", &mut buf); + encode_field_varint(8, 7, &mut buf); + encode_field_varint(9, 123, &mut buf); + encode_field_varint(10, 1_700_000_000, &mut buf); + encode_field_string(12, "mid-abc", &mut buf); + encode_field_bytes(13, &elem_bytes, &mut buf); + encode_field_varint(15, 1_700_000_001, &mut buf); + encode_field_bytes(20, &log_ext, &mut buf); + + let m = decode_inbound_push(&buf).unwrap(); + assert_eq!(m.callback_command, "C2CMsg"); + assert_eq!(m.from_account, "user_42"); + assert_eq!(m.to_account, "bot_1"); + assert_eq!(m.sender_nickname, "Alice"); + assert_eq!(m.msg_seq, 7); + assert_eq!(m.msg_random, 123); + assert_eq!(m.msg_time, 1_700_000_000); + assert_eq!(m.msg_id, "mid-abc"); + assert_eq!(m.event_time, 1_700_000_001); + assert_eq!(m.trace_id, "trace-123"); + assert_eq!(m.msg_body.len(), 1); + assert_eq!(m.msg_body[0].msg_content.text.as_deref(), Some("hello")); + assert!(m.recall_msg_seq_list.is_empty()); + } + + #[test] + fn decode_inbound_push_group_with_recall_list() { + let mut recall_entry = Vec::new(); + encode_field_varint(1, 99, &mut recall_entry); + encode_field_string(2, "old-msg-id", &mut recall_entry); + + let mut buf = Vec::new(); + encode_field_string(1, "GroupSysMsg", &mut buf); + encode_field_string(5, "gid-x", &mut buf); + encode_field_string(6, "gcode-y", &mut buf); + encode_field_string(7, "Room", &mut buf); + encode_field_bytes(17, &recall_entry, &mut buf); + encode_field_string(19, "g-private-code", &mut buf); + + let m = decode_inbound_push(&buf).unwrap(); + assert_eq!(m.callback_command, "GroupSysMsg"); + assert_eq!(m.group_id, "gid-x"); + assert_eq!(m.group_code, "gcode-y"); + assert_eq!(m.group_name, "Room"); + assert_eq!(m.private_from_group_code, "g-private-code"); + assert_eq!(m.recall_msg_seq_list.len(), 1); + assert_eq!(m.recall_msg_seq_list[0].msg_seq, 99); + assert_eq!(m.recall_msg_seq_list[0].msg_id, "old-msg-id"); + assert!(m.trace_id.is_empty(), "no log_ext => empty trace_id"); + } + + // ─── decode_inbound_json ────────────────────────────────────── + + #[test] + fn decode_inbound_json_full_dm_shape() { + let json = serde_json::json!({ + "callback_command": "C2CMsg", + "from_account": "user_42", + "to_account": "bot_1", + "sender_nickname": "Alice", + "msg_seq": 7, + "msg_random": 123, + "msg_time": 1_700_000_000, + "msg_id": "mid-1", + "msg_body": [ + { + "msg_type": "TIMTextElem", + "msg_content": { "text": "hi" } + }, + { + "msg_type": "TIMImageElem", + "msg_content": { + "uuid": "u-1", + "image_format": 1, + "image_info_array": [ + { "type": 1, "size": 100, "width": 10, "height": 20, "url": "https://x/i.png" } + ] + } + } + ], + "recall_msg_seq_list": [{ "msg_seq": 9, "msg_id": "old" }], + "log_ext": { "trace_id": "trace-json" } + }); + let m = decode_inbound_json(json.to_string().as_bytes()).unwrap(); + assert_eq!(m.callback_command, "C2CMsg"); + assert_eq!(m.from_account, "user_42"); + assert_eq!(m.msg_id, "mid-1"); + assert_eq!(m.msg_body.len(), 2); + assert_eq!(m.msg_body[0].msg_content.text.as_deref(), Some("hi")); + let img = &m.msg_body[1].msg_content; + assert_eq!(img.uuid.as_deref(), Some("u-1")); + assert_eq!(img.image_info_array.len(), 1); + assert_eq!(img.image_info_array[0].url, "https://x/i.png"); + assert_eq!(m.recall_msg_seq_list.len(), 1); + assert_eq!(m.recall_msg_seq_list[0].msg_seq, 9); + assert_eq!(m.trace_id, "trace-json"); + } + + #[test] + fn decode_inbound_json_rejects_non_object_root() { + let err = decode_inbound_json(b"[1,2,3]").unwrap_err(); + match err { + YuanbaoError::ProtoDecode(m) => assert!(m.contains("not an object"), "got {m}"), + other => panic!("unexpected {other:?}"), + } + } + + #[test] + fn decode_inbound_json_rejects_invalid_json() { + let err = decode_inbound_json(b"not json").unwrap_err(); + match err { + YuanbaoError::ProtoDecode(m) => assert!(m.contains("json parse"), "got {m}"), + other => panic!("unexpected {other:?}"), + } + } + + #[test] + fn decode_msg_body_element_json_handles_image_type_alias() { + // Some payloads use `image_type` (snake_case) instead of `type`. + let v = serde_json::json!({ + "msg_type": "TIMImageElem", + "msg_content": { + "image_info_array": [ + { "image_type": 2, "size": 50, "width": 5, "height": 6, "url": "u" } + ] + } + }); + let el = decode_msg_body_element_json(&v); + assert_eq!(el.msg_type, "TIMImageElem"); + assert_eq!(el.msg_content.image_info_array.len(), 1); + assert_eq!(el.msg_content.image_info_array[0].image_type, 2); + } + + #[test] + fn decode_msg_content_image_info_with_only_image_type_zero_defaults_to_one() { + // When `image_type` is 0 but url is present, decoder bumps to 1. + let mut ib = Vec::new(); + encode_field_varint(2, 64, &mut ib); + encode_field_string(5, "https://x/y.png", &mut ib); + let mut content = Vec::new(); + encode_field_bytes(8, &ib, &mut content); + let mut elem = Vec::new(); + encode_field_string(1, "TIMImageElem", &mut elem); + encode_field_bytes(2, &content, &mut elem); + let got = decode_msg_body_element(&elem).unwrap(); + assert_eq!(got.msg_content.image_info_array.len(), 1); + assert_eq!(got.msg_content.image_info_array[0].image_type, 1); + assert_eq!(got.msg_content.image_info_array[0].url, "https://x/y.png"); + } } diff --git a/src/openhuman/channels/providers/yuanbao/proto_biz.rs b/src/openhuman/channels/providers/yuanbao/proto_biz.rs index 4b802606bb..4cb620dda3 100644 --- a/src/openhuman/channels/providers/yuanbao/proto_biz.rs +++ b/src/openhuman/channels/providers/yuanbao/proto_biz.rs @@ -394,6 +394,179 @@ mod tests { assert_eq!(parsed.member_count, 42); } + // ─── encode_send_c2c branches ────────────────────────────────── + + #[test] + fn c2c_encode_with_msg_id_msg_random_group_code_trace_id() { + // Hit the branches: msg_id non-empty, msg_random != 0, group_code + // non-empty, trace_id non-empty. + let buf = encode_send_c2c_message( + "uid_alice", + "uid_bot", + &text_body("hi"), + "mid-1", + 42, + "gcode-x", + "trace-1", + ); + let frame = decode_conn_msg(&buf).unwrap(); + assert_eq!(frame.cmd, biz_cmd::SEND_C2C_MESSAGE); + assert_eq!(frame.msg_id, "mid-1"); + // Re-parse the biz body and check the fields we encoded show up. + let f = parse_fields(&frame.data).unwrap(); + assert_eq!(get_string(&f, 1), "mid-1"); + assert_eq!(get_string(&f, 2), "uid_alice"); + assert_eq!(get_string(&f, 3), "uid_bot"); + assert_eq!(get_varint(&f, 4), 42); + assert_eq!(get_string(&f, 6), "gcode-x"); + // log_ext (field 8) carries nested {1: trace_id} + let log_ext = get_bytes(&f, 8); + assert!(!log_ext.is_empty()); + let inner = parse_fields(&log_ext).unwrap(); + assert_eq!(get_string(&inner, 1), "trace-1"); + } + + #[test] + fn c2c_encode_generates_synthetic_req_id_when_msg_id_empty() { + // msg_id empty branch — req_id falls back to `c2c_`. + let buf = encode_send_c2c_message("uid_alice", "uid_bot", &text_body("hi"), "", 0, "", ""); + let frame = decode_conn_msg(&buf).unwrap(); + assert!( + frame.msg_id.starts_with("c2c_"), + "expected synthetic req_id starting with c2c_, got {}", + frame.msg_id + ); + } + + // ─── encode_send_group branches ──────────────────────────────── + + #[test] + fn group_encode_with_all_optional_fields() { + let buf = encode_send_group_message( + "group_42", + "uid_bot", + &text_body("hello"), + "mid-g", + "uid_to", + "rand_x", + "ref-msg-99", + "trace-g", + ); + let frame = decode_conn_msg(&buf).unwrap(); + assert_eq!(frame.cmd, biz_cmd::SEND_GROUP_MESSAGE); + assert_eq!(frame.msg_id, "mid-g"); + let f = parse_fields(&frame.data).unwrap(); + assert_eq!(get_string(&f, 1), "mid-g"); + assert_eq!(get_string(&f, 2), "group_42"); + assert_eq!(get_string(&f, 3), "uid_bot"); + assert_eq!(get_string(&f, 4), "uid_to"); + assert_eq!(get_string(&f, 5), "rand_x"); + assert_eq!(get_string(&f, 7), "ref-msg-99"); + let log_ext = get_bytes(&f, 9); + let inner = parse_fields(&log_ext).unwrap(); + assert_eq!(get_string(&inner, 1), "trace-g"); + } + + #[test] + fn group_encode_generates_synthetic_req_id_when_msg_id_empty() { + let buf = encode_send_group_message( + "group_x", + "uid_bot", + &text_body("hi"), + "", + "", + "", + "", + "", + ); + let frame = decode_conn_msg(&buf).unwrap(); + assert!( + frame.msg_id.starts_with("grp_"), + "expected synthetic req_id starting with grp_, got {}", + frame.msg_id + ); + } + + // ─── encode_send_group_heartbeat ─────────────────────────────── + + #[test] + fn group_heartbeat_encodes_send_time_and_heartbeat() { + let buf = encode_send_group_heartbeat( + "hb_g_1", + "uid_bot", + "group_42", + ws_heartbeat::RUNNING, + 1_700_000_123, + ); + let frame = decode_conn_msg(&buf).unwrap(); + assert_eq!(frame.cmd, biz_cmd::SEND_GROUP_HEARTBEAT); + assert_eq!(frame.msg_id, "hb_g_1"); + let f = parse_fields(&frame.data).unwrap(); + assert_eq!(get_string(&f, 1), "uid_bot"); + assert_eq!(get_string(&f, 2), ""); // to_account empty for group + assert_eq!(get_string(&f, 3), "group_42"); + assert_eq!(get_varint(&f, 4), 1_700_000_123); + assert_eq!(get_varint(&f, 5), ws_heartbeat::RUNNING as u64); + } + + // ─── encode_get_group_member_list ────────────────────────────── + + #[test] + fn get_group_member_list_omits_offset_when_zero() { + let buf = encode_get_group_member_list("qgm_1", "group_42", 0, 100); + let frame = decode_conn_msg(&buf).unwrap(); + assert_eq!(frame.cmd, biz_cmd::GET_GROUP_MEMBER_LIST); + let f = parse_fields(&frame.data).unwrap(); + assert_eq!(get_string(&f, 1), "group_42"); + // offset (field 2) skipped when 0 + assert_eq!(get_varint(&f, 2), 0); + assert_eq!(get_varint(&f, 3), 100); + } + + #[test] + fn get_group_member_list_includes_offset_when_nonzero() { + let buf = encode_get_group_member_list("qgm_2", "group_42", 200, 50); + let frame = decode_conn_msg(&buf).unwrap(); + let f = parse_fields(&frame.data).unwrap(); + assert_eq!(get_varint(&f, 2), 200); + assert_eq!(get_varint(&f, 3), 50); + } + + // ─── decode_biz_rsp_code + decode_response_envelope ──────────── + + #[test] + fn decode_biz_rsp_code_reads_code_and_message() { + let mut buf = Vec::new(); + put_varint_field(1, 4002, &mut buf); + put_string_field(2, "rate limited", &mut buf); + let (code, msg) = decode_biz_rsp_code(&buf).unwrap(); + assert_eq!(code, 4002); + assert_eq!(msg, "rate limited"); + } + + #[test] + fn decode_biz_rsp_code_on_empty_returns_defaults() { + let (code, msg) = decode_biz_rsp_code(&[]).unwrap(); + assert_eq!(code, 0); + assert!(msg.is_empty()); + } + + #[test] + fn decode_response_envelope_extracts_frame() { + let original = encode_conn_msg( + cmd_type::RESPONSE, + biz_cmd::SEND_C2C_MESSAGE, + 1, + "mid-r", + module::BIZ_PKG, + &[0xAA, 0xBB], + ); + let frame = decode_response_envelope(&original).unwrap(); + assert_eq!(frame.cmd, biz_cmd::SEND_C2C_MESSAGE); + assert_eq!(frame.msg_id, "mid-r"); + assert_eq!(frame.data, vec![0xAA, 0xBB]); + } + #[test] fn group_member_list_decode() { let mut m1 = Vec::new(); diff --git a/src/openhuman/channels/providers/yuanbao/sign.rs b/src/openhuman/channels/providers/yuanbao/sign.rs index e168b8f95d..03791c6b14 100644 --- a/src/openhuman/channels/providers/yuanbao/sign.rs +++ b/src/openhuman/channels/providers/yuanbao/sign.rs @@ -109,6 +109,12 @@ impl SignManager { cache.get(app_key).cloned().filter(|e| e.is_valid()) } + /// Test-only: inject a cache entry without touching the sign endpoint. + #[cfg(test)] + pub(crate) async fn set_cached_for_test(&self, app_key: &str, entry: TokenEntry) { + self.cache.lock().await.insert(app_key.to_string(), entry); + } + /// Get a valid token, fetching one if the cache is empty or stale. pub async fn get_token( &self, @@ -443,4 +449,181 @@ mod tests { ); assert!(mgr.cached("ak").await.is_none()); } + + #[test] + fn token_entry_seconds_remaining_is_signed() { + let e_future = TokenEntry { + token: "t".into(), + bot_id: "b".into(), + product: String::new(), + source: String::new(), + expire_ts: unix_now() + 300, + }; + assert!(e_future.seconds_remaining() >= 290); + let e_past = TokenEntry { + expire_ts: unix_now().saturating_sub(60), + ..e_future + }; + assert!(e_past.seconds_remaining() <= 0); + } + + #[test] + fn suffix_redacts_to_last_4_chars() { + assert_eq!(suffix(""), ""); + assert_eq!(suffix("a"), "a"); + assert_eq!(suffix("abcd"), "abcd"); + assert_eq!(suffix("abcdef"), "cdef"); + assert_eq!(suffix("0123456789"), "6789"); + } + + // ─── refresh / fetch_with_retry via wiremock ──────────────── + + fn ok_body(token: &str, bot_id: &str, duration_secs: u64) -> serde_json::Value { + serde_json::json!({ + "code": 0, + "msg": "ok", + "data": { + "token": token, + "bot_id": bot_id, + "product": "prod1", + "source": "src1", + "duration": duration_secs, + } + }) + } + + #[tokio::test] + async fn get_token_fetches_and_caches_on_first_call() { + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .and(wiremock::matchers::path(SIGN_PATH)) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_json(ok_body("tok-1", "bot-1", 7200)), + ) + .mount(&server) + .await; + let mgr = SignManager::new(reqwest::Client::new()); + let e = mgr + .get_token("ak", "sk", &server.uri(), "") + .await + .expect("token"); + assert_eq!(e.token, "tok-1"); + assert_eq!(e.bot_id, "bot-1"); + assert!(e.expire_ts > unix_now() + 7000); + + // Second call should hit the cache (still works even if server stops). + let cached = mgr.cached("ak").await.expect("cached"); + assert_eq!(cached.token, "tok-1"); + } + + #[tokio::test] + async fn get_token_retries_on_code_10099_then_succeeds() { + let server = wiremock::MockServer::start().await; + // First two requests return code=10099, third returns code=0. + wiremock::Mock::given(wiremock::matchers::method("POST")) + .and(wiremock::matchers::path(SIGN_PATH)) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 10099, + "msg": "try again", + })), + ) + .up_to_n_times(2) + .mount(&server) + .await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .and(wiremock::matchers::path(SIGN_PATH)) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_json(ok_body("tok-r", "bot-r", 600)), + ) + .mount(&server) + .await; + let mgr = SignManager::new(reqwest::Client::new()); + let e = mgr.refresh("ak", "sk", &server.uri(), "").await.unwrap(); + assert_eq!(e.token, "tok-r"); + assert_eq!(e.bot_id, "bot-r"); + } + + #[tokio::test] + async fn get_token_surfaces_http_error_as_auth_failed() { + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .respond_with(wiremock::ResponseTemplate::new(401).set_body_string("Unauthorized")) + .mount(&server) + .await; + let mgr = SignManager::new(reqwest::Client::new()); + let err = mgr + .get_token("ak", "sk", &server.uri(), "") + .await + .unwrap_err(); + match err { + YuanbaoError::AuthFailed(m) => assert!(m.contains("HTTP 401"), "got {m}"), + other => panic!("expected AuthFailed, got {other:?}"), + } + } + + #[tokio::test] + async fn get_token_fails_on_non_zero_business_code() { + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 40001, + "msg": "bad secret", + })), + ) + .mount(&server) + .await; + let mgr = SignManager::new(reqwest::Client::new()); + let err = mgr + .get_token("ak", "sk", &server.uri(), "") + .await + .unwrap_err(); + match err { + YuanbaoError::AuthFailed(m) => { + assert!(m.contains("code=40001"), "got {m}"); + assert!(m.contains("bad secret"), "got {m}"); + } + other => panic!("expected AuthFailed, got {other:?}"), + } + } + + #[tokio::test] + async fn force_refresh_evicts_cache_and_refetches() { + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .and(wiremock::matchers::path(SIGN_PATH)) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_json(ok_body("tok-a", "bot", 600)), + ) + .up_to_n_times(1) + .mount(&server) + .await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .and(wiremock::matchers::path(SIGN_PATH)) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_json(ok_body("tok-b", "bot", 600)), + ) + .mount(&server) + .await; + let mgr = SignManager::new(reqwest::Client::new()); + let first = mgr.get_token("ak", "sk", &server.uri(), "").await.unwrap(); + assert_eq!(first.token, "tok-a"); + let second = mgr + .force_refresh("ak", "sk", &server.uri(), "to_env") + .await + .unwrap(); + assert_eq!(second.token, "tok-b"); + } + + #[tokio::test] + async fn clear_locks_drops_all_per_app_key_mutexes() { + let mgr = SignManager::new(reqwest::Client::new()); + // Prime the locks map. + let _ = mgr.get_refresh_lock("ak-1").await; + let _ = mgr.get_refresh_lock("ak-2").await; + assert_eq!(mgr.locks.lock().await.len(), 2); + mgr.clear_locks().await; + assert!(mgr.locks.lock().await.is_empty()); + } } diff --git a/src/openhuman/channels/providers/yuanbao/types.rs b/src/openhuman/channels/providers/yuanbao/types.rs index c661d078ae..6364f763f5 100644 --- a/src/openhuman/channels/providers/yuanbao/types.rs +++ b/src/openhuman/channels/providers/yuanbao/types.rs @@ -211,6 +211,178 @@ impl Source { } } +#[cfg(test)] +mod tests { + use super::*; + + fn text_elem(s: &str) -> MsgBodyElement { + MsgBodyElement { + msg_type: "TIMTextElem".into(), + msg_content: MsgContent { + text: Some(s.into()), + ..Default::default() + }, + } + } + + fn image_elem(info_urls: &[&str], inline_url: Option<&str>) -> MsgBodyElement { + MsgBodyElement { + msg_type: "TIMImageElem".into(), + msg_content: MsgContent { + image_info_array: info_urls + .iter() + .map(|u| ImageInfo { + image_type: 1, + url: (*u).into(), + ..Default::default() + }) + .collect(), + url: inline_url.map(String::from), + ..Default::default() + }, + } + } + + #[test] + fn dm_is_not_group() { + let m = InboundMessage { + from_account: "alice".into(), + ..Default::default() + }; + assert!(!m.is_group()); + assert_eq!(m.chat_id(), "alice"); + } + + #[test] + fn group_is_group_and_chat_id_is_group_code() { + let m = InboundMessage { + group_code: "grp_42".into(), + from_account: "alice".into(), + ..Default::default() + }; + assert!(m.is_group()); + assert_eq!(m.chat_id(), "grp_42"); + } + + #[test] + fn is_recall_iff_recall_list_non_empty() { + let mut m = InboundMessage::default(); + assert!(!m.is_recall()); + m.recall_msg_seq_list.push(ImMsgSeq { + msg_seq: 7, + msg_id: "x".into(), + }); + assert!(m.is_recall()); + } + + #[test] + fn extract_text_concatenates_text_elements() { + let m = InboundMessage { + msg_body: vec![text_elem("hello"), text_elem("world"), image_elem(&[], None)], + ..Default::default() + }; + assert_eq!(m.extract_text(), "hello\nworld"); + } + + #[test] + fn extract_text_ignores_text_none_and_non_text() { + let m = InboundMessage { + msg_body: vec![ + MsgBodyElement { + msg_type: "TIMTextElem".into(), + msg_content: MsgContent::default(), // text: None + }, + image_elem(&["https://x/y.png"], None), + ], + ..Default::default() + }; + assert_eq!(m.extract_text(), ""); + } + + #[test] + fn extract_text_on_empty_msg_body_returns_empty() { + let m = InboundMessage::default(); + assert_eq!(m.extract_text(), ""); + } + + #[test] + fn extract_image_urls_from_image_info_array() { + let m = InboundMessage { + msg_body: vec![image_elem(&["https://a/1.png", "https://a/2.png"], None)], + ..Default::default() + }; + assert_eq!( + m.extract_image_urls(), + vec!["https://a/1.png".to_string(), "https://a/2.png".into()] + ); + } + + #[test] + fn extract_image_urls_falls_back_to_inline_url_field() { + let m = InboundMessage { + msg_body: vec![image_elem(&[], Some("https://a/inline.png"))], + ..Default::default() + }; + assert_eq!( + m.extract_image_urls(), + vec!["https://a/inline.png".to_string()] + ); + } + + #[test] + fn extract_image_urls_dedups_inline_when_already_in_info_array() { + let dup = "https://a/dup.png"; + let m = InboundMessage { + msg_body: vec![image_elem(&[dup], Some(dup))], + ..Default::default() + }; + assert_eq!(m.extract_image_urls(), vec![dup.to_string()]); + } + + #[test] + fn extract_image_urls_skips_empty_url_in_info_array() { + let m = InboundMessage { + msg_body: vec![image_elem(&[""], None)], + ..Default::default() + }; + assert!(m.extract_image_urls().is_empty()); + } + + #[test] + fn extract_image_urls_ignores_text_elements() { + let m = InboundMessage { + msg_body: vec![text_elem("hi"), image_elem(&["https://a/1.png"], None)], + ..Default::default() + }; + assert_eq!(m.extract_image_urls(), vec!["https://a/1.png".to_string()]); + } + + #[test] + fn source_reply_target_dm_is_raw_uid() { + let s = Source { + from_account: "uid_alice".into(), + is_group: false, + ..Default::default() + }; + assert_eq!(s.reply_target(), "uid_alice"); + } + + #[test] + fn source_reply_target_group_uses_g_prefix() { + let s = Source { + group_code: "grp_42".into(), + is_group: true, + ..Default::default() + }; + assert_eq!(s.reply_target(), "g:grp_42"); + } + + #[test] + fn message_kind_default_is_text() { + assert_eq!(MessageKind::default(), MessageKind::Text); + } +} + /// Group metadata returned by `QueryGroupInfo`. #[derive(Debug, Clone, Default, PartialEq)] pub struct GroupInfo { diff --git a/src/openhuman/channels/providers/yuanbao/wire.rs b/src/openhuman/channels/providers/yuanbao/wire.rs index d7137c20cd..98d24f5a16 100644 --- a/src/openhuman/channels/providers/yuanbao/wire.rs +++ b/src/openhuman/channels/providers/yuanbao/wire.rs @@ -229,4 +229,127 @@ mod tests { let b = next_seq_no(); assert!(b > a); } + + #[test] + fn varint_too_long_errors() { + // 11 continuation bytes overflows the 64-bit shift guard. + let buf = vec![0x80; 11]; + match decode_varint(&buf, 0).unwrap_err() { + YuanbaoError::ProtoDecode(m) => assert!(m.contains("too long"), "got {m}"), + other => panic!("unexpected {other:?}"), + } + } + + #[test] + fn parse_fields_truncated_bytes_field_errors() { + // Field 1 (wire type 2) declaring length 5 but only 1 byte of payload. + let buf = vec![ + (1 << 3) | 2, // tag: field=1, wire=2 + 5, // claimed len + b'a', + ]; + match parse_fields(&buf).unwrap_err() { + YuanbaoError::ProtoDecode(m) => assert!(m.contains("truncated"), "got {m}"), + other => panic!("unexpected {other:?}"), + } + } + + #[test] + fn parse_fields_reads_fixed64() { + let mut buf = Vec::new(); + buf.push((1 << 3) | 1); // tag: field=1, wire=1 (fixed64) + buf.extend_from_slice(&0x1122_3344_5566_7788u64.to_le_bytes()); + let f = parse_fields(&buf).unwrap(); + match f[0].1 { + FieldValue::Fixed64(v) => assert_eq!(v, 0x1122_3344_5566_7788), + ref other => panic!("expected Fixed64 got {other:?}"), + } + } + + #[test] + fn parse_fields_truncated_fixed64_errors() { + let mut buf = Vec::new(); + buf.push((1 << 3) | 1); + buf.extend_from_slice(&[0u8; 4]); // only 4/8 bytes + match parse_fields(&buf).unwrap_err() { + YuanbaoError::ProtoDecode(m) => assert!(m.contains("fixed64"), "got {m}"), + other => panic!("unexpected {other:?}"), + } + } + + #[test] + fn parse_fields_reads_fixed32() { + let mut buf = Vec::new(); + buf.push((1 << 3) | 5); // tag: field=1, wire=5 (fixed32) + buf.extend_from_slice(&0xCAFEBABEu32.to_le_bytes()); + let f = parse_fields(&buf).unwrap(); + match f[0].1 { + FieldValue::Fixed32(v) => assert_eq!(v, 0xCAFEBABE), + ref other => panic!("expected Fixed32 got {other:?}"), + } + } + + #[test] + fn parse_fields_truncated_fixed32_errors() { + let mut buf = Vec::new(); + buf.push((1 << 3) | 5); + buf.extend_from_slice(&[0u8; 2]); // only 2/4 bytes + match parse_fields(&buf).unwrap_err() { + YuanbaoError::ProtoDecode(m) => assert!(m.contains("fixed32"), "got {m}"), + other => panic!("unexpected {other:?}"), + } + } + + #[test] + fn parse_fields_unsupported_wire_type_errors() { + // wire type 3 (start group) is not supported. + let buf = vec![(1 << 3) | 3]; + match parse_fields(&buf).unwrap_err() { + YuanbaoError::ProtoDecode(m) => { + assert!(m.contains("unsupported wire type 3"), "got {m}") + } + other => panic!("unexpected {other:?}"), + } + } + + #[test] + fn get_string_returns_empty_when_field_is_varint() { + // Field 1 exists but encoded as varint, not bytes — get_string must + // skip past it and return the default. + let mut buf = Vec::new(); + encode_field_varint(1, 7, &mut buf); + let fields = parse_fields(&buf).unwrap(); + assert_eq!(get_string(&fields, 1), ""); + } + + #[test] + fn get_varint_returns_zero_when_field_is_bytes() { + let mut buf = Vec::new(); + encode_field_string(1, "not a varint", &mut buf); + let fields = parse_fields(&buf).unwrap(); + assert_eq!(get_varint(&fields, 1), 0); + } + + #[test] + fn get_bytes_returns_empty_when_field_is_varint() { + let mut buf = Vec::new(); + encode_field_varint(1, 7, &mut buf); + let fields = parse_fields(&buf).unwrap(); + assert!(get_bytes(&fields, 1).is_empty()); + } + + #[test] + fn get_repeated_bytes_collects_multiple_same_field() { + let mut buf = Vec::new(); + encode_field_string(1, "a", &mut buf); + encode_field_string(1, "bb", &mut buf); + encode_field_string(2, "c", &mut buf); // different field — should be skipped + encode_field_string(1, "ddd", &mut buf); + let fields = parse_fields(&buf).unwrap(); + let got = get_repeated_bytes(&fields, 1); + assert_eq!(got.len(), 3); + assert_eq!(got[0], b"a"); + assert_eq!(got[1], b"bb"); + assert_eq!(got[2], b"ddd"); + } } From 147edecad2c9868c14644e150c099d1f1cd6101a Mon Sep 17 00:00:00 2001 From: lrt4836 <296659110@qq.com> Date: Sat, 23 May 2026 23:27:46 +0800 Subject: [PATCH 3/8] fix(channels/yuanbao): address CodeRabbit PR #2494 review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Resolve all 9 actionable comments from the CodeRabbit review. Frontend: - ChannelSetupModal: render branded YuanbaoIcon for the yuanbao channel instead of falling through CHANNEL_ICONS to an empty emoji. - YuanbaoConfig: route required-field validation and the connecting spinner copy through i18n; branch on connectChannel's reported status so non-"connected" results surface as errors instead of being silently treated as success. - i18n: add 3 new keys (fieldRequired / connecting / unexpectedStatus) to en, ko, and all chunk-3 files (zh-CN translated; other locales use English placeholders to keep `pnpm i18n:check` green). Backend (Rust): - ops.rs: stop persisting `app_secret` in plaintext config.toml — the encrypted credentials store already holds it. Also persist the optional endpoint overrides (env / api_domain / ws_domain / route_env) so a non-default cluster selection survives restart. - startup.rs: extract `resolve_yuanbao_app_secret` and load the secret from the credentials store at startup when TOML is empty. Pre-existing TOML values still win so manually-installed deployments don't break. - channel.rs: drop the hex preview of inbound biz payloads from the pipeline-error log path; user content / PII must not leak to logs. - connection.rs: re-check `*shutdown.borrow()` after connect_once returns. Without this, a shutdown signal observed inside connect_once consumes the `changed()` notification, leaving the outer `tokio::select!` to wait through the full reconnect backoff before exiting. - proto_biz.rs: replace `as i32` / `as u32` casts on decoded varints with `varint_to_i32` / `varint_to_u32` helpers backed by `try_from`, so oversized fields surface a structured `ProtoDecode` error instead of silently truncating `code`, `member_count`, `role`, `join_time`, and `next_offset`. - wire.rs: guard the `WT_LEN` arithmetic with `usize::try_from(len)` and `pos.checked_add(...)`, eliminating the slice-panic path on adversarial length-delimited fields. Tests (+6 net): - proto_biz: `decode_biz_rsp_code_rejects_varint_out_of_i32_range`, `decode_group_member_list_rejects_varint_out_of_u32_range`. - wire: `parse_fields_oversize_len_field_errors_without_panic`. - ops: updated `connect_yuanbao_persists_when_credentials_valid` to assert TOML has no plaintext secret and the store has both fields; new `connect_yuanbao_persists_env_override` covering the endpoint round-trip. - startup: three tests around `resolve_yuanbao_app_secret` (load from credentials, prefer existing TOML, gracefully return empty). - connection: `run_exits_promptly_after_shutdown_signal` regression guard against backoff-blocked shutdown. cargo test --lib -- yuanbao: 188 passed (was 182). pnpm i18n:check: green. --- .../components/channels/ChannelSetupModal.tsx | 11 +- app/src/components/channels/YuanbaoConfig.tsx | 24 +++- app/src/lib/i18n/chunks/ar-3.ts | 3 + app/src/lib/i18n/chunks/bn-3.ts | 3 + app/src/lib/i18n/chunks/de-3.ts | 3 + app/src/lib/i18n/chunks/en-3.ts | 3 + app/src/lib/i18n/chunks/es-3.ts | 3 + app/src/lib/i18n/chunks/fr-3.ts | 3 + app/src/lib/i18n/chunks/hi-3.ts | 3 + app/src/lib/i18n/chunks/id-3.ts | 3 + app/src/lib/i18n/chunks/it-3.ts | 3 + app/src/lib/i18n/chunks/ko-3.ts | 3 + app/src/lib/i18n/chunks/pt-3.ts | 3 + app/src/lib/i18n/chunks/ru-3.ts | 3 + app/src/lib/i18n/chunks/zh-CN-3.ts | 3 + app/src/lib/i18n/en.ts | 3 + app/src/lib/i18n/ko.ts | 3 + src/openhuman/channels/controllers/ops.rs | 42 ++++++- .../channels/controllers/ops_tests.rs | 72 +++++++++++- .../channels/providers/yuanbao/channel.rs | 18 +-- .../channels/providers/yuanbao/connection.rs | 52 ++++++++- .../channels/providers/yuanbao/proto_biz.rs | 82 +++++++++++--- .../channels/providers/yuanbao/wire.rs | 34 +++++- src/openhuman/channels/runtime/startup.rs | 103 +++++++++++++++++- 24 files changed, 442 insertions(+), 41 deletions(-) diff --git a/app/src/components/channels/ChannelSetupModal.tsx b/app/src/components/channels/ChannelSetupModal.tsx index c3b2502294..38c58baafe 100644 --- a/app/src/components/channels/ChannelSetupModal.tsx +++ b/app/src/components/channels/ChannelSetupModal.tsx @@ -10,7 +10,11 @@ import type { ChannelDefinition, ChannelType } from '../../types/channels'; import DiscordConfig from './DiscordConfig'; import TelegramConfig from './TelegramConfig'; import YuanbaoConfig from './YuanbaoConfig'; +import YuanbaoIcon from './YuanbaoIcon'; +// Emoji icons for channels rendered as plain text. `yuanbao` is handled +// separately with a branded SVG (see `YuanbaoIcon`) — matches the +// rendering used in `ChannelSelector`. const CHANNEL_ICONS: Record = { telegram: '\u2708\uFE0F', discord: '\uD83C\uDFAE', @@ -66,6 +70,7 @@ export default function ChannelSetupModal({ definition, onClose }: ChannelSetupM }; const emojiIcon = CHANNEL_ICONS[definition.icon] ?? ''; + const isYuanbao = definition.icon === 'yuanbao'; const modalContent = (
- {emojiIcon && {emojiIcon}} + {isYuanbao ? ( + + ) : ( + emojiIcon && {emojiIcon} + )}

diff --git a/app/src/components/channels/YuanbaoConfig.tsx b/app/src/components/channels/YuanbaoConfig.tsx index 33b78f8986..3ba1f4198f 100644 --- a/app/src/components/channels/YuanbaoConfig.tsx +++ b/app/src/components/channels/YuanbaoConfig.tsx @@ -75,7 +75,7 @@ const YuanbaoConfig = ({ definition }: YuanbaoConfigProps) => { for (const field of spec.fields) { const empty = !fieldValues[field.key]?.trim(); if (field.required && empty) { - errors[field.key] = `${field.label} 不能为空`; + errors[field.key] = t('channels.yuanbao.fieldRequired').replace('{field}', field.label); } } if (Object.keys(errors).length > 0) { @@ -116,6 +116,26 @@ const YuanbaoConfig = ({ definition }: YuanbaoConfigProps) => { console.log('[YuanbaoConfig] handleConnect: 5.RPC returned', result); log('connect result: %o', result); + // Only treat explicit "connected" as success. Any other status + // (e.g. "pending_auth" if a future auth flow gets added) must + // surface as an error instead of silently dispatching connected. + if (result.status !== 'connected') { + const msg = t('channels.yuanbao.unexpectedStatus').replace( + '{status}', + result.status ?? '' + ); + console.warn('[YuanbaoConfig] handleConnect: 6.unexpected status', result.status); + dispatch( + setChannelConnectionStatus({ + channel: 'yuanbao', + authMode: spec.mode, + status: 'error', + lastError: msg, + }) + ); + return; + } + if (result.restart_required) { console.log( '[YuanbaoConfig] handleConnect: 6.restart_required=true, calling restartCoreProcess' @@ -275,7 +295,7 @@ const YuanbaoConfig = ({ definition }: YuanbaoConfigProps) => { )} {busy - ? '连接中…' + ? t('channels.yuanbao.connecting') : status === 'connected' ? t('channels.telegram.reconnect') : t('channels.telegram.connect')} diff --git a/app/src/lib/i18n/chunks/ar-3.ts b/app/src/lib/i18n/chunks/ar-3.ts index 1d22c8ed93..abe653abf4 100644 --- a/app/src/lib/i18n/chunks/ar-3.ts +++ b/app/src/lib/i18n/chunks/ar-3.ts @@ -402,6 +402,9 @@ const ar3: TranslationMap = { 'channels.web.description': 'Chat via the built-in web UI.', 'channels.web.authMode.managed_dm.description': 'Use the embedded web chat — no setup required.', 'welcome.continueLocallyExperimental': 'Continue Locally (Experimental)', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', }; export default ar3; diff --git a/app/src/lib/i18n/chunks/bn-3.ts b/app/src/lib/i18n/chunks/bn-3.ts index 711bf29caa..9a98a3e7a5 100644 --- a/app/src/lib/i18n/chunks/bn-3.ts +++ b/app/src/lib/i18n/chunks/bn-3.ts @@ -405,6 +405,9 @@ const bn3: TranslationMap = { 'channels.web.description': 'Chat via the built-in web UI.', 'channels.web.authMode.managed_dm.description': 'Use the embedded web chat — no setup required.', 'welcome.continueLocallyExperimental': 'Continue Locally (Experimental)', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', }; export default bn3; diff --git a/app/src/lib/i18n/chunks/de-3.ts b/app/src/lib/i18n/chunks/de-3.ts index 2a7ff892f9..323541840d 100644 --- a/app/src/lib/i18n/chunks/de-3.ts +++ b/app/src/lib/i18n/chunks/de-3.ts @@ -417,6 +417,9 @@ const de3: TranslationMap = { 'welcome.continueLocallyExperimental': 'Continue Locally (Experimental)', 'welcome.localSessionStarting': 'Starting local session...', 'welcome.localSessionDesc': 'Uses an offline local profile and skips TinyHumans OAuth.', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', }; export default de3; diff --git a/app/src/lib/i18n/chunks/en-3.ts b/app/src/lib/i18n/chunks/en-3.ts index c71ecc7e58..679cade521 100644 --- a/app/src/lib/i18n/chunks/en-3.ts +++ b/app/src/lib/i18n/chunks/en-3.ts @@ -405,6 +405,9 @@ const en3: TranslationMap = { 'channels.web.displayName': 'Web', 'channels.web.description': 'Chat via the built-in web UI.', 'channels.web.authMode.managed_dm.description': 'Use the embedded web chat — no setup required.', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', }; export default en3; diff --git a/app/src/lib/i18n/chunks/es-3.ts b/app/src/lib/i18n/chunks/es-3.ts index bcea0ebba7..e344eee08d 100644 --- a/app/src/lib/i18n/chunks/es-3.ts +++ b/app/src/lib/i18n/chunks/es-3.ts @@ -410,6 +410,9 @@ const es3: TranslationMap = { 'channels.web.description': 'Chat via the built-in web UI.', 'channels.web.authMode.managed_dm.description': 'Use the embedded web chat — no setup required.', 'welcome.continueLocallyExperimental': 'Continue Locally (Experimental)', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', }; export default es3; diff --git a/app/src/lib/i18n/chunks/fr-3.ts b/app/src/lib/i18n/chunks/fr-3.ts index 6f0a77ce38..16e6c60b40 100644 --- a/app/src/lib/i18n/chunks/fr-3.ts +++ b/app/src/lib/i18n/chunks/fr-3.ts @@ -411,6 +411,9 @@ const fr3: TranslationMap = { 'channels.web.description': 'Chat via the built-in web UI.', 'channels.web.authMode.managed_dm.description': 'Use the embedded web chat — no setup required.', 'welcome.continueLocallyExperimental': 'Continue Locally (Experimental)', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', }; export default fr3; diff --git a/app/src/lib/i18n/chunks/hi-3.ts b/app/src/lib/i18n/chunks/hi-3.ts index 3f11035bd2..7cd7d9bec7 100644 --- a/app/src/lib/i18n/chunks/hi-3.ts +++ b/app/src/lib/i18n/chunks/hi-3.ts @@ -407,6 +407,9 @@ const hi3: TranslationMap = { 'channels.web.description': 'Chat via the built-in web UI.', 'channels.web.authMode.managed_dm.description': 'Use the embedded web chat — no setup required.', 'welcome.continueLocallyExperimental': 'Continue Locally (Experimental)', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', }; export default hi3; diff --git a/app/src/lib/i18n/chunks/id-3.ts b/app/src/lib/i18n/chunks/id-3.ts index 923ce7e57e..2924d625c3 100644 --- a/app/src/lib/i18n/chunks/id-3.ts +++ b/app/src/lib/i18n/chunks/id-3.ts @@ -410,6 +410,9 @@ const id3: TranslationMap = { 'channels.web.description': 'Chat via the built-in web UI.', 'channels.web.authMode.managed_dm.description': 'Use the embedded web chat — no setup required.', 'welcome.continueLocallyExperimental': 'Continue Locally (Experimental)', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', }; export default id3; diff --git a/app/src/lib/i18n/chunks/it-3.ts b/app/src/lib/i18n/chunks/it-3.ts index de00a982f0..67b1fe54ec 100644 --- a/app/src/lib/i18n/chunks/it-3.ts +++ b/app/src/lib/i18n/chunks/it-3.ts @@ -410,6 +410,9 @@ const it3: TranslationMap = { 'channels.web.description': 'Chat via the built-in web UI.', 'channels.web.authMode.managed_dm.description': 'Use the embedded web chat — no setup required.', 'welcome.continueLocallyExperimental': 'Continue Locally (Experimental)', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', }; export default it3; diff --git a/app/src/lib/i18n/chunks/ko-3.ts b/app/src/lib/i18n/chunks/ko-3.ts index 774500ecae..7a5bfd2e73 100644 --- a/app/src/lib/i18n/chunks/ko-3.ts +++ b/app/src/lib/i18n/chunks/ko-3.ts @@ -407,5 +407,8 @@ const ko3: TranslationMap = { 'channels.web.description': 'Chat via the built-in web UI.', 'channels.web.authMode.managed_dm.description': 'Use the embedded web chat — no setup required.', 'welcome.continueLocallyExperimental': 'Continue Locally (Experimental)', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', }; export default ko3; diff --git a/app/src/lib/i18n/chunks/pt-3.ts b/app/src/lib/i18n/chunks/pt-3.ts index 9cdd6fd415..4cf5bf3bce 100644 --- a/app/src/lib/i18n/chunks/pt-3.ts +++ b/app/src/lib/i18n/chunks/pt-3.ts @@ -409,6 +409,9 @@ const pt3: TranslationMap = { 'channels.web.description': 'Chat via the built-in web UI.', 'channels.web.authMode.managed_dm.description': 'Use the embedded web chat — no setup required.', 'welcome.continueLocallyExperimental': 'Continue Locally (Experimental)', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', }; export default pt3; diff --git a/app/src/lib/i18n/chunks/ru-3.ts b/app/src/lib/i18n/chunks/ru-3.ts index 635fca70a3..bba9166436 100644 --- a/app/src/lib/i18n/chunks/ru-3.ts +++ b/app/src/lib/i18n/chunks/ru-3.ts @@ -406,6 +406,9 @@ const ru3: TranslationMap = { 'channels.web.description': 'Chat via the built-in web UI.', 'channels.web.authMode.managed_dm.description': 'Use the embedded web chat — no setup required.', 'welcome.continueLocallyExperimental': 'Continue Locally (Experimental)', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', }; export default ru3; diff --git a/app/src/lib/i18n/chunks/zh-CN-3.ts b/app/src/lib/i18n/chunks/zh-CN-3.ts index a417beb005..db5c2bb89c 100644 --- a/app/src/lib/i18n/chunks/zh-CN-3.ts +++ b/app/src/lib/i18n/chunks/zh-CN-3.ts @@ -400,6 +400,9 @@ const zhCN3: TranslationMap = { 'channels.web.description': '通过内置的 Web UI 聊天。', 'channels.web.authMode.managed_dm.description': '使用嵌入式 Web 聊天 — 无需设置。', 'welcome.continueLocallyExperimental': 'Continue Locally (Experimental)', + 'channels.yuanbao.connecting': '连接中…', + 'channels.yuanbao.fieldRequired': '{field} 不能为空', + 'channels.yuanbao.unexpectedStatus': '意外的连接状态:{status}', }; export default zhCN3; diff --git a/app/src/lib/i18n/en.ts b/app/src/lib/i18n/en.ts index e2659d8c1d..80d227edc7 100644 --- a/app/src/lib/i18n/en.ts +++ b/app/src/lib/i18n/en.ts @@ -2019,6 +2019,9 @@ const en: TranslationMap = { 'channels.web.displayName': 'Web', 'channels.web.description': 'Chat via the built-in web UI.', 'channels.web.authMode.managed_dm.description': 'Use the embedded web chat — no setup required.', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', 'chat.unsubscribeApproval.approve': 'Approve & Unsubscribe', 'chat.unsubscribeApproval.approved': '✓ Successfully unsubscribed.', 'chat.unsubscribeApproval.denied': '✕ Request denied.', diff --git a/app/src/lib/i18n/ko.ts b/app/src/lib/i18n/ko.ts index a11c0e30d1..78a01c09f2 100644 --- a/app/src/lib/i18n/ko.ts +++ b/app/src/lib/i18n/ko.ts @@ -1301,6 +1301,9 @@ const ko: TranslationMap = { 'channels.telegram.savedRestartRequired': '채널이 저장되었습니다. 활성화하려면 앱을 다시 시작하세요.', 'channels.web.alwaysAvailable': '항상 사용 가능', + 'channels.yuanbao.connecting': 'Connecting…', + 'channels.yuanbao.fieldRequired': '{field} is required', + 'channels.yuanbao.unexpectedStatus': 'Unexpected connection status: {status}', 'chat.unsubscribeApproval.approve': '승인 및 구독 취소', 'chat.unsubscribeApproval.approved': '✓ 구독 취소가 완료되었습니다.', 'chat.unsubscribeApproval.denied': '✕ 요청이 거부되었습니다.', diff --git a/src/openhuman/channels/controllers/ops.rs b/src/openhuman/channels/controllers/ops.rs index 805f79b5c0..55874d70f9 100644 --- a/src/openhuman/channels/controllers/ops.rs +++ b/src/openhuman/channels/controllers/ops.rs @@ -388,13 +388,32 @@ pub async fn connect_channel( .filter(|s| !s.is_empty()) .ok_or_else(|| "missing required app_key".to_string())? .to_string(); - let app_secret = creds_map + // `app_secret` is already in the encrypted credentials store + // (stored above via `store_provider_credentials`); we intentionally + // do NOT mirror it into the plaintext TOML to limit exposure + // surface. The runtime loads it from credentials at startup. + let _ = creds_map .get("app_secret") .and_then(|v| v.as_str()) .map(str::trim) .filter(|s| !s.is_empty()) - .ok_or_else(|| "missing required app_secret".to_string())? - .to_string(); + .ok_or_else(|| "missing required app_secret".to_string())?; + + // Optional endpoint overrides — preserve any non-default values + // submitted by the client (e.g. `env = "pre"`) so the runtime + // reconnects to the correct cluster after restart. + let opt_string = |key: &str| -> Option { + creds_map + .get(key) + .and_then(|v| v.as_str()) + .map(str::trim) + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()) + }; + let env_override = opt_string("env"); + let api_domain_override = opt_string("api_domain"); + let ws_domain_override = opt_string("ws_domain"); + let route_env_override = opt_string("route_env"); let mut persisted = config.clone(); let mut yb_config = persisted @@ -403,7 +422,20 @@ pub async fn connect_channel( .clone() .unwrap_or_default(); yb_config.app_key = app_key; - yb_config.app_secret = app_secret; + // Clear any stale plaintext secret from a previous version. + yb_config.app_secret = String::new(); + if let Some(env) = env_override { + yb_config.env = env; + } + if let Some(api_domain) = api_domain_override { + yb_config.api_domain = api_domain; + } + if let Some(ws_domain) = ws_domain_override { + yb_config.ws_domain = ws_domain; + } + if let Some(route_env) = route_env_override { + yb_config.route_env = route_env; + } persisted.channels_config.yuanbao = Some(yb_config); persisted @@ -413,7 +445,7 @@ pub async fn connect_channel( tracing::info!( target: "openhuman::channels", - "[yuanbao] connect_channel: wrote channels_config.yuanbao; restart core for WS listener" + "[yuanbao] connect_channel: wrote channels_config.yuanbao (secret stored in credentials); restart core for WS listener" ); } diff --git a/src/openhuman/channels/controllers/ops_tests.rs b/src/openhuman/channels/controllers/ops_tests.rs index 3d2d055c6c..825ca7747f 100644 --- a/src/openhuman/channels/controllers/ops_tests.rs +++ b/src/openhuman/channels/controllers/ops_tests.rs @@ -598,8 +598,78 @@ async fn connect_yuanbao_persists_when_credentials_valid() { yb.get("app_key").and_then(toml::Value::as_str), Some("real-key") ); + // The plaintext `app_secret` must NOT be persisted in TOML — the + // runtime loads it from the encrypted credentials store instead. + let toml_secret = yb.get("app_secret").and_then(toml::Value::as_str); + assert!( + toml_secret.is_none() || toml_secret == Some(""), + "app_secret must not be persisted in plaintext TOML, got {toml_secret:?}" + ); + + // The credentials store should contain the secret so startup can recover it. + let auth = crate::openhuman::credentials::AuthService::from_config(&config); + let profile = auth + .get_profile("channel:yuanbao:api_key", None) + .expect("credentials lookup succeeds") + .expect("yuanbao credentials stored"); assert_eq!( - yb.get("app_secret").and_then(toml::Value::as_str), + profile.metadata.get("app_secret").map(String::as_str), Some("real-secret") ); + assert_eq!( + profile.metadata.get("app_key").map(String::as_str), + Some("real-key") + ); +} + +#[tokio::test] +async fn connect_yuanbao_persists_env_override() { + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/v5/robotLogic/sign-token")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 0, + "data": { + "token": "tok-pre", + "bot_id": "bot-456", + "product": "yuanbao", + "source": "openhuman", + "duration": 3600, + } + }))) + .mount(&server) + .await; + + let (_tmp, config) = yuanbao_test_config(&server.uri()); + connect_channel( + &config, + "yuanbao", + ChannelAuthMode::ApiKey, + serde_json::json!({ + "app_key": "k", + "app_secret": "s", + "env": "pre", + "route_env": "canary", + }), + ) + .await + .expect("valid yuanbao credentials should succeed"); + + let raw = tokio::fs::read_to_string(&config.config_path) + .await + .expect("config should be persisted"); + let parsed: toml::Value = toml::from_str(&raw).expect("config parses"); + let yb = parsed + .get("channels_config") + .and_then(|v| v.get("yuanbao")) + .and_then(toml::Value::as_table) + .expect("channels_config.yuanbao persisted"); + assert_eq!(yb.get("env").and_then(toml::Value::as_str), Some("pre")); + assert_eq!( + yb.get("route_env").and_then(toml::Value::as_str), + Some("canary") + ); } diff --git a/src/openhuman/channels/providers/yuanbao/channel.rs b/src/openhuman/channels/providers/yuanbao/channel.rs index 2baeaa8646..b597dfe79b 100644 --- a/src/openhuman/channels/providers/yuanbao/channel.rs +++ b/src/openhuman/channels/providers/yuanbao/channel.rs @@ -405,15 +405,12 @@ impl YuanbaoChannel { tracing::trace!("[yuanbao] filtered at {reason}"); } PipelineOutcome::Failed(err) => { - let preview_len = biz_body.len().min(256); - let hex: String = biz_body[..preview_len] - .iter() - .map(|b| format!("{b:02x}")) - .collect(); + // Intentionally omit the raw biz payload — it can carry + // user content / PII. The decoder error already encodes + // the structural reason; only the length is safe to log. warn!( - "[yuanbao] pipeline error: {err} | biz_len={} biz_hex_first_{preview_len}={}", - biz_body.len(), - hex + "[yuanbao] pipeline error: {err} | biz_len={}", + biz_body.len() ); } } @@ -693,10 +690,7 @@ mod tests { let ch = YuanbaoChannel::new(good_cfg()).unwrap(); ch.start_heartbeat_task("recipient-1").await; assert!( - ch.heartbeat_tasks - .lock() - .await - .contains_key("recipient-1"), + ch.heartbeat_tasks.lock().await.contains_key("recipient-1"), "should have spawned a task for recipient-1" ); // Second start for same recipient is a no-op (does not double-spawn). diff --git a/src/openhuman/channels/providers/yuanbao/connection.rs b/src/openhuman/channels/providers/yuanbao/connection.rs index 480581cbaa..c906533656 100644 --- a/src/openhuman/channels/providers/yuanbao/connection.rs +++ b/src/openhuman/channels/providers/yuanbao/connection.rs @@ -218,6 +218,15 @@ impl YuanbaoConnection { *self.sender.lock().await = None; self.pending.lock().clear(); + // `connect_once` may have returned because shutdown fired inside + // its read loop. In that case we must not sleep through the + // reconnect backoff — exit immediately so stop is responsive. + if *shutdown.borrow() { + info!("[yuanbao] shutdown signaled, stopping connection loop"); + self.shutdown().await; + return; + } + attempt += 1; let delay = backoff_seconds(attempt); info!( @@ -709,7 +718,48 @@ mod tests { let msg = Message::Binary(vec![0xFF, 0xFF, 0xFF, 0xFF]); let err = conn.handle_auth_response(&msg).unwrap_err(); // Either Proto decode error or some other surface — must not be Ok. - assert!(!matches!(err, YuanbaoError::AuthFailed(_) if format!("{err:?}").contains("binary"))); + assert!( + !matches!(err, YuanbaoError::AuthFailed(_) if format!("{err:?}").contains("binary")) + ); + } + + /// Regression guard for the post-`connect_once` shutdown short-circuit: + /// once shutdown is signaled, `run()` must not block on the reconnect + /// backoff. We force connect_once to fail synchronously (invalid WS URL), + /// then signal shutdown — total runtime must be well under the first + /// backoff slot (`backoff_seconds(1) == 1s`). + #[tokio::test] + async fn run_exits_promptly_after_shutdown_signal() { + use std::time::Instant; + let (tx, _rx) = mpsc::unbounded_channel(); + let mut c = cfg(); + // tokio-tungstenite rejects the URL synchronously — connect_once + // returns Err in microseconds, putting `run()` on the post-connect + // cleanup path that the fix targets. + c.ws_domain = "not-a-valid-ws-url".to_string(); + c.max_reconnect_attempts = 100; + let conn = YuanbaoConnection::new(c, tx, None); + let (sd_tx, sd_rx) = watch::channel(false); + + let handle = tokio::spawn(conn.clone().run(sd_rx)); + // Let `run()` enter the loop and attempt connect_once at least once. + time::sleep(Duration::from_millis(20)).await; + + let started = Instant::now(); + sd_tx.send(true).unwrap(); + + // The first reconnect backoff slot is 1s. Without responsive + // shutdown handling, run() would sleep through it before checking + // the flag. 500ms gives us comfortable headroom while staying + // far enough below the backoff to detect a regression. + let res = time::timeout(Duration::from_millis(500), handle).await; + res.expect("run() did not exit within 500ms of shutdown signal") + .expect("run() task panicked"); + assert!( + started.elapsed() < Duration::from_millis(500), + "run() took {:?} to exit after shutdown — backoff was not skipped", + started.elapsed() + ); } #[tokio::test] diff --git a/src/openhuman/channels/providers/yuanbao/proto_biz.rs b/src/openhuman/channels/providers/yuanbao/proto_biz.rs index 4cb620dda3..daf91f517d 100644 --- a/src/openhuman/channels/providers/yuanbao/proto_biz.rs +++ b/src/openhuman/channels/providers/yuanbao/proto_biz.rs @@ -241,10 +241,24 @@ pub fn encode_query_group_info(req_id: &str, group_code: &str) -> Vec { ) } +/// Try to narrow a varint into a smaller integer type, returning +/// `YuanbaoError::ProtoDecode` (instead of silently truncating) when +/// the upstream value is out of range. Used to harden response decoders +/// against malformed / adversarial input. +fn varint_to_i32(value: u64, field_label: &str) -> Result { + i32::try_from(value) + .map_err(|_| YuanbaoError::ProtoDecode(format!("{field_label} out of i32 range: {value}"))) +} + +fn varint_to_u32(value: u64, field_label: &str) -> Result { + u32::try_from(value) + .map_err(|_| YuanbaoError::ProtoDecode(format!("{field_label} out of u32 range: {value}"))) +} + pub fn decode_query_group_info_rsp(data: &[u8]) -> Result { let fields = parse_fields(data)?; let mut info = GroupInfo { - code: get_varint(&fields, 1) as i32, + code: varint_to_i32(get_varint(&fields, 1), "GroupInfoRsp.code")?, message: get_string(&fields, 2), ..Default::default() }; @@ -254,7 +268,7 @@ pub fn decode_query_group_info_rsp(data: &[u8]) -> Result Result Result Result<(i32, String), YuanbaoError> { let fields = parse_fields(data)?; - Ok((get_varint(&fields, 1) as i32, get_string(&fields, 2))) + Ok(( + varint_to_i32(get_varint(&fields, 1), "BizRsp.code")?, + get_string(&fields, 2), + )) } /// Decode a `ConnMsg` and return the typed biz response code + frame for @@ -469,16 +486,8 @@ mod tests { #[test] fn group_encode_generates_synthetic_req_id_when_msg_id_empty() { - let buf = encode_send_group_message( - "group_x", - "uid_bot", - &text_body("hi"), - "", - "", - "", - "", - "", - ); + let buf = + encode_send_group_message("group_x", "uid_bot", &text_body("hi"), "", "", "", "", ""); let frame = decode_conn_msg(&buf).unwrap(); assert!( frame.msg_id.starts_with("grp_"), @@ -587,4 +596,43 @@ mod tests { assert_eq!(page.next_offset, 100); assert!(page.is_complete); } + + /// Adversarial input: a varint that overflows i32. The decoder must + /// surface `YuanbaoError::ProtoDecode` instead of silently truncating + /// (which would corrupt the `code` field returned to callers). + #[test] + fn decode_biz_rsp_code_rejects_varint_out_of_i32_range() { + let mut buf = Vec::new(); + put_varint_field(1, u64::MAX, &mut buf); + put_string_field(2, "ok", &mut buf); + match decode_biz_rsp_code(&buf).unwrap_err() { + YuanbaoError::ProtoDecode(m) => { + assert!( + m.contains("out of i32 range"), + "expected i32 overflow message, got: {m}" + ); + } + other => panic!("expected ProtoDecode, got {other:?}"), + } + } + + /// Same guard applied to the group-member-list `next_offset` field — + /// an oversized varint must produce a structured decode error, not a + /// silent `as u32` wrap that would mis-paginate subsequent fetches. + #[test] + fn decode_group_member_list_rejects_varint_out_of_u32_range() { + let mut rsp = Vec::new(); + put_varint_field(1, 0, &mut rsp); + put_string_field(2, "ok", &mut rsp); + put_varint_field(4, u64::from(u32::MAX) + 1, &mut rsp); + match decode_get_group_member_list_rsp(&rsp).unwrap_err() { + YuanbaoError::ProtoDecode(m) => { + assert!( + m.contains("out of u32 range"), + "expected u32 overflow message, got: {m}" + ); + } + other => panic!("expected ProtoDecode, got {other:?}"), + } + } } diff --git a/src/openhuman/channels/providers/yuanbao/wire.rs b/src/openhuman/channels/providers/yuanbao/wire.rs index 98d24f5a16..6c5582c8c6 100644 --- a/src/openhuman/channels/providers/yuanbao/wire.rs +++ b/src/openhuman/channels/providers/yuanbao/wire.rs @@ -100,7 +100,19 @@ pub fn parse_fields(data: &[u8]) -> Result, YuanbaoError> WT_LEN => { let (len, n) = decode_varint(data, pos)?; pos += n; - let end = pos + len as usize; + // Use checked conversions / arithmetic — a crafted oversize + // varint length would otherwise overflow `usize` on 32-bit + // targets and panic during slicing. + let len_usize = usize::try_from(len).map_err(|_| { + YuanbaoError::ProtoDecode(format!( + "len field {field} too large for platform: {len}" + )) + })?; + let end = pos.checked_add(len_usize).ok_or_else(|| { + YuanbaoError::ProtoDecode(format!( + "len field {field} overflows position: pos={pos} len={len}" + )) + })?; if end > data.len() { return Err(YuanbaoError::ProtoDecode(format!( "truncated len field {field}: need {len} have {}", @@ -254,6 +266,26 @@ mod tests { } } + #[test] + fn parse_fields_oversize_len_field_errors_without_panic() { + // Field 1 (wire type 2) with a varint length encoding `u64::MAX` — + // previously this would attempt `pos + len as usize`, overflowing + // on 32-bit and slicing past the buffer on 64-bit. Now it must + // return a structured decode error. + let mut buf = Vec::new(); + buf.push((1 << 3) | 2); // tag: field=1, wire=2 + encode_varint(u64::MAX, &mut buf); // adversarial length + match parse_fields(&buf).unwrap_err() { + YuanbaoError::ProtoDecode(m) => { + assert!( + m.contains("too large") || m.contains("overflows") || m.contains("truncated"), + "expected overflow/truncation error, got {m}" + ); + } + other => panic!("unexpected {other:?}"), + } + } + #[test] fn parse_fields_reads_fixed64() { let mut buf = Vec::new(); diff --git a/src/openhuman/channels/runtime/startup.rs b/src/openhuman/channels/runtime/startup.rs index 965e119db4..f108593462 100644 --- a/src/openhuman/channels/runtime/startup.rs +++ b/src/openhuman/channels/runtime/startup.rs @@ -502,7 +502,8 @@ pub async fn start_channels(config: Config) -> Result<()> { } if let Some(ref yb) = config.channels_config.yuanbao { - match YuanbaoChannel::new(yb.clone()) { + let yb_cfg = resolve_yuanbao_app_secret(yb.clone(), &config); + match YuanbaoChannel::new(yb_cfg) { Ok(ch) => channels.push(Arc::new(ch)), Err(e) => tracing::warn!("[channels] yuanbao config invalid: {e}"), } @@ -643,3 +644,103 @@ pub async fn start_channels(config: Config) -> Result<()> { Ok(()) } + +/// Best-effort fill of `yb_cfg.app_secret` from the encrypted credentials +/// store when TOML doesn't already carry one. +/// +/// `app_secret` is intentionally not persisted in `config.toml` (see the +/// `yuanbao` branch in `controllers/ops.rs`). Existing TOML values still +/// win so manually-installed deployments don't break. Returns the +/// (possibly-modified) config; logging is the only side effect on failure. +fn resolve_yuanbao_app_secret( + mut yb_cfg: crate::openhuman::channels::providers::yuanbao::YuanbaoConfig, + config: &Config, +) -> crate::openhuman::channels::providers::yuanbao::YuanbaoConfig { + if !yb_cfg.app_secret.is_empty() { + return yb_cfg; + } + let auth = crate::openhuman::credentials::AuthService::from_config(config); + match auth.get_profile("channel:yuanbao:api_key", None) { + Ok(Some(profile)) => { + if let Some(secret) = profile.metadata.get("app_secret") { + yb_cfg.app_secret = secret.clone(); + } + } + Ok(None) => { + tracing::warn!( + "[channels] yuanbao credentials missing — connect the channel again from the UI" + ); + } + Err(e) => { + tracing::warn!("[channels] failed to load yuanbao credentials: {e}"); + } + } + yb_cfg +} + +#[cfg(test)] +mod yuanbao_secret_tests { + use super::*; + use crate::openhuman::channels::providers::yuanbao::YuanbaoConfig; + use crate::openhuman::credentials::AuthService; + use std::collections::HashMap; + use tempfile::tempdir; + + fn isolated_config() -> (tempfile::TempDir, Config) { + let tmp = tempdir().expect("tempdir"); + let mut config = Config::default(); + config.workspace_dir = tmp.path().join("workspace"); + config.config_path = tmp.path().join("config.toml"); + std::fs::create_dir_all(&config.workspace_dir).expect("workspace dir"); + (tmp, config) + } + + #[test] + fn loads_app_secret_from_credentials_when_toml_empty() { + let (_tmp, config) = isolated_config(); + // Pre-write the credentials the same way `connect_channel` does: + // metadata under the `channel:yuanbao:api_key` provider key. + let auth = AuthService::from_config(&config); + let mut metadata = HashMap::new(); + metadata.insert("app_key".to_string(), "ak".to_string()); + metadata.insert("app_secret".to_string(), "from-credentials".to_string()); + auth.store_provider_token("channel:yuanbao:api_key", "default", "", metadata, true) + .expect("store credentials"); + + let yb = YuanbaoConfig { + app_key: "ak".into(), + app_secret: String::new(), + ..Default::default() + }; + let resolved = resolve_yuanbao_app_secret(yb, &config); + assert_eq!(resolved.app_secret, "from-credentials"); + } + + #[test] + fn preserves_existing_toml_secret_without_consulting_store() { + // No credentials in the store at all — resolver must still leave + // the TOML-supplied secret untouched. + let (_tmp, config) = isolated_config(); + let yb = YuanbaoConfig { + app_key: "ak".into(), + app_secret: "from-toml".into(), + ..Default::default() + }; + let resolved = resolve_yuanbao_app_secret(yb, &config); + assert_eq!(resolved.app_secret, "from-toml"); + } + + #[test] + fn returns_empty_secret_when_neither_toml_nor_credentials_have_one() { + let (_tmp, config) = isolated_config(); + let yb = YuanbaoConfig { + app_key: "ak".into(), + app_secret: String::new(), + ..Default::default() + }; + let resolved = resolve_yuanbao_app_secret(yb, &config); + // Surfaces empty so the downstream `YuanbaoChannel::new` validate() + // step can fail clearly, instead of attempting auth with a stale value. + assert_eq!(resolved.app_secret, ""); + } +} From 3407652a438ccfd29db4047b8b65de1741f563ae Mon Sep 17 00:00:00 2001 From: lrt4836 <296659110@qq.com> Date: Sun, 24 May 2026 00:09:05 +0800 Subject: [PATCH 4/8] chore: apply rustfmt auto-fixes --- .../channels/providers/yuanbao/cos.rs | 19 +++++++---- .../channels/providers/yuanbao/media.rs | 33 +++++++++---------- .../channels/providers/yuanbao/types.rs | 6 +++- 3 files changed, 33 insertions(+), 25 deletions(-) diff --git a/src/openhuman/channels/providers/yuanbao/cos.rs b/src/openhuman/channels/providers/yuanbao/cos.rs index 3dc2a9bda1..c9cd05fac8 100644 --- a/src/openhuman/channels/providers/yuanbao/cos.rs +++ b/src/openhuman/channels/providers/yuanbao/cos.rs @@ -438,10 +438,9 @@ mod tests { .mount(&server) .await; let http = reqwest::Client::new(); - let creds = - get_cos_credentials(&http, &server.uri(), "appk", "bot", "tok", "", "file.png") - .await - .unwrap(); + let creds = get_cos_credentials(&http, &server.uri(), "appk", "bot", "tok", "", "file.png") + .await + .unwrap(); assert_eq!(creds.bucket, "bkt-1"); assert_eq!(creds.region, "ap-shanghai"); assert_eq!(creds.location, "k/v/file.png"); @@ -536,9 +535,15 @@ mod tests { let http = reqwest::Client::new(); // empty credentials → fail without making any HTTP call let bad = CosCredentials::default(); - let err = upload_to_cos(&http, &bad, b"data", "f.bin", "application/octet-stream".into()) - .await - .unwrap_err(); + let err = upload_to_cos( + &http, + &bad, + b"data", + "f.bin", + "application/octet-stream".into(), + ) + .await + .unwrap_err(); match err { YuanbaoError::Media(m) => assert!(m.contains("credentials missing"), "got {m}"), other => panic!("expected Media error, got {other:?}"), diff --git a/src/openhuman/channels/providers/yuanbao/media.rs b/src/openhuman/channels/providers/yuanbao/media.rs index 266db09bb3..271be114ba 100644 --- a/src/openhuman/channels/providers/yuanbao/media.rs +++ b/src/openhuman/channels/providers/yuanbao/media.rs @@ -365,10 +365,7 @@ mod tests { guess_mime_type("file.xlsx"), "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" ); - assert_eq!( - guess_mime_type("file.ppt"), - "application/vnd.ms-powerpoint" - ); + assert_eq!(guess_mime_type("file.ppt"), "application/vnd.ms-powerpoint"); assert_eq!( guess_mime_type("file.pptx"), "application/vnd.openxmlformats-officedocument.presentationml.presentation" @@ -496,8 +493,15 @@ mod tests { #[test] fn build_image_msg_body_falls_back_to_filename_then_default_uuid() { - let with_filename = - build_image_msg_body("https://x/", None, Some("only-name.png"), 0, 0, 0, "image/png"); + let with_filename = build_image_msg_body( + "https://x/", + None, + Some("only-name.png"), + 0, + 0, + 0, + "image/png", + ); assert_eq!( with_filename[0].msg_content.uuid.as_deref(), Some("only-name.png") @@ -537,9 +541,7 @@ mod tests { async fn download_url_returns_bytes_and_content_type() { let server = wiremock::MockServer::start().await; wiremock::Mock::given(wiremock::matchers::method("HEAD")) - .respond_with( - wiremock::ResponseTemplate::new(200).insert_header("Content-Length", "3"), - ) + .respond_with(wiremock::ResponseTemplate::new(200).insert_header("Content-Length", "3")) .mount(&server) .await; wiremock::Mock::given(wiremock::matchers::method("GET")) @@ -561,12 +563,10 @@ mod tests { let server = wiremock::MockServer::start().await; // HEAD reports a very large file → reject BEFORE GET. wiremock::Mock::given(wiremock::matchers::method("HEAD")) - .respond_with( - wiremock::ResponseTemplate::new(200).insert_header( - "Content-Length", - (10u64 * 1024 * 1024 + 1).to_string().as_str(), - ), - ) + .respond_with(wiremock::ResponseTemplate::new(200).insert_header( + "Content-Length", + (10u64 * 1024 * 1024 + 1).to_string().as_str(), + )) .mount(&server) .await; let http = reqwest::Client::new(); @@ -586,8 +586,7 @@ mod tests { .await; wiremock::Mock::given(wiremock::matchers::method("GET")) .respond_with( - wiremock::ResponseTemplate::new(200) - .set_body_bytes(vec![0u8; 2 * 1024 * 1024]), // 2 MiB + wiremock::ResponseTemplate::new(200).set_body_bytes(vec![0u8; 2 * 1024 * 1024]), // 2 MiB ) .mount(&server) .await; diff --git a/src/openhuman/channels/providers/yuanbao/types.rs b/src/openhuman/channels/providers/yuanbao/types.rs index 6364f763f5..b2215dc2e7 100644 --- a/src/openhuman/channels/providers/yuanbao/types.rs +++ b/src/openhuman/channels/providers/yuanbao/types.rs @@ -278,7 +278,11 @@ mod tests { #[test] fn extract_text_concatenates_text_elements() { let m = InboundMessage { - msg_body: vec![text_elem("hello"), text_elem("world"), image_elem(&[], None)], + msg_body: vec![ + text_elem("hello"), + text_elem("world"), + image_elem(&[], None), + ], ..Default::default() }; assert_eq!(m.extract_text(), "hello\nworld"); From c468b77fcf7fdb23d4077a0c70f3454c6f7ba716 Mon Sep 17 00:00:00 2001 From: lrt4836 <296659110@qq.com> Date: Sun, 24 May 2026 00:33:41 +0800 Subject: [PATCH 5/8] fix(channels/yuanbao): address 2nd-round CodeRabbit review on PR #2494 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three CodeRabbit-flagged issues from the follow-up review: 1) ops.rs — apply endpoint overrides BEFORE preflight verification. The previous flow rebuilt YuanbaoConfig from `config.channels_config.yuanbao` alone inside `verify_yuanbao_credentials`, so client-supplied `env` / `api_domain` / `ws_domain` / `route_env` overrides were ignored at verify-time but applied at persistence-time. A user submitting `env = "pre"` could pass verification against PROD's sign-token cluster and then fail auth after restart when the persisted override took effect. The verifier now consumes a pre-built effective `YuanbaoConfig`. A new `build_effective_yuanbao_config` helper overlays the overrides on top of the existing TOML, calls `apply_env_defaults`, and produces the single config used for BOTH verify and persistence — they can no longer diverge. 2) cos.rs — tighten `get_cos_credentials_sends_route_env_header_when_non_empty`. The wiremock matcher only checked for `X-Route-Env: canary`, so a future refactor routing the call to the wrong endpoint would still pass the test as long as some POST carried that header. Bind the matcher to `UPLOAD_INFO_PATH` as well. 3) startup.rs — don't hydrate `app_secret` from a stale profile. `resolve_yuanbao_app_secret` used to copy whatever secret was in the encrypted store, even if `yb_cfg.app_key` had been edited in `config.toml`. That would silently pair a new key with an old secret on next startup and the channel would fail auth until the user reconnected manually. Now compare the stored profile's `app_key` metadata to `yb_cfg.app_key` before copying. On mismatch, log a warning naming both keys and leave `app_secret` empty so `YuanbaoChannel::new`'s `validate()` step fails loudly instead of attempting auth with a stale value. Tests added: - ops_tests: `connect_yuanbao_verifies_against_overridden_api_domain` proves the verifier hits the override URI even when base TOML `api_domain` points at a black hole (`http://127.0.0.1:1`), and that the same override is the one persisted. - startup tests: `skips_hydration_when_stored_profile_has_different_app_key` proves we leave `app_secret` empty when the store profile is keyed to a different `app_key`. cargo test -- yuanbao: 190 passed, 0 failed. --- src/openhuman/channels/controllers/ops.rs | 168 +++++++++--------- .../channels/controllers/ops_tests.rs | 78 ++++++++ .../channels/providers/yuanbao/cos.rs | 4 + src/openhuman/channels/runtime/startup.rs | 44 ++++- 4 files changed, 212 insertions(+), 82 deletions(-) diff --git a/src/openhuman/channels/controllers/ops.rs b/src/openhuman/channels/controllers/ops.rs index 55874d70f9..1df9722463 100644 --- a/src/openhuman/channels/controllers/ops.rs +++ b/src/openhuman/channels/controllers/ops.rs @@ -7,6 +7,7 @@ use crate::api::config::{app_env_from_env, effective_backend_api_url, is_staging use crate::api::jwt::get_session_token; use crate::api::rest::BackendOAuthClient; use crate::openhuman::channels::providers::yuanbao::sign::SignManager; +use crate::openhuman::channels::providers::yuanbao::YuanbaoConfig; use crate::openhuman::config::{Config, DiscordConfig, IMessageConfig, TelegramConfig}; use crate::openhuman::credentials; use crate::rpc::RpcOutcome; @@ -109,40 +110,83 @@ fn parse_optional_bool(value: Option<&Value>) -> Option { } } -/// Verify Yuanbao credentials against the `sign-token` endpoint before any -/// persistence so invalid `app_key` / `app_secret` surface the upstream API -/// error to the user instead of silently succeeding. -/// -/// Honours an explicit `api_domain` already configured in TOML; otherwise -/// derives it from `env` (prod by default). -async fn verify_yuanbao_credentials( - config: &Config, +/// Read a required non-empty Yuanbao credential field from the connect-channel +/// payload. Returns the trimmed value or an error naming the missing field. +fn require_yuanbao_field( creds_map: &serde_json::Map, -) -> Result<(), String> { - let app_key = creds_map - .get("app_key") + key: &str, +) -> Result { + creds_map + .get(key) .and_then(|v| v.as_str()) .map(str::trim) .filter(|s| !s.is_empty()) - .ok_or_else(|| "missing required app_key".to_string())?; - let app_secret = creds_map - .get("app_secret") - .and_then(|v| v.as_str()) - .map(str::trim) - .filter(|s| !s.is_empty()) - .ok_or_else(|| "missing required app_secret".to_string())?; + .map(|s| s.to_string()) + .ok_or_else(|| format!("missing required {key}")) +} - let mut yb_config = config.channels_config.yuanbao.clone().unwrap_or_default(); - if yb_config.api_domain.is_empty() { - yb_config.apply_env_defaults(); +/// Build the **effective** Yuanbao config that will be used for both +/// preflight verification and persistence. +/// +/// Starts from the existing TOML (so manually-installed deployments keep +/// any custom routes), overlays the client-supplied endpoint overrides +/// (`env` / `api_domain` / `ws_domain` / `route_env`), then calls +/// `apply_env_defaults` so the verifier hits the correct cluster — e.g. a +/// user submitting `env = "pre"` is verified against the pre-release +/// sign-token endpoint instead of the default prod one. +/// +/// `app_secret` is intentionally left empty: the runtime loads it from +/// the encrypted credentials store at startup, never from `config.toml`. +fn build_effective_yuanbao_config( + base: YuanbaoConfig, + creds_map: &serde_json::Map, + app_key: String, +) -> YuanbaoConfig { + let opt_string = |key: &str| -> Option { + creds_map + .get(key) + .and_then(|v| v.as_str()) + .map(str::trim) + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()) + }; + + let mut cfg = base; + cfg.app_key = app_key; + cfg.app_secret = String::new(); + if let Some(env) = opt_string("env") { + cfg.env = env; + } + if let Some(api_domain) = opt_string("api_domain") { + cfg.api_domain = api_domain; + } + if let Some(ws_domain) = opt_string("ws_domain") { + cfg.ws_domain = ws_domain; + } + if let Some(route_env) = opt_string("route_env") { + cfg.route_env = route_env; } + cfg.apply_env_defaults(); + cfg +} +/// Verify Yuanbao credentials against the `sign-token` endpoint before any +/// persistence so invalid `app_key` / `app_secret` surface the upstream API +/// error to the user instead of silently succeeding. +/// +/// Takes the **effective** `YuanbaoConfig` already built from the client's +/// overrides + TOML defaults, so the verifier targets whatever cluster the +/// runtime will use after restart. +async fn verify_yuanbao_credentials( + yb_cfg: &YuanbaoConfig, + app_secret: &str, +) -> Result<(), String> { SignManager::new(reqwest::Client::new()) .get_token( - app_key, + &yb_cfg.app_key, app_secret, - &yb_config.api_domain, - &yb_config.route_env, + &yb_cfg.api_domain, + &yb_cfg.route_env, ) .await .map_err(|e| format!("yuanbao credential verification failed: {e}"))?; @@ -201,11 +245,19 @@ pub async fn connect_channel( def.validate_credentials(auth_mode, creds_map)?; - // Yuanbao: verify credentials with the sign-token endpoint before any - // persistence so invalid creds surface the upstream API error to the - // user without leaving dangling credential entries or TOML state. + // Yuanbao: build the effective config (with any client-supplied + // endpoint overrides applied) once, verify against THAT cluster, and + // reuse the same config for persistence below. This prevents the + // verifier from validating against prod while the runtime then + // reconnects to a pre-release cluster after restart. + let mut prebuilt_yuanbao_config: Option = None; if channel_id == "yuanbao" && auth_mode == ChannelAuthMode::ApiKey { - verify_yuanbao_credentials(config, creds_map).await?; + let app_key = require_yuanbao_field(creds_map, "app_key")?; + let app_secret = require_yuanbao_field(creds_map, "app_secret")?; + let base = config.channels_config.yuanbao.clone().unwrap_or_default(); + let effective = build_effective_yuanbao_config(base, creds_map, app_key); + verify_yuanbao_credentials(&effective, &app_secret).await?; + prebuilt_yuanbao_config = Some(effective); } // iMessage is local-only (no credentials): persist channels_config + return connected. @@ -381,61 +433,15 @@ pub async fn connect_channel( "[discord] connect_channel: wrote channels_config.discord; restart core for listener to load token" ); } else if channel_id == "yuanbao" && auth_mode == ChannelAuthMode::ApiKey { - let app_key = creds_map - .get("app_key") - .and_then(|v| v.as_str()) - .map(str::trim) - .filter(|s| !s.is_empty()) - .ok_or_else(|| "missing required app_key".to_string())? - .to_string(); - // `app_secret` is already in the encrypted credentials store - // (stored above via `store_provider_credentials`); we intentionally - // do NOT mirror it into the plaintext TOML to limit exposure - // surface. The runtime loads it from credentials at startup. - let _ = creds_map - .get("app_secret") - .and_then(|v| v.as_str()) - .map(str::trim) - .filter(|s| !s.is_empty()) - .ok_or_else(|| "missing required app_secret".to_string())?; - - // Optional endpoint overrides — preserve any non-default values - // submitted by the client (e.g. `env = "pre"`) so the runtime - // reconnects to the correct cluster after restart. - let opt_string = |key: &str| -> Option { - creds_map - .get(key) - .and_then(|v| v.as_str()) - .map(str::trim) - .filter(|s| !s.is_empty()) - .map(|s| s.to_string()) - }; - let env_override = opt_string("env"); - let api_domain_override = opt_string("api_domain"); - let ws_domain_override = opt_string("ws_domain"); - let route_env_override = opt_string("route_env"); + // Reuse the effective config built above (with `env` / `api_domain` + // / `ws_domain` / `route_env` overrides already applied and + // `app_secret` already cleared) so persistence and verification + // can never diverge. + let yb_config = prebuilt_yuanbao_config + .take() + .expect("yuanbao verify branch must run before persistence"); let mut persisted = config.clone(); - let mut yb_config = persisted - .channels_config - .yuanbao - .clone() - .unwrap_or_default(); - yb_config.app_key = app_key; - // Clear any stale plaintext secret from a previous version. - yb_config.app_secret = String::new(); - if let Some(env) = env_override { - yb_config.env = env; - } - if let Some(api_domain) = api_domain_override { - yb_config.api_domain = api_domain; - } - if let Some(ws_domain) = ws_domain_override { - yb_config.ws_domain = ws_domain; - } - if let Some(route_env) = route_env_override { - yb_config.route_env = route_env; - } persisted.channels_config.yuanbao = Some(yb_config); persisted diff --git a/src/openhuman/channels/controllers/ops_tests.rs b/src/openhuman/channels/controllers/ops_tests.rs index 825ca7747f..446ae5d43a 100644 --- a/src/openhuman/channels/controllers/ops_tests.rs +++ b/src/openhuman/channels/controllers/ops_tests.rs @@ -622,6 +622,84 @@ async fn connect_yuanbao_persists_when_credentials_valid() { ); } +#[tokio::test] +async fn connect_yuanbao_verifies_against_overridden_api_domain() { + // Regression: previously, `verify_yuanbao_credentials` rebuilt the + // YuanbaoConfig from `config.channels_config.yuanbao` alone and + // ignored the `api_domain` / `env` / `route_env` overrides on the + // connect-channel payload. A user submitting `env = "pre"` could + // pass verification against PROD and then fail after restart when + // the persisted override took effect. + // + // Here the base TOML's `api_domain` deliberately points at an + // unreachable URL — verification only succeeds if the override + // supplied in `creds_map` is what actually gets used. + use wiremock::matchers::{header, method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/v5/robotLogic/sign-token")) + .and(header("X-Route-Env", "canary")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "code": 0, + "data": { + "token": "tok-override", + "bot_id": "bot-1", + "product": "yuanbao", + "source": "openhuman", + "duration": 3600, + } + }))) + .mount(&server) + .await; + + let (_tmp, mut config) = isolated_test_config(); + // Base TOML points to a black hole so the test fails immediately if + // the verifier ignores the override. + config.channels_config.yuanbao = Some(YuanbaoConfig { + api_domain: "http://127.0.0.1:1".to_string(), + ..Default::default() + }); + + let mock_uri = server.uri(); + let result = connect_channel( + &config, + "yuanbao", + ChannelAuthMode::ApiKey, + serde_json::json!({ + "app_key": "k", + "app_secret": "s", + "api_domain": mock_uri.clone(), + "route_env": "canary", + }), + ) + .await + .expect("override should be applied before verify"); + + assert_eq!(result.value.status, "connected"); + + // The override should also have been persisted (single source of + // truth between verify and persist). + let raw = tokio::fs::read_to_string(&config.config_path) + .await + .expect("config should be persisted"); + let parsed: toml::Value = toml::from_str(&raw).expect("config parses"); + let yb = parsed + .get("channels_config") + .and_then(|v| v.get("yuanbao")) + .and_then(toml::Value::as_table) + .expect("channels_config.yuanbao persisted"); + assert_eq!( + yb.get("api_domain").and_then(toml::Value::as_str), + Some(mock_uri.as_str()), + ); + assert_eq!( + yb.get("route_env").and_then(toml::Value::as_str), + Some("canary"), + ); +} + #[tokio::test] async fn connect_yuanbao_persists_env_override() { use wiremock::matchers::{method, path}; diff --git a/src/openhuman/channels/providers/yuanbao/cos.rs b/src/openhuman/channels/providers/yuanbao/cos.rs index c9cd05fac8..e7aa40fec0 100644 --- a/src/openhuman/channels/providers/yuanbao/cos.rs +++ b/src/openhuman/channels/providers/yuanbao/cos.rs @@ -474,7 +474,11 @@ mod tests { #[tokio::test] async fn get_cos_credentials_sends_route_env_header_when_non_empty() { let server = wiremock::MockServer::start().await; + // Bind the matcher to both the upload-info path AND the header so + // this test fails if a future refactor routes the call elsewhere + // but happens to still attach `X-Route-Env: canary` somewhere. wiremock::Mock::given(wiremock::matchers::method("POST")) + .and(wiremock::matchers::path(UPLOAD_INFO_PATH)) .and(wiremock::matchers::header("X-Route-Env", "canary")) .respond_with( wiremock::ResponseTemplate::new(200) diff --git a/src/openhuman/channels/runtime/startup.rs b/src/openhuman/channels/runtime/startup.rs index f108593462..480e6832d6 100644 --- a/src/openhuman/channels/runtime/startup.rs +++ b/src/openhuman/channels/runtime/startup.rs @@ -652,6 +652,12 @@ pub async fn start_channels(config: Config) -> Result<()> { /// `yuanbao` branch in `controllers/ops.rs`). Existing TOML values still /// win so manually-installed deployments don't break. Returns the /// (possibly-modified) config; logging is the only side effect on failure. +/// +/// The stored secret is **only** copied when the stored profile's +/// `app_key` matches `yb_cfg.app_key`. Without that guard, editing +/// `app_key` in `config.toml` would silently pair a fresh key with a +/// stale secret on next startup, and the channel would fail auth until +/// the user reconnected or cleared credentials manually. fn resolve_yuanbao_app_secret( mut yb_cfg: crate::openhuman::channels::providers::yuanbao::YuanbaoConfig, config: &Config, @@ -662,7 +668,14 @@ fn resolve_yuanbao_app_secret( let auth = crate::openhuman::credentials::AuthService::from_config(config); match auth.get_profile("channel:yuanbao:api_key", None) { Ok(Some(profile)) => { - if let Some(secret) = profile.metadata.get("app_secret") { + let stored_app_key = profile.metadata.get("app_key").map(String::as_str); + if stored_app_key != Some(yb_cfg.app_key.as_str()) { + tracing::warn!( + "[channels] yuanbao stored credentials are for a different app_key (toml={:?}, store={:?}); reconnect the channel to refresh the secret", + yb_cfg.app_key, + stored_app_key, + ); + } else if let Some(secret) = profile.metadata.get("app_secret") { yb_cfg.app_secret = secret.clone(); } } @@ -743,4 +756,33 @@ mod yuanbao_secret_tests { // step can fail clearly, instead of attempting auth with a stale value. assert_eq!(resolved.app_secret, ""); } + + #[test] + fn skips_hydration_when_stored_profile_has_different_app_key() { + // Reproduces the stale-secret hazard: user changed `app_key` in + // `config.toml` (e.g. swapped to a different bot) but the + // credentials store still has the old key's profile. The resolver + // must NOT graft the old secret onto the new key. + let (_tmp, config) = isolated_config(); + let auth = AuthService::from_config(&config); + let mut metadata = HashMap::new(); + metadata.insert("app_key".to_string(), "OLD-KEY".to_string()); + metadata.insert( + "app_secret".to_string(), + "old-key-secret-do-not-use".to_string(), + ); + auth.store_provider_token("channel:yuanbao:api_key", "default", "", metadata, true) + .expect("store credentials"); + + let yb = YuanbaoConfig { + app_key: "NEW-KEY".into(), + app_secret: String::new(), + ..Default::default() + }; + let resolved = resolve_yuanbao_app_secret(yb, &config); + assert_eq!( + resolved.app_secret, "", + "stale profile keyed to OLD-KEY must not hydrate NEW-KEY's secret", + ); + } } From 4edb66fd25493c6dca64de623d885aff580d3be5 Mon Sep 17 00:00:00 2001 From: lrt4836 <296659110@qq.com> Date: Sun, 24 May 2026 13:10:45 +0800 Subject: [PATCH 6/8] docs(matrix): add 10.1.5 Yuanbao Connection row MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per the matrix update contract in `docs/TEST-COVERAGE-MATRIX.md` — any PR that adds, removes, or changes a feature leaf must update the matrix in the same change. This PR introduces the Yuanbao channel provider, so add the leaf row pointing at the RU test paths landed in this PR (sign-token preflight, credentials store hydration including the stale-app_key guard, WS reconnect/shutdown). Status 🟡 not ✅: there is no dedicated WDIO spec for Yuanbao yet — the connect-flow UI is rendered via the generic `ChannelSetupModal` that is already covered by other channel flow specs. --- docs/TEST-COVERAGE-MATRIX.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/TEST-COVERAGE-MATRIX.md b/docs/TEST-COVERAGE-MATRIX.md index 9674728600..1b4f995c16 100644 --- a/docs/TEST-COVERAGE-MATRIX.md +++ b/docs/TEST-COVERAGE-MATRIX.md @@ -344,6 +344,7 @@ Canonical mapping of every product feature to its test source(s). Drives gap-fil | 10.1.2 | WhatsApp Connection | WD | `app/test/e2e/specs/whatsapp-flow.spec.ts` (this PR) | ✅ | Was ❌ | | 10.1.3 | Gmail Connection | WD | `gmail-flow.spec.ts` | ✅ | | | 10.1.4 | Slack Connection | WD | `app/test/e2e/specs/slack-flow.spec.ts` (this PR) | ✅ | Was ❌ | +| 10.1.5 | Yuanbao Connection | RU | `src/openhuman/channels/providers/yuanbao/` (this PR), `src/openhuman/channels/controllers/ops.rs::tests::connect_yuanbao_*` (this PR), `src/openhuman/channels/runtime/startup.rs::yuanbao_secret_tests` (this PR) | 🟡 | New API-key channel for Tencent Yuanbao. RU covers sign-token preflight (valid/invalid creds, env-override cluster routing), credentials store hydration (incl. stale app_key guard), and WS reconnect/shutdown. No WDIO spec yet — connect-flow UI is rendered via the generic `ChannelSetupModal` already exercised by other channel flow specs. | ### 10.2 Authentication & Authorization From a05e7e60ac48343719820068f7b041777400fb76 Mon Sep 17 00:00:00 2001 From: lrt4836 <296659110@qq.com> Date: Sun, 24 May 2026 15:30:41 +0800 Subject: [PATCH 7/8] =?UTF-8?q?test(channels/yuanbao):=20cover=20diff=20li?= =?UTF-8?q?nes=20for=20=E2=89=A580%=20gate?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Coverage Gate (diff-cover ≥ 80%) job was failing on PR #2494 at 13% because the yuanbao channel surface lacked Vitest coverage. Add targeted unit tests that lift diff coverage on the touched files well above the 80% threshold: - YuanbaoIcon: 33% → 100% (default/custom className, unique clipPath ids per instance for collision-free dual rendering) - YuanbaoConfig: 2% → 95.1% (renders fields, returns null with no auth modes, inline validation + clear-on-input, all four connect outcomes — connected/restart_required/restart-fails/ non-connected/connect-throws — disconnect success+failure, stale-connecting reset on mount, lastError rendering) - ChannelSetupModal: 60% → 100% (yuanbao SVG branch + yuanbao switch case + emoji branch + fallback message + Escape close) - channelConnectionsSlice: 87.5% → 100% (ensureChannelModes lazy-init when persisted state is missing the yuanbao key) Diff-cover total on changed lines: 13% → 84%. --- .../__tests__/ChannelSetupModal.test.tsx | 72 +++++ .../channels/__tests__/YuanbaoConfig.test.tsx | 287 ++++++++++++++++++ .../channels/__tests__/YuanbaoIcon.test.tsx | 37 +++ .../__tests__/channelConnectionsSlice.test.ts | 25 ++ 4 files changed, 421 insertions(+) create mode 100644 app/src/components/channels/__tests__/ChannelSetupModal.test.tsx create mode 100644 app/src/components/channels/__tests__/YuanbaoConfig.test.tsx create mode 100644 app/src/components/channels/__tests__/YuanbaoIcon.test.tsx diff --git a/app/src/components/channels/__tests__/ChannelSetupModal.test.tsx b/app/src/components/channels/__tests__/ChannelSetupModal.test.tsx new file mode 100644 index 0000000000..ae521baa31 --- /dev/null +++ b/app/src/components/channels/__tests__/ChannelSetupModal.test.tsx @@ -0,0 +1,72 @@ +import { screen } from '@testing-library/react'; +import { describe, expect, it, vi } from 'vitest'; + +import { FALLBACK_DEFINITIONS } from '../../../lib/channels/definitions'; +import { renderWithProviders } from '../../../test/test-utils'; +import type { ChannelDefinition } from '../../../types/channels'; +import ChannelSetupModal from '../ChannelSetupModal'; + +// YuanbaoConfig pulls in API + Tauri helpers we don't need for the routing +// branches under test — stub it so we only assert ChannelSetupModal's own +// behavior (icon branch + yuanbao switch case). +vi.mock('../YuanbaoConfig', () => ({ + default: () =>
Yuanbao Config
, +})); + +vi.mock('../TelegramConfig', () => ({ + default: () =>
Telegram Config
, +})); + +vi.mock('../DiscordConfig', () => ({ + default: () =>
Discord Config
, +})); + +const yuanbaoDef: ChannelDefinition = { + id: 'yuanbao', + display_name: '元宝', + description: '通过元宝(Yuanbao)机器人收发消息。', + icon: 'yuanbao', + auth_modes: [ + { + mode: 'api_key', + description: '提供元宝开放平台的 AppID 和 AppSecret。', + fields: [], + auth_action: undefined, + }, + ], + capabilities: ['send_text', 'receive_text'], +}; + +describe('ChannelSetupModal', () => { + it('renders the YuanbaoConfig body and brand SVG icon for the yuanbao channel', () => { + renderWithProviders( {}} />); + // Header title + body routing both exercised. + expect(screen.getByText('元宝')).toBeInTheDocument(); + expect(screen.getByTestId('yuanbao-config')).toBeInTheDocument(); + // YuanbaoIcon emits an aria-hidden SVG in the header; the emoji-based + // fallback should NOT also render for yuanbao. + const dialog = screen.getByRole('dialog'); + expect(dialog.querySelector('svg[aria-hidden="true"]')).not.toBeNull(); + }); + + it('renders the emoji icon and TelegramConfig body for the telegram channel', () => { + const telegramDef = FALLBACK_DEFINITIONS.find(d => d.id === 'telegram')!; + renderWithProviders( {}} />); + expect(screen.getByTestId('telegram-config')).toBeInTheDocument(); + // Emoji branch produces a span sibling to the title. + expect(screen.getByText('\u2708\uFE0F')).toBeInTheDocument(); + }); + + it('falls back to the unavailable-channel message for an unknown channel id', () => { + const unknown: ChannelDefinition = { ...yuanbaoDef, id: 'unknown', display_name: 'Unknown' }; + renderWithProviders( {}} />); + expect(screen.getByText(/Configuration for/i)).toBeInTheDocument(); + }); + + it('invokes onClose when the Escape key is pressed', () => { + const onClose = vi.fn(); + renderWithProviders(); + document.dispatchEvent(new KeyboardEvent('keydown', { key: 'Escape' })); + expect(onClose).toHaveBeenCalledTimes(1); + }); +}); diff --git a/app/src/components/channels/__tests__/YuanbaoConfig.test.tsx b/app/src/components/channels/__tests__/YuanbaoConfig.test.tsx new file mode 100644 index 0000000000..404f44d6cf --- /dev/null +++ b/app/src/components/channels/__tests__/YuanbaoConfig.test.tsx @@ -0,0 +1,287 @@ +import { fireEvent, screen, waitFor } from '@testing-library/react'; +import { afterEach, describe, expect, it, vi } from 'vitest'; + +import { channelConnectionsApi } from '../../../services/api/channelConnectionsApi'; +import { setChannelConnectionStatus } from '../../../store/channelConnectionsSlice'; +import { createTestStore, renderWithProviders } from '../../../test/test-utils'; +import type { ChannelDefinition } from '../../../types/channels'; +import { restartCoreProcess } from '../../../utils/tauriCommands/core'; +import YuanbaoConfig from '../YuanbaoConfig'; + +vi.mock('../../../services/api/channelConnectionsApi', () => ({ + channelConnectionsApi: { connectChannel: vi.fn(), disconnectChannel: vi.fn() }, +})); + +vi.mock('../../../utils/tauriCommands/core', () => ({ restartCoreProcess: vi.fn() })); + +// Mirrors the backend yuanbao_definition() in +// src/openhuman/channels/controllers/definitions.rs — kept inline because +// the frontend fallback definitions list does not (yet) include yuanbao. +const yuanbaoDef: ChannelDefinition = { + id: 'yuanbao', + display_name: '元宝', + description: '通过元宝(Yuanbao)机器人收发消息。', + icon: 'yuanbao', + auth_modes: [ + { + mode: 'api_key', + description: '提供元宝开放平台的 AppID 和 AppSecret。', + fields: [ + { + key: 'app_key', + label: 'AppID', + field_type: 'string', + required: true, + placeholder: '元宝开放平台 AppID', + }, + { + key: 'app_secret', + label: 'AppSecret', + field_type: 'secret', + required: true, + placeholder: '元宝开放平台 AppSecret', + }, + ], + auth_action: undefined, + }, + ], + capabilities: ['send_text', 'receive_text', 'typing'], +}; + +afterEach(() => { + vi.clearAllMocks(); +}); + +describe('YuanbaoConfig', () => { + it('renders the api_key mode label, description, and credential fields', () => { + renderWithProviders(); + expect(screen.getByText('Use your own API Key')).toBeInTheDocument(); + expect(screen.getByText(/AppID 和 AppSecret/)).toBeInTheDocument(); + expect(screen.getByPlaceholderText('元宝开放平台 AppID')).toBeInTheDocument(); + expect(screen.getByPlaceholderText('元宝开放平台 AppSecret')).toBeInTheDocument(); + }); + + it('shows a Connect and a (disabled) Disconnect button by default', () => { + renderWithProviders(); + expect(screen.getByText('Connect')).toBeInTheDocument(); + const disconnect = screen.getByText('Disconnect'); + expect(disconnect).toBeDisabled(); + }); + + it('returns null when the definition has no auth modes', () => { + const empty: ChannelDefinition = { ...yuanbaoDef, auth_modes: [] }; + const { container } = renderWithProviders(); + expect(container.firstChild).toBeNull(); + }); + + it('shows inline validation errors when required fields are empty and clears them on input', () => { + renderWithProviders(); + fireEvent.click(screen.getByText('Connect')); + + // Two required fields → two inline error messages. + const appKeyError = screen + .getAllByText(/AppID/) + .filter(node => node.className.includes('text-coral')); + expect(appKeyError.length).toBeGreaterThan(0); + expect(channelConnectionsApi.connectChannel).not.toHaveBeenCalled(); + + // Typing into a field clears that field's error (covers updateField + // branch that mutates fieldErrors). + fireEvent.change(screen.getByPlaceholderText('元宝开放平台 AppID'), { + target: { value: 'app-key-123' }, + }); + expect( + screen.queryAllByText(/AppID/).filter(node => node.className.includes('text-coral')).length + ).toBe(0); + }); + + it('connects successfully and dispatches connected when restart is not required', async () => { + vi.mocked(channelConnectionsApi.connectChannel).mockResolvedValue({ + status: 'connected', + restart_required: false, + }); + + const { store } = renderWithProviders(); + fireEvent.change(screen.getByPlaceholderText('元宝开放平台 AppID'), { + target: { value: 'app-key-123' }, + }); + fireEvent.change(screen.getByPlaceholderText('元宝开放平台 AppSecret'), { + target: { value: 'app-secret-xyz' }, + }); + fireEvent.click(screen.getByText('Connect')); + + await waitFor(() => { + expect(channelConnectionsApi.connectChannel).toHaveBeenCalledWith('yuanbao', { + authMode: 'api_key', + credentials: { app_key: 'app-key-123', app_secret: 'app-secret-xyz' }, + }); + }); + await waitFor(() => { + const conn = store.getState().channelConnections.connections.yuanbao?.api_key; + expect(conn?.status).toBe('connected'); + expect(conn?.capabilities).toEqual(['read', 'write']); + }); + expect(restartCoreProcess).not.toHaveBeenCalled(); + }); + + it('calls restartCoreProcess and dispatches connected when restart_required=true', async () => { + vi.mocked(channelConnectionsApi.connectChannel).mockResolvedValue({ + status: 'connected', + restart_required: true, + }); + vi.mocked(restartCoreProcess).mockResolvedValue(); + + const { store } = renderWithProviders(); + fireEvent.change(screen.getByPlaceholderText('元宝开放平台 AppID'), { + target: { value: 'app-key-123' }, + }); + fireEvent.change(screen.getByPlaceholderText('元宝开放平台 AppSecret'), { + target: { value: 'app-secret-xyz' }, + }); + fireEvent.click(screen.getByText('Connect')); + + await waitFor(() => { + expect(restartCoreProcess).toHaveBeenCalledTimes(1); + }); + await waitFor(() => { + const conn = store.getState().channelConnections.connections.yuanbao?.api_key; + expect(conn?.status).toBe('connected'); + }); + }); + + it('marks the channel as error when restartCoreProcess throws after a successful connect', async () => { + vi.mocked(channelConnectionsApi.connectChannel).mockResolvedValue({ + status: 'connected', + restart_required: true, + }); + vi.mocked(restartCoreProcess).mockRejectedValue(new Error('core restart failed')); + + const { store } = renderWithProviders(); + fireEvent.change(screen.getByPlaceholderText('元宝开放平台 AppID'), { + target: { value: 'app-key-123' }, + }); + fireEvent.change(screen.getByPlaceholderText('元宝开放平台 AppSecret'), { + target: { value: 'app-secret-xyz' }, + }); + fireEvent.click(screen.getByText('Connect')); + + await waitFor(() => { + const conn = store.getState().channelConnections.connections.yuanbao?.api_key; + expect(conn?.status).toBe('error'); + expect(conn?.lastError).toBeTruthy(); + }); + }); + + it('surfaces an error when the backend returns a non-connected status', async () => { + vi.mocked(channelConnectionsApi.connectChannel).mockResolvedValue({ + status: 'pending_auth', + restart_required: false, + }); + + const { store } = renderWithProviders(); + fireEvent.change(screen.getByPlaceholderText('元宝开放平台 AppID'), { + target: { value: 'app-key-123' }, + }); + fireEvent.change(screen.getByPlaceholderText('元宝开放平台 AppSecret'), { + target: { value: 'app-secret-xyz' }, + }); + fireEvent.click(screen.getByText('Connect')); + + await waitFor(() => { + const conn = store.getState().channelConnections.connections.yuanbao?.api_key; + expect(conn?.status).toBe('error'); + expect(conn?.lastError).toContain('pending_auth'); + }); + }); + + it('captures connect failures from the API and dispatches an error status', async () => { + vi.mocked(channelConnectionsApi.connectChannel).mockRejectedValue( + new Error('invalid credentials') + ); + + const { store } = renderWithProviders(); + fireEvent.change(screen.getByPlaceholderText('元宝开放平台 AppID'), { + target: { value: 'app-key-123' }, + }); + fireEvent.change(screen.getByPlaceholderText('元宝开放平台 AppSecret'), { + target: { value: 'app-secret-xyz' }, + }); + fireEvent.click(screen.getByText('Connect')); + + await waitFor(() => { + const conn = store.getState().channelConnections.connections.yuanbao?.api_key; + expect(conn?.status).toBe('error'); + expect(conn?.lastError).toBe('invalid credentials'); + }); + }); + + it('disconnects an active channel via the API and clears the connection', async () => { + const store = createTestStore(); + store.dispatch( + setChannelConnectionStatus({ channel: 'yuanbao', authMode: 'api_key', status: 'connected' }) + ); + vi.mocked(channelConnectionsApi.disconnectChannel).mockResolvedValue(); + + renderWithProviders(, { store }); + + // Status is connected → Reconnect label appears on the primary button. + expect(screen.getByText('Reconnect')).toBeInTheDocument(); + const disconnect = screen.getByText('Disconnect'); + expect(disconnect).not.toBeDisabled(); + fireEvent.click(disconnect); + + await waitFor(() => { + expect(channelConnectionsApi.disconnectChannel).toHaveBeenCalledWith('yuanbao', 'api_key'); + }); + await waitFor(() => { + const conn = store.getState().channelConnections.connections.yuanbao?.api_key; + expect(conn?.status).toBe('disconnected'); + }); + }); + + it('reports an error status when the disconnect API call fails', async () => { + const store = createTestStore(); + store.dispatch( + setChannelConnectionStatus({ channel: 'yuanbao', authMode: 'api_key', status: 'connected' }) + ); + vi.mocked(channelConnectionsApi.disconnectChannel).mockRejectedValue( + new Error('rpc unreachable') + ); + + renderWithProviders(, { store }); + fireEvent.click(screen.getByText('Disconnect')); + + await waitFor(() => { + const conn = store.getState().channelConnections.connections.yuanbao?.api_key; + expect(conn?.status).toBe('error'); + expect(conn?.lastError).toBe('rpc unreachable'); + }); + }); + + it('resets a stale "connecting" status from a previous session on mount', () => { + const store = createTestStore(); + store.dispatch( + setChannelConnectionStatus({ channel: 'yuanbao', authMode: 'api_key', status: 'connecting' }) + ); + + renderWithProviders(, { store }); + + const conn = store.getState().channelConnections.connections.yuanbao?.api_key; + expect(conn?.status).toBe('disconnected'); + }); + + it('renders the last error message when the connection is in an error state', () => { + const store = createTestStore(); + store.dispatch( + setChannelConnectionStatus({ + channel: 'yuanbao', + authMode: 'api_key', + status: 'error', + lastError: 'sign verification failed', + }) + ); + + renderWithProviders(, { store }); + expect(screen.getByText('sign verification failed')).toBeInTheDocument(); + }); +}); diff --git a/app/src/components/channels/__tests__/YuanbaoIcon.test.tsx b/app/src/components/channels/__tests__/YuanbaoIcon.test.tsx new file mode 100644 index 0000000000..067535c224 --- /dev/null +++ b/app/src/components/channels/__tests__/YuanbaoIcon.test.tsx @@ -0,0 +1,37 @@ +import { render } from '@testing-library/react'; +import { describe, expect, it } from 'vitest'; + +import YuanbaoIcon from '../YuanbaoIcon'; + +describe('YuanbaoIcon', () => { + it('renders an inline SVG with the default size class', () => { + const { container } = render(); + const svg = container.querySelector('svg'); + expect(svg).not.toBeNull(); + expect(svg).toHaveAttribute('aria-hidden', 'true'); + expect(svg?.getAttribute('class')).toContain('w-5'); + expect(svg?.getAttribute('class')).toContain('h-5'); + }); + + it('applies a custom className override', () => { + const { container } = render(); + const svg = container.querySelector('svg'); + expect(svg?.getAttribute('class')).toBe('w-10 h-10 text-amber-500'); + }); + + it('generates a unique clipPath id per instance so duplicate icons do not collide', () => { + const { container } = render( + <> + + + + ); + const clips = container.querySelectorAll('clipPath'); + expect(clips.length).toBe(2); + const id1 = clips[0].getAttribute('id'); + const id2 = clips[1].getAttribute('id'); + expect(id1).toBeTruthy(); + expect(id2).toBeTruthy(); + expect(id1).not.toBe(id2); + }); +}); diff --git a/app/src/store/__tests__/channelConnectionsSlice.test.ts b/app/src/store/__tests__/channelConnectionsSlice.test.ts index 3474bb43de..1c545d4956 100644 --- a/app/src/store/__tests__/channelConnectionsSlice.test.ts +++ b/app/src/store/__tests__/channelConnectionsSlice.test.ts @@ -177,6 +177,31 @@ describe('channelConnectionsSlice', () => { }); }); + it('lazily initialises a channel modes bucket when persisted state is missing the key', () => { + // Simulates a rehydrated state from before yuanbao existed: the channel + // key is absent so `state.connections.yuanbao` is undefined. Without + // `ensureChannelModes()` the first upsert would crash on + // `state.connections[channel][authMode]`. See `ensureChannelModes` in + // channelConnectionsSlice.ts. + const migrated = reducer(undefined, completeBreakingMigration()); + const partial = { + ...migrated, + connections: { ...migrated.connections, yuanbao: undefined as never }, + }; + + const next = reducer( + partial, + upsertChannelConnection({ + channel: 'yuanbao', + authMode: 'api_key', + patch: { status: 'connected' }, + }) + ); + + expect(next.connections.yuanbao).toBeDefined(); + expect(next.connections.yuanbao.api_key?.status).toBe('connected'); + }); + it('clears stale lastError when patch explicitly sets undefined', () => { const withError = reducer( undefined, From 2fd574691e8ba332f1e3c3028057a555393039c9 Mon Sep 17 00:00:00 2001 From: lrt4836 <296659110@qq.com> Date: Mon, 25 May 2026 10:14:05 +0800 Subject: [PATCH 8/8] i18n(channels/yuanbao): add skills.channelIcon.yuanbao key MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Upstream refactored `getChannelIcons()` in `skills/skillIcons.tsx` to take a `t()` translator and read each entry's aria-label from `skills.channelIcon.`. Add the matching `yuanbao` entry to `en.ts` and all 13 locale chunks so the yuanbao branch (resolved during rebase onto upstream/main) does not break `pnpm i18n:check`. zh-CN translated as 元宝; other locales carry the same `Yuanbao` placeholder as the existing un-translated `discord`/`telegram` keys. --- app/src/lib/i18n/chunks/ar-5.ts | 1 + app/src/lib/i18n/chunks/bn-5.ts | 1 + app/src/lib/i18n/chunks/de-5.ts | 1 + app/src/lib/i18n/chunks/en-5.ts | 1 + app/src/lib/i18n/chunks/es-5.ts | 1 + app/src/lib/i18n/chunks/fr-5.ts | 1 + app/src/lib/i18n/chunks/hi-5.ts | 1 + app/src/lib/i18n/chunks/id-5.ts | 1 + app/src/lib/i18n/chunks/it-5.ts | 1 + app/src/lib/i18n/chunks/ko-5.ts | 1 + app/src/lib/i18n/chunks/pt-5.ts | 1 + app/src/lib/i18n/chunks/ru-5.ts | 1 + app/src/lib/i18n/chunks/zh-CN-5.ts | 1 + app/src/lib/i18n/en.ts | 1 + 14 files changed, 14 insertions(+) diff --git a/app/src/lib/i18n/chunks/ar-5.ts b/app/src/lib/i18n/chunks/ar-5.ts index c808aeabdc..b459d06cb9 100644 --- a/app/src/lib/i18n/chunks/ar-5.ts +++ b/app/src/lib/i18n/chunks/ar-5.ts @@ -620,6 +620,7 @@ const ar5: TranslationMap = { 'skills.channelIcon.imessage': 'iMessage', 'skills.channelIcon.telegram': 'Telegram', 'skills.channelIcon.web': 'Web', + 'skills.channelIcon.yuanbao': 'Yuanbao', 'skills.composio.poweredBy': 'Powered by Composio', 'skills.composio.staleStatusTitle': 'Connections are showing stale status', 'skills.create.allowedToolsHelp': 'Rendered into the SKILL.md frontmatter as', diff --git a/app/src/lib/i18n/chunks/bn-5.ts b/app/src/lib/i18n/chunks/bn-5.ts index ce43e37f19..140ddd0a9b 100644 --- a/app/src/lib/i18n/chunks/bn-5.ts +++ b/app/src/lib/i18n/chunks/bn-5.ts @@ -626,6 +626,7 @@ const bn5: TranslationMap = { 'skills.channelIcon.imessage': 'iMessage', 'skills.channelIcon.telegram': 'Telegram', 'skills.channelIcon.web': 'Web', + 'skills.channelIcon.yuanbao': 'Yuanbao', 'skills.composio.poweredBy': 'Powered by Composio', 'skills.composio.staleStatusTitle': 'Connections are showing stale status', 'skills.create.allowedToolsHelp': 'Rendered into the SKILL.md frontmatter as', diff --git a/app/src/lib/i18n/chunks/de-5.ts b/app/src/lib/i18n/chunks/de-5.ts index 49271c9bc0..4369195548 100644 --- a/app/src/lib/i18n/chunks/de-5.ts +++ b/app/src/lib/i18n/chunks/de-5.ts @@ -648,6 +648,7 @@ const de5: TranslationMap = { 'skills.channelIcon.imessage': 'iMessage', 'skills.channelIcon.telegram': 'Telegram', 'skills.channelIcon.web': 'Web', + 'skills.channelIcon.yuanbao': 'Yuanbao', 'skills.composio.poweredBy': 'Powered by Composio', 'skills.composio.staleStatusTitle': 'Connections are showing stale status', 'skills.create.allowedToolsHelp': 'Rendered into the SKILL.md frontmatter as', diff --git a/app/src/lib/i18n/chunks/en-5.ts b/app/src/lib/i18n/chunks/en-5.ts index 3733f25ca6..58bac49567 100644 --- a/app/src/lib/i18n/chunks/en-5.ts +++ b/app/src/lib/i18n/chunks/en-5.ts @@ -625,6 +625,7 @@ const en5: TranslationMap = { 'skills.channelIcon.imessage': 'iMessage', 'skills.channelIcon.telegram': 'Telegram', 'skills.channelIcon.web': 'Web', + 'skills.channelIcon.yuanbao': 'Yuanbao', 'skills.composio.poweredBy': 'Powered by Composio', 'skills.composio.staleStatusTitle': 'Connections are showing stale status', 'skills.create.allowedToolsHelp': 'Rendered into the SKILL.md frontmatter as', diff --git a/app/src/lib/i18n/chunks/es-5.ts b/app/src/lib/i18n/chunks/es-5.ts index 80cfb8212f..c69251ad7b 100644 --- a/app/src/lib/i18n/chunks/es-5.ts +++ b/app/src/lib/i18n/chunks/es-5.ts @@ -632,6 +632,7 @@ const es5: TranslationMap = { 'skills.channelIcon.imessage': 'iMessage', 'skills.channelIcon.telegram': 'Telegram', 'skills.channelIcon.web': 'Web', + 'skills.channelIcon.yuanbao': 'Yuanbao', 'skills.composio.poweredBy': 'Powered by Composio', 'skills.composio.staleStatusTitle': 'Connections are showing stale status', 'skills.create.allowedToolsHelp': 'Rendered into the SKILL.md frontmatter as', diff --git a/app/src/lib/i18n/chunks/fr-5.ts b/app/src/lib/i18n/chunks/fr-5.ts index 10c0496cbe..f7103604b2 100644 --- a/app/src/lib/i18n/chunks/fr-5.ts +++ b/app/src/lib/i18n/chunks/fr-5.ts @@ -636,6 +636,7 @@ const fr5: TranslationMap = { 'skills.channelIcon.imessage': 'iMessage', 'skills.channelIcon.telegram': 'Telegram', 'skills.channelIcon.web': 'Web', + 'skills.channelIcon.yuanbao': 'Yuanbao', 'skills.composio.poweredBy': 'Powered by Composio', 'skills.composio.staleStatusTitle': 'Connections are showing stale status', 'skills.create.allowedToolsHelp': 'Rendered into the SKILL.md frontmatter as', diff --git a/app/src/lib/i18n/chunks/hi-5.ts b/app/src/lib/i18n/chunks/hi-5.ts index 2726644e67..8b06fa0c08 100644 --- a/app/src/lib/i18n/chunks/hi-5.ts +++ b/app/src/lib/i18n/chunks/hi-5.ts @@ -628,6 +628,7 @@ const hi5: TranslationMap = { 'skills.channelIcon.imessage': 'iMessage', 'skills.channelIcon.telegram': 'Telegram', 'skills.channelIcon.web': 'Web', + 'skills.channelIcon.yuanbao': 'Yuanbao', 'skills.composio.poweredBy': 'Powered by Composio', 'skills.composio.staleStatusTitle': 'Connections are showing stale status', 'skills.create.allowedToolsHelp': 'Rendered into the SKILL.md frontmatter as', diff --git a/app/src/lib/i18n/chunks/id-5.ts b/app/src/lib/i18n/chunks/id-5.ts index ed512d764a..794f3a4172 100644 --- a/app/src/lib/i18n/chunks/id-5.ts +++ b/app/src/lib/i18n/chunks/id-5.ts @@ -629,6 +629,7 @@ const id5: TranslationMap = { 'skills.channelIcon.imessage': 'iMessage', 'skills.channelIcon.telegram': 'Telegram', 'skills.channelIcon.web': 'Web', + 'skills.channelIcon.yuanbao': 'Yuanbao', 'skills.composio.poweredBy': 'Powered by Composio', 'skills.composio.staleStatusTitle': 'Connections are showing stale status', 'skills.create.allowedToolsHelp': 'Rendered into the SKILL.md frontmatter as', diff --git a/app/src/lib/i18n/chunks/it-5.ts b/app/src/lib/i18n/chunks/it-5.ts index f998f416df..81594aa856 100644 --- a/app/src/lib/i18n/chunks/it-5.ts +++ b/app/src/lib/i18n/chunks/it-5.ts @@ -633,6 +633,7 @@ const it5: TranslationMap = { 'skills.channelIcon.imessage': 'iMessage', 'skills.channelIcon.telegram': 'Telegram', 'skills.channelIcon.web': 'Web', + 'skills.channelIcon.yuanbao': 'Yuanbao', 'skills.composio.poweredBy': 'Powered by Composio', 'skills.composio.staleStatusTitle': 'Connections are showing stale status', 'skills.create.allowedToolsHelp': 'Rendered into the SKILL.md frontmatter as', diff --git a/app/src/lib/i18n/chunks/ko-5.ts b/app/src/lib/i18n/chunks/ko-5.ts index 634d4531b8..424a6dc666 100644 --- a/app/src/lib/i18n/chunks/ko-5.ts +++ b/app/src/lib/i18n/chunks/ko-5.ts @@ -628,6 +628,7 @@ const ko5: TranslationMap = { 'skills.channelIcon.imessage': 'iMessage', 'skills.channelIcon.telegram': 'Telegram', 'skills.channelIcon.web': 'Web', + 'skills.channelIcon.yuanbao': 'Yuanbao', 'skills.composio.poweredBy': 'Powered by Composio', 'skills.composio.staleStatusTitle': 'Connections are showing stale status', 'skills.create.allowedToolsHelp': 'Rendered into the SKILL.md frontmatter as', diff --git a/app/src/lib/i18n/chunks/pt-5.ts b/app/src/lib/i18n/chunks/pt-5.ts index 3cdaee3485..1138e5682c 100644 --- a/app/src/lib/i18n/chunks/pt-5.ts +++ b/app/src/lib/i18n/chunks/pt-5.ts @@ -633,6 +633,7 @@ const pt5: TranslationMap = { 'skills.channelIcon.imessage': 'iMessage', 'skills.channelIcon.telegram': 'Telegram', 'skills.channelIcon.web': 'Web', + 'skills.channelIcon.yuanbao': 'Yuanbao', 'skills.composio.poweredBy': 'Powered by Composio', 'skills.composio.staleStatusTitle': 'Connections are showing stale status', 'skills.create.allowedToolsHelp': 'Rendered into the SKILL.md frontmatter as', diff --git a/app/src/lib/i18n/chunks/ru-5.ts b/app/src/lib/i18n/chunks/ru-5.ts index 39e6e8e97c..3fd4b6dc5a 100644 --- a/app/src/lib/i18n/chunks/ru-5.ts +++ b/app/src/lib/i18n/chunks/ru-5.ts @@ -630,6 +630,7 @@ const ru5: TranslationMap = { 'skills.channelIcon.imessage': 'iMessage', 'skills.channelIcon.telegram': 'Telegram', 'skills.channelIcon.web': 'Web', + 'skills.channelIcon.yuanbao': 'Yuanbao', 'skills.composio.poweredBy': 'Powered by Composio', 'skills.composio.staleStatusTitle': 'Connections are showing stale status', 'skills.create.allowedToolsHelp': 'Rendered into the SKILL.md frontmatter as', diff --git a/app/src/lib/i18n/chunks/zh-CN-5.ts b/app/src/lib/i18n/chunks/zh-CN-5.ts index 8c2985f781..0d19efcf03 100644 --- a/app/src/lib/i18n/chunks/zh-CN-5.ts +++ b/app/src/lib/i18n/chunks/zh-CN-5.ts @@ -600,6 +600,7 @@ const zhCN5: TranslationMap = { 'skills.channelIcon.imessage': 'iMessage', 'skills.channelIcon.telegram': 'Telegram', 'skills.channelIcon.web': 'Web', + 'skills.channelIcon.yuanbao': '元宝', 'skills.composio.poweredBy': 'Powered by Composio', 'skills.composio.staleStatusTitle': 'Connections are showing stale status', 'skills.create.allowedToolsHelp': 'Rendered into the SKILL.md frontmatter as', diff --git a/app/src/lib/i18n/en.ts b/app/src/lib/i18n/en.ts index 80d227edc7..64422dc120 100644 --- a/app/src/lib/i18n/en.ts +++ b/app/src/lib/i18n/en.ts @@ -3075,6 +3075,7 @@ const en: TranslationMap = { 'skills.channelIcon.imessage': 'iMessage', 'skills.channelIcon.telegram': 'Telegram', 'skills.channelIcon.web': 'Web', + 'skills.channelIcon.yuanbao': 'Yuanbao', 'skills.composio.poweredBy': 'Powered by Composio', 'skills.composio.staleStatusTitle': 'Connections are showing stale status', 'skills.create.allowedTools': 'Allowed tools',