diff --git a/Cargo.lock b/Cargo.lock index 258ce02..1f5dc28 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -960,7 +960,7 @@ checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" [[package]] name = "openab" -version = "0.7.6" +version = "0.7.7" dependencies = [ "anyhow", "async-trait", diff --git a/src/acp/pool.rs b/src/acp/pool.rs index cff159b..0c7dc93 100644 --- a/src/acp/pool.rs +++ b/src/acp/pool.rs @@ -2,18 +2,22 @@ use crate::acp::connection::AcpConnection; use crate::config::AgentConfig; use anyhow::{anyhow, Result}; use std::collections::HashMap; -use tokio::sync::RwLock; +use std::sync::Arc; +use tokio::sync::{Mutex, RwLock}; use tokio::time::Instant; use tracing::{info, warn}; /// Combined state protected by a single lock to prevent deadlocks. -/// Lock ordering: always acquire `state` before any operation on either map. +/// Lock ordering: never await a per-connection mutex while holding `state`. struct PoolState { - /// Active connections: thread_key → AcpConnection. - active: HashMap, + /// Active connections: thread_key → AcpConnection handle. + active: HashMap>>, /// Suspended sessions: thread_key → ACP sessionId. /// Saved on eviction so sessions can be resumed via `session/load`. suspended: HashMap, + /// Serializes create/resume work per thread so rapid same-thread requests + /// cannot race each other into duplicate `session/load` attempts. + creating: HashMap>>, } pub struct SessionPool { @@ -22,12 +26,44 @@ pub struct SessionPool { max_sessions: usize, } +type EvictionCandidate = ( + String, + Arc>, + Instant, + Option, +); + +fn remove_if_same_handle( + map: &mut HashMap>>, + key: &str, + expected: &Arc>, +) -> Option>> { + let should_remove = map + .get(key) + .is_some_and(|current| Arc::ptr_eq(current, expected)); + if should_remove { + map.remove(key) + } else { + None + } +} + +fn get_or_insert_gate( + map: &mut HashMap>>, + key: &str, +) -> Arc> { + map.entry(key.to_string()) + .or_insert_with(|| Arc::new(Mutex::new(()))) + .clone() +} + impl SessionPool { pub fn new(config: AgentConfig, max_sessions: usize) -> Self { Self { state: RwLock::new(PoolState { active: HashMap::new(), suspended: HashMap::new(), + creating: HashMap::new(), }), config, max_sessions, @@ -35,43 +71,63 @@ impl SessionPool { } pub async fn get_or_create(&self, thread_id: &str) -> Result<()> { - // Check if alive connection exists - { - let state = self.state.read().await; - if let Some(conn) = state.active.get(thread_id) { - if conn.alive() { - return Ok(()); - } - } - } + let create_gate = { + let mut state = self.state.write().await; + get_or_insert_gate(&mut state.creating, thread_id) + }; + let _create_guard = create_gate.lock().await; - // Need to create or rebuild - let mut state = self.state.write().await; + let (existing, saved_session_id) = { + let state = self.state.read().await; + ( + state.active.get(thread_id).cloned(), + state.suspended.get(thread_id).cloned(), + ) + }; - // Double-check after acquiring write lock - if let Some(conn) = state.active.get(thread_id) { + let had_existing = existing.is_some(); + let mut saved_session_id = saved_session_id; + if let Some(conn) = existing.clone() { + let conn = conn.lock().await; if conn.alive() { return Ok(()); } - warn!(thread_id, "stale connection, rebuilding"); - suspend_entry(&mut state, thread_id); + if saved_session_id.is_none() { + saved_session_id = conn.acp_session_id.clone(); + } } - if state.active.len() >= self.max_sessions { - // LRU evict: suspend the oldest idle session to make room - let oldest = state.active + // Snapshot active handles so we can inspect them outside the state lock. + let snapshot: Vec<(String, Arc>)> = { + let state = self.state.read().await; + state + .active .iter() - .min_by_key(|(_, c)| c.last_active) - .map(|(k, _)| k.clone()); - if let Some(key) = oldest { - info!(evicted = %key, "pool full, suspending oldest idle session"); - suspend_entry(&mut state, &key); - } else { - return Err(anyhow!("pool exhausted ({} sessions)", self.max_sessions)); + .map(|(k, v)| (k.clone(), Arc::clone(v))) + .collect() + }; + + let mut eviction_candidate: Option = None; + let mut skipped_locked_candidates = 0usize; + for (key, conn) in snapshot { + if key == thread_id { + continue; + } + let conn_handle = Arc::clone(&conn); + let Ok(conn) = conn.try_lock() else { + skipped_locked_candidates += 1; + continue; + }; + let candidate = (key, conn_handle, conn.last_active, conn.acp_session_id.clone()); + match &eviction_candidate { + Some((_, _, oldest_last_active, _)) if candidate.2 >= *oldest_last_active => {} + _ => eviction_candidate = Some(candidate), } } - let mut conn = AcpConnection::spawn( + // Build the replacement connection outside the state lock so one stuck + // initialization does not block all unrelated sessions. + let mut new_conn = AcpConnection::spawn( &self.config.command, &self.config.args, &self.config.working_dir, @@ -79,14 +135,12 @@ impl SessionPool { ) .await?; - conn.initialize().await?; + new_conn.initialize().await?; - // Try to resume a suspended session via session/load - let saved_session_id = state.suspended.remove(thread_id); let mut resumed = false; if let Some(ref sid) = saved_session_id { - if conn.supports_load_session { - match conn.session_load(sid, &self.config.working_dir).await { + if new_conn.supports_load_session { + match new_conn.session_load(sid, &self.config.working_dir).await { Ok(()) => { info!(thread_id, session_id = %sid, "session resumed via session/load"); resumed = true; @@ -99,39 +153,119 @@ impl SessionPool { } if !resumed { - conn.session_new(&self.config.working_dir).await?; - if saved_session_id.is_some() { - conn.session_reset = true; + new_conn.session_new(&self.config.working_dir).await?; + // Surface the reset banner both for restored sessions and for stale + // live entries that died before we could recover a resumable + // session id. In both cases the caller is continuing after an + // unexpected session loss. + if had_existing || saved_session_id.is_some() { + new_conn.session_reset = true; + } + } + + let new_conn = Arc::new(Mutex::new(new_conn)); + + let mut state = self.state.write().await; + + // Another task may have created a healthy connection while we were + // initializing this one. + if let Some(existing) = state.active.get(thread_id).cloned() { + let Ok(existing) = existing.try_lock() else { + return Ok(()); + }; + if existing.alive() { + return Ok(()); + } + warn!(thread_id, "stale connection, rebuilding"); + drop(existing); + state.active.remove(thread_id); + } + + if state.active.len() >= self.max_sessions { + if let Some((key, expected_conn, _, sid)) = eviction_candidate { + if remove_if_same_handle(&mut state.active, &key, &expected_conn).is_some() { + info!(evicted = %key, "pool full, suspending oldest idle session"); + if let Some(sid) = sid { + state.suspended.insert(key, sid); + } + } else { + warn!(evicted = %key, "pool full but eviction candidate changed before removal"); + } + } else if skipped_locked_candidates > 0 { + warn!( + max_sessions = self.max_sessions, + skipped_locked_candidates, + "pool full but all other sessions were busy during eviction scan" + ); } } - state.active.insert(thread_id.to_string(), conn); + if state.active.len() >= self.max_sessions { + return Err(anyhow!("pool exhausted ({} sessions)", self.max_sessions)); + } + + state.suspended.remove(thread_id); + state.active.insert(thread_id.to_string(), new_conn); Ok(()) } /// Get mutable access to a connection. Caller must have called get_or_create first. pub async fn with_connection(&self, thread_id: &str, f: F) -> Result where - F: FnOnce(&mut AcpConnection) -> std::pin::Pin> + Send + '_>>, + F: for<'a> FnOnce( + &'a mut AcpConnection, + ) -> std::pin::Pin> + Send + 'a>>, { - let mut state = self.state.write().await; - let conn = state.active - .get_mut(thread_id) - .ok_or_else(|| anyhow!("no connection for thread {thread_id}"))?; - f(conn).await + let conn = { + let state = self.state.read().await; + state + .active + .get(thread_id) + .cloned() + .ok_or_else(|| anyhow!("no connection for thread {thread_id}"))? + }; + + let mut conn = conn.lock().await; + f(&mut conn).await } pub async fn cleanup_idle(&self, ttl_secs: u64) { let cutoff = Instant::now() - std::time::Duration::from_secs(ttl_secs); + + let snapshot: Vec<(String, Arc>)> = { + let state = self.state.read().await; + state + .active + .iter() + .map(|(k, v)| (k.clone(), Arc::clone(v))) + .collect() + }; + + let mut stale = Vec::new(); + for (key, conn) in snapshot { + // Skip active sessions for this cleanup round instead of waiting on + // their per-connection mutex. A busy session is not idle. + let conn_handle = Arc::clone(&conn); + let Ok(conn) = conn.try_lock() else { + continue; + }; + if conn.last_active < cutoff || !conn.alive() { + stale.push((key, conn_handle, conn.acp_session_id.clone())); + } + } + + if stale.is_empty() { + return; + } + let mut state = self.state.write().await; - let stale: Vec = state.active - .iter() - .filter(|(_, c)| c.last_active < cutoff || !c.alive()) - .map(|(k, _)| k.clone()) - .collect(); - for key in stale { - info!(thread_id = %key, "cleaning up idle session"); - suspend_entry(&mut state, &key); + for (key, expected_conn, sid) in stale { + if remove_if_same_handle(&mut state.active, &key, &expected_conn).is_some() { + info!(thread_id = %key, "cleaning up idle session"); + if let Some(sid) = sid { + state.suspended.insert(key, sid); + } + } } } @@ -143,14 +277,45 @@ impl SessionPool { } } -/// Suspend a connection: save its sessionId to the suspended map and remove -/// from active. The connection is dropped, triggering process group kill. -fn suspend_entry(state: &mut PoolState, thread_id: &str) { - if let Some(conn) = state.active.remove(thread_id) { - if let Some(sid) = &conn.acp_session_id { - info!(thread_id, session_id = %sid, "suspending session"); - state.suspended.insert(thread_id.to_string(), sid.clone()); - } - // conn dropped here → Drop impl kills process group +#[cfg(test)] +mod tests { + use super::{get_or_insert_gate, remove_if_same_handle}; + use std::collections::HashMap; + use std::sync::Arc; + use tokio::sync::Mutex; + + #[test] + fn remove_if_same_handle_removes_matching_entry() { + let expected = Arc::new(Mutex::new(1_u8)); + let mut map = HashMap::from([("thread".to_string(), Arc::clone(&expected))]); + + let removed = remove_if_same_handle(&mut map, "thread", &expected); + + assert!(removed.is_some()); + assert!(map.is_empty()); + } + + #[test] + fn remove_if_same_handle_keeps_replaced_entry() { + let stale = Arc::new(Mutex::new(1_u8)); + let fresh = Arc::new(Mutex::new(2_u8)); + let mut map = HashMap::from([("thread".to_string(), Arc::clone(&fresh))]); + + let removed = remove_if_same_handle(&mut map, "thread", &stale); + + assert!(removed.is_none()); + let current = map.get("thread").expect("entry should remain"); + assert!(Arc::ptr_eq(current, &fresh)); + } + + #[test] + fn get_or_insert_gate_reuses_gate_for_same_thread() { + let mut map = HashMap::new(); + + let first = get_or_insert_gate(&mut map, "thread"); + let second = get_or_insert_gate(&mut map, "thread"); + + assert!(Arc::ptr_eq(&first, &second)); + assert_eq!(map.len(), 1); } }