diff --git a/bench/src/bin/bulk.rs b/bench/src/bin/bulk.rs index 0c268990b..605fd6a68 100644 --- a/bench/src/bin/bulk.rs +++ b/bench/src/bin/bulk.rs @@ -162,7 +162,7 @@ async fn client( // to `Arc`ing them connection.close(0u32.into(), b"Benchmark done"); - endpoint.wait_idle().await; + endpoint.wait_all_draining().await; if opt.stats { println!("\nClient connection stats:\n{:#?}", connection.stats()); diff --git a/noq-proto/src/connection/mod.rs b/noq-proto/src/connection/mod.rs index 691e6451e..1d356d72c 100644 --- a/noq-proto/src/connection/mod.rs +++ b/noq-proto/src/connection/mod.rs @@ -2386,7 +2386,10 @@ impl Connection { match timer { Timer::Conn(timer) => match timer { ConnTimer::Close => { - self.state.move_to_drained(None); + let was_draining = self.state.move_to_drained(None); + if !was_draining { + self.endpoint_events.push_back(EndpointEventInner::Draining); + } // move_to_drained checks that we weren't in drained before. // Adding events to endpoint_events is only legal if `Drained` was never queued before. self.endpoint_events.push_back(EndpointEventInner::Drained); @@ -4370,7 +4373,10 @@ impl Connection { code: TransportErrorCode::AEAD_LIMIT_REACHED, .. }) => { - self.state.move_to_drained(Some(conn_err)); + let was_draining = self.state.move_to_drained(Some(conn_err)); + if !was_draining { + self.endpoint_events.push_back(EndpointEventInner::Draining); + } } ConnectionError::TimedOut => { unreachable!("timeouts aren't generated by packet processing"); @@ -4381,6 +4387,7 @@ impl Connection { } ConnectionError::VersionMismatch => { self.state.move_to_draining(Some(conn_err)); + self.endpoint_events.push_back(EndpointEventInner::Draining); } ConnectionError::LocallyClosed => { unreachable!("LocallyClosed isn't generated by packet processing"); @@ -4491,6 +4498,8 @@ impl Connection { continue; }; + trace!(?frame, "processing frame in closed state"); + self.path_stats .for_path(path_id) .frame_rx @@ -4498,6 +4507,7 @@ impl Connection { if let Frame::Close(_error) = frame { self.state.move_to_draining(None); + self.endpoint_events.push_back(EndpointEventInner::Draining); break; } } @@ -4827,6 +4837,7 @@ impl Connection { } Frame::Close(reason) => { self.state.move_to_draining(Some(reason.into())); + self.endpoint_events.push_back(EndpointEventInner::Draining); return Ok(()); } _ => { @@ -5578,6 +5589,7 @@ impl Connection { if let Some(reason) = close { self.state.move_to_draining(Some(reason.into())); + self.endpoint_events.push_back(EndpointEventInner::Draining); self.connection_close_pending = true; } @@ -6946,7 +6958,10 @@ impl Connection { /// Terminate the connection instantly, without sending a close packet fn kill(&mut self, reason: ConnectionError) { self.close_common(); - self.state.move_to_drained(Some(reason)); + let was_draining = self.state.move_to_drained(Some(reason)); + if !was_draining { + self.endpoint_events.push_back(EndpointEventInner::Draining); + } // move_to_drained checks that we were never in drained before, so we // never sent a `Drained` event before (it's illegal to send more events after drained). self.endpoint_events.push_back(EndpointEventInner::Drained); diff --git a/noq-proto/src/connection/state.rs b/noq-proto/src/connection/state.rs index 3e6ce2975..06b593711 100644 --- a/noq-proto/src/connection/state.rs +++ b/noq-proto/src/connection/state.rs @@ -66,14 +66,20 @@ impl State { /// Moves to the drained state. /// /// Panics if the state was already drained. - pub(super) fn move_to_drained(&mut self, error: Option) { - let (error, is_local) = if let Some(error) = error { - (Some(error), false) + /// + /// Returns whether we were in the draining state before. + pub(super) fn move_to_drained(&mut self, error: Option) -> bool { + let (error, is_local, was_draining) = if let Some(error) = error { + ( + Some(error), + false, + matches!(self.inner, InnerState::Draining { .. }), + ) } else { - let error = match &mut self.inner { - InnerState::Draining { error, .. } => error.take(), + let (error, was_draining) = match &mut self.inner { + InnerState::Draining { error, .. } => (error.take(), true), InnerState::Drained { .. } => panic!("invalid state transition drained -> drained"), - InnerState::Closed { error_read, .. } if *error_read => None, + InnerState::Closed { error_read, .. } if *error_read => (None, false), InnerState::Closed { remote_reason, .. } => { let error = match remote_reason.clone().into() { ConnectionError::ConnectionClosed(close) => { @@ -89,14 +95,15 @@ impl State { } e => e, }; - Some(error) + (Some(error), false) } - InnerState::Handshake(_) | InnerState::Established => None, + InnerState::Handshake(_) | InnerState::Established => (None, false), }; - (error, self.is_local_close()) + (error, self.is_local_close(), was_draining) }; self.inner = InnerState::Drained { error, is_local }; trace!("connection state: drained"); + was_draining } /// Moves to a draining state. diff --git a/noq-proto/src/endpoint.rs b/noq-proto/src/endpoint.rs index 17c7c9555..bb885d5b6 100644 --- a/noq-proto/src/endpoint.rs +++ b/noq-proto/src/endpoint.rs @@ -129,6 +129,9 @@ impl Endpoint { } } } + Draining => { + // Nothing to do. + } Drained => { if let Some(conn) = self.connections.try_remove(ch.0) { self.index.remove(&conn); diff --git a/noq-proto/src/shared.rs b/noq-proto/src/shared.rs index 4097af02d..703ee9cf7 100644 --- a/noq-proto/src/shared.rs +++ b/noq-proto/src/shared.rs @@ -48,10 +48,17 @@ impl EndpointEvent { pub fn is_drained(&self) -> bool { self.0 == EndpointEventInner::Drained } + + /// Whether this is the event is the event indicating the start of the draining period. + pub fn is_draining(&self) -> bool { + self.0 == EndpointEventInner::Draining + } } #[derive(Clone, Debug, Eq, PartialEq)] pub(crate) enum EndpointEventInner { + /// The connection started draining + Draining, /// The connection has been drained Drained, /// The connection has a new active reset token diff --git a/noq-proto/src/tests/mod.rs b/noq-proto/src/tests/mod.rs index 4858b83a3..e7b547f8a 100644 --- a/noq-proto/src/tests/mod.rs +++ b/noq-proto/src/tests/mod.rs @@ -4320,3 +4320,59 @@ fn regression_close_without_connection_event() { Some(Event::ConnectionLost { .. }) ); } + +/// Ensures that the draining delay for the server is exactly 0.5 RTT and 1 RTT for the client. +/// +/// The draining delay is the time between the connection being closed and the connection +/// entering the "draining" state (either on the same or on the other side). +/// +/// We expect the side that *receives* the CONNECTION_CLOSE to immediately enter the draining +/// state. However in absolute terms, it'll be delayed by 0.5 RTT (exactly the latency) compared +/// to when `connection.close()` was called. +/// On the side that called `connection.close()` we first enter the "closed" state, and only +/// enter the "draining" state once we *receive* a "reciprocal" CONNECTION_CLOSE from the other +/// side. In the normal case this will be exactly 1 RTT after calling `connection.close()` to +/// account for the latency of CONNECTION_CLOSE going one way and then coming back. +/// +/// The "draining" state from noq-proto is observed by noq to enable `wait_idle` waiting the +/// ideal amount of time before allowing us to close the socket. +#[test] +fn timely_graceful_close() { + const ONE_WAY_LATENCY: Duration = Duration::from_millis(100); + + let _guard = subscribe(); + let mut pair = Pair::default(); + pair.latency = ONE_WAY_LATENCY; + let mut pair = ConnPair::connect_with(pair, client_config()); + + let start = pair.time; + pair.close(Client, 0, b"done!"); + + assert!(!pair.is_draining(Client)); + assert!(!pair.is_draining(Server)); + + // The client now sends CONNECTION_CLOSE to the server and it processes it. + // When the server receives CONNECTION_CLOSE, it responds with one of its own + // and enters the draining state. + pair.drive_client(); + pair.advance_time(); + let now = pair.time; + pair.drive_server(); + + assert!(pair.is_draining(Server)); + let server_draining_delay = now.saturating_duration_since(start); + info!(?server_draining_delay); + assert_eq!(server_draining_delay, ONE_WAY_LATENCY); + + // The server has now sent a CONNECTION_CLOSE back in response and the client processes it. + // The client then enters the draining state once it processed the response. + // already drove server + pair.advance_time(); + let now = pair.time; + pair.drive_client(); + + assert!(pair.is_draining(Client)); + let client_draining_delay = now.saturating_duration_since(start); + info!(?client_draining_delay); + assert_eq!(client_draining_delay, ONE_WAY_LATENCY * 2); +} diff --git a/noq-proto/src/tests/util.rs b/noq-proto/src/tests/util.rs index b36514e64..ac1a2592d 100644 --- a/noq-proto/src/tests/util.rs +++ b/noq-proto/src/tests/util.rs @@ -812,6 +812,13 @@ impl ConnPair { let now = self.pair.time; self.conn_mut(side).handle_network_change(hint, now); } + + pub(super) fn is_draining(&self, side: Side) -> bool { + match side { + Client => self.client.draining_connections.contains(&self.client_ch), + Server => self.server.draining_connections.contains(&self.server_ch), + } + } } impl Default for Pair { @@ -829,7 +836,7 @@ pub(super) struct TestEndpoint { pub(super) inbound: VecDeque, pub(super) accepted: Option>, pub(super) connections: HashMap, - drained_connections: HashSet, + pub(super) draining_connections: HashSet, conn_events: HashMap>, pub(super) captured_packets: Vec>, pub(super) capture_inbound_packets: bool, @@ -872,7 +879,7 @@ impl TestEndpoint { inbound: VecDeque::new(), accepted: None, connections: HashMap::default(), - drained_connections: HashSet::default(), + draining_connections: HashSet::default(), conn_events: HashMap::default(), captured_packets: Vec::new(), capture_inbound_packets: false, @@ -975,8 +982,8 @@ impl TestEndpoint { } for (ch, event) in endpoint_events { - if event.is_drained() { - self.drained_connections.insert(ch); + if event.is_draining() { + self.draining_connections.insert(ch); } if let Some(event) = self.handle_event(ch, event) && let Some(conn) = self.connections.get_mut(&ch) @@ -1084,11 +1091,9 @@ pub(crate) fn subscribe() -> tracing::subscriber::DefaultGuard { .with_default_directive(tracing::Level::TRACE.into()) .from_env_lossy(), ) + .without_time() .with_line_number(true) .with_writer(|| TestWriter); - // tracing uses std::time to trace time, which panics in wasm. - #[cfg(all(target_family = "wasm", target_os = "unknown"))] - let builder = builder.without_time(); tracing::subscriber::set_default(builder.finish()) } diff --git a/noq/examples/client.rs b/noq/examples/client.rs index 2b3bc79c7..0afbca3c8 100644 --- a/noq/examples/client.rs +++ b/noq/examples/client.rs @@ -162,7 +162,7 @@ async fn run(options: Opt) -> Result<()> { conn.close(0u32.into(), b"done"); // Give the server a fair chance to receive the close packet - endpoint.wait_idle().await; + endpoint.wait_all_draining().await; Ok(()) } diff --git a/noq/examples/connection.rs b/noq/examples/connection.rs index 958ca0a8c..ae9a2f07f 100644 --- a/noq/examples/connection.rs +++ b/noq/examples/connection.rs @@ -49,7 +49,7 @@ async fn main() -> Result<(), Box> { let _ = connection.accept_uni().await; // Make sure the server has a chance to clean up - endpoint.wait_idle().await; + endpoint.wait_all_draining().await; Ok(()) } diff --git a/noq/examples/insecure_connection.rs b/noq/examples/insecure_connection.rs index 359a84294..d7fdc0a25 100644 --- a/noq/examples/insecure_connection.rs +++ b/noq/examples/insecure_connection.rs @@ -66,7 +66,7 @@ async fn run_client(server_addr: SocketAddr) -> Result<(), Box Result<(), Box> { ); // Make sure the server has a chance to clean up - client.wait_idle().await; + client.wait_all_draining().await; Ok(()) } diff --git a/noq/src/endpoint.rs b/noq/src/endpoint.rs index d4b20d63c..4caae8ea0 100644 --- a/noq/src/endpoint.rs +++ b/noq/src/endpoint.rs @@ -366,7 +366,31 @@ impl Endpoint { } } - /// Wait for all connections on the endpoint to be cleanly shut down + /// Waits for all connections on the endpoint to be cleanly shut down and drained. + /// + /// This is equivalent to [`wait_all_draining()`] with additionally waiting for the connections to be + /// drained. Please see its documentation for more information. + /// + /// Use `wait_idle()` in favor of `wait_all_draining()` if you care about waiting for the + /// [`Connection`] structs to be dropped. + /// + /// [`wait_all_draining()`]: Self::wait_all_draining + /// [`Connection`]: crate::Connection + pub async fn wait_idle(&self) { + loop { + { + let endpoint = &mut *self.inner.state.lock().unwrap(); + if endpoint.recv_state.connections.is_empty() { + break; + } + // Construct future while lock is held to avoid race + self.inner.shared.idle.notified() + } + .await; + } + } + + /// Waits for all connections on the endpoint to be ready for shutting down. /// /// Waiting for this condition before exiting ensures that a good-faith effort is made to notify /// peers of recent connection closes, whereas exiting immediately could force them to wait out @@ -375,16 +399,23 @@ impl Endpoint { /// Does not proactively close existing connections or cause incoming connections to be /// rejected. Consider calling [`close()`] if that is desired. /// - /// [`close()`]: Endpoint::close - pub async fn wait_idle(&self) { + /// Unlike [`wait_idle()`], this doesn't wait for the full draining period, so it can't be + /// used to wait for all now-idle [`Connection`]s to be dropped. + /// + /// See also this section in the QUIC RFC: + /// + /// [`close()`]: Self::close + /// [`wait_idle()`]: Self::wait_idle + /// [`Connection`]: crate::Connection + pub async fn wait_all_draining(&self) { loop { { let endpoint = &mut *self.inner.state.lock().unwrap(); - if endpoint.recv_state.connections.is_empty() { + if endpoint.recv_state.connections.active_connections == 0 { break; } // Construct future while lock is held to avoid race - self.inner.shared.idle.notified() + self.inner.shared.all_draining.notified() } .await; } @@ -467,6 +498,7 @@ impl Drop for EndpointDriver { // Drop all outgoing channels, signaling the termination of the endpoint to the associated // connections. endpoint.recv_state.connections.senders.clear(); + endpoint.recv_state.connections.active_connections = 0; } } @@ -550,7 +582,17 @@ pub(crate) struct State { #[derive(Debug)] pub(crate) struct Shared { + /// Notifies subscribers of new incoming connections. + /// + /// This enables the `Endpoint::accept` API. incoming: Notify, + /// Notifies subscribers when *all* connections have entered the draining state. + /// + /// This powers the `Endpoint::wait_idle` API. + all_draining: Notify, + /// Notifies subscribesr when *all* connections have been dropped. + /// + /// This powers the `Endpoint::wait_drained` API. idle: Notify, /// Number of live handles that can be used to initiate or handle I/O; excludes the driver ref_count: AtomicUsize, @@ -602,7 +644,12 @@ impl State { } }; - if event.is_drained() { + if event.is_draining() { + self.recv_state.connections.active_connections -= 1; + if self.recv_state.connections.active_connections == 0 { + shared.all_draining.notify_waiters(); + } + } else if event.is_drained() { self.recv_state.connections.senders.remove(&ch); if self.recv_state.connections.is_empty() { shared.idle.notify_waiters(); @@ -700,6 +747,16 @@ struct ConnectionSet { sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, /// Set if the endpoint has been manually closed close: Option<(VarInt, Bytes)>, + /// Counter for all active (non-draining/drained) connections. + /// + /// This is directly related to the QUIC connection states "Initial", "Handshake", + /// "Established", "Closed", "Draining" and "Drained" (see also `proto/src/connection/state.rs`). + /// + /// Any connection state that is not "Draining" or "Drained" is considered active. + /// + /// This counter is updated when new connections are added ([`ConnectionSet::insert`]) and when + /// a connection informs us about entering the draining state ([`State::handle_events`]). + active_connections: u64, } impl ConnectionSet { @@ -719,6 +776,9 @@ impl ConnectionSet { .unwrap(); } self.senders.insert(handle, send); + if self.close.is_none() { + self.active_connections += 1; + } Connecting::new(handle, conn, self.sender.clone(), recv, sender, runtime) } @@ -789,6 +849,7 @@ impl EndpointRef { Self(Arc::new(EndpointInner { shared: Shared { incoming: Notify::new(), + all_draining: Notify::new(), idle: Notify::new(), ref_count: AtomicUsize::new(0), }, @@ -864,6 +925,7 @@ impl RecvState { senders: FxHashMap::default(), sender, close: None, + active_connections: 0, }, incoming: VecDeque::new(), recv_buf: recv_buf.into(), diff --git a/noq/src/tests.rs b/noq/src/tests.rs index b4d19cbdc..8fbee7d90 100755 --- a/noq/src/tests.rs +++ b/noq/src/tests.rs @@ -411,7 +411,7 @@ async fn zero_rtt() { drop((stream, connection)); - endpoint.wait_idle().await; + endpoint.wait_all_draining().await; } #[test] @@ -586,7 +586,7 @@ fn run_echo(args: EchoArgs) { tokio::spawn(echo(stream)); } }); - server.wait_idle().await; + server.wait_all_draining().await; }); info!("connecting from {} to {}", args.client_addr, server_addr); @@ -617,7 +617,7 @@ fn run_echo(args: EchoArgs) { assert_eq!(data[..], msg[..], "Data mismatch"); } new_conn.close(0u32.into(), b"done"); - client.wait_idle().await; + client.wait_all_draining().await; } .instrument(error_span!("client")), ); @@ -1260,7 +1260,7 @@ async fn dropped_connection_cleans_up() { }, async { endpoint.accept().await.unwrap().await.unwrap() } ); - endpoint.wait_idle().await; + endpoint.wait_all_draining().await; } /// Test that accessing stats from `Path` works as expected. @@ -1480,7 +1480,7 @@ async fn close_path() -> TestResult { test_done_tx.send(()).expect("not dropped"); - server.wait_idle().await; + server.wait_all_draining().await; TestResult::Ok(()) } @@ -1526,7 +1526,7 @@ async fn close_path() -> TestResult { test_done_rx.await.expect("not dropped"); client.close(0u8.into(), b"test finished"); - client.wait_idle().await; + client.wait_all_draining().await; TestResult::Ok(()) } diff --git a/noq/tests/post_quantum.rs b/noq/tests/post_quantum.rs index 21896e5d6..cfa608f62 100644 --- a/noq/tests/post_quantum.rs +++ b/noq/tests/post_quantum.rs @@ -79,7 +79,7 @@ async fn check_post_quantum_key_exchange(min_mtu: u16) { let _ = connection.accept_uni().await; // Make sure the server has a chance to clean up - endpoint.wait_idle().await; + endpoint.wait_all_draining().await; jh.await.unwrap(); } diff --git a/perf/src/client.rs b/perf/src/client.rs index 92c1318cf..3835c44a2 100644 --- a/perf/src/client.rs +++ b/perf/src/client.rs @@ -201,7 +201,7 @@ pub async fn run(opt: Opt) -> Result<()> { } } - endpoint.wait_idle().await; + endpoint.wait_all_draining().await; #[cfg(feature = "json-output")] if let Some(path) = opt.json {