diff --git a/crates/taskito-core/src/worker.rs b/crates/taskito-core/src/worker.rs index fd975ce..e9b451d 100644 --- a/crates/taskito-core/src/worker.rs +++ b/crates/taskito-core/src/worker.rs @@ -13,4 +13,13 @@ pub trait WorkerDispatcher: Send + Sync { /// Signal the pool to stop accepting new work. fn shutdown(&self); + + /// Notify the pool that a running job should be cancelled. + /// + /// Pools that run tasks in-process (e.g. the thread pool) can rely on the + /// storage cancel flag and provide a no-op. Pools that execute tasks in a + /// separate process (e.g. the prefork pool) must use this hook to deliver + /// a side-channel signal so the worker observes the cancel without polling + /// storage. + fn notify_cancel(&self, _job_id: &str) {} } diff --git a/crates/taskito-python/src/prefork/child.rs b/crates/taskito-python/src/prefork/child.rs index 00b449e..a25332c 100644 --- a/crates/taskito-python/src/prefork/child.rs +++ b/crates/taskito-python/src/prefork/child.rs @@ -29,6 +29,15 @@ impl ChildWriter { pub fn send_shutdown(&mut self) { let _ = self.send(&ParentMessage::Shutdown); } + + /// Send a cooperative-cancel request for `job_id`. Returns the underlying + /// I/O error if the pipe is broken so the caller can decide whether to + /// retry or drop the request. + pub fn send_cancel(&mut self, job_id: &str) -> std::io::Result<()> { + self.send(&ParentMessage::Cancel { + job_id: job_id.to_string(), + }) + } } /// Reader half — reads result messages from the child process via stdout. diff --git a/crates/taskito-python/src/prefork/mod.rs b/crates/taskito-python/src/prefork/mod.rs index ce2bab4..61623b4 100644 --- a/crates/taskito-python/src/prefork/mod.rs +++ b/crates/taskito-python/src/prefork/mod.rs @@ -10,6 +10,8 @@ //! - N reader threads: one per child, reads results from stdout, sends to `result_tx` //! - One watchdog thread: enforces per-job timeouts by `SIGKILL`-ing children //! whose deadlines pass +//! - One cancel-router thread: forwards cooperative-cancel requests from +//! `notify_cancel` to the child currently running the named job //! - Child processes: run `python -m taskito.prefork ` mod child; @@ -24,7 +26,7 @@ use std::thread::{self, JoinHandle}; use std::time::{Duration, Instant}; use async_trait::async_trait; -use crossbeam_channel::Sender; +use crossbeam_channel::{Receiver, Sender, TrySendError}; use taskito_core::job::Job; use taskito_core::scheduler::JobResult; @@ -38,12 +40,30 @@ use slot::{ActiveJob, SlotState}; /// sending `SIGKILL`. const SHUTDOWN_DRAIN: Duration = Duration::from_secs(30); +/// Bounded capacity for the cancel side-channel. Cancel requests are tiny +/// and always make progress on the router thread, so this buffer absorbs +/// realistic bursts (workflow-cascade cancels, retried clicks) without ever +/// back-pressuring the caller. +const CANCEL_CHANNEL_CAPACITY: usize = 1024; + +/// Per-child writer collection shared between the dispatch thread, the +/// cancel router, and the restart path. Each slot mirrors the `processes` +/// vector — `None` while a child is being respawned, `Some(writer)` while +/// the child is live. +type WriterPool = Arc>>>; +type ProcessPool = Arc>>>; +type InFlightCounters = Arc>; + /// Multi-process worker pool that dispatches jobs to child Python processes. pub struct PreforkPool { num_workers: usize, app_path: String, python: String, shutdown: AtomicBool, + /// Side-channel for cooperative cancellation. The dispatch loop installs + /// the sender when `run()` starts and clears it on shutdown so + /// `notify_cancel` becomes a no-op once the pool is no longer running. + cancel_tx: Mutex>>, } impl PreforkPool { @@ -55,6 +75,7 @@ impl PreforkPool { app_path, python, shutdown: AtomicBool::new(false), + cancel_tx: Mutex::new(None), } } } @@ -68,24 +89,19 @@ impl WorkerDispatcher for PreforkPool { ) { let num_workers = self.num_workers; - // Shared per-child state. let slots: SlotState = slot::new_slots(num_workers); - let in_flight: Arc> = + let in_flight: InFlightCounters = Arc::new((0..num_workers).map(|_| AtomicU32::new(0)).collect()); - let processes: Arc>>> = - Arc::new((0..num_workers).map(|_| Mutex::new(None)).collect()); - - // Per-child writers stay on the dispatch thread. - let mut writers: Vec> = (0..num_workers).map(|_| None).collect(); + let processes: ProcessPool = Arc::new((0..num_workers).map(|_| Mutex::new(None)).collect()); + let writers: WriterPool = Arc::new((0..num_workers).map(|_| Mutex::new(None)).collect()); let mut reader_handles: Vec> = Vec::new(); - // Initial spawn. for idx in 0..num_workers { if let Some(handle) = start_child( idx, &self.python, &self.app_path, - &mut writers, + &writers, &processes, &slots, &in_flight, @@ -95,17 +111,17 @@ impl WorkerDispatcher for PreforkPool { } } - if writers.iter().all(Option::is_none) { + let live_children = count_live_writers(&writers); + if live_children == 0 { log::error!("[taskito] no prefork children started, aborting"); return; } + log::info!("[taskito] prefork pool running with {live_children} children"); - log::info!( - "[taskito] prefork pool running with {} children", - writers.iter().filter(|w| w.is_some()).count() - ); + let (cancel_tx, cancel_rx) = crossbeam_channel::bounded::(CANCEL_CHANNEL_CAPACITY); + self.set_cancel_sender(Some(cancel_tx)); + let cancel_router = spawn_cancel_router(slots.clone(), writers.clone(), cancel_rx); - // Watchdog: kills children that exceed their per-job timeout. let watchdog_shutdown = Arc::new(AtomicBool::new(false)); let watchdog_handle = watchdog::spawn( slots.clone(), @@ -115,41 +131,32 @@ impl WorkerDispatcher for PreforkPool { watchdog_shutdown.clone(), ); - // Dispatch loop. let mut restart_count: u64 = 0; while let Some(job) = job_rx.recv().await { if self.shutdown.load(Ordering::Relaxed) { break; } - // Bring back any children that have exited (crashed, killed by - // watchdog, OOM, etc.). for idx in 0..num_workers { - let dead = match processes[idx].lock() { - Ok(mut guard) => match guard.as_mut() { - Some(p) => !p.is_alive(), - None => true, - }, - Err(_) => false, - }; - if dead { - log::warn!("[taskito] prefork child {idx} died, restarting"); - restart_count += 1; - if let Some(handle) = start_child( - idx, - &self.python, - &self.app_path, - &mut writers, - &processes, - &slots, - &in_flight, - &result_tx, - ) { - reader_handles.push(handle); - log::info!( - "[taskito] prefork child {idx} restarted (total restarts: {restart_count})" - ); - } + if !is_child_dead(&processes, idx) { + continue; + } + log::warn!("[taskito] prefork child {idx} died, restarting"); + restart_count += 1; + if let Some(handle) = start_child( + idx, + &self.python, + &self.app_path, + &writers, + &processes, + &slots, + &in_flight, + &result_tx, + ) { + reader_handles.push(handle); + log::info!( + "[taskito] prefork child {idx} restarted (total restarts: {restart_count})" + ); } } @@ -159,59 +166,26 @@ impl WorkerDispatcher for PreforkPool { .collect(); let idx = dispatch::least_loaded(&counts); - let Some(writer) = writers[idx].as_mut() else { - log::error!( - "[taskito] no live writer for child {idx}, dropping job {}; will be reaped", - job.id - ); - continue; - }; - - let active = ActiveJob { - job_id: job.id.clone(), - task_name: job.task_name.clone(), - retry_count: job.retry_count, - max_retries: job.max_retries, - timeout_ms: job.timeout_ms, - started_at: Instant::now(), - deadline: deadline_from_timeout(job.timeout_ms), - }; - - // Register *before* sending so a fast child can never publish a - // result the reader can't pair with a slot entry. - slot::set(&slots, idx, active); - - let msg = ParentMessage::from(&job); - match writer.send(&msg) { - Ok(()) => { - in_flight[idx].fetch_add(1, Ordering::Relaxed); - } - Err(e) => { - // Roll back the slot install — neither reader nor watchdog - // should fire for this aborted dispatch. - let _ = slot::take(&slots, idx); - log::error!( - "[taskito] failed to send job {} to child {idx}: {e}", - job.id - ); - // Job will be reaped by the scheduler's stale-job reaper. - } - } + dispatch_job(idx, job, &writers, &slots, &in_flight); } + // Stop accepting new cancel requests so the router can drain and exit + // cleanly while writers are still alive. + self.set_cancel_sender(None); + // Stop the watchdog before sending shutdown so it doesn't race with // children draining their final results. watchdog_shutdown.store(true, Ordering::SeqCst); - // Graceful shutdown: tell all live children to stop. - for (idx, writer) in writers.iter_mut().enumerate() { - if let Some(w) = writer.as_mut() { - w.send_shutdown(); - log::info!("[taskito] sent shutdown to prefork child {idx}"); + for idx in 0..num_workers { + if let Ok(mut guard) = writers[idx].lock() { + if let Some(w) = guard.as_mut() { + w.send_shutdown(); + log::info!("[taskito] sent shutdown to prefork child {idx}"); + } } } - // Wait for children to exit (or kill after the drain timeout). for idx in 0..num_workers { if let Ok(mut guard) = processes[idx].lock() { if let Some(process) = guard.as_mut() { @@ -221,16 +195,124 @@ impl WorkerDispatcher for PreforkPool { } } - // Reader threads exit when their child closes stdout. + // Drop writers so the cancel router observes `Disconnected` on its + // receiver and exits — otherwise the router thread would leak. + for slot in writers.iter() { + if let Ok(mut guard) = slot.lock() { + *guard = None; + } + } + for handle in reader_handles { let _ = handle.join(); } let _ = watchdog_handle.join(); + let _ = cancel_router.join(); } fn shutdown(&self) { self.shutdown.store(true, Ordering::SeqCst); } + + fn notify_cancel(&self, job_id: &str) { + let Ok(guard) = self.cancel_tx.lock() else { + return; + }; + let Some(tx) = guard.as_ref() else { + return; + }; + match tx.try_send(job_id.to_string()) { + Ok(()) | Err(TrySendError::Disconnected(_)) => {} + Err(TrySendError::Full(_)) => { + log::warn!("[taskito] prefork cancel channel full, dropping cancel for {job_id}"); + } + } + } +} + +impl PreforkPool { + fn set_cancel_sender(&self, tx: Option>) { + if let Ok(mut guard) = self.cancel_tx.lock() { + *guard = tx; + } + } +} + +/// Count children with a live writer (i.e. successfully spawned and not +/// torn down). Used at startup to fail fast if every spawn attempt failed. +fn count_live_writers(writers: &WriterPool) -> usize { + writers + .iter() + .filter(|slot| slot.lock().map(|g| g.is_some()).unwrap_or(false)) + .count() +} + +/// Whether the child at `idx` has exited (or never spawned successfully). +fn is_child_dead(processes: &ProcessPool, idx: usize) -> bool { + match processes[idx].lock() { + Ok(mut guard) => match guard.as_mut() { + Some(p) => !p.is_alive(), + None => true, + }, + Err(_) => false, + } +} + +/// Push a job to child `idx`. The slot is registered before sending so a +/// fast child cannot publish a result the reader can't pair with a slot +/// entry; on send failure the slot is rolled back so neither the reader +/// nor the watchdog will fire for this aborted dispatch. +fn dispatch_job( + idx: usize, + job: Job, + writers: &WriterPool, + slots: &SlotState, + in_flight: &InFlightCounters, +) { + let active = ActiveJob { + job_id: job.id.clone(), + task_name: job.task_name.clone(), + retry_count: job.retry_count, + max_retries: job.max_retries, + timeout_ms: job.timeout_ms, + started_at: Instant::now(), + deadline: deadline_from_timeout(job.timeout_ms), + }; + slot::set(slots, idx, active); + + let msg = ParentMessage::from(&job); + let send_result = match writers[idx].lock() { + Ok(mut guard) => match guard.as_mut() { + Some(writer) => writer.send(&msg), + None => { + drop(guard); + let _ = slot::take(slots, idx); + log::error!( + "[taskito] no live writer for child {idx}, dropping job {}; will be reaped", + job.id + ); + return; + } + }, + Err(_) => { + let _ = slot::take(slots, idx); + log::error!("[taskito] writer mutex poisoned for child {idx}"); + return; + } + }; + + match send_result { + Ok(()) => { + in_flight[idx].fetch_add(1, Ordering::Relaxed); + } + Err(e) => { + let _ = slot::take(slots, idx); + log::error!( + "[taskito] failed to send job {} to child {idx}: {e}", + job.id + ); + } + } } /// Spawn child `idx` and its reader thread, plumbing the writer + process into @@ -241,18 +323,20 @@ fn start_child( idx: usize, python: &str, app_path: &str, - writers: &mut [Option], - processes: &Arc>>>, + writers: &WriterPool, + processes: &ProcessPool, slots: &SlotState, - in_flight: &Arc>, + in_flight: &InFlightCounters, result_tx: &Sender, ) -> Option> { match spawn_child(python, app_path) { Ok((writer, reader, process)) => { log::info!("[taskito] prefork child {idx} ready"); - writers[idx] = Some(writer); - if let Ok(mut slot) = processes[idx].lock() { - *slot = Some(process); + if let Ok(mut guard) = writers[idx].lock() { + *guard = Some(writer); + } + if let Ok(mut guard) = processes[idx].lock() { + *guard = Some(process); } // Reset the slot for the new child — the killed/dead one's job (if // any) was already completed by the watchdog or shutdown path. @@ -284,7 +368,7 @@ fn spawn_reader_thread( idx: usize, mut reader: ChildReader, slots: SlotState, - in_flight: Arc>, + in_flight: InFlightCounters, result_tx: Sender, ) -> JoinHandle<()> { thread::Builder::new() @@ -302,7 +386,7 @@ fn spawn_reader_thread( } in_flight[idx].fetch_sub(1, Ordering::Relaxed); if result_tx.send(job_result).is_err() { - break; // result channel closed + break; } } Err(e) => { @@ -314,6 +398,43 @@ fn spawn_reader_thread( .expect("failed to spawn prefork reader thread") } +/// Cancel router: forwards cooperative-cancel requests to the child +/// currently running the named job. +/// +/// The router never owns the slot — it only consults `find_by_job_id` to +/// route the message. Result/timeout completion still owns the slot +/// `take()`. If the job is no longer running (already completed, never +/// dispatched, or just finished between `notify_cancel` and the router +/// pick-up), the request is dropped silently — the storage cancel flag +/// set by `Storage::request_cancel` already handles those cases. +fn spawn_cancel_router( + slots: SlotState, + writers: WriterPool, + cancel_rx: Receiver, +) -> JoinHandle<()> { + thread::Builder::new() + .name("taskito-prefork-cancel-router".into()) + .spawn(move || { + for job_id in cancel_rx.iter() { + let Some(idx) = slot::find_by_job_id(&slots, &job_id) else { + continue; + }; + let Ok(mut guard) = writers[idx].lock() else { + continue; + }; + let Some(writer) = guard.as_mut() else { + continue; + }; + if let Err(e) = writer.send_cancel(&job_id) { + log::warn!( + "[taskito] failed to forward cancel for {job_id} to child {idx}: {e}" + ); + } + } + }) + .expect("failed to spawn prefork cancel-router thread") +} + /// Convert a per-task timeout in milliseconds to an absolute `Instant` deadline. /// Returns `None` for `timeout_ms <= 0` (no timeout configured) so the watchdog /// skips the slot. diff --git a/crates/taskito-python/src/prefork/protocol.rs b/crates/taskito-python/src/prefork/protocol.rs index fbbc96e..3405a89 100644 --- a/crates/taskito-python/src/prefork/protocol.rs +++ b/crates/taskito-python/src/prefork/protocol.rs @@ -23,6 +23,12 @@ pub enum ParentMessage { timeout_ms: i64, namespace: Option, }, + /// Cooperative-cancel request for a job currently running (or queued in + /// the child's stdin buffer). The child marks `job_id` as cancelled so + /// `current_job.check_cancelled()` raises `TaskCancelledError`. + Cancel { + job_id: String, + }, Shutdown, } diff --git a/crates/taskito-python/src/prefork/slot.rs b/crates/taskito-python/src/prefork/slot.rs index e17b8bd..3d93be6 100644 --- a/crates/taskito-python/src/prefork/slot.rs +++ b/crates/taskito-python/src/prefork/slot.rs @@ -60,3 +60,18 @@ pub fn take_if_expired(slots: &SlotState, idx: usize, now: Instant) -> Option Option { + for (idx, slot) in slots.iter().enumerate() { + let guard = slot.lock().expect("slot mutex poisoned"); + if guard.as_ref().is_some_and(|j| j.job_id == job_id) { + return Some(idx); + } + } + None +} diff --git a/crates/taskito-python/src/py_queue/mod.rs b/crates/taskito-python/src/py_queue/mod.rs index 7190bda..aac4eff 100644 --- a/crates/taskito-python/src/py_queue/mod.rs +++ b/crates/taskito-python/src/py_queue/mod.rs @@ -7,7 +7,7 @@ mod worker; mod workflow_ops; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use pyo3::prelude::*; use pyo3::types::PyDict; @@ -21,6 +21,7 @@ use taskito_core::storage::postgres::PostgresStorage; use taskito_core::storage::redis_backend::RedisStorage; use taskito_core::storage::sqlite::SqliteStorage; use taskito_core::storage::{Storage, StorageBackend}; +use taskito_core::worker::WorkerDispatcher; use crate::py_job::PyJob; @@ -39,6 +40,11 @@ pub struct PyQueue { pub(crate) scheduler_reap_interval: u32, pub(crate) scheduler_cleanup_interval: u32, pub(crate) namespace: Option, + /// Active worker dispatcher, set while `run_worker` is executing. Used by + /// `request_cancel` to deliver a side-channel signal to pools that run + /// tasks out-of-process (prefork). For in-process pools the trait's + /// default no-op makes this a free notification. + pub(crate) dispatcher: Arc>>>, /// Cached workflow storage handle. Lazily initialized on first workflow API /// call; migrations run exactly once per `PyQueue` instance instead of /// per-call. @@ -137,6 +143,7 @@ impl PyQueue { scheduler_reap_interval, scheduler_cleanup_interval, namespace, + dispatcher: Arc::new(Mutex::new(None)), #[cfg(feature = "workflows")] workflow_storage: std::sync::OnceLock::new(), }) @@ -362,10 +369,19 @@ impl PyQueue { } /// Request cancellation of a running job. Returns True if cancel was requested. + /// + /// Sets the storage cancel flag and, if a worker pool is currently running, + /// notifies it via a side channel so out-of-process pools (prefork) can + /// observe the cancel without polling storage. pub fn request_cancel(&self, job_id: &str) -> PyResult { - self.storage + let requested = self + .storage .request_cancel(job_id) - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string())) + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + if requested { + self.notify_dispatcher_cancel(job_id); + } + Ok(requested) } /// Check if cancellation has been requested for a job. @@ -602,3 +618,24 @@ impl PyQueue { ) } } + +impl PyQueue { + /// Install the active worker dispatcher. Called by `run_worker` before + /// the dispatch loop starts so `request_cancel` can deliver out-of-band + /// cancel signals to the running pool. Pass `None` on shutdown. + pub(crate) fn set_dispatcher(&self, dispatcher: Option>) { + if let Ok(mut guard) = self.dispatcher.lock() { + *guard = dispatcher; + } + } + + fn notify_dispatcher_cancel(&self, job_id: &str) { + let dispatcher = match self.dispatcher.lock() { + Ok(guard) => guard.clone(), + Err(_) => return, + }; + if let Some(d) = dispatcher { + d.notify_cancel(job_id); + } + } +} diff --git a/crates/taskito-python/src/py_queue/worker.rs b/crates/taskito-python/src/py_queue/worker.rs index e7c4c0a..df28658 100644 --- a/crates/taskito-python/src/py_queue/worker.rs +++ b/crates/taskito-python/src/py_queue/worker.rs @@ -326,7 +326,7 @@ impl PyQueue { let (result_tx, result_rx) = crossbeam_channel::bounded(self.num_workers * 2); let registry_arc = Arc::new(task_registry); - let filters_arc = Arc::new(retry_filters.into()); + let filters_arc: Arc = Arc::new(retry_filters.into()); let scheduler_arc = Arc::new(scheduler); let scheduler_for_dispatch = scheduler_arc.clone(); @@ -368,10 +368,49 @@ impl PyQueue { })? }; - // Create multi-threaded tokio runtime for scheduler + worker pool + // Build the dispatcher up front for the prefork case so we can install + // it on the queue before the run loop starts — request_cancel relies on + // the install to deliver out-of-band cancel signals to child processes. + // + // For in-process pools (native-async, classic async) `notify_cancel` is + // a no-op — running tasks observe cancellation via the storage flag — + // so we deliberately do NOT install the dispatcher on `self`. + // Installing it would keep an `Arc` reference (the async + // executor / PyResultSender chain) alive on the parent thread until + // `set_dispatcher(None)` runs after `runtime_handle.join()`, deadlocking + // shutdown: the drain loop waits for the result channel to disconnect, + // which can't happen until PyResultSender drops, which can't happen + // until both `async_executor` Arc refs drop — and the second ref is + // exactly what `self.dispatcher` was holding. let num_workers = self.num_workers; let use_prefork = pool.as_deref() == Some("prefork"); - let prefork_app_path = app_path; + let dispatcher_for_run: Arc = if use_prefork { + let pool_arc: Arc = Arc::new( + crate::prefork::PreforkPool::new(num_workers, app_path.unwrap_or_default()), + ); + self.set_dispatcher(Some(pool_arc.clone())); + pool_arc + } else { + #[cfg(feature = "native-async")] + { + let pool_arc: Arc = + Arc::new(taskito_async::NativeAsyncPool::new( + num_workers, + registry_arc.clone(), + filters_arc.clone(), + async_executor, + )); + pool_arc + } + #[cfg(not(feature = "native-async"))] + { + let pool_arc: Arc = Arc::new( + AsyncWorkerPool::new(num_workers, registry_arc.clone(), filters_arc.clone()), + ); + pool_arc + } + }; + // Move result_tx into the runtime — don't keep a copy in the main thread // so result_rx disconnects when all workers are done. let runtime_handle = std::thread::spawn(move || { @@ -393,30 +432,7 @@ impl PyQueue { }); let worker_task = tokio::spawn(async move { - use taskito_core::worker::WorkerDispatcher; - - if use_prefork { - let app = prefork_app_path.unwrap_or_default(); - let pool = crate::prefork::PreforkPool::new(num_workers, app); - pool.run(job_rx, result_tx).await; - } else { - #[cfg(feature = "native-async")] - { - let pool = taskito_async::NativeAsyncPool::new( - num_workers, - registry_arc, - filters_arc, - async_executor, - ); - pool.run(job_rx, result_tx).await; - } - - #[cfg(not(feature = "native-async"))] - { - let pool = AsyncWorkerPool::new(num_workers, registry_arc, filters_arc); - pool.run(job_rx, result_tx).await; - } - } + dispatcher_for_run.run(job_rx, result_tx).await; }); let _ = tokio::join!(scheduler_task, worker_task); @@ -508,6 +524,10 @@ impl PyQueue { let _ = runtime_handle.join(); + // Clear the dispatcher reference so post-shutdown cancel requests + // become no-ops instead of forwarding to a torn-down pool. + self.set_dispatcher(None); + // Unregister worker on shutdown let _ = self.storage.unregister_worker(&worker_id); diff --git a/py_src/taskito/context.py b/py_src/taskito/context.py index af26b18..2dc55e4 100644 --- a/py_src/taskito/context.py +++ b/py_src/taskito/context.py @@ -6,6 +6,7 @@ import logging import threading import time +from collections.abc import Callable from typing import TYPE_CHECKING, Any from taskito._active_context import _ActiveContext @@ -20,6 +21,30 @@ _local = threading.local() _queue_ref: Queue | None = None +# Optional in-process cancel signal. When set, ``check_cancelled()`` consults +# this hook before falling back to storage. The prefork child installs one +# that reads a local cancel set populated by the IPC reader thread; the +# default (None) preserves storage-backed behaviour for the thread pool. +_local_cancel_check: Callable[[str], bool] | None = None + + +def set_local_cancel_check(fn: Callable[[str], bool]) -> None: + """Install a process-local cancel check used by ``check_cancelled()``. + + The callable receives the current ``job_id`` and returns ``True`` if the + job has been cancelled. Installed by out-of-process workers (prefork) + that receive cancel signals over IPC and want ``check_cancelled()`` to + react without polling storage. + """ + global _local_cancel_check + _local_cancel_check = fn + + +def clear_local_cancel_check() -> None: + """Remove a previously installed local cancel check.""" + global _local_cancel_check + _local_cancel_check = None + class JobContext: """Provides access to the currently executing job's metadata and controls. @@ -117,6 +142,12 @@ def check_cancelled(self) -> None: TaskCancelledError: If the job has been marked for cancellation. """ ctx = self._require_context() + # Fast path: a worker that received an out-of-band cancel signal + # (e.g. the prefork child via IPC) installs a local check so we can + # observe the cancel without a storage round-trip. + local_check = _local_cancel_check + if local_check is not None and local_check(ctx.job_id): + raise TaskCancelledError(f"Job {ctx.job_id} was cancelled") if _queue_ref is None: raise RuntimeError("Queue reference not set.") if _queue_ref._inner.is_cancel_requested(ctx.job_id): diff --git a/py_src/taskito/prefork/child.py b/py_src/taskito/prefork/child.py index b7486ad..c48c65c 100644 --- a/py_src/taskito/prefork/child.py +++ b/py_src/taskito/prefork/child.py @@ -3,7 +3,12 @@ Each child is an independent Python interpreter that: 1. Imports the app module and builds the task registry. 2. Initializes resources (if any). -3. Reads JSON job messages from stdin, executes tasks, writes JSON results to stdout. +3. Runs a stdin reader thread that demultiplexes ``job``, ``cancel``, and + ``shutdown`` messages from the parent. Jobs go on an internal queue; + cancels populate a local set that ``current_job.check_cancelled()`` reads + via a registered hook. +4. Pulls jobs off the internal queue on the main thread, executes them, + and writes JSON results to stdout. Spawned by the Rust ``PreforkPool`` via ``python -m taskito.prefork ``. """ @@ -15,17 +20,29 @@ import json import logging import os +import queue as _queue_mod import sys +import threading import time import traceback from typing import Any from taskito.async_support.helpers import run_maybe_async -from taskito.context import _clear_context, _set_context, _set_queue_ref +from taskito.context import ( + _clear_context, + _set_context, + _set_queue_ref, + clear_local_cancel_check, + set_local_cancel_check, +) from taskito.exceptions import TaskCancelledError logger = logging.getLogger("taskito.prefork.child") +# Sentinel pushed onto the internal job queue when the parent requests +# shutdown so the main loop can terminate without polling. +_SHUTDOWN_SENTINEL: dict[str, Any] = {"__shutdown__": True} + def _import_queue(app_path: str) -> Any: """Import and return the Queue instance from a dotted path like 'myapp:queue'.""" @@ -43,6 +60,31 @@ def _write_message(msg: dict[str, Any]) -> None: sys.stdout.flush() +class _CancelSignal: + """Thread-safe set of job IDs the parent has asked us to cancel. + + Cancel messages may arrive before, during, or after the job they target; + keeping the IDs around until the corresponding result is written means + a cancel that races a job's start still fires deterministically. + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + self._ids: set[str] = set() + + def request(self, job_id: str) -> None: + with self._lock: + self._ids.add(job_id) + + def is_requested(self, job_id: str) -> bool: + with self._lock: + return job_id in self._ids + + def discard(self, job_id: str) -> None: + with self._lock: + self._ids.discard(job_id) + + def _execute_job( queue: Any, job: dict[str, Any], @@ -69,18 +111,12 @@ def _execute_job( "timed_out": False, } - # Set job context _set_context(job_id, task_name, retry_count, job.get("queue", "default")) start_ns = time.monotonic_ns() try: - # Deserialize payload args, kwargs = queue._deserialize_payload(task_name, payload) - - # Call the wrapped task function (handles middleware, resources, proxies) result = run_maybe_async(wrapper(*args, **kwargs)) - - # Serialize result result_bytes = queue._serializer.dumps(result) if result is not None else None wall_time_ns = time.monotonic_ns() - start_ns @@ -106,7 +142,6 @@ def _execute_job( error_msg = traceback.format_exc() logger.error("task %s[%s] failed: %s", task_name, job_id, error_msg.splitlines()[-1]) - # Check retry filters should_retry = True filters = queue._task_retry_filters.get(task_name) if filters: @@ -137,6 +172,53 @@ def _execute_job( _clear_context() +def _spawn_stdin_reader( + job_queue: _queue_mod.Queue[dict[str, Any]], + cancels: _CancelSignal, +) -> threading.Thread: + """Run a background thread that demultiplexes parent → child messages. + + The main thread is blocked inside ``_execute_job`` while a job is + running, so reading stdin must happen elsewhere. This thread converts + the line-delimited JSON stream into queue items + cancel-set updates. + """ + + def reader() -> None: + try: + for line in sys.stdin: + line = line.strip() + if not line: + continue + try: + msg = json.loads(line) + except json.JSONDecodeError as e: + logger.warning("invalid IPC message from parent: %s", e) + continue + + msg_type = msg.get("type") + if msg_type == "shutdown": + job_queue.put(_SHUTDOWN_SENTINEL) + return + if msg_type == "job": + job_queue.put(msg) + elif msg_type == "cancel": + job_id = msg.get("job_id") + if isinstance(job_id, str): + cancels.request(job_id) + else: + logger.warning("unknown IPC message type: %r", msg_type) + except (BrokenPipeError, EOFError, KeyboardInterrupt): + logger.debug("child stdin closed") + finally: + # Ensure the main loop wakes up even if stdin closed without a + # shutdown message (e.g. the parent died). + job_queue.put(_SHUTDOWN_SENTINEL) + + thread = threading.Thread(target=reader, name="taskito-prefork-stdin", daemon=True) + thread.start() + return thread + + def main() -> None: """Child process main loop. Called via ``python -m taskito.prefork ``.""" if len(sys.argv) < 2: @@ -151,41 +233,37 @@ def main() -> None: if cwd not in sys.path: sys.path.insert(0, cwd) - # Import the queue and set up context queue = _import_queue(app_path) _set_queue_ref(queue) - # Initialize resources if any are defined runtime = queue._resource_runtime if runtime is not None: runtime.initialize() - # Signal readiness + job_queue: _queue_mod.Queue[dict[str, Any]] = _queue_mod.Queue() + cancels = _CancelSignal() + set_local_cancel_check(cancels.is_requested) + _spawn_stdin_reader(job_queue, cancels) + _write_message({"type": "ready"}) logger.info("child ready (app=%s, pid=%d)", app_path, os.getpid()) - # Main loop: read jobs from stdin, execute, write results to stdout try: - for line in sys.stdin: - line = line.strip() - if not line: - continue - - msg = json.loads(line) - - if msg.get("type") == "shutdown": - sys.stdout.flush() + while True: + msg = job_queue.get() + if msg is _SHUTDOWN_SENTINEL: break - - if msg.get("type") == "job": - result = _execute_job(queue, msg) - _write_message(result) - + result = _execute_job(queue, msg) + _write_message(result) + # Drop the cancel marker once the result is written so a future + # job with the same ID (extremely unlikely, but possible across + # ID-reuse boundaries) does not auto-cancel. + cancels.discard(result.get("job_id", "")) except (BrokenPipeError, EOFError, KeyboardInterrupt): - logger.debug("child pipe closed or interrupted") + logger.debug("child output pipe closed or interrupted") finally: - # Teardown resources + clear_local_cancel_check() if runtime is not None: try: runtime.teardown() diff --git a/tests/python/prefork_apps/cancel_app.py b/tests/python/prefork_apps/cancel_app.py new file mode 100644 index 0000000..20bf770 --- /dev/null +++ b/tests/python/prefork_apps/cancel_app.py @@ -0,0 +1,33 @@ +"""Module-level Queue + tasks used by the prefork cancel regression tests. + +The Queue inside this module must be importable both in the parent test +process and inside each prefork child interpreter. The DB path comes from +``TASKITO_CANCEL_TEST_DB`` so each test run can use its own tmp file while +still letting the parent and child build identical Queue instances from +the same module path. +""" + +from __future__ import annotations + +import os +import time + +from taskito import Queue +from taskito.context import current_job + +queue = Queue(db_path=os.environ.get("TASKITO_CANCEL_TEST_DB", "/tmp/taskito-cancel.db")) + + +@queue.task(timeout=30, max_retries=0) +def cooperative_loop(max_iters: int = 600) -> int: + """Loop calling ``check_cancelled()`` so cancel can stop the task quickly.""" + for _ in range(max_iters): + current_job.check_cancelled() + time.sleep(0.05) + return max_iters + + +@queue.task(max_retries=0) +def quick(x: int) -> int: + """Returns immediately — used to verify the child still serves jobs after a cancel.""" + return x * 2 diff --git a/tests/python/test_prefork.py b/tests/python/test_prefork.py index 02897a6..8277db1 100644 --- a/tests/python/test_prefork.py +++ b/tests/python/test_prefork.py @@ -213,3 +213,104 @@ def test_prefork_finishes_before_deadline(timeout_app: object) -> None: result = job.result(timeout=15) assert result == "done" + + +# --------------------------------------------------------------------------- +# Cooperative cancellation propagation (issue #82) +# --------------------------------------------------------------------------- + +CANCEL_APP_PATH = "prefork_apps.cancel_app:queue" + + +@pytest.fixture +def cancel_app(tmp_path: Path) -> Iterator[object]: + """Set up the module-level cancel-test app with a per-test DB path. + + Mirrors ``timeout_app`` — must set the env var before the import so the + parent's Queue construction and the child's re-import see the same DB. + """ + db_path = str(tmp_path / "cancel.db") + prev_db = os.environ.get("TASKITO_CANCEL_TEST_DB") + prev_pythonpath = os.environ.get("PYTHONPATH") + + os.environ["TASKITO_CANCEL_TEST_DB"] = db_path + os.environ["PYTHONPATH"] = ( + f"{PREFORK_APP_DIR}{os.pathsep}{prev_pythonpath}" if prev_pythonpath else PREFORK_APP_DIR + ) + if PREFORK_APP_DIR not in sys.path: + sys.path.insert(0, PREFORK_APP_DIR) + + sys.modules.pop("prefork_apps.cancel_app", None) + sys.modules.pop("prefork_apps", None) + module = importlib.import_module("prefork_apps.cancel_app") + + try: + yield module + finally: + with contextlib.suppress(Exception): + module.queue._inner.request_shutdown() + if prev_db is None: + os.environ.pop("TASKITO_CANCEL_TEST_DB", None) + else: + os.environ["TASKITO_CANCEL_TEST_DB"] = prev_db + if prev_pythonpath is None: + os.environ.pop("PYTHONPATH", None) + else: + os.environ["PYTHONPATH"] = prev_pythonpath + + +def _start_cancel_worker(queue: Queue) -> threading.Thread: + thread = threading.Thread( + target=queue.run_worker, + kwargs={"pool": "prefork", "app": CANCEL_APP_PATH}, + daemon=True, + ) + thread.start() + return thread + + +@prefork_unix_only +def test_prefork_cancel_running_job_stops_quickly(cancel_app: object) -> None: + """``cancel_running_job`` propagates to the prefork child and stops a + cooperative task within a small budget — the regression test for #82.""" + queue: Queue = cancel_app.queue # type: ignore[attr-defined] + + cancels_seen: list[str] = [] + + class CancelSpy(TaskMiddleware): + def on_cancel(self, ctx: JobContext) -> None: + cancels_seen.append(ctx.id) + + queue._global_middleware.append(CancelSpy()) + + job = cancel_app.cooperative_loop.delay(600) # type: ignore[attr-defined] + _start_cancel_worker(queue) + + # Give the worker time to dispatch the job to a child and let the loop + # start spinning before we cancel. + time.sleep(1.0) + + assert queue.cancel_running_job(job.id) is True + + status = _wait_for_terminal(job, timeout=10) + assert status == "cancelled", f"expected 'cancelled', got {status!r} (error={job.error!r})" + assert job.id in cancels_seen, "on_cancel middleware did not fire" + + +@prefork_unix_only +def test_prefork_cancel_does_not_kill_child(cancel_app: object) -> None: + """A cancel must stop the running task without killing the child — the + next job dispatched to the same pool should still complete normally.""" + queue: Queue = cancel_app.queue # type: ignore[attr-defined] + + long_job = cancel_app.cooperative_loop.delay(600) # type: ignore[attr-defined] + _start_cancel_worker(queue) + + time.sleep(1.0) + assert queue.cancel_running_job(long_job.id) is True + status = _wait_for_terminal(long_job, timeout=10) + assert status == "cancelled" + + follow_up = cancel_app.quick.delay(21) # type: ignore[attr-defined] + result = follow_up.result(timeout=15) + assert result == 42