From 2af1504452c7b70c8a76392c5bdd75c4d0addd31 Mon Sep 17 00:00:00 2001 From: David Anderson Date: Wed, 6 May 2026 10:18:20 -0700 Subject: [PATCH] ts_tunnel: add replay protection for incoming packets Rejects packets for an established session that have excessively old counter values. To allow for some packet reordering on the wire, we maintain a sliding window of acceptable counter values, and track already received counter values in that window with a bit ring. Fixes #17 Signed-off-by: David Anderson Change-Id: Ib71e5ea78078bb2e84314bc0f2fee94e6a6a6964 --- Cargo.lock | 1 + ts_tunnel/Cargo.toml | 1 + ts_tunnel/src/endpoint.rs | 12 +- ts_tunnel/src/handshake.rs | 5 +- ts_tunnel/src/lib.rs | 1 + ts_tunnel/src/replay.rs | 258 +++++++++++++++++++++++++++++++++++++ ts_tunnel/src/session.rs | 35 +++-- 7 files changed, 294 insertions(+), 19 deletions(-) create mode 100644 ts_tunnel/src/replay.rs diff --git a/Cargo.lock b/Cargo.lock index f4b3a4aa..a41af53e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4614,6 +4614,7 @@ dependencies = [ "hex", "hkdf", "itertools", + "proptest", "rand 0.10.1", "tokio", "tracing", diff --git a/ts_tunnel/Cargo.toml b/ts_tunnel/Cargo.toml index 282d3fab..439d1662 100644 --- a/ts_tunnel/Cargo.toml +++ b/ts_tunnel/Cargo.toml @@ -35,6 +35,7 @@ clap = { workspace = true, features = ["derive", "env"] } hex = "0.4" tokio = { workspace = true, features = ["full", "macros"] } ts_cli_util.workspace = true +proptest.workspace = true [lints] workspace = true diff --git a/ts_tunnel/src/endpoint.rs b/ts_tunnel/src/endpoint.rs index f8bd141c..749eb114 100644 --- a/ts_tunnel/src/endpoint.rs +++ b/ts_tunnel/src/endpoint.rs @@ -74,8 +74,8 @@ enum SessionState { None(Queue), /// Active session available. Active { - recv: ReceiveSession, - recv_prev: Option, + recv: Box, + recv_prev: Option>, send: TransmitSession, }, } @@ -99,7 +99,7 @@ impl SessionState { next.send.encrypt(&mut ret); *self = SessionState::Active { send: next.send, - recv: next.recv, + recv: Box::new(next.recv), recv_prev: None, }; ret @@ -114,7 +114,7 @@ impl SessionState { .inspect(|recv_prev| endpoint.ids.remove_session(recv_prev.id())); *self = SessionState::Active { send: next.send, - recv: next.recv, + recv: Box::new(next.recv), recv_prev: Some(recv), }; vec![] @@ -177,7 +177,7 @@ impl SessionState { } /// Get the receive session matching the given ID, if any. - fn get_recv(&self, id: SessionId) -> Option<&ReceiveSession> { + fn get_recv(&mut self, id: SessionId) -> Option<&mut ReceiveSession> { match self { SessionState::None(_) => None, SessionState::Active { @@ -185,7 +185,7 @@ impl SessionState { } => { if recv.id() == id && !recv.expired(Instant::now()) { Some(recv) - } else if let Some(recv_prev) = recv_prev.as_ref() + } else if let Some(recv_prev) = recv_prev.as_mut() && recv_prev.id() == id && !recv.expired(Instant::now()) { diff --git a/ts_tunnel/src/handshake.rs b/ts_tunnel/src/handshake.rs index 56488d09..49b62830 100644 --- a/ts_tunnel/src/handshake.rs +++ b/ts_tunnel/src/handshake.rs @@ -532,13 +532,14 @@ mod tests { assert_eq!(b_handshake.peer_static, a_static.public); assert_eq!(b_handshake.timestamp, a_init_time); let b_session = SessionId::random(); // B wants to receive at this ID - let (b_session, response_pkt) = + let (mut b_session, response_pkt) = b_handshake.respond(b_session, &psk, &b_mac_send, Instant::now()); // Peer A receives response let response_pkt = HandshakeResponse::try_ref_from_bytes(response_pkt.as_ref()) .expect("response_pkt should be a valid handshake response message"); - let Some(a_session) = a_handshake.finish(response_pkt, &psk, &a_mac_recv, Instant::now()) + let Some(mut a_session) = + a_handshake.finish(response_pkt, &psk, &a_mac_recv, Instant::now()) else { panic!("failed to process handshake response from peer B"); }; diff --git a/ts_tunnel/src/lib.rs b/ts_tunnel/src/lib.rs index 75b9d7f4..ea2ee8f1 100644 --- a/ts_tunnel/src/lib.rs +++ b/ts_tunnel/src/lib.rs @@ -5,6 +5,7 @@ mod endpoint; mod handshake; mod macs; mod messages; +mod replay; mod session; mod time; diff --git a/ts_tunnel/src/replay.rs b/ts_tunnel/src/replay.rs new file mode 100644 index 00000000..d8f276b1 --- /dev/null +++ b/ts_tunnel/src/replay.rs @@ -0,0 +1,258 @@ +//! Implementation of the packet replay protection algorithm from RFC 6479. +//! +//! The overall goal of replay protection is to only accept new packets in an established session, +//! and reject attempts at playing back older packets. +//! +//! We could naively do this by tracking the highest packet counter we've seen on a valid packet, +//! and reject all packets presenting an older counter. However, this is overly conservative in +//! the face of packet reordering on the network, wherein a burst of packets may arrive slightly +//! out of order. +//! +//! Precisely tracking all previously seen packet IDs for all time is prohibitively expensive, so +//! practical systems compromise and track both the highest counter seen so far, and a sliding +//! window of the N packets prior to the latest. Packets in that window can be received out +//! of order while still rejecting replays. Packets that fall earlier than the window are rejected +//! unconditionally, on the assumption that sufficiently old packets have all been received or lost +//! permanently. +//! +//! The window can be implemented with a regular bitset, with each bit tracking one packet in the +//! window of recent counters. The downside of the naive implementation is that whenever a newer +//! packet is accepted, sliding the window forward involves doing a bit shift operation on the +//! entire bitset. This is fairly expensive to do at high packet line rates. +//! +//! The first idea of RFC 6479 is that, if we make the window a power of two, we can directly map +//! a counter value to a bit index by masking the higher order bits of the counter. This turns +//! the bitset into a ring buffer, where the bit position of the highest seen counter is the head +//! pointer. As the highest seen counter value increments when receiving packets, the window's head +//! position automatically slides forward. +//! +//! Here's a visual representation of what that looks like in a small 32-bit window: +//! +//! | 0 0 0 1 1 0 1 0 1 0 0 1 1 1 1 1 1 1 0 1 0 1 1 1 1 1 1 1 1 1 1 1 | +//! ^ ^ ^ +//! | | \ +//! | | Current tail: 144_844 +//! | | Bit index after masking: 12 +//! | \ +//! | Current head: 144_875 +//! | Bit index after masking: 11 +//! \ +//! Counter 144_872 has already been received +//! Bit index after masking: 8 +//! +//! This approach introduces a new issue: when advancing the head of the window, we have to take +//! care to zero out bits that have wrapped around from the window's tail. We want this operation +//! to be cheaper than bit shifting, since that's what we've been trying to avoid this whole time. +//! +//! RFC 6479's second idea is to observe that replay windows usually span several machine words. +//! The window is represented as an array of blocks, for example a `[u64; 8]` for 512 bits total. +//! If we shrink the usable window to leave one of those blocks unused, then the ring's head and +//! tail pointers never occupy the same block. +//! +//! This lets us advance the head pointer very cheaply: whenever the head position crosses over +//! into a new block, we zero that block entirely. This may result in zeroing several consecutive +//! blocks if the head advances by a large amount, or even the entire ring if the head advances +//! more than the window size. Finally, once the appropriate blocks have been zeroed, the bit +//! corresponding to the new highest counter is set. +//! +//! The resulting window after sliding has exactly the same content as in the bit-shift +//! implementation, but the cost of advancing has been reduced to zeroing a few machine words. +//! Similarly, the cost of setting a bit within the window is a clean bit masking operation +//! (because the overall ring size is a power of 2), followed by a bit set operation within a +//! single machine word. The cost of checking an arbitrary counter value consists of a few +//! comparisons to check if the counter is before or after the current window, and as mask+bit test +//! for counters within the window. + +use std::fmt::Debug; + +/// A packet replay tracker. +/// +/// In the abstract, the tracker rejects previously seen counter values. However, to +/// do this perfectly would require a large amount of storage. Instead, the tracker assumes +/// that counter values are seen mostly in ascending order, and only explicitly tracks seen +/// counter values in a short window behind the latest seen value. +/// +/// Values that fall before this window are unconditionally rejected; values larger than any seen +/// so far are unconditionally accepted (and advance the tracker's sliding window); values that +/// fall within the window are tracked explicitly with a bitset, to ensure they are accepted once +/// only. +#[derive(Default)] +pub struct ReplayWindow { + // nonce counter value of the end of the sliding window + last: u64, + blocks: [u64; ReplayWindow::N_BLOCKS as usize], +} + +impl Debug for ReplayWindow { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "ReplayWindow{{\n last: {}\n bits:", self.last)?; + for block in self.blocks { + write!(f, " ")?; + for octet in 0usize..8 { + let v = ((block >> (octet * 8)) & 0xff) as u8; + write!(f, "{:08b} ", v.reverse_bits())?; + } + writeln!(f)?; + } + write!(f, "}}") + } +} + +impl ReplayWindow { + const TOTAL_BITS: u64 = 256; + const N_BLOCKS: u64 = Self::TOTAL_BITS / u64::BITS as u64; + const BIT_IDX_BITMASK: u64 = (u64::BITS - 1) as u64; + const BIT_IDX_SHIFT: u32 = u64::BITS.ilog2(); + const BLOCK_IDX_BITMASK: u64 = Self::N_BLOCKS - 1; + + pub const WINDOW_SIZE: u64 = (Self::N_BLOCKS - 1) * u64::BITS as u64; + + fn smallest_valid(&self) -> u64 { + self.last.saturating_sub(Self::WINDOW_SIZE - 1) + } + + fn block_idx_unbounded(&self, counter: u64) -> u64 { + counter >> Self::BIT_IDX_SHIFT + } + + fn bit_idx(&self, counter: u64) -> u64 { + counter & Self::BIT_IDX_BITMASK + } + + fn block_idx_and_bit_mask(&self, counter: u64) -> (usize, u64) { + let block_idx = self.block_idx_unbounded(counter) & Self::BLOCK_IDX_BITMASK; + (block_idx as usize, 1 << self.bit_idx(counter)) + } + + /// Report whether counter is a new value that can be processed. + /// + /// Does not update the replay window state, so should be called prior to doing + /// expensive processing. After processing, you must call `ReplayWindow::set` to + /// update the replay window state. + pub fn check(&self, counter: u64) -> bool { + if counter > self.last { + return true; + } + if counter < self.smallest_valid() { + return false; + } + let (block_idx, bit_mask) = self.block_idx_and_bit_mask(counter); + self.blocks[block_idx] & bit_mask == 0 + } + + /// Update the replay window to mark the given counter as seen and accepted + /// + /// # Panics + /// + /// If [`ReplayWindow::check(counter)`] is false. + pub fn set(&mut self, counter: u64) { + if counter < self.smallest_valid() { + panic!( + "invalid set: counter {} is older than smallest valid {}", + counter, + self.smallest_valid() + ); + } + if counter > self.last { + let cur_block = self.block_idx_unbounded(self.last); + let new_block = self.block_idx_unbounded(counter); + let delta = new_block - cur_block; + if delta >= Self::N_BLOCKS { + self.blocks = [0; Self::N_BLOCKS as usize]; + } else { + for i in cur_block..new_block { + let idx = (i + 1) & Self::BLOCK_IDX_BITMASK; + self.blocks[idx as usize] = 0; + } + } + self.last = counter; + } + let (block_idx, bit_mask) = self.block_idx_and_bit_mask(counter); + if self.blocks[block_idx] & bit_mask != 0 { + panic!( + "invalid set: counter {} was already set previously", + counter + ); + } + self.blocks[block_idx] |= bit_mask; + } + + #[cfg(test)] + fn check_and_set(&mut self, counter: u64) -> bool { + let accept = self.check(counter); + if accept { + self.set(counter); + } + accept + } + + #[cfg(test)] + fn received_in_window(&self) -> u64 { + let counters = self.smallest_valid()..self.last + 1; + counters + .map(|ctr| if self.check(ctr) { 0 } else { 1 }) + .sum() + } +} + +#[cfg(test)] +mod tests { + use std::{cmp::max, collections::HashSet}; + + use super::*; + + #[test] + fn just_advance() { + let mut window = ReplayWindow::default(); + + for counter in 0..600 { + assert!(window.check_and_set(counter)); + assert_eq!( + window.received_in_window(), + (counter + 1).clamp(0, ReplayWindow::WINDOW_SIZE) + ); + } + } + + #[test] + fn out_of_order() { + let mut window = ReplayWindow::default(); + + assert!(window.check_and_set(500)); + assert!(!window.check(500)); + assert!(!window.check(100)); + assert_eq!(window.received_in_window(), 1); + for (i, counter) in (400..450).rev().enumerate() { + assert!(window.check_and_set(counter)); + assert_eq!(window.received_in_window(), (i + 2) as u64); + } + for (i, counter) in (451..500).enumerate() { + assert!(window.check_and_set(counter)); + assert_eq!(window.received_in_window(), (i + 52) as u64); + } + } + + proptest::proptest! { + #[test] + fn any_order(counters in proptest::collection::vec(0u64..1000, 0..2000)) { + let mut seen = HashSet::new(); + let mut latest = None; + let mut window = ReplayWindow::default(); + for counter in counters { + let accepted = window.check_and_set(counter); + if accepted { + assert!(!seen.contains(&counter)); + if let Some(latest_ctr) = latest { + assert!(counter >= window.smallest_valid()); + latest = Some(max(latest_ctr, counter)) + } else { + latest = Some(counter); + } + seen.insert(counter); + } else { + assert!(seen.contains(&counter) || counter < window.smallest_valid()); + } + } + } + } +} diff --git a/ts_tunnel/src/session.rs b/ts_tunnel/src/session.rs index d902ec67..8bcd6cdb 100644 --- a/ts_tunnel/src/session.rs +++ b/ts_tunnel/src/session.rs @@ -12,7 +12,10 @@ use zerocopy::{ little_endian::{U32, U64}, }; -use crate::messages::{SessionId, TransportDataHeader}; +use crate::{ + messages::{SessionId, TransportDataHeader}, + replay::ReplayWindow, +}; type SessionKey = chacha20poly1305::Key; @@ -151,7 +154,7 @@ pub struct ReceiveSession { cipher: ChaCha20Poly1305, id: SessionId, created: Instant, - // TODO: nonce sliding window for replay protection + window: ReplayWindow, } impl Debug for ReceiveSession { @@ -168,13 +171,14 @@ impl ReceiveSession { cipher: ChaCha20Poly1305::new(&key), id, created: now, + window: ReplayWindow::default(), } } /// Decrypt wireguard transport data messages in place. /// /// Returns the packets which successfully decrypted. - pub fn decrypt(&self, mut packets: Vec) -> Vec { + pub fn decrypt(&mut self, mut packets: Vec) -> Vec { packets.retain_mut(|packet| self.decrypt_one(packet)); packets } @@ -182,7 +186,7 @@ impl ReceiveSession { /// Decrypt a wireguard transport data message in place. #[tracing::instrument(skip_all, fields(session_id = ?self.id))] #[must_use] - fn decrypt_one(&self, pkt: &mut PacketMut) -> bool { + fn decrypt_one(&mut self, pkt: &mut PacketMut) -> bool { let Ok((header, _)) = TransportDataHeader::try_ref_from_prefix(pkt.as_ref()) else { tracing::warn!("decode as transport packet failed"); return false; @@ -209,16 +213,25 @@ impl ReceiveSession { return false; } + let counter = header.nonce.into(); + if !self.window.check(counter) { + tracing::trace!("reject old/replayed packet"); + return false; + } + let nonce = Nonce::from(header.nonce); pkt.truncate_front(size_of::()); - let result = self.cipher.decrypt_in_place(nonce.as_ref(), &[], pkt); - - if let Err(e) = &result { - tracing::error!(err = %e, "decryption failed"); + match self.cipher.decrypt_in_place(nonce.as_ref(), &[], pkt) { + Ok(_) => { + self.window.set(counter); + true + } + Err(e) => { + tracing::error!(err = %e, "decryption failed"); + false + } } - - result.is_ok() } pub fn id(&self) -> SessionId { @@ -241,7 +254,7 @@ mod tests { let session = SessionId::random(); let now = Instant::now(); let send = TransmitSession::new(k.into(), session, now); - let recv = ReceiveSession::new(k.into(), session, now); + let mut recv = ReceiveSession::new(k.into(), session, now); const CLEARTEXT: &[u8] = b"foobar"; let mut pkt = [PacketMut::from(CLEARTEXT)];