diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index e1550de1..06c27fec 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -2728,6 +2728,7 @@ name = "opencodeui" version = "0.6.10" dependencies = [ "futures-util", + "libc", "log", "papaya", "rapidhash", diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 8d94c63d..d9525e66 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -14,6 +14,7 @@ name = "app_lib" [dependencies] futures-util = "0.3" +libc = "0.2" log = "0.4" papaya = "0.2.3" rapidhash = { version = "4.4.1", features = ["unsafe"] } diff --git a/src-tauri/src/app/commands/opencode.rs b/src-tauri/src/app/commands/opencode.rs index 183032d1..99e858f5 100644 --- a/src-tauri/src/app/commands/opencode.rs +++ b/src-tauri/src/app/commands/opencode.rs @@ -4,13 +4,42 @@ // ============================================ use crate::app::service::ServiceState; +use reqwest::StatusCode; +use serde::Serialize; use std::{ + collections::HashMap, + future::Future, + pin::Pin, process::{Command, Stdio}, sync::atomic::Ordering, time::Duration, }; use tauri::State; +type HealthCheckFuture = Pin + Send>>; +type HealthCheckFn = dyn Fn() -> HealthCheckFuture + Send + Sync; +type SpawnFn = dyn Fn(&str, &HashMap) -> Result + Send + Sync; + +const START_READINESS_ATTEMPTS: usize = 30; +const START_READINESS_DELAY: Duration = Duration::from_millis(500); + +#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)] +#[serde(rename_all = "camelCase")] +pub struct StartOpencodeServiceResult { + pub spawned_now: bool, + pub app_owned: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum CloseServiceAction { + Stop { pid: Option }, + KeepRunning, +} + +fn is_running_health_status(status: StatusCode) -> bool { + status.is_success() || status == StatusCode::UNAUTHORIZED +} + /// 检查 opencode 服务是否在运行(通过 health endpoint) pub async fn is_service_running(url: &str) -> bool { let health_url = format!("{}/global/health", url.trim_end_matches('/')); @@ -23,7 +52,7 @@ pub async fn is_service_running(url: &str) -> bool { .timeout(Duration::from_secs(5)) .send() .await - .map(|r| r.status().is_success()) + .map(|r| is_running_health_status(r.status())) .unwrap_or(false), Err(_) => false, } @@ -32,7 +61,7 @@ pub async fn is_service_running(url: &str) -> bool { /// 启动 opencode serve 进程 fn spawn_opencode_serve( binary_path: &str, - env_vars: &std::collections::HashMap, + env_vars: &HashMap, ) -> Result { log::info!("Starting opencode serve with binary: {}", binary_path); if !env_vars.is_empty() { @@ -54,6 +83,12 @@ fn spawn_opencode_serve( cmd.creation_flags(CREATE_NO_WINDOW); } + #[cfg(unix)] + { + use std::os::unix::process::CommandExt; + cmd.process_group(0); + } + cmd.spawn().map_err(|e| { format!( "Failed to start '{}': {}. Check that the path is correct.", @@ -62,6 +97,105 @@ fn spawn_opencode_serve( }) } +async fn wait_for_service( + health_check: &HealthCheckFn, + readiness_attempts: usize, + readiness_delay: Duration, +) -> bool { + for _ in 0..readiness_attempts { + tokio::time::sleep(readiness_delay).await; + if health_check().await { + return true; + } + } + + false +} + +async fn start_opencode_service_inner( + state: &ServiceState, + url: &str, + binary_path: &str, + env_vars: &HashMap, + health_check: &HealthCheckFn, + spawn: &SpawnFn, + readiness_attempts: usize, + readiness_delay: Duration, +) -> Result { + let mut process = state.process.lock().await; + let owned_pid = process.child_pid; + let healthy = health_check().await; + + if healthy { + if let Some(pid) = owned_pid { + log::info!( + "opencode service already running at {} with app-owned PID {}", + url, + pid + ); + state.we_started.store(true, Ordering::SeqCst); + } else { + log::info!("opencode service already running externally at {}", url); + state.we_started.store(false, Ordering::SeqCst); + } + + return Ok(StartOpencodeServiceResult { + spawned_now: false, + app_owned: owned_pid.is_some(), + }); + } + + if let Some(pid) = owned_pid { + log::info!( + "opencode service already started by app with PID {} but health is not ready yet", + pid + ); + state.we_started.store(true, Ordering::SeqCst); + + if wait_for_service(health_check, readiness_attempts, readiness_delay).await { + log::info!("opencode service is ready at {}", url); + } else { + log::warn!( + "opencode service still not healthy for existing app-owned PID {}", + pid + ); + } + + return Ok(StartOpencodeServiceResult { + spawned_now: false, + app_owned: true, + }); + } + + let pid = spawn(binary_path, env_vars)?; + + if let Err(error) = state.register_spawned_pid_locked(&mut process, pid) { + drop(process); + log::warn!( + "Failed to persist app-owned PID {} after spawn, stopping backend again: {}", + pid, + error + ); + kill_process_by_pid(pid); + return Err(error); + } + + drop(process); + + log::info!("Started opencode serve, PID: {}", pid); + + if wait_for_service(health_check, readiness_attempts, readiness_delay).await { + log::info!("opencode service is ready at {}", url); + } else { + log::warn!("opencode service started but health check not passing yet"); + } + + Ok(StartOpencodeServiceResult { + spawned_now: true, + app_owned: true, + }) +} + /// 跨平台杀进程 pub fn kill_process_by_pid(pid: u32) { #[cfg(target_os = "windows")] @@ -78,11 +212,63 @@ pub fn kill_process_by_pid(pid: u32) { #[cfg(not(target_os = "windows"))] { - let _ = Command::new("kill") - .arg(pid.to_string()) - .stdout(Stdio::null()) - .stderr(Stdio::null()) - .spawn(); + let signal_pid = match unix_signal_pid(pid) { + Ok(signal_pid) => signal_pid, + Err(error) => { + log::warn!("Refusing to signal PID {}: {}", pid, error); + return; + } + }; + + let process_group_pid = signal_pid + .checked_neg() + .expect("validated signal pid should always be positive"); + + if let Err(group_error) = send_sigterm(process_group_pid) { + if group_error.raw_os_error() == Some(libc::ESRCH) { + log::info!( + "Process group {} not found for PID {}, falling back to direct SIGTERM", + process_group_pid, + pid + ); + + if let Err(single_error) = send_sigterm(signal_pid) { + log::warn!( + "Failed to SIGTERM legacy PID {} after process-group ESRCH fallback: {}", + pid, + single_error + ); + } + + return; + } + + log::warn!( + "Failed to SIGTERM process group {} for PID {}: {}", + process_group_pid, + pid, + group_error + ); + } + } +} + +#[cfg(not(target_os = "windows"))] +fn unix_signal_pid(pid: u32) -> Result { + if pid == 0 { + return Err("PID 0 is reserved and would signal the current process group"); + } + + libc::pid_t::try_from(pid).map_err(|_| "PID does not fit into libc::pid_t") +} + +#[cfg(not(target_os = "windows"))] +fn send_sigterm(target_pid: libc::pid_t) -> std::io::Result<()> { + let result = unsafe { libc::kill(target_pid, libc::SIGTERM) }; + if result == 0 { + Ok(()) + } else { + Err(std::io::Error::last_os_error()) } } @@ -98,39 +284,47 @@ pub async fn start_opencode_service( state: State<'_, ServiceState>, url: String, binary_path: String, - env_vars: std::collections::HashMap, -) -> Result { - if is_service_running(&url).await { - log::info!("opencode service already running at {}", url); - return Ok(false); - } - - let child = spawn_opencode_serve(&binary_path, &env_vars)?; - let pid = child.id(); - log::info!("Started opencode serve, PID: {}", pid); + env_vars: HashMap, +) -> Result { + let health_url = url.clone(); - state.child_pid.store(pid, Ordering::SeqCst); - state.we_started.store(true, Ordering::SeqCst); + start_opencode_service_inner( + state.inner(), + &url, + &binary_path, + &env_vars, + &move || { + let health_url = health_url.clone(); + Box::pin(async move { is_service_running(&health_url).await }) + }, + &|spawn_binary_path, spawn_env_vars| { + spawn_opencode_serve(spawn_binary_path, spawn_env_vars).map(|child| child.id()) + }, + START_READINESS_ATTEMPTS, + START_READINESS_DELAY, + ) + .await +} - for _ in 0..30 { - tokio::time::sleep(Duration::from_millis(500)).await; - if is_service_running(&url).await { - log::info!("opencode service is ready at {}", url); - return Ok(true); +async fn prepare_close_service_action( + state: &ServiceState, + stop_service: bool, +) -> CloseServiceAction { + if stop_service { + CloseServiceAction::Stop { + pid: state.take_owned_pid_for_shutdown().await, } + } else { + CloseServiceAction::KeepRunning } - - log::warn!("opencode service started but health check not passing yet"); - Ok(true) } /// 停止 opencode serve #[tauri::command] pub async fn stop_opencode_service(state: State<'_, ServiceState>) -> Result<(), String> { - let pid = state.child_pid.swap(0, Ordering::SeqCst); - state.we_started.store(false, Ordering::SeqCst); + let pid = state.take_owned_pid_for_shutdown().await; - if pid > 0 { + if let Some(pid) = pid { log::info!("Stopping opencode serve, PID: {}", pid); kill_process_by_pid(pid); } @@ -151,16 +345,523 @@ pub async fn confirm_close_app( state: State<'_, ServiceState>, stop_service: bool, ) -> Result<(), String> { - if stop_service { - let pid = state.child_pid.swap(0, Ordering::SeqCst); - if pid > 0 { - log::info!("Closing app and stopping opencode serve, PID: {}", pid); - kill_process_by_pid(pid); + match prepare_close_service_action(state.inner(), stop_service).await { + CloseServiceAction::Stop { pid } => { + if let Some(pid) = pid { + log::info!("Closing app and stopping opencode serve, PID: {}", pid); + kill_process_by_pid(pid); + } + } + CloseServiceAction::KeepRunning => { + log::info!("Closing app, keeping opencode serve running"); } - state.we_started.store(false, Ordering::SeqCst); - } else { - log::info!("Closing app, keeping opencode serve running"); } window.destroy().map_err(|e| e.to_string()) } + +#[cfg(test)] +mod tests { + use super::{ + is_running_health_status, prepare_close_service_action, start_opencode_service_inner, + CloseServiceAction, HealthCheckFuture, StartOpencodeServiceResult, + }; + use crate::app::service::ServiceState; + use reqwest::StatusCode; + use std::{ + collections::HashMap, + fs, + path::PathBuf, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + time::{Duration, SystemTime, UNIX_EPOCH}, + }; + use tokio::sync::Barrier; + + #[cfg(unix)] + use super::{kill_process_by_pid, unix_signal_pid}; + + #[cfg(unix)] + use std::{os::unix::process::CommandExt, process::{Child, Command, Stdio}, thread}; + + const TEST_URL: &str = "http://127.0.0.1:4096"; + const TEST_BINARY_PATH: &str = "opencode"; + + fn unique_marker_path(test_name: &str) -> PathBuf { + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("clock should be after epoch") + .as_nanos(); + + std::env::temp_dir().join(format!( + "opencodeui-opencode-tests-{test_name}-{}-{timestamp}.json", + std::process::id() + )) + } + + #[cfg(unix)] + fn unique_pid_path(test_name: &str) -> PathBuf { + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("clock should be after epoch") + .as_nanos(); + + std::env::temp_dir().join(format!( + "opencodeui-opencode-tests-{test_name}-{}-{timestamp}.pid", + std::process::id() + )) + } + + #[cfg(unix)] + fn spawn_waiting_shell(pid_file_path: &PathBuf, create_process_group: bool) -> Child { + let mut command = Command::new("sh"); + command + .args([ + "-c", + "sleep 600 & child=$!; printf \"%s\" \"$child\" > \"$1\"; wait \"$child\"", + "sh", + pid_file_path + .to_str() + .expect("temporary pid file path should be valid UTF-8"), + ]) + .stdout(Stdio::null()) + .stderr(Stdio::null()); + + if create_process_group { + command.process_group(0); + } + + command.spawn().expect("shell process should spawn") + } + + #[cfg(unix)] + fn read_pid_with_retry(pid_file_path: &PathBuf) -> u32 { + for _ in 0..100 { + if let Ok(contents) = fs::read_to_string(pid_file_path) { + if let Ok(pid) = contents.trim().parse::() { + return pid; + } + } + + thread::sleep(Duration::from_millis(20)); + } + + panic!( + "child pid file '{}' was not populated in time", + pid_file_path.display() + ); + } + + #[cfg(unix)] + fn pid_exists(pid: u32) -> bool { + let Ok(signal_pid) = unix_signal_pid(pid) else { + return false; + }; + + let result = unsafe { libc::kill(signal_pid, 0) }; + if result == 0 { + true + } else { + matches!(std::io::Error::last_os_error().raw_os_error(), Some(libc::EPERM)) + } + } + + #[cfg(unix)] + fn wait_for_process_exit(process: &mut Child) { + for _ in 0..200 { + if process + .try_wait() + .expect("waiting for process status should succeed") + .is_some() + { + return; + } + + thread::sleep(Duration::from_millis(20)); + } + + panic!("process {} did not exit before timeout", process.id()); + } + + #[cfg(unix)] + fn wait_for_pid_gone(pid: u32) { + for _ in 0..200 { + if !pid_exists(pid) { + return; + } + + thread::sleep(Duration::from_millis(20)); + } + + panic!("pid {} still appears alive after timeout", pid); + } + + #[test] + fn unauthorized_health_response_counts_as_running() { + assert!(is_running_health_status(StatusCode::OK)); + assert!(is_running_health_status(StatusCode::UNAUTHORIZED)); + assert!(!is_running_health_status(StatusCode::FORBIDDEN)); + assert!(!is_running_health_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + + #[cfg(unix)] + #[test] + fn unix_process_group_shutdown_terminates_child_process() { + let pid_file_path = unique_pid_path("unix-process-group-shutdown"); + let mut parent = spawn_waiting_shell(&pid_file_path, true); + let parent_pid = parent.id(); + let child_pid = read_pid_with_retry(&pid_file_path); + + assert!(pid_exists(parent_pid)); + assert!(pid_exists(child_pid)); + + kill_process_by_pid(parent_pid); + + wait_for_process_exit(&mut parent); + wait_for_pid_gone(parent_pid); + wait_for_pid_gone(child_pid); + + let _ = fs::remove_file(pid_file_path); + } + + #[cfg(unix)] + #[test] + fn legacy_single_pid_kill_fallback_terminates_process_after_esrch() { + let mut process = Command::new("sleep") + .arg("600") + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .spawn() + .expect("sleep process should spawn"); + let pid = process.id(); + + assert!(pid_exists(pid)); + + kill_process_by_pid(pid); + + wait_for_process_exit(&mut process); + wait_for_pid_gone(pid); + } + + #[cfg(unix)] + #[test] + fn rejects_zero_pid_before_signal() { + assert_eq!(unix_signal_pid(0), Err("PID 0 is reserved and would signal the current process group")); + assert_eq!(unix_signal_pid(u32::MAX), Err("PID does not fit into libc::pid_t")); + } + + #[tokio::test] + async fn concurrent_start_spawns_once() { + let state = Arc::new(ServiceState::default()); + let barrier = Arc::new(Barrier::new(8)); + let health_call_count = Arc::new(AtomicUsize::new(0)); + let spawn_count = Arc::new(AtomicUsize::new(0)); + + let handles: Vec<_> = (0..8) + .map(|_| { + let state = Arc::clone(&state); + let barrier = Arc::clone(&barrier); + let health_call_count = Arc::clone(&health_call_count); + let spawn_count = Arc::clone(&spawn_count); + + tokio::spawn(async move { + let env_vars = HashMap::new(); + let health_check = move || -> HealthCheckFuture { + let health_call_count = Arc::clone(&health_call_count); + Box::pin(async move { + let call_index = health_call_count.fetch_add(1, Ordering::SeqCst); + call_index >= 2 + }) + }; + let spawn = move |_binary_path: &str, _env_vars: &HashMap| { + let spawn_index = spawn_count.fetch_add(1, Ordering::SeqCst); + Ok(10_000 + spawn_index as u32) + }; + + barrier.wait().await; + + start_opencode_service_inner( + state.as_ref(), + TEST_URL, + TEST_BINARY_PATH, + &env_vars, + &health_check, + &spawn, + 4, + Duration::from_millis(1), + ) + .await + }) + }) + .collect(); + + let mut spawned_results = 0; + + for handle in handles { + let result = handle.await.expect("task should join").expect("start should succeed"); + if result.spawned_now { + spawned_results += 1; + } + + assert!(result.app_owned); + } + + assert_eq!(spawned_results, 1); + assert_eq!(spawn_count.load(Ordering::SeqCst), 1); + assert_eq!(state.process.lock().await.child_pid, Some(10_000)); + assert!(state.we_started.load(Ordering::SeqCst)); + } + + #[tokio::test] + async fn external_running_service_not_owned() { + let state = ServiceState::default(); + let env_vars = HashMap::new(); + let spawn_count = Arc::new(AtomicUsize::new(0)); + let health_check = || -> HealthCheckFuture { Box::pin(async { true }) }; + let spawn_count_for_closure = Arc::clone(&spawn_count); + let spawn = move |_binary_path: &str, _env_vars: &HashMap| { + spawn_count_for_closure.fetch_add(1, Ordering::SeqCst); + Ok(20_000) + }; + + let result = start_opencode_service_inner( + &state, + TEST_URL, + TEST_BINARY_PATH, + &env_vars, + &health_check, + &spawn, + 2, + Duration::from_millis(1), + ) + .await + .expect("start should succeed"); + + assert_eq!( + result, + StartOpencodeServiceResult { + spawned_now: false, + app_owned: false, + } + ); + assert_eq!(spawn_count.load(Ordering::SeqCst), 0); + assert_eq!(state.process.lock().await.child_pid, None); + assert!(!state.we_started.load(Ordering::SeqCst)); + } + + #[tokio::test] + async fn existing_owned_unhealthy_service_does_not_respawn() { + let state = ServiceState::default(); + state.set_child_pid(30_000).await; + + let env_vars = HashMap::new(); + let health_call_count = Arc::new(AtomicUsize::new(0)); + let spawn_count = Arc::new(AtomicUsize::new(0)); + let health_check = move || -> HealthCheckFuture { + let health_call_count = Arc::clone(&health_call_count); + Box::pin(async move { + let call_index = health_call_count.fetch_add(1, Ordering::SeqCst); + call_index >= 2 + }) + }; + let spawn_count_for_closure = Arc::clone(&spawn_count); + let spawn = move |_binary_path: &str, _env_vars: &HashMap| { + spawn_count_for_closure.fetch_add(1, Ordering::SeqCst); + Ok(30_001) + }; + + let result = start_opencode_service_inner( + &state, + TEST_URL, + TEST_BINARY_PATH, + &env_vars, + &health_check, + &spawn, + 4, + Duration::from_millis(1), + ) + .await + .expect("start should succeed"); + + assert_eq!( + result, + StartOpencodeServiceResult { + spawned_now: false, + app_owned: true, + } + ); + assert_eq!(spawn_count.load(Ordering::SeqCst), 0); + assert_eq!(state.process.lock().await.child_pid, Some(30_000)); + assert!(state.we_started.load(Ordering::SeqCst)); + } + + #[tokio::test] + async fn spawn_timeout_preserves_pid_and_returns_true() { + let state = ServiceState::default(); + let env_vars = HashMap::new(); + let spawn_count = Arc::new(AtomicUsize::new(0)); + let health_check = || -> HealthCheckFuture { Box::pin(async { false }) }; + let spawn_count_for_closure = Arc::clone(&spawn_count); + let spawn = move |_binary_path: &str, _env_vars: &HashMap| { + let spawn_index = spawn_count_for_closure.fetch_add(1, Ordering::SeqCst); + Ok(40_000 + spawn_index as u32) + }; + + let result = start_opencode_service_inner( + &state, + TEST_URL, + TEST_BINARY_PATH, + &env_vars, + &health_check, + &spawn, + 2, + Duration::from_millis(1), + ) + .await + .expect("start should succeed"); + + assert_eq!( + result, + StartOpencodeServiceResult { + spawned_now: true, + app_owned: true, + } + ); + assert_eq!(spawn_count.load(Ordering::SeqCst), 1); + assert_eq!(state.process.lock().await.child_pid, Some(40_000)); + assert!(state.we_started.load(Ordering::SeqCst)); + } + + #[tokio::test] + async fn healthy_owned_service_reports_app_owned_without_respawn() { + let state = ServiceState::default(); + state.set_child_pid(45_000).await; + let env_vars = HashMap::new(); + let spawn_count = Arc::new(AtomicUsize::new(0)); + let health_check = || -> HealthCheckFuture { Box::pin(async { true }) }; + let spawn_count_for_closure = Arc::clone(&spawn_count); + let spawn = move |_binary_path: &str, _env_vars: &HashMap| { + spawn_count_for_closure.fetch_add(1, Ordering::SeqCst); + Ok(45_001) + }; + + let result = start_opencode_service_inner( + &state, + TEST_URL, + TEST_BINARY_PATH, + &env_vars, + &health_check, + &spawn, + 2, + Duration::from_millis(1), + ) + .await + .expect("start should succeed"); + + assert_eq!( + result, + StartOpencodeServiceResult { + spawned_now: false, + app_owned: true, + } + ); + assert_eq!(spawn_count.load(Ordering::SeqCst), 0); + assert_eq!(state.process.lock().await.child_pid, Some(45_000)); + assert!(state.we_started.load(Ordering::SeqCst)); + } + + #[tokio::test] + async fn reopened_keep_running_backend_remains_app_owned_for_final_close() { + let marker_path = unique_marker_path("reopen-keep-running"); + let owned_pid = std::process::id(); + let initial_state = ServiceState::new(marker_path.clone()); + + initial_state + .register_spawned_pid(owned_pid) + .await + .expect("initial ownership marker should persist"); + + let keep_running_action = prepare_close_service_action(&initial_state, false).await; + assert_eq!(keep_running_action, CloseServiceAction::KeepRunning); + assert_eq!(initial_state.process.lock().await.child_pid, Some(owned_pid)); + assert!(initial_state.we_started.load(Ordering::SeqCst)); + + let reopened_state = ServiceState::new(marker_path.clone()); + assert_eq!(reopened_state.process.lock().await.child_pid, Some(owned_pid)); + assert!(reopened_state.we_started.load(Ordering::SeqCst)); + + let env_vars = HashMap::new(); + let spawn_count = Arc::new(AtomicUsize::new(0)); + let health_check = || -> HealthCheckFuture { Box::pin(async { true }) }; + let spawn_count_for_closure = Arc::clone(&spawn_count); + let spawn = move |_binary_path: &str, _env_vars: &HashMap| { + spawn_count_for_closure.fetch_add(1, Ordering::SeqCst); + Ok(owned_pid + 1) + }; + + let result = start_opencode_service_inner( + &reopened_state, + TEST_URL, + TEST_BINARY_PATH, + &env_vars, + &health_check, + &spawn, + 2, + Duration::from_millis(1), + ) + .await + .expect("reopened start should succeed"); + + assert_eq!( + result, + StartOpencodeServiceResult { + spawned_now: false, + app_owned: true, + } + ); + assert_eq!(spawn_count.load(Ordering::SeqCst), 0); + assert!(reopened_state.we_started.load(Ordering::SeqCst)); + + let final_close_action = prepare_close_service_action(&reopened_state, true).await; + assert_eq!( + final_close_action, + CloseServiceAction::Stop { + pid: Some(owned_pid), + } + ); + assert!(!reopened_state.we_started.load(Ordering::SeqCst)); + assert!(!marker_path.exists()); + + let _ = fs::remove_file(marker_path); + } + + #[tokio::test] + async fn stop_takes_owned_pid_once() { + let state = ServiceState::default(); + state.set_child_pid(50_000).await; + state.we_started.store(true, Ordering::SeqCst); + + let first_stop_pid = state.take_owned_pid_for_shutdown().await; + let second_stop_pid = state.take_owned_pid_for_shutdown().await; + + assert_eq!(first_stop_pid, Some(50_000)); + assert_eq!(second_stop_pid, None); + assert_eq!(state.process.lock().await.child_pid, None); + assert!(!state.we_started.load(Ordering::SeqCst)); + } + + #[tokio::test] + async fn confirm_close_keep_service_preserves_pid() { + let state = ServiceState::default(); + state.set_child_pid(60_000).await; + state.we_started.store(true, Ordering::SeqCst); + + let action = prepare_close_service_action(&state, false).await; + + assert_eq!(action, CloseServiceAction::KeepRunning); + assert_eq!(state.process.lock().await.child_pid, Some(60_000)); + assert!(state.we_started.load(Ordering::SeqCst)); + } +} diff --git a/src-tauri/src/app/mod.rs b/src-tauri/src/app/mod.rs index 1abace0e..e7ff3996 100644 --- a/src-tauri/src/app/mod.rs +++ b/src-tauri/src/app/mod.rs @@ -188,6 +188,9 @@ pub fn run() { #[cfg(not(target_os = "android"))] { + let service_state_path = app.path().app_data_dir()?.join("owned-opencode-service.json"); + app.manage(service::ServiceState::new(service_state_path)); + let main_window = create_main_window(&app.handle())?; finish_desktop_window_setup(&main_window); @@ -221,7 +224,6 @@ pub fn run() { // Desktop: 注册 service management commands + 窗口关闭拦截 #[cfg(not(target_os = "android"))] let builder = builder - .manage(service::ServiceState::default()) .on_window_event(|window, event| { match event { tauri::WindowEvent::CloseRequested { api, .. } => { diff --git a/src-tauri/src/app/service.rs b/src-tauri/src/app/service.rs index 23bbb1ae..e2b1c74a 100644 --- a/src-tauri/src/app/service.rs +++ b/src-tauri/src/app/service.rs @@ -1,18 +1,643 @@ -use std::sync::atomic::{AtomicBool, AtomicU32}; +use serde::{Deserialize, Serialize}; +use std::{ + fs, + path::{Path, PathBuf}, + process::Command, + sync::atomic::{AtomicBool, Ordering}, +}; + +use tokio::sync::Mutex; + +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] +pub struct ServiceProcess { + pub child_pid: Option, +} + +#[derive(Debug, Clone, Serialize, PartialEq, Eq)] +struct OwnedServiceMarker { + pid: u32, + process_key: String, +} + +#[derive(Debug, Clone, Deserialize, PartialEq, Eq)] +struct OwnedServiceMarkerOnDisk { + pid: u32, + process_key: Option, + started_at: Option, +} /// 跟踪我们是否启动了 opencode serve 进程 pub struct ServiceState { - /// 我们启动的子进程 PID - pub child_pid: AtomicU32, + /// App 自己持有的后端进程状态必须经过此 mutex 串行化,避免并发命令覆盖彼此的 PID。 + pub process: Mutex, /// 是否由我们启动(用于关闭时判断是否需要询问) pub we_started: AtomicBool, + ownership_marker_path: Option, +} + +impl ServiceState { + pub fn new(ownership_marker_path: PathBuf) -> Self { + let restored_pid = Self::restore_owned_pid(&ownership_marker_path); + + Self { + process: Mutex::new(ServiceProcess { + child_pid: restored_pid, + }), + we_started: AtomicBool::new(restored_pid.is_some()), + ownership_marker_path: Some(ownership_marker_path), + } + } + + fn restore_owned_pid(ownership_marker_path: &Path) -> Option { + Self::restore_owned_pid_with( + ownership_marker_path, + current_process_identity_key, + current_legacy_process_started_at, + ) + } + + fn restore_owned_pid_with( + ownership_marker_path: &Path, + current_process_identity_key: IdentityKeyFn, + current_legacy_process_started_at: LegacyStartedAtFn, + ) -> Option + where + IdentityKeyFn: Fn(u32) -> Option, + LegacyStartedAtFn: Fn(u32) -> Option, + { + let marker = match Self::load_ownership_marker(ownership_marker_path) { + Some(marker) => marker, + None => return None, + }; + + if let Some(stored_process_key) = marker.process_key.as_deref() { + let Some(current_process_key) = current_process_identity_key(marker.pid) else { + Self::clear_ownership_marker_file(ownership_marker_path); + return None; + }; + + if stored_process_key == current_process_key { + return Some(marker.pid); + } + + Self::clear_ownership_marker_file(ownership_marker_path); + return None; + } + + if let Some(legacy_started_at) = marker.started_at.as_deref() { + if current_legacy_process_started_at(marker.pid).as_deref() == Some(legacy_started_at) { + if let Some(current_process_key) = current_process_identity_key(marker.pid) { + let migrated_marker = OwnedServiceMarker { + pid: marker.pid, + process_key: current_process_key, + }; + + if let Err(error) = Self::persist_ownership_marker_file( + ownership_marker_path, + &migrated_marker, + ) { + log::warn!( + "Failed to migrate legacy service ownership marker '{}': {}", + ownership_marker_path.display(), + error + ); + + Self::clear_ownership_marker_file(ownership_marker_path); + return None; + } + } + + return Some(marker.pid); + } + + // Legacy started_at markers depended on locale/timezone-sensitive output on Unix, + // so drifted values are unrecoverable and must be cleared instead of guessed. + Self::clear_ownership_marker_file(ownership_marker_path); + return None; + } + + Self::clear_ownership_marker_file(ownership_marker_path); + None + } + + fn load_ownership_marker(ownership_marker_path: &Path) -> Option { + let marker_json = fs::read_to_string(ownership_marker_path).ok()?; + serde_json::from_str(&marker_json).ok() + } + + fn clear_ownership_marker_file(ownership_marker_path: &Path) { + if let Err(error) = fs::remove_file(ownership_marker_path) { + if error.kind() != std::io::ErrorKind::NotFound { + log::warn!( + "Failed to clear service ownership marker '{}': {}", + ownership_marker_path.display(), + error + ); + } + } + } + + fn persist_ownership_marker(&self, marker: &OwnedServiceMarker) -> Result<(), String> { + let Some(ownership_marker_path) = self.ownership_marker_path.as_ref() else { + return Ok(()); + }; + + Self::persist_ownership_marker_file(ownership_marker_path, marker) + } + + fn persist_ownership_marker_file( + ownership_marker_path: &Path, + marker: &OwnedServiceMarker, + ) -> Result<(), String> { + if let Some(parent) = ownership_marker_path.parent() { + fs::create_dir_all(parent).map_err(|error| { + format!( + "Failed to create service ownership directory '{}': {}", + parent.display(), + error + ) + })?; + } + + let marker_json = serde_json::to_vec(marker) + .map_err(|error| format!("Failed to serialize service ownership marker: {}", error))?; + + fs::write(ownership_marker_path, marker_json).map_err(|error| { + format!( + "Failed to persist service ownership marker '{}': {}", + ownership_marker_path.display(), + error + ) + }) + } + + fn clear_ownership_marker(&self) { + if let Some(ownership_marker_path) = self.ownership_marker_path.as_ref() { + Self::clear_ownership_marker_file(ownership_marker_path); + } + } + + pub fn register_spawned_pid_locked( + &self, + process: &mut ServiceProcess, + pid: u32, + ) -> Result<(), String> { + let marker = if self.ownership_marker_path.is_some() { + Some(OwnedServiceMarker { + pid, + process_key: current_process_identity_key(pid).ok_or_else(|| { + format!("Failed to capture spawned process identity key for PID {}", pid) + })?, + }) + } else { + None + }; + + process.child_pid = Some(pid); + + if let Some(marker) = marker.as_ref() { + if let Err(error) = self.persist_ownership_marker(marker) { + process.child_pid = None; + self.we_started.store(false, Ordering::SeqCst); + return Err(error); + } + } + + self.we_started.store(true, Ordering::SeqCst); + Ok(()) + } + + #[cfg(test)] + pub async fn register_spawned_pid(&self, pid: u32) -> Result<(), String> { + let mut process = self.process.lock().await; + self.register_spawned_pid_locked(&mut process, pid) + } + + #[cfg(test)] + pub async fn set_child_pid(&self, pid: u32) { + self.process.lock().await.child_pid = Some(pid); + } + + #[cfg(test)] + pub async fn take_child_pid(&self) -> Option { + self.process.lock().await.child_pid.take() + } + + pub async fn take_owned_pid_for_shutdown(&self) -> Option { + let mut process = self.process.lock().await; + let pid = process.child_pid.take(); + self.we_started.store(false, Ordering::SeqCst); + self.clear_ownership_marker(); + drop(process); + pid + } + + #[cfg(test)] + pub async fn clear_child_pid(&self) { + self.process.lock().await.child_pid = None; + } } impl Default for ServiceState { fn default() -> Self { Self { - child_pid: AtomicU32::new(0), + process: Mutex::new(ServiceProcess::default()), we_started: AtomicBool::new(false), + ownership_marker_path: None, } } } + +fn current_process_identity_key(pid: u32) -> Option { + #[cfg(target_os = "windows")] + { + let started_at = current_windows_process_started_at_utc(pid)?; + Some(format!("windows:{started_at}")) + } + + #[cfg(target_os = "linux")] + { + current_linux_process_identity_key(pid) + } + + #[cfg(target_os = "macos")] + { + current_macos_process_identity_key(pid) + } + + #[cfg(all(unix, not(target_os = "linux"), not(target_os = "macos")))] + { + current_unix_legacy_process_started_at(pid) + } +} + +fn current_legacy_process_started_at(pid: u32) -> Option { + #[cfg(target_os = "windows")] + { + current_windows_process_started_at_utc(pid) + } + + #[cfg(target_os = "linux")] + { + current_unix_legacy_process_started_at(pid) + } + + #[cfg(target_os = "macos")] + { + current_unix_legacy_process_started_at(pid) + } + + #[cfg(all(unix, not(target_os = "linux"), not(target_os = "macos")))] + { + current_unix_legacy_process_started_at(pid) + } +} + +#[cfg(target_os = "windows")] +fn current_windows_process_started_at_utc(pid: u32) -> Option { + use std::os::windows::process::CommandExt; + + const CREATE_NO_WINDOW: u32 = 0x08000000; + + let mut command = Command::new("powershell"); + command.creation_flags(CREATE_NO_WINDOW); + + command + .args([ + "-NoLogo", + "-NoProfile", + "-NonInteractive", + "-Command", + &format!( + "$process = Get-Process -Id {} -ErrorAction SilentlyContinue; if ($process) {{ $process.StartTime.ToUniversalTime().ToString('o') }}", + pid + ), + ]) + .output() + .ok() + .filter(|output| output.status.success()) + .map(|output| String::from_utf8_lossy(&output.stdout).trim().to_string()) + .filter(|started_at| !started_at.is_empty()) +} + +#[cfg(unix)] +fn current_unix_legacy_process_started_at(pid: u32) -> Option { + Command::new("ps") + .args(["-p", &pid.to_string(), "-o", "lstart="]) + .output() + .ok() + .filter(|output| output.status.success()) + .map(|output| String::from_utf8_lossy(&output.stdout).trim().to_string()) + .filter(|started_at| !started_at.is_empty()) +} + +#[cfg(target_os = "macos")] +fn current_macos_process_identity_key(pid: u32) -> Option { + let pid = i32::try_from(pid).ok()?; + let mut mib = [libc::CTL_KERN, libc::KERN_PROC, libc::KERN_PROC_PID, pid]; + let mut info = std::mem::MaybeUninit::::uninit(); + let mut info_len = std::mem::size_of::(); + + let sysctl_result = unsafe { + libc::sysctl( + mib.as_mut_ptr(), + mib.len() as libc::c_uint, + info.as_mut_ptr().cast(), + &mut info_len, + std::ptr::null_mut(), + 0, + ) + }; + + if sysctl_result != 0 || info_len < std::mem::size_of::() { + return None; + } + + let start_time = unsafe { info.assume_init().kp_proc.p_un.p_starttime }; + format_macos_process_identity_key(start_time.tv_sec, start_time.tv_usec) +} + +#[cfg(target_os = "macos")] +fn format_macos_process_identity_key( + seconds: libc::time_t, + microseconds: libc::suseconds_t, +) -> Option { + if seconds == 0 && microseconds == 0 { + return None; + } + + Some(format!("macos:{seconds}:{microseconds}")) +} + +#[cfg(target_os = "linux")] +fn current_linux_process_identity_key(pid: u32) -> Option { + let boot_id = fs::read_to_string("/proc/sys/kernel/random/boot_id") + .ok()? + .trim() + .to_string(); + + if boot_id.is_empty() { + return None; + } + + let stat = fs::read_to_string(format!("/proc/{pid}/stat")).ok()?; + let starttime_ticks = parse_linux_proc_stat_starttime(&stat)?; + + Some(format!("linux:{boot_id}:{starttime_ticks}")) +} + +#[cfg(target_os = "linux")] +fn parse_linux_proc_stat_starttime(stat: &str) -> Option<&str> { + let stat = stat.trim(); + let comm_end = stat.rfind(')')?; + let remainder = stat.get(comm_end + 1..)?.trim_start(); + let fields: Vec<&str> = remainder.split_whitespace().collect(); + + fields.get(19).copied() +} + +#[cfg(test)] +mod tests { + use std::{ + fs, + path::PathBuf, + sync::atomic::Ordering, + time::{SystemTime, UNIX_EPOCH}, + }; + + use super::ServiceState; + + fn missing_process_identity_key(_: u32) -> Option { + None + } + + fn matching_legacy_started_at(_: u32) -> Option { + Some("legacy-start-time".to_string()) + } + + fn unique_marker_path(test_name: &str) -> PathBuf { + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("clock should be after epoch") + .as_nanos(); + + std::env::temp_dir().join(format!( + "opencodeui-{test_name}-{}-{timestamp}.json", + std::process::id() + )) + } + + #[tokio::test] + async fn service_state_default() { + let state = ServiceState::default(); + + assert_eq!(state.process.lock().await.child_pid, None); + assert!(!state.we_started.load(Ordering::SeqCst)); + } + + #[tokio::test] + async fn service_state_take_pid() { + let state = ServiceState::default(); + + state.set_child_pid(1234).await; + assert_eq!(state.process.lock().await.child_pid, Some(1234)); + + assert_eq!(state.take_child_pid().await, Some(1234)); + assert_eq!(state.process.lock().await.child_pid, None); + + state.set_child_pid(5678).await; + state.clear_child_pid().await; + assert_eq!(state.take_child_pid().await, None); + assert!(!state.we_started.load(Ordering::SeqCst)); + } + + #[tokio::test] + async fn service_state_restores_owned_pid_from_marker() { + let marker_path = unique_marker_path("restore-owned-pid-canonical"); + let pid = std::process::id(); + let state = ServiceState::new(marker_path.clone()); + + state + .register_spawned_pid(pid) + .await + .expect("marker should be written"); + + let state = ServiceState::new(marker_path.clone()); + + assert_eq!(state.process.lock().await.child_pid, Some(pid)); + assert!(state.we_started.load(Ordering::SeqCst)); + + let _ = fs::remove_file(marker_path); + } + + #[tokio::test] + async fn service_state_discards_stale_marker() { + let marker_path = unique_marker_path("discard-stale-canonical-marker"); + let pid = std::process::id(); + + fs::write( + &marker_path, + format!("{{\"pid\":{pid},\"process_key\":\"stale-process-key\"}}"), + ) + .expect("marker should be written"); + + let state = ServiceState::new(marker_path.clone()); + + assert_eq!(state.process.lock().await.child_pid, None); + assert!(!state.we_started.load(Ordering::SeqCst)); + assert!(!marker_path.exists()); + } + + #[tokio::test] + async fn register_spawned_pid_persists_marker() { + let marker_path = unique_marker_path("persist-process-key-marker"); + let pid = std::process::id(); + let state = ServiceState::new(marker_path.clone()); + + state + .register_spawned_pid(pid) + .await + .expect("marker persistence should succeed"); + + assert_eq!(state.process.lock().await.child_pid, Some(pid)); + assert!(state.we_started.load(Ordering::SeqCst)); + let marker = fs::read_to_string(&marker_path).expect("marker should exist"); + assert!(marker.contains(&format!("\"pid\":{pid}"))); + assert!(marker.contains("\"process_key\":\"")); + assert!(!marker.contains("\"started_at\":")); + + let taken_pid = state.take_owned_pid_for_shutdown().await; + assert_eq!(taken_pid, Some(pid)); + assert!(!marker_path.exists()); + } + + #[tokio::test] + async fn service_state_migrates_legacy_started_at_marker_to_canonical_process_key() { + let marker_path = unique_marker_path("migrate-legacy-started-at-marker"); + let pid = std::process::id(); + let legacy_started_at = super::current_legacy_process_started_at(pid) + .expect("legacy process identity should be readable for migration test"); + + fs::write( + &marker_path, + format!("{{\"pid\":{pid},\"started_at\":\"{legacy_started_at}\"}}"), + ) + .expect("legacy marker should be written"); + + let state = ServiceState::new(marker_path.clone()); + + assert_eq!(state.process.lock().await.child_pid, Some(pid)); + assert!(state.we_started.load(Ordering::SeqCst)); + + let migrated_marker = fs::read_to_string(&marker_path).expect("marker should be rewritten"); + assert!(migrated_marker.contains(&format!("\"pid\":{pid}"))); + assert!(migrated_marker.contains("\"process_key\":\"")); + assert!(!migrated_marker.contains("\"started_at\":")); + + let _ = fs::remove_file(marker_path); + } + + #[tokio::test] + async fn service_state_clears_unrecoverable_legacy_started_at_marker_after_locale_drift() { + let marker_path = unique_marker_path("clear-drifted-legacy-started-at-marker"); + let pid = std::process::id(); + + fs::write( + &marker_path, + format!("{{\"pid\":{pid},\"started_at\":\"locale-drifted-start-time\"}}"), + ) + .expect("legacy marker should be written"); + + let state = ServiceState::new(marker_path.clone()); + + assert_eq!(state.process.lock().await.child_pid, None); + assert!(!state.we_started.load(Ordering::SeqCst)); + assert!(!marker_path.exists()); + } + + #[test] + fn service_state_restores_legacy_marker_when_canonical_identity_is_unavailable_before_migration() { + let marker_path = unique_marker_path("restore-legacy-marker-without-canonical-identity"); + let pid = std::process::id(); + + fs::write( + &marker_path, + format!("{{\"pid\":{pid},\"started_at\":\"legacy-start-time\"}}"), + ) + .expect("legacy marker should be written"); + + let restored_pid = ServiceState::restore_owned_pid_with( + &marker_path, + missing_process_identity_key, + matching_legacy_started_at, + ); + + assert_eq!(restored_pid, Some(pid)); + let legacy_marker = fs::read_to_string(&marker_path).expect("legacy marker should remain for future migration"); + assert!(legacy_marker.contains(&format!("\"pid\":{pid}"))); + assert!(legacy_marker.contains("\"started_at\":\"legacy-start-time\"")); + assert!(!legacy_marker.contains("\"process_key\":")); + + let _ = fs::remove_file(marker_path); + } + + #[cfg(target_os = "linux")] + #[test] + fn parse_linux_proc_stat_starttime_simple_comm() { + let stat = "1234 (bash) S 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 987654 20 21 22"; + + assert_eq!(super::parse_linux_proc_stat_starttime(stat), Some("987654")); + } + + #[cfg(target_os = "linux")] + #[test] + fn parse_linux_proc_stat_starttime_comm_with_spaces() { + let stat = "1234 (code helper) S 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 7654321 20 21 22"; + + assert_eq!(super::parse_linux_proc_stat_starttime(stat), Some("7654321")); + } + + #[cfg(target_os = "linux")] + #[test] + fn parse_linux_proc_stat_starttime_comm_with_parentheses() { + let stat = "1234 (worker (beta) task) S 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 13579 20 21 22"; + + assert_eq!(super::parse_linux_proc_stat_starttime(stat), Some("13579")); + } + + #[cfg(target_os = "linux")] + #[test] + fn linux_current_process_identity_is_stable() { + let pid = std::process::id(); + let first = super::current_linux_process_identity_key(pid) + .expect("linux current process identity should be readable"); + let second = super::current_linux_process_identity_key(pid) + .expect("linux current process identity should remain readable"); + + assert!(first.starts_with("linux:")); + assert_eq!(first, second); + } + + #[cfg(target_os = "macos")] + #[test] + fn format_macos_process_identity_key() { + assert_eq!( + super::format_macos_process_identity_key(1_715_138_925, 42), + Some("macos:1715138925:42".to_string()) + ); + assert_eq!(super::format_macos_process_identity_key(0, 0), None); + } + + #[cfg(target_os = "macos")] + #[test] + fn macos_current_process_identity_is_stable() { + let pid = std::process::id(); + let first = super::current_macos_process_identity_key(pid) + .expect("macOS current process identity should be readable"); + let second = super::current_macos_process_identity_key(pid) + .expect("macOS current process identity should remain readable"); + + assert!(first.starts_with("macos:")); + assert_eq!(first, second); + } +} diff --git a/src/features/chat/ChatArea.test.ts b/src/features/chat/ChatArea.test.ts index 62725688..77476f52 100644 --- a/src/features/chat/ChatArea.test.ts +++ b/src/features/chat/ChatArea.test.ts @@ -1,7 +1,8 @@ import { describe, expect, it } from 'vitest' -import { buildTurnDurationMap } from './ChatArea' +import { buildMessageIdMap, buildTurnDurationMap } from './ChatArea' import { buildVisibleMessageEntries, getVisibleMessageForkTargetId } from './chatAreaVisibility' import type { Message, MessageError, Part, ToolPart, ReasoningPart } from '../../types/message' +import { isAssistantMessage } from '../../types/message' function createUserMessage(id: string, created: number): Message { return { @@ -157,3 +158,28 @@ describe('buildTurnDurationMap', () => { expect(durationMap.has('assistant-1')).toBe(false) }) }) + +describe('buildMessageIdMap', () => { + it('resolves assistant parent user messages without relying on adjacency', () => { + const parentUser = createUserMessage('user-1', 1000) + const otherUser = createUserMessage('user-2', 2000) + const assistant = createAssistantMessage('assistant-1', [], 2001) + expect(isAssistantMessage(assistant.info)).toBe(true) + if (!isAssistantMessage(assistant.info)) throw new Error('Expected assistant message') + + const messageIdMap = buildMessageIdMap([parentUser, otherUser, assistant]) + + expect(messageIdMap.get(assistant.info.parentID)).toBe(parentUser) + }) + + it('returns undefined safely when an assistant parent message is missing', () => { + const assistant = createAssistantMessage('assistant-missing-parent', [], 2001) + expect(isAssistantMessage(assistant.info)).toBe(true) + if (!isAssistantMessage(assistant.info)) throw new Error('Expected assistant message') + assistant.info.parentID = 'missing-user' + + const messageIdMap = buildMessageIdMap([assistant]) + + expect(messageIdMap.get(assistant.info.parentID)).toBeUndefined() + }) +}) diff --git a/src/features/chat/ChatArea.tsx b/src/features/chat/ChatArea.tsx index 999045f6..3e1d7855 100644 --- a/src/features/chat/ChatArea.tsx +++ b/src/features/chat/ChatArea.tsx @@ -96,6 +96,10 @@ export function buildTurnDurationMap(messages: Message[], visibleMessages: Messa return map } +export function buildMessageIdMap(messages: Message[]): Map { + return new Map(messages.map(message => [message.info.id, message])) +} + interface ChatAreaProps { messages: Message[] sessionId?: string | null @@ -169,6 +173,8 @@ export const ChatArea = memo( // ---- Data ---- const visibleMessageEntries = useMemo(() => buildVisibleMessageEntries(messages), [messages]) const visibleMessages = useMemo(() => visibleMessageEntries.map(e => e.message), [visibleMessageEntries]) + const visibleMessageCount = visibleMessages.length + const messageIdMap = useMemo(() => buildMessageIdMap(messages), [messages]) const forkTargetIdMap = useMemo( () => new Map(visibleMessageEntries.map(entry => [entry.message.info.id, getVisibleMessageForkTargetId(entry)])), @@ -282,7 +288,7 @@ export const ChatArea = memo( observer.observe(sentinel) return () => observer.disconnect() - }, [sessionId, visibleMessages]) + }, [sessionId]) // column-reverse 下 prepend 在负方向远端,scrollTop 不变,视口自然不跳。 // 不需要手动补偿。 @@ -299,6 +305,10 @@ export const ChatArea = memo( useEffect(() => { const root = scrollRef.current if (!root) return + if (visibleMessageCount === 0) { + onVisibleIdsChangeRef.current?.([]) + return + } const visibleIds = new Set() const observer = new IntersectionObserver( @@ -326,10 +336,12 @@ export const ChatArea = memo( // Observe all current message elements const elements = root.querySelectorAll('[data-message-id]') - elements.forEach(el => observer.observe(el)) + elements.forEach(el => { + observer.observe(el) + }) return () => observer.disconnect() - }, [visibleMessages]) + }, [visibleMessageCount]) // ============================================ // Imperative Handle @@ -415,6 +427,9 @@ export const ChatArea = memo( > = {}): ApiSession { + return { + id: 'session-1', + title: 'Active session title', + directory: '/workspace/project', + ...overrides, + } as ApiSession +} + +function renderItem(entry: ActiveSessionTreeEntry, resolvedSession?: ApiSession) { + const onSelect = vi.fn() + const view = render( + , + ) + + return { + ...view, + onSelect, + } +} + +describe('ActiveSessionItem', () => { + it('shows a neutral descendant-only label without self-active status labels or pulse animation', () => { + const descendantOnlyEntry = { + sessionId: 'session-descendant', + title: 'Ancestor row', + directory: '/workspace/project', + activitySource: 'descendant', + } satisfies ActiveSessionTreeEntry + + const { container } = renderItem(descendantOnlyEntry, createResolvedSession({ id: 'session-descendant', title: 'Ancestor row' })) + + expect(screen.getByText('Ancestor row')).toBeInTheDocument() + expect(screen.getByText('Active below')).toBeInTheDocument() + expect(screen.queryByText('Working')).not.toBeInTheDocument() + expect(screen.queryByText('Retrying')).not.toBeInTheDocument() + expect(screen.queryByText('Awaiting Permission')).not.toBeInTheDocument() + expect(screen.queryByText('Awaiting Answer')).not.toBeInTheDocument() + expect(container.querySelector('.animate-ping')).not.toBeInTheDocument() + }) + + it('keeps the working label and pulse animation for self-active busy entries', () => { + const busyEntry = { + sessionId: 'session-working', + activitySource: 'self', + status: { type: 'busy' }, + title: 'Working session', + directory: '/workspace/project', + } satisfies ActiveSessionTreeEntry + + const { container } = renderItem(busyEntry, createResolvedSession({ id: 'session-working', title: 'Working session' })) + + expect(screen.getByText('Working session')).toBeInTheDocument() + expect(screen.getByText('Working')).toBeInTheDocument() + expect(container.querySelector('.animate-ping')).toBeInTheDocument() + }) + + it('keeps unresolved descendant-only rows non-interactive without a resolved session', () => { + const descendantOnlyEntry = { + sessionId: 'session-descendant', + title: 'Ancestor row', + directory: '/workspace/project', + activitySource: 'descendant', + } satisfies ActiveSessionTreeEntry + + const onSelect = vi.fn() + render() + + const button = screen.getByRole('button', { name: /ancestor row/i }) + expect(button).toBeDisabled() + expect(button.draggable).toBe(false) + + fireEvent.click(button) + + expect(onSelect).not.toHaveBeenCalled() + }) + + it('keeps the pending permission label for self-active entries waiting on user input', () => { + const pendingEntry = { + sessionId: 'session-permission', + activitySource: 'self', + status: { type: 'busy' }, + title: 'Permission session', + directory: '/workspace/project', + pendingAction: { + type: 'permission', + description: 'write /workspace/project/file.ts', + }, + } satisfies ActiveSessionTreeEntry + + renderItem(pendingEntry, createResolvedSession({ id: 'session-permission', title: 'Permission session' })) + + expect(screen.getByText('Awaiting Permission')).toBeInTheDocument() + expect(screen.getByText('write /workspace/project/file.ts')).toBeInTheDocument() + }) + + it('keeps the retry label and attempt count for self-active retry entries', () => { + const retryEntry = { + sessionId: 'session-retry', + activitySource: 'self', + status: { type: 'retry', attempt: 3, next: Date.now() + 1000, message: 'Temporary failure' }, + title: 'Retry session', + directory: '/workspace/project', + } satisfies ActiveSessionTreeEntry + + renderItem(retryEntry, createResolvedSession({ id: 'session-retry', title: 'Retry session' })) + + expect(screen.getByText('Retrying')).toBeInTheDocument() + expect(screen.getByText('attempt 3')).toBeInTheDocument() + }) +}) diff --git a/src/features/chat/sidebar/ActiveSessionItem.tsx b/src/features/chat/sidebar/ActiveSessionItem.tsx index bd3204e6..1e833f41 100644 --- a/src/features/chat/sidebar/ActiveSessionItem.tsx +++ b/src/features/chat/sidebar/ActiveSessionItem.tsx @@ -1,9 +1,9 @@ import { useTranslation } from 'react-i18next' -import type { ActiveSessionEntry } from '../../../store/activeSessionStore' import type { ApiSession } from '../../../api' +import type { ActiveSessionTreeEntry } from './activeSessionTree' interface ActiveSessionItemProps { - entry: ActiveSessionEntry + entry: ActiveSessionTreeEntry /** 从 sessions 列表或 API 拉取到的完整 session 对象 */ resolvedSession?: ApiSession isSelected: boolean @@ -12,16 +12,17 @@ interface ActiveSessionItemProps { export function ActiveSessionItem({ entry, resolvedSession, isSelected, onSelect }: ActiveSessionItemProps) { const { t } = useTranslation(['chat', 'common']) - const isRetry = entry.status.type === 'retry' - const pending = entry.pendingAction + const isSelfActive = entry.activitySource === 'self' + const isRetry = isSelfActive && entry.status.type === 'retry' + const pending = isSelfActive ? entry.pendingAction : undefined // 标题优先从 resolvedSession 取,然后 fallback 到 entry.title(sessionMeta),最后截取 ID const displayTitle = resolvedSession?.title || entry.title || entry.sessionId.slice(0, 12) + '...' // 目录优先从 resolvedSession 取 const directory = resolvedSession?.directory || entry.directory // 状态显示:permission > question > retry > working - const statusConfig = - pending?.type === 'permission' + const statusConfig = isSelfActive + ? pending?.type === 'permission' ? { label: t('activeSession.awaitingPermission'), color: 'text-warning-100', @@ -33,13 +34,17 @@ export function ActiveSessionItem({ entry, resolvedSession, isSelected, onSelect : isRetry ? { label: t('activeSession.retrying'), color: 'text-warning-100', dotColor: 'bg-warning-100', pulse: false } : { label: t('activeSession.working'), color: 'text-success-100', dotColor: 'bg-success-100', pulse: true } + : { + label: t('activeSession.activeBelow'), + color: 'text-text-400', + dotColor: 'bg-text-400', + pulse: false, + } const handleClick = () => { if (resolvedSession) { onSelect(resolvedSession) } - // 如果没有 resolvedSession(极端情况:API 拉取失败),不做任何事 - // 用户可以等 session 数据加载完,或从 Recents tab 找到 } // 拖拽到主信息流进行分屏 / 替换会话 @@ -58,11 +63,13 @@ export function ActiveSessionItem({ entry, resolvedSession, isSelected, onSelect } return ( -
@@ -91,7 +98,7 @@ export function ActiveSessionItem({ entry, resolvedSession, isSelected, onSelect {pending.description} )} - {isRetry && entry.status.type === 'retry' && ( + {isSelfActive && isRetry && entry.status.type === 'retry' && ( <> · @@ -109,6 +116,6 @@ export function ActiveSessionItem({ entry, resolvedSession, isSelected, onSelect )}
- + ) } diff --git a/src/features/chat/sidebar/SessionChildrenSlot.test.tsx b/src/features/chat/sidebar/SessionChildrenSlot.test.tsx new file mode 100644 index 00000000..d9f12552 --- /dev/null +++ b/src/features/chat/sidebar/SessionChildrenSlot.test.tsx @@ -0,0 +1,216 @@ +import { act, render, screen, waitFor } from '@testing-library/react' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import type { ApiSession } from '../../../api' +import { SessionChildrenSlot } from './SessionChildrenSlot' + +const { getSessionChildrenMock, layoutState, busySessionsState } = vi.hoisted(() => ({ + getSessionChildrenMock: vi.fn(), + layoutState: { sidebarSubSessionSortOrder: 'activeAsc' as 'activeAsc' | 'activeDesc' }, + busySessionsState: [] as Array<{ sessionId: string }>, +})) + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +vi.mock('../../../api', () => ({ + getSessionChildren: getSessionChildrenMock, + updateSession: vi.fn(), + deleteSession: vi.fn(), +})) + +vi.mock('../../../hooks/useInputCapabilities', () => ({ + useInputCapabilities: () => ({ preferTouchUi: false }), +})) + +vi.mock('../../../store', () => ({ + useLayoutStore: () => layoutState, + useBusySessions: () => busySessionsState, +})) + +vi.mock('../../../components/ui/ConfirmDialog', () => ({ + ConfirmDialog: () => null, +})) + +vi.mock('../../sessions', () => ({ + SessionListItem: ({ session }: { session: ApiSession }) =>
{session.title}
, +})) + +function createSession(id: string, title: string, created: number, updated?: number): ApiSession { + return { + id, + title, + directory: '/workspace/project', + time: { + created, + ...(updated === undefined ? {} : { updated }), + }, + } as ApiSession +} + +function createParentSession(): ApiSession { + return createSession('parent-1', 'Parent session', 50) +} + +function getRenderedTitles() { + return screen.getAllByTestId('child-session-row').map(node => node.textContent) +} + +describe('SessionChildrenSlot', () => { + beforeEach(() => { + layoutState.sidebarSubSessionSortOrder = 'activeAsc' + busySessionsState.length = 0 + getSessionChildrenMock.mockReset() + }) + + afterEach(() => { + vi.useRealTimers() + }) + + it('sorts provided child sessions by active time ascending without mutating the input array', () => { + const providedChildren = [ + createSession('child-3', 'Third active', 30, 300), + createSession('child-1', 'First active', 10, 100), + createSession('child-2', 'Second active', 20, 200), + ] + const originalTitles = providedChildren.map(session => session.title) + + render( + + {providedChildren} + , + ) + + expect(getRenderedTitles()).toEqual(['First active', 'Second active', 'Third active']) + expect(providedChildren.map(session => session.title)).toEqual(originalTitles) + }) + + it('sorts fetched child sessions by active time descending', async () => { + layoutState.sidebarSubSessionSortOrder = 'activeDesc' + getSessionChildrenMock.mockResolvedValue([ + createSession('child-1', 'First active', 10, 100), + createSession('child-3', 'Third active', 30, 300), + createSession('child-2', 'Second active', 20, 200), + ]) + + render( + , + ) + + await waitFor(() => { + expect(getRenderedTitles()).toEqual(['Third active', 'Second active', 'First active']) + }) + }) + + it('preserves deterministic relative order when active times fall into the just-now bucket', () => { + const now = Date.now() + const providedChildren = [ + createSession('child-a', 'Alpha', 20, now - 10_000), + createSession('child-b', 'Bravo', 20, now - 20_000), + createSession('child-c', 'Charlie', 20, now - 30_000), + ] + const originalTitles = providedChildren.map(session => session.title) + + render( + + {providedChildren} + , + ) + + expect(getRenderedTitles()).toEqual(['Alpha', 'Bravo', 'Charlie']) + expect(providedChildren.map(session => session.title)).toEqual(originalTitles) + }) + + it('ranks busy just-now sessions ahead of non-busy just-now sessions', () => { + const now = Date.now() + const providedChildren = [ + createSession('child-a', 'Inactive now', 20, now - 10_000), + createSession('child-b', 'Busy now', 20, now - 20_000), + createSession('child-c', 'Inactive now 2', 20, now - 30_000), + ] + + busySessionsState.push({ sessionId: 'child-b' }) + + render( + + {providedChildren} + , + ) + + expect(getRenderedTitles()).toEqual(['Busy now', 'Inactive now', 'Inactive now 2']) + }) + + it('does not mutate the fetched child array while sorting by active time', async () => { + layoutState.sidebarSubSessionSortOrder = 'activeDesc' + const fetchedChildren = [ + createSession('child-1', 'First active', 10, 100), + createSession('child-3', 'Third active', 30, 300), + createSession('child-2', 'Second active', 20, 200), + ] + const originalTitles = fetchedChildren.map(session => session.title) + getSessionChildrenMock.mockResolvedValue(fetchedChildren) + + render( + , + ) + + await waitFor(() => { + expect(getRenderedTitles()).toEqual(['Third active', 'Second active', 'First active']) + }) + + expect(fetchedChildren.map(session => session.title)).toEqual(originalTitles) + }) + + it('re-sorts by active time after sessions age out of the just-now bucket', () => { + vi.useFakeTimers() + vi.setSystemTime(new Date('2026-05-19T00:00:00.000Z')) + + const providedChildren = [ + createSession('child-a', 'Alpha', 20, Date.now() - 10_000), + createSession('child-b', 'Bravo', 20, Date.now() - 30_000), + createSession('child-c', 'Charlie', 20, Date.now() - 20_000), + ] + + render( + + {providedChildren} + , + ) + + expect(getRenderedTitles()).toEqual(['Alpha', 'Bravo', 'Charlie']) + + act(() => { + vi.advanceTimersByTime(61_000) + }) + + expect(getRenderedTitles()).toEqual(['Bravo', 'Charlie', 'Alpha']) + }) +}) diff --git a/src/features/chat/sidebar/SessionChildrenSlot.tsx b/src/features/chat/sidebar/SessionChildrenSlot.tsx index 91edb4ea..64d1482d 100644 --- a/src/features/chat/sidebar/SessionChildrenSlot.tsx +++ b/src/features/chat/sidebar/SessionChildrenSlot.tsx @@ -2,12 +2,15 @@ // fetchAll=true → /children 拉全量,children 有值 → 直接渲染 // 删除/重命名自己管自己的状态,和主列表行为完全一致 -import { useState, useEffect, useCallback } from 'react' +import { useState, useEffect, useCallback, useMemo } from 'react' import { useTranslation } from 'react-i18next' +import { MS_PER_MINUTE } from '../../../constants' import { getSessionChildren, updateSession, deleteSession as apiDeleteSession, type ApiSession } from '../../../api' import { SpinnerIcon } from '../../../components/Icons' import { ConfirmDialog } from '../../../components/ui/ConfirmDialog' import { useInputCapabilities } from '../../../hooks/useInputCapabilities' +import { useNow } from '../../../hooks/useNow' +import { useBusySessions, useLayoutStore } from '../../../store' import { uiErrorHandler } from '../../../utils' import { SessionListItem } from '../../sessions' @@ -38,6 +41,9 @@ export function SessionChildrenSlot({ }: SessionChildrenSlotProps) { const { t } = useTranslation(['chat', 'common']) const { preferTouchUi } = useInputCapabilities() + const { sidebarSubSessionSortOrder } = useLayoutStore() + const busySessions = useBusySessions() + const now = useNow(MS_PER_MINUTE) const [fetched, setFetched] = useState([]) const [loading, setLoading] = useState(false) const [deleteConfirm, setDeleteConfirm] = useState<{ isOpen: boolean; sessionId: string | null }>({ @@ -92,7 +98,34 @@ export function SessionChildrenSlot({ } }, [deleteConfirm.sessionId, selectedSessionId, onDeleteSelected]) - const list = fetchAll ? fetched : givenChildren + const list = useMemo(() => { + const source = fetchAll ? fetched : givenChildren + if (!source) return source + const busySessionIds = new Set(busySessions.map(entry => entry.sessionId)) + + return [...source].sort((left, right) => { + const leftActive = left.time.updated ?? left.time.created + const rightActive = right.time.updated ?? right.time.created + const leftIsJustNow = now - leftActive < MS_PER_MINUTE + const rightIsJustNow = now - rightActive < MS_PER_MINUTE + const leftIsBusy = busySessionIds.has(left.id) + const rightIsBusy = busySessionIds.has(right.id) + + if (leftIsJustNow && rightIsJustNow) { + if (leftIsBusy !== rightIsBusy) { + return leftIsBusy ? -1 : 1 + } + + return 0 + } + + if (sidebarSubSessionSortOrder === 'activeDesc') { + return rightActive - leftActive + } + + return leftActive - rightActive + }) + }, [busySessions, fetchAll, fetched, givenChildren, now, sidebarSubSessionSortOrder]) if (!list?.length && !loading) return null diff --git a/src/features/chat/sidebar/SidePanel.tsx b/src/features/chat/sidebar/SidePanel.tsx index 1354638c..61d9cb56 100644 --- a/src/features/chat/sidebar/SidePanel.tsx +++ b/src/features/chat/sidebar/SidePanel.tsx @@ -7,7 +7,8 @@ import { ConfirmDialog } from '../../../components/ui/ConfirmDialog' import { ActiveSessionItem } from './ActiveSessionItem' import { NotificationItem } from './NotificationItem' import { SidebarFooter } from './SidebarFooter' -import { buildActiveSessionTree } from './activeSessionTree' +import { buildActiveSessionTree, type ActiveSessionTreeEntry } from './activeSessionTree' +import { buildActiveTreeSessionTargets } from './activeSessionTargets' import { getParentPath } from './sidebarUtils' import { SidebarIcon, @@ -75,6 +76,27 @@ interface ProjectItem { sectionKind?: 'project' | 'workspace' } +let childSessionStoreVersion = 0 + +function subscribeToChildSessionStoreVersion(onStoreChange: () => void) { + return childSessionStore.subscribe(() => { + childSessionStoreVersion += 1 + onStoreChange() + }) +} + +function getChildSessionStoreVersion() { + return childSessionStoreVersion +} + +function useChildSessionStoreVersion() { + return useSyncExternalStore( + subscribeToChildSessionStoreVersion, + getChildSessionStoreVersion, + getChildSessionStoreVersion, + ) +} + function getSelectionRange(visibleIds: string[], anchorId: string, targetId: string) { const startIndex = visibleIds.indexOf(anchorId) const endIndex = visibleIds.indexOf(targetId) @@ -261,11 +283,7 @@ export function SidePanel({ // Active sessions const busySessions = useBusySessions() const busyCount = useBusyCount() - const childSessionVersion = useSyncExternalStore( - childSessionStore.subscribe.bind(childSessionStore), - childSessionStore.getVersion, - childSessionStore.getVersion, - ) + const childSessionMetadataVersion = useChildSessionStoreVersion() // Notification history const notifications = useNotifications() const unreadNotificationCount = useUnreadNotificationCount() @@ -296,51 +314,47 @@ export function SidePanel({ return map }, [sessions, fetchedSessions]) - // 异步拉取不在 lookup 中的 active/notification/selected session - useEffect(() => { - const allNeeded = [ - ...busySessions.map(e => ({ sessionId: e.sessionId, directory: e.directory })), - ...notifications.map(e => ({ sessionId: e.sessionId, directory: e.directory })), - ] - if (selectedSessionId && !sessionLookup.has(selectedSessionId)) { - allNeeded.push({ sessionId: selectedSessionId, directory: currentDirectory || '' }) - } - const missing = allNeeded.filter(entry => !sessionLookup.has(entry.sessionId)) - if (missing.length === 0) return - - let cancelled = false - const fetchMissing = async () => { - const results: Record = {} - await Promise.allSettled( - missing.map(async entry => { - try { - const session = await getSession(entry.sessionId, entry.directory) - if (!cancelled) results[session.id] = session - } catch { - /* ignore */ - } - }), - ) - if (!cancelled && Object.keys(results).length > 0) { - setFetchedSessions(prev => ({ ...prev, ...results })) - } - } - fetchMissing() - return () => { - cancelled = true - } - }, [busySessions, notifications, sessionLookup, selectedSessionId, currentDirectory]) - // ---- 子 session 展示数据 ---- const rootSessionIds = useMemo(() => new Set(sessions.map(s => s.id)), [sessions]) + const getChildSessionInfo = useCallback( + (sessionId: string) => { + void childSessionMetadataVersion + return childSessionStore.getSessionInfo(sessionId) + }, + [childSessionMetadataVersion], + ) + const findParentId = useCallback( (id: string) => { const s = sessionLookup.get(id) if (s?.parentID) return s.parentID - return childSessionStore.getSessionInfo(id)?.parentID + return getChildSessionInfo(id)?.parentID + }, + [sessionLookup, getChildSessionInfo], + ) + + const resolveActiveTreeEntry = useCallback( + (sessionId: string) => { + const resolvedSession = sessionLookup.get(sessionId) + if (resolvedSession) { + return { + title: resolvedSession.title, + directory: resolvedSession.directory, + } + } + + const childSessionInfo = getChildSessionInfo(sessionId) + if (childSessionInfo) { + return { + title: childSessionInfo.title, + directory: childSessionInfo.directory, + } + } + + return undefined }, - [sessionLookup, childSessionVersion], + [sessionLookup, getChildSessionInfo], ) // 开关开 → 拉 /children 全量:选中的 root 或选中子 session 时保持其父展开 @@ -392,10 +406,50 @@ export function SidePanel({ ]) const activeSessionTree = useMemo( - () => buildActiveSessionTree(busySessions, findParentId), - [busySessions, findParentId], + () => buildActiveSessionTree(busySessions, findParentId, resolveActiveTreeEntry), + [busySessions, findParentId, resolveActiveTreeEntry], + ) + const activeTreeSessionTargets = useMemo( + () => buildActiveTreeSessionTargets(activeSessionTree), + [activeSessionTree], ) + // 异步拉取不在 lookup 中的 active/notification/selected session + useEffect(() => { + const allNeeded = [ + ...busySessions.map(e => ({ sessionId: e.sessionId, directory: e.directory })), + ...activeTreeSessionTargets, + ...notifications.map(e => ({ sessionId: e.sessionId, directory: e.directory })), + ] + if (selectedSessionId && !sessionLookup.has(selectedSessionId)) { + allNeeded.push({ sessionId: selectedSessionId, directory: currentDirectory || '' }) + } + const missing = allNeeded.filter(entry => !sessionLookup.has(entry.sessionId)) + if (missing.length === 0) return + + let cancelled = false + const fetchMissing = async () => { + const results: Record = {} + await Promise.allSettled( + missing.map(async entry => { + try { + const session = await getSession(entry.sessionId, entry.directory) + if (!cancelled) results[session.id] = session + } catch { + /* ignore */ + } + }), + ) + if (!cancelled && Object.keys(results).length > 0) { + setFetchedSessions(prev => ({ ...prev, ...results })) + } + } + fetchMissing() + return () => { + cancelled = true + } + }, [busySessions, activeTreeSessionTargets, notifications, sessionLookup, selectedSessionId, currentDirectory]) + const buildProjectGroups = useCallback( (directories: typeof savedDirectories): ProjectItem[] => { const savedNameByPath = new Map( @@ -603,7 +657,9 @@ export function SidePanel({ const handleRemoveProject = useCallback( (projectId: string) => { - getProjectDirectoriesToRemove(projectId).forEach(directory => removeDirectory(directory)) + getProjectDirectoriesToRemove(projectId).forEach(directory => { + removeDirectory(directory) + }) }, [getProjectDirectoriesToRemove, removeDirectory], ) @@ -649,7 +705,7 @@ export function SidePanel({ ) const renderActiveSessionNode = useCallback( - function renderActiveSessionNode(entry: (typeof busySessions)[number], level = 0): ReactNode { + function renderActiveSessionNode(entry: ActiveSessionTreeEntry, level = 0): ReactNode { const resolvedSession = sessionLookup.get(entry.sessionId) const childEntries = activeSessionTree.childrenByParent.get(entry.sessionId) ?? [] @@ -760,7 +816,9 @@ export function SidePanel({ const handleBatchRemoveProjects = useCallback(() => { if (selectedProjectIds.size === 0) return for (const projectId of selectedProjectIds) { - getProjectDirectoriesToRemove(projectId).forEach(directory => removeDirectory(directory)) + getProjectDirectoriesToRemove(projectId).forEach(directory => { + removeDirectory(directory) + }) } setSelectedProjectIds(new Set()) projectSelectionAnchorIdRef.current = null @@ -827,6 +885,7 @@ export function SidePanel({ style={{ justifyContent: showLabels ? 'flex-end' : 'center', paddingRight: showLabels ? 8 : 0 }} > )} + + + {floatingMenu} diff --git a/src/features/chat/sidebar/activeSessionTargets.test.ts b/src/features/chat/sidebar/activeSessionTargets.test.ts new file mode 100644 index 00000000..4d1ee4f7 --- /dev/null +++ b/src/features/chat/sidebar/activeSessionTargets.test.ts @@ -0,0 +1,100 @@ +import { describe, expect, it } from 'vitest' +import { buildActiveTreeSessionTargets } from './activeSessionTargets' +import type { ActiveSessionTree } from './activeSessionTree' + +describe('buildActiveTreeSessionTargets', () => { + it('propagates a child directory upward so inferred ancestors become hydratable', () => { + const tree: ActiveSessionTree = { + rootEntries: [ + { + sessionId: 'root-1', + activitySource: 'descendant', + title: 'Root', + }, + ], + childrenByParent: new Map([ + [ + 'root-1', + [ + { + sessionId: 'parent-1', + activitySource: 'descendant', + title: 'Parent', + }, + ], + ], + [ + 'parent-1', + [ + { + sessionId: 'child-1', + activitySource: 'self', + status: { type: 'busy' }, + directory: '/workspace/demo', + }, + ], + ], + ]), + } + + expect(buildActiveTreeSessionTargets(tree)).toEqual([ + { sessionId: 'child-1', directory: '/workspace/demo' }, + { sessionId: 'parent-1', directory: '/workspace/demo' }, + { sessionId: 'root-1', directory: '/workspace/demo' }, + ]) + }) + + it('keeps a known ancestor directory instead of overwriting it from descendants', () => { + const tree: ActiveSessionTree = { + rootEntries: [ + { + sessionId: 'root-1', + activitySource: 'descendant', + directory: '/workspace/root', + }, + ], + childrenByParent: new Map([ + [ + 'root-1', + [ + { + sessionId: 'child-1', + activitySource: 'self', + status: { type: 'busy' }, + directory: '/workspace/child', + }, + ], + ], + ]), + } + + expect(buildActiveTreeSessionTargets(tree)).toEqual([ + { sessionId: 'child-1', directory: '/workspace/child' }, + { sessionId: 'root-1', directory: '/workspace/root' }, + ]) + }) + + it('skips targets that still have no usable directory anywhere in the branch', () => { + const tree: ActiveSessionTree = { + rootEntries: [ + { + sessionId: 'root-1', + activitySource: 'descendant', + }, + ], + childrenByParent: new Map([ + [ + 'root-1', + [ + { + sessionId: 'child-1', + activitySource: 'descendant', + }, + ], + ], + ]), + } + + expect(buildActiveTreeSessionTargets(tree)).toEqual([]) + }) +}) diff --git a/src/features/chat/sidebar/activeSessionTargets.ts b/src/features/chat/sidebar/activeSessionTargets.ts new file mode 100644 index 00000000..5d978ede --- /dev/null +++ b/src/features/chat/sidebar/activeSessionTargets.ts @@ -0,0 +1,35 @@ +import type { ActiveSessionTree, ActiveSessionTreeEntry } from './activeSessionTree' + +export interface ActiveSessionFetchTarget { + sessionId: string + directory: string +} + +export function buildActiveTreeSessionTargets(activeSessionTree: ActiveSessionTree): ActiveSessionFetchTarget[] { + const targets = new Map() + + const visit = (entry: ActiveSessionTreeEntry): string | undefined => { + const children = activeSessionTree.childrenByParent.get(entry.sessionId) ?? [] + + let inheritedDirectory = entry.directory + + for (const child of children) { + const childDirectory = visit(child) + if (!inheritedDirectory && childDirectory) { + inheritedDirectory = childDirectory + } + } + + if (inheritedDirectory) { + targets.set(entry.sessionId, inheritedDirectory) + } + + return inheritedDirectory + } + + for (const rootEntry of activeSessionTree.rootEntries) { + visit(rootEntry) + } + + return Array.from(targets, ([sessionId, directory]) => ({ sessionId, directory })) +} diff --git a/src/features/chat/sidebar/activeSessionTree.test.ts b/src/features/chat/sidebar/activeSessionTree.test.ts index 1015033f..2a6079a5 100644 --- a/src/features/chat/sidebar/activeSessionTree.test.ts +++ b/src/features/chat/sidebar/activeSessionTree.test.ts @@ -2,6 +2,11 @@ import { describe, expect, it } from 'vitest' import type { ActiveSessionEntry } from '../../../store/activeSessionStore' import { buildActiveSessionTree } from './activeSessionTree' +type DisplayEntry = { + sessionId: string + activitySource: 'self' | 'descendant' +} + function makeEntry(sessionId: string): ActiveSessionEntry { return { sessionId, @@ -9,33 +14,214 @@ function makeEntry(sessionId: string): ActiveSessionEntry { } } +function displayEntry(sessionId: string, activitySource: 'self' | 'descendant'): DisplayEntry { + return { sessionId, activitySource } +} + +function mapTree(tree: ReturnType) { + return { + rootEntries: tree.rootEntries.map(entry => displayEntry(entry.sessionId, entry.activitySource)), + childrenByParent: new Map( + Array.from(tree.childrenByParent.entries(), ([parentId, entries]) => [ + parentId, + entries.map(entry => displayEntry(entry.sessionId, entry.activitySource)), + ]), + ), + } +} + +function noResolvedEntry() { + return undefined +} + describe('buildActiveSessionTree', () => { - it('nests active children under an active parent', () => { - const root = makeEntry('root') - const child = makeEntry('child') - const grandchild = makeEntry('grandchild') + it('nests an active child under an active parent with self activity sources', () => { + const root = makeEntry('root-1') + const child = makeEntry('parent-1') + + const tree = mapTree( + buildActiveSessionTree( + [root, child], + sessionId => { + if (sessionId === 'parent-1') return 'root-1' + return undefined + }, + noResolvedEntry, + ), + ) + + expect(tree.rootEntries).toEqual([displayEntry('root-1', 'self')]) + expect(tree.childrenByParent).toEqual(new Map([['root-1', [displayEntry('parent-1', 'self')]]])) + }) + + it('injects inactive ancestors to preserve a deep active chain', () => { + const root = makeEntry('root-1') + const leaf = makeEntry('leaf-1') + + const tree = mapTree( + buildActiveSessionTree( + [root, leaf], + sessionId => { + if (sessionId === 'leaf-1') return 'child-1' + if (sessionId === 'child-1') return 'parent-1' + if (sessionId === 'parent-1') return 'root-1' + return undefined + }, + noResolvedEntry, + ), + ) + + expect(tree.rootEntries).toEqual([displayEntry('root-1', 'self')]) + expect(tree.childrenByParent).toEqual( + new Map([ + ['root-1', [displayEntry('parent-1', 'descendant')]], + ['parent-1', [displayEntry('child-1', 'descendant')]], + ['child-1', [displayEntry('leaf-1', 'self')]], + ]), + ) + }) + + it('keeps sibling order while injecting a shared inactive ancestor once', () => { + const root = makeEntry('root-1') + const child = makeEntry('child-1') + const sibling = makeEntry('sibling-1') + + const tree = mapTree( + buildActiveSessionTree( + [root, child, sibling], + sessionId => { + if (sessionId === 'child-1' || sessionId === 'sibling-1') return 'parent-1' + if (sessionId === 'parent-1') return 'root-1' + return undefined + }, + noResolvedEntry, + ), + ) + + expect(tree.rootEntries).toEqual([displayEntry('root-1', 'self')]) + expect(tree.childrenByParent).toEqual( + new Map([ + ['root-1', [displayEntry('parent-1', 'descendant')]], + ['parent-1', [displayEntry('child-1', 'self'), displayEntry('sibling-1', 'self')]], + ]), + ) + }) + + it('prefers a self-active parent entry over descendant-only injection', () => { + const root = makeEntry('root-1') + const parent = makeEntry('parent-1') + const child = makeEntry('child-1') + + const tree = mapTree( + buildActiveSessionTree( + [root, parent, child], + sessionId => { + if (sessionId === 'parent-1') return 'root-1' + if (sessionId === 'child-1') return 'parent-1' + return undefined + }, + noResolvedEntry, + ), + ) + + expect(tree.rootEntries).toEqual([displayEntry('root-1', 'self')]) + expect(tree.childrenByParent).toEqual( + new Map([ + ['root-1', [displayEntry('parent-1', 'self')]], + ['parent-1', [displayEntry('child-1', 'self')]], + ]), + ) + }) + + it('replaces a descendant placeholder when a busy parent is encountered after its child', () => { + const root = makeEntry('root-1') + const parent = makeEntry('parent-1') + const child = makeEntry('child-1') + + const tree = mapTree( + buildActiveSessionTree( + [root, child, parent], + sessionId => { + if (sessionId === 'parent-1') return 'root-1' + if (sessionId === 'child-1') return 'parent-1' + return undefined + }, + noResolvedEntry, + ), + ) + + expect(tree.rootEntries).toEqual([displayEntry('root-1', 'self')]) + expect(tree.childrenByParent).toEqual( + new Map([ + ['root-1', [displayEntry('parent-1', 'self')]], + ['parent-1', [displayEntry('child-1', 'self')]], + ]), + ) + }) + + it('keeps active children visible when an ancestor chain cannot be fully resolved', () => { + const root = makeEntry('root-1') + const orphanChild = makeEntry('orphan-child-1') + + const tree = mapTree( + buildActiveSessionTree( + [root, orphanChild], + sessionId => { + if (sessionId === 'orphan-child-1') return 'missing-parent' + return undefined + }, + noResolvedEntry, + ), + ) + + expect(tree.rootEntries).toEqual([displayEntry('root-1', 'self'), displayEntry('missing-parent', 'descendant')]) + expect(tree.childrenByParent).toEqual( + new Map([['missing-parent', [displayEntry('orphan-child-1', 'self')]]]), + ) + }) + + it('breaks parent cycles by falling back to top-level self entries', () => { + const child = makeEntry('child-1') + const leaf = makeEntry('leaf-1') - const tree = buildActiveSessionTree([root, child, grandchild], sessionId => { - if (sessionId === 'child') return 'root' - if (sessionId === 'grandchild') return 'child' - return undefined - }) + const tree = mapTree( + buildActiveSessionTree( + [child, leaf], + sessionId => { + if (sessionId === 'child-1') return 'leaf-1' + if (sessionId === 'leaf-1') return 'child-1' + return undefined + }, + noResolvedEntry, + ), + ) - expect(tree.rootEntries).toEqual([root]) - expect(tree.childrenByParent.get('root')).toEqual([child]) - expect(tree.childrenByParent.get('child')).toEqual([grandchild]) + expect(tree.rootEntries).toEqual([displayEntry('child-1', 'self'), displayEntry('leaf-1', 'self')]) + expect(tree.childrenByParent).toEqual(new Map()) }) - it('promotes an active child to the top level when its parent is not active', () => { - const sibling = makeEntry('sibling') - const childOnly = makeEntry('child-only') + it('does not promote an active child when its inactive parent should be preserved', () => { + const sibling = makeEntry('sibling-1') + const childOnly = makeEntry('child-1') - const tree = buildActiveSessionTree([sibling, childOnly], sessionId => { - if (sessionId === 'child-only') return 'idle-parent' - return undefined - }) + const tree = mapTree( + buildActiveSessionTree( + [sibling, childOnly], + sessionId => { + if (sessionId === 'child-1') return 'parent-1' + if (sessionId === 'parent-1') return 'root-1' + return undefined + }, + noResolvedEntry, + ), + ) - expect(tree.rootEntries).toEqual([sibling, childOnly]) - expect(tree.childrenByParent.size).toBe(0) + expect(tree.rootEntries).toEqual([displayEntry('sibling-1', 'self'), displayEntry('root-1', 'descendant')]) + expect(tree.childrenByParent).toEqual( + new Map([ + ['root-1', [displayEntry('parent-1', 'descendant')]], + ['parent-1', [displayEntry('child-1', 'self')]], + ]), + ) }) }) diff --git a/src/features/chat/sidebar/activeSessionTree.ts b/src/features/chat/sidebar/activeSessionTree.ts index 2b5dd78d..1b3adf19 100644 --- a/src/features/chat/sidebar/activeSessionTree.ts +++ b/src/features/chat/sidebar/activeSessionTree.ts @@ -1,31 +1,152 @@ import type { ActiveSessionEntry } from '../../../store/activeSessionStore' +export type ActiveSessionTreeEntry = + | (ActiveSessionEntry & { activitySource: 'self' }) + | { + sessionId: string + activitySource: 'descendant' + title?: string + directory?: string + } + export interface ActiveSessionTree { - rootEntries: ActiveSessionEntry[] - childrenByParent: Map + rootEntries: ActiveSessionTreeEntry[] + childrenByParent: Map } export function buildActiveSessionTree( busySessions: ActiveSessionEntry[], findParentId: (sessionId: string) => string | undefined, + resolveEntry: (sessionId: string) => { title?: string; directory?: string } | undefined, ): ActiveSessionTree { - const busySessionIds = new Set(busySessions.map(entry => entry.sessionId)) - const rootEntries: ActiveSessionEntry[] = [] - const childrenByParent = new Map() + const displayEntries = new Map() + const rootEntryIds: string[] = [] + const childrenByParentIds = new Map() + const rootEntryIdSet = new Set() + const childEntryIdsByParent = new Map>() - for (const entry of busySessions) { - const parentId = findParentId(entry.sessionId) + const upsertEntry = (entry: ActiveSessionTreeEntry): ActiveSessionTreeEntry => { + const existing = displayEntries.get(entry.sessionId) - if (!parentId || !busySessionIds.has(parentId)) { - rootEntries.push(entry) - continue + if (!existing) { + displayEntries.set(entry.sessionId, entry) + return entry + } + + if (existing.activitySource === 'self') { + return existing + } + + if (entry.activitySource === 'self') { + displayEntries.set(entry.sessionId, entry) + return entry + } + + const mergedEntry: ActiveSessionTreeEntry = { + sessionId: existing.sessionId, + activitySource: 'descendant', + title: existing.title ?? entry.title, + directory: existing.directory ?? entry.directory, + } + + displayEntries.set(entry.sessionId, mergedEntry) + return mergedEntry + } + + const appendRoot = (entryId: string) => { + if (rootEntryIdSet.has(entryId)) { + return } - const siblings = childrenByParent.get(parentId) + rootEntryIdSet.add(entryId) + rootEntryIds.push(entryId) + } + + const appendChild = (parentId: string, childId: string) => { + let childIds = childEntryIdsByParent.get(parentId) + if (!childIds) { + childIds = new Set() + childEntryIdsByParent.set(parentId, childIds) + } + + if (childIds.has(childId)) { + return + } + + childIds.add(childId) + const siblings = childrenByParentIds.get(parentId) if (siblings) { - siblings.push(entry) + siblings.push(childId) } else { - childrenByParent.set(parentId, [entry]) + childrenByParentIds.set(parentId, [childId]) + } + } + + for (const entry of busySessions) { + const selfEntry = upsertEntry({ + ...entry, + activitySource: 'self', + }) + const visitedIds = new Set([entry.sessionId]) + const chain: ActiveSessionTreeEntry[] = [selfEntry] + + let currentId = entry.sessionId + let cycleDetected = false + + while (true) { + const parentId = findParentId(currentId) + + if (!parentId) { + break + } + + if (visitedIds.has(parentId)) { + cycleDetected = true + break + } + + visitedIds.add(parentId) + + const parentEntry = displayEntries.get(parentId) + const resolvedMeta = resolveEntry(parentId) + const ancestorEntry = upsertEntry( + parentEntry ?? { + sessionId: parentId, + activitySource: 'descendant', + title: resolvedMeta?.title, + directory: resolvedMeta?.directory, + }, + ) + + chain.unshift(ancestorEntry) + currentId = parentId + } + + if (cycleDetected) { + appendRoot(selfEntry.sessionId) + continue + } + + appendRoot(chain[0].sessionId) + + for (let index = 1; index < chain.length; index += 1) { + appendChild(chain[index - 1].sessionId, chain[index].sessionId) + } + } + + const rootEntries = rootEntryIds + .map(sessionId => displayEntries.get(sessionId)) + .filter((entry): entry is ActiveSessionTreeEntry => entry !== undefined) + + const childrenByParent = new Map() + + for (const [parentId, childIds] of childrenByParentIds.entries()) { + const children = childIds + .map(sessionId => displayEntries.get(sessionId)) + .filter((entry): entry is ActiveSessionTreeEntry => entry !== undefined) + + if (children.length > 0) { + childrenByParent.set(parentId, children) } } diff --git a/src/features/message/MessageRenderer.test.tsx b/src/features/message/MessageRenderer.test.tsx index d5f90d1d..50f767a1 100644 --- a/src/features/message/MessageRenderer.test.tsx +++ b/src/features/message/MessageRenderer.test.tsx @@ -1,8 +1,13 @@ import type { ReactNode } from 'react' import { fireEvent, render, screen, waitFor } from '@testing-library/react' -import { describe, expect, it, vi } from 'vitest' +import { beforeEach, describe, expect, it, vi } from 'vitest' import { MessageRenderer } from './MessageRenderer' -import type { Message } from '../../types/message' +import type { Message, UserMessageInfo } from '../../types/message' + +const { useModelsMock, useThemeMock } = vi.hoisted(() => ({ + useModelsMock: vi.fn(), + useThemeMock: vi.fn(), +})) vi.mock('motion/mini', () => ({ animate: () => Promise.resolve(), @@ -10,16 +15,11 @@ vi.mock('motion/mini', () => ({ vi.mock('../../hooks', () => ({ useDelayedRender: (show: boolean) => show, + useModels: useModelsMock, })) vi.mock('../../hooks/useTheme', () => ({ - useTheme: () => ({ - collapseUserMessages: false, - stepFinishDisplay: { turnDuration: false }, - descriptiveToolSteps: false, - inlineToolRequests: false, - immersiveMode: false, - }), + useTheme: useThemeMock, })) vi.mock('../../components/ui', () => ({ @@ -34,13 +34,38 @@ vi.mock('./parts', () => ({ FilePartView: () => null, AgentPartView: () => null, SyntheticTextPartView: () => null, - StepFinishPartView: () => null, + StepFinishPartView: ({ modelLabel }: { modelLabel?: string }) => ( +
{modelLabel ? `step-finish-model:${modelLabel}` : 'step-finish-model:undefined'}
+ ), SubtaskPartView: () => null, RetryPartView: () => null, CompactionPartView: () =>
History compacted
, MessageErrorView: () => null, })) +function createThemeOverrides(overrides?: Record) { + return { + collapseUserMessages: false, + stepFinishDisplay: { + agent: false, + model: false, + tokens: false, + cache: false, + cost: false, + duration: false, + turnDuration: false, + completedAt: false, + }, + completedAtFormat: 'absolute', + modelLabelFormat: 'code', + showModelVariant: false, + descriptiveToolSteps: false, + inlineToolRequests: false, + immersiveMode: false, + ...overrides, + } +} + function createAssistantMessage(): Message { return { info: { @@ -75,6 +100,42 @@ function createAssistantMessage(): Message { } } +function createStepFinishPart() { + return { + id: 'step-finish-1', + sessionID: 'session-1', + messageID: 'assistant-1', + type: 'step-finish' as const, + reason: 'stop', + cost: 0, + tokens: { + input: 0, + output: 0, + reasoning: 0, + cache: { read: 0, write: 0 }, + }, + } +} + +function createToolPart() { + return { + id: 'tool-1', + sessionID: 'session-1', + messageID: 'assistant-1', + type: 'tool' as const, + callID: 'call-1', + tool: 'bash', + state: { + status: 'completed' as const, + input: { command: 'pwd' }, + output: '/workspace', + title: 'Ran bash', + metadata: {}, + time: { start: 1, end: 2 }, + }, + } +} + function createUserMessage(): Message { return { info: { @@ -90,7 +151,23 @@ function createUserMessage(): Message { } } +function createUserMessageWithVariant(variant?: string): Message { + const message = createUserMessage() + if (variant !== undefined) { + ;(message.info as UserMessageInfo).model.variant = variant + } + return message +} + describe('MessageRenderer assistant fork', () => { + beforeEach(() => { + useModelsMock.mockReset() + useThemeMock.mockReset() + + useModelsMock.mockReturnValue({ models: [], isLoading: false, error: null, refetch: vi.fn() }) + useThemeMock.mockImplementation(() => createThemeOverrides()) + }) + it('passes the explicit fork target id when forking an assistant message', async () => { const onFork = vi.fn() const message = createAssistantMessage() @@ -139,4 +216,271 @@ describe('MessageRenderer assistant fork', () => { expect(screen.getByText('History compacted')).toBeInTheDocument() }) + + it('uses raw assistantInfo.modelID for step-finish model label in code mode', () => { + const message = createAssistantMessage() + message.parts = [createStepFinishPart()] + + useThemeMock.mockImplementation(() => + createThemeOverrides({ + stepFinishDisplay: { + agent: false, + model: true, + tokens: false, + cache: false, + cost: false, + duration: false, + turnDuration: false, + completedAt: false, + }, + modelLabelFormat: 'code', + }), + ) + useModelsMock.mockReturnValue({ + models: [{ id: 'model-1', providerId: 'provider-1', name: 'Resolved Name' }], + isLoading: false, + error: null, + refetch: vi.fn(), + }) + + render() + + expect(screen.getByText('step-finish-model:model-1')).toBeInTheDocument() + }) + + it('keeps the existing model label unchanged when showModelVariant is disabled', () => { + const message = createAssistantMessage() + message.parts = [createStepFinishPart()] + + useThemeMock.mockImplementation(() => + createThemeOverrides({ + stepFinishDisplay: { + agent: false, + model: true, + tokens: false, + cache: false, + cost: false, + duration: false, + turnDuration: false, + completedAt: false, + }, + modelLabelFormat: 'code', + showModelVariant: false, + }), + ) + + render() + + expect(screen.getByText('step-finish-model:model-1')).toBeInTheDocument() + expect(screen.queryByText('step-finish-model:model-1 · X High')).toBeNull() + }) + + it('uses resolved model.name for step-finish model label in name mode', () => { + const message = createAssistantMessage() + message.parts = [createStepFinishPart()] + + useThemeMock.mockImplementation(() => + createThemeOverrides({ + stepFinishDisplay: { + agent: false, + model: true, + tokens: false, + cache: false, + cost: false, + duration: false, + turnDuration: false, + completedAt: false, + }, + modelLabelFormat: 'name', + }), + ) + useModelsMock.mockReturnValue({ + models: [{ id: 'model-1', providerId: 'provider-1', name: 'Resolved Name' }], + isLoading: false, + error: null, + refetch: vi.fn(), + }) + + render() + + expect(screen.getByText('step-finish-model:Resolved Name')).toBeInTheDocument() + }) + + it('appends the formatted requested variant in code mode when enabled', () => { + const message = createAssistantMessage() + message.parts = [createStepFinishPart()] + + useThemeMock.mockImplementation(() => + createThemeOverrides({ + stepFinishDisplay: { + agent: false, + model: true, + tokens: false, + cache: false, + cost: false, + duration: false, + turnDuration: false, + completedAt: false, + }, + modelLabelFormat: 'code', + showModelVariant: true, + }), + ) + + render() + + expect(screen.getByText('step-finish-model:model-1 · X High')).toBeInTheDocument() + }) + + it('appends the formatted requested variant in name mode when enabled', () => { + const message = createAssistantMessage() + message.parts = [createStepFinishPart()] + + useThemeMock.mockImplementation(() => + createThemeOverrides({ + stepFinishDisplay: { + agent: false, + model: true, + tokens: false, + cache: false, + cost: false, + duration: false, + turnDuration: false, + completedAt: false, + }, + modelLabelFormat: 'name', + showModelVariant: true, + }), + ) + useModelsMock.mockReturnValue({ + models: [{ id: 'model-1', providerId: 'provider-1', name: 'Resolved Name' }], + isLoading: false, + error: null, + refetch: vi.fn(), + }) + + render() + + expect(screen.getByText('step-finish-model:Resolved Name · Xhigh')).toBeInTheDocument() + }) + + it('passes the resolved model label through the grouped tool footer step-finish path', () => { + const message = createAssistantMessage() + message.parts = [createToolPart(), createStepFinishPart()] + + useThemeMock.mockImplementation(() => + createThemeOverrides({ + stepFinishDisplay: { + agent: false, + model: true, + tokens: false, + cache: false, + cost: false, + duration: false, + turnDuration: false, + completedAt: false, + }, + modelLabelFormat: 'name', + showModelVariant: true, + }), + ) + useModelsMock.mockReturnValue({ + models: [{ id: 'model-1', providerId: 'provider-1', name: 'Grouped Tool Model' }], + isLoading: false, + error: null, + refetch: vi.fn(), + }) + + render() + + expect(screen.getByText('step-finish-model:Grouped Tool Model · X High')).toBeInTheDocument() + }) + + it('uses the provider-specific model name when duplicate model ids exist in name mode', () => { + const message = createAssistantMessage() + message.parts = [createStepFinishPart()] + + useThemeMock.mockImplementation(() => + createThemeOverrides({ + stepFinishDisplay: { + agent: false, + model: true, + tokens: false, + cache: false, + cost: false, + duration: false, + turnDuration: false, + completedAt: false, + }, + modelLabelFormat: 'name', + }), + ) + useModelsMock.mockReturnValue({ + models: [ + { id: 'model-1', providerId: 'provider-2', name: 'Wrong Provider Name' }, + { id: 'model-1', providerId: 'provider-1', name: 'Correct Provider Name' }, + ], + isLoading: false, + error: null, + refetch: vi.fn(), + }) + + render() + + expect(screen.getByText('step-finish-model:Correct Provider Name')).toBeInTheDocument() + expect(screen.queryByText('step-finish-model:Wrong Provider Name')).toBeNull() + }) + + it('falls back to assistantInfo.modelID when name mode has an empty model list', () => { + const message = createAssistantMessage() + message.parts = [createStepFinishPart()] + + useThemeMock.mockImplementation(() => + createThemeOverrides({ + stepFinishDisplay: { + agent: false, + model: true, + tokens: false, + cache: false, + cost: false, + duration: false, + turnDuration: false, + completedAt: false, + }, + modelLabelFormat: 'name', + }), + ) + useModelsMock.mockReturnValue({ models: [], isLoading: false, error: null, refetch: vi.fn() }) + + render() + + expect(screen.getByText('step-finish-model:model-1')).toBeInTheDocument() + }) + + it('leaves the model label unchanged when the parent message has no variant', () => { + const message = createAssistantMessage() + message.parts = [createStepFinishPart()] + + useThemeMock.mockImplementation(() => + createThemeOverrides({ + stepFinishDisplay: { + agent: false, + model: true, + tokens: false, + cache: false, + cost: false, + duration: false, + turnDuration: false, + completedAt: false, + }, + modelLabelFormat: 'code', + showModelVariant: true, + }), + ) + + render() + + expect(screen.getByText('step-finish-model:model-1')).toBeInTheDocument() + expect(screen.queryByText(/step-finish-model:model-1 ·/)).toBeNull() + }) }) diff --git a/src/features/message/MessageRenderer.tsx b/src/features/message/MessageRenderer.tsx index 0db992f9..6d98e146 100644 --- a/src/features/message/MessageRenderer.tsx +++ b/src/features/message/MessageRenderer.tsx @@ -4,7 +4,7 @@ import { diffLines } from 'diff' import { animate } from 'motion/mini' import { ChevronDownIcon, ChevronRightIcon, SplitIcon, SpinnerIcon, UndoIcon } from '../../components/Icons' import { CopyButton, SmoothHeight } from '../../components/ui' -import { useDelayedRender } from '../../hooks' +import { useDelayedRender, useModels } from '../../hooks' import { useTheme } from '../../hooks/useTheme' import { useInlineToolRequests, @@ -41,6 +41,7 @@ import { formatDuration, formatCompletedAt, formatDetailedDateTime } from '../.. interface MessageRendererProps { message: Message + parentMessage?: Message allowStreamingLayoutAnimation?: boolean /** 回合总时长(毫秒),仅在回合最后一条 assistant 消息上有值 */ turnDuration?: number @@ -53,6 +54,7 @@ interface MessageRendererProps { export const MessageRenderer = memo(function MessageRenderer({ message, + parentMessage, allowStreamingLayoutAnimation = true, turnDuration, onUndo, @@ -79,6 +81,7 @@ export const MessageRenderer = memo(function MessageRenderer({ return ( {showCollapse && (