diff --git a/roborock/map/b01_map_parser.py b/roborock/map/b01_map_parser.py index 5e0a4bad..2a0f0d17 100644 --- a/roborock/map/b01_map_parser.py +++ b/roborock/map/b01_map_parser.py @@ -10,14 +10,9 @@ `roborock/map/proto/b01_scmap.proto`. """ -import base64 -import binascii -import hashlib import io -import zlib from dataclasses import dataclass -from Crypto.Cipher import AES from google.protobuf.message import DecodeError, Message from PIL import Image from vacuum_map_parser_base.config.image_config import ImageConfig @@ -25,7 +20,7 @@ from roborock.exceptions import RoborockException from roborock.map.proto.b01_scmap_pb2 import RobotMap # type: ignore[attr-defined] -from roborock.protocol import Utils +from roborock.protocols.b01_map_protocol import decode_map_response from .map_parser import ParsedMapData @@ -48,7 +43,7 @@ def __init__(self, config: B01MapParserConfig | None = None) -> None: def parse(self, raw_payload: bytes, *, serial: str, model: str) -> ParsedMapData: """Parse a raw MAP_RESPONSE payload and return a PNG + MapData.""" - inflated = _decode_b01_map_payload(raw_payload, serial=serial, model=model) + inflated = decode_map_response(raw_payload, serial=serial, model=model) parsed = _parse_scmap_payload(inflated) size_x, size_y, grid = _extract_grid(parsed) room_names = _extract_room_names(parsed) @@ -78,43 +73,6 @@ def parse(self, raw_payload: bytes, *, serial: str, model: str) -> ParsedMapData ) -def _derive_map_key(serial: str, model: str) -> bytes: - """Derive the B01/Q7 map decrypt key from serial + model.""" - model_suffix = model.split(".")[-1] - model_key = (model_suffix + "0" * 16)[:16].encode() - material = f"{serial}+{model_suffix}+{serial}".encode() - encrypted = Utils.encrypt_ecb(material, model_key) - md5 = hashlib.md5(base64.b64encode(encrypted), usedforsecurity=False).hexdigest() - return md5[8:24].encode() - - -def _decode_base64_payload(raw_payload: bytes) -> bytes: - blob = raw_payload.strip() - padded = blob + b"=" * (-len(blob) % 4) - try: - return base64.b64decode(padded, validate=True) - except binascii.Error as err: - raise RoborockException("Failed to decode B01 map payload") from err - - -def _decode_b01_map_payload(raw_payload: bytes, *, serial: str, model: str) -> bytes: - """Decode raw B01 `MAP_RESPONSE` payload into inflated SCMap bytes.""" - # TODO: Move this lower-level B01 transport decode under `roborock.protocols` - # so this module only handles SCMap parsing/rendering. - encrypted_payload = _decode_base64_payload(raw_payload) - if len(encrypted_payload) % AES.block_size != 0: - raise RoborockException("Unexpected encrypted B01 map payload length") - - map_key = _derive_map_key(serial, model) - - try: - compressed_hex = Utils.decrypt_ecb(encrypted_payload, map_key).decode("ascii") - compressed_payload = bytes.fromhex(compressed_hex) - return zlib.decompress(compressed_payload) - except (ValueError, UnicodeDecodeError, zlib.error) as err: - raise RoborockException("Failed to decode B01 map payload") from err - - def _parse_proto(blob: bytes, message: Message, *, context: str) -> None: try: message.ParseFromString(blob) diff --git a/roborock/protocols/b01_map_protocol.py b/roborock/protocols/b01_map_protocol.py new file mode 100644 index 00000000..b4e380cb --- /dev/null +++ b/roborock/protocols/b01_map_protocol.py @@ -0,0 +1,46 @@ +"""B01/Q7 map transport decoding helpers.""" + +import base64 +import binascii +import hashlib +import zlib + +from Crypto.Cipher import AES + +from roborock.exceptions import RoborockException +from roborock.protocol import Utils + + +def derive_map_key(serial: str, model: str) -> bytes: + """Derive the B01/Q7 map decrypt key from serial + model.""" + model_suffix = model.split(".")[-1] + model_key = (model_suffix + "0" * 16)[:16].encode() + material = f"{serial}+{model_suffix}+{serial}".encode() + encrypted = Utils.encrypt_ecb(material, model_key) + md5 = hashlib.md5(base64.b64encode(encrypted), usedforsecurity=False).hexdigest() + return md5[8:24].encode() + + +def decode_map_response(raw_payload: bytes, *, serial: str, model: str) -> bytes: + """Decode raw B01 ``MAP_RESPONSE`` payload into inflated SCMap bytes.""" + encrypted_payload = _decode_base64_payload(raw_payload) + if len(encrypted_payload) % AES.block_size != 0: + raise RoborockException("Unexpected encrypted B01 map payload length") + + map_key = derive_map_key(serial, model) + + try: + compressed_hex = Utils.decrypt_ecb(encrypted_payload, map_key).decode("ascii") + compressed_payload = bytes.fromhex(compressed_hex) + return zlib.decompress(compressed_payload) + except (ValueError, UnicodeDecodeError, zlib.error) as err: + raise RoborockException("Failed to decode B01 map payload") from err + + +def _decode_base64_payload(raw_payload: bytes) -> bytes: + blob = raw_payload.strip() + padded = blob + b"=" * (-len(blob) % 4) + try: + return base64.b64decode(padded, validate=True) + except binascii.Error as err: + raise RoborockException("Failed to decode B01 map payload") from err diff --git a/tests/protocols/test_b01_map_protocol.py b/tests/protocols/test_b01_map_protocol.py new file mode 100644 index 00000000..1d316fdc --- /dev/null +++ b/tests/protocols/test_b01_map_protocol.py @@ -0,0 +1,35 @@ +import base64 +import gzip +import zlib +from pathlib import Path + +import pytest +from Crypto.Cipher import AES +from Crypto.Util.Padding import pad + +from roborock.exceptions import RoborockException +from roborock.protocols.b01_map_protocol import decode_map_response, derive_map_key + +FIXTURE = Path(__file__).resolve().parent.parent / "map" / "testdata" / "raw-mqtt-map301.bin.inflated.bin.gz" + + +def _build_payload(inflated: bytes, *, serial: str, model: str) -> bytes: + compressed_hex = zlib.compress(inflated).hex().encode() + map_key = derive_map_key(serial, model) + encrypted = AES.new(map_key, AES.MODE_ECB).encrypt(pad(compressed_hex, AES.block_size)) + return base64.b64encode(encrypted) + + +def test_decode_map_response_decodes_fixture_payload() -> None: + serial = "testsn012345" + model = "roborock.vacuum.sc05" + inflated = gzip.decompress(FIXTURE.read_bytes()) + + payload = _build_payload(inflated, serial=serial, model=model) + + assert decode_map_response(payload, serial=serial, model=model) == inflated + + +def test_decode_map_response_rejects_invalid_payload() -> None: + with pytest.raises(RoborockException, match="Failed to decode B01 map payload"): + decode_map_response(b"not a map", serial="testsn012345", model="roborock.vacuum.sc05")