diff --git a/src/buffer.rs b/src/buffer.rs index f7350c7..245a67e 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -84,21 +84,17 @@ impl BufferReader { #[inline] fn ensure_bytes(&self, count: usize) -> Result<()> { if self.data.remaining() < count { - return Err(ReplicationError::protocol(format!( - "Not enough bytes remaining. Need {}, have {}", - count, - self.data.remaining() - ))); + return Self::short_buffer_err(count, self.data.remaining()); } Ok(()) } - /// Skip the message type byte and return current position - #[inline] - pub fn skip_message_type(&mut self) -> Result { - self.ensure_bytes(1)?; - self.data.advance(1); - Ok(self.data.len()) + #[cold] + #[inline(never)] + fn short_buffer_err(needed: usize, have: usize) -> Result<()> { + Err(ReplicationError::protocol(format!( + "Not enough bytes remaining. Need {needed}, have {have}" + ))) } /// Read a single byte @@ -601,15 +597,6 @@ mod tests { assert_eq!(reader.remaining(), 2); } - #[test] - fn test_buffer_reader_skip_message_type() { - let data = [0x42, 0x01, 0x02, 0x03]; // 'B' message type - let mut reader = BufferReader::new(&data); - - reader.skip_message_type().unwrap(); - assert_eq!(reader.read_u8().unwrap(), 0x01); - } - #[test] fn test_buffer_writer_signed_integers() { let mut writer = BufferWriter::new(); @@ -868,10 +855,21 @@ mod tests { assert!(reader.read_bytes_buf(5).is_err()); } + /// Pin the error message format produced by the `#[cold]` + /// `short_buffer_err` helper. Several layers above (parser, stream) + /// surface this string in logs, so format regressions would silently + /// degrade diagnostics. #[test] - fn test_buffer_reader_skip_message_type_empty() { - let data: &[u8] = &[]; - let mut reader = BufferReader::new(data); - assert!(reader.skip_message_type().is_err()); + fn test_buffer_reader_short_buffer_err_message_format() { + let data = [0x01, 0x02]; + let mut reader = BufferReader::new(&data); + let err = reader.read_bytes_buf(5).unwrap_err(); + let s = err.to_string(); + assert!( + s.contains("Not enough bytes remaining"), + "expected 'Not enough bytes remaining' in error, got: {s}" + ); + assert!(s.contains("Need 5"), "expected 'Need 5', got: {s}"); + assert!(s.contains("have 2"), "expected 'have 2', got: {s}"); } } diff --git a/src/lsn.rs b/src/lsn.rs index 4233c90..b5d56cc 100644 --- a/src/lsn.rs +++ b/src/lsn.rs @@ -101,13 +101,14 @@ impl SharedLsnFeedback { /// /// This should be called when data has been written/flushed to the destination /// database, but not yet committed (e.g., during batch writes). + #[inline] pub fn update_flushed_lsn(&self, lsn: XLogRecPtr) { if lsn == 0 { return; } + let mut current = self.flushed_lsn.load(Ordering::Acquire); loop { - let current = self.flushed_lsn.load(Ordering::Acquire); if lsn <= current { return; } @@ -124,7 +125,7 @@ impl SharedLsnFeedback { ); return; } - Err(_) => continue, + Err(actual) => current = actual, } } } @@ -134,12 +135,15 @@ impl SharedLsnFeedback { /// This should be called when a transaction has been successfully committed /// to the destination database. This is the most important LSN as PostgreSQL /// uses it to determine which WAL can be recycled. + #[inline] pub fn update_applied_lsn(&self, lsn: XLogRecPtr) { if lsn == 0 { return; } + + let mut current = self.applied_lsn.load(Ordering::Acquire); + let mut advanced = false; loop { - let current = self.applied_lsn.load(Ordering::Acquire); if lsn <= current { break; } @@ -154,14 +158,17 @@ impl SharedLsnFeedback { "SharedLsnFeedback: Updated applied LSN from {} to {}", current, lsn ); + advanced = true; break; } - Err(_) => continue, + Err(actual) => current = actual, } } - // Applied data is implicitly flushed, update flushed as well - self.update_flushed_lsn(lsn); + // Applied data is implicitly flushed. Only chase the flushed CAS when we actually moved applied forward; otherwise flushed cannot be behind. + if advanced { + self.update_flushed_lsn(lsn); + } } /// Get the current flushed LSN @@ -483,4 +490,90 @@ mod tests { feedback.update_applied_lsn(50); assert_eq!(feedback.get_applied_lsn(), 50); } + + /// Deterministic stress test that forces the `Err(actual) => current = actual` + /// branch in `update_flushed_lsn` and `update_applied_lsn` to fire. + /// + /// `compare_exchange_weak` may also fail spuriously on weak memory models + /// (e.g. ARM) even without contention, so on those targets the loop body + /// is exercised even by a single-threaded run. On x86 we additionally + /// fan out across many threads so the lost-update path is always touched + /// regardless of the CAS strength. + #[test] + fn test_concurrent_cas_retry_path() { + use std::sync::Arc; + use std::sync::Barrier; + use std::thread; + + const THREADS: usize = 16; + const ITERS: u64 = 5_000; + + let feedback = SharedLsnFeedback::new_shared(); + let barrier = Arc::new(Barrier::new(THREADS)); + + let mut handles = Vec::with_capacity(THREADS); + for tid in 0..THREADS as u64 { + let fb = Arc::clone(&feedback); + let bar = Arc::clone(&barrier); + handles.push(thread::spawn(move || { + bar.wait(); + // Each thread proposes a strictly increasing sequence of LSNs; + // collisions between threads guarantee CAS contention and force + // the `Err(actual)` branch to update `current` and retry. + for i in 1..=ITERS { + let lsn = i * THREADS as u64 + tid; + fb.update_flushed_lsn(lsn); + fb.update_applied_lsn(lsn); + } + })); + } + for h in handles { + h.join().unwrap(); + } + + // Final value must be at least the highest proposed LSN (monotonic). + let max_lsn = ITERS * THREADS as u64 + (THREADS as u64 - 1); + assert!( + feedback.get_applied_lsn() >= max_lsn - (THREADS as u64 - 1), + "applied LSN regressed under contention: got {}, want >= {}", + feedback.get_applied_lsn(), + max_lsn - (THREADS as u64 - 1) + ); + // Applied advancing must drag flushed along. + assert!(feedback.get_flushed_lsn() >= feedback.get_applied_lsn()); + } + + /// `update_applied_lsn` must NOT touch `flushed_lsn` when the proposed + /// applied LSN does not advance the current value. This is the + /// optimization that saves a CAS-loop per consumer event when LSN is + /// unchanged. + #[test] + fn test_applied_no_advance_does_not_modify_flushed() { + let feedback = SharedLsnFeedback::new(); + feedback.update_flushed_lsn(1000); + feedback.update_applied_lsn(500); + assert_eq!(feedback.get_flushed_lsn(), 1000); + assert_eq!(feedback.get_applied_lsn(), 500); + + // Stale applied LSN must not regress flushed. + feedback.update_applied_lsn(400); + assert_eq!(feedback.get_flushed_lsn(), 1000); + assert_eq!(feedback.get_applied_lsn(), 500); + + // Equal applied LSN must not touch flushed either. + feedback.update_applied_lsn(500); + assert_eq!(feedback.get_flushed_lsn(), 1000); + assert_eq!(feedback.get_applied_lsn(), 500); + } + + /// `update_applied_lsn` advancing past the current `flushed_lsn` must + /// also pull `flushed_lsn` forward (applied data is implicitly flushed). + #[test] + fn test_applied_advance_drags_flushed_forward() { + let feedback = SharedLsnFeedback::new(); + feedback.update_flushed_lsn(100); + feedback.update_applied_lsn(500); + assert_eq!(feedback.get_applied_lsn(), 500); + assert_eq!(feedback.get_flushed_lsn(), 500); + } } diff --git a/src/protocol.rs b/src/protocol.rs index e5897bb..9b52a6c 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -1365,10 +1365,7 @@ impl LogicalReplicationParser { ColumnData::binary_bytes(data) } _ => { - return Err(ReplicationError::protocol(format!( - "Unknown column data type: '{}'", - column_type as char - ))); + return Self::unknown_column_type_err(column_type); } }; @@ -1377,6 +1374,15 @@ impl LogicalReplicationParser { Ok(TupleData::from_smallvec(columns)) } + + #[cold] + #[inline(never)] + fn unknown_column_type_err(column_type: u8) -> Result { + Err(ReplicationError::protocol(format!( + "Unknown column data type: '{}'", + column_type as char + ))) + } } /// Parse keepalive message from the replication stream @@ -3146,4 +3152,45 @@ mod tests { _ => panic!("Expected Relation message"), } } + + /// Covers the cold error path `LogicalReplicationParser::unknown_column_type_err`. + /// + /// Crafts a synthetic `INSERT` WAL message whose tuple contains a column + /// with an unrecognised type byte (`b'x'`). `parse_tuple_data` must reject + /// it via the `#[cold]` helper. + #[test] + fn test_parse_tuple_data_unknown_column_type() { + let mut parser = LogicalReplicationParser::with_protocol_version(1); + + // INSERT message: 'I' + relation_id (u32) + 'N' + tuple_data + // tuple_data: column_count (u16) + per-column { type (u8) [+ length + payload] } + let mut msg = Vec::new(); + msg.push(b'I'); + msg.extend_from_slice(&42u32.to_be_bytes()); // relation_id + msg.push(b'N'); // new tuple marker + msg.extend_from_slice(&1u16.to_be_bytes()); // 1 column + msg.push(b'x'); // unknown column type → triggers the cold helper + + let err = parser.parse_wal_message(&msg).unwrap_err(); + let msg_str = err.to_string(); + assert!( + msg_str.contains("Unknown column data type"), + "expected 'Unknown column data type' in error, got: {msg_str}" + ); + assert!( + msg_str.contains("'x'"), + "expected the offending byte to be reported, got: {msg_str}" + ); + } + + /// Direct unit test for the cold error helper itself, so it is reachable + /// without parser-level setup. Pins the message format that callers may + /// rely on for log-grep diagnostics. + #[test] + fn test_unknown_column_type_err_message_format() { + let err = LogicalReplicationParser::unknown_column_type_err(b'?').unwrap_err(); + let s = err.to_string(); + assert!(s.contains("Unknown column data type")); + assert!(s.contains("'?'")); + } } diff --git a/src/stream.rs b/src/stream.rs index 6e6bc80..d25537a 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -19,12 +19,11 @@ use crate::types::{ ChangeEvent, EventType, Lsn, RelationColumn, ReplicaIdentity, ReplicationSlotOptions, SlotType, }; use crate::{ - format_lsn, parse_keepalive_message, postgres_timestamp_to_chrono, BufferReader, - LogicalReplicationMessage, LogicalReplicationParser, PgReplicationConnection, RelationInfo, - ReplicationConnectionRetry, ReplicationState, RetryConfig, StreamingReplicationMessage, - XLogRecPtr, INVALID_XLOG_REC_PTR, + format_lsn, parse_keepalive_message, postgres_timestamp_to_chrono, LogicalReplicationMessage, + LogicalReplicationParser, PgReplicationConnection, RelationInfo, ReplicationConnectionRetry, + ReplicationState, RetryConfig, StreamingReplicationMessage, XLogRecPtr, INVALID_XLOG_REC_PTR, }; -use bytes::Bytes; +use bytes::{Buf, Bytes}; use std::sync::Arc; use std::future::Future; @@ -53,6 +52,12 @@ pub struct LogicalReplicationStream { /// feedback. Must be a power of two so the modulo folds to a bitmask. const FEEDBACK_CHECK_EVENT_INTERVAL: u32 = 128; +/// How often `next_event_with_retry` performs the connection health check. +/// Must be a power of two. The health check itself is time-gated, but its +/// `Instant::now()` call is a vDSO syscall, so we additionally amortize it +/// across this many events on the hot path. +const HEALTH_CHECK_EVENT_INTERVAL: u32 = 1024; + /// Configuration for the replication stream #[derive(Debug, Clone)] pub struct ReplicationStreamConfig { @@ -738,10 +743,14 @@ impl LogicalReplicationStream { &mut self, cancellation_token: &CancellationToken, ) -> Result { - // Perform periodic health check - if let Err(e) = self.check_connection_health().await { - warn!("Health check failed: {}", e); - // Don't fail immediately, try to continue + // Perform periodic health check. + // + // The check itself is cheap when nothing is wrong (it just compares `Instant::now()` against `last_health_check`), but `Instant::now()` is a vDSO syscall on Linux. At 100k+ rps that adds up, so we gate the health check on the per-event feedback counter and let it amortize across `HEALTH_CHECK_EVENT_INTERVAL` events. The actual time-based interval inside `check_connection_health()` still applies, this just skips the syscall on the inner-loop hot path. + if self.feedback_check_counter & (HEALTH_CHECK_EVENT_INTERVAL - 1) == 0 { + if let Err(e) = self.check_connection_health().await { + warn!("Health check failed: {}", e); + // Don't fail immediately, try to continue + } } // Try to get the next event with retry logic for transient connection errors @@ -808,40 +817,40 @@ impl LogicalReplicationStream { } /// Process a WAL data message (zero-copy: uses Bytes slicing) + /// + /// The XLogData header is fixed-layout (25 bytes), so we decode it directly + /// from the slice instead of constructing a `BufferReader` — mirroring the + /// keepalive fast path. This avoids per-event bounds-check overhead and an + /// extra `Bytes` slice for the header region. + #[inline] fn process_wal_message(&mut self, data: impl Into) -> Result> { let data: Bytes = data.into(); - // Check minimum message length (1 + 8 + 8 + 8 = 25 bytes) + // 'w' (1) + start_lsn (8) + end_lsn (8) + send_time (8) = 25 bytes if data.len() < 25 { return Err(ReplicationError::protocol( "WAL message too short".to_string(), )); } - // Use BufferReader with zero-copy Bytes (ref-counted, no data copy) - let mut reader = BufferReader::from_bytes(data); - - // Skip the message type ('w') - let _msg_type = reader.skip_message_type()?; - - // Parse WAL message header - // Format: 'w' + start_lsn (8) + end_lsn (8) + send_time (8) + message_data - let start_lsn = reader.read_u64()?; - let end_lsn = reader.read_u64()?; - let _send_time = reader.read_i64()?; + // Parse WAL message header: Format: 'w' + start_lsn (8) + end_lsn (8) + send_time (8) + message_data + let mut header = &data[1..25]; + let start_lsn = header.get_u64(); + let end_lsn = header.get_u64(); + let _send_time = header.get_i64(); // Update LSN tracking using the server's WAL end position for this message if end_lsn > 0 { self.state.update_received_lsn(end_lsn); } - // Check if there's message data remaining - if reader.remaining() == 0 { + // No payload after the header → nothing to convert. + if data.len() == 25 { return Ok(None); } - // Get the remaining bytes for message parsing (zero-copy Bytes slice) - let message_data = reader.read_bytes_buf(reader.remaining())?; + // Zero-copy slice for the message body — refcount-only, no memcpy. + let message_data = data.slice(25..); let replication_message = self.parser.parse_wal_message_bytes(message_data)?; self.convert_to_change_event(replication_message, start_lsn) } @@ -2209,6 +2218,82 @@ mod tests { ); } + #[test] + fn test_health_check_event_interval_is_power_of_two() { + // `next_event_with_retry` uses the same `counter & (INTERVAL - 1) == 0` + // bitmask trick to gate `Instant::now()` calls, so HEALTH_CHECK_EVENT_INTERVAL + // must be a non-zero power of two. + assert!(HEALTH_CHECK_EVENT_INTERVAL > 0); + assert_eq!( + HEALTH_CHECK_EVENT_INTERVAL & (HEALTH_CHECK_EVENT_INTERVAL - 1), + 0, + "HEALTH_CHECK_EVENT_INTERVAL must be a power of two" + ); + } + + #[test] + fn test_health_check_event_interval_at_least_feedback_interval() { + // The health check interval is intentionally coarser than the feedback + // interval (and is a multiple of it) so feedback never gets starved by + // the health-check throttle. Pin the relationship in case either + // constant is ever bumped without consideration. + assert!( + HEALTH_CHECK_EVENT_INTERVAL >= FEEDBACK_CHECK_EVENT_INTERVAL, + "health-check interval should not be tighter than the feedback interval" + ); + assert_eq!( + HEALTH_CHECK_EVENT_INTERVAL % FEEDBACK_CHECK_EVENT_INTERVAL, + 0, + "HEALTH_CHECK_EVENT_INTERVAL should be a multiple of FEEDBACK_CHECK_EVENT_INTERVAL" + ); + } + + #[test] + fn test_health_check_gate_fires_on_first_event() { + // The gate is `counter & (INTERVAL - 1) == 0`. The counter starts at + // 0, so the very first call to `next_event_with_retry` must run the + // health check — important for catching dead connections at startup + // even on low-throughput streams. + let counter: u32 = 0; + assert_eq!(counter & (HEALTH_CHECK_EVENT_INTERVAL - 1), 0); + } + + #[test] + fn test_health_check_gate_fires_on_multiples() { + // Mirrors `test_feedback_check_gate_semantics` for the health-check gate. + let interval = HEALTH_CHECK_EVENT_INTERVAL; + let mut counter: u32 = 0; + let mut fire_events = Vec::new(); + // Walk through 3 intervals' worth of events and record where the gate fires. + for event_idx in 0..(interval * 3) { + if counter & (interval - 1) == 0 { + fire_events.push(event_idx); + } + counter = counter.wrapping_add(1); + } + assert_eq!( + fire_events, + vec![0, interval, interval * 2], + "health-check gate should fire at counter 0, INTERVAL, 2*INTERVAL" + ); + } + + #[test] + fn test_health_check_gate_handles_counter_wraparound() { + // The counter is `u32` and uses `wrapping_add`. The gate must continue + // to fire correctly when the counter wraps from u32::MAX back to 0. + let interval = HEALTH_CHECK_EVENT_INTERVAL; + // u32::MAX = 0xFFFF_FFFF; with INTERVAL=1024, (INTERVAL - 1) = 0x3FF, + // so u32::MAX & 0x3FF == 0x3FF (gate does NOT fire). + let counter: u32 = u32::MAX; + assert_ne!(counter & (interval - 1), 0); + + // After wrapping_add(1) the counter wraps to 0, and the gate fires. + let next = counter.wrapping_add(1); + assert_eq!(next, 0); + assert_eq!(next & (interval - 1), 0); + } + #[test] fn test_replication_state_should_send_feedback() { let mut state = ReplicationState::new();