Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 22 additions & 24 deletions src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,21 +84,17 @@ impl BufferReader {
#[inline]
fn ensure_bytes(&self, count: usize) -> Result<()> {
if self.data.remaining() < count {
return Err(ReplicationError::protocol(format!(
"Not enough bytes remaining. Need {}, have {}",
count,
self.data.remaining()
)));
return Self::short_buffer_err(count, self.data.remaining());
}
Ok(())
}

/// Skip the message type byte and return current position
#[inline]
pub fn skip_message_type(&mut self) -> Result<usize> {
self.ensure_bytes(1)?;
self.data.advance(1);
Ok(self.data.len())
#[cold]
#[inline(never)]
fn short_buffer_err(needed: usize, have: usize) -> Result<()> {
Err(ReplicationError::protocol(format!(
"Not enough bytes remaining. Need {needed}, have {have}"
)))
}

/// Read a single byte
Expand Down Expand Up @@ -601,15 +597,6 @@ mod tests {
assert_eq!(reader.remaining(), 2);
}

#[test]
fn test_buffer_reader_skip_message_type() {
let data = [0x42, 0x01, 0x02, 0x03]; // 'B' message type
let mut reader = BufferReader::new(&data);

reader.skip_message_type().unwrap();
assert_eq!(reader.read_u8().unwrap(), 0x01);
}

#[test]
fn test_buffer_writer_signed_integers() {
let mut writer = BufferWriter::new();
Expand Down Expand Up @@ -868,10 +855,21 @@ mod tests {
assert!(reader.read_bytes_buf(5).is_err());
}

/// Pin the error message format produced by the `#[cold]`
/// `short_buffer_err` helper. Several layers above (parser, stream)
/// surface this string in logs, so format regressions would silently
/// degrade diagnostics.
#[test]
fn test_buffer_reader_skip_message_type_empty() {
let data: &[u8] = &[];
let mut reader = BufferReader::new(data);
assert!(reader.skip_message_type().is_err());
fn test_buffer_reader_short_buffer_err_message_format() {
let data = [0x01, 0x02];
let mut reader = BufferReader::new(&data);
let err = reader.read_bytes_buf(5).unwrap_err();
let s = err.to_string();
assert!(
s.contains("Not enough bytes remaining"),
"expected 'Not enough bytes remaining' in error, got: {s}"
);
assert!(s.contains("Need 5"), "expected 'Need 5', got: {s}");
assert!(s.contains("have 2"), "expected 'have 2', got: {s}");
}
}
105 changes: 99 additions & 6 deletions src/lsn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,14 @@ impl SharedLsnFeedback {
///
/// This should be called when data has been written/flushed to the destination
/// database, but not yet committed (e.g., during batch writes).
#[inline]
pub fn update_flushed_lsn(&self, lsn: XLogRecPtr) {
if lsn == 0 {
return;
}

let mut current = self.flushed_lsn.load(Ordering::Acquire);
loop {
let current = self.flushed_lsn.load(Ordering::Acquire);
if lsn <= current {
return;
}
Expand All @@ -124,7 +125,7 @@ impl SharedLsnFeedback {
);
return;
}
Err(_) => continue,
Err(actual) => current = actual,
}
}
}
Expand All @@ -134,12 +135,15 @@ impl SharedLsnFeedback {
/// This should be called when a transaction has been successfully committed
/// to the destination database. This is the most important LSN as PostgreSQL
/// uses it to determine which WAL can be recycled.
#[inline]
pub fn update_applied_lsn(&self, lsn: XLogRecPtr) {
if lsn == 0 {
return;
}

let mut current = self.applied_lsn.load(Ordering::Acquire);
let mut advanced = false;
loop {
let current = self.applied_lsn.load(Ordering::Acquire);
if lsn <= current {
break;
}
Expand All @@ -154,14 +158,17 @@ impl SharedLsnFeedback {
"SharedLsnFeedback: Updated applied LSN from {} to {}",
current, lsn
);
advanced = true;
break;
}
Err(_) => continue,
Err(actual) => current = actual,
}
}

// Applied data is implicitly flushed, update flushed as well
self.update_flushed_lsn(lsn);
// Applied data is implicitly flushed. Only chase the flushed CAS when we actually moved applied forward; otherwise flushed cannot be behind.
if advanced {
self.update_flushed_lsn(lsn);
}
}

/// Get the current flushed LSN
Expand Down Expand Up @@ -483,4 +490,90 @@ mod tests {
feedback.update_applied_lsn(50);
assert_eq!(feedback.get_applied_lsn(), 50);
}

/// Deterministic stress test that forces the `Err(actual) => current = actual`
/// branch in `update_flushed_lsn` and `update_applied_lsn` to fire.
///
/// `compare_exchange_weak` may also fail spuriously on weak memory models
/// (e.g. ARM) even without contention, so on those targets the loop body
/// is exercised even by a single-threaded run. On x86 we additionally
/// fan out across many threads so the lost-update path is always touched
/// regardless of the CAS strength.
#[test]
fn test_concurrent_cas_retry_path() {
use std::sync::Arc;
use std::sync::Barrier;
use std::thread;

const THREADS: usize = 16;
const ITERS: u64 = 5_000;

let feedback = SharedLsnFeedback::new_shared();
let barrier = Arc::new(Barrier::new(THREADS));

let mut handles = Vec::with_capacity(THREADS);
for tid in 0..THREADS as u64 {
let fb = Arc::clone(&feedback);
let bar = Arc::clone(&barrier);
handles.push(thread::spawn(move || {
bar.wait();
// Each thread proposes a strictly increasing sequence of LSNs;
// collisions between threads guarantee CAS contention and force
// the `Err(actual)` branch to update `current` and retry.
for i in 1..=ITERS {
let lsn = i * THREADS as u64 + tid;
fb.update_flushed_lsn(lsn);
fb.update_applied_lsn(lsn);
}
}));
}
for h in handles {
h.join().unwrap();
}

// Final value must be at least the highest proposed LSN (monotonic).
let max_lsn = ITERS * THREADS as u64 + (THREADS as u64 - 1);
assert!(
feedback.get_applied_lsn() >= max_lsn - (THREADS as u64 - 1),
"applied LSN regressed under contention: got {}, want >= {}",
feedback.get_applied_lsn(),
max_lsn - (THREADS as u64 - 1)
);
// Applied advancing must drag flushed along.
assert!(feedback.get_flushed_lsn() >= feedback.get_applied_lsn());
}

/// `update_applied_lsn` must NOT touch `flushed_lsn` when the proposed
/// applied LSN does not advance the current value. This is the
/// optimization that saves a CAS-loop per consumer event when LSN is
/// unchanged.
#[test]
fn test_applied_no_advance_does_not_modify_flushed() {
let feedback = SharedLsnFeedback::new();
feedback.update_flushed_lsn(1000);
feedback.update_applied_lsn(500);
assert_eq!(feedback.get_flushed_lsn(), 1000);
assert_eq!(feedback.get_applied_lsn(), 500);

// Stale applied LSN must not regress flushed.
feedback.update_applied_lsn(400);
assert_eq!(feedback.get_flushed_lsn(), 1000);
assert_eq!(feedback.get_applied_lsn(), 500);

// Equal applied LSN must not touch flushed either.
feedback.update_applied_lsn(500);
assert_eq!(feedback.get_flushed_lsn(), 1000);
assert_eq!(feedback.get_applied_lsn(), 500);
}

/// `update_applied_lsn` advancing past the current `flushed_lsn` must
/// also pull `flushed_lsn` forward (applied data is implicitly flushed).
#[test]
fn test_applied_advance_drags_flushed_forward() {
let feedback = SharedLsnFeedback::new();
feedback.update_flushed_lsn(100);
feedback.update_applied_lsn(500);
assert_eq!(feedback.get_applied_lsn(), 500);
assert_eq!(feedback.get_flushed_lsn(), 500);
}
}
55 changes: 51 additions & 4 deletions src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1365,10 +1365,7 @@ impl LogicalReplicationParser {
ColumnData::binary_bytes(data)
}
_ => {
return Err(ReplicationError::protocol(format!(
"Unknown column data type: '{}'",
column_type as char
)));
return Self::unknown_column_type_err(column_type);
}
};

