diff --git a/src/h2/stream.py b/src/h2/stream.py index 9b5bce78..5aeef858 100644 --- a/src/h2/stream.py +++ b/src/h2/stream.py @@ -1364,15 +1364,23 @@ def _initialize_content_length(self, headers: Iterable[Header]) -> None: self._expected_content_length = 0 return + content_length = None for n, v in headers: if n == b"content-length": try: - self._expected_content_length = int(v, 10) + header_content_length = int(v, 10) except ValueError as err: msg = f"Invalid content-length header: {v!r}" raise ProtocolError(msg) from err - return + if content_length is not None and header_content_length != content_length: + msg = "Conflicting content-length headers received" + raise ProtocolError(msg) + + content_length = header_content_length + + if content_length is not None: + self._expected_content_length = content_length def _track_content_length(self, length: int, end_stream: bool) -> None: """ diff --git a/tests/test_invalid_content_lengths.py b/tests/test_invalid_content_lengths.py index 39401ea2..584365c1 100644 --- a/tests/test_invalid_content_lengths.py +++ b/tests/test_invalid_content_lengths.py @@ -39,6 +39,60 @@ class TestInvalidContentLengths: ] server_config = h2.config.H2Configuration(client_side=False) + @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) + def test_duplicate_matching_content_length(self, frame_factory, request_headers) -> None: + """ + Remote peers sending duplicate matching content-length headers are + accepted. + """ + c = h2.connection.H2Connection(config=self.server_config) + c.initiate_connection() + c.receive_data(frame_factory.preamble()) + c.clear_outbound_data_buffer() + + headers = frame_factory.build_headers_frame( + headers=[*request_headers, request_headers[-1]], + ) + data = frame_factory.build_data_frame( + data=b"\x01"*15, + flags=["END_STREAM"], + ) + events = c.receive_data(headers.serialize() + data.serialize()) + + assert isinstance(events[0], h2.events.RequestReceived) + assert isinstance(events[1], h2.events.DataReceived) + assert isinstance(events[2], h2.events.StreamEnded) + assert not c.data_to_send() + + @pytest.mark.parametrize( + ("request_headers", "conflicting_header"), + [ + (example_request_headers, ("content-length", "14")), + (example_request_headers_bytes, (b"content-length", b"14")), + ], + ) + def test_duplicate_conflicting_content_length(self, frame_factory, request_headers, conflicting_header) -> None: + """ + Remote peers sending duplicate conflicting content-length headers cause + Protocol Errors. + """ + c = h2.connection.H2Connection(config=self.server_config) + c.initiate_connection() + c.receive_data(frame_factory.preamble()) + c.clear_outbound_data_buffer() + + headers = frame_factory.build_headers_frame( + headers=[*request_headers, conflicting_header], + ) + with pytest.raises(h2.exceptions.ProtocolError, match="Conflicting content-length"): + c.receive_data(headers.serialize()) + + expected_frame = frame_factory.build_goaway_frame( + last_stream_id=1, + error_code=h2.errors.ErrorCodes.PROTOCOL_ERROR, + ) + assert c.data_to_send() == expected_frame.serialize() + @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) def test_too_much_data(self, frame_factory, request_headers) -> None: """