diff --git a/Cargo.lock b/Cargo.lock index fffa5a7c9..64e86efce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -504,12 +504,6 @@ dependencies = [ "shlex", ] -[[package]] -name = "cesu8" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" - [[package]] name = "cexpr" version = "0.6.0" @@ -1604,25 +1598,52 @@ checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" [[package]] name = "jni" -version = "0.21.1" +version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97" +checksum = "5efd9a482cf3a427f00d6b35f14332adc7902ce91efb778580e180ff90fa3498" dependencies = [ - "cesu8", "cfg-if", "combine", + "jni-macros", "jni-sys", "log", - "thiserror 1.0.69", + "simd_cesu8", + "thiserror 2.0.17", "walkdir", - "windows-sys 0.45.0", + "windows-link", +] + +[[package]] +name = "jni-macros" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a00109accc170f0bdb141fed3e393c565b6f5e072365c3bd58f5b062591560a3" +dependencies = [ + "proc-macro2", + "quote", + "rustc_version", + "simd_cesu8", + "syn", ] [[package]] name = "jni-sys" -version = "0.3.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" +checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2" +dependencies = [ + "jni-sys-macros", +] + +[[package]] +name = "jni-sys-macros" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264" +dependencies = [ + "quote", + "syn", +] [[package]] name = "jobserver" @@ -2454,9 +2475,9 @@ dependencies = [ [[package]] name = "rustls-platform-verifier" -version = "0.6.1" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be59af91596cac372a6942530653ad0c3a246cdd491aaa9dcaee47f88d67d5a0" +checksum = "26d1e2536ce4f35f4846aa13bff16bd0ff40157cdb14cc056c7b14ba41233ba0" dependencies = [ "core-foundation", "core-foundation-sys", @@ -2470,7 +2491,7 @@ dependencies = [ "security-framework", "security-framework-sys", "webpki-root-certs", - "windows-sys 0.59.0", + "windows-sys 0.61.1", ] [[package]] @@ -2481,9 +2502,9 @@ checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" [[package]] name = "rustls-webpki" -version = "0.103.12" +version = "0.103.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8279bb85272c9f10811ae6a6c547ff594d6a7f3c6c6b02ee9726d1d0dcfcdd06" +checksum = "61c429a8649f110dddef65e2a5ad240f747e85f7758a6bccc7e5777bd33f756e" dependencies = [ "aws-lc-rs", "ring", @@ -2653,6 +2674,22 @@ version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" +[[package]] +name = "simd_cesu8" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94f90157bb87cddf702797c5dadfa0be7d266cdf49e22da2fcaa32eff75b2c33" +dependencies = [ + "rustc_version", + "simdutf8", +] + +[[package]] +name = "simdutf8" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" + [[package]] name = "siphasher" version = "1.0.1" @@ -3451,15 +3488,6 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "45e46c0661abb7180e7b9c281db115305d49ca1709ab8242adf09666d2173c65" -[[package]] -name = "windows-sys" -version = "0.45.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" -dependencies = [ - "windows-targets 0.42.2", -] - [[package]] name = "windows-sys" version = "0.52.0" @@ -3496,21 +3524,6 @@ dependencies = [ "windows-link", ] -[[package]] -name = "windows-targets" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" -dependencies = [ - "windows_aarch64_gnullvm 0.42.2", - "windows_aarch64_msvc 0.42.2", - "windows_i686_gnu 0.42.2", - "windows_i686_msvc 0.42.2", - "windows_x86_64_gnu 0.42.2", - "windows_x86_64_gnullvm 0.42.2", - "windows_x86_64_msvc 0.42.2", -] - [[package]] name = "windows-targets" version = "0.52.6" @@ -3544,12 +3557,6 @@ dependencies = [ "windows_x86_64_msvc 0.53.0", ] -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" - [[package]] name = "windows_aarch64_gnullvm" version = "0.52.6" @@ -3562,12 +3569,6 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" -[[package]] -name = "windows_aarch64_msvc" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" - [[package]] name = "windows_aarch64_msvc" version = "0.52.6" @@ -3580,12 +3581,6 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" -[[package]] -name = "windows_i686_gnu" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" - [[package]] name = "windows_i686_gnu" version = "0.52.6" @@ -3610,12 +3605,6 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" -[[package]] -name = "windows_i686_msvc" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" - [[package]] name = "windows_i686_msvc" version = "0.52.6" @@ -3628,12 +3617,6 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" -[[package]] -name = "windows_x86_64_gnu" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" - [[package]] name = "windows_x86_64_gnu" version = "0.52.6" @@ -3646,12 +3629,6 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" - [[package]] name = "windows_x86_64_gnullvm" version = "0.52.6" @@ -3664,12 +3641,6 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" -[[package]] -name = "windows_x86_64_msvc" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" - [[package]] name = "windows_x86_64_msvc" version = "0.52.6" diff --git a/Cargo.toml b/Cargo.toml index 0dd38c829..2c4f5d1a5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,7 +44,7 @@ rustc-hash = "2" rustls = { version = "0.23.33", default-features = false, features = ["std"] } rustls-pemfile = "2" rustls-pki-types = "1.7" -rustls-platform-verifier = "0.6" +rustls-platform-verifier = "0.7" serde = { version = "1.0", features = ["derive"] } serde_json = "1" slab = "0.4.9" diff --git a/deny.toml b/deny.toml index 494d96553..1354e0db2 100644 --- a/deny.toml +++ b/deny.toml @@ -13,7 +13,7 @@ allow = [ "ISC", "MIT", "NCSA", - "OpenSSL", + "OpenSSL", # aws-lc-fips-sys "Unicode-3.0", "Zlib", # foldhash, dependency of fastbloom ] @@ -28,9 +28,6 @@ skip = [ # ring uses getrandom 0.2, newer crates use 0.3 { crate = "getrandom", reason = "ring depends on 0.2, newer ecosystem uses 0.3" }, { crate = "r-efi", reason = "proptest dev-dependency pulls in old getrandom" }, - # jni and redox_users use thiserror 1.x - { crate = "thiserror", reason = "transitive deps use thiserror 1.x" }, - { crate = "thiserror-impl", reason = "follows thiserror" }, # follows getrandom versions { crate = "wasi", reason = "follows getrandom version split" }, # various transitive deps require different windows-sys versions diff --git a/docs/book/book.toml b/docs/book/book.toml index 8cc483200..359d09d9f 100644 --- a/docs/book/book.toml +++ b/docs/book/book.toml @@ -1,6 +1,5 @@ [book] authors = ["Timon Post"] language = "en" -multilingual = false src = "src" title = "noq" diff --git a/noq-proto/src/config/mod.rs b/noq-proto/src/config/mod.rs index 8be51a764..2cbd5b484 100644 --- a/noq-proto/src/config/mod.rs +++ b/noq-proto/src/config/mod.rs @@ -80,11 +80,11 @@ impl EndpointConfig { /// information in local connection IDs, e.g. to support stateless packet-level load balancers. /// /// Defaults to [`HashedConnectionIdGenerator`]. - pub fn cid_generator Box + Send + Sync + 'static>( + pub fn cid_generator( &mut self, - factory: F, + factory: Arc Box + Send + Sync>, ) -> &mut Self { - self.connection_id_generator_factory = Arc::new(factory); + self.connection_id_generator_factory = factory; self } diff --git a/noq-proto/src/config/transport.rs b/noq-proto/src/config/transport.rs index 0152d409a..a72df6393 100644 --- a/noq-proto/src/config/transport.rs +++ b/noq-proto/src/config/transport.rs @@ -60,6 +60,7 @@ pub struct TransportConfig { pub(crate) mtu_discovery_config: Option, pub(crate) pad_to_mtu: bool, pub(crate) ack_frequency_config: Option, + pub(crate) max_outgoing_bytes_per_second: Option, pub(crate) persistent_congestion_threshold: u32, pub(crate) keep_alive_interval: Option, @@ -272,6 +273,14 @@ impl TransportConfig { self } + /// Configures an outbound rate limit (in bytes per second) for each connection. + /// + /// Defaults to `None`, which disables rate limiting. + pub fn max_outgoing_bytes_per_second(&mut self, value: Option) -> &mut Self { + self.max_outgoing_bytes_per_second = value; + self + } + /// Number of consecutive PTOs after which network is considered to be experiencing persistent congestion. pub fn persistent_congestion_threshold(&mut self, value: u32) -> &mut Self { self.persistent_congestion_threshold = value; @@ -558,6 +567,7 @@ impl Default for TransportConfig { mtu_discovery_config: Some(MtuDiscoveryConfig::default()), pad_to_mtu: false, ack_frequency_config: None, + max_outgoing_bytes_per_second: None, persistent_congestion_threshold: 3, keep_alive_interval: None, @@ -606,6 +616,7 @@ impl fmt::Debug for TransportConfig { mtu_discovery_config, pad_to_mtu, ack_frequency_config, + max_outgoing_bytes_per_second, persistent_congestion_threshold, keep_alive_interval, crypto_buffer_size, @@ -641,6 +652,10 @@ impl fmt::Debug for TransportConfig { .field("mtu_discovery_config", mtu_discovery_config) .field("pad_to_mtu", pad_to_mtu) .field("ack_frequency_config", ack_frequency_config) + .field( + "max_outgoing_bytes_per_second", + max_outgoing_bytes_per_second, + ) .field( "persistent_congestion_threshold", persistent_congestion_threshold, diff --git a/noq-proto/src/connection/mod.rs b/noq-proto/src/connection/mod.rs index f22291b88..b92892ab8 100644 --- a/noq-proto/src/connection/mod.rs +++ b/noq-proto/src/connection/mod.rs @@ -2613,8 +2613,10 @@ impl Connection { /// Whether the connection is in the process of being established /// - /// If this returns `false`, the connection may be either established or closed, - /// signaled by the emission of a `Connected` or `ConnectionLost` message respectively. + /// If this returns `false`, the connection may be either established or closed, signaled by the + /// emission of a [`Connected`](Event::Connected) or [`ConnectionLost`](Event::ConnectionLost) + /// event respectively. Note that locally-initiated closes via [`close()`](Self::close) do not + /// emit a `ConnectionLost` event. /// /// For an established connection this essentially means the handshake is **completed**, /// but not necessarily yet confirmed. @@ -2628,7 +2630,10 @@ impl Connection { /// either peer application intentionally closes it, or when either transport layer detects an /// error such as a time-out or certificate validation failure. /// - /// A `ConnectionLost` event is emitted with details when the connection becomes closed. + /// A [`ConnectionLost`](Event::ConnectionLost) event is emitted with details when the + /// connection is closed by the peer or due to an error. When the local application closes + /// the connection via [`close()`](Self::close), no `ConnectionLost` event is emitted; + /// instead, pending operations fail with [`ConnectionError::LocallyClosed`]. pub fn is_closed(&self) -> bool { self.state.is_closed() } @@ -2836,6 +2841,7 @@ impl Connection { }; if self.detect_spurious_loss(&ack, space, path) { + self.path_stats.for_path(path).spurious_congestion_events += 1; self.path_data_mut(path) .congestion .on_spurious_congestion_event(); @@ -7460,7 +7466,10 @@ pub enum Event { HandshakeConfirmed, /// The connection was lost /// - /// Emitted if the peer closes the connection or an error is encountered. + /// Emitted when the connection is closed due to an error, a timeout, or the peer closing it. + /// This is **not** emitted when the local application closes the connection via + /// [`Connection::close()`](crate::Connection::close). In that case, pending operations will + /// fail with [`ConnectionError::LocallyClosed`]. ConnectionLost { /// Reason that the connection was closed reason: ConnectionError, @@ -7637,6 +7646,9 @@ impl SentFrames { MaxStreams(max_streams) => { self.retransmits_mut().max_stream_id[max_streams.dir as usize] = true } + StreamsBlocked(streams_blocked) => { + self.retransmits_mut().streams_blocked[streams_blocked.dir as usize] = true + } } } } diff --git a/noq-proto/src/connection/pacing.rs b/noq-proto/src/connection/pacing.rs index 794b8070c..cb601ef97 100644 --- a/noq-proto/src/connection/pacing.rs +++ b/noq-proto/src/connection/pacing.rs @@ -18,22 +18,36 @@ pub(super) struct Pacer { last_window: u64, last_mtu: u16, tokens: u64, + max_bytes_per_second: Option, prev: Instant, } impl Pacer { /// Obtains a new [`Pacer`]. - pub(super) fn new(smoothed_rtt: Duration, window: u64, mtu: u16, now: Instant) -> Self { + pub(super) fn new( + smoothed_rtt: Duration, + window: u64, + mtu: u16, + max_bytes_per_second: Option, + now: Instant, + ) -> Self { + let window = rate_limited_window(smoothed_rtt, window, max_bytes_per_second); let capacity = optimal_capacity(smoothed_rtt, window, mtu); Self { capacity, last_window: window, last_mtu: mtu, tokens: capacity, + max_bytes_per_second, prev: now, } } + /// Obtains the `max_bytes_per_second` used when this [`Pacer`] was constructed. + pub(crate) fn max_bytes_per_second(&self) -> Option { + self.max_bytes_per_second + } + /// Record that a packet has been transmitted. pub(super) fn on_transmit(&mut self, packet_length: u16) { self.tokens = self.tokens.saturating_sub(packet_length.into()) @@ -60,6 +74,7 @@ impl Pacer { "zero-sized congestion control window is nonsense" ); + let window = rate_limited_window(smoothed_rtt, window, self.max_bytes_per_second); if window != self.last_window || mtu != self.last_mtu { self.capacity = optimal_capacity(smoothed_rtt, window, mtu); @@ -149,6 +164,27 @@ fn optimal_capacity(smoothed_rtt: Duration, window: u64, mtu: u16) -> u64 { ) } +/// Clamps the window to limit the sending rate to `max_bytes_per_second`. +/// +/// If `max_bytes_per_second` is `None`, the original window is returned. +fn rate_limited_window( + smoothed_rtt: Duration, + window: u64, + max_bytes_per_second: Option, +) -> u64 { + let Some(max_bytes_per_second) = max_bytes_per_second else { + return window; + }; + + let rate_window = max_bytes_per_second as f64 * smoothed_rtt.as_secs_f64(); + + // the pacer refills tokens at x1.25 speed, so we shrink the window to cancel out the speedup + // (otherwise the actual sending rate could be higher than `max_bytes_per_second`) + let adjusted_rate_window = (rate_window / 1.25).round(); + + Ord::min(window, Ord::max(adjusted_rate_window as u64, 1)) +} + /// Period of traffic to batch together on a reasonably fast connection const TARGET_BURST_INTERVAL: Duration = Duration::from_millis(2); @@ -175,17 +211,17 @@ mod tests { let rtt = Duration::from_micros(400); assert!( - Pacer::new(rtt, 30000, 1500, new_instant) + Pacer::new(rtt, 30000, 1500, None, new_instant) .delay(Duration::from_micros(0), 0, 1500, 1, old_instant) .is_none() ); assert!( - Pacer::new(rtt, 30000, 1500, new_instant) + Pacer::new(rtt, 30000, 1500, None, new_instant) .delay(Duration::from_micros(0), 1600, 1500, 1, old_instant) .is_none() ); assert!( - Pacer::new(rtt, 30000, 1500, new_instant) + Pacer::new(rtt, 30000, 1500, None, new_instant) .delay(Duration::from_micros(0), 1500, 1500, 3000, old_instant) .is_none() ); @@ -198,18 +234,18 @@ mod tests { let rtt = Duration::from_millis(50); let now = Instant::now(); - let pacer = Pacer::new(rtt, window, mtu, now); + let pacer = Pacer::new(rtt, window, mtu, None, now); assert_eq!( pacer.capacity, (window as u128 * TARGET_BURST_INTERVAL.as_nanos() / rtt.as_nanos()) as u64 ); assert_eq!(pacer.tokens, pacer.capacity); - let pacer = Pacer::new(Duration::from_millis(0), window, mtu, now); + let pacer = Pacer::new(Duration::from_millis(0), window, mtu, None, now); assert_eq!(pacer.capacity, MAX_BURST_SIZE * mtu as u64); assert_eq!(pacer.tokens, pacer.capacity); - let pacer = Pacer::new(rtt, 1, mtu, now); + let pacer = Pacer::new(rtt, 1, mtu, None, now); assert_eq!(pacer.capacity, mtu as u64); assert_eq!(pacer.tokens, pacer.capacity); } @@ -221,7 +257,7 @@ mod tests { let rtt = Duration::from_millis(50); let now = Instant::now(); - let mut pacer = Pacer::new(rtt, window, mtu, now); + let mut pacer = Pacer::new(rtt, window, mtu, None, now); assert_eq!( pacer.capacity, (window as u128 * TARGET_BURST_INTERVAL.as_nanos() / rtt.as_nanos()) as u64 @@ -260,7 +296,7 @@ mod tests { let rtt = Duration::from_millis(50); let old_instant = Instant::now(); - let mut pacer = Pacer::new(rtt, window, mtu, old_instant); + let mut pacer = Pacer::new(rtt, window, mtu, None, old_instant); let packet_capacity = pacer.capacity / mtu as u64; for _ in 0..packet_capacity { @@ -322,4 +358,37 @@ mod tests { ); assert_eq!(pacer.tokens, pacer.capacity); } + + #[test] + fn computes_pause_correctly_for_rate_limited() { + let window = 2_000_000u64; + let mtu = 1000; + let rtt = Duration::from_millis(50); + let old_instant = Instant::now(); + + let mut pacer = Pacer::new(rtt, window, mtu, Some(2_000), old_instant); + assert_eq!( + pacer.delay(rtt, 1_000, mtu, window, old_instant), + None, + "When capacity is available packets should be sent immediately" + ); + pacer.on_transmit(mtu); + + let actual_delay = pacer + .delay(rtt, 1_000, mtu, window, old_instant) + .expect("Send must be delayed"); + + let expected_delay = Duration::from_millis(500); + let diff = actual_delay.abs_diff(expected_delay); + + // Allow up to 2ns difference due to rounding + assert!( + diff < Duration::from_nanos(2), + "expected ≈ {expected_delay:?}, got {actual_delay:?} (diff {diff:?})" + ); + + // Should be able to send after a while + let now = old_instant + expected_delay / 2; + assert_eq!(pacer.delay(rtt, 500, mtu, window, now), None); + } } diff --git a/noq-proto/src/connection/paths.rs b/noq-proto/src/connection/paths.rs index ceeefc006..f26f23926 100644 --- a/noq-proto/src/connection/paths.rs +++ b/noq-proto/src/connection/paths.rs @@ -317,6 +317,7 @@ impl PathData { config.initial_rtt, congestion.initial_window(), config.get_initial_mtu(), + config.max_outgoing_bytes_per_second, now, ), congestion, @@ -375,7 +376,13 @@ impl PathData { Self { network_path, rtt: prev.rtt, - pacing: Pacer::new(smoothed_rtt, congestion.window(), prev.current_mtu(), now), + pacing: Pacer::new( + smoothed_rtt, + congestion.window(), + prev.current_mtu(), + prev.pacing.max_bytes_per_second(), + now, + ), sending_ecn: true, congestion, app_limited: false, diff --git a/noq-proto/src/connection/qlog.rs b/noq-proto/src/connection/qlog.rs index 5c618aca0..040c7b5c7 100644 --- a/noq-proto/src/connection/qlog.rs +++ b/noq-proto/src/connection/qlog.rs @@ -729,6 +729,17 @@ impl ToQlog for frame::MaxStreams { } } +#[cfg(feature = "qlog")] +impl ToQlog for frame::StreamsBlocked { + fn to_qlog(&self) -> QuicFrame { + QuicFrame::StreamsBlocked { + stream_type: self.dir.into(), + limit: self.limit, + raw: None, + } + } +} + #[cfg(feature = "qlog")] impl ToQlog for frame::NewConnectionId { fn to_qlog(&self) -> QuicFrame { diff --git a/noq-proto/src/connection/spaces.rs b/noq-proto/src/connection/spaces.rs index 8cb3f89a1..39983c16b 100644 --- a/noq-proto/src/connection/spaces.rs +++ b/noq-proto/src/connection/spaces.rs @@ -509,6 +509,7 @@ pub(super) struct LostPacket { pub struct Retransmits { pub(super) max_data: bool, pub(super) max_stream_id: [bool; 2], + pub(super) streams_blocked: [bool; 2], pub(super) reset_stream: Vec<(StreamId, VarInt)>, pub(super) stop_sending: Vec, pub(super) max_stream_data: FxHashSet, @@ -562,6 +563,7 @@ impl Retransmits { let Self { max_data, max_stream_id, + streams_blocked, reset_stream, stop_sending, max_stream_data, @@ -583,6 +585,7 @@ impl Retransmits { } = &self; !max_data && !max_stream_id.iter().any(|x| *x) + && !streams_blocked.iter().any(|x| *x) && reset_stream.is_empty() && stop_sending.is_empty() && max_stream_data @@ -611,6 +614,7 @@ impl ::std::ops::BitOrAssign for Retransmits { let Self { max_data, max_stream_id, + streams_blocked, reset_stream, stop_sending, max_stream_data, @@ -636,6 +640,7 @@ impl ::std::ops::BitOrAssign for Retransmits { self.max_data |= max_data; for dir in Dir::iter() { self.max_stream_id[dir as usize] |= max_stream_id[dir as usize]; + self.streams_blocked[dir as usize] |= streams_blocked[dir as usize]; } self.reset_stream.extend_from_slice(&reset_stream); self.stop_sending.extend_from_slice(&stop_sending); diff --git a/noq-proto/src/connection/stats.rs b/noq-proto/src/connection/stats.rs index 45d5812ca..dfaf02748 100644 --- a/noq-proto/src/connection/stats.rs +++ b/noq-proto/src/connection/stats.rs @@ -227,6 +227,8 @@ pub struct PathStats { pub cwnd: u64, /// Congestion events on the connection. pub congestion_events: u64, + /// Spurious congestion events on the connection. + pub spurious_congestion_events: u64, /// The number of packets lost on this path. pub lost_packets: u64, /// The number of bytes lost on this path. @@ -280,6 +282,7 @@ impl std::ops::Add for ConnectionStats { frame_rx, cwnd: _, congestion_events: _, + spurious_congestion_events: _, lost_packets, lost_bytes, sent_plpmtud_probes: _, @@ -310,6 +313,7 @@ impl std::ops::AddAssign for ConnectionStats { frame_rx: path_frame_rx, cwnd: _, congestion_events: _, + spurious_congestion_events: _, lost_packets: path_lost_packets, lost_bytes: path_lost_bytes, sent_plpmtud_probes: _, diff --git a/noq-proto/src/connection/streams/mod.rs b/noq-proto/src/connection/streams/mod.rs index 208187f7f..f09598884 100644 --- a/noq-proto/src/connection/streams/mod.rs +++ b/noq-proto/src/connection/streams/mod.rs @@ -48,8 +48,8 @@ impl<'a> Streams<'a> { return None; } - // TODO: Queue STREAM_ID_BLOCKED if this fails if self.state.next[dir as usize] >= self.state.max[dir as usize] { + self.state.streams_blocked[dir as usize] = true; return None; } diff --git a/noq-proto/src/connection/streams/state.rs b/noq-proto/src/connection/streams/state.rs index bfa9e1474..771c5984e 100644 --- a/noq-proto/src/connection/streams/state.rs +++ b/noq-proto/src/connection/streams/state.rs @@ -135,6 +135,8 @@ pub struct StreamsState { /// The shrink to be applied to local_max_data when receive_window is shrunk receive_window_shrink_debt: u64, + /// Whether the locally-initiated stream limit has been hit, per direction + pub(super) streams_blocked: [bool; 2], } impl StreamsState { @@ -179,6 +181,7 @@ impl StreamsState { initial_max_stream_data_bidi_local: 0u32.into(), initial_max_stream_data_bidi_remote: 0u32.into(), receive_window_shrink_debt: 0, + streams_blocked: [false, false], }; for dir in Dir::iter() { @@ -491,6 +494,22 @@ impl StreamsState { let count = self.max_remote[dir as usize]; builder.write_frame(frame::MaxStreams { dir, count }, stats); } + + // STREAMS_BLOCKED + for dir in Dir::iter() { + if self.streams_blocked[dir as usize] { + pending.streams_blocked[dir as usize] = true; + self.streams_blocked[dir as usize] = false; + } + + if !pending.streams_blocked[dir as usize] || builder.frame_space_remaining() <= 9 { + continue; + } + + pending.streams_blocked[dir as usize] = false; + let limit = self.max[dir as usize]; + builder.write_frame(frame::StreamsBlocked { dir, limit }, stats); + } } pub(in crate::connection) fn write_stream_frames<'a, 'b>( @@ -655,6 +674,7 @@ impl StreamsState { let current = &mut self.max[dir as usize]; if count > *current { *current = count; + self.streams_blocked[dir as usize] = false; self.events.push_back(StreamEvent::Available { dir }); } diff --git a/noq-proto/src/frame.rs b/noq-proto/src/frame.rs index 29339a882..f27ff71f7 100644 --- a/noq-proto/src/frame.rs +++ b/noq-proto/src/frame.rs @@ -216,6 +216,7 @@ pub(super) enum EncodableFrame<'a> { MaxData(MaxData), MaxStreamData(MaxStreamData), MaxStreams(MaxStreams), + StreamsBlocked(StreamsBlocked), } impl<'a> EncodableFrame<'a> { @@ -249,7 +250,8 @@ impl<'a> EncodableFrame<'a> { | EncodableFrame::StreamMeta(_) | EncodableFrame::MaxData(_) | EncodableFrame::MaxStreamData(_) - | EncodableFrame::MaxStreams(_) => true, + | EncodableFrame::MaxStreams(_) + | EncodableFrame::StreamsBlocked(_) => true, } } } diff --git a/noq-proto/src/tests/mod.rs b/noq-proto/src/tests/mod.rs index 289b78a45..2a088b048 100644 --- a/noq-proto/src/tests/mod.rs +++ b/noq-proto/src/tests/mod.rs @@ -196,7 +196,9 @@ fn server_stateless_reset() { rng.fill_bytes(&mut key_material); let mut endpoint_config = EndpointConfig::new(Arc::new(reset_key)); - endpoint_config.cid_generator(move || Box::new(HashedConnectionIdGenerator::from_key(0))); + endpoint_config.cid_generator(Arc::new(move || { + Box::new(HashedConnectionIdGenerator::from_key(0)) + })); let endpoint_config = Arc::new(endpoint_config); let mut pair = Pair::new(endpoint_config.clone(), server_config()); @@ -225,7 +227,9 @@ fn client_stateless_reset() { rng.fill_bytes(&mut key_material); let mut endpoint_config = EndpointConfig::new(Arc::new(reset_key)); - endpoint_config.cid_generator(move || Box::new(HashedConnectionIdGenerator::from_key(0))); + endpoint_config.cid_generator(Arc::new(move || { + Box::new(HashedConnectionIdGenerator::from_key(0)) + })); let endpoint_config = Arc::new(endpoint_config); let mut pair = Pair::new(endpoint_config.clone(), server_config()); @@ -253,7 +257,9 @@ fn stateless_reset_limit() { let _guard = subscribe(); let remote = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 42); let mut endpoint_config = EndpointConfig::default(); - endpoint_config.cid_generator(move || Box::new(RandomConnectionIdGenerator::new(8))); + endpoint_config.cid_generator(Arc::new(move || { + Box::new(RandomConnectionIdGenerator::new(8)) + })); let endpoint_config = Arc::new(endpoint_config); let mut endpoint = Endpoint::new( endpoint_config.clone(), @@ -1033,6 +1039,69 @@ fn stream_id_limit() { let _ = chunks.finalize(); } +#[test] +fn streams_blocked() { + let _guard = subscribe(); + let server = ServerConfig { + transport: Arc::new(TransportConfig { + max_concurrent_uni_streams: 1u32.into(), + ..TransportConfig::default() + }), + ..server_config() + }; + let mut pair = Pair::new(Default::default(), server); + let (client_ch, server_ch) = pair.connect(); + + // Use up the only stream slot, then try to open another + let s = pair + .client_streams(client_ch) + .open(Dir::Uni) + .expect("first uni stream"); + assert_eq!(pair.client_streams(client_ch).open(Dir::Uni), None); + + // Send data so the STREAMS_BLOCKED piggybacks on an outgoing packet + pair.client_send(client_ch, s).write(b"hi").unwrap(); + pair.drive(); + + assert_eq!( + pair.client_conn_mut(client_ch) + .stats() + .frame_tx + .streams_blocked_uni, + 1 + ); + assert_eq!( + pair.server_conn_mut(server_ch) + .stats() + .frame_rx + .streams_blocked_uni, + 1 + ); +} + +#[test] +fn streams_blocked_not_sent_under_limit() { + let _guard = subscribe(); + let mut pair = Pair::default(); + let (client_ch, _server_ch) = pair.connect(); + + // Default config allows many streams; opening one should not trigger STREAMS_BLOCKED + let s = pair + .client_streams(client_ch) + .open(Dir::Uni) + .expect("open stream"); + pair.client_send(client_ch, s).write(b"hi").unwrap(); + pair.drive(); + + assert_eq!( + pair.client_conn_mut(client_ch) + .stats() + .frame_tx + .streams_blocked_uni, + 0 + ); +} + #[test] fn key_update_simple() { let _guard = subscribe(); diff --git a/noq-proto/src/tests/multipath.rs b/noq-proto/src/tests/multipath.rs index 23dae419f..2357e6dfe 100644 --- a/noq-proto/src/tests/multipath.rs +++ b/noq-proto/src/tests/multipath.rs @@ -80,7 +80,7 @@ fn non_zero_length_cids() { } let mut ep_config = EndpointConfig::default(); - ep_config.cid_generator(|| Box::new(ZeroLenCidGenerator)); + ep_config.cid_generator(Arc::new(|| Box::new(ZeroLenCidGenerator))); let client = Endpoint::new(Arc::new(ep_config), None, true); let mut pair = Pair::new_from_endpoint(client, server); diff --git a/noq-udp/Cargo.toml b/noq-udp/Cargo.toml index cb748797d..feaa680bc 100644 --- a/noq-udp/Cargo.toml +++ b/noq-udp/Cargo.toml @@ -17,7 +17,7 @@ default = ["tracing", "tracing-log"] # Configure `tracing` to log events via `log` if no `tracing` subscriber exists. tracing-log = ["tracing/log"] log = ["dep:log"] -# Use private Apple APIs to send multiple packets in a single syscall. +# Support private Apple APIs to send multiple packets in a single syscall. fast-apple-datapath = [] [dependencies] diff --git a/noq-udp/src/lib.rs b/noq-udp/src/lib.rs index 27486e9a3..8dd717198 100644 --- a/noq-udp/src/lib.rs +++ b/noq-udp/src/lib.rs @@ -162,6 +162,7 @@ impl Transmit<'_> { /// This case is actually quite common when splitting up a prepared GSO batch /// again after GSO has been disabled because the last datagram in a GSO /// batch is allowed to be smaller than the segment size. + #[cfg_attr(apple_fast, allow(dead_code))] // Used by prepare_msg, which is unused when apple_fast fn effective_segment_size(&self) -> Option { match self.segment_size? { size if size >= self.contents.len() => None, diff --git a/noq-udp/src/unix.rs b/noq-udp/src/unix.rs index db86bfb03..71ccdc821 100644 --- a/noq-udp/src/unix.rs +++ b/noq-udp/src/unix.rs @@ -5,7 +5,7 @@ use std::{ mem::{self, MaybeUninit}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, num::NonZeroUsize, - os::unix::io::AsRawFd, + os::fd::AsRawFd, sync::{ Mutex, atomic::{AtomicBool, AtomicUsize, Ordering}, @@ -34,23 +34,6 @@ pub(crate) struct msghdr_x { pub msg_datalen: usize, } -#[cfg(apple_fast)] -unsafe extern "C" { - fn recvmsg_x( - s: libc::c_int, - msgp: *const msghdr_x, - cnt: libc::c_uint, - flags: libc::c_int, - ) -> isize; - - fn sendmsg_x( - s: libc::c_int, - msgp: *const msghdr_x, - cnt: libc::c_uint, - flags: libc::c_int, - ) -> isize; -} - #[cfg(target_os = "freebsd")] type IpTosTy = libc::c_uchar; #[cfg(not(any(target_os = "freebsd", target_os = "netbsd")))] @@ -73,6 +56,13 @@ pub struct UdpSocketState { /// In particular, we do not use IP_TOS cmsg_type in this case, /// which is not supported on Linux <3.13 and results in not sending the UDP packet at all. sendmsg_einval: AtomicBool, + + /// Whether to use Apple's fast `sendmsg_x`/`recvmsg_x` APIs. + /// + /// These private APIs provide better performance but may not be available on all + /// Apple OS versions. Callers must verify availability before enabling. + #[cfg(apple_fast)] + apple_fast_path: AtomicBool, } impl UdpSocketState { @@ -120,11 +110,14 @@ impl UdpSocketState { } let mut may_fragment = false; + #[cfg_attr( + not(any(target_os = "linux", target_os = "android")), + expect(unused_mut) + )] + let mut gro_segments = NonZeroUsize::MIN; + #[cfg(any(target_os = "linux", target_os = "android"))] { - // opportunistically try to enable GRO. See gro::gro_segments(). - let _ = set_socket_option(&*io, libc::SOL_UDP, libc::UDP_GRO, OPTION_ON); - // Forbid IPv4 fragmentation. Set even for IPv6 to account for IPv6 mapped IPv4 addresses. // Set `may_fragment` to `true` if this option is not supported on the platform. may_fragment |= !set_socket_option_supported( @@ -145,6 +138,17 @@ impl UdpSocketState { libc::IPV6_PMTUDISC_PROBE, )?; } + + if set_socket_option(&*io, libc::SOL_UDP, libc::UDP_GRO, OPTION_ON).is_ok() { + // As defined in net/ipv4/udp_offload.c + // #define UDP_GRO_CNT_MAX 64 + // + // NOTE: this MUST be set to UDP_GRO_CNT_MAX to ensure that the receive buffer size + // (get_max_udp_payload_size() * gro_segments()) is large enough to hold the largest GRO + // list the kernel might potentially produce. See + // https://github.com/quinn-rs/quinn/pull/1354. + gro_segments = NonZeroUsize::new(64).expect("known"); + } } #[cfg(any(target_os = "freebsd", apple))] { @@ -187,10 +191,12 @@ impl UdpSocketState { let now = Instant::now(); Ok(Self { last_send_error: Mutex::new(now.checked_sub(2 * IO_ERROR_LOG_INTERVAL).unwrap_or(now)), - max_gso_segments: AtomicUsize::new(gso::max_gso_segments()), - gro_segments: gro::gro_segments(), + max_gso_segments: AtomicUsize::new(gso::max_gso_segments(&*io)), + gro_segments, may_fragment, sendmsg_einval: AtomicBool::new(false), + #[cfg(apple_fast)] + apple_fast_path: AtomicBool::new(false), }) } @@ -225,13 +231,50 @@ impl UdpSocketState { send(self, socket.0, transmit) } + #[cfg(not(any( + apple, + target_os = "openbsd", + target_os = "netbsd", + target_os = "dragonfly", + solarish + )))] pub fn recv( &self, socket: UdpSockRef<'_>, bufs: &mut [IoSliceMut<'_>], meta: &mut [RecvMeta], ) -> io::Result { - recv(socket.0, bufs, meta) + recv_via_recvmmsg(socket.0, bufs, meta) + } + + #[cfg(apple_fast)] + pub fn recv( + &self, + socket: UdpSockRef<'_>, + bufs: &mut [IoSliceMut<'_>], + meta: &mut [RecvMeta], + ) -> io::Result { + if self.is_apple_fast_path_enabled() { + recv_via_recvmsg_x(self, socket.0, bufs, meta) + } else { + recv_single(socket.0, bufs, meta) + } + } + + #[cfg(any( + target_os = "openbsd", + target_os = "netbsd", + target_os = "dragonfly", + solarish, + apple_slow + ))] + pub fn recv( + &self, + socket: UdpSockRef<'_>, + bufs: &mut [IoSliceMut<'_>], + meta: &mut [RecvMeta], + ) -> io::Result { + recv_single(socket.0, bufs, meta) } /// The maximum amount of segments which can be transmitted if a platform @@ -298,6 +341,45 @@ impl UdpSocketState { fn set_sendmsg_einval(&self) { self.sendmsg_einval.store(true, Ordering::Relaxed) } + + /// Enables Apple's fast UDP datapath using private `sendmsg_x`/`recvmsg_x` APIs. + /// Once enabled, this also updates [`max_gso_segments`] to allow batched sends. + /// + /// # Safety + /// + /// These APIs may crash on unsupported OS versions, so callers must verify + /// availability before enabling. + /// + /// [`max_gso_segments`]: Self::max_gso_segments + #[cfg(apple_fast)] + pub unsafe fn set_apple_fast_path(&self) { + self.apple_fast_path.store(true, Ordering::Relaxed); + self.max_gso_segments.store(BATCH_SIZE, Ordering::Relaxed); + } + + /// Returns whether Apple's fast UDP datapath is enabled for this socket. + #[cfg(apple_fast)] + pub fn is_apple_fast_path_enabled(&self) -> bool { + self.apple_fast_path.load(Ordering::Relaxed) + } + + /// Disables Apple's fast UDP datapath, reverting to `sendmsg`/`recvmsg`. + #[cfg(apple_fast)] + fn disable_apple_fast_path(&self) { + self.apple_fast_path.store(false, Ordering::Relaxed); + self.max_gso_segments.store(1, Ordering::Relaxed); + } + + /// Resolves an Apple fast-path function pointer via `resolver`, disabling the fast path if + /// the symbol is absent so that future calls use the slow path directly. + #[cfg(apple_fast)] + fn resolve_apple_fast_fn(&self, resolver: fn() -> Option) -> Option { + let f = resolver(); + if f.is_none() { + self.disable_apple_fast_path(); + } + f + } } #[cfg(not(any(apple, target_os = "openbsd", target_os = "netbsd")))] @@ -387,6 +469,20 @@ fn send( #[cfg(apple_fast)] fn send(state: &UdpSocketState, io: SockRef<'_>, transmit: &Transmit<'_>) -> io::Result<()> { + if state.is_apple_fast_path_enabled() { + send_via_sendmsg_x(state, io, transmit) + } else { + send_single(state, io, transmit) + } +} + +/// Send using the fast `sendmsg_x` API. +#[cfg(apple_fast)] +fn send_via_sendmsg_x( + state: &UdpSocketState, + io: SockRef<'_>, + transmit: &Transmit<'_>, +) -> io::Result<()> { let mut hdrs = unsafe { mem::zeroed::<[msghdr_x; BATCH_SIZE]>() }; let mut iovs = unsafe { mem::zeroed::<[libc::iovec; BATCH_SIZE]>() }; let mut ctrls = [cmsg::Aligned([0u8; CMSG_LEN]); BATCH_SIZE]; @@ -400,7 +496,7 @@ fn send(state: &UdpSocketState, io: SockRef<'_>, transmit: &Transmit<'_>) -> io: .enumerate() .take(BATCH_SIZE) { - prepare_msg( + prepare_msg_x( &Transmit { destination: transmit.destination, ecn: transmit.ecn, @@ -418,24 +514,21 @@ fn send(state: &UdpSocketState, io: SockRef<'_>, transmit: &Transmit<'_>) -> io: hdrs[i].msg_datalen = chunk.len(); cnt += 1; } - loop { - let n = unsafe { sendmsg_x(io.as_raw_fd(), hdrs.as_ptr(), cnt as u32, 0) }; - - if n >= 0 { - return Ok(()); - } - - let e = io::Error::last_os_error(); - match e.kind() { - // Retry the transmission - io::ErrorKind::Interrupted => continue, - _ => return Err(e), - } - } + let Some(sendmsg_x) = state.resolve_apple_fast_fn(sendmsg_x_fn) else { + return send_single(state, io, transmit); + }; + retry_if_interrupted(|| unsafe { sendmsg_x(io.as_raw_fd(), hdrs.as_ptr(), cnt as u32, 0) })?; + Ok(()) } #[cfg(any(target_os = "openbsd", target_os = "netbsd", apple_slow))] fn send(state: &UdpSocketState, io: SockRef<'_>, transmit: &Transmit<'_>) -> io::Result<()> { + send_single(state, io, transmit) +} + +#[cfg(any(target_os = "openbsd", target_os = "netbsd", apple))] +#[cfg_attr(apple_fast, allow(dead_code))] // Unused when apple_fast is enabled +fn send_single(state: &UdpSocketState, io: SockRef<'_>, transmit: &Transmit<'_>) -> io::Result<()> { let mut hdr: libc::msghdr = unsafe { mem::zeroed() }; let mut iov: libc::iovec = unsafe { mem::zeroed() }; let mut ctrl = cmsg::Aligned([0u8; CMSG_LEN]); @@ -449,22 +542,11 @@ fn send(state: &UdpSocketState, io: SockRef<'_>, transmit: &Transmit<'_>) -> io: cfg!(apple) || cfg!(target_os = "openbsd") || cfg!(target_os = "netbsd"), state.sendmsg_einval(), ); - loop { - let n = unsafe { libc::sendmsg(io.as_raw_fd(), &hdr, 0) }; - - if n >= 0 { - return Ok(()); - } - - let e = io::Error::last_os_error(); - match e.kind() { - // Retry the transmission - io::ErrorKind::Interrupted => continue, - _ => return Err(e), - } - } + retry_if_interrupted(|| unsafe { libc::sendmsg(io.as_raw_fd(), &hdr, 0) })?; + Ok(()) } +/// Receive using the batched `recvmmsg` syscall. #[cfg(not(any( apple, target_os = "openbsd", @@ -472,7 +554,11 @@ fn send(state: &UdpSocketState, io: SockRef<'_>, transmit: &Transmit<'_>) -> io: target_os = "dragonfly", solarish )))] -fn recv(io: SockRef<'_>, bufs: &mut [IoSliceMut<'_>], meta: &mut [RecvMeta]) -> io::Result { +fn recv_via_recvmmsg( + io: SockRef<'_>, + bufs: &mut [IoSliceMut<'_>], + meta: &mut [RecvMeta], +) -> io::Result { let mut names = [MaybeUninit::::uninit(); BATCH_SIZE]; let mut ctrls = [cmsg::Aligned(MaybeUninit::<[u8; CMSG_LEN]>::uninit()); BATCH_SIZE]; let mut hdrs = unsafe { mem::zeroed::<[libc::mmsghdr; BATCH_SIZE]>() }; @@ -485,36 +571,29 @@ fn recv(io: SockRef<'_>, bufs: &mut [IoSliceMut<'_>], meta: &mut [RecvMeta]) -> &mut hdrs[i].msg_hdr, ); } - let msg_count = loop { - let n = unsafe { - libc::recvmmsg( - io.as_raw_fd(), - hdrs.as_mut_ptr(), - bufs.len().min(BATCH_SIZE) as _, - 0, - ptr::null_mut::(), - ) - }; - - if n >= 0 { - break n; - } - - let e = io::Error::last_os_error(); - match e.kind() { - // Retry receiving - io::ErrorKind::Interrupted => continue, - _ => return Err(e), - } - }; + let msg_count = retry_if_interrupted(|| unsafe { + libc::recvmmsg( + io.as_raw_fd(), + hdrs.as_mut_ptr(), + bufs.len().min(BATCH_SIZE) as _, + 0, + ptr::null_mut::(), + ) as isize + })?; for i in 0..(msg_count as usize) { meta[i] = decode_recv(&names[i], &hdrs[i].msg_hdr, hdrs[i].msg_len as usize)?; } Ok(msg_count as usize) } +/// Receive using the fast `recvmsg_x` API. #[cfg(apple_fast)] -fn recv(io: SockRef<'_>, bufs: &mut [IoSliceMut<'_>], meta: &mut [RecvMeta]) -> io::Result { +fn recv_via_recvmsg_x( + state: &UdpSocketState, + io: SockRef<'_>, + bufs: &mut [IoSliceMut<'_>], + meta: &mut [RecvMeta], +) -> io::Result { let mut names = [MaybeUninit::::uninit(); BATCH_SIZE]; // MacOS 10.15 `recvmsg_x` does not override the `msghdr_x` // `msg_controllen`. Thus, after the call to `recvmsg_x`, one does not know @@ -526,36 +605,74 @@ fn recv(io: SockRef<'_>, bufs: &mut [IoSliceMut<'_>], meta: &mut [RecvMeta]) -> let mut hdrs = unsafe { mem::zeroed::<[msghdr_x; BATCH_SIZE]>() }; let max_msg_count = bufs.len().min(BATCH_SIZE); for i in 0..max_msg_count { - prepare_recv(&mut bufs[i], &mut names[i], &mut ctrls[i], &mut hdrs[i]); + prepare_recv_x(&mut bufs[i], &mut names[i], &mut ctrls[i], &mut hdrs[i]); } - let msg_count = loop { - let n = unsafe { recvmsg_x(io.as_raw_fd(), hdrs.as_mut_ptr(), max_msg_count as _, 0) }; - - if n >= 0 { - break n; - } - - let e = io::Error::last_os_error(); - match e.kind() { - // Retry receiving - io::ErrorKind::Interrupted => continue, - _ => return Err(e), - } + let Some(recvmsg_x) = state.resolve_apple_fast_fn(recvmsg_x_fn) else { + return recv_single(io, bufs, meta); }; + let msg_count = retry_if_interrupted(|| unsafe { + recvmsg_x(io.as_raw_fd(), hdrs.as_mut_ptr(), max_msg_count as _, 0) + })?; for i in 0..(msg_count as usize) { meta[i] = decode_recv(&names[i], &hdrs[i], hdrs[i].msg_datalen as usize)?; } Ok(msg_count as usize) } +/// Returns the `sendmsg_x` function pointer, resolving it via `dlsym` on first call. +/// +/// Returns `None` if the symbol is not available on the current OS version. +#[cfg(apple_fast)] +fn sendmsg_x_fn() -> Option { + static ADDR: std::sync::OnceLock = std::sync::OnceLock::new(); + // SAFETY: `resolve_symbol` only returns non-zero addresses obtained from `dlsym`, which + // guarantees a callable symbol whose type matches the declaration above. + resolve_symbol(&ADDR, c"sendmsg_x") + .map(|addr| unsafe { std::mem::transmute::(addr) }) +} + +/// Returns the `recvmsg_x` function pointer, resolving it via `dlsym` on first call. +/// +/// Returns `None` if the symbol is not available on the current OS version. +#[cfg(apple_fast)] +fn recvmsg_x_fn() -> Option { + static ADDR: std::sync::OnceLock = std::sync::OnceLock::new(); + // SAFETY: `resolve_symbol` only returns non-zero addresses obtained from `dlsym`, which + // guarantees a callable symbol whose type matches the declaration above. + resolve_symbol(&ADDR, c"recvmsg_x") + .map(|addr| unsafe { std::mem::transmute::(addr) }) +} + +#[cfg(apple_fast)] +type SendmsgXFn = + unsafe extern "C" fn(libc::c_int, *const msghdr_x, libc::c_uint, libc::c_int) -> isize; +#[cfg(apple_fast)] +type RecvmsgXFn = + unsafe extern "C" fn(libc::c_int, *mut msghdr_x, libc::c_uint, libc::c_int) -> isize; + +/// Resolves a symbol via `dlsym` on first call, caching the result. +/// +/// Returns `None` if the symbol is not available on the current OS version. +#[cfg(apple_fast)] +fn resolve_symbol(lock: &std::sync::OnceLock, name: &std::ffi::CStr) -> Option { + let addr = + *lock.get_or_init(|| unsafe { libc::dlsym(libc::RTLD_DEFAULT, name.as_ptr()) as usize }); + (addr != 0).then_some(addr) +} + #[cfg(any( target_os = "openbsd", target_os = "netbsd", target_os = "dragonfly", solarish, - apple_slow + apple ))] -fn recv(io: SockRef<'_>, bufs: &mut [IoSliceMut<'_>], meta: &mut [RecvMeta]) -> io::Result { +#[cfg_attr(apple_fast, allow(dead_code))] // Unused when apple_fast is enabled +fn recv_single( + io: SockRef<'_>, + bufs: &mut [IoSliceMut<'_>], + meta: &mut [RecvMeta], +) -> io::Result { let mut name = MaybeUninit::::uninit(); let mut ctrl = cmsg::Aligned(MaybeUninit::<[u8; CMSG_LEN]>::uninit()); let mut hdr = unsafe { mem::zeroed::() }; @@ -584,11 +701,11 @@ fn recv(io: SockRef<'_>, bufs: &mut [IoSliceMut<'_>], meta: &mut [RecvMeta]) -> const CMSG_LEN: usize = 88; +#[cfg_attr(apple_fast, allow(dead_code))] // Unused when apple_fast is enabled fn prepare_msg( transmit: &Transmit<'_>, dst_addr: &socket2::SockAddr, - #[cfg(not(apple_fast))] hdr: &mut libc::msghdr, - #[cfg(apple_fast)] hdr: &mut msghdr_x, + hdr: &mut libc::msghdr, iov: &mut libc::iovec, ctrl: &mut cmsg::Aligned<[u8; CMSG_LEN]>, #[allow(unused_variables)] // only used on FreeBSD & macOS @@ -628,6 +745,10 @@ fn prepare_msg( encoder.push(libc::IPPROTO_IPV6, libc::IPV6_TCLASS, ecn); } + // On apple_fast, prepare_msg is only compiled for send_single (fallback path), while the main + // send path uses prepare_msg_x with msghdr_x. gso::set_segment_size has a different signature + // when apple_fast is enabled, and it's a no-op on non-Linux platforms anyway. + #[cfg(not(apple_fast))] if let Some(segment_size) = transmit.effective_segment_size() { gso::set_segment_size(&mut encoder, segment_size as u16); } @@ -671,7 +792,67 @@ fn prepare_msg( encoder.finish(); } -#[cfg(not(apple_fast))] +/// Prepares an `msghdr_x` for use with `sendmsg_x`. +#[cfg(apple_fast)] +fn prepare_msg_x( + transmit: &Transmit<'_>, + dst_addr: &socket2::SockAddr, + hdr: &mut msghdr_x, + iov: &mut libc::iovec, + ctrl: &mut cmsg::Aligned<[u8; CMSG_LEN]>, + #[allow(unused_variables)] encode_src_ip: bool, + sendmsg_einval: bool, +) { + iov.iov_base = transmit.contents.as_ptr() as *const _ as *mut _; + iov.iov_len = transmit.contents.len(); + + let name = dst_addr.as_ptr() as *mut libc::c_void; + let namelen = dst_addr.len(); + hdr.msg_name = name as *mut _; + hdr.msg_namelen = namelen; + hdr.msg_iov = iov; + hdr.msg_iovlen = 1; + + hdr.msg_control = ctrl.0.as_mut_ptr() as _; + hdr.msg_controllen = CMSG_LEN as _; + let mut encoder = unsafe { cmsg::Encoder::new(hdr) }; + let ecn = transmit.ecn.map_or(0, |x| x as libc::c_int); + let is_ipv4 = transmit.destination.is_ipv4() + || matches!(transmit.destination.ip(), IpAddr::V6(addr) if addr.to_ipv4_mapped().is_some()); + if is_ipv4 { + if !sendmsg_einval { + encoder.push(libc::IPPROTO_IP, libc::IP_TOS, ecn as IpTosTy); + } + } else { + encoder.push(libc::IPPROTO_IPV6, libc::IPV6_TCLASS, ecn); + } + + if let Some(ip) = &transmit.src_ip { + match ip { + IpAddr::V4(v4) => { + if encode_src_ip { + let addr = libc::in_addr { + s_addr: u32::from_ne_bytes(v4.octets()), + }; + encoder.push(libc::IPPROTO_IP, libc::IP_RECVDSTADDR, addr); + } + } + IpAddr::V6(v6) => { + let pktinfo = libc::in6_pktinfo { + ipi6_ifindex: 0, + ipi6_addr: libc::in6_addr { + s6_addr: v6.octets(), + }, + }; + encoder.push(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO, pktinfo); + } + } + } + + encoder.finish(); +} + +#[cfg_attr(apple_fast, allow(dead_code))] // Unused when apple_fast is enabled fn prepare_recv( buf: &mut IoSliceMut<'_>, name: &mut MaybeUninit, @@ -687,8 +868,9 @@ fn prepare_recv( hdr.msg_flags = 0; } +/// Prepares an `msghdr_x` for receiving with `recvmsg_x`. #[cfg(apple_fast)] -fn prepare_recv( +fn prepare_recv_x( buf: &mut IoSliceMut<'_>, name: &mut MaybeUninit, ctrl: &mut cmsg::Aligned<[u8; CMSG_LEN]>, @@ -710,17 +892,42 @@ fn decode_recv>( len: usize, ) -> io::Result { let name = unsafe { name.assume_init() }; - let mut ecn_bits = 0; - let mut dst_ip = None; - let mut interface_index = None; - #[allow(unused_mut)] // only mutable on Linux - let mut stride = len; + let mut ctrl = ControlMetadata { + ecn_bits: 0, + dst_ip: None, + interface_index: None, + stride: len, + }; let cmsg_iter = unsafe { cmsg::Iter::new(hdr) }; for cmsg in cmsg_iter { + ctrl.decode(cmsg); + } + + Ok(RecvMeta { + len, + stride: ctrl.stride, + addr: decode_socket_addr(&name)?, + ecn: EcnCodepoint::from_bits(ctrl.ecn_bits), + dst_ip: ctrl.dst_ip, + interface_index: ctrl.interface_index, + }) +} + +/// Metadata decoded from control messages +struct ControlMetadata { + ecn_bits: u8, + dst_ip: Option, + interface_index: Option, + stride: usize, +} + +impl ControlMetadata { + /// Decodes a control message and updates the metadata state + fn decode(&mut self, cmsg: &libc::cmsghdr) { match (cmsg.cmsg_level, cmsg.cmsg_type) { (libc::IPPROTO_IP, libc::IP_TOS) => unsafe { - ecn_bits = cmsg::decode::(cmsg); + self.ecn_bits = cmsg::decode::(cmsg); }, // FreeBSD uses IP_RECVTOS here, and we can be liberal because cmsgs are opt-in. #[cfg(not(any( @@ -730,7 +937,7 @@ fn decode_recv>( solarish )))] (libc::IPPROTO_IP, libc::IP_RECVTOS) => unsafe { - ecn_bits = cmsg::decode::(cmsg); + self.ecn_bits = cmsg::decode::(cmsg); }, (libc::IPPROTO_IPV6, libc::IPV6_TCLASS) => unsafe { // Temporary hack around broken macos ABI. Remove once upstream fixes it. @@ -739,76 +946,68 @@ fn decode_recv>( if cfg!(apple) && cmsg.cmsg_len as usize == libc::CMSG_LEN(mem::size_of::() as _) as usize { - ecn_bits = cmsg::decode::(cmsg); + self.ecn_bits = cmsg::decode::(cmsg); } else { - ecn_bits = cmsg::decode::(cmsg) as u8; + self.ecn_bits = cmsg::decode::(cmsg) as u8; } }, #[cfg(any(target_os = "linux", target_os = "android"))] (libc::IPPROTO_IP, libc::IP_PKTINFO) => { let pktinfo = unsafe { cmsg::decode::(cmsg) }; - dst_ip = Some(IpAddr::V4(Ipv4Addr::from( + self.dst_ip = Some(IpAddr::V4(Ipv4Addr::from( pktinfo.ipi_addr.s_addr.to_ne_bytes(), ))); - interface_index = Some(pktinfo.ipi_ifindex as u32); + self.interface_index = Some(pktinfo.ipi_ifindex as u32); } #[cfg(any(bsd, apple))] (libc::IPPROTO_IP, libc::IP_RECVDSTADDR) => { let in_addr = unsafe { cmsg::decode::(cmsg) }; - dst_ip = Some(IpAddr::V4(Ipv4Addr::from(in_addr.s_addr.to_ne_bytes()))); + self.dst_ip = Some(IpAddr::V4(Ipv4Addr::from(in_addr.s_addr.to_ne_bytes()))); } (libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => { let pktinfo = unsafe { cmsg::decode::(cmsg) }; - dst_ip = Some(IpAddr::V6(Ipv6Addr::from(pktinfo.ipi6_addr.s6_addr))); + self.dst_ip = Some(IpAddr::V6(Ipv6Addr::from(pktinfo.ipi6_addr.s6_addr))); #[cfg_attr(not(target_os = "android"), expect(clippy::unnecessary_cast))] { - interface_index = Some(pktinfo.ipi6_ifindex as u32); + self.interface_index = Some(pktinfo.ipi6_ifindex as u32); } } #[cfg(any(target_os = "linux", target_os = "android"))] (libc::SOL_UDP, libc::UDP_GRO) => unsafe { - stride = cmsg::decode::(cmsg) as usize; + self.stride = cmsg::decode::(cmsg) as usize; }, _ => {} } } +} - let addr = match libc::c_int::from(name.ss_family) { +/// Decodes a `sockaddr_storage` into a `SocketAddr` +fn decode_socket_addr(name: &libc::sockaddr_storage) -> io::Result { + match libc::c_int::from(name.ss_family) { libc::AF_INET => { // Safety: if the ss_family field is AF_INET then storage must be a sockaddr_in. let addr: &libc::sockaddr_in = - unsafe { &*(&name as *const _ as *const libc::sockaddr_in) }; - SocketAddr::V4(SocketAddrV4::new( + unsafe { &*(name as *const _ as *const libc::sockaddr_in) }; + Ok(SocketAddr::V4(SocketAddrV4::new( Ipv4Addr::from(addr.sin_addr.s_addr.to_ne_bytes()), u16::from_be(addr.sin_port), - )) + ))) } libc::AF_INET6 => { // Safety: if the ss_family field is AF_INET6 then storage must be a sockaddr_in6. let addr: &libc::sockaddr_in6 = - unsafe { &*(&name as *const _ as *const libc::sockaddr_in6) }; - SocketAddr::V6(SocketAddrV6::new( + unsafe { &*(name as *const _ as *const libc::sockaddr_in6) }; + Ok(SocketAddr::V6(SocketAddrV6::new( Ipv6Addr::from(addr.sin6_addr.s6_addr), u16::from_be(addr.sin6_port), addr.sin6_flowinfo, addr.sin6_scope_id, - )) - } - f => { - return Err(io::Error::other(format!( - "expected AF_INET or AF_INET6, got {f} in decode_recv" - ))); + ))) } - }; - - Ok(RecvMeta { - len, - stride, - addr, - ecn: EcnCodepoint::from_bits(ecn_bits), - dst_ip, - interface_index, - }) + f => Err(io::Error::other(format!( + "expected AF_INET or AF_INET6, got {f}" + ))), + } } #[cfg(not(apple_slow))] @@ -832,23 +1031,25 @@ mod gso { /// Checks whether GSO support is available by checking the kernel version followed by setting /// the UDP_SEGMENT option on a socket - pub(crate) fn max_gso_segments() -> usize { + pub(crate) fn max_gso_segments(socket: &impl AsRawFd) -> usize { const GSO_SIZE: libc::c_int = 1500; if !SUPPORTED_BY_CURRENT_KERNEL.get_or_init(supported_by_current_kernel) { return 1; } - let Ok(socket) = std::net::UdpSocket::bind("[::]:0") - .or_else(|_| std::net::UdpSocket::bind((Ipv4Addr::LOCALHOST, 0))) - else { - return 1; - }; - // As defined in linux/udp.h // #define UDP_MAX_SEGMENTS (1 << 6UL) - match set_socket_option(&socket, libc::SOL_UDP, libc::UDP_SEGMENT, GSO_SIZE) { - Ok(()) => 64, + match set_socket_option(socket, libc::SOL_UDP, libc::UDP_SEGMENT, GSO_SIZE) { + Ok(()) => { + // Disable GSO again globally to ensure we can selectively enable it via cmsg. + // See: + // - https://github.com/quinn-rs/quinn/issues/2575 + // - https://man7.org/linux/man-pages/man7/udp.7.html + let _ = set_socket_option(socket, libc::SOL_UDP, libc::UDP_SEGMENT, 0); + + 64 + } Err(_e) => { crate::log::debug!( "failed to set `UDP_SEGMENT` socket option ({_e}); setting `max_gso_segments = 1`" @@ -980,21 +1181,17 @@ mod gso { // On Apple platforms using the `sendmsg_x` call, UDP datagram segmentation is not // offloaded to the NIC or even the kernel, but instead done here in user space in // [`send`]) and then passed to the OS as individual `iovec`s (up to `BATCH_SIZE`). +// The initial value is 1 (no batching); callers can enable batching via +// `UdpSocketState::set_apple_fast_path()` which updates `max_gso_segments`. #[cfg(not(any(target_os = "linux", target_os = "android")))] mod gso { use super::*; - pub(super) fn max_gso_segments() -> usize { - #[cfg(apple_fast)] - { - BATCH_SIZE - } - #[cfg(not(apple_fast))] - { - 1 - } + pub(super) fn max_gso_segments(_socket: &impl AsRawFd) -> usize { + 1 } + #[cfg_attr(apple_fast, allow(dead_code))] // Unused when apple_fast is enabled pub(super) fn set_segment_size( #[cfg(not(apple_fast))] _encoder: &mut cmsg::Encoder<'_, libc::msghdr>, #[cfg(apple_fast)] _encoder: &mut cmsg::Encoder<'_, msghdr_x>, @@ -1003,31 +1200,6 @@ mod gso { } } -#[cfg(any(target_os = "linux", target_os = "android"))] -mod gro { - use super::*; - - pub(crate) fn gro_segments() -> NonZeroUsize { - let Ok(socket) = std::net::UdpSocket::bind("[::]:0") - .or_else(|_| std::net::UdpSocket::bind((Ipv4Addr::LOCALHOST, 0))) - else { - return NonZeroUsize::MIN; - }; - - // As defined in net/ipv4/udp_offload.c - // #define UDP_GRO_CNT_MAX 64 - // - // NOTE: this MUST be set to UDP_GRO_CNT_MAX to ensure that the receive buffer size - // (get_max_udp_payload_size() * gro_segments()) is large enough to hold the largest GRO - // list the kernel might potentially produce. See - // https://github.com/quinn-rs/quinn/pull/1354. - match set_socket_option(&socket, libc::SOL_UDP, libc::UDP_GRO, OPTION_ON) { - Ok(()) => NonZeroUsize::new(64).expect("known"), - Err(_) => NonZeroUsize::MIN, - } - } -} - /// Returns whether the given socket option is supported on the current platform /// /// Yields `Ok(true)` if the option was set successfully, `Ok(false)` if setting @@ -1070,11 +1242,17 @@ fn set_socket_option( const OPTION_ON: libc::c_int = 1; -#[cfg(not(any(target_os = "linux", target_os = "android")))] -mod gro { - use std::num::NonZeroUsize; - - pub(super) fn gro_segments() -> NonZeroUsize { - NonZeroUsize::MIN +/// Calls `f` in a loop, retrying on `EINTR`, and returns the non-negative result or the first +/// non-`EINTR` error. +fn retry_if_interrupted(mut f: impl FnMut() -> isize) -> io::Result { + loop { + let n = f(); + if n >= 0 { + return Ok(n); + } + let e = io::Error::last_os_error(); + if e.kind() != io::ErrorKind::Interrupted { + return Err(e); + } } } diff --git a/noq-udp/tests/tests.rs b/noq-udp/tests/tests.rs index 9bc2014a1..fa0d374c0 100644 --- a/noq-udp/tests/tests.rs +++ b/noq-udp/tests/tests.rs @@ -416,3 +416,95 @@ fn ip_to_v6_mapped(x: IpAddr) -> IpAddr { IpAddr::V6(_) => x, } } + +/// Test Apple fast datapath enable/disable functionality. +/// +/// This test verifies that: +/// 1. `max_gso_segments()` returns 1 by default (fast path disabled) +/// 2. After calling `set_apple_fast_path()`, `max_gso_segments()` returns `BATCH_SIZE` +/// 3. Send/recv still works correctly with the fast path enabled +#[test] +#[cfg(apple_fast)] +fn apple_fast_datapath() { + let send = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).unwrap(); + let recv = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).unwrap(); + let dst_addr = recv.local_addr().unwrap(); + + let send_state = UdpSocketState::new((&send).into()).unwrap(); + let recv_state = UdpSocketState::new((&recv).into()).unwrap(); + + // Initially, fast path should be disabled and max_gso_segments should be 1 + assert!( + !send_state.is_apple_fast_path_enabled(), + "fast path should be disabled initially" + ); + assert_eq!( + send_state.max_gso_segments().get(), + 1, + "max_gso_segments should be 1 before enabling fast path" + ); + + // Enable the fast path + // SAFETY: Assume that sendmsg_x/recvmsg_x are available on the macOS test host. + unsafe { + send_state.set_apple_fast_path(); + recv_state.set_apple_fast_path(); + } + + // After enabling, fast path should be enabled and max_gso_segments should be BATCH_SIZE + assert!( + send_state.is_apple_fast_path_enabled(), + "fast path should be enabled after calling set_apple_fast_path()" + ); + assert_eq!( + send_state.max_gso_segments().get(), + noq_udp::BATCH_SIZE, + "max_gso_segments should be BATCH_SIZE after enabling fast path" + ); + + // Verify send/recv still works with fast path enabled + recv.set_nonblocking(false).unwrap(); + + const SEGMENT_SIZE: usize = 128; + let segments = send_state.max_gso_segments().get(); + let msg = vec![0xAB; SEGMENT_SIZE * segments]; + + send_state + .try_send( + (&send).into(), + &Transmit { + destination: dst_addr, + ecn: None, + contents: &msg, + segment_size: Some(SEGMENT_SIZE), + src_ip: None, + }, + ) + .unwrap(); + + // Receive all segments + let mut buf = [0u8; u16::MAX as usize]; + let mut total_received = 0; + while total_received < segments { + let mut meta = RecvMeta::default(); + let n = recv_state + .recv( + (&recv).into(), + &mut [IoSliceMut::new(&mut buf)], + slice::from_mut(&mut meta), + ) + .unwrap(); + assert_eq!(n, 1); + let received_segments = meta.len / meta.stride; + for i in 0..received_segments { + assert_eq!( + &buf[i * meta.stride..(i + 1) * meta.stride], + &msg[(total_received + i) * SEGMENT_SIZE..(total_received + i + 1) * SEGMENT_SIZE], + "segment {} content mismatch", + total_received + i + ); + } + total_received += received_segments; + } + assert_eq!(total_received, segments, "should receive all segments"); +} diff --git a/noq/Cargo.toml b/noq/Cargo.toml index 4e50fa2e9..617078f78 100644 --- a/noq/Cargo.toml +++ b/noq/Cargo.toml @@ -100,7 +100,7 @@ directories-next = { workspace = true } rand = { workspace = true } rcgen = { workspace = true } clap = { workspace = true } -tokio = { workspace = true, features = ["rt", "rt-multi-thread", "time", "macros"] } +tokio = { workspace = true, features = ["rt", "rt-multi-thread", "time", "macros", "test-util"] } tracing-subscriber = { workspace = true } tracing-futures = { workspace = true } url = { workspace = true } diff --git a/noq/src/connection.rs b/noq/src/connection.rs index 741d7e0e5..f7d9a5dca 100644 --- a/noq/src/connection.rs +++ b/noq/src/connection.rs @@ -6,7 +6,10 @@ use std::{ net::{IpAddr, SocketAddr}, num::NonZeroUsize, pin::Pin, - sync::{Arc, Weak}, + sync::{ + Arc, Weak, + atomic::{AtomicUsize, Ordering}, + }, task::{Context, Poll, Waker, ready}, }; @@ -579,9 +582,14 @@ impl Connection { } } - /// If the connection is closed, the reason why. + /// Whether the connection is closed, and why. /// - /// Returns `None` if the connection is still open. + /// The close_reason is always set to `Some(ConnectionError)` when a socket is + /// closed; whether it was closed manually by calling [`Connection::close()`] or due to + /// an internal error (such as an idle timeout or the peer closing the + /// connection). + /// + /// Note: when the connection is closed, `connection.close_reason().is_some()` will always be true. pub fn close_reason(&self) -> Option { self.0.lock_without_waking("close_reason").error.clone() } @@ -1237,7 +1245,7 @@ pub(crate) struct ConnectionRef(Arc>); impl ConnectionRef { #[allow(clippy::redundant_allocation)] fn from_arc(inner: Arc>) -> Self { - inner.lock_without_waking("from_arc").ref_count += 1; + inner.shared.ref_count.fetch_add(1, Ordering::Relaxed); Self(inner) } @@ -1258,16 +1266,18 @@ impl Clone for ConnectionRef { impl Drop for ConnectionRef { fn drop(&mut self) { + if self.shared.ref_count.fetch_sub(1, Ordering::Relaxed) > 1 { + return; + } + let conn = &mut *self.lock_without_waking("drop"); - if let Some(x) = conn.ref_count.checked_sub(1) { - conn.ref_count = x; - if x == 0 && !conn.inner.is_closed() { - // If the driver is alive, it's just it and us, so we'd better shut it down. If it's - // not, we can't do any harm. If there were any streams being opened, then either - // the connection will be closed for an unrelated reason or a fresh reference will - // be constructed for the newly opened stream. - conn.implicit_close(&self.shared); - } + + if !conn.inner.is_closed() { + // If the driver is alive, it's just it and us, so we'd better shut it down. If it's + // not, we can't do any harm. If there were any streams being opened, then either + // the connection will be closed for an unrelated reason or a fresh reference will + // be constructed for the newly opened stream. + conn.implicit_close(&self.shared); } } } @@ -1372,6 +1382,8 @@ pub(crate) struct Shared { datagram_received: Notify, datagrams_unblocked: Notify, closed: Notify, + /// Number of live handles that can be used to initiate or handle I/O; excludes the driver + ref_count: AtomicUsize, } pub(crate) struct State { @@ -1405,8 +1417,6 @@ pub(crate) struct State { /// When the last reference to a path is dropped via [`Self::decrement_path_refs`] its value is cleared. pub(crate) final_path_stats: FxHashMap, pub(crate) path_events: tokio::sync::broadcast::Sender, - /// Number of live handles that can be used to initiate or handle I/O; excludes the driver - ref_count: usize, sender: Pin>, pub(crate) runtime: Arc, send_buffer: Vec, @@ -1448,7 +1458,6 @@ impl State { stopped: FxHashMap::default(), open_path: FxHashMap::default(), error: None, - ref_count: 0, sender, runtime, send_buffer: Vec::new(), diff --git a/noq/src/endpoint.rs b/noq/src/endpoint.rs index 847ad3da4..d4b20d63c 100644 --- a/noq/src/endpoint.rs +++ b/noq/src/endpoint.rs @@ -8,7 +8,10 @@ use std::{ num::NonZeroUsize, pin::Pin, str, - sync::{Arc, Mutex}, + sync::{ + Arc, Mutex, + atomic::{AtomicUsize, Ordering}, + }, task::{Context, Poll, RawWaker, RawWakerVTable, Waker}, }; @@ -105,23 +108,37 @@ impl Endpoint { /// Helper to construct an endpoint for use with both incoming and outgoing connections /// - /// Platform defaults for dual-stack sockets vary. For example, any socket bound to a wildcard - /// IPv6 address on Windows will not by default be able to communicate with IPv4 - /// addresses. Portable applications should bind an address that matches the family they wish to - /// communicate within. + /// Note that `addr` is the *local* address to bind to, which should usually be a wildcard + /// address like `0.0.0.0:0` or `[::]:0`, which allow communication with any reachable IPv4 or + /// IPv6 address respectively from an OS-assigned port. + /// + /// If an IPv6 address is provided, attempts to make the socket dual-stack so as to allow + /// communication with both IPv4 and IPv6 clients. As such, calling `Endpoint::server` with + /// the address `[::]:0` is a reasonable default to maximize the ability to accept connections + /// from any address. + /// + /// Some environments may not allow creation of dual-stack sockets, in which case an IPv6 + /// server will only be able to accept connections from IPv6 clients. An IPv4 server is never + /// dual-stack. #[cfg(all( not(wasm_browser), any(feature = "runtime-tokio", feature = "runtime-smol"), any(feature = "aws-lc-rs", feature = "ring"), // `EndpointConfig::default()` is only available with these ))] pub fn server(config: ServerConfig, addr: SocketAddr) -> io::Result { - let socket = std::net::UdpSocket::bind(addr)?; + let socket = Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP))?; + if addr.is_ipv6() + && let Err(e) = socket.set_only_v6(false) + { + tracing::debug!(%e, "unable to make socket dual-stack"); + } + socket.bind(&addr.into())?; let runtime = default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?; Self::new_with_abstract_socket( EndpointConfig::default(), Some(config), - runtime.wrap_udp_socket(socket)?, + runtime.wrap_udp_socket(socket.into())?, runtime, ) } @@ -424,7 +441,8 @@ impl Future for EndpointDriver { // - all `Endpoint` structs are dropped and all connections are drained, // - or `Endpoint::close` has been called and all connections are drained. if endpoint.recv_state.connections.is_empty() - && (endpoint.ref_count == 0 || endpoint.recv_state.connections.close.is_some()) + && (self.0.shared.ref_count.load(Ordering::Relaxed) == 0 + || endpoint.recv_state.connections.close.is_some()) { trace!("endpoint driver stopping"); Poll::Ready(Ok(())) @@ -524,8 +542,6 @@ pub(crate) struct State { driver: Option, ipv6: bool, events: mpsc::UnboundedReceiver<(ConnectionHandle, EndpointEvent)>, - /// Number of live handles that can be used to initiate or handle I/O; excludes the driver - ref_count: usize, driver_lost: bool, runtime: Arc, stats: EndpointStats, @@ -536,6 +552,8 @@ pub(crate) struct State { pub(crate) struct Shared { incoming: Notify, idle: Notify, + /// Number of live handles that can be used to initiate or handle I/O; excludes the driver + ref_count: AtomicUsize, } impl State { @@ -772,6 +790,7 @@ impl EndpointRef { shared: Shared { incoming: Notify::new(), idle: Notify::new(), + ref_count: AtomicUsize::new(0), }, state: Mutex::new(State { socket, @@ -781,7 +800,6 @@ impl EndpointRef { ipv6, events, driver: None, - ref_count: 0, driver_lost: false, recv_state, runtime, @@ -794,23 +812,22 @@ impl EndpointRef { impl Clone for EndpointRef { fn clone(&self) -> Self { - self.0.state.lock().unwrap().ref_count += 1; + self.0.shared.ref_count.fetch_add(1, Ordering::Relaxed); Self(self.0.clone()) } } impl Drop for EndpointRef { fn drop(&mut self) { + if self.0.shared.ref_count.fetch_sub(1, Ordering::Relaxed) > 1 { + return; + } + let endpoint = &mut *self.0.state.lock().unwrap(); - if let Some(x) = endpoint.ref_count.checked_sub(1) { - endpoint.ref_count = x; - if x == 0 { - // If the driver is about to be on its own, ensure it can shut down if the last - // connection is gone. - if let Some(task) = endpoint.driver.take() { - task.wake(); - } - } + // If the driver is about to be on its own, ensure it can shut down if the last + // connection is gone. + if let Some(task) = endpoint.driver.take() { + task.wake(); } } } diff --git a/noq/src/recv_stream.rs b/noq/src/recv_stream.rs index 5467a9dc5..81fc858fc 100644 --- a/noq/src/recv_stream.rs +++ b/noq/src/recv_stream.rs @@ -275,6 +275,9 @@ impl RecvStream { } conn.inner.recv_stream(self.stream).stop(error_code)?; self.all_data_read = true; + // Clean up shared state that might be left over from a cancelled read + // operation, so `drop` doesn't have to + conn.blocked_readers.remove(&self.stream); Ok(()) } @@ -581,6 +584,18 @@ impl tokio::io::AsyncRead for RecvStream { impl Drop for RecvStream { fn drop(&mut self) { + if self.all_data_read { + debug_assert!( + !self + .conn + .lock_without_waking("RecvStream:drop") + .blocked_readers + .contains_key(&self.stream), + "Stream {} should not have a blocked reader when all data read is true", + &self.stream + ); + return; + } let mut conn = self.conn.lock_and_wake("RecvStream::drop"); // clean up any previously registered wakers @@ -590,10 +605,9 @@ impl Drop for RecvStream { conn.skip_waking(); return; } - if !self.all_data_read { - // Ignore ClosedStream errors - let _ = conn.inner.recv_stream(self.stream).stop(0u32.into()); - } + + // Ignore ClosedStream errors + let _ = conn.inner.recv_stream(self.stream).stop(0u32.into()); } } diff --git a/noq/src/tests.rs b/noq/src/tests.rs index 8f5a8c692..b1f8c5381 100755 --- a/noq/src/tests.rs +++ b/noq/src/tests.rs @@ -11,8 +11,13 @@ use std::{ convert::TryInto, io, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, + pin::pin, str, - sync::Arc, + sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }, + task::{Context, Poll, RawWaker, RawWakerVTable, Waker}, }; use crate::runtime::TokioRuntime; @@ -839,7 +844,7 @@ async fn multiple_conns_with_zero_length_cids() { let mut factory = EndpointFactory::new(); factory .endpoint_config - .cid_generator(|| Box::new(RandomConnectionIdGenerator::new(0))); + .cid_generator(Arc::new(|| Box::new(RandomConnectionIdGenerator::new(0)))); let server = factory.endpoint("server"); let server_addr = server.local_addr().unwrap(); @@ -1213,6 +1218,43 @@ async fn weak_connection_handle() { client_res.expect("client task panicked"); } +#[tokio::test(start_paused = true)] +async fn dropped_endpoint_cleans_up() { + let _guard = subscribe(); + + let mut endpoint_factory = EndpointFactory::new(); + let cid_generator = Arc::new(|| -> Box { + Box::::default() + }); + endpoint_factory + .endpoint_config + .cid_generator(cid_generator.clone()); + let endpoint = endpoint_factory.endpoint("endpoint"); + drop(endpoint_factory); + assert_eq!(Arc::strong_count(&cid_generator), 2); + drop(endpoint); + // Let the driver task run; paused runtimes are guaranteed to drain pending work on sleep. + tokio::time::sleep(Duration::from_millis(1)).await; + assert_eq!(Arc::strong_count(&cid_generator), 1); +} + +#[tokio::test] +async fn dropped_connection_cleans_up() { + let _guard = subscribe(); + let endpoint = endpoint(); + tokio::join!( + async { + endpoint + .connect(endpoint.local_addr().unwrap(), "localhost") + .unwrap() + .await + .unwrap() + }, + async { endpoint.accept().await.unwrap().await.unwrap() } + ); + endpoint.wait_idle().await; +} + /// Test that accessing stats from `Path` works as expected. #[tokio::test] async fn path_clone_stats_after_abandon() { @@ -1481,3 +1523,151 @@ async fn nat_traversal_wakes_connection_driver() -> TestResult { tokio::join!(server_task, client_task); Ok(()) } + +#[tokio::test] +async fn stream_drop_removes_blocked_reader() { + let _guard = subscribe(); + + for drop_stream in [false, true] { + let endpoint_factory = EndpointFactory::new(); + let server = endpoint_factory.endpoint("server"); + let server_address = server.local_addr().unwrap(); + let client = endpoint_factory.endpoint("client"); + + let server_task = tokio::spawn(async move { + let conn = server.accept().await.unwrap().await.unwrap(); + let mut stream = conn.accept_uni().await.unwrap(); + + // read "hello" + let mut buf = [0u8; 5]; + stream.read_exact(&mut buf).await.unwrap(); + + let (waker, wake_counter) = new_count_waker(); + let mut cx = Context::from_waker(&waker); + // do a blocking read which will add the stream in conn.blocked_readers + { + let mut buf = [0u8; 64]; + let read_fut = stream.read(&mut buf); + tokio::pin!(read_fut); + assert!(matches!(read_fut.as_mut().poll(&mut cx), Poll::Pending)); + } + + if !drop_stream { + assert_eq!(wake_counter.wakes(), 0); + // We have a blocked reader, closing the connection should wake it. We use this as + // a proxy to assert that the stream is in conn.blocked_readers. + conn.close(0u32.into(), b"done"); + assert_eq!(wake_counter.wakes(), 1); + } else { + // dropping the stream should remove it from conn.blocked_readers, so we don't + // expect any wakeups + drop(stream); + assert_eq!(wake_counter.wakes(), 0, "no wakeups should have occurred"); + conn.close(0u32.into(), b"done"); + assert_eq!(wake_counter.wakes(), 0, "no wakeups should have occurred"); + } + }); + + let conn = client + .connect(server_address, "localhost") + .unwrap() + .await + .unwrap(); + let mut stream = conn.open_uni().await.unwrap(); + // need to send some data to actually start the stream + stream.write_all(b"hello").await.unwrap(); + + server_task.await.unwrap(); + } +} + +/// Test that dropping a `RecvStream` after cancelling a read and then +/// explicitly `stop`ing it doesn't panic. +#[tokio::test] +async fn recv_stream_cancel_stop_drop() { + let _guard = subscribe(); + let factory = EndpointFactory::new(); + let server = factory.endpoint("server"); + let server_addr = server.local_addr().unwrap(); + let client = factory.endpoint("client"); + let recv_dropped = tokio::sync::SetOnce::new(); + tokio::join!( + async { + let conn = server.accept().await.unwrap().await.unwrap(); + let mut recv = conn.accept_uni().await.unwrap(); + // Create a future to read from the stream, poll it once, then immediately drop it + { + let fut = pin!(recv.read_to_end(usize::MAX)); + let mut cx = Context::from_waker(Waker::noop()); + assert!(fut.poll(&mut cx).is_pending()); + } + recv_dropped.set(()).unwrap(); + recv.stop(0u32.into()).unwrap(); + }, + async { + let conn = client + .connect(server_addr, "localhost") + .unwrap() + .await + .unwrap(); + let mut send = conn.open_uni().await.unwrap(); + _ = send.write_all(b"hello").await; + // Don't drop (finish) the send stream until the read has been + // cancelled by the server, ensuring that read_to_end can't complete + // immediately. + recv_dropped.wait().await; + }, + ); +} + +#[derive(Default)] +struct WakeCounter { + wakes: AtomicUsize, +} + +impl WakeCounter { + fn wakes(&self) -> usize { + self.wakes.load(Ordering::SeqCst) + } +} + +fn new_count_waker() -> (Waker, Arc) { + // instance of WakeCounter + let counter = Arc::new(WakeCounter::default()); + + // convert + let waker = unsafe { Waker::from_raw(raw_waker(counter.clone())) }; + (waker, counter) +} + +fn raw_waker(counter: Arc) -> RawWaker { + // Store an Arc behind the raw pointer. + let ptr = Arc::into_raw(counter) as *const (); + RawWaker::new(ptr, &VTABLE) +} + +static VTABLE: RawWakerVTable = + RawWakerVTable::new(clone_waker, wake_waker, wake_by_ref_waker, drop_waker); + +unsafe fn clone_waker(data: *const ()) -> RawWaker { + let arc = unsafe { Arc::::from_raw(data as *const WakeCounter) }; + let cloned = arc.clone(); + std::mem::forget(arc); + raw_waker(cloned) +} + +unsafe fn wake_waker(data: *const ()) { + let arc = unsafe { Arc::::from_raw(data as *const WakeCounter) }; + arc.wakes.fetch_add(1, Ordering::SeqCst); + // arc drops here +} + +unsafe fn wake_by_ref_waker(data: *const ()) { + let arc = unsafe { Arc::::from_raw(data as *const WakeCounter) }; + arc.wakes.fetch_add(1, Ordering::SeqCst); + std::mem::forget(arc); +} + +unsafe fn drop_waker(data: *const ()) { + drop(unsafe { Arc::::from_raw(data as *const WakeCounter) }); +} diff --git a/perf/src/client.rs b/perf/src/client.rs index 999998156..c0a3c24aa 100644 --- a/perf/src/client.rs +++ b/perf/src/client.rs @@ -1,5 +1,5 @@ #[cfg(feature = "json-output")] -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, sync::Arc, @@ -166,15 +166,22 @@ pub async fn run(opt: Opt) -> Result<()> { let stats_fut = async { let interval_duration = Duration::from_secs(opt.interval); + #[cfg(feature = "json-output")] + let allow_table_output = opt.json.clone().is_none_or(|path| path != Path::new("-")); + #[cfg(not(feature = "json-output"))] + let allow_table_output = true; + loop { let start = Instant::now(); tokio::time::sleep(interval_duration).await; { stats.on_interval(start, &stream_stats); - stats.print(); - if opt.common.conn_stats { - println!("{:?}\n", connection.stats()); + if allow_table_output { + stats.print(); + if opt.common.conn_stats { + println!("{:?}\n", connection.stats()); + } } } } diff --git a/perf/src/stats.rs b/perf/src/stats.rs index e21d784f9..0234cfbd1 100644 --- a/perf/src/stats.rs +++ b/perf/src/stats.rs @@ -120,12 +120,11 @@ impl Stats { #[cfg(feature = "json-output")] pub fn print_json(&self, path: &Path) -> io::Result<()> { - match path { - path if path == Path::new("-") => json::print(self, std::io::stdout()), - _ => { - let file = File::create(path)?; - json::print(self, file) - } + if path == Path::new("-") { + json::print(self, std::io::stdout()); + } else { + let file = File::create(path)?; + json::print(self, file) } Ok(()) }