Expand All @@ -1377,6 +1374,15 @@ impl LogicalReplicationParser {

Ok(TupleData::from_smallvec(columns))
}

#[cold]
#[inline(never)]
fn unknown_column_type_err(column_type: u8) -> Result<TupleData> {
Err(ReplicationError::protocol(format!(
"Unknown column data type: '{}'",
column_type as char
)))
}
}

/// Parse keepalive message from the replication stream
Expand Down Expand Up @@ -3146,4 +3152,45 @@ mod tests {
_ => panic!("Expected Relation message"),
}
}

/// Covers the cold error path `LogicalReplicationParser::unknown_column_type_err`.
///
/// Crafts a synthetic `INSERT` WAL message whose tuple contains a column
/// with an unrecognised type byte (`b'x'`). `parse_tuple_data` must reject
/// it via the `#[cold]` helper.
#[test]
fn test_parse_tuple_data_unknown_column_type() {
let mut parser = LogicalReplicationParser::with_protocol_version(1);

// INSERT message: 'I' + relation_id (u32) + 'N' + tuple_data
// tuple_data: column_count (u16) + per-column { type (u8) [+ length + payload] }
let mut msg = Vec::new();
msg.push(b'I');
msg.extend_from_slice(&42u32.to_be_bytes()); // relation_id
msg.push(b'N'); // new tuple marker
msg.extend_from_slice(&1u16.to_be_bytes()); // 1 column
msg.push(b'x'); // unknown column type → triggers the cold helper

let err = parser.parse_wal_message(&msg).unwrap_err();
let msg_str = err.to_string();
assert!(
msg_str.contains("Unknown column data type"),
"expected 'Unknown column data type' in error, got: {msg_str}"
);
assert!(
msg_str.contains("'x'"),
"expected the offending byte to be reported, got: {msg_str}"
);
}

/// Direct unit test for the cold error helper itself, so it is reachable
/// without parser-level setup. Pins the message format that callers may
/// rely on for log-grep diagnostics.
#[test]
fn test_unknown_column_type_err_message_format() {
let err = LogicalReplicationParser::unknown_column_type_err(b'?').unwrap_err();
let s = err.to_string();
assert!(s.contains("Unknown column data type"));
assert!(s.contains("'?'"));
}
}
Loading
Loading