diff --git a/docs/session-persistence.md b/docs/session-persistence.md new file mode 100644 index 00000000..4e33837d --- /dev/null +++ b/docs/session-persistence.md @@ -0,0 +1,199 @@ +# Session Persistence + +## 背景與問題 + +openab 的每個 Discord thread 對應一個 ACP agent process(Claude Code、Gemini、Codex、Kiro 等)。原始設計將所有 session 狀態放在記憶體的 `HashMap` 裡,這代表: + +- **Pod 重啟後**,所有進行中的對話會消失,user 必須從頭開始 +- **Agent process crash 後**,session 不會自動還原 +- **未來接其他平台**(Slack、Telegram)時,沒有統一的 session identity 格式 + +這份文件說明 `feat: add file-based session persistence` 這個 commit 的設計決策與實作細節。 + +--- + +## 設計目標 + +1. **重開機還原** — Pod 重啟後,user 發訊息時能自動還原上一次的對話脈絡 +2. **最小依賴** — 不引入 Redis、PostgreSQL 等外部服務;利用已有的 PVC mount(`/data`) +3. **Crash-safe** — 中途斷電或 OOM kill 不會損壞已儲存的資料 +4. **平台無關** — Session identity 不綁定 Discord,未來接 Slack/Telegram 不需要改核心邏輯 + +--- + +## 為什麼選擇檔案系統而不是 SQLite 或 Redis? + +| 選項 | 優點 | 缺點 | 結論 | +|------|------|------|------| +| 檔案系統 (JSONL) | 零依賴、人可讀、append-only crash-safe、PVC 已有 | 無 index query、單 pod only | **採用** | +| SQLite | 有 index、WAL crash-safe | 單 pod only、需要額外 crate | 過度設計 | +| Redis | 多 pod、有 TTL | 需要額外服務、Redis 本身也要 persist | 目前不需要 | + +目前 openab 是單 pod 部署,session 數量少(預設 max 10),不需要複雜查詢。等真的需要橫向擴展時,session store 已經是獨立的 `SessionStore` 層,換掉後端不影響其他邏輯。 + +--- + +## 架構 + +``` +/data/sessions/ +├── index.json ← 所有 session 的 metadata(atomic write) +├── discord_987654321.jsonl ← session "discord:987654321" 的對話記錄 +└── discord_111222333.jsonl +``` + +### `index.json` 格式 + +```json +{ + "sessions": { + "discord:987654321": { + "key": "discord:987654321", + "platform": "discord", + "agent": "claude-code", + "created_at": 1712345678, + "last_active": 1712399999 + } + } +} +``` + +### JSONL transcript 格式 + +每行一個 JSON,append-only: + +```jsonl +{"role":"user","content":"幫我寫一個 hello world","ts":1712345680} +{"role":"assistant","content":"好的,這是 hello world 範例...","ts":1712345682} +``` + +--- + +## Session Key 設計 + +原本的 key 是 `thread_id: u64`(Discord 特有)。 + +現在改成平台無關的字串格式: + +``` +"{platform}:{thread_id}" +``` + +範例: +- `"discord:987654321"` — Discord thread +- `"slack:C012AB3CD:1234567890.123456"` — Slack thread(未來) +- `"telegram:-100123:42"` — Telegram thread(未來) + +這讓 `SessionPool` 和 `SessionStore` 完全不知道平台是誰,只管 key 字串。 + +--- + +## 還原流程 + +``` +Pod 重啟 → 記憶體清空 + +User 在 Discord thread 發訊息 + ↓ +discord.rs: session_key = "discord:{thread_id}" + ↓ +pool.get_or_create("discord:987654321") + ↓ + ├─ 記憶體有,且 alive → 直接回傳(正常路徑) + │ + ├─ 記憶體沒有,但 index.json 有記錄 + │ ↓ + │ spawn 新 agent process + │ ↓ + │ initialize() + session/new() + │ ↓ + │ 讀 discord_987654321.jsonl(最多最近 20 條) + │ ↓ + │ session_prime_context(history) + │ → 把歷史對話送進 agent(silently drain 回應) + │ ↓ + │ 回傳,conn.session_reset = true + │ → discord.rs 顯示 "⚠️ Session expired, starting fresh..." + │ + └─ 都沒有 → 全新 session,寫入 index.json +``` + +--- + +## Crash-safe 保證 + +**`index.json`(atomic write)** + +```rust +// 寫到 .tmp 再 rename,rename 是 POSIX atomic 操作 +tokio::fs::write(&tmp, content).await?; +tokio::fs::rename(&tmp, index_path).await?; +``` + +即使在 write 過程中 crash,`.tmp` 會留著,但 `index.json` 還是舊的完整版本,下次啟動不會讀到半寫的資料。 + +**JSONL transcript(append-only)** + +每行獨立,寫到一半只會損壞最後一行。`load_transcript` 使用 `serde_json::from_str(l).ok()` 忽略無法解析的行: + +```rust +.filter_map(|l| serde_json::from_str(l).ok()) +``` + +--- + +## 還原的 context 限制 + +`store.rs` 中的常數: + +```rust +const MAX_RESTORE_ENTRIES: usize = 20; +``` + +只還原最近 20 條訊息。原因: +- Agent process 初始化時,`session_prime_context` 會把歷史送進去,等 agent 回應才算完成 +- 太長的 context 會超過 timeout(`session_prime_context` 有 60s 上限) +- 大部分對話的連貫性在最近 20 條以內就夠了 + +--- + +## 新增的設定 + +`config.toml` 新增可選的 `[session]` 區段: + +```toml +[session] +dir = "/data/sessions" # 預設值,通常不需要改 +``` + +若不設定,預設使用 `/data/sessions`(在 Helm chart 的 PVC mount 路徑 `/data` 之下)。 + +--- + +## 檔案結構 + +``` +src/ +├── session/ +│ ├── mod.rs ← pub use +│ ├── key.rs ← SessionKey struct +│ └── store.rs ← SessionStore(load/upsert/append/remove) +├── acp/ +│ ├── connection.rs ← 新增 session_prime_context() +│ └── pool.rs ← 整合 SessionStore +├── discord.rs ← 使用 SessionKey,記錄 transcript +├── config.rs ← 新增 SessionConfig +└── main.rs ← 初始化 SessionStore +``` + +--- + +## 未來擴充 + +接新平台(Slack、Telegram)時,只需要: + +1. 新增 `SessionKey::slack(...)` / `SessionKey::telegram(...)` 等 constructor +2. 實作各自的 platform handler(類似 `discord.rs`),使用同樣的 `SessionStore` 和 `SessionPool` +3. Session 核心邏輯完全不動 + +接新 agent 時,只需要修改 `config.toml` 的 `[agent]` 區段,session store 格式與 agent 無關。 diff --git a/src/acp/connection.rs b/src/acp/connection.rs index 53770509..b1419a88 100644 --- a/src/acp/connection.rs +++ b/src/acp/connection.rs @@ -316,4 +316,48 @@ impl AcpConnection { pub fn alive(&self) -> bool { !self._reader_handle.is_finished() } + + /// Inject previous conversation history into a freshly created session. + /// + /// This is called after [`session_new`] when restoring a session from the + /// on-disk store (e.g. after a pod restart). The transcript is sent as a + /// single prompt; the agent's acknowledgment response is silently drained + /// so nothing is forwarded to the platform layer. + /// + /// `history` is a slice of `(role, content)` pairs where `role` is either + /// `"user"` or `"assistant"`. At most the last 20 messages are meaningful + /// (the store already trims to that limit). + pub async fn session_prime_context(&mut self, history: &[(String, String)]) -> Result<()> { + if history.is_empty() { + return Ok(()); + } + + let mut ctx = String::from( + "[Context restoration: the following is the previous conversation history. \ + Continue from where it left off.]\n\n", + ); + for (role, content) in history { + let label = if role == "user" { "User" } else { "Assistant" }; + ctx.push_str(&format!("{label}: {content}\n\n")); + } + ctx.push_str("[End of history.]"); + + let (mut rx, _) = self.session_prompt(vec![ContentBlock::Text { text: ctx }]).await?; + + // Drain all streaming events silently; stop on final response (id set). + let drain = async { + while let Some(msg) = rx.recv().await { + if msg.id.is_some() { + break; + } + } + }; + // 60-second guard so a non-responsive agent doesn't block the pool forever. + if tokio::time::timeout(std::time::Duration::from_secs(60), drain).await.is_err() { + tracing::warn!("timeout waiting for context prime response; continuing anyway"); + } + + self.prompt_done().await; + Ok(()) + } } diff --git a/src/acp/pool.rs b/src/acp/pool.rs index a2c8a06c..1f36ae78 100644 --- a/src/acp/pool.rs +++ b/src/acp/pool.rs @@ -1,7 +1,9 @@ use crate::acp::connection::AcpConnection; use crate::config::AgentConfig; +use crate::session::{SessionMeta, SessionStore}; use anyhow::{anyhow, Result}; use std::collections::HashMap; +use std::sync::Arc; use tokio::sync::RwLock; use tokio::time::Instant; use tracing::{info, warn}; @@ -10,44 +12,66 @@ pub struct SessionPool { connections: RwLock>, config: AgentConfig, max_sessions: usize, + store: Arc, } impl SessionPool { - pub fn new(config: AgentConfig, max_sessions: usize) -> Self { + pub fn new(config: AgentConfig, max_sessions: usize, store: Arc) -> Self { Self { connections: RwLock::new(HashMap::new()), config, max_sessions, + store, } } - pub async fn get_or_create(&self, thread_id: &str) -> Result<()> { - // Check if alive connection exists + /// Return a reference to the session store (used by the platform adapter to + /// record transcript entries). + pub fn store(&self) -> &Arc { + &self.store + } + + pub async fn get_or_create(&self, session_key: &str) -> Result<()> { + // ── fast path: alive connection already in memory ──────────────────── { let conns = self.connections.read().await; - if let Some(conn) = conns.get(thread_id) { + if let Some(conn) = conns.get(session_key) { if conn.alive() { return Ok(()); } } } - // Need to create or rebuild + // ── check persistent store before acquiring the write lock ─────────── + // Load metadata + transcript while we're NOT holding the lock, so the + // (potentially slow) file I/O doesn't block other readers. + let all_meta = self.store.load_all().await; + let stored_meta = all_meta.get(session_key).cloned(); + let transcript = if stored_meta.is_some() { + self.store.load_transcript(session_key).await + } else { + vec![] + }; + + // ── acquire write lock and create / rebuild ────────────────────────── let mut conns = self.connections.write().await; - // Double-check after acquiring write lock - if let Some(conn) = conns.get(thread_id) { + // Double-check: another task may have created the connection while we + // were loading from disk. + if let Some(conn) = conns.get(session_key) { if conn.alive() { return Ok(()); } - warn!(thread_id, "stale connection, rebuilding"); - conns.remove(thread_id); + warn!(session_key, "stale connection, rebuilding"); + conns.remove(session_key); } if conns.len() >= self.max_sessions { return Err(anyhow!("pool exhausted ({} sessions)", self.max_sessions)); } + let is_restore = stored_meta.is_some(); + let mut conn = AcpConnection::spawn( &self.config.command, &self.config.args, @@ -59,39 +83,78 @@ impl SessionPool { conn.initialize().await?; conn.session_new(&self.config.working_dir).await?; - let is_rebuild = conns.contains_key(thread_id); - if is_rebuild { - conn.session_reset = true; + // If restoring an existing session, replay history into the agent so it + // has context about the previous conversation. + if !transcript.is_empty() { + let history: Vec<(String, String)> = transcript + .into_iter() + .map(|e| (e.role, e.content)) + .collect(); + if let Err(e) = conn.session_prime_context(&history).await { + warn!(error = %e, session_key, "failed to prime context; continuing without history"); + } else { + info!(session_key, messages = history.len(), "context primed from transcript"); + } + } + + // Mark the connection so the platform adapter can show a "restored" or + // "starting fresh" notice to the user. + conn.session_reset = is_restore; + + // Persist metadata (create or update last_active). + let now = now_secs(); + let meta = SessionMeta { + key: session_key.to_string(), + platform: session_key.split(':').next().unwrap_or("unknown").to_string(), + agent: self.config.command.clone(), + created_at: stored_meta.map(|m| m.created_at).unwrap_or(now), + last_active: now, + }; + if let Err(e) = self.store.upsert(meta).await { + warn!(error = %e, session_key, "failed to persist session metadata"); } - conns.insert(thread_id.to_string(), conn); + conns.insert(session_key.to_string(), conn); Ok(()) } - /// Get mutable access to a connection. Caller must have called get_or_create first. - pub async fn with_connection(&self, thread_id: &str, f: F) -> Result + /// Get mutable access to a connection via a closure. + /// The caller must have called `get_or_create` first. + pub async fn with_connection(&self, session_key: &str, f: F) -> Result where F: FnOnce(&mut AcpConnection) -> std::pin::Pin> + Send + '_>>, { let mut conns = self.connections.write().await; let conn = conns - .get_mut(thread_id) - .ok_or_else(|| anyhow!("no connection for thread {thread_id}"))?; + .get_mut(session_key) + .ok_or_else(|| anyhow!("no connection for session {session_key}"))?; f(conn).await } pub async fn cleanup_idle(&self, ttl_secs: u64) { let cutoff = Instant::now() - std::time::Duration::from_secs(ttl_secs); + let stale: Vec = { + let conns = self.connections.read().await; + conns + .iter() + .filter(|(_, c)| c.last_active < cutoff || !c.alive()) + .map(|(k, _)| k.clone()) + .collect() + }; + + if stale.is_empty() { + return; + } + let mut conns = self.connections.write().await; - let stale: Vec = conns - .iter() - .filter(|(_, c)| c.last_active < cutoff || !c.alive()) - .map(|(k, _)| k.clone()) - .collect(); for key in stale { - info!(thread_id = %key, "cleaning up idle session"); + info!(session_key = %key, "cleaning up idle session"); conns.remove(&key); - // Child process killed via kill_on_drop when AcpConnection drops + // Child process killed via kill_on_drop when AcpConnection drops. + // Remove from persistent store so it is not "restored" after cleanup. + if let Err(e) = self.store.remove(&key).await { + warn!(error = %e, session_key = %key, "failed to remove session from store"); + } } } @@ -102,3 +165,10 @@ impl SessionPool { info!(count, "pool shutdown complete"); } } + +fn now_secs() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} diff --git a/src/config.rs b/src/config.rs index c4ed3d30..269a104c 100644 --- a/src/config.rs +++ b/src/config.rs @@ -10,6 +10,8 @@ pub struct Config { #[serde(default)] pub pool: PoolConfig, #[serde(default)] + pub session: SessionConfig, + #[serde(default)] pub reactions: ReactionsConfig, #[serde(default)] pub stt: SttConfig, @@ -140,6 +142,22 @@ impl Default for PoolConfig { } } +#[derive(Debug, Deserialize)] +pub struct SessionConfig { + /// Directory where session metadata and transcripts are stored. + /// Defaults to `/data/sessions` which is inside the PVC mount. + #[serde(default = "default_session_dir")] + pub dir: String, +} + +fn default_session_dir() -> String { "/data/sessions".into() } + +impl Default for SessionConfig { + fn default() -> Self { + Self { dir: default_session_dir() } + } +} + impl Default for ReactionsConfig { fn default() -> Self { Self { diff --git a/src/discord.rs b/src/discord.rs index e267064e..d6a312c9 100644 --- a/src/discord.rs +++ b/src/discord.rs @@ -1,5 +1,6 @@ use crate::acp::{classify_notification, AcpEvent, ContentBlock, SessionPool}; use crate::config::{ReactionsConfig, SttConfig}; +use crate::session::{SessionKey, SessionStore}; use crate::error_display::{format_coded_error, format_user_error}; use crate::format; use crate::reactions::StatusReactionController; @@ -29,6 +30,7 @@ static HTTP_CLIENT: LazyLock = LazyLock::new(|| { pub struct Handler { pub pool: Arc, + pub store: Arc, pub allowed_channels: HashSet, pub allowed_users: HashSet, pub reactions_config: ReactionsConfig, @@ -181,14 +183,26 @@ impl EventHandler for Handler { } }; - let thread_key = thread_id.to_string(); - if let Err(e) = self.pool.get_or_create(&thread_key).await { + let session_key = SessionKey::discord(thread_id).to_string(); + if let Err(e) = self.pool.get_or_create(&session_key).await { let msg = format_user_error(&e.to_string()); let _ = edit(&ctx, thread_channel, thinking_msg.id, &format!("⚠️ {}", msg)).await; error!("pool error: {e}"); return; } + // Record the user's message to the persistent transcript. + // We use `prompt` (mention-stripped) rather than `prompt_with_sender` + // so the stored text is clean and readable. + let store = self.store.clone(); + let key_for_record = session_key.clone(); + let user_text = prompt.clone(); + tokio::spawn(async move { + if let Err(e) = store.append_message(&key_for_record, "user", &user_text).await { + tracing::warn!(error = %e, "failed to record user message to transcript"); + } + }); + // Create reaction controller on the user's original message let reactions = Arc::new(StatusReactionController::new( self.reactions_config.enabled, @@ -200,10 +214,11 @@ impl EventHandler for Handler { )); reactions.set_queued().await; - // Stream prompt with live edits (pass content blocks instead of just text) + // Stream prompt with live edits (pass content blocks instead of just text). + // Returns the final response text so we can persist it. let result = stream_prompt( &self.pool, - &thread_key, + &session_key, content_blocks, &ctx, thread_channel, @@ -212,13 +227,30 @@ impl EventHandler for Handler { ) .await; - match &result { - Ok(()) => reactions.set_done().await, - Err(_) => reactions.set_error().await, + let succeeded = result.is_ok(); + + if succeeded { + reactions.set_done().await; + } else { + reactions.set_error().await; + } + + // Record the assistant's response to the persistent transcript. + if let Ok(ref response_text) = result { + if !response_text.is_empty() { + let store = self.store.clone(); + let key_for_record = session_key.clone(); + let text = response_text.clone(); + tokio::spawn(async move { + if let Err(e) = store.append_message(&key_for_record, "assistant", &text).await { + tracing::warn!(error = %e, "failed to record assistant message to transcript"); + } + }); + } } // Hold emoji briefly then clear - let hold_ms = if result.is_ok() { + let hold_ms = if succeeded { self.reactions_config.timing.done_hold_ms } else { self.reactions_config.timing.error_hold_ms @@ -416,16 +448,16 @@ async fn edit(ctx: &Context, ch: ChannelId, msg_id: MessageId, content: &str) -> async fn stream_prompt( pool: &SessionPool, - thread_key: &str, + session_key: &str, content_blocks: Vec, ctx: &Context, channel: ChannelId, msg_id: MessageId, reactions: Arc, -) -> anyhow::Result<()> { +) -> anyhow::Result { let reactions = reactions.clone(); - pool.with_connection(thread_key, |conn| { + pool.with_connection(session_key, |conn| { let content_blocks = content_blocks.clone(); let ctx = ctx.clone(); let reactions = reactions.clone(); @@ -598,7 +630,10 @@ async fn stream_prompt( } } - Ok(()) + // Return the plain text portion of the response for transcript recording. + // We return text_buf (not final_content which includes tool lines) so the + // stored transcript contains clean, readable assistant text. + Ok(text_buf) }) }) .await diff --git a/src/main.rs b/src/main.rs index 225bf236..6bc83701 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,6 +4,7 @@ mod discord; mod error_display; mod format; mod reactions; +mod session; mod stt; use serenity::prelude::*; @@ -21,8 +22,13 @@ async fn main() -> anyhow::Result<()> { ) .init(); - let config_path = std::env::args() - .nth(1) + let first_arg = std::env::args().nth(1); + if matches!(first_arg.as_deref(), Some("--version") | Some("-V")) { + println!("openab {}", env!("CARGO_PKG_VERSION")); + return Ok(()); + } + + let config_path = first_arg .map(PathBuf::from) .unwrap_or_else(|| PathBuf::from("config.toml")); @@ -36,7 +42,13 @@ async fn main() -> anyhow::Result<()> { "config loaded" ); - let pool = Arc::new(acp::SessionPool::new(cfg.agent, cfg.pool.max_sessions)); + let store = Arc::new(session::SessionStore::new(&cfg.session.dir)); + if let Err(e) = store.init().await { + anyhow::bail!("failed to initialise session store at {}: {e}", cfg.session.dir); + } + info!(dir = %cfg.session.dir, "session store ready"); + + let pool = Arc::new(acp::SessionPool::new(cfg.agent, cfg.pool.max_sessions, store.clone())); let ttl_secs = cfg.pool.session_ttl_hours * 3600; let allowed_channels = parse_id_set(&cfg.discord.allowed_channels, "allowed_channels")?; @@ -61,6 +73,7 @@ async fn main() -> anyhow::Result<()> { let handler = discord::Handler { pool: pool.clone(), + store: store.clone(), allowed_channels, allowed_users, reactions_config: cfg.reactions, diff --git a/src/session/key.rs b/src/session/key.rs new file mode 100644 index 00000000..ce95b3e5 --- /dev/null +++ b/src/session/key.rs @@ -0,0 +1,44 @@ +/// Platform-agnostic session key. +/// +/// Format: `"{platform}:{thread_id}"` +/// +/// Examples: +/// - `"discord:987654321"` (Discord thread) +/// - `"slack:T01234:thread_ts"` (Slack thread, future) +/// - `"telegram:-100123:42"` (Telegram thread, future) +/// +/// Thread IDs within Discord are globally unique, so no parent channel is needed. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct SessionKey(String); + +impl SessionKey { + pub fn discord(thread_id: u64) -> Self { + Self(format!("discord:{thread_id}")) + } + + pub fn as_str(&self) -> &str { + &self.0 + } + + /// Returns a filesystem-safe version (colons replaced with underscores). + /// Used as the JSONL transcript filename. + pub fn to_filename(&self) -> String { + self.0.replace(':', "_") + } + + pub fn platform(&self) -> &str { + self.0.split(':').next().unwrap_or("unknown") + } +} + +impl std::fmt::Display for SessionKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for String { + fn from(k: SessionKey) -> Self { + k.0 + } +} diff --git a/src/session/mod.rs b/src/session/mod.rs new file mode 100644 index 00000000..d690b920 --- /dev/null +++ b/src/session/mod.rs @@ -0,0 +1,5 @@ +pub mod key; +pub mod store; + +pub use key::SessionKey; +pub use store::{SessionMeta, SessionStore, TranscriptEntry}; diff --git a/src/session/store.rs b/src/session/store.rs new file mode 100644 index 00000000..6549d664 --- /dev/null +++ b/src/session/store.rs @@ -0,0 +1,173 @@ +//! File-based session persistence. +//! +//! Layout under `base_dir` (default `/data/sessions`): +//! +//! ```text +//! /data/sessions/ +//! ├── index.json ← all session metadata (atomic-write) +//! ├── discord_123456.jsonl ← transcript for session "discord:123456" +//! └── discord_789012.jsonl +//! ``` +//! +//! The JSONL transcript files are append-only, making them crash-safe: a partial +//! write only corrupts the last line, which `load_transcript` silently skips. +//! +//! `index.json` is written atomically (write-to-tmp + rename) so it is never +//! left in a half-written state after a crash or restart. + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::path::PathBuf; +use tokio::io::AsyncWriteExt; +use tracing::warn; + +/// Metadata for a single session, persisted in `index.json`. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionMeta { + /// Full session key, e.g. `"discord:987654321"`. + pub key: String, + /// Platform name extracted from the key, e.g. `"discord"`. + pub platform: String, + /// Agent command used for this session, e.g. `"claude-code"`. + pub agent: String, + /// Unix timestamp (seconds) when the session was first created. + pub created_at: u64, + /// Unix timestamp (seconds) of the last activity. + pub last_active: u64, +} + +/// A single message in the transcript. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TranscriptEntry { + /// `"user"` or `"assistant"`. + pub role: String, + /// Plain-text content of the message. + pub content: String, + /// Unix timestamp (seconds). + pub ts: u64, +} + +/// Maximum number of transcript entries returned for context restoration. +/// Older messages beyond this limit are ignored to keep the context prompt short. +const MAX_RESTORE_ENTRIES: usize = 20; + +#[derive(Serialize, Deserialize, Default)] +struct Index { + sessions: HashMap, +} + +pub struct SessionStore { + base_dir: PathBuf, +} + +impl SessionStore { + pub fn new(base_dir: impl Into) -> Self { + Self { base_dir: base_dir.into() } + } + + /// Create the storage directory if it does not already exist. + pub async fn init(&self) -> anyhow::Result<()> { + tokio::fs::create_dir_all(&self.base_dir).await?; + Ok(()) + } + + // ── index helpers ──────────────────────────────────────────────────────── + + fn index_path(&self) -> PathBuf { + self.base_dir.join("index.json") + } + + fn transcript_path(&self, key: &str) -> PathBuf { + // Replace colons so the key is a valid filename on all platforms. + let filename = key.replace(':', "_"); + self.base_dir.join(format!("{filename}.jsonl")) + } + + async fn load_index(&self) -> Index { + match tokio::fs::read_to_string(self.index_path()).await { + Ok(s) => serde_json::from_str(&s).unwrap_or_default(), + Err(_) => Index::default(), + } + } + + async fn write_index(&self, idx: &Index) -> anyhow::Result<()> { + let content = serde_json::to_string_pretty(idx)?; + // Atomic write: write to a temp file then rename. + let tmp = self.index_path().with_extension("tmp"); + tokio::fs::write(&tmp, content.as_bytes()).await?; + tokio::fs::rename(&tmp, self.index_path()).await?; + Ok(()) + } + + // ── public API ─────────────────────────────────────────────────────────── + + /// Return metadata for all known sessions. + pub async fn load_all(&self) -> HashMap { + self.load_index().await.sessions + } + + /// Insert or update session metadata in `index.json`. + pub async fn upsert(&self, meta: SessionMeta) -> anyhow::Result<()> { + let mut idx = self.load_index().await; + idx.sessions.insert(meta.key.clone(), meta); + self.write_index(&idx).await + } + + /// Remove a session from `index.json` and delete its transcript file. + pub async fn remove(&self, key: &str) -> anyhow::Result<()> { + let mut idx = self.load_index().await; + idx.sessions.remove(key); + self.write_index(&idx).await?; + let _ = tokio::fs::remove_file(self.transcript_path(key)).await; + Ok(()) + } + + /// Append a single message line to the session's JSONL transcript. + /// + /// The file is created if it does not exist. Because lines are appended + /// one at a time, a crash can only corrupt the last (incomplete) line, + /// which `load_transcript` will silently skip. + pub async fn append_message(&self, key: &str, role: &str, content: &str) -> anyhow::Result<()> { + let entry = TranscriptEntry { + role: role.to_string(), + content: content.to_string(), + ts: now_secs(), + }; + let mut line = serde_json::to_string(&entry)?; + line.push('\n'); + + let mut file = tokio::fs::OpenOptions::new() + .create(true) + .append(true) + .open(self.transcript_path(key)) + .await?; + file.write_all(line.as_bytes()).await?; + Ok(()) + } + + /// Load the most recent transcript entries for context restoration. + /// + /// Returns at most [`MAX_RESTORE_ENTRIES`] entries (the tail of the file), + /// keeping the context prompt small enough to avoid timeouts. + pub async fn load_transcript(&self, key: &str) -> Vec { + match tokio::fs::read_to_string(self.transcript_path(key)).await { + Ok(s) => { + let all: Vec = s + .lines() + .filter_map(|l| serde_json::from_str(l).ok()) + .collect(); + // Return only the tail so context restoration stays concise. + let skip = all.len().saturating_sub(MAX_RESTORE_ENTRIES); + all.into_iter().skip(skip).collect() + } + Err(_) => vec![], + } + } +} + +fn now_secs() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +}