Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 2 additions & 44 deletions roborock/map/b01_map_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,17 @@
`roborock/map/proto/b01_scmap.proto`.
"""

import base64
import binascii
import hashlib
import io
import zlib
from dataclasses import dataclass

from Crypto.Cipher import AES
from google.protobuf.message import DecodeError, Message
from PIL import Image
from vacuum_map_parser_base.config.image_config import ImageConfig
from vacuum_map_parser_base.map_data import ImageData, MapData

from roborock.exceptions import RoborockException
from roborock.map.proto.b01_scmap_pb2 import RobotMap # type: ignore[attr-defined]
from roborock.protocol import Utils
from roborock.protocols.b01_map_protocol import decode_map_response

from .map_parser import ParsedMapData

Expand All @@ -48,7 +43,7 @@ def __init__(self, config: B01MapParserConfig | None = None) -> None:

def parse(self, raw_payload: bytes, *, serial: str, model: str) -> ParsedMapData:
"""Parse a raw MAP_RESPONSE payload and return a PNG + MapData."""
inflated = _decode_b01_map_payload(raw_payload, serial=serial, model=model)
inflated = decode_map_response(raw_payload, serial=serial, model=model)
parsed = _parse_scmap_payload(inflated)
size_x, size_y, grid = _extract_grid(parsed)
room_names = _extract_room_names(parsed)
Expand Down Expand Up @@ -78,43 +73,6 @@ def parse(self, raw_payload: bytes, *, serial: str, model: str) -> ParsedMapData
)


def _derive_map_key(serial: str, model: str) -> bytes:
"""Derive the B01/Q7 map decrypt key from serial + model."""
model_suffix = model.split(".")[-1]
model_key = (model_suffix + "0" * 16)[:16].encode()
material = f"{serial}+{model_suffix}+{serial}".encode()
encrypted = Utils.encrypt_ecb(material, model_key)
md5 = hashlib.md5(base64.b64encode(encrypted), usedforsecurity=False).hexdigest()
return md5[8:24].encode()


def _decode_base64_payload(raw_payload: bytes) -> bytes:
blob = raw_payload.strip()
padded = blob + b"=" * (-len(blob) % 4)
try:
return base64.b64decode(padded, validate=True)
except binascii.Error as err:
raise RoborockException("Failed to decode B01 map payload") from err


def _decode_b01_map_payload(raw_payload: bytes, *, serial: str, model: str) -> bytes:
"""Decode raw B01 `MAP_RESPONSE` payload into inflated SCMap bytes."""
# TODO: Move this lower-level B01 transport decode under `roborock.protocols`
# so this module only handles SCMap parsing/rendering.
encrypted_payload = _decode_base64_payload(raw_payload)
if len(encrypted_payload) % AES.block_size != 0:
raise RoborockException("Unexpected encrypted B01 map payload length")

map_key = _derive_map_key(serial, model)

try:
compressed_hex = Utils.decrypt_ecb(encrypted_payload, map_key).decode("ascii")
compressed_payload = bytes.fromhex(compressed_hex)
return zlib.decompress(compressed_payload)
except (ValueError, UnicodeDecodeError, zlib.error) as err:
raise RoborockException("Failed to decode B01 map payload") from err


def _parse_proto(blob: bytes, message: Message, *, context: str) -> None:
try:
message.ParseFromString(blob)
Expand Down
46 changes: 46 additions & 0 deletions roborock/protocols/b01_map_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""B01/Q7 map transport decoding helpers."""

import base64
import binascii
import hashlib
import zlib

from Crypto.Cipher import AES

from roborock.exceptions import RoborockException
from roborock.protocol import Utils


def derive_map_key(serial: str, model: str) -> bytes:
"""Derive the B01/Q7 map decrypt key from serial + model."""
model_suffix = model.split(".")[-1]
model_key = (model_suffix + "0" * 16)[:16].encode()
material = f"{serial}+{model_suffix}+{serial}".encode()
encrypted = Utils.encrypt_ecb(material, model_key)
md5 = hashlib.md5(base64.b64encode(encrypted), usedforsecurity=False).hexdigest()
return md5[8:24].encode()


def decode_map_response(raw_payload: bytes, *, serial: str, model: str) -> bytes:
"""Decode raw B01 ``MAP_RESPONSE`` payload into inflated SCMap bytes."""
encrypted_payload = _decode_base64_payload(raw_payload)
if len(encrypted_payload) % AES.block_size != 0:
raise RoborockException("Unexpected encrypted B01 map payload length")

map_key = derive_map_key(serial, model)

try:
compressed_hex = Utils.decrypt_ecb(encrypted_payload, map_key).decode("ascii")
compressed_payload = bytes.fromhex(compressed_hex)
return zlib.decompress(compressed_payload)
except (ValueError, UnicodeDecodeError, zlib.error) as err:
raise RoborockException("Failed to decode B01 map payload") from err


def _decode_base64_payload(raw_payload: bytes) -> bytes:
blob = raw_payload.strip()
padded = blob + b"=" * (-len(blob) % 4)
try:
return base64.b64decode(padded, validate=True)
except binascii.Error as err:
raise RoborockException("Failed to decode B01 map payload") from err
35 changes: 35 additions & 0 deletions tests/protocols/test_b01_map_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import base64
import gzip
import zlib
from pathlib import Path

import pytest
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad

from roborock.exceptions import RoborockException
from roborock.protocols.b01_map_protocol import decode_map_response, derive_map_key

FIXTURE = Path(__file__).resolve().parent.parent / "map" / "testdata" / "raw-mqtt-map301.bin.inflated.bin.gz"


def _build_payload(inflated: bytes, *, serial: str, model: str) -> bytes:
compressed_hex = zlib.compress(inflated).hex().encode()
map_key = derive_map_key(serial, model)
encrypted = AES.new(map_key, AES.MODE_ECB).encrypt(pad(compressed_hex, AES.block_size))
return base64.b64encode(encrypted)


def test_decode_map_response_decodes_fixture_payload() -> None:
serial = "testsn012345"
model = "roborock.vacuum.sc05"
inflated = gzip.decompress(FIXTURE.read_bytes())

payload = _build_payload(inflated, serial=serial, model=model)

assert decode_map_response(payload, serial=serial, model=model) == inflated


def test_decode_map_response_rejects_invalid_payload() -> None:
with pytest.raises(RoborockException, match="Failed to decode B01 map payload"):
decode_map_response(b"not a map", serial="testsn012345", model="roborock.vacuum.sc05")