From 2796235260aa5c5672af807cd161e1b564ac933b Mon Sep 17 00:00:00 2001 From: Hannes Tschofenig Date: Sat, 14 Mar 2026 20:02:34 +0100 Subject: [PATCH] Allowing psk_id also in the unprotected header. --- cwt/cose.py | 3 +- cwt/recipient_algs/hpke.py | 14 +++++--- tests/test_cose_hpke.py | 71 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 83 insertions(+), 5 deletions(-) diff --git a/cwt/cose.py b/cwt/cose.py index 838deaf..0578b97 100644 --- a/cwt/cose.py +++ b/cwt/cose.py @@ -699,7 +699,8 @@ def _validate_cose_message( if not isinstance(v, (bytes, bytearray)): raise ValueError("ek (-4) must be bstr.") if k == -5: # psk_id - raise ValueError("psk_id (-5) must be placed only in the protected header.") + if not isinstance(v, (bytes, bytearray)): + raise ValueError("psk_id (-5) must be bstr.") h[k] = v if len(h) != len(p) + len(u): raise ValueError("The same keys are both in protected and unprotected headers.") diff --git a/cwt/recipient_algs/hpke.py b/cwt/recipient_algs/hpke.py index c2afa2f..943fd85 100644 --- a/cwt/recipient_algs/hpke.py +++ b/cwt/recipient_algs/hpke.py @@ -97,13 +97,20 @@ def _build_recipient_info(self, content_alg: int) -> bytes: ] return self._dumps(recipient_structure) + def _get_psk_id(self) -> Optional[bytes]: + # Prefer the protected header when both buckets are available, but + # accept unprotected psk_id for interoperability. + psk_id = self._protected.get(-5, None) if isinstance(self._protected, dict) else None + if psk_id is None and isinstance(self._unprotected, dict): + psk_id = self._unprotected.get(-5, None) + return psk_id + def encode(self, plaintext: bytes = b"", aad: bytes = b"") -> Tuple[List[Any], Optional[COSEKeyInterface]]: if self._recipient_key is None: raise ValueError("recipient_key should be set in advance.") self._kem_key = self._to_kem_key(self._recipient_key) try: - # psk_id MUST be in the protected header (draft-ietf-cose-hpke) - psk_id = self._protected.get(-5, None) if isinstance(self._protected, dict) else None + psk_id = self._get_psk_id() if psk_id is not None and not isinstance(psk_id, (bytes, bytearray)): raise EncodeError("psk_id (-5) must be bstr.") if self._psk is not None and psk_id is None: @@ -148,8 +155,7 @@ def decode( if not isinstance(ek, (bytes, bytearray)): raise DecodeError("ek (-4) must be bstr.") try: - # psk_id MUST be in the protected header (draft-ietf-cose-hpke) - psk_id = self._protected.get(-5, None) if isinstance(self._protected, dict) else None + psk_id = self._get_psk_id() if psk_id is not None and not isinstance(psk_id, (bytes, bytearray)): raise DecodeError("psk_id (-5) must be bstr.") if self._psk is not None and psk_id is None: diff --git a/tests/test_cose_hpke.py b/tests/test_cose_hpke.py index 9891826..13b7c3e 100644 --- a/tests/test_cose_hpke.py +++ b/tests/test_cose_hpke.py @@ -223,6 +223,38 @@ def test_cose_hpke_encrypt0_with_psk_id_roundtrip(self): recipient = COSE.new() assert b"This is the content." == recipient.decode(encoded, rsk, hpke_psk=b"secret-psk") + def test_cose_hpke_encrypt0_with_unprotected_psk_id_roundtrip(self): + rpk = COSEKey.from_jwk( + { + "kty": "EC", + "kid": "01", + "crv": "P-256", + "x": "usWxHK2PmfnHKwXPS54m0kTcGJ90UiglWiGahtagnv8", + "y": "IBOL-C3BttVivg-lSreASjpkttcsz-1rb7btKLv8EX4", + } + ) + sender = COSE.new() + encoded = sender.encode_and_encrypt( + b"This is the content.", + rpk, + protected={COSEHeaders.ALG: COSEAlgs.HPKE_0}, + unprotected={COSEHeaders.KID: b"01", COSEHeaders.PSK_ID: b"psk-01"}, + hpke_psk=b"secret-psk", + ) + + rsk = COSEKey.from_jwk( + { + "kty": "EC", + "kid": "01", + "crv": "P-256", + "x": "usWxHK2PmfnHKwXPS54m0kTcGJ90UiglWiGahtagnv8", + "y": "IBOL-C3BttVivg-lSreASjpkttcsz-1rb7btKLv8EX4", + "d": "V8kgd2ZBRuh2dgyVINBUqpPDr7BOMGcF22CQMIUHtNM", + } + ) + recipient = COSE.new() + assert b"This is the content." == recipient.decode(encoded, rsk, hpke_psk=b"secret-psk") + @pytest.mark.parametrize( "alg", [COSEAlgs.HPKE_1], @@ -705,6 +737,45 @@ def test_cose_hpke_ke_with_psk_roundtrip(self): recipient = COSE.new() assert b"This is the content." == recipient.decode(encoded, rsk, hpke_psk=b"secret-psk") + def test_cose_hpke_ke_with_unprotected_psk_roundtrip(self): + enc_key = COSEKey.from_symmetric_key(alg="A128GCM") + rpk = COSEKey.from_jwk( + { + "kty": "EC", + "kid": "01", + "crv": "P-256", + "x": "usWxHK2PmfnHKwXPS54m0kTcGJ90UiglWiGahtagnv8", + "y": "IBOL-C3BttVivg-lSreASjpkttcsz-1rb7btKLv8EX4", + } + ) + + r = Recipient.new( + protected={COSEHeaders.ALG: COSEAlgs.HPKE_0_KE}, + unprotected={COSEHeaders.KID: b"01", COSEHeaders.PSK_ID: b"psk-01"}, + recipient_key=rpk, + hpke_psk=b"secret-psk", + ) + sender = COSE.new() + encoded = sender.encode_and_encrypt( + b"This is the content.", + enc_key, + protected={COSEHeaders.ALG: COSEAlgs.A128GCM}, + recipients=[r], + ) + + rsk = COSEKey.from_jwk( + { + "kty": "EC", + "kid": "01", + "crv": "P-256", + "x": "usWxHK2PmfnHKwXPS54m0kTcGJ90UiglWiGahtagnv8", + "y": "IBOL-C3BttVivg-lSreASjpkttcsz-1rb7btKLv8EX4", + "d": "V8kgd2ZBRuh2dgyVINBUqpPDr7BOMGcF22CQMIUHtNM", + } + ) + recipient = COSE.new() + assert b"This is the content." == recipient.decode(encoded, rsk, hpke_psk=b"secret-psk") + def test_cose_hpke_ke_integrated_as_recipient_should_fail(self): """Integrated HPKE algorithms cannot be used as recipient algorithms.""" rpk = COSEKey.from_jwk(