diff --git a/src/quic/ack_handler.zig b/src/quic/ack_handler.zig index d95d8f2..78b5f6a 100644 --- a/src/quic/ack_handler.zig +++ b/src/quic/ack_handler.zig @@ -34,9 +34,6 @@ const MAX_PTO: i64 = 60_000_000_000; /// Caps backoff to ensure timely retransmission under extreme packet loss. const MAX_HANDSHAKE_PTO: i64 = 3_000_000_000; -/// Maximum number of packets tracked per ACK result. -const MAX_ACK_RESULT: usize = 256; - /// Maximum number of stream frame records per sent packet. /// Must be large enough to track all stream frames in a packet (e.g., many small 0-RTT streams). // Cap inline stream-frame records per sent packet. Each is 32 bytes and lives @@ -91,37 +88,51 @@ pub const SentPacket = struct { } }; -/// Fixed-capacity list of SentPackets for ACK results. +/// Heap-backed list of SentPackets for ACK results. +/// +/// Bulk benchmark rows can acknowledge or declare lost thousands of packets in +/// a single ACK/loss pass after ACK compression or loss bursts. A fixed small +/// result buffer silently dropped the excess packets, which left +/// bytes_in_flight and stream ack_offset accounting stale even though packets +/// had been removed from the sent-packet map. pub const SentPacketList = struct { - buf: [MAX_ACK_RESULT]SentPacket = undefined, - len: usize = 0, + items: std.ArrayListUnmanaged(SentPacket) = .empty, - pub fn append(self: *SentPacketList, item: SentPacket) void { - if (self.len < MAX_ACK_RESULT) { - self.buf[self.len] = item; - self.len += 1; - } + pub fn deinit(self: *SentPacketList, allocator: Allocator) void { + self.items.deinit(allocator); + } + + pub fn clearRetainingCapacity(self: *SentPacketList) void { + self.items.clearRetainingCapacity(); + } + + pub fn append(self: *SentPacketList, allocator: Allocator, item: SentPacket) !void { + try self.items.append(allocator, item); } pub fn constSlice(self: *const SentPacketList) []const SentPacket { - return self.buf[0..self.len]; + return self.items.items; + } + + pub fn count(self: *const SentPacketList) usize { + return self.items.items.len; } }; -/// Fixed-capacity list of u64 for tracking packet numbers. +/// Heap-backed list of u64 for tracking packet numbers. const PnList = struct { - buf: [MAX_ACK_RESULT]u64 = undefined, - len: usize = 0, + items: std.ArrayListUnmanaged(u64) = .empty, - pub fn append(self: *PnList, item: u64) void { - if (self.len < MAX_ACK_RESULT) { - self.buf[self.len] = item; - self.len += 1; - } + pub fn deinit(self: *PnList, allocator: Allocator) void { + self.items.deinit(allocator); + } + + pub fn append(self: *PnList, allocator: Allocator, item: u64) !void { + try self.items.append(allocator, item); } pub fn constSlice(self: *const PnList) []const u64 { - return self.buf[0..self.len]; + return self.items.items; } }; @@ -130,6 +141,17 @@ pub const AckResult = struct { acked: SentPacketList = .{}, lost: SentPacketList = .{}, persistent_congestion: bool = false, + + pub fn deinit(self: *AckResult, allocator: Allocator) void { + self.acked.deinit(allocator); + self.lost.deinit(allocator); + } + + pub fn reset(self: *AckResult) void { + self.acked.clearRetainingCapacity(); + self.lost.clearRetainingCapacity(); + self.persistent_congestion = false; + } }; /// Tracks sent packets and handles loss detection for a single packet number space. @@ -183,59 +205,58 @@ pub const SentPacketTracker = struct { now: i64, result: *AckResult, ) !void { - result.* = .{}; + result.reset(); if (self.largest_acked == null or largest_ack > self.largest_acked.?) { self.largest_acked = largest_ack; } - // Process the first ACK range: [largest_ack - first_ack_range, largest_ack] - { - const range_start = largest_ack -| first_ack_range; - var pn = range_start; - while (pn <= largest_ack) : (pn += 1) { - if (self.sent_packets.fetchSwapRemove(pn)) |kv| { - const pkt = kv.value; - if (pkt.ack_eliciting) { - self.ack_eliciting_in_flight -|= 1; - } - - if (pkt.pn == largest_ack) { - const send_delta = now - pkt.time_sent; - rtt_stats.updateRtt(send_delta, ack_delay_ns, true); + // ACK ranges can span hundreds of thousands of packet numbers after ACK + // compression. Iterate the packets still in flight instead of walking + // every packet number in the encoded ranges. + const first_range_start = largest_ack -| first_ack_range; + var i: usize = 0; + while (i < self.sent_packets.count()) { + const pn = self.sent_packets.keys()[i]; + var acked = pn >= first_range_start and pn <= largest_ack; + if (!acked) { + for (ack_ranges) |range| { + if (pn >= range.start and pn <= range.end) { + acked = true; + break; } - - result.acked.append(pkt); } - if (pn == largest_ack) break; } - } + if (!acked) { + i += 1; + continue; + } - // Process additional ACK ranges - for (ack_ranges) |range| { - var pn = range.start; - while (pn <= range.end) : (pn += 1) { - if (self.sent_packets.fetchSwapRemove(pn)) |kv| { - const pkt = kv.value; - if (pkt.ack_eliciting) { - self.ack_eliciting_in_flight -|= 1; - } - result.acked.append(pkt); - } - if (pn == range.end) break; + const kv = self.sent_packets.fetchSwapRemove(pn).?; + const pkt = kv.value; + if (pkt.ack_eliciting) { + self.ack_eliciting_in_flight -|= 1; } + + if (pkt.pn == largest_ack) { + const send_delta = now - pkt.time_sent; + rtt_stats.updateRtt(send_delta, ack_delay_ns, true); + } + + try result.acked.append(self.allocator, pkt); } // Detect lost packets - self.detectLostPackets(rtt_stats, now, result); + try self.detectLostPackets(rtt_stats, now, result); } - fn detectLostPackets(self: *SentPacketTracker, rtt_stats: *RttStats, now: i64, result: *AckResult) void { + fn detectLostPackets(self: *SentPacketTracker, rtt_stats: *RttStats, now: i64, result: *AckResult) !void { self.loss_time = null; const loss_delay = rtt_stats.lossDelay(); const lost_send_time = now - loss_delay; var to_remove: PnList = .{}; + defer to_remove.deinit(self.allocator); // Track earliest and latest send times of lost ack-eliciting packets // for persistent congestion detection (RFC 9002 §7.6.2) @@ -250,8 +271,8 @@ pub const SentPacketTracker = struct { } if (pkt.time_sent <= lost_send_time) { - result.lost.append(pkt); - to_remove.append(pkt.pn); + try result.lost.append(self.allocator, pkt); + try to_remove.append(self.allocator, pkt.pn); if (pkt.ack_eliciting) { if (earliest_lost_time == null or pkt.time_sent < earliest_lost_time.?) { earliest_lost_time = pkt.time_sent; @@ -266,8 +287,8 @@ pub const SentPacketTracker = struct { if (self.largest_acked.? >= PACKET_THRESHOLD and pkt.pn <= self.largest_acked.? - PACKET_THRESHOLD) { - result.lost.append(pkt); - to_remove.append(pkt.pn); + try result.lost.append(self.allocator, pkt); + try to_remove.append(self.allocator, pkt.pn); if (pkt.ack_eliciting) { if (earliest_lost_time == null or pkt.time_sent < earliest_lost_time.?) { earliest_lost_time = pkt.time_sent; @@ -695,10 +716,10 @@ pub const PacketHandler = struct { /// Run loss detection for a specific packet number space (called when loss_time fires). /// Returns the lost packets for congestion control processing. - pub fn detectLossesForSpace(self: *PacketHandler, level: EncLevel, now: i64, result: *AckResult) void { + pub fn detectLossesForSpace(self: *PacketHandler, level: EncLevel, now: i64, result: *AckResult) !void { const idx = @intFromEnum(level); - result.* = .{}; - self.sent[idx].detectLostPackets(&self.rtt_stats, now, result); + result.reset(); + try self.sent[idx].detectLostPackets(&self.rtt_stats, now, result); for (result.lost.constSlice()) |pkt| { if (pkt.in_flight) { self.bytes_in_flight -|= pkt.size; @@ -762,9 +783,10 @@ test "SentPacketTracker: basic send and ack" { const ack_time = now + 50_000_000; var result: AckResult = .{}; + defer result.deinit(testing.allocator); try tracker.onAckReceived(0, 0, &.{}, 0, &rtt_stats, ack_time, &result); - try testing.expectEqual(@as(usize, 1), result.acked.len); + try testing.expectEqual(@as(usize, 1), result.acked.count()); try testing.expectEqual(@as(u64, 0), result.acked.constSlice()[0].pn); try testing.expectEqual(@as(u32, 0), tracker.ack_eliciting_in_flight); try testing.expect(rtt_stats.has_measurement); @@ -782,6 +804,39 @@ test "ReceivedPacketTracker: immediate ACK on reorder" { try testing.expect(tracker.ack_queued); } +test "SentPacketTracker: large ACK range scans in-flight packets" { + var tracker = SentPacketTracker.init(testing.allocator); + defer tracker.deinit(); + + var rtt_stats = RttStats{}; + const now: i64 = 1_000_000_000; + + try tracker.onPacketSent(.{ + .pn = 10, + .time_sent = now, + .size = 1200, + .ack_eliciting = true, + .in_flight = true, + .enc_level = .application, + }); + try tracker.onPacketSent(.{ + .pn = 1_000_000, + .time_sent = now + 1_000, + .size = 1200, + .ack_eliciting = true, + .in_flight = true, + .enc_level = .application, + }); + + var result: AckResult = .{}; + defer result.deinit(testing.allocator); + try tracker.onAckReceived(1_000_000, 0, &.{}, 999_990, &rtt_stats, now + 50_000_000, &result); + + try testing.expectEqual(@as(usize, 2), result.acked.count()); + try testing.expectEqual(@as(usize, 0), tracker.sent_packets.count()); + try testing.expectEqual(@as(u32, 0), tracker.ack_eliciting_in_flight); +} + test "PacketHandler: integration" { var handler = PacketHandler.init(testing.allocator); defer handler.deinit(); @@ -804,6 +859,7 @@ test "PacketHandler: integration" { const ack_time = now + 50_000_000; var result: AckResult = .{}; + defer result.deinit(testing.allocator); try handler.onAckReceived(.initial, 0, 0, 3, &.{}, 0, ack_time, &result); try testing.expectEqual(@as(u64, 0), handler.bytes_in_flight); @@ -947,4 +1003,3 @@ test "NewReno: app_limited suppresses cwnd growth" { cc.onPacketAcked(1200, 300); try testing.expect(cc.congestion_window > after_ack); } - diff --git a/src/quic/connection.zig b/src/quic/connection.zig index 19d7948..e3f1059 100644 --- a/src/quic/connection.zig +++ b/src/quic/connection.zig @@ -1400,6 +1400,7 @@ pub const Connection = struct { self.cc.app_limited = self.pkt_handler.bytes_in_flight < self.cc.sendWindow(); var ack_result: ack_handler.AckResult = .{}; + defer ack_result.deinit(self.allocator); try self.pkt_handler.onAckReceived( enc_level, ack.largest_ack, @@ -1432,11 +1433,11 @@ pub const Connection = struct { for (pkt.getStreamFrames()) |sf| { if (stream_mod.isBidi(sf.stream_id)) { if (self.streams.getStream(sf.stream_id)) |s| { - s.send.onAck(sf.offset, sf.length); + try s.send.onAck(sf.offset, sf.length); } } else { if (self.streams.send_streams.get(sf.stream_id)) |s| { - s.onAck(sf.offset, sf.length); + try s.onAck(sf.offset, sf.length); } } } @@ -1514,7 +1515,7 @@ pub const Connection = struct { ql.metricsUpdated(now, rs.min_rtt, rs.smoothed_rtt, rs.latest_rtt, rs.rtt_var, self.cc.sendWindow(), self.pkt_handler.bytes_in_flight); } - self.maybeConfirmHandshake(enc_level, result.acked.len); + self.maybeConfirmHandshake(enc_level, result.acked.count()); // Update pacer self.pacer.setBandwidth(self.cc.sendWindow(), &self.pkt_handler.rtt_stats); @@ -1529,6 +1530,7 @@ pub const Connection = struct { self.cc.app_limited = self.pkt_handler.bytes_in_flight < self.cc.sendWindow(); var ack_result: ack_handler.AckResult = .{}; + defer ack_result.deinit(self.allocator); try self.pkt_handler.onAckReceived( enc_level, ack.largest_ack, @@ -1564,11 +1566,11 @@ pub const Connection = struct { for (pkt.getStreamFrames()) |sf| { if (stream_mod.isBidi(sf.stream_id)) { if (self.streams.getStream(sf.stream_id)) |s| { - s.send.onAck(sf.offset, sf.length); + try s.send.onAck(sf.offset, sf.length); } } else { if (self.streams.send_streams.get(sf.stream_id)) |s| { - s.onAck(sf.offset, sf.length); + try s.onAck(sf.offset, sf.length); } } } @@ -1667,7 +1669,7 @@ pub const Connection = struct { self.peer_ecn_ect1[space_idx] = ack.ecn_ect1; self.peer_ecn_ce[space_idx] = ack.ecn_ce; - self.maybeConfirmHandshake(enc_level, result.acked.len); + self.maybeConfirmHandshake(enc_level, result.acked.count()); // Update pacer self.pacer.setBandwidth(self.cc.sendWindow(), &self.pkt_handler.rtt_stats); @@ -1863,8 +1865,39 @@ pub const Connection = struct { self.streams.setMaxStreams(self.streams.max_bidi_streams, max); }, - .data_blocked => {}, - .stream_data_blocked => {}, + .data_blocked => |limit| { + // If the peer is blocked on connection credit, immediately + // re-advertise our current receive limit. MAX_DATA is + // idempotent, and blocked frames are the peer's explicit + // signal that a previous update might not have arrived. + self.queueFlowControlUpdates(); + const current = self.conn_flow_ctrl.base.receive_window; + if (current > limit) { + self.pending_frames.push(.{ .max_data = current }); + } + }, + .stream_data_blocked => |blocked| { + // Same logic for per-stream credit. For bidi streams, the + // receive side carries the limit we advertise to the peer. + self.queueFlowControlUpdates(); + if (self.streams.getStream(blocked.stream_id)) |s| { + const current = s.recv.receive_window; + if (current > blocked.limit) { + self.pending_frames.push(.{ .max_stream_data = .{ + .stream_id = blocked.stream_id, + .max = current, + } }); + } + } else if (self.streams.recv_streams.get(blocked.stream_id)) |s| { + const current = s.receive_window; + if (current > blocked.limit) { + self.pending_frames.push(.{ .max_stream_data = .{ + .stream_id = blocked.stream_id, + .max = current, + } }); + } + } + }, .streams_blocked_bidi => |val| { // RFC 9000 §19.14: STREAMS_BLOCKED must not exceed 2^60 if (val > (1 << 60)) { @@ -2950,13 +2983,20 @@ pub const Connection = struct { if (self.pkt_num_spaces[1].crypto_seal != null) { self.queueCryptoRetransmission(.handshake); } - // Reset stream send_offset for unACKed data + // Queue unacked stream bytes as retransmission ranges. Do not + // rewind send_offset: already-sent bytes must not be recounted + // against connection flow control when MAX_DATA credit is zero. var resend_it = self.streams.streams.valueIterator(); while (resend_it.next()) |s_ptr| { const s = s_ptr.*; if (s.send.hasUnackedData()) { - s.send.send_offset = s.send.ack_offset; - if (s.send.fin_queued) s.send.fin_sent = false; + const start = s.send.ack_offset; + const end = s.send.write_offset; + if (end > start) { + s.send.queueRetransmit(start, end - start, s.send.fin_queued); + } else if (s.send.fin_queued) { + s.send.queueRetransmit(end, 0, true); + } } } } @@ -3205,7 +3245,8 @@ pub const Connection = struct { // Loss timers don't increment pto_count — they run loss detection directly. if (self.pkt_handler.getExpiredLossTime(now)) |loss_level| { var loss_result: ack_handler.AckResult = .{}; - self.pkt_handler.detectLossesForSpace(loss_level, now, &loss_result); + defer loss_result.deinit(self.allocator); + try self.pkt_handler.detectLossesForSpace(loss_level, now, &loss_result); var has_non_probe_loss_lt = false; var earliest_lost_sent_time_lt: ?i64 = null; for (loss_result.lost.constSlice()) |pkt| { @@ -3322,29 +3363,22 @@ pub const Connection = struct { } } } - // Always scan in-flight packets for stream data to retransmit - // (RFC 9002 §6.2.4: prefer data over PING). This is critical - // under loss: after the packer consumes retransmission data, - // hasData() returns false, but the retransmission packet might - // still be in-flight (not yet ACKed/declared lost). Without this - // unconditional scan, the PTO sends PINGs instead of data, - // and the peer never receives the file. - // If no stream data is pending AND the stream has unsent - // data that was consumed by the packer but never ACKed, - // reset send_offset to write_offset - data_size to force - // retransmission. Only do this for small streams (multiconnect - // serves 1KB files) to avoid resending entire large transfers. - // If no pending data but there IS unACKed data, - // reset send_offset to ack_offset to retransmit - // only the unACKed portion. - if (!has_data) { + // PTO must rescue every stream with outstanding data, not only + // the connection as a whole. In multi-stream transfers, one + // stream can already have pending data while another has an + // unacked hole and no queued retransmission. A connection-level + // has_data guard would leave that second stream stalled. + { var resend_it = self.streams.streams.valueIterator(); while (resend_it.next()) |s_ptr| { const s = s_ptr.*; if (s.send.hasUnackedData()) { - s.send.send_offset = s.send.ack_offset; - if (s.send.fin_queued) { - s.send.fin_sent = false; + const start = s.send.ack_offset; + const end = s.send.write_offset; + if (end > start) { + s.send.queueRetransmit(start, end - start, s.send.fin_queued); + } else if (s.send.fin_queued) { + s.send.queueRetransmit(end, 0, true); } has_data = true; } diff --git a/src/quic/flow_control.zig b/src/quic/flow_control.zig index 1c9cc87..b518cb4 100644 --- a/src/quic/flow_control.zig +++ b/src/quic/flow_control.zig @@ -34,7 +34,7 @@ pub const BaseFlowController = struct { return .{ .receive_window = receive_window, .receive_window_size = receive_window, - .max_receive_window_size = max_receive_window, + .max_receive_window_size = @max(max_receive_window, receive_window), }; } @@ -65,13 +65,12 @@ pub const BaseFlowController = struct { // Check if we should send a BLOCKED frame and mark it as sent. // Returns the limit if a blocked frame should be sent, null otherwise. - // Only triggers once per limit to avoid duplicates. + // BLOCKED frames are advisory. Emitting once per blocked limit avoids + // packet storms while still re-arming when the peer raises credit. pub fn shouldSendBlocked(self: *BaseFlowController) ?u64 { - if (self.bytes_sent >= self.send_window) { - if (self.blocked_at == null or self.blocked_at.? != self.send_window) { - self.blocked_at = self.send_window; - return self.send_window; - } + if (self.bytes_sent >= self.send_window and self.blocked_at != self.send_window) { + self.blocked_at = self.send_window; + return self.send_window; } return null; } @@ -285,15 +284,15 @@ test "StreamFlowController: limited by connection" { try testing.expectEqual(@as(u64, 500), sfc.sendWindowSize()); } -test "BaseFlowController: shouldSendBlocked fires once per limit" { +test "BaseFlowController: shouldSendBlocked emits once per blocked limit" { var fc = BaseFlowController.init(1000, MAX_RECEIVE_WINDOW); fc.send_window = 100; fc.addBytesSent(100); // First call should return the limit try testing.expectEqual(@as(?u64, 100), fc.shouldSendBlocked()); - // Second call should return null (already sent for this limit) - try testing.expect(fc.shouldSendBlocked() == null); + // Repeated calls at the same blocked limit should not generate a packet storm. + try testing.expectEqual(@as(?u64, null), fc.shouldSendBlocked()); // After window update, should be able to send again fc.updateSendWindow(200); diff --git a/src/quic/frame.zig b/src/quic/frame.zig index 81ed678..fc4f78b 100644 --- a/src/quic/frame.zig +++ b/src/quic/frame.zig @@ -850,6 +850,82 @@ pub const PendingFrameQueue = struct { len: u8 = 0, pub fn push(self: *PendingFrameQueue, frame: PendingControlFrame) void { + var i: u8 = 0; + while (i < self.len) : (i += 1) { + switch (frame) { + .ping => switch (self.items[i]) { + .ping => return, + else => {}, + }, + .immediate_ack => switch (self.items[i]) { + .immediate_ack => return, + else => {}, + }, + .max_data => |new_max| switch (self.items[i]) { + .max_data => |old_max| { + self.items[i] = .{ .max_data = @max(old_max, new_max) }; + return; + }, + else => {}, + }, + .max_stream_data => |new_msd| switch (self.items[i]) { + .max_stream_data => |old_msd| if (old_msd.stream_id == new_msd.stream_id) { + self.items[i] = .{ .max_stream_data = .{ + .stream_id = new_msd.stream_id, + .max = @max(old_msd.max, new_msd.max), + } }; + return; + }, + else => {}, + }, + .max_streams_bidi => |new_max| switch (self.items[i]) { + .max_streams_bidi => |old_max| { + self.items[i] = .{ .max_streams_bidi = @max(old_max, new_max) }; + return; + }, + else => {}, + }, + .max_streams_uni => |new_max| switch (self.items[i]) { + .max_streams_uni => |old_max| { + self.items[i] = .{ .max_streams_uni = @max(old_max, new_max) }; + return; + }, + else => {}, + }, + .data_blocked => |new_limit| switch (self.items[i]) { + .data_blocked => |old_limit| { + self.items[i] = .{ .data_blocked = @max(old_limit, new_limit) }; + return; + }, + else => {}, + }, + .stream_data_blocked => |new_sdb| switch (self.items[i]) { + .stream_data_blocked => |old_sdb| if (old_sdb.stream_id == new_sdb.stream_id) { + self.items[i] = .{ .stream_data_blocked = .{ + .stream_id = new_sdb.stream_id, + .limit = @max(old_sdb.limit, new_sdb.limit), + } }; + return; + }, + else => {}, + }, + .streams_blocked_bidi => |new_limit| switch (self.items[i]) { + .streams_blocked_bidi => |old_limit| { + self.items[i] = .{ .streams_blocked_bidi = @max(old_limit, new_limit) }; + return; + }, + else => {}, + }, + .streams_blocked_uni => |new_limit| switch (self.items[i]) { + .streams_blocked_uni => |old_limit| { + self.items[i] = .{ .streams_blocked_uni = @max(old_limit, new_limit) }; + return; + }, + else => {}, + }, + else => {}, + } + } if (self.len < capacity) { self.items[self.len] = frame; self.len += 1; diff --git a/src/quic/packet_packer.zig b/src/quic/packet_packer.zig index f3ac1f0..a89e3a1 100644 --- a/src/quic/packet_packer.zig +++ b/src/quic/packet_packer.zig @@ -332,11 +332,13 @@ pub const PacketPacker = struct { // 0-RTT packets only contain STREAM and DATAGRAM frames — skip ACK, CRYPTO, control if (!zero_rtt) { // 1. ACK frame (always first if pending) - // Force ACK generation whenever there are unacknowledged ack-eliciting packets. - // In ack_only mode (congestion-limited), prompt ACKs are critical for the peer's - // CC to grow its window. Delaying ACKs starves the peer of feedback. + // Force ACK generation only when it can piggyback on non-ACK data. + // ACK-only packets bypass congestion control; forcing one on every + // congestion-limited poll can create an ACK-only send loop that + // never drains bytes_in_flight or returns quiescent state to the + // application. In ack_only mode, honor the normal ACK queue/alarm. const ack_delay_exp: u64 = 3; - const ack_frame_opt: ?Frame = if (pkt_handler.hasUnackedAckEliciting(level)) + const ack_frame_opt: ?Frame = if (!ack_only and pkt_handler.hasUnackedAckEliciting(level)) pkt_handler.getAckFrameForced(level, now, ack_delay_exp) else pkt_handler.getAckFrame(level, now, ack_delay_exp); @@ -411,16 +413,19 @@ pub const PacketPacker = struct { // 4. Pending control frames (only in 1-RTT) // PATH_CHALLENGE/PATH_RESPONSE are always sent (path probing is exempt from CC). - // Other control frames are skipped when ack_only (congestion-limited). + // Receiver-credit frames must also be allowed through the congestion-limited + // path: otherwise a small control packet can fill cwnd and prevent MAX_DATA + // from unblocking a peer that is already stuck at the connection window. if (level == .application) { var remaining = pending_frames.len; while (remaining > 0) : (remaining -= 1) { const pcf = pending_frames.pop() orelse break; - const is_path_probing = switch (pcf) { + const is_urgent_control = switch (pcf) { .path_challenge, .path_response => true, + .max_data, .max_stream_data, .max_streams_bidi, .max_streams_uni => true, else => false, }; - if (is_path_probing or !ack_only) { + if (is_urgent_control or !ack_only) { try pcf.write(writer); ack_eliciting = true; } else { @@ -445,12 +450,16 @@ pub const PacketPacker = struct { const sched_count = streams.getScheduledStreams(&sched_buf); for (sched_buf[0..sched_count]) |s| { if (fbs.seek + AEAD_TAG_LEN + 16 >= effective_max) break; - if (conn_budget == 0) break; + const retransmitting = s.send.hasRetransmitData(); + if (conn_budget == 0 and !retransmitting) continue; if (stream_frame_info_count >= ack_handler.MAX_STREAM_FRAMES_PER_PACKET) break; const remaining = effective_max - fbs.seek - AEAD_TAG_LEN; const header_overhead = streamFrameHeaderOverhead(s.send.stream_id, s.send.send_offset, remaining); if (remaining <= header_overhead) break; - const max_stream_data = @min(remaining - header_overhead, conn_budget); + const max_stream_data = if (retransmitting) + remaining - header_overhead + else + @min(remaining - header_overhead, conn_budget); const prev_send_offset = s.send.send_offset; if (s.send.popStreamFrame(max_stream_data)) |stream_frame| { try stream_frame.write(writer); @@ -478,12 +487,16 @@ pub const PacketPacker = struct { if (uni_sched_count > 0) { for (uni_sched_buf[0..uni_sched_count]) |s| { if (fbs.seek + AEAD_TAG_LEN + 16 >= effective_max) break; - if (conn_budget == 0) break; + const retransmitting = s.hasRetransmitData(); + if (conn_budget == 0 and !retransmitting) continue; if (stream_frame_info_count >= ack_handler.MAX_STREAM_FRAMES_PER_PACKET) break; const remaining_uni = effective_max - fbs.seek - AEAD_TAG_LEN; const uni_header_overhead = streamFrameHeaderOverhead(s.stream_id, s.send_offset, remaining_uni); if (remaining_uni <= uni_header_overhead) break; - const max_stream_data = @min(remaining_uni - uni_header_overhead, conn_budget); + const max_stream_data = if (retransmitting) + remaining_uni - uni_header_overhead + else + @min(remaining_uni - uni_header_overhead, conn_budget); const prev_uni_offset = s.send_offset; if (s.popStreamFrame(max_stream_data)) |stream_frame| { try stream_frame.write(writer); @@ -889,7 +902,7 @@ test "PacketPacker: no data produces no packet" { try testing.expectEqual(@as(usize, 0), written); } -test "PacketPacker: ack_only skips stream data" { +test "PacketPacker: ack_only skips stream data but sends receiver credit" { const scid = &[_]u8{0x01}; const dcid = &[_]u8{ 0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08 }; var packer = PacketPacker.init(testing.allocator, false, scid, dcid, 0x00000001); @@ -904,6 +917,7 @@ test "PacketPacker: ack_only skips stream data" { defer streams.deinit(); var pending_frames = frame_mod.PendingFrameQueue{}; + pending_frames.push(.{ .max_data = 12345 }); // Create a stream with data streams.setMaxStreams(10, 10); @@ -933,13 +947,56 @@ test "PacketPacker: ack_only skips stream data" { true, // ack_only = true ); - // Should produce an ACK-only packet + // Should produce a packet carrying ACK plus receiver-credit control. try testing.expect(written > 0); + try testing.expectEqual(@as(u8, 0), pending_frames.len); // Stream data should still be waiting (not consumed) try testing.expect(s.send.hasData()); } +test "PacketPacker: ack_only does not force sub-threshold ACKs" { + const scid = &[_]u8{0x01}; + const dcid = &[_]u8{ 0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08 }; + var packer = PacketPacker.init(testing.allocator, false, scid, dcid, 0x00000001); + + var pkt_handler = ack_handler.PacketHandler.init(testing.allocator); + defer pkt_handler.deinit(); + + var crypto_mgr = crypto_stream.CryptoStreamManager.init(testing.allocator); + defer crypto_mgr.deinit(); + + var streams = stream_mod.StreamsMap.init(testing.allocator, false); + defer streams.deinit(); + + var pending_frames = frame_mod.PendingFrameQueue{}; + + // One ack-eliciting packet arms a delayed ACK, but it is below the immediate + // ACK threshold. ACK-only packing must not force an immediate packet. + try pkt_handler.recv[2].onPacketReceived(0, true, 1000, 0); + + const keys = try testClientKeys(); + + var out_buf: [1500]u8 = undefined; + const written = try packer.packCoalesced( + &out_buf, + &pkt_handler, + &crypto_mgr, + &streams, + &pending_frames, + null, + null, + null, + keys.seal, + 1000, + null, + true, + ); + + try testing.expectEqual(@as(usize, 0), written); + try testing.expectEqual(@as(u64, 0), pkt_handler.next_pn[2]); +} + test "PacketPacker: coalesced Initial + Handshake" { const scid = &[_]u8{0x01}; const dcid = &[_]u8{ 0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08 }; diff --git a/src/quic/stream.zig b/src/quic/stream.zig index 7545ef9..6c86758 100644 --- a/src/quic/stream.zig +++ b/src/quic/stream.zig @@ -4,6 +4,7 @@ const testing = std.testing; const flow_control = @import("flow_control.zig"); const Frame = @import("frame.zig").Frame; +const ranges = @import("ranges.zig"); /// Stream ID encoding per RFC 9000 Section 2.1: /// Bit 0: initiator (0 = client, 1 = server) @@ -104,17 +105,62 @@ pub const FrameSorter = struct { effective_offset = self.read_pos; } - // Check if there's already a chunk at this offset. - // Don't overwrite a longer chunk with a shorter one (retransmission - // with different fragmentation boundaries). Also free old data to - // prevent memory leaks. - if (self.chunks.get(effective_offset)) |existing| { - if (existing.len >= effective_data.len) { - // Existing chunk covers at least as much data — skip. - return; + // Fast path for the dominant receive case: new STREAM data appends at + // or beyond the highest byte ever buffered. It cannot overlap an + // existing chunk, so avoid scanning the full chunk map for every packet. + if (effective_offset >= self.highest_buffered) { + const owned = try self.allocator.dupe(u8, effective_data); + errdefer self.allocator.free(owned); + try self.chunks.put(self.allocator, effective_offset, owned); + self.highest_buffered = effective_offset + owned.len; + return; + } + + while (true) { + const new_start = effective_offset; + const new_end = effective_offset + effective_data.len; + var changed = false; + var i: usize = 0; + while (i < self.chunks.count()) { + const existing_offset = self.chunks.keys()[i]; + const existing = self.chunks.values()[i]; + const existing_end = existing_offset + existing.len; + if (existing_end <= new_start or existing_offset >= new_end) { + i += 1; + continue; + } + + if (existing_offset <= new_start and existing_end >= new_end) { + return; + } + + if (existing_offset <= new_start and existing_end > new_start) { + const skip: usize = @intCast(existing_end - new_start); + effective_offset = existing_end; + effective_data = effective_data[skip..]; + if (effective_data.len == 0) return; + changed = true; + break; + } + + if (existing_offset < new_end and existing_end > new_end) { + const suffix_start: usize = @intCast(new_end - existing_offset); + const suffix = existing[suffix_start..]; + const owned_suffix = try self.allocator.dupe(u8, suffix); + errdefer self.allocator.free(owned_suffix); + _ = self.chunks.swapRemove(existing_offset); + self.allocator.free(existing); + try self.chunks.put(self.allocator, new_end, owned_suffix); + changed = true; + break; + } + + _ = self.chunks.swapRemove(existing_offset); + self.allocator.free(existing); + changed = true; + break; } - // New chunk is longer — free old, overwrite below. - self.allocator.free(existing); + if (!changed) break; } // Copy data to owned buffer @@ -130,10 +176,39 @@ pub const FrameSorter = struct { /// Pop the next contiguous chunk of data from the read position. /// Returns null if there's no data available at the current read position. pub fn pop(self: *FrameSorter) ?[]const u8 { - if (self.chunks.get(self.read_pos)) |data| { - _ = self.chunks.orderedRemove(self.read_pos); - self.read_pos += data.len; - return data; + if (self.chunks.fetchSwapRemove(self.read_pos)) |entry| { + self.read_pos += entry.value.len; + return entry.value; + } + + var best_index: ?usize = null; + var best_end: u64 = 0; + for (self.chunks.keys(), 0..) |offset, index| { + const data = self.chunks.values()[index]; + const end = offset + data.len; + if (offset <= self.read_pos and self.read_pos < end and end > best_end) { + best_index = index; + best_end = end; + } + } + if (best_index) |index| { + const offset = self.chunks.keys()[index]; + const data = self.chunks.values()[index]; + _ = self.chunks.swapRemove(offset); + const skip: usize = @intCast(self.read_pos - offset); + const readable = data[skip..]; + const owned = if (skip == 0) + data + else blk: { + const copy = self.allocator.dupe(u8, readable) catch { + self.allocator.free(data); + return null; + }; + self.allocator.free(data); + break :blk copy; + }; + self.read_pos += readable.len; + return owned; } return null; } @@ -283,6 +358,11 @@ pub const SendStream = struct { /// ack_offset and write_offset may not have been received. ack_offset: u64 = 0, + /// STREAM frame ACKs can arrive out of order because packet ACK ranges are + /// not equivalent to contiguous stream-byte delivery. Track acknowledged + /// byte ranges and only advance ack_offset through a contiguous prefix. + acked_ranges: ranges.RangeSet, + /// Maximum data the peer allows us to send on this stream. send_window: u64 = std.math.maxInt(u64), @@ -293,7 +373,7 @@ pub const SendStream = struct { urgency: u3 = 3, /// RFC 9218 priority: incremental streams are interleaved round-robin. - incremental: bool = false, + incremental: bool = true, /// WebTransport sendOrder: higher values transmitted first. When set, /// takes precedence over RFC 9218 urgency for scheduling. @@ -323,10 +403,12 @@ pub const SendStream = struct { .stream_id = stream_id, .allocator = allocator, .write_buffer = .{ .items = &.{}, .capacity = 0 }, + .acked_ranges = ranges.RangeSet.init(allocator), }; } pub fn deinit(self: *SendStream) void { + self.acked_ranges.deinit(); self.write_buffer.deinit(self.allocator); } @@ -353,11 +435,32 @@ pub const SendStream = struct { } /// Update the acknowledged offset when a packet carrying stream frames is ACKed. - pub fn onAck(self: *SendStream, offset: u64, length: u64) void { + pub fn onAck(self: *SendStream, offset: u64, length: u64) !void { + if (length == 0) { + return; + } const end = offset + length; - if (end > self.ack_offset) { - self.ack_offset = end; + if (end <= self.ack_offset) { + return; + } + + try self.acked_ranges.addRange(@max(offset, self.ack_offset), end - 1); + while (true) { + var advanced = false; + for (self.acked_ranges.getRanges()) |range| { + if (range.start <= self.ack_offset and self.ack_offset <= range.end) { + self.ack_offset = range.end + 1; + self.acked_ranges.removeBelow(self.ack_offset); + advanced = true; + break; + } + } + if (!advanced) { + break; + } } + self.trimRetransmitRangesBelow(self.ack_offset); + self.send_offset = @max(self.send_offset, self.ack_offset); } /// Cancel the stream with an error code (sends RESET_STREAM). @@ -374,13 +477,12 @@ pub const SendStream = struct { } // Check if we should send STREAM_DATA_BLOCKED. Returns the limit if yes. - // Only triggers once per limit to avoid duplicates. + // STREAM_DATA_BLOCKED is advisory; emit once per blocked limit and re-arm + // when MAX_STREAM_DATA advances the send window. pub fn shouldSendBlocked(self: *SendStream) ?u64 { - if (self.send_offset >= self.send_window and self.hasData()) { - if (self.blocked_at == null or self.blocked_at.? != self.send_window) { - self.blocked_at = self.send_window; - return self.send_window; - } + if (self.send_offset >= self.send_window and self.hasData() and self.blocked_at != self.send_window) { + self.blocked_at = self.send_window; + return self.send_window; } return null; } @@ -424,18 +526,24 @@ pub const SendStream = struct { }; self.retransmit_count += 1; } else { - // Queue overflow: fall back to resending from the earliest lost offset. - // Find minimum offset across all queued ranges and the new range, - // then reset send_offset so the packer resends everything from there. - // The receiver's FrameSorter deduplicates any already-received data. + // Queue overflow: coalesce to one broad retransmit range. Do not + // rewind send_offset: retransmitted bytes have already consumed + // connection-level flow-control credit and must stay on the + // retransmit path. var min_offset = offset; + var max_end = offset + length; var has_fin = fin; for (self.retransmit_ranges[0..self.retransmit_count]) |r| { min_offset = @min(min_offset, r.offset); + max_end = @max(max_end, r.offset + r.length); if (r.fin) has_fin = true; } - self.send_offset = @min(self.send_offset, min_offset); - self.retransmit_count = 0; + self.retransmit_ranges[0] = .{ + .offset = min_offset, + .length = max_end - min_offset, + .fin = has_fin, + }; + self.retransmit_count = 1; if (has_fin) { self.fin_lost = true; self.fin_sent = false; @@ -451,6 +559,13 @@ pub const SendStream = struct { (self.fin_queued and !self.fin_sent); } + /// Check if the next send is retransmission-only data. Retransmitted bytes + /// were already counted against connection flow control when first sent, so + /// packet assembly must not block them behind exhausted MAX_DATA credit. + pub fn hasRetransmitData(self: *const SendStream) bool { + return self.retransmit_count > 0 or self.fin_lost; + } + /// Check if there's data that has been sent but not yet acknowledged. /// Used by PTO to determine if retransmission is needed. pub fn hasUnackedData(self: *const SendStream) bool { @@ -589,6 +704,22 @@ pub const SendStream = struct { } self.retransmit_count -= 1; } + + fn trimRetransmitRangesBelow(self: *SendStream, acked_offset: u64) void { + var i: usize = 0; + while (i < self.retransmit_count) { + const end = self.retransmit_ranges[i].offset + self.retransmit_ranges[i].length; + if (end <= acked_offset) { + self.removeRetransmitRange(i); + continue; + } + if (self.retransmit_ranges[i].offset < acked_offset) { + self.retransmit_ranges[i].length = end - acked_offset; + self.retransmit_ranges[i].offset = acked_offset; + } + i += 1; + } + } }; /// A bidirectional QUIC stream combining send and receive. @@ -1164,6 +1295,58 @@ test "FrameSorter: out-of-order data" { testing.allocator.free(chunk2.?); } +test "FrameSorter: sequential append fast path remains readable" { + var sorter = FrameSorter.init(testing.allocator); + defer sorter.deinit(); + + try sorter.push(0, "hello", false); + try sorter.push(5, " ", false); + try sorter.push(6, "world", true); + + const chunk1 = sorter.pop(); + try testing.expect(chunk1 != null); + try testing.expectEqualStrings("hello", chunk1.?); + testing.allocator.free(chunk1.?); + + const chunk2 = sorter.pop(); + try testing.expect(chunk2 != null); + try testing.expectEqualStrings(" ", chunk2.?); + testing.allocator.free(chunk2.?); + + const chunk3 = sorter.pop(); + try testing.expect(chunk3 != null); + try testing.expectEqualStrings("world", chunk3.?); + testing.allocator.free(chunk3.?); + + try testing.expect(sorter.isComplete()); +} + +test "FrameSorter: out-of-order gap still accepts sequential tail" { + var sorter = FrameSorter.init(testing.allocator); + defer sorter.deinit(); + + try sorter.push(6, "world", true); + try sorter.push(0, "hello", false); + try sorter.push(5, " ", false); + + const chunk1 = sorter.pop(); + try testing.expect(chunk1 != null); + try testing.expectEqualStrings("hello", chunk1.?); + testing.allocator.free(chunk1.?); + + const chunk2 = sorter.pop(); + try testing.expect(chunk2 != null); + try testing.expectEqualStrings(" ", chunk2.?); + testing.allocator.free(chunk2.?); + + const chunk3 = sorter.pop(); + try testing.expect(chunk3 != null); + try testing.expectEqualStrings("world", chunk3.?); + testing.allocator.free(chunk3.?); + + try testing.expect(sorter.isComplete()); +} + // RFC 9000 §4.5: final size validation test "FrameSorter: conflicting final size from FIN" { var sorter = FrameSorter.init(testing.allocator); @@ -1670,8 +1853,8 @@ test "SendStream: partial retransmit due to max_len" { } // Retransmit queue overflow: when MAX_RETRANSMIT_RANGES is exceeded, -// send_offset is lowered to cover all lost data (no silent data loss). -test "SendStream: retransmit queue overflow falls back to send_offset" { +// lost data is coalesced without rewinding send_offset. +test "SendStream: retransmit queue overflow coalesces ranges" { var ss = SendStream.init(testing.allocator, 0); defer ss.deinit(); @@ -1694,10 +1877,70 @@ test "SendStream: retransmit queue overflow falls back to send_offset" { // Queue one more — should trigger overflow fallback ss.queueRetransmit(1700, 50, false); - // After overflow: retransmit queue is cleared, send_offset lowered to earliest + // After overflow: retransmit queue is coalesced, while send_offset stays at + // the real high-water mark so flow-control credit is not double-counted. + try testing.expectEqual(@as(u8, 1), ss.retransmit_count); + try testing.expectEqual(@as(u64, 0), ss.retransmit_ranges[0].offset); + try testing.expectEqual(@as(u64, 1750), ss.retransmit_ranges[0].length); + try testing.expectEqual(@as(u64, 2048), ss.send_offset); + try testing.expect(ss.hasRetransmitData()); +} + +test "SendStream: contiguous ACK advances send offset after retransmit" { + var ss = SendStream.init(testing.allocator, 0); + defer ss.deinit(); + + const data = "x" ** 100; + try ss.writeData(data); + ss.send_offset = 20; + + ss.queueRetransmit(20, 80, false); + const retransmit = ss.popStreamFrame(100).?; + try testing.expectEqual(@as(u64, 20), retransmit.stream.offset); + try testing.expectEqual(@as(u64, 80), retransmit.stream.length); + try testing.expectEqual(@as(u64, 20), ss.send_offset); + + try ss.onAck(0, 100); + + try testing.expectEqual(@as(u64, 100), ss.ack_offset); + try testing.expectEqual(@as(u64, 100), ss.send_offset); + try testing.expect(!ss.hasData()); + try testing.expect(ss.popStreamFrame(100) == null); +} + +test "SendStream: ACK progress trims stale retransmit ranges" { + var ss = SendStream.init(testing.allocator, 0); + defer ss.deinit(); + + const data = "x" ** 100; + try ss.writeData(data); + ss.send_offset = 100; + ss.queueRetransmit(0, 80, false); + + try ss.onAck(0, 32); + + try testing.expectEqual(@as(u8, 1), ss.retransmit_count); + try testing.expectEqual(@as(u64, 32), ss.retransmit_ranges[0].offset); + try testing.expectEqual(@as(u64, 48), ss.retransmit_ranges[0].length); + + try ss.onAck(32, 48); try testing.expectEqual(@as(u8, 0), ss.retransmit_count); - try testing.expectEqual(@as(u64, 0), ss.send_offset); // min of all range offsets - try testing.expect(ss.hasData()); // still has data to send +} + +test "SendStream: shouldSendBlocked emits once per blocked limit" { + var ss = SendStream.init(testing.allocator, 0); + defer ss.deinit(); + + ss.send_window = 4; + try ss.writeData("abcdefgh"); + _ = ss.popStreamFrame(16).?; + + try testing.expectEqual(@as(?u64, 4), ss.shouldSendBlocked()); + try testing.expectEqual(@as(?u64, null), ss.shouldSendBlocked()); + + ss.updateSendWindow(6); + _ = ss.popStreamFrame(16).?; + try testing.expectEqual(@as(?u64, 6), ss.shouldSendBlocked()); } // RFC 9218 priority scheduling tests diff --git a/src/quic/tls13.zig b/src/quic/tls13.zig index fd40096..6c003ed 100644 --- a/src/quic/tls13.zig +++ b/src/quic/tls13.zig @@ -1,7 +1,7 @@ // TLS 1.3 handshake for QUIC (RFC 8446 + RFC 9001) // // Supports TLS_AES_128_GCM_SHA256 (0x1301) only. -// ECDSA P-256 for signatures, X25519 for key exchange. +// ECDSA P-256 or Ed25519 for signatures, X25519 for key exchange. // X.509 certificate chain validation via std.crypto.Certificate. const std = @import("std"); @@ -21,6 +21,7 @@ const Sha384 = crypto.hash.sha2.Sha384; const Sha512 = crypto.hash.sha2.Sha512; const X25519 = crypto.dh.X25519; const EcdsaP256Sha256 = crypto.sign.ecdsa.EcdsaP256Sha256; +const Ed25519 = crypto.sign.Ed25519; const Aes128Gcm = crypto.aead.aes_gcm.Aes128Gcm; // TLS 1.3 handshake message types @@ -54,6 +55,7 @@ const SIG_ECDSA_P256_SHA256: u16 = 0x0403; const SIG_RSA_PSS_RSAE_SHA256: u16 = 0x0804; const SIG_RSA_PSS_RSAE_SHA384: u16 = 0x0805; const SIG_RSA_PSS_RSAE_SHA512: u16 = 0x0806; +const SIG_ED25519: u16 = 0x0807; // Named groups const GROUP_SECP256R1: u16 = 0x0017; @@ -70,6 +72,11 @@ const CIPHER_SUITE_CHACHA20_POLY1305_SHA256: u16 = 0x1303; pub const EncryptionLevel = quic_crypto.EncryptionLevel; +pub const PrivateKeyAlgorithm = enum { + ecdsa_p256_sha256, + ed25519, +}; + // ─── CertificateVerify signature verification ──────────────────────── fn verifyCertificateVerifySignature( @@ -86,6 +93,14 @@ fn verifyCertificateVerifySignature( const sig = EcdsaP256Sha256.Signature.fromDer(sig_bytes) catch return error.BadCertificateVerify; sig.verify(signed_content, pub_key) catch return error.BadCertificateVerify; }, + SIG_ED25519 => { + if (pub_key_algo != .curveEd25519) return error.BadCertificateVerify; + if (pub_key_bytes.len != Ed25519.PublicKey.encoded_length) return error.BadCertificateVerify; + if (sig_bytes.len != Ed25519.Signature.encoded_length) return error.BadCertificateVerify; + const pub_key = Ed25519.PublicKey.fromBytes(pub_key_bytes[0..Ed25519.PublicKey.encoded_length].*) catch return error.BadCertificateVerify; + const sig = Ed25519.Signature.fromBytes(sig_bytes[0..Ed25519.Signature.encoded_length].*); + sig.verify(signed_content, pub_key) catch return error.BadCertificateVerify; + }, SIG_RSA_PSS_RSAE_SHA256 => verifyRsaPss(pub_key_bytes, pub_key_algo, sig_bytes, signed_content, Sha256) catch return error.BadCertificateVerify, SIG_RSA_PSS_RSAE_SHA384 => verifyRsaPss(pub_key_bytes, pub_key_algo, sig_bytes, signed_content, Sha384) catch return error.BadCertificateVerify, SIG_RSA_PSS_RSAE_SHA512 => verifyRsaPss(pub_key_bytes, pub_key_algo, sig_bytes, signed_content, Sha512) catch return error.BadCertificateVerify, @@ -505,7 +520,8 @@ pub const SessionTicket = struct { pub const TlsConfig = struct { cert_chain_der: []const []const u8, // DER-encoded certificates - private_key_bytes: []const u8, // Raw ECDSA P-256 private key (32 bytes) + private_key_bytes: []const u8, // Raw P-256 scalar or Ed25519 seed (32 bytes) + private_key_algorithm: PrivateKeyAlgorithm = .ecdsa_p256_sha256, alpn: []const []const u8, server_name: ?[]const u8 = null, // SNI (client only) skip_cert_verify: bool = true, // Skip X.509 chain + CertificateVerify validation @@ -1597,6 +1613,7 @@ pub const Tls13Handshake = struct { &buf, transcript_hash, self.config.private_key_bytes, + self.config.private_key_algorithm, true, // is_server ) catch return error.InternalError; @@ -2125,8 +2142,10 @@ fn buildClientHello( pos += 65; // signature_algorithms extension - pos = writeExtHeader(buf, pos, @intFromEnum(ExtType.signature_algorithms), 2 + 8); - writeU16(buf[pos..], 8); // list length (4 algorithms x 2 bytes) + pos = writeExtHeader(buf, pos, @intFromEnum(ExtType.signature_algorithms), 2 + 10); + writeU16(buf[pos..], 10); // list length (5 algorithms x 2 bytes) + pos += 2; + writeU16(buf[pos..], SIG_ED25519); pos += 2; writeU16(buf[pos..], SIG_ECDSA_P256_SHA256); pos += 2; @@ -2453,6 +2472,7 @@ fn buildCertificateVerify( buf: []u8, transcript_hash: [32]u8, private_key_bytes: []const u8, + private_key_algorithm: PrivateKeyAlgorithm, is_server: bool, ) ![]const u8 { // Build the content to sign: @@ -2464,29 +2484,41 @@ fn buildCertificateVerify( sign_content[64 + 33] = 0x00; @memcpy(sign_content[64 + 34 ..][0..32], &transcript_hash); - // Sign with ECDSA P-256 if (private_key_bytes.len != 32) return error.InternalError; - const secret_key = EcdsaP256Sha256.SecretKey.fromBytes(private_key_bytes[0..32].*) catch return error.InternalError; - const key_pair = EcdsaP256Sha256.KeyPair.fromSecretKey(secret_key) catch return error.InternalError; - - const sig = key_pair.sign(&sign_content, null) catch return error.InternalError; - - var der_buf: [EcdsaP256Sha256.Signature.der_encoded_length_max]u8 = undefined; - const sig_bytes = sig.toDer(&der_buf); + var sig_storage: [EcdsaP256Sha256.Signature.der_encoded_length_max]u8 = undefined; + var sig_len: usize = 0; + const sig_algo: u16 = switch (private_key_algorithm) { + .ecdsa_p256_sha256 => sig_algo: { + const secret_key = EcdsaP256Sha256.SecretKey.fromBytes(private_key_bytes[0..32].*) catch return error.InternalError; + const key_pair = EcdsaP256Sha256.KeyPair.fromSecretKey(secret_key) catch return error.InternalError; + const sig = key_pair.sign(&sign_content, null) catch return error.InternalError; + const sig_bytes = sig.toDer(&sig_storage); + sig_len = sig_bytes.len; + break :sig_algo SIG_ECDSA_P256_SHA256; + }, + .ed25519 => sig_algo: { + const key_pair = Ed25519.KeyPair.generateDeterministic(private_key_bytes[0..32].*) catch return error.InternalError; + const sig = key_pair.sign(&sign_content, null) catch return error.InternalError; + const sig_bytes = sig.toBytes(); + @memcpy(sig_storage[0..sig_bytes.len], &sig_bytes); + sig_len = sig_bytes.len; + break :sig_algo SIG_ED25519; + }, + }; // Build message var pos: usize = 4; // reserve for header // signature_algorithm - writeU16(buf[pos..], SIG_ECDSA_P256_SHA256); + writeU16(buf[pos..], sig_algo); pos += 2; // signature length + signature - writeU16(buf[pos..], @intCast(sig_bytes.len)); + writeU16(buf[pos..], @intCast(sig_len)); pos += 2; - @memcpy(buf[pos..][0..sig_bytes.len], sig_bytes); - pos += sig_bytes.len; + @memcpy(buf[pos..][0..sig_len], sig_storage[0..sig_len]); + pos += sig_len; // Fill in message header const body_len: u24 = @intCast(pos - 4); @@ -2683,6 +2715,53 @@ pub fn extractPkcs8EcPrivateKey(der: []const u8) ![]const u8 { return extractEcPrivateKey(der[pos..][0..octet_len]); } +fn readDerValue(der: []const u8, pos: *usize, expected_tag: u8) ![]const u8 { + if (pos.* >= der.len or der[pos.*] != expected_tag) return error.DecodeError; + pos.* += 1; + if (pos.* >= der.len) return error.DecodeError; + + var len: usize = der[pos.*]; + pos.* += 1; + if (len & 0x80 != 0) { + const num = len & 0x7f; + if (num == 0 or num > @sizeOf(usize) or pos.* + num > der.len) return error.DecodeError; + len = 0; + for (0..num) |i| { + len = (len << 8) | der[pos.* + i]; + } + pos.* += num; + } + + if (pos.* + len > der.len) return error.DecodeError; + const value = der[pos.*..][0..len]; + pos.* += len; + return value; +} + +// Extract the 32-byte Ed25519 seed from a PKCS#8 DER-encoded private key. +// RFC 8410: AlgorithmIdentifier OID 1.3.101.112, PrivateKey OCTET STRING. +pub fn extractEd25519PrivateKey(der: []const u8) ![]const u8 { + var outer_pos: usize = 0; + const outer = try readDerValue(der, &outer_pos, 0x30); + + var pos: usize = 0; + _ = try readDerValue(outer, &pos, 0x02); // version + + const algorithm = try readDerValue(outer, &pos, 0x30); + var algorithm_pos: usize = 0; + const oid = try readDerValue(algorithm, &algorithm_pos, 0x06); + if (!std.mem.eql(u8, oid, &.{ 0x2b, 0x65, 0x70 })) return error.DecodeError; + if (algorithm_pos != algorithm.len) return error.DecodeError; + + const private_key = try readDerValue(outer, &pos, 0x04); + if (private_key.len == 32) return private_key; + + var nested_pos: usize = 0; + const nested = try readDerValue(private_key, &nested_pos, 0x04); + if (nested_pos != private_key.len or nested.len != 32) return error.DecodeError; + return nested; +} + // ─── Tests ─────────────────────────────────────────────────────────── test "TranscriptHash: basic usage" {