Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 51 additions & 21 deletions src/acp/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ use crate::acp::connection::AcpConnection;
use crate::config::AgentConfig;
use anyhow::{anyhow, Result};
use std::collections::HashMap;
use tokio::sync::RwLock;
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
use tokio::time::Instant;
use tracing::{info, warn};

pub struct SessionPool {
connections: RwLock<HashMap<String, AcpConnection>>,
connections: RwLock<HashMap<String, Arc<Mutex<AcpConnection>>>>,
config: AgentConfig,
max_sessions: usize,
}
Expand All @@ -22,22 +23,22 @@ impl SessionPool {
}

pub async fn get_or_create(&self, thread_id: &str) -> Result<()> {
// Check if alive connection exists
// Fast path: alive connection exists — only the read lock is needed.
{
let conns = self.connections.read().await;
if let Some(conn) = conns.get(thread_id) {
if conn.alive() {
if let Some(conn_arc) = conns.get(thread_id) {
if conn_arc.lock().await.alive() {
return Ok(());
}
}
}

// Need to create or rebuild
// Slow path: create or rebuild.
let mut conns = self.connections.write().await;

// Double-check after acquiring write lock
if let Some(conn) = conns.get(thread_id) {
if conn.alive() {
// Double-check after acquiring the write lock.
if let Some(conn_arc) = conns.get(thread_id) {
if conn_arc.lock().await.alive() {
return Ok(());
}
warn!(thread_id, "stale connection, rebuilding");
Expand All @@ -64,30 +65,59 @@ impl SessionPool {
conn.session_reset = true;
}

conns.insert(thread_id.to_string(), conn);
conns.insert(thread_id.to_string(), Arc::new(Mutex::new(conn)));
Ok(())
}

/// Get mutable access to a connection. Caller must have called get_or_create first.
/// Run `f` against a mutable connection reference. Only this connection's
/// per-session mutex is held for the callback's duration — the pool lock
/// is released immediately, so concurrent sessions are not blocked.
/// Caller must have called `get_or_create` first.
pub async fn with_connection<F, R>(&self, thread_id: &str, f: F) -> Result<R>
where
F: FnOnce(&mut AcpConnection) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<R>> + Send + '_>>,
{
let mut conns = self.connections.write().await;
let conn = conns
.get_mut(thread_id)
.ok_or_else(|| anyhow!("no connection for thread {thread_id}"))?;
f(conn).await
let conn_arc = {
let conns = self.connections.read().await;
conns
.get(thread_id)
.cloned()
.ok_or_else(|| anyhow!("no connection for thread {thread_id}"))?
};
let mut conn = conn_arc.lock().await;
f(&mut conn).await
}

pub async fn cleanup_idle(&self, ttl_secs: u64) {
let cutoff = Instant::now() - std::time::Duration::from_secs(ttl_secs);

// Snapshot the Arcs under the read lock, then release it before
// awaiting any per-connection mutex. Otherwise a long-running
// `session_prompt` would block `cleanup_idle` on the connection
// mutex while it still held the pool write lock, re-introducing
// exactly the starvation this refactor is meant to eliminate.
let snapshot: Vec<(String, Arc<Mutex<AcpConnection>>)> = {
let conns = self.connections.read().await;
conns.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
};

// Probe each connection under its own mutex. `try_lock` skips
// connections that are currently in use — they are by definition
// not idle, so there is nothing to clean up for them this round.
let mut stale = Vec::new();
for (key, conn_arc) in &snapshot {
let Ok(conn) = conn_arc.try_lock() else { continue };
if conn.last_active < cutoff || !conn.alive() {
stale.push(key.clone());
}
}

if stale.is_empty() {
return;
}

// Only now take the pool write lock to remove the stale entries.
let mut conns = self.connections.write().await;
let stale: Vec<String> = conns
.iter()
.filter(|(_, c)| c.last_active < cutoff || !c.alive())
.map(|(k, _)| k.clone())
.collect();
for key in stale {
info!(thread_id = %key, "cleaning up idle session");
conns.remove(&key);
Expand Down