-
Notifications
You must be signed in to change notification settings - Fork 75
feat: Support both a01 and v1 device types with traits #425
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
f541de5
4370d30
1a0d3df
c589925
9165448
440b0e2
ba06e49
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| """Thin wrapper around the MQTT channel for Roborock A01 devices.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import logging | ||
| from typing import Any, overload | ||
|
|
||
| from roborock.protocols.a01_protocol import ( | ||
| decode_rpc_response, | ||
| encode_mqtt_payload, | ||
| ) | ||
| from roborock.roborock_message import RoborockDyadDataProtocol, RoborockZeoProtocol | ||
|
|
||
| from .mqtt_channel import MqttChannel | ||
|
|
||
| _LOGGER = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @overload | ||
| async def send_decoded_command( | ||
| mqtt_channel: MqttChannel, | ||
| params: dict[RoborockDyadDataProtocol, Any], | ||
| ) -> dict[RoborockDyadDataProtocol, Any]: | ||
| ... | ||
|
|
||
|
|
||
| @overload | ||
| async def send_decoded_command( | ||
| mqtt_channel: MqttChannel, | ||
| params: dict[RoborockZeoProtocol, Any], | ||
| ) -> dict[RoborockZeoProtocol, Any]: | ||
| ... | ||
|
|
||
|
|
||
| async def send_decoded_command( | ||
| mqtt_channel: MqttChannel, | ||
| params: dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any], | ||
| ) -> dict[RoborockDyadDataProtocol, Any] | dict[RoborockZeoProtocol, Any]: | ||
| """Send a command on the MQTT channel and get a decoded response.""" | ||
| _LOGGER.debug("Sending MQTT command: %s", params) | ||
| roborock_message = encode_mqtt_payload(params) | ||
| response = await mqtt_channel.send_message(roborock_message) | ||
| return decode_rpc_response(response) # type: ignore[return-value] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| """Low-level interface for connections to Roborock devices.""" | ||
|
|
||
| import logging | ||
| from collections.abc import Callable | ||
| from typing import Protocol | ||
|
|
||
| from roborock.roborock_message import RoborockMessage | ||
|
|
||
| _LOGGER = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class Channel(Protocol): | ||
| """A generic channel for establishing a connection with a Roborock device. | ||
|
|
||
| Individual channel implementations have their own methods for speaking to | ||
| the device that hide some of the protocol specific complexity, but they | ||
| are still specialized for the device type and protocol. | ||
| """ | ||
|
|
||
| @property | ||
| def is_connected(self) -> bool: | ||
| """Return true if the channel is connected.""" | ||
| ... | ||
|
|
||
| async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]: | ||
| """Subscribe to messages from the device.""" | ||
| ... |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,101 +4,72 @@ | |
| until the API is stable. | ||
| """ | ||
|
|
||
| import enum | ||
| import logging | ||
| from collections.abc import Callable | ||
| from functools import cached_property | ||
|
|
||
| from roborock.containers import ( | ||
| HomeDataDevice, | ||
| HomeDataProduct, | ||
| ModelStatus, | ||
| S7MaxVStatus, | ||
| Status, | ||
| UserData, | ||
| ) | ||
| from abc import ABC | ||
| from collections.abc import Callable, Mapping | ||
| from types import MappingProxyType | ||
|
|
||
| from roborock.containers import HomeDataDevice | ||
| from roborock.roborock_message import RoborockMessage | ||
| from roborock.roborock_typing import RoborockCommand | ||
|
|
||
| from .v1_channel import V1Channel | ||
| from .channel import Channel | ||
| from .traits.trait import Trait | ||
|
|
||
| _LOGGER = logging.getLogger(__name__) | ||
|
|
||
| __all__ = [ | ||
| "RoborockDevice", | ||
| "DeviceVersion", | ||
| ] | ||
|
|
||
|
|
||
| class DeviceVersion(enum.StrEnum): | ||
| """Enum for device versions.""" | ||
|
|
||
| V1 = "1.0" | ||
| A01 = "A01" | ||
| UNKNOWN = "unknown" | ||
|
|
||
| class RoborockDevice(ABC): | ||
| """A generic channel for establishing a connection with a Roborock device. | ||
|
|
||
| class RoborockDevice: | ||
| """Unified Roborock device class with automatic connection setup.""" | ||
| Individual channel implementations have their own methods for speaking to | ||
| the device that hide some of the protocol specific complexity, but they | ||
| are still specialized for the device type and protocol. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| user_data: UserData, | ||
| device_info: HomeDataDevice, | ||
| product_info: HomeDataProduct, | ||
| v1_channel: V1Channel, | ||
| channel: Channel, | ||
| traits: list[Trait], | ||
| ) -> None: | ||
| """Initialize the RoborockDevice. | ||
|
|
||
| The device takes ownership of the V1 channel for communication with the device. | ||
| Use `connect()` to establish the connection, which will set up the appropriate | ||
| protocol channel. Use `close()` to clean up all connections. | ||
| """ | ||
| self._user_data = user_data | ||
| self._device_info = device_info | ||
| self._product_info = product_info | ||
| self._v1_channel = v1_channel | ||
| self._duid = device_info.duid | ||
| self._name = device_info.name | ||
| self._channel = channel | ||
| self._unsub: Callable[[], None] | None = None | ||
| self._trait_map = {trait.name: trait for trait in traits} | ||
| if len(self._trait_map) != len(traits): | ||
| raise ValueError("Duplicate trait names found in traits list") | ||
|
|
||
| @property | ||
| def duid(self) -> str: | ||
| """Return the device unique identifier (DUID).""" | ||
| return self._device_info.duid | ||
| return self._duid | ||
|
|
||
| @property | ||
| def name(self) -> str: | ||
| """Return the device name.""" | ||
| return self._device_info.name | ||
|
|
||
| @cached_property | ||
| def device_version(self) -> str: | ||
| """Return the device version. | ||
|
|
||
| At the moment this is a simple check against the product version (pv) of the device | ||
| and used as a placeholder for upcoming functionality for devices that will behave | ||
| differently based on the version and capabilities. | ||
| """ | ||
| if self._device_info.pv == DeviceVersion.V1.value: | ||
| return DeviceVersion.V1 | ||
| elif self._device_info.pv == DeviceVersion.A01.value: | ||
| return DeviceVersion.A01 | ||
| _LOGGER.warning( | ||
| "Unknown device version %s for device %s, using default UNKNOWN", | ||
| self._device_info.pv, | ||
| self._device_info.name, | ||
| ) | ||
| return DeviceVersion.UNKNOWN | ||
| return self._name | ||
|
|
||
| @property | ||
| def is_connected(self) -> bool: | ||
| """Return whether the device is connected.""" | ||
| return self._v1_channel.is_mqtt_connected or self._v1_channel.is_local_connected | ||
| return self._channel.is_connected | ||
|
|
||
| async def connect(self) -> None: | ||
| """Connect to the device using the appropriate protocol channel.""" | ||
| if self._unsub: | ||
| raise ValueError("Already connected to the device") | ||
| self._unsub = await self._v1_channel.subscribe(self._on_message) | ||
| self._unsub = await self._channel.subscribe(self._on_message) | ||
| _LOGGER.info("Connected to V1 device %s", self.name) | ||
|
|
||
| async def close(self) -> None: | ||
|
|
@@ -111,10 +82,7 @@ def _on_message(self, message: RoborockMessage) -> None: | |
| """Handle incoming messages from the device.""" | ||
| _LOGGER.debug("Received message from device: %s", message) | ||
|
|
||
| async def get_status(self) -> Status: | ||
| """Get the current status of the device. | ||
|
|
||
| This is a placeholder command and will likely be changed/moved in the future. | ||
| """ | ||
| status_type: type[Status] = ModelStatus.get(self._product_info.model, S7MaxVStatus) | ||
| return await self._v1_channel.rpc_channel.send_command(RoborockCommand.GET_STATUS, response_type=status_type) | ||
| @property | ||
| def traits(self) -> Mapping[str, Trait]: | ||
| """Return the traits of the device.""" | ||
| return MappingProxyType(self._trait_map) | ||
|
Comment on lines
+86
to
+88
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why can't this just stay as a dict as is? I don't really know what MappingProxyType is
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When this returns a |
||
Uh oh!
There was an error while loading. Please reload this page.