From d7132f598d2d5d8992cebbbd82cb56cbe86c60ba Mon Sep 17 00:00:00 2001 From: Markus Jacobsen Date: Wed, 10 Sep 2025 13:58:04 +0200 Subject: [PATCH 1/9] Add base Websocket class --- homeassistant_api/rawbasewebsocket.py | 82 +++++++++++++++++++++++++++ homeassistant_api/rawwebsocket.py | 48 +--------------- homeassistant_api/websocket.py | 3 +- 3 files changed, 86 insertions(+), 47 deletions(-) create mode 100644 homeassistant_api/rawbasewebsocket.py diff --git a/homeassistant_api/rawbasewebsocket.py b/homeassistant_api/rawbasewebsocket.py new file mode 100644 index 00000000..f959d502 --- /dev/null +++ b/homeassistant_api/rawbasewebsocket.py @@ -0,0 +1,82 @@ +import logging +import time +from typing import Optional, cast + +from pydantic import ValidationError + +from homeassistant_api.errors import ( + ReceivingError, + RequestError, +) +from homeassistant_api.models.websocket import ( + ErrorResponse, + EventResponse, + PingResponse, + ResultResponse, +) +from homeassistant_api.utils import JSONType + +logger = logging.getLogger(__name__) + + +class RawBaseWebsocketClient: + """Shared methods for Websocket clients.""" + + api_url: str + token: str + _id_counter: int + _result_responses: dict[int, Optional[ResultResponse]] + _event_responses: dict[int, list[EventResponse]] + _ping_responses: dict[int, PingResponse] + + def __init__(self, api_url: str, token: str) -> None: + self.api_url = api_url + self.token = token.strip() + + self._id_counter = 0 + self._result_responses = {} # id -> response + self._event_responses = {} # id -> [response, ...] + self._ping_responses = {} # id -> (sent, received) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.api_url!r})" + + def _request_id(self) -> int: + """Get a unique id for a message.""" + self._id_counter += 1 + return self._id_counter + + def check_success(self, data: dict[str, JSONType]) -> None: + """Check if a command message was successful.""" + try: + error_resp = ErrorResponse.model_validate(data) + raise RequestError(error_resp.error.code, error_resp.error.message) + except ValidationError: + pass + + def handle_recv(self, data: dict[str, JSONType]) -> None: + """Handle a received message.""" + if "id" not in data: + raise ReceivingError( + "Received a message without an id outside the auth phase." + ) + self.check_success(data) + self.parse_response(data) + + def parse_response(self, data: dict[str, JSONType]) -> None: + data_id = cast(int, data["id"]) + if data.get("type") == "pong": + logger.info("Received pong message") + self._ping_responses[data_id].end = time.perf_counter_ns() + elif data.get("type") == "result": + logger.info("Received result message") + if data.get("success"): + self._result_responses[data_id] = ResultResponse.model_validate(data) + else: + error_resp = ErrorResponse.model_validate(data) + raise RequestError(error_resp.error.code, error_resp.error.message) + elif data.get("type") == "event": + logger.info("Received event message %s", data["event"]) + self._event_responses[data_id].append(EventResponse.model_validate(data)) + else: + raise ReceivingError(f"Received unexpected message type: {data}") diff --git a/homeassistant_api/rawwebsocket.py b/homeassistant_api/rawwebsocket.py index 09e702f5..c14560dd 100644 --- a/homeassistant_api/rawwebsocket.py +++ b/homeassistant_api/rawwebsocket.py @@ -8,7 +8,6 @@ from homeassistant_api.errors import ( ReceivingError, - RequestError, ResponseError, UnauthorizedError, ) @@ -16,17 +15,17 @@ AuthInvalid, AuthOk, AuthRequired, - ErrorResponse, EventResponse, PingResponse, ResultResponse, ) +from homeassistant_api.rawbasewebsocket import RawBaseWebsocketClient from homeassistant_api.utils import JSONType logger = logging.getLogger(__name__) -class RawWebsocketClient: +class RawWebsocketClient(RawBaseWebsocketClient): api_url: str token: str _conn: Optional[ws.ClientConnection] @@ -36,8 +35,7 @@ def __init__( api_url: str, token: str, ) -> None: - self.api_url = api_url - self.token = token.strip() + super().__init__(api_url, token) self._conn = None self._id_counter = 0 @@ -66,11 +64,6 @@ def __exit__(self, exc_type, exc_value, traceback): self._conn.__exit__(exc_type, exc_value, traceback) self._conn = None - def _request_id(self) -> int: - """Get a unique id for a message.""" - self._id_counter += 1 - return self._id_counter - def _send(self, data: dict[str, JSONType]) -> None: """Send a message to the websocket server.""" logger.debug(f"Sending message: {data}") @@ -112,41 +105,6 @@ def send(self, type: str, include_id: bool = True, **data: Any) -> int: return data["id"] return -1 # non-command messages don't have an id - def check_success(self, data: dict[str, JSONType]) -> None: - """Check if a command message was successful.""" - try: - error_resp = ErrorResponse.model_validate(data) - raise RequestError(error_resp.error.code, error_resp.error.message) - except ValidationError: - pass - - def handle_recv(self, data: dict[str, JSONType]) -> None: - """Handle a received message.""" - if "id" not in data: - raise ReceivingError( - "Received a message without an id outside the auth phase." - ) - self.check_success(data) - self.parse_response(data) - - def parse_response(self, data: dict[str, JSONType]) -> None: - data_id = cast(int, data["id"]) - if data.get("type") == "pong": - logger.info("Received pong message") - self._ping_responses[data_id].end = time.perf_counter_ns() - elif data.get("type") == "result": - logger.info("Received result message") - if data.get("success"): - self._result_responses[data_id] = ResultResponse.model_validate(data) - else: - error_resp = ErrorResponse.model_validate(data) - raise RequestError(error_resp.error.code, error_resp.error.message) - elif data.get("type") == "event": - logger.info("Received event message %s", data["event"]) - self._event_responses[data_id].append(EventResponse.model_validate(data)) - else: - raise ReceivingError(f"Received unexpected message type: {data}") - def recv(self, id: int) -> Union[EventResponse, ResultResponse, PingResponse]: """Receive a response to a message from the websocket server.""" while True: diff --git a/homeassistant_api/websocket.py b/homeassistant_api/websocket.py index eb9d2c79..9af064da 100644 --- a/homeassistant_api/websocket.py +++ b/homeassistant_api/websocket.py @@ -23,10 +23,9 @@ ResultResponse, TemplateEvent, ) +from homeassistant_api.rawwebsocket import RawWebsocketClient from homeassistant_api.utils import JSONType, prepare_entity_id -from .rawwebsocket import RawWebsocketClient - logger = logging.getLogger(__name__) From dfd0349052e03218ce280ff0cc80e6a88e6e0e73 Mon Sep 17 00:00:00 2001 From: Markus Jacobsen Date: Wed, 10 Sep 2025 14:39:40 +0200 Subject: [PATCH 2/9] Add Async Websocket client --- homeassistant_api/__init__.py | 2 + homeassistant_api/asyncwebsocket.py | 395 +++++++++++++++++++++++++ homeassistant_api/models/domains.py | 8 +- homeassistant_api/rawasyncwebsocket.py | 155 ++++++++++ 4 files changed, 557 insertions(+), 3 deletions(-) create mode 100644 homeassistant_api/asyncwebsocket.py create mode 100644 homeassistant_api/rawasyncwebsocket.py diff --git a/homeassistant_api/__init__.py b/homeassistant_api/__init__.py index cc637912..055cc2f5 100644 --- a/homeassistant_api/__init__.py +++ b/homeassistant_api/__init__.py @@ -12,6 +12,7 @@ "Event", "LogbookEntry", "WebsocketClient", + "AsyncWebsocketClient", "AuthInvalid", "AuthOk", "AuthRequired", @@ -21,6 +22,7 @@ "EventResponse", ) +from .asyncwebsocket import AsyncWebsocketClient from .client import Client from .models.domains import Domain, Service from .models.entity import Entity, Group diff --git a/homeassistant_api/asyncwebsocket.py b/homeassistant_api/asyncwebsocket.py new file mode 100644 index 00000000..28b110b4 --- /dev/null +++ b/homeassistant_api/asyncwebsocket.py @@ -0,0 +1,395 @@ +import contextlib +import logging +import urllib.parse as urlparse +from typing import AsyncGenerator, Dict, Optional, Tuple, Union, cast + +from homeassistant_api.models import Domain, Entity, Group, State +from homeassistant_api.models.states import Context +from homeassistant_api.models.websocket import ( + EventResponse, + FiredEvent, + FiredTrigger, + ResultResponse, + TemplateEvent, +) +from homeassistant_api.rawasyncwebsocket import RawAsyncWebsocketClient +from homeassistant_api.utils import JSONType, prepare_entity_id + +logger = logging.getLogger(__name__) + + +class AsyncWebsocketClient(RawAsyncWebsocketClient): + """ + + The main class for interacting with the Async Home Assistant WebSocket API client. + + Here's a quick example of how to use the :py:class:`AsyncWebsocketClient` class: + + .. code-block:: python + + from homeassistant_api import AsyncWebsocketClient + + async with AsyncWebsocketClient( + '', # i.e. 'ws://homeassistant.local:8123/api/websocket' + '' + ) as ws_client: + light = await ws_client.trigger_service('light', 'turn_on', entity_id="light.living_room") + """ + + def __init__( + self, + api_url: str, + token: str, + ) -> None: + parsed = urlparse.urlparse(api_url) + + if parsed.scheme not in {"ws", "wss"}: + raise ValueError(f"Unknown scheme {parsed.scheme} in {api_url}") + super().__init__(api_url, token) + logger.debug(f"AsyncWebsocketClient initialized with api_url: {api_url}") + + async def get_rendered_template(self, template: str) -> str: + """ + Renders a Jinja2 template with Home Assistant context data. + See https://www.home-assistant.io/docs/configuration/templating. + + Sends command :code:`{"type": "render_template", ...}`. + """ + id = await self.send("render_template", template=template, report_errors=True) + first = await self.recv(id) + assert cast(ResultResponse, first).result is None + second = await self.recv(id) + await self._unsubscribe(id) + return cast(TemplateEvent, cast(EventResponse, second).event).result + + async def get_config(self) -> dict[str, JSONType]: + """ + Get the Home Assistant configuration. + + Sends command :code:`{"type": "get_config", ...}`. + """ + return cast( + dict[str, JSONType], + cast( + ResultResponse, + await self.recv(await self.send("get_config")), + ).result, + ) + + async def get_states(self) -> Tuple[State, ...]: + """ + Get a list of states. + + Sends command :code:`{"type": "get_states", ...}`. + """ + return tuple( + State.from_json(state) + for state in cast( + list[dict[str, JSONType]], + cast( + ResultResponse, await self.recv(await self.send("get_states")) + ).result, + ) + ) + + async def get_state( # pylint: disable=duplicate-code + self, + *, + entity_id: Optional[str] = None, + group_id: Optional[str] = None, + slug: Optional[str] = None, + ) -> State: + """ + Just calls the :py:meth:`get_states` method and filters the result. + + Please tell home-assistant/core to add a :code:`{"type": "get_state", ...}` command to the WS API! + There is a lot of disappointment and frustration in the community because this is not available. + """ + entity_id = prepare_entity_id( + group_id=group_id, + slug=slug, + entity_id=entity_id, + ) + + for state in await self.get_states(): + if state.entity_id == entity_id: + return state + raise ValueError(f"Entity {entity_id} not found!") + + async def get_entities(self) -> Dict[str, Group]: + """ + Fetches all entities from the Websocket API and returns them as a dictionary of :py:class:`Group`'s. + For example :code:`light.living_room` would be in the group :code:`light` (i.e. :code:`get_entities()["light"].living_room`). + """ + entities: Dict[str, Group] = {} + for state in await self.get_states(): + group_id, entity_slug = state.entity_id.split(".") + if group_id not in entities: + entities[group_id] = Group( + group_id=group_id, + _client=self, # type: ignore[arg-type] + ) + entities[group_id]._add_entity(entity_slug, state) + return entities + + async def get_entity( + self, + group_id: Optional[str] = None, + slug: Optional[str] = None, + entity_id: Optional[str] = None, + ) -> Optional[Entity]: + """ + Returns an :py:class:`Entity` model for an :code:`entity_id`. + + Calls :py:meth:`get_states` under the hood. + + Please tell home-assistant/core to add a :code:`{"type": "get_state", ...}` command to the WS API! + There is a lot of disappointment and frustration in the community because this is not available. + """ + if group_id is not None and slug is not None: + state = await self.get_state(group_id=group_id, slug=slug) + elif entity_id is not None: + state = await self.get_state(entity_id=entity_id) + else: + help_msg = ( + "Use keyword arguments to pass entity_id. " + "Or you can pass the group_id and slug instead" + ) + raise ValueError( + f"Neither group_id and slug or entity_id provided. {help_msg}" + ) + split_group_id, split_slug = state.entity_id.split(".") + group = Group( + group_id=split_group_id, + _client=self, # type: ignore[arg-type] + ) + group._add_entity(split_slug, state) + return group.get_entity(split_slug) + + async def get_domains(self) -> dict[str, Domain]: + """ + Get a list of services that Home Assistant offers (organized into a dictionary of service domains). + + For example, the service :code:`light.turn_on` would be in the domain :code:`light`. + + Sends command :code:`{"type": "get_services", ...}`. + """ + resp = await self.recv(await self.send("get_services")) + domains = map( + lambda item: Domain.from_json( + {"domain": item[0], "services": item[1]}, + client=self, + ), + cast(dict[str, JSONType], cast(ResultResponse, resp).result).items(), + ) + return {domain.domain_id: domain for domain in domains} + + async def get_domain(self, domain: str) -> Domain: + """Get a domain. + + Note: This is not a method in the WS API client... yet. + + Please tell home-assistant/core to add a `get_domain` command to the WS API! + + For now, just call the :py:meth":`get_domains` method and parsing the result. + """ + return (await self.get_domains())[domain] + + async def trigger_service( + self, + domain: str, + service: str, + entity_id: Optional[str] = None, + **service_data, + ) -> None: + """ + Trigger a service (that doesn't return a response). + + Sends command :code:`{"type": "call_service", ...}`. + """ + params = { + "domain": domain, + "service": service, + "service_data": service_data, + "return_response": False, + } + if entity_id is not None: + params["target"] = {"entity_id": entity_id} + + data = await self.recv( + await self.send("call_service", include_id=True, **params) + ) + + # TODO: handle data["result"]["context"] ? + + assert ( + cast( + dict[str, JSONType], + cast(ResultResponse, data).result, + ).get("response") + is None + ) # should always be None for services without a response + + async def trigger_service_with_response( + self, + domain: str, + service: str, + entity_id: Optional[str] = None, + **service_data, + ) -> dict[str, JSONType]: + """ + Trigger a service (that returns a response) and return the response. + + Sends command :code:`{"type": "call_service", ...}`. + """ + params = { + "domain": domain, + "service": service, + "service_data": service_data, + "return_response": True, + } + if entity_id is not None: + params["target"] = {"entity_id": entity_id} + + data = await self.recv( + await self.send("call_service", include_id=True, **params) + ) + + return cast(dict[str, dict[str, JSONType]], cast(ResultResponse, data).result)[ + "response" + ] + + @contextlib.asynccontextmanager + async def listen_events( + self, + event_type: Optional[str] = None, + ) -> AsyncGenerator[AsyncGenerator[FiredEvent, None], None]: + """ + Listen for all events of a certain type. + + For example, to listen for all events of type `test_event`: + + .. code-block:: python + + async with ws_client.listen_events("test_event") as events: + async for i, event in zip(range(2), events): # to only wait for two events to be received + print(event) + """ + subscription = await self._subscribe_events(event_type) + yield cast(AsyncGenerator[FiredEvent, None], self._wait_for(subscription)) + await self._unsubscribe(subscription) + + async def _subscribe_events(self, event_type: Optional[str]) -> int: + """ + Subscribe to all events of a certain type. + + + Sends command :code:`{"type": "subscribe_events", ...}`. + """ + params = {"event_type": event_type} if event_type else {} + return ( + await self.recv( + await self.send("subscribe_events", include_id=True, **params) + ) + ).id + + @contextlib.asynccontextmanager + async def listen_trigger( + self, trigger: str, **trigger_fields + ) -> AsyncGenerator[AsyncGenerator[dict[str, JSONType], None], None]: + """ + Listen to a Home Assistant trigger. + Allows additional trigger keyword parameters with :code:`**kwargs` (i.e. passing :code:`tag_id=...` for NFC tag triggers). + + For example, in Home Assistant Automations we can subscribe to a state trigger for a light entity with YAML: + + .. code-block:: yaml + + triggers: + # ... + - trigger: state + entity_id: light.kitchen + + To subscribe to that same state trigger with :py:class:`AsyncWebsocketClient` instead + + .. code-block:: python + + async with ws_client.listen_trigger("state", entity_id="light.kitchen") as trigger: + async for event in trigger: # will iterate until we manually break out of the loop + print(event) + if : + break + # exiting the context manager unsubscribes from the trigger + + Woohoo! We can now listen to triggers in Python code! + """ + subscription = await self._subscribe_trigger(trigger, **trigger_fields) + yield ( + fired_trigger.variables + async for fired_trigger in cast( + AsyncGenerator[FiredTrigger, None], + self._wait_for(subscription), + ) + ) + await self._unsubscribe(subscription) + + async def _subscribe_trigger(self, trigger: str, **trigger_fields) -> int: + """ + Return the subscription id of the trigger we subscribe to. + + Sends command :code:`{"type": "subscribe_trigger", ...}`. + """ + return ( + await self.recv( + await self.send( + "subscribe_trigger", trigger={"platform": trigger, **trigger_fields} + ) + ) + ).id + + async def _wait_for( + self, subscription_id: int + ) -> AsyncGenerator[Union[FiredEvent, FiredTrigger], None]: + """ + An iterator that waits for events of a certain type. + """ + while True: + yield cast( + Union[ + FiredEvent, FiredTrigger + ], # we can cast this because TemplateEvent is only used for rendering templates + cast(EventResponse, await self.recv(subscription_id)).event, + ) + + async def _unsubscribe(self, subcription_id: int) -> None: + """ + Unsubscribe from all events of a certain type. + + Sends command :code:`{"type": "unsubscribe_events", ...}`. + """ + resp = await self.recv( + await self.send("unsubscribe_events", subscription=subcription_id) + ) + assert cast(ResultResponse, resp).result is None + self._event_responses.pop(subcription_id) + + async def fire_event(self, event_type: str, **event_data) -> Context: + """ + Fire an event. + + Sends command :code:`{"type": "fire_event", ...}`. + """ + params: dict[str, JSONType] = {"event_type": event_type} + if event_data: + params["event_data"] = event_data + return Context.from_json( + cast( + dict[str, dict[str, JSONType]], + cast( + ResultResponse, + await self.recv( + await self.send("fire_event", include_id=True, **params) + ), + ).result, + )["context"] + ) diff --git a/homeassistant_api/models/domains.py b/homeassistant_api/models/domains.py index 9c8c9505..7f2e5367 100644 --- a/homeassistant_api/models/domains.py +++ b/homeassistant_api/models/domains.py @@ -27,7 +27,7 @@ from .states import State if TYPE_CHECKING: - from homeassistant_api import Client, WebsocketClient + from homeassistant_api import Client, WebsocketClient, AsyncWebsocketClient class Domain(BaseModel): @@ -36,7 +36,7 @@ class Domain(BaseModel): def __init__( self, *args, - _client: Optional[Union["Client", "WebsocketClient"]] = None, + _client: Optional[Union["Client", "WebsocketClient", "AsyncWebsocketClient"]] = None, **kwargs, ) -> None: super().__init__(*args, **kwargs) @@ -44,7 +44,7 @@ def __init__( raise ValueError("No client passed.") object.__setattr__(self, "_client", _client) - _client: Union["Client", "WebsocketClient"] + _client: Union["Client", "WebsocketClient", "AsyncWebsocketClient"] domain_id: str = Field( ..., description="The name of the domain that services belong to. " @@ -65,6 +65,8 @@ def from_json(cls, json: Union[dict[str, JSONType], Any, None], **kwargs) -> Sel @classmethod def from_json_with_client( cls, json: Dict[str, JSONType], client: Union["Client", "WebsocketClient"] + def from_json( + cls, json: Dict[str, JSONType], client: Union["Client", "WebsocketClient", "AsyncWebsocketClient"] ) -> "Domain": """Constructs Domain and Service models from json data.""" if "domain" not in json or "services" not in json: diff --git a/homeassistant_api/rawasyncwebsocket.py b/homeassistant_api/rawasyncwebsocket.py new file mode 100644 index 00000000..462bfcd2 --- /dev/null +++ b/homeassistant_api/rawasyncwebsocket.py @@ -0,0 +1,155 @@ +import json +import logging +import time +from typing import Any, Optional, Union, cast + +import websockets.asyncio.client as ws +from pydantic import ValidationError + +from homeassistant_api.errors import ( + ReceivingError, + ResponseError, + UnauthorizedError, +) +from homeassistant_api.models.websocket import ( + AuthInvalid, + AuthOk, + AuthRequired, + EventResponse, + PingResponse, + ResultResponse, +) +from homeassistant_api.rawbasewebsocket import RawBaseWebsocketClient +from homeassistant_api.utils import JSONType + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +class RawAsyncWebsocketClient(RawBaseWebsocketClient): + api_url: str + token: str + _conn: Optional[ws.ClientConnection] + + def __init__( + self, + api_url: str, + token: str, + ) -> None: + super().__init__(api_url, token) + self._conn = None + + async def __aenter__(self): + self._conn = await ws.connect(self.api_url) + await self._conn.__aenter__() + okay = await self.authentication_phase() + logging.info("Authenticated with Home Assistant (%s)", okay.ha_version) + await self.supported_features_phase() + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + if not self._conn: + raise ReceivingError("Connection is not open!") + await self._conn.__aexit__(exc_type, exc_value, traceback) + self._conn = None + + async def _send(self, data: dict[str, JSONType]) -> None: + """Send a message to the websocket server.""" + logger.debug(f"Sending message: {data}") + if self._conn is None: + raise ReceivingError("Connection is not open!") + await self._conn.send(json.dumps(data)) + + async def _recv(self) -> dict[str, JSONType]: + """Receive a message from the websocket server.""" + if self._conn is None: + raise ReceivingError("Connection is not open!") + _bytes = await self._conn.recv() + logger.debug("Received message: %s", _bytes) + return cast(dict[str, JSONType], json.loads(_bytes)) + + async def send(self, type: str, include_id: bool = True, **data: Any) -> int: + """ + Send a command message to the websocket server and wait for a "result" response. + + Returns the id of the message sent. + """ + if include_id: # auth messages don't have an id + data["id"] = self._request_id() + + data["type"] = type + await self._send(data) + + if "id" in data: + assert isinstance(data["id"], int) + if data["type"] == "ping": + self._ping_responses[data["id"]] = PingResponse( + start=time.perf_counter_ns(), + id=data["id"], + type="pong", + ) + else: + self._event_responses[data["id"]] = [] + self._result_responses[data["id"]] = None + return data["id"] + return -1 # non-command messages don't have an id + + async def recv(self, id: int) -> Union[EventResponse, ResultResponse, PingResponse]: + """Receive a response to a message from the websocket server.""" + while True: + ## have we received a message with the id we're looking for? + if self._result_responses.get(id) is not None: + return cast(dict[int, ResultResponse], self._result_responses).pop( + id + ) # ughhh why can't mypy figure this out + if self._event_responses.get(id, []): + return self._event_responses[id].pop(0) + if self._ping_responses.get(id) is not None: + if self._ping_responses[id].end is not None: + return self._ping_responses.pop(id) + + ## if not, keep receiving messages until we do + self.handle_recv(await self._recv()) + + async def authentication_phase(self) -> AuthOk: + """Authenticate with the websocket server.""" + # Capture the first message from the server saying we need to authenticate + try: + welcome = AuthRequired.model_validate(await self._recv()) + logger.debug(f"Received welcome message: {welcome}") + except ValidationError as e: + raise ResponseError("Unexpected response during authentication") from e + + # Send our authentication token + await self.send("auth", access_token=self.token, include_id=False) + logger.debug("Sent auth message") + + # Check the response + resp = await self._recv() + try: + return AuthOk.model_validate(resp) + except ValidationError as e: + error_resp = AuthInvalid.model_validate(resp) + raise UnauthorizedError(error_resp.message) from e + except Exception as e: + raise ResponseError( + "Unexpected response during authentication", resp["message"] + ) from e + + async def supported_features_phase(self) -> None: + """Get the supported features from the websocket server.""" + resp = await self.recv( + await self.send( + "supported_features", + features={ + # "coalesce_messages": 42, # including this key sets it to True + }, + ) + ) + assert cast(ResultResponse, resp).result is None + + async def ping_latency(self) -> float: + """Get the latency (in milliseconds) of the connection by sending a ping message.""" + pong = cast(PingResponse, await self.recv(await self.send("ping"))) + assert pong.end is not None + return (pong.end - pong.start) / 1_000_000 From 85f2e60dfb4ff223e06606086c69bca6344cdd5a Mon Sep 17 00:00:00 2001 From: Markus Jacobsen Date: Mon, 15 Sep 2025 10:50:20 +0200 Subject: [PATCH 3/9] Rework structure and fix typing --- homeassistant_api/__init__.py | 2 - homeassistant_api/asyncwebsocket.py | 395 ------------------ homeassistant_api/models/domains.py | 30 +- homeassistant_api/rawasyncwebsocket.py | 438 ++++++++++++++++++-- homeassistant_api/rawwebsocket.py | 351 +++++++++++++++- homeassistant_api/websocket.py | 531 +------------------------ 6 files changed, 777 insertions(+), 970 deletions(-) delete mode 100644 homeassistant_api/asyncwebsocket.py diff --git a/homeassistant_api/__init__.py b/homeassistant_api/__init__.py index 055cc2f5..cc637912 100644 --- a/homeassistant_api/__init__.py +++ b/homeassistant_api/__init__.py @@ -12,7 +12,6 @@ "Event", "LogbookEntry", "WebsocketClient", - "AsyncWebsocketClient", "AuthInvalid", "AuthOk", "AuthRequired", @@ -22,7 +21,6 @@ "EventResponse", ) -from .asyncwebsocket import AsyncWebsocketClient from .client import Client from .models.domains import Domain, Service from .models.entity import Entity, Group diff --git a/homeassistant_api/asyncwebsocket.py b/homeassistant_api/asyncwebsocket.py deleted file mode 100644 index 28b110b4..00000000 --- a/homeassistant_api/asyncwebsocket.py +++ /dev/null @@ -1,395 +0,0 @@ -import contextlib -import logging -import urllib.parse as urlparse -from typing import AsyncGenerator, Dict, Optional, Tuple, Union, cast - -from homeassistant_api.models import Domain, Entity, Group, State -from homeassistant_api.models.states import Context -from homeassistant_api.models.websocket import ( - EventResponse, - FiredEvent, - FiredTrigger, - ResultResponse, - TemplateEvent, -) -from homeassistant_api.rawasyncwebsocket import RawAsyncWebsocketClient -from homeassistant_api.utils import JSONType, prepare_entity_id - -logger = logging.getLogger(__name__) - - -class AsyncWebsocketClient(RawAsyncWebsocketClient): - """ - - The main class for interacting with the Async Home Assistant WebSocket API client. - - Here's a quick example of how to use the :py:class:`AsyncWebsocketClient` class: - - .. code-block:: python - - from homeassistant_api import AsyncWebsocketClient - - async with AsyncWebsocketClient( - '', # i.e. 'ws://homeassistant.local:8123/api/websocket' - '' - ) as ws_client: - light = await ws_client.trigger_service('light', 'turn_on', entity_id="light.living_room") - """ - - def __init__( - self, - api_url: str, - token: str, - ) -> None: - parsed = urlparse.urlparse(api_url) - - if parsed.scheme not in {"ws", "wss"}: - raise ValueError(f"Unknown scheme {parsed.scheme} in {api_url}") - super().__init__(api_url, token) - logger.debug(f"AsyncWebsocketClient initialized with api_url: {api_url}") - - async def get_rendered_template(self, template: str) -> str: - """ - Renders a Jinja2 template with Home Assistant context data. - See https://www.home-assistant.io/docs/configuration/templating. - - Sends command :code:`{"type": "render_template", ...}`. - """ - id = await self.send("render_template", template=template, report_errors=True) - first = await self.recv(id) - assert cast(ResultResponse, first).result is None - second = await self.recv(id) - await self._unsubscribe(id) - return cast(TemplateEvent, cast(EventResponse, second).event).result - - async def get_config(self) -> dict[str, JSONType]: - """ - Get the Home Assistant configuration. - - Sends command :code:`{"type": "get_config", ...}`. - """ - return cast( - dict[str, JSONType], - cast( - ResultResponse, - await self.recv(await self.send("get_config")), - ).result, - ) - - async def get_states(self) -> Tuple[State, ...]: - """ - Get a list of states. - - Sends command :code:`{"type": "get_states", ...}`. - """ - return tuple( - State.from_json(state) - for state in cast( - list[dict[str, JSONType]], - cast( - ResultResponse, await self.recv(await self.send("get_states")) - ).result, - ) - ) - - async def get_state( # pylint: disable=duplicate-code - self, - *, - entity_id: Optional[str] = None, - group_id: Optional[str] = None, - slug: Optional[str] = None, - ) -> State: - """ - Just calls the :py:meth:`get_states` method and filters the result. - - Please tell home-assistant/core to add a :code:`{"type": "get_state", ...}` command to the WS API! - There is a lot of disappointment and frustration in the community because this is not available. - """ - entity_id = prepare_entity_id( - group_id=group_id, - slug=slug, - entity_id=entity_id, - ) - - for state in await self.get_states(): - if state.entity_id == entity_id: - return state - raise ValueError(f"Entity {entity_id} not found!") - - async def get_entities(self) -> Dict[str, Group]: - """ - Fetches all entities from the Websocket API and returns them as a dictionary of :py:class:`Group`'s. - For example :code:`light.living_room` would be in the group :code:`light` (i.e. :code:`get_entities()["light"].living_room`). - """ - entities: Dict[str, Group] = {} - for state in await self.get_states(): - group_id, entity_slug = state.entity_id.split(".") - if group_id not in entities: - entities[group_id] = Group( - group_id=group_id, - _client=self, # type: ignore[arg-type] - ) - entities[group_id]._add_entity(entity_slug, state) - return entities - - async def get_entity( - self, - group_id: Optional[str] = None, - slug: Optional[str] = None, - entity_id: Optional[str] = None, - ) -> Optional[Entity]: - """ - Returns an :py:class:`Entity` model for an :code:`entity_id`. - - Calls :py:meth:`get_states` under the hood. - - Please tell home-assistant/core to add a :code:`{"type": "get_state", ...}` command to the WS API! - There is a lot of disappointment and frustration in the community because this is not available. - """ - if group_id is not None and slug is not None: - state = await self.get_state(group_id=group_id, slug=slug) - elif entity_id is not None: - state = await self.get_state(entity_id=entity_id) - else: - help_msg = ( - "Use keyword arguments to pass entity_id. " - "Or you can pass the group_id and slug instead" - ) - raise ValueError( - f"Neither group_id and slug or entity_id provided. {help_msg}" - ) - split_group_id, split_slug = state.entity_id.split(".") - group = Group( - group_id=split_group_id, - _client=self, # type: ignore[arg-type] - ) - group._add_entity(split_slug, state) - return group.get_entity(split_slug) - - async def get_domains(self) -> dict[str, Domain]: - """ - Get a list of services that Home Assistant offers (organized into a dictionary of service domains). - - For example, the service :code:`light.turn_on` would be in the domain :code:`light`. - - Sends command :code:`{"type": "get_services", ...}`. - """ - resp = await self.recv(await self.send("get_services")) - domains = map( - lambda item: Domain.from_json( - {"domain": item[0], "services": item[1]}, - client=self, - ), - cast(dict[str, JSONType], cast(ResultResponse, resp).result).items(), - ) - return {domain.domain_id: domain for domain in domains} - - async def get_domain(self, domain: str) -> Domain: - """Get a domain. - - Note: This is not a method in the WS API client... yet. - - Please tell home-assistant/core to add a `get_domain` command to the WS API! - - For now, just call the :py:meth":`get_domains` method and parsing the result. - """ - return (await self.get_domains())[domain] - - async def trigger_service( - self, - domain: str, - service: str, - entity_id: Optional[str] = None, - **service_data, - ) -> None: - """ - Trigger a service (that doesn't return a response). - - Sends command :code:`{"type": "call_service", ...}`. - """ - params = { - "domain": domain, - "service": service, - "service_data": service_data, - "return_response": False, - } - if entity_id is not None: - params["target"] = {"entity_id": entity_id} - - data = await self.recv( - await self.send("call_service", include_id=True, **params) - ) - - # TODO: handle data["result"]["context"] ? - - assert ( - cast( - dict[str, JSONType], - cast(ResultResponse, data).result, - ).get("response") - is None - ) # should always be None for services without a response - - async def trigger_service_with_response( - self, - domain: str, - service: str, - entity_id: Optional[str] = None, - **service_data, - ) -> dict[str, JSONType]: - """ - Trigger a service (that returns a response) and return the response. - - Sends command :code:`{"type": "call_service", ...}`. - """ - params = { - "domain": domain, - "service": service, - "service_data": service_data, - "return_response": True, - } - if entity_id is not None: - params["target"] = {"entity_id": entity_id} - - data = await self.recv( - await self.send("call_service", include_id=True, **params) - ) - - return cast(dict[str, dict[str, JSONType]], cast(ResultResponse, data).result)[ - "response" - ] - - @contextlib.asynccontextmanager - async def listen_events( - self, - event_type: Optional[str] = None, - ) -> AsyncGenerator[AsyncGenerator[FiredEvent, None], None]: - """ - Listen for all events of a certain type. - - For example, to listen for all events of type `test_event`: - - .. code-block:: python - - async with ws_client.listen_events("test_event") as events: - async for i, event in zip(range(2), events): # to only wait for two events to be received - print(event) - """ - subscription = await self._subscribe_events(event_type) - yield cast(AsyncGenerator[FiredEvent, None], self._wait_for(subscription)) - await self._unsubscribe(subscription) - - async def _subscribe_events(self, event_type: Optional[str]) -> int: - """ - Subscribe to all events of a certain type. - - - Sends command :code:`{"type": "subscribe_events", ...}`. - """ - params = {"event_type": event_type} if event_type else {} - return ( - await self.recv( - await self.send("subscribe_events", include_id=True, **params) - ) - ).id - - @contextlib.asynccontextmanager - async def listen_trigger( - self, trigger: str, **trigger_fields - ) -> AsyncGenerator[AsyncGenerator[dict[str, JSONType], None], None]: - """ - Listen to a Home Assistant trigger. - Allows additional trigger keyword parameters with :code:`**kwargs` (i.e. passing :code:`tag_id=...` for NFC tag triggers). - - For example, in Home Assistant Automations we can subscribe to a state trigger for a light entity with YAML: - - .. code-block:: yaml - - triggers: - # ... - - trigger: state - entity_id: light.kitchen - - To subscribe to that same state trigger with :py:class:`AsyncWebsocketClient` instead - - .. code-block:: python - - async with ws_client.listen_trigger("state", entity_id="light.kitchen") as trigger: - async for event in trigger: # will iterate until we manually break out of the loop - print(event) - if : - break - # exiting the context manager unsubscribes from the trigger - - Woohoo! We can now listen to triggers in Python code! - """ - subscription = await self._subscribe_trigger(trigger, **trigger_fields) - yield ( - fired_trigger.variables - async for fired_trigger in cast( - AsyncGenerator[FiredTrigger, None], - self._wait_for(subscription), - ) - ) - await self._unsubscribe(subscription) - - async def _subscribe_trigger(self, trigger: str, **trigger_fields) -> int: - """ - Return the subscription id of the trigger we subscribe to. - - Sends command :code:`{"type": "subscribe_trigger", ...}`. - """ - return ( - await self.recv( - await self.send( - "subscribe_trigger", trigger={"platform": trigger, **trigger_fields} - ) - ) - ).id - - async def _wait_for( - self, subscription_id: int - ) -> AsyncGenerator[Union[FiredEvent, FiredTrigger], None]: - """ - An iterator that waits for events of a certain type. - """ - while True: - yield cast( - Union[ - FiredEvent, FiredTrigger - ], # we can cast this because TemplateEvent is only used for rendering templates - cast(EventResponse, await self.recv(subscription_id)).event, - ) - - async def _unsubscribe(self, subcription_id: int) -> None: - """ - Unsubscribe from all events of a certain type. - - Sends command :code:`{"type": "unsubscribe_events", ...}`. - """ - resp = await self.recv( - await self.send("unsubscribe_events", subscription=subcription_id) - ) - assert cast(ResultResponse, resp).result is None - self._event_responses.pop(subcription_id) - - async def fire_event(self, event_type: str, **event_data) -> Context: - """ - Fire an event. - - Sends command :code:`{"type": "fire_event", ...}`. - """ - params: dict[str, JSONType] = {"event_type": event_type} - if event_data: - params["event_data"] = event_data - return Context.from_json( - cast( - dict[str, dict[str, JSONType]], - cast( - ResultResponse, - await self.recv( - await self.send("fire_event", include_id=True, **params) - ), - ).result, - )["context"] - ) diff --git a/homeassistant_api/models/domains.py b/homeassistant_api/models/domains.py index 7f2e5367..0499fc44 100644 --- a/homeassistant_api/models/domains.py +++ b/homeassistant_api/models/domains.py @@ -27,7 +27,7 @@ from .states import State if TYPE_CHECKING: - from homeassistant_api import Client, WebsocketClient, AsyncWebsocketClient + from homeassistant_api import Client, WebsocketClient class Domain(BaseModel): @@ -36,7 +36,7 @@ class Domain(BaseModel): def __init__( self, *args, - _client: Optional[Union["Client", "WebsocketClient", "AsyncWebsocketClient"]] = None, + _client: Optional[Union["Client", "WebsocketClient"]] = None, **kwargs, ) -> None: super().__init__(*args, **kwargs) @@ -44,7 +44,7 @@ def __init__( raise ValueError("No client passed.") object.__setattr__(self, "_client", _client) - _client: Union["Client", "WebsocketClient", "AsyncWebsocketClient"] + _client: Union["Client", "WebsocketClient"] domain_id: str = Field( ..., description="The name of the domain that services belong to. " @@ -66,7 +66,9 @@ def from_json(cls, json: Union[dict[str, JSONType], Any, None], **kwargs) -> Sel def from_json_with_client( cls, json: Dict[str, JSONType], client: Union["Client", "WebsocketClient"] def from_json( - cls, json: Dict[str, JSONType], client: Union["Client", "WebsocketClient", "AsyncWebsocketClient"] + cls, + json: Dict[str, JSONType], + client: Union["Client", "WebsocketClient"], ) -> "Domain": """Constructs Domain and Service models from json data.""" if "domain" not in json or "services" not in json: @@ -616,14 +618,13 @@ def trigger( async def async_trigger( self, **service_data - ) -> Union[Tuple[State, ...], Tuple[Tuple[State, ...], dict[str, JSONType]]]: + ) -> Union[ + Tuple[State, ...], + None, + dict[str, JSONType], + tuple[tuple[State, ...], dict[str, JSONType]], + ]: """Triggers the service associated with this object.""" - from homeassistant_api import WebsocketClient # prevent circular import - - if isinstance(self.domain._client, WebsocketClient): - raise NotImplementedError( - "WebsocketClient does not support async/await syntax." - ) try: return await self.domain._client.async_trigger_service_with_response( self.domain.domain_id, @@ -649,7 +650,12 @@ def __call__( Coroutine[ Any, Any, - Union[Tuple[State, ...], Tuple[Tuple[State, ...], dict[str, JSONType]]], + Union[ + Tuple[State, ...], + Tuple[Tuple[State, ...], dict[str, JSONType]], + dict[str, JSONType], + None, + ], ], ]: """ diff --git a/homeassistant_api/rawasyncwebsocket.py b/homeassistant_api/rawasyncwebsocket.py index 462bfcd2..352d18f3 100644 --- a/homeassistant_api/rawasyncwebsocket.py +++ b/homeassistant_api/rawasyncwebsocket.py @@ -1,7 +1,17 @@ +import contextlib import json import logging import time -from typing import Any, Optional, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Dict, + Optional, + Tuple, + Union, + cast, +) import websockets.asyncio.client as ws from pydantic import ValidationError @@ -11,64 +21,65 @@ ResponseError, UnauthorizedError, ) +from homeassistant_api.models import Domain, Entity, Group, State +from homeassistant_api.models.states import Context from homeassistant_api.models.websocket import ( AuthInvalid, AuthOk, AuthRequired, EventResponse, + FiredEvent, + FiredTrigger, PingResponse, ResultResponse, + TemplateEvent, ) from homeassistant_api.rawbasewebsocket import RawBaseWebsocketClient -from homeassistant_api.utils import JSONType +from homeassistant_api.utils import JSONType, prepare_entity_id + +if TYPE_CHECKING: + from homeassistant_api import WebsocketClient logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) class RawAsyncWebsocketClient(RawBaseWebsocketClient): - api_url: str - token: str - _conn: Optional[ws.ClientConnection] + _async_conn: Optional[ws.ClientConnection] - def __init__( - self, - api_url: str, - token: str, - ) -> None: + def __init__(self, api_url: str, token: str) -> None: super().__init__(api_url, token) - self._conn = None + self._async_conn = None async def __aenter__(self): - self._conn = await ws.connect(self.api_url) - await self._conn.__aenter__() - okay = await self.authentication_phase() + self._async_conn = await ws.connect(self.api_url) + await self._async_conn.__aenter__() + okay = await self.async_authentication_phase() logging.info("Authenticated with Home Assistant (%s)", okay.ha_version) - await self.supported_features_phase() + await self.async_supported_features_phase() return self async def __aexit__(self, exc_type, exc_value, traceback): - if not self._conn: + if not self._async_conn: raise ReceivingError("Connection is not open!") - await self._conn.__aexit__(exc_type, exc_value, traceback) - self._conn = None + await self._async_conn.__aexit__(exc_type, exc_value, traceback) + self._async_conn = None - async def _send(self, data: dict[str, JSONType]) -> None: + async def _async_send(self, data: dict[str, JSONType]) -> None: """Send a message to the websocket server.""" logger.debug(f"Sending message: {data}") - if self._conn is None: + if self._async_conn is None: raise ReceivingError("Connection is not open!") - await self._conn.send(json.dumps(data)) + await self._async_conn.send(json.dumps(data)) - async def _recv(self) -> dict[str, JSONType]: + async def _async_recv(self) -> dict[str, JSONType]: """Receive a message from the websocket server.""" - if self._conn is None: + if self._async_conn is None: raise ReceivingError("Connection is not open!") - _bytes = await self._conn.recv() + _bytes = await self._async_conn.recv() logger.debug("Received message: %s", _bytes) return cast(dict[str, JSONType], json.loads(_bytes)) - async def send(self, type: str, include_id: bool = True, **data: Any) -> int: + async def async_send(self, type: str, include_id: bool = True, **data: Any) -> int: """ Send a command message to the websocket server and wait for a "result" response. @@ -78,7 +89,7 @@ async def send(self, type: str, include_id: bool = True, **data: Any) -> int: data["id"] = self._request_id() data["type"] = type - await self._send(data) + await self._async_send(data) if "id" in data: assert isinstance(data["id"], int) @@ -94,7 +105,9 @@ async def send(self, type: str, include_id: bool = True, **data: Any) -> int: return data["id"] return -1 # non-command messages don't have an id - async def recv(self, id: int) -> Union[EventResponse, ResultResponse, PingResponse]: + async def async_recv( + self, id: int + ) -> Union[EventResponse, ResultResponse, PingResponse]: """Receive a response to a message from the websocket server.""" while True: ## have we received a message with the id we're looking for? @@ -109,23 +122,23 @@ async def recv(self, id: int) -> Union[EventResponse, ResultResponse, PingRespon return self._ping_responses.pop(id) ## if not, keep receiving messages until we do - self.handle_recv(await self._recv()) + self.handle_recv(await self._async_recv()) - async def authentication_phase(self) -> AuthOk: + async def async_authentication_phase(self) -> AuthOk: """Authenticate with the websocket server.""" # Capture the first message from the server saying we need to authenticate try: - welcome = AuthRequired.model_validate(await self._recv()) + welcome = AuthRequired.model_validate(await self._async_recv()) logger.debug(f"Received welcome message: {welcome}") except ValidationError as e: raise ResponseError("Unexpected response during authentication") from e # Send our authentication token - await self.send("auth", access_token=self.token, include_id=False) + await self.async_send("auth", access_token=self.token, include_id=False) logger.debug("Sent auth message") # Check the response - resp = await self._recv() + resp = await self._async_recv() try: return AuthOk.model_validate(resp) except ValidationError as e: @@ -136,10 +149,10 @@ async def authentication_phase(self) -> AuthOk: "Unexpected response during authentication", resp["message"] ) from e - async def supported_features_phase(self) -> None: + async def async_supported_features_phase(self) -> None: """Get the supported features from the websocket server.""" - resp = await self.recv( - await self.send( + resp = await self.async_recv( + await self.async_send( "supported_features", features={ # "coalesce_messages": 42, # including this key sets it to True @@ -148,8 +161,357 @@ async def supported_features_phase(self) -> None: ) assert cast(ResultResponse, resp).result is None - async def ping_latency(self) -> float: + async def async_ping_latency(self) -> float: """Get the latency (in milliseconds) of the connection by sending a ping message.""" - pong = cast(PingResponse, await self.recv(await self.send("ping"))) + pong = cast(PingResponse, await self.async_recv(await self.async_send("ping"))) assert pong.end is not None return (pong.end - pong.start) / 1_000_000 + + async def async_get_rendered_template(self, template: str) -> str: + """ + Renders a Jinja2 template with Home Assistant context data. + See https://www.home-assistant.io/docs/configuration/templating. + + Sends command :code:`{"type": "render_template", ...}`. + """ + id = await self.async_send( + "render_template", template=template, report_errors=True + ) + first = await self.async_recv(id) + assert cast(ResultResponse, first).result is None + second = await self.async_recv(id) + await self._async_unsubscribe(id) + return cast(TemplateEvent, cast(EventResponse, second).event).result + + async def async_get_config(self) -> dict[str, JSONType]: + """ + Get the Home Assistant configuration. + + Sends command :code:`{"type": "get_config", ...}`. + """ + return cast( + dict[str, JSONType], + cast( + ResultResponse, + await self.async_recv(await self.async_send("get_config")), + ).result, + ) + + async def async_get_states(self) -> Tuple[State, ...]: + """ + Get a list of states. + + Sends command :code:`{"type": "get_states", ...}`. + """ + return tuple( + State.from_json(state) + for state in cast( + list[dict[str, JSONType]], + cast( + ResultResponse, + await self.async_recv(await self.async_send("get_states")), + ).result, + ) + ) + + async def async_get_state( # pylint: disable=duplicate-code + self, + *, + entity_id: Optional[str] = None, + group_id: Optional[str] = None, + slug: Optional[str] = None, + ) -> State: + """ + Just calls the :py:meth:`get_states` method and filters the result. + + Please tell home-assistant/core to add a :code:`{"type": "get_state", ...}` command to the WS API! + There is a lot of disappointment and frustration in the community because this is not available. + """ + entity_id = prepare_entity_id( + group_id=group_id, + slug=slug, + entity_id=entity_id, + ) + + for state in await self.async_get_states(): + if state.entity_id == entity_id: + return state + raise ValueError(f"Entity {entity_id} not found!") + + async def async_get_entities(self) -> Dict[str, Group]: + """ + Fetches all entities from the Websocket API and returns them as a dictionary of :py:class:`Group`'s. + For example :code:`light.living_room` would be in the group :code:`light` (i.e. :code:`get_entities()["light"].living_room`). + """ + entities: Dict[str, Group] = {} + for state in await self.async_get_states(): + group_id, entity_slug = state.entity_id.split(".") + if group_id not in entities: + entities[group_id] = Group( + group_id=group_id, + _client=self, # type: ignore[arg-type] + ) + entities[group_id]._add_entity(entity_slug, state) + return entities + + async def async_get_entity( + self, + group_id: Optional[str] = None, + slug: Optional[str] = None, + entity_id: Optional[str] = None, + ) -> Optional[Entity]: + """ + Returns an :py:class:`Entity` model for an :code:`entity_id`. + + Calls :py:meth:`get_states` under the hood. + + Please tell home-assistant/core to add a :code:`{"type": "get_state", ...}` command to the WS API! + There is a lot of disappointment and frustration in the community because this is not available. + """ + if group_id is not None and slug is not None: + state = await self.async_get_state(group_id=group_id, slug=slug) + elif entity_id is not None: + state = await self.async_get_state(entity_id=entity_id) + else: + help_msg = ( + "Use keyword arguments to pass entity_id. " + "Or you can pass the group_id and slug instead" + ) + raise ValueError( + f"Neither group_id and slug or entity_id provided. {help_msg}" + ) + split_group_id, split_slug = state.entity_id.split(".") + group = Group( + group_id=split_group_id, + _client=self, # type: ignore[arg-type] + ) + group._add_entity(split_slug, state) + return group.get_entity(split_slug) + + async def async_get_domains(self) -> dict[str, Domain]: + """ + Get a list of services that Home Assistant offers (organized into a dictionary of service domains). + + For example, the service :code:`light.turn_on` would be in the domain :code:`light`. + + Sends command :code:`{"type": "get_services", ...}`. + """ + resp = await self.async_recv(await self.async_send("get_services")) + domains = map( + lambda item: Domain.from_json( + {"domain": item[0], "services": item[1]}, + client=cast(WebsocketClient, self), + ), + cast(dict[str, JSONType], cast(ResultResponse, resp).result).items(), + ) + return {domain.domain_id: domain for domain in domains} + + async def async_get_domain(self, domain: str) -> Domain: + """Get a domain. + + Note: This is not a method in the WS API client... yet. + + Please tell home-assistant/core to add a `get_domain` command to the WS API! + + For now, just call the :py:meth":`get_domains` method and parsing the result. + """ + return (await self.async_get_domains())[domain] + + async def async_trigger_service( + self, + domain: str, + service: str, + entity_id: Optional[str] = None, + **service_data, + ) -> None: + """ + Trigger a service (that doesn't return a response). + + Sends command :code:`{"type": "call_service", ...}`. + """ + params = { + "domain": domain, + "service": service, + "service_data": service_data, + "return_response": False, + } + if entity_id is not None: + params["target"] = {"entity_id": entity_id} + + data = await self.async_recv( + await self.async_send("call_service", include_id=True, **params) + ) + + # TODO: handle data["result"]["context"] ? + + assert ( + cast( + dict[str, JSONType], + cast(ResultResponse, data).result, + ).get("response") + is None + ) # should always be None for services without a response + + async def async_trigger_service_with_response( + self, + domain: str, + service: str, + entity_id: Optional[str] = None, + **service_data, + ) -> dict[str, JSONType]: + """ + Trigger a service (that returns a response) and return the response. + + Sends command :code:`{"type": "call_service", ...}`. + """ + params = { + "domain": domain, + "service": service, + "service_data": service_data, + "return_response": True, + } + if entity_id is not None: + params["target"] = {"entity_id": entity_id} + + data = await self.async_recv( + await self.async_send("call_service", include_id=True, **params) + ) + + return cast(dict[str, dict[str, JSONType]], cast(ResultResponse, data).result)[ + "response" + ] + + @contextlib.asynccontextmanager + async def async_listen_events( + self, + event_type: Optional[str] = None, + ) -> AsyncGenerator[AsyncGenerator[FiredEvent, None], None]: + """ + Listen for all events of a certain type. + + For example, to listen for all events of type `test_event`: + + .. code-block:: python + + async with ws_client.listen_events("test_event") as events: + async for i, event in zip(range(2), events): # to only wait for two events to be received + print(event) + """ + subscription = await self._async_subscribe_events(event_type) + yield cast(AsyncGenerator[FiredEvent, None], self._async_wait_for(subscription)) + await self._async_unsubscribe(subscription) + + async def _async_subscribe_events(self, event_type: Optional[str]) -> int: + """ + Subscribe to all events of a certain type. + + + Sends command :code:`{"type": "subscribe_events", ...}`. + """ + params = {"event_type": event_type} if event_type else {} + return ( + await self.async_recv( + await self.async_send("subscribe_events", include_id=True, **params) + ) + ).id + + @contextlib.asynccontextmanager + async def async_listen_trigger( + self, trigger: str, **trigger_fields + ) -> AsyncGenerator[AsyncGenerator[dict[str, JSONType], None], None]: + """ + Listen to a Home Assistant trigger. + Allows additional trigger keyword parameters with :code:`**kwargs` (i.e. passing :code:`tag_id=...` for NFC tag triggers). + + For example, in Home Assistant Automations we can subscribe to a state trigger for a light entity with YAML: + + .. code-block:: yaml + + triggers: + # ... + - trigger: state + entity_id: light.kitchen + + To subscribe to that same state trigger with :py:class:`AsyncWebsocketClient` instead + + .. code-block:: python + + async with ws_client.listen_trigger("state", entity_id="light.kitchen") as trigger: + async for event in trigger: # will iterate until we manually break out of the loop + print(event) + if : + break + # exiting the context manager unsubscribes from the trigger + + Woohoo! We can now listen to triggers in Python code! + """ + subscription = await self._async_subscribe_trigger(trigger, **trigger_fields) + yield ( + fired_trigger.variables + async for fired_trigger in cast( + AsyncGenerator[FiredTrigger, None], + self._async_wait_for(subscription), + ) + ) + await self._async_unsubscribe(subscription) + + async def _async_subscribe_trigger(self, trigger: str, **trigger_fields) -> int: + """ + Return the subscription id of the trigger we subscribe to. + + Sends command :code:`{"type": "subscribe_trigger", ...}`. + """ + return ( + await self.async_recv( + await self.async_send( + "subscribe_trigger", trigger={"platform": trigger, **trigger_fields} + ) + ) + ).id + + async def _async_wait_for( + self, subscription_id: int + ) -> AsyncGenerator[Union[FiredEvent, FiredTrigger], None]: + """ + An iterator that waits for events of a certain type. + """ + while True: + yield cast( + Union[ + FiredEvent, FiredTrigger + ], # we can cast this because TemplateEvent is only used for rendering templates + cast(EventResponse, await self.async_recv(subscription_id)).event, + ) + + async def _async_unsubscribe(self, subcription_id: int) -> None: + """ + Unsubscribe from all events of a certain type. + + Sends command :code:`{"type": "unsubscribe_events", ...}`. + """ + resp = await self.async_recv( + await self.async_send("unsubscribe_events", subscription=subcription_id) + ) + assert cast(ResultResponse, resp).result is None + self._event_responses.pop(subcription_id) + + async def async_fire_event(self, event_type: str, **event_data) -> Context: + """ + Fire an event. + + Sends command :code:`{"type": "fire_event", ...}`. + """ + params: dict[str, JSONType] = {"event_type": event_type} + if event_data: + params["event_data"] = event_data + return Context.from_json( + cast( + dict[str, dict[str, JSONType]], + cast( + ResultResponse, + await self.async_recv( + await self.async_send("fire_event", include_id=True, **params) + ), + ).result, + )["context"] + ) diff --git a/homeassistant_api/rawwebsocket.py b/homeassistant_api/rawwebsocket.py index c14560dd..ef00fdb5 100644 --- a/homeassistant_api/rawwebsocket.py +++ b/homeassistant_api/rawwebsocket.py @@ -1,7 +1,8 @@ +import contextlib import json import logging import time -from typing import Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple, Union, cast import websockets.sync.client as ws from pydantic import ValidationError @@ -11,30 +12,32 @@ ResponseError, UnauthorizedError, ) +from homeassistant_api.models import Domain, Entity, Group, State +from homeassistant_api.models.states import Context from homeassistant_api.models.websocket import ( AuthInvalid, AuthOk, AuthRequired, EventResponse, + FiredEvent, + FiredTrigger, PingResponse, ResultResponse, + TemplateEvent, ) from homeassistant_api.rawbasewebsocket import RawBaseWebsocketClient -from homeassistant_api.utils import JSONType +from homeassistant_api.utils import JSONType, prepare_entity_id + +if TYPE_CHECKING: + from homeassistant_api import WebsocketClient logger = logging.getLogger(__name__) class RawWebsocketClient(RawBaseWebsocketClient): - api_url: str - token: str _conn: Optional[ws.ClientConnection] - def __init__( - self, - api_url: str, - token: str, - ) -> None: + def __init__(self, api_url: str, token: str) -> None: super().__init__(api_url, token) self._conn = None @@ -164,3 +167,333 @@ def ping_latency(self) -> float: pong = cast(PingResponse, self.recv(self.send("ping"))) assert pong.end is not None return (pong.end - pong.start) / 1_000_000 + + def get_rendered_template(self, template: str) -> str: + """ + Renders a Jinja2 template with Home Assistant context data. + See https://www.home-assistant.io/docs/configuration/templating. + + Sends command :code:`{"type": "render_template", ...}`. + """ + id = self.send("render_template", template=template, report_errors=True) + first = self.recv(id) + assert cast(ResultResponse, first).result is None + second = self.recv(id) + self._unsubscribe(id) + return cast(TemplateEvent, cast(EventResponse, second).event).result + + def get_config(self) -> dict[str, JSONType]: + """ + Get the Home Assistant configuration. + + Sends command :code:`{"type": "get_config", ...}`. + """ + return cast( + dict[str, JSONType], + cast( + ResultResponse, + self.recv(self.send("get_config")), + ).result, + ) + + def get_states(self) -> Tuple[State, ...]: + """ + Get a list of states. + + Sends command :code:`{"type": "get_states", ...}`. + """ + return tuple( + State.from_json(state) + for state in cast( + list[dict[str, JSONType]], + cast(ResultResponse, self.recv(self.send("get_states"))).result, + ) + ) + + def get_state( # pylint: disable=duplicate-code + self, + *, + entity_id: Optional[str] = None, + group_id: Optional[str] = None, + slug: Optional[str] = None, + ) -> State: + """ + Just calls the :py:meth:`get_states` method and filters the result. + + Please tell home-assistant/core to add a :code:`{"type": "get_state", ...}` command to the WS API! + There is a lot of disappointment and frustration in the community because this is not available. + """ + entity_id = prepare_entity_id( + group_id=group_id, + slug=slug, + entity_id=entity_id, + ) + + for state in self.get_states(): + if state.entity_id == entity_id: + return state + raise ValueError(f"Entity {entity_id} not found!") + + def get_entities(self) -> Dict[str, Group]: + """ + Fetches all entities from the Websocket API and returns them as a dictionary of :py:class:`Group`'s. + For example :code:`light.living_room` would be in the group :code:`light` (i.e. :code:`get_entities()["light"].living_room`). + """ + entities: Dict[str, Group] = {} + for state in self.get_states(): + group_id, entity_slug = state.entity_id.split(".") + if group_id not in entities: + entities[group_id] = Group( + group_id=group_id, + _client=self, # type: ignore[arg-type] + ) + entities[group_id]._add_entity(entity_slug, state) + return entities + + def get_entity( + self, + group_id: Optional[str] = None, + slug: Optional[str] = None, + entity_id: Optional[str] = None, + ) -> Optional[Entity]: + """ + Returns an :py:class:`Entity` model for an :code:`entity_id`. + + Calls :py:meth:`get_states` under the hood. + + Please tell home-assistant/core to add a :code:`{"type": "get_state", ...}` command to the WS API! + There is a lot of disappointment and frustration in the community because this is not available. + """ + if group_id is not None and slug is not None: + state = self.get_state(group_id=group_id, slug=slug) + elif entity_id is not None: + state = self.get_state(entity_id=entity_id) + else: + help_msg = ( + "Use keyword arguments to pass entity_id. " + "Or you can pass the group_id and slug instead" + ) + raise ValueError( + f"Neither group_id and slug or entity_id provided. {help_msg}" + ) + split_group_id, split_slug = state.entity_id.split(".") + group = Group( + group_id=split_group_id, + _client=self, # type: ignore[arg-type] + ) + group._add_entity(split_slug, state) + return group.get_entity(split_slug) + + def get_domains(self) -> dict[str, Domain]: + """ + Get a list of services that Home Assistant offers (organized into a dictionary of service domains). + + For example, the service :code:`light.turn_on` would be in the domain :code:`light`. + + Sends command :code:`{"type": "get_services", ...}`. + """ + resp = self.recv(self.send("get_services")) + domains = map( + lambda item: Domain.from_json( + {"domain": item[0], "services": item[1]}, + client=cast(WebsocketClient, self), + ), + cast(dict[str, JSONType], cast(ResultResponse, resp).result).items(), + ) + return {domain.domain_id: domain for domain in domains} + + def get_domain(self, domain: str) -> Domain: + """Get a domain. + + Note: This is not a method in the WS API client... yet. + + Please tell home-assistant/core to add a `get_domain` command to the WS API! + + For now, just call the :py:meth":`get_domains` method and parsing the result. + """ + return self.get_domains()[domain] + + def trigger_service( + self, + domain: str, + service: str, + entity_id: Optional[str] = None, + **service_data, + ) -> None: + """ + Trigger a service (that doesn't return a response). + + Sends command :code:`{"type": "call_service", ...}`. + """ + params = { + "domain": domain, + "service": service, + "service_data": service_data, + "return_response": False, + } + if entity_id is not None: + params["target"] = {"entity_id": entity_id} + + data = self.recv(self.send("call_service", include_id=True, **params)) + + # TODO: handle data["result"]["context"] ? + + assert ( + cast( + dict[str, JSONType], + cast(ResultResponse, data).result, + ).get("response") + is None + ) # should always be None for services without a response + + def trigger_service_with_response( + self, + domain: str, + service: str, + entity_id: Optional[str] = None, + **service_data, + ) -> dict[str, JSONType]: + """ + Trigger a service (that returns a response) and return the response. + + Sends command :code:`{"type": "call_service", ...}`. + """ + params = { + "domain": domain, + "service": service, + "service_data": service_data, + "return_response": True, + } + if entity_id is not None: + params["target"] = {"entity_id": entity_id} + + data = self.recv(self.send("call_service", include_id=True, **params)) + + return cast(dict[str, dict[str, JSONType]], cast(ResultResponse, data).result)[ + "response" + ] + + @contextlib.contextmanager + def listen_events( + self, + event_type: Optional[str] = None, + ) -> Generator[Generator[FiredEvent, None, None], None, None]: + """ + Listen for all events of a certain type. + + For example, to listen for all events of type `test_event`: + + .. code-block:: python + + with ws_client.listen_events("test_event") as events: + for i, event in zip(range(2), events): # to only wait for two events to be received + print(event) + """ + subscription = self._subscribe_events(event_type) + yield cast(Generator[FiredEvent, None, None], self._wait_for(subscription)) + self._unsubscribe(subscription) + + def _subscribe_events(self, event_type: Optional[str]) -> int: + """ + Subscribe to all events of a certain type. + + + Sends command :code:`{"type": "subscribe_events", ...}`. + """ + params = {"event_type": event_type} if event_type else {} + return self.recv(self.send("subscribe_events", include_id=True, **params)).id + + @contextlib.contextmanager + def listen_trigger( + self, trigger: str, **trigger_fields + ) -> Generator[Generator[dict[str, JSONType], None, None], None, None]: + """ + Listen to a Home Assistant trigger. + Allows additional trigger keyword parameters with :code:`**kwargs` (i.e. passing :code:`tag_id=...` for NFC tag triggers). + + For example, in Home Assistant Automations we can subscribe to a state trigger for a light entity with YAML: + + .. code-block:: yaml + + triggers: + # ... + - trigger: state + entity_id: light.kitchen + + To subscribe to that same state trigger with :py:class:`WebsocketClient` instead + + .. code-block:: python + + with ws_client.listen_trigger("state", entity_id="light.kitchen") as trigger: + for event in trigger: # will iterate until we manually break out of the loop + print(event) + if : + break + # exiting the context manager unsubscribes from the trigger + + Woohoo! We can now listen to triggers in Python code! + """ + subscription = self._subscribe_trigger(trigger, **trigger_fields) + yield ( + fired_trigger.variables + for fired_trigger in cast( + Generator[FiredTrigger, None, None], + self._wait_for(subscription), + ) + ) + self._unsubscribe(subscription) + + def _subscribe_trigger(self, trigger: str, **trigger_fields) -> int: + """ + Return the subscription id of the trigger we subscribe to. + + Sends command :code:`{"type": "subscribe_trigger", ...}`. + """ + return self.recv( + self.send( + "subscribe_trigger", trigger={"platform": trigger, **trigger_fields} + ) + ).id + + def _wait_for( + self, subscription_id: int + ) -> Generator[Union[FiredEvent, FiredTrigger], None, None]: + """ + An iterator that waits for events of a certain type. + """ + while True: + yield cast( + Union[ + FiredEvent, FiredTrigger + ], # we can cast this because TemplateEvent is only used for rendering templates + cast(EventResponse, self.recv(subscription_id)).event, + ) + + def _unsubscribe(self, subcription_id: int) -> None: + """ + Unsubscribe from all events of a certain type. + + Sends command :code:`{"type": "unsubscribe_events", ...}`. + """ + resp = self.recv(self.send("unsubscribe_events", subscription=subcription_id)) + assert cast(ResultResponse, resp).result is None + self._event_responses.pop(subcription_id) + + def fire_event(self, event_type: str, **event_data) -> Context: + """ + Fire an event. + + Sends command :code:`{"type": "fire_event", ...}`. + """ + params: dict[str, JSONType] = {"event_type": event_type} + if event_data: + params["event_data"] = event_data + return Context.from_json( + cast( + dict[str, dict[str, JSONType]], + cast( + ResultResponse, + self.recv(self.send("fire_event", include_id=True, **params)), + ).result, + )["context"] + ) diff --git a/homeassistant_api/websocket.py b/homeassistant_api/websocket.py index 9af064da..5a4d0d0a 100644 --- a/homeassistant_api/websocket.py +++ b/homeassistant_api/websocket.py @@ -1,38 +1,18 @@ -import contextlib +"""Module containing the primary Client class.""" + import logging import urllib.parse as urlparse -from typing import Dict, Generator, List, Optional, Tuple, Union, cast -from homeassistant_api.models import ( - ConfigEntry, - ConfigEntryEvent, - ConfigSubEntry, - DisableEnableResult, - Domain, - Entity, - FlowResult, - Group, - IntegrationTypes, - State, -) -from homeassistant_api.models.states import Context -from homeassistant_api.models.websocket import ( - EventResponse, - FiredEvent, - FiredTrigger, - ResultResponse, - TemplateEvent, -) -from homeassistant_api.rawwebsocket import RawWebsocketClient -from homeassistant_api.utils import JSONType, prepare_entity_id +from .rawasyncwebsocket import RawAsyncWebsocketClient +from .rawwebsocket import RawWebsocketClient logger = logging.getLogger(__name__) -class WebsocketClient(RawWebsocketClient): +class WebsocketClient(RawWebsocketClient, RawAsyncWebsocketClient): """ - The main class for interactign with the Home Assistant WebSocket API client. + The main class for interacting with the Home Assistant WebSocket API client. Here's a quick example of how to use the :py:class:`WebsocketClient` class: @@ -47,496 +27,19 @@ class WebsocketClient(RawWebsocketClient): light = ws_client.trigger_service('light', 'turn_on', entity_id="light.living_room") """ - def __init__( - self, - api_url: str, - token: str, - ) -> None: + def __init__(self, api_url: str, token: str, use_async: bool = False) -> None: parsed = urlparse.urlparse(api_url) - if parsed.scheme not in {"ws", "wss"}: - raise ValueError(f"Unknown scheme {parsed.scheme} in {api_url}") - super().__init__(api_url, token) - logger.debug(f"WebSocketClient initialized with api_url: {api_url}") - - def get_rendered_template(self, template: str) -> str: - """ - Renders a Jinja2 template with Home Assistant context data. - See https://www.home-assistant.io/docs/configuration/templating. - - Sends command :code:`{"type": "render_template", ...}`. - """ - id = self.send("render_template", template=template, report_errors=True) - first = self.recv(id) - assert cast(ResultResponse, first).result is None - second = self.recv(id) - self._unsubscribe(id) - return cast(TemplateEvent, cast(EventResponse, second).event).result - - def get_config(self) -> dict[str, JSONType]: - """ - Get the Home Assistant configuration. - - Sends command :code:`{"type": "get_config", ...}`. - """ - return cast( - dict[str, JSONType], - cast( - ResultResponse, - self.recv(self.send("get_config")), - ).result, - ) - - def get_states(self) -> Tuple[State, ...]: - """ - Get a list of states. - - Sends command :code:`{"type": "get_states", ...}`. - """ - return tuple( - State.from_json(state) - for state in cast( - list[dict[str, JSONType]], - cast(ResultResponse, self.recv(self.send("get_states"))).result, - ) - ) - - def get_state( # pylint: disable=duplicate-code - self, - *, - entity_id: Optional[str] = None, - group_id: Optional[str] = None, - slug: Optional[str] = None, - ) -> State: - """ - Just calls the :py:meth:`get_states` method and filters the result. - - Please tell home-assistant/core to add a :code:`{"type": "get_state", ...}` command to the WS API! - There is a lot of disappointment and frustration in the community because this is not available. - """ - entity_id = prepare_entity_id( - group_id=group_id, - slug=slug, - entity_id=entity_id, - ) - - for state in self.get_states(): - if state.entity_id == entity_id: - return state - raise ValueError(f"Entity {entity_id} not found!") - - def get_entities(self) -> Dict[str, Group]: - """ - Fetches all entities from the Websocket API and returns them as a dictionary of :py:class:`Group`'s. - For example :code:`light.living_room` would be in the group :code:`light` (i.e. :code:`get_entities()["light"].living_room`). - """ - entities: Dict[str, Group] = {} - for state in self.get_states(): - group_id, entity_slug = state.entity_id.split(".") - if group_id not in entities: - entities[group_id] = Group( - group_id=group_id, - _client=self, # type: ignore[arg-type] - ) - entities[group_id]._add_entity(entity_slug, state) - return entities - - def get_entity( - self, - group_id: Optional[str] = None, - slug: Optional[str] = None, - entity_id: Optional[str] = None, - ) -> Optional[Entity]: - """ - Returns an :py:class:`Entity` model for an :code:`entity_id`. - - Calls :py:meth:`get_states` under the hood. - - Please tell home-assistant/core to add a :code:`{"type": "get_state", ...}` command to the WS API! - There is a lot of disappointment and frustration in the community because this is not available. - """ - if group_id is not None and slug is not None: - state = self.get_state(group_id=group_id, slug=slug) - elif entity_id is not None: - state = self.get_state(entity_id=entity_id) + if parsed.scheme in {"ws", "wss"}: + if use_async: + RawAsyncWebsocketClient.__init__(self, api_url, token) + client_type = "Async" + else: + RawWebsocketClient.__init__(self, api_url, token) + client_type = "" else: - help_msg = ( - "Use keyword arguments to pass entity_id. " - "Or you can pass the group_id and slug instead" - ) - raise ValueError( - f"Neither group_id and slug or entity_id provided. {help_msg}" - ) - split_group_id, split_slug = state.entity_id.split(".") - group = Group( - group_id=split_group_id, - _client=self, # type: ignore[arg-type] - ) - group._add_entity(split_slug, state) - return group.get_entity(split_slug) - - def get_domains(self) -> dict[str, Domain]: - """ - Get a list of services that Home Assistant offers (organized into a dictionary of service domains). - - For example, the service :code:`light.turn_on` would be in the domain :code:`light`. - - Sends command :code:`{"type": "get_services", ...}`. - """ - resp = self.recv(self.send("get_services")) - domains = map( - lambda item: Domain.from_json_with_client( - {"domain": item[0], "services": item[1]}, - client=self, - ), - cast(dict[str, JSONType], cast(ResultResponse, resp).result).items(), - ) - return {domain.domain_id: domain for domain in domains} - - def get_domain(self, domain: str) -> Domain: - """Get a domain. - - Note: This is not a method in the WS API client... yet. - - Please tell home-assistant/core to add a `get_domain` command to the WS API! - - For now, just call the :py:meth":`get_domains` method and parsing the result. - """ - return self.get_domains()[domain] - - # config_entries.py - - def get_nonuser_flows_in_progress(self) -> Tuple[FlowResult, ...]: - """ - Get config entries that are in progress but not initiated by a user. - - Sends command :code:`{"type": "config_entries/flow/progress"}`. - """ - return tuple( - FlowResult.from_json(flow_result) - for flow_result in cast( - list[dict[str, JSONType]], - cast( - ResultResponse, self.recv(self.send("config_entries/flow/progress")) - ).result, - ) - ) - - def disable_config_entry(self, entry_id: str) -> DisableEnableResult: - """ - Disable a config entry. - - Sends command :code:`{"type": "config_entries/disable", disabled_by="user", ...}`. - """ - return DisableEnableResult.from_json( - cast( - ResultResponse, - self.recv( - self.send( - "config_entries/disable", entry_id=entry_id, disabled_by="user" - ) - ), - ).result, - ) - - def enable_config_entry(self, entry_id: str) -> DisableEnableResult: - """Enable a config entry. - - Sends command :code:`{"type": "config_entries/disable", disabled_by=None, ...}`. - - """ - return DisableEnableResult.from_json( - cast( - ResultResponse, - self.recv( - self.send( - "config_entries/disable", entry_id=entry_id, disabled_by=None - ) - ), - ).result, - ) - - def ignore_config_flow(self, flow_id: str, title: str) -> None: - """ - Ignore an active config flow. - - Sends command :code:`{"type": "config_entries/ignore_flow", ...}`. - """ - self.recv(self.send("config_entries/ignore_flow", flow_id=flow_id, title=title)) - - def get_config_entries( - self, type_filter: List[IntegrationTypes] = [], domain: str = "" - ) -> Tuple[ConfigEntry, ...]: - """ - Get filtered config entries. - - Sends command :code:`{"type": "config_entries/get", ...}`. - """ - return tuple( - ConfigEntry.from_json(config_entry) - for config_entry in cast( - list[dict[str, JSONType]], - cast( - ResultResponse, - self.recv( - self.send( - "config_entries/get", type_filter=type_filter, domain=domain - ) - ), - ).result, - ) - ) - - def _subscribe_config_entries(self) -> int: - """ - Subscribe to config entry flows. - - Sends command :code:`{"type": "config_entries/subscribe"}`. - """ - - return self.recv(self.send("config_entries/subscribe")).id - - @contextlib.contextmanager - def listen_config_entries( - self, disconnect_client: bool = True - ) -> Generator[Generator[List[ConfigEntryEvent], None, None], None, None]: - """ - Listen to all config entry flow events. - - For example: - - .. code-block:: python - - with ws_client.listen_config_entries() as flows: - for i, flow in zip(range(2), flows): # to only wait for two flows to be received - print(flow) - """ - subscription = self._subscribe_config_entries() - yield cast( - Generator[List[ConfigEntryEvent], None, None], self._wait_for(subscription) - ) - # There is no "unsubscribe" method available for these events. - # Provide the ability to "unsubscribe" by disconnecting and reconnecting the Websocket client. - if disconnect_client: - logger.info("Reloading Websocket Client. Undefined behavior may occur.") - self.__exit__(None, None, None) - self.__enter__() - - def get_entry_subentries(self, entry_id: str) -> Tuple[ConfigSubEntry, ...]: - """ - Get an entry's sub-entries. - - Sends command :code:`{"type": "config_entries/subentries/list", ...}`. - """ - return tuple( - ConfigSubEntry.from_json(sub_entry) - for sub_entry in cast( - list[dict[str, JSONType]], - cast( - ResultResponse, - self.recv( - self.send("config_entries/subentries/list", entry_id=entry_id) - ), - ).result, - ) - ) - - # UNTESTED - def delete_entry_subentry(self, entry_id: str, subentry_id: str) -> None: - """ - Delete an entry's sub-entry. - - Sends command :code:`{"type": "config_entries/subentries/delete", ...}`. - """ - self.recv( - self.send( - "config_entries/subentries/delete", - entry_id=entry_id, - subentry_id=subentry_id, - ) - ) - - def trigger_service( - self, - domain: str, - service: str, - entity_id: Optional[str] = None, - **service_data, - ) -> None: - """ - Trigger a service (that doesn't return a response). - - Sends command :code:`{"type": "call_service", ...}`. - """ - params = { - "domain": domain, - "service": service, - "service_data": service_data, - "return_response": False, - } - if entity_id is not None: - params["target"] = {"entity_id": entity_id} - - data = self.recv(self.send("call_service", include_id=True, **params)) - - # TODO: handle data["result"]["context"] ? - - assert ( - cast( - dict[str, JSONType], - cast(ResultResponse, data).result, - ).get("response") - is None - ) # should always be None for services without a response - - def trigger_service_with_response( - self, - domain: str, - service: str, - entity_id: Optional[str] = None, - **service_data, - ) -> dict[str, JSONType]: - """ - Trigger a service (that returns a response) and return the response. - - Sends command :code:`{"type": "call_service", ...}`. - """ - params = { - "domain": domain, - "service": service, - "service_data": service_data, - "return_response": True, - } - if entity_id is not None: - params["target"] = {"entity_id": entity_id} - - data = self.recv(self.send("call_service", include_id=True, **params)) - - return cast(dict[str, dict[str, JSONType]], cast(ResultResponse, data).result)[ - "response" - ] - - @contextlib.contextmanager - def listen_events( - self, - event_type: Optional[str] = None, - ) -> Generator[Generator[FiredEvent, None, None], None, None]: - """ - Listen for all events of a certain type. - - For example, to listen for all events of type `test_event`: - - .. code-block:: python - - with ws_client.listen_events("test_event") as events: - for i, event in zip(range(2), events): # to only wait for two events to be received - print(event) - """ - subscription = self._subscribe_events(event_type) - yield cast(Generator[FiredEvent, None, None], self._wait_for(subscription)) - self._unsubscribe(subscription) - - def _subscribe_events(self, event_type: Optional[str]) -> int: - """ - Subscribe to all events of a certain type. - - - Sends command :code:`{"type": "subscribe_events", ...}`. - """ - params = {"event_type": event_type} if event_type else {} - return self.recv(self.send("subscribe_events", include_id=True, **params)).id - - @contextlib.contextmanager - def listen_trigger( - self, trigger: str, **trigger_fields - ) -> Generator[Generator[dict[str, JSONType], None, None], None, None]: - """ - Listen to a Home Assistant trigger. - Allows additional trigger keyword parameters with :code:`**kwargs` (i.e. passing :code:`tag_id=...` for NFC tag triggers). - - For example, in Home Assistant Automations we can subscribe to a state trigger for a light entity with YAML: - - .. code-block:: yaml - - triggers: - # ... - - trigger: state - entity_id: light.kitchen - - To subscribe to that same state trigger with :py:class:`WebsocketClient` instead - - .. code-block:: python - - with ws_client.listen_trigger("state", entity_id="light.kitchen") as trigger: - for event in trigger: # will iterate until we manually break out of the loop - print(event) - if : - break - # exiting the context manager unsubscribes from the trigger - - Woohoo! We can now listen to triggers in Python code! - """ - subscription = self._subscribe_trigger(trigger, **trigger_fields) - yield ( - fired_trigger.variables - for fired_trigger in cast( - Generator[FiredTrigger, None, None], - self._wait_for(subscription), - ) - ) - self._unsubscribe(subscription) - - def _subscribe_trigger(self, trigger: str, **trigger_fields) -> int: - """ - Return the subscription id of the trigger we subscribe to. - - Sends command :code:`{"type": "subscribe_trigger", ...}`. - """ - return self.recv( - self.send( - "subscribe_trigger", trigger={"platform": trigger, **trigger_fields} - ) - ).id - - def _wait_for( - self, subscription_id: int - ) -> Generator[Union[FiredEvent, FiredTrigger, List[ConfigEntryEvent]], None, None]: - """ - An iterator that waits for events of a certain type. - """ - while True: - yield cast( - Union[ - FiredEvent, FiredTrigger, List[ConfigEntryEvent] - ], # we can cast this because TemplateEvent is only used for rendering templates - cast(EventResponse, self.recv(subscription_id)).event, - ) - - def _unsubscribe(self, subcription_id: int) -> None: - """ - Unsubscribe from all events of a certain type. - - Sends command :code:`{"type": "unsubscribe_events", ...}`. - """ - resp = self.recv(self.send("unsubscribe_events", subscription=subcription_id)) - assert cast(ResultResponse, resp).result is None - self._event_responses.pop(subcription_id) - - def fire_event(self, event_type: str, **event_data) -> Context: - """ - Fire an event. + raise ValueError(f"Unknown scheme {parsed.scheme} in {api_url}") - Sends command :code:`{"type": "fire_event", ...}`. - """ - params: dict[str, JSONType] = {"event_type": event_type} - if event_data: - params["event_data"] = event_data - return Context.from_json( - cast( - dict[str, dict[str, JSONType]], - cast( - ResultResponse, - self.recv(self.send("fire_event", include_id=True, **params)), - ).result, - )["context"] + logger.debug( + f"{client_type}WebSocketClient initialized with api_url: {api_url}" ) From 17c8845b5978d5f71bd35339332ee085814b871b Mon Sep 17 00:00:00 2001 From: Markus Jacobsen Date: Mon, 22 Sep 2025 11:55:35 +0200 Subject: [PATCH 4/9] Fix typing --- homeassistant_api/rawasyncwebsocket.py | 2 ++ homeassistant_api/rawwebsocket.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/homeassistant_api/rawasyncwebsocket.py b/homeassistant_api/rawasyncwebsocket.py index 352d18f3..19cadb76 100644 --- a/homeassistant_api/rawasyncwebsocket.py +++ b/homeassistant_api/rawasyncwebsocket.py @@ -39,6 +39,8 @@ if TYPE_CHECKING: from homeassistant_api import WebsocketClient +else: + WebsocketClient = None # pylint: disable=invalid-name logger = logging.getLogger(__name__) diff --git a/homeassistant_api/rawwebsocket.py b/homeassistant_api/rawwebsocket.py index ef00fdb5..8bb403a5 100644 --- a/homeassistant_api/rawwebsocket.py +++ b/homeassistant_api/rawwebsocket.py @@ -30,6 +30,8 @@ if TYPE_CHECKING: from homeassistant_api import WebsocketClient +else: + WebsocketClient = None # pylint: disable=invalid-name logger = logging.getLogger(__name__) @@ -293,6 +295,7 @@ def get_domains(self) -> dict[str, Domain]: Sends command :code:`{"type": "get_services", ...}`. """ resp = self.recv(self.send("get_services")) + print(resp) domains = map( lambda item: Domain.from_json( {"domain": item[0], "services": item[1]}, From 71e40fb4b6b2d588374bf609bf4d634b15efb856 Mon Sep 17 00:00:00 2001 From: Markus Jacobsen Date: Wed, 24 Sep 2025 11:09:38 +0200 Subject: [PATCH 5/9] Add testing --- tests/conftest.py | 13 ++++++++ tests/test_client.py | 9 +++++ tests/test_endpoints.py | 74 +++++++++++++++++++++++++++++++++++++++++ tests/test_errors.py | 10 ++++++ tests/test_events.py | 27 +++++++++++++++ 5 files changed, 133 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 16c26a75..aeeda93f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,3 +57,16 @@ def setup_websocket_client( os.environ["HOMEASSISTANTAPI_TOKEN"], ) as client: yield client + + +@pytest.fixture(name="async_websocket_client", scope="session") +async def setup_async_websocket_client( + wait_for_server: Literal[None], +) -> AsyncGenerator[Client, None]: + """Initializes the Client and enters an async WebSocket session.""" + async with WebsocketClient( + os.environ["HOMEASSISTANTAPI_WS_URL"], + os.environ["HOMEASSISTANTAPI_TOKEN"], + use_async=True + ) as client: + yield client diff --git a/tests/test_client.py b/tests/test_client.py index 22c3a6d2..2a5a2447 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -55,3 +55,12 @@ def test_websocket_client_ping() -> None: os.environ["HOMEASSISTANTAPI_TOKEN"], ) as client: assert client.ping_latency() > 0 + + +async def test_async_websocket_client_ping() -> None: + async with WebsocketClient( + os.environ["HOMEASSISTANTAPI_WS_URL"], + os.environ["HOMEASSISTANTAPI_TOKEN"], + use_async=True + ) as client: + assert (await client.async_ping_latency()) > 0 diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index f712ac9c..a3e89c0e 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -129,6 +129,19 @@ def test_websocket_get_rendered_template(websocket_client: WebsocketClient) -> N } +async def test_async_websocket_get_rendered_template( + async_websocket_client: WebsocketClient +) -> None: + """Tests the `"type": "render_template"` websocket command.""" + rendered_template = await async_websocket_client.async_get_rendered_template( + 'The sun is {{ states("sun.sun").replace("_", " the ") }}.' + ) + assert rendered_template in { + "The sun is above the horizon.", + "The sun is below the horizon.", + } + + def test_check_api_config(cached_client: Client) -> None: """Tests the `POST /api/config/core/check_config` endpoint.""" assert cached_client.check_api_config() @@ -199,6 +212,14 @@ def test_websocket_get_entities(websocket_client: WebsocketClient) -> None: assert "sun" in entities +async def test_async_websocket_get_entities( + async_websocket_client: WebsocketClient +) -> None: + """Tests the `"type": "get_entities"` websocket command.""" + entities = await async_websocket_client.async_get_entities() + assert "sun" in entities + + def test_get_domains(cached_client: Client) -> None: """Tests the `GET /api/services` endpoint.""" domains = cached_client.get_domains() @@ -217,6 +238,14 @@ def test_websocket_get_domains(websocket_client: WebsocketClient) -> None: assert "homeassistant" in domains +async def test_async_websocket_get_domains( + async_websocket_client: WebsocketClient +) -> None: + """Tests the `"type": "get_domains"` websocket command.""" + domains = await async_websocket_client.async_get_domains() + assert "homeassistant" in domains + + def test_get_domain(cached_client: Client) -> None: """Tests the `GET /api/services` endpoint.""" domain = cached_client.get_domain("homeassistant") @@ -238,6 +267,15 @@ def test_websocket_get_domain(websocket_client: WebsocketClient) -> None: assert domain.services +async def test_async_websocket_get_domain( + async_websocket_client: WebsocketClient +) -> None: + """Tests the `"type": "get_domain"` websocket command.""" + domain = await async_websocket_client.async_get_domain("homeassistant") + assert domain is not None + assert domain.services + + def test_get_nonuser_flows_in_progress(websocket_client: WebsocketClient) -> None: """Tests the `"type": "config_entries/flow/progress"` websocket command.""" # No flows in progress @@ -364,6 +402,19 @@ def test_websocket_trigger_service(websocket_client: WebsocketClient) -> None: assert resp is None +async def test_async_websocket_trigger_service( + async_websocket_client: WebsocketClient +) -> None: + """Tests the `"type": "trigger_service"` websocket command.""" + notify = await async_websocket_client.async_get_domain("notify") + assert notify is not None + resp = await notify.persistent_notification( + message="Your API Test Suite just said hello!", title="Test Suite Notifcation" + ) + # Websocket API doesnt return changed states so we check for None + assert resp is None + + def test_websocket_trigger_service_with_entity_id( websocket_client: WebsocketClient, ) -> None: @@ -416,6 +467,20 @@ def test_websocket_trigger_service_with_response( assert data is not None +async def test_async_websocket_trigger_service_with_response( + async_websocket_client: WebsocketClient +) -> None: + """Tests the `"type": "trigger_service_with_response"` websocket command.""" + weather = await async_websocket_client.async_get_domain("weather") + assert weather is not None + data = weather.get_forecasts( + entity_id="weather.forecast_home", + type="hourly", + ) + # Websocket API doesnt return changed states so we check data is not None because we expect a response + assert data is not None + + def test_get_states(cached_client: Client) -> None: """Tests the `GET /api/states` endpoint.""" states = cached_client.get_states() @@ -437,6 +502,15 @@ def test_websocket_get_states(websocket_client: WebsocketClient) -> None: assert isinstance(state, State) +async def test_async_websocket_get_states( + async_websocket_client: WebsocketClient +) -> None: + """Tests the `"type": "get_states"` websocket command.""" + states = await async_websocket_client.async_get_states() + for state in states: + assert isinstance(state, State) + + def test_get_state(cached_client: Client) -> None: """Tests the `GET /api/states/` endpoint.""" state = cached_client.get_state(entity_id="sun.sun") diff --git a/tests/test_errors.py b/tests/test_errors.py index f892a015..471716d4 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -45,6 +45,16 @@ def test_websocket_unauthorized() -> None: pass +async def test_async_websocket_unauthorized() -> None: + with pytest.raises(UnauthorizedError): + async with WebsocketClient( + os.environ["HOMEASSISTANTAPI_WS_URL"], + "lolthisisawrongtokenforsure", + use_async=True, + ): + pass + + async def test_async_unauthorized() -> None: with pytest.raises(UnauthorizedError): async with Client( diff --git a/tests/test_events.py b/tests/test_events.py index 458e3342..f8cc0cdc 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -21,6 +21,17 @@ def test_listen_events(websocket_client: WebsocketClient) -> None: assert event.data["message"] == "Triggered by websocket client" +async def test_async_listen_events(async_websocket_client: WebsocketClient) -> None: + async with async_websocket_client.async_listen_events("async_test_event") as events: + await async_websocket_client.async_fire_event( + "async_test_event", message="Triggered by async websocket client" + ) + async for _, event in zip(range(1), events): + assert event.origin == "LOCAL" + assert event.event_type == "async_test_event" + assert event.data["message"] == "Triggered by async websocket client" + + def test_listen_trigger(websocket_client: WebsocketClient) -> None: future = datetime.fromisoformat( websocket_client.get_rendered_template("{{ (now() + timedelta(seconds=1)) }}") @@ -70,3 +81,19 @@ def test_listen_config_entries(websocket_client: WebsocketClient) -> None: assert flow[0].type == ConfigEntryChange.UPDATED assert flow[0].entry.disabled_by is None assert flow[0].entry.state == ConfigEntryState.LOADED + + +async def test_async_listen_trigger(async_websocket_client: WebsocketClient) -> None: + future = datetime.fromisoformat( + await async_websocket_client.async_get_rendered_template( + "{{ (now() + timedelta(seconds=1)) }}" + ) + ) + async with async_websocket_client.async_listen_trigger( + "time", at=future.strftime("%H:%M:%S") + ) as triggers: + async for _, trigger in zip(range(1), triggers): + assert trigger["trigger"]["platform"] == "time" + assert datetime.fromisoformat( + trigger["trigger"]["now"] + ).timestamp() == pytest.approx(future.timestamp(), abs=1) From a7311857a39844d02b0a6f7744950175b9aa90fc Mon Sep 17 00:00:00 2001 From: Markus Jacobsen Date: Thu, 25 Sep 2025 14:27:18 +0200 Subject: [PATCH 6/9] Use breaks instead of zip(range(1), ...) in async event tests --- tests/test_events.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_events.py b/tests/test_events.py index f8cc0cdc..ed5fcd5b 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -26,10 +26,12 @@ async def test_async_listen_events(async_websocket_client: WebsocketClient) -> N await async_websocket_client.async_fire_event( "async_test_event", message="Triggered by async websocket client" ) - async for _, event in zip(range(1), events): + # Typing breaks when using zip in an async context, so break instead + async for event in events: assert event.origin == "LOCAL" assert event.event_type == "async_test_event" assert event.data["message"] == "Triggered by async websocket client" + break def test_listen_trigger(websocket_client: WebsocketClient) -> None: @@ -92,8 +94,10 @@ async def test_async_listen_trigger(async_websocket_client: WebsocketClient) -> async with async_websocket_client.async_listen_trigger( "time", at=future.strftime("%H:%M:%S") ) as triggers: - async for _, trigger in zip(range(1), triggers): + # Typing breaks when using zip in an async context, so break instead + async for trigger in triggers: assert trigger["trigger"]["platform"] == "time" assert datetime.fromisoformat( trigger["trigger"]["now"] ).timestamp() == pytest.approx(future.timestamp(), abs=1) + break From b4de87ac39437c3f5f02355c357d8e8bf7156a76 Mon Sep 17 00:00:00 2001 From: Adam Logan Date: Sun, 22 Mar 2026 09:38:20 -0700 Subject: [PATCH 7/9] fixes from rebase --- homeassistant_api/models/domains.py | 10 +++------- tests/conftest.py | 2 +- tests/test_client.py | 2 +- tests/test_endpoints.py | 14 +++++++------- 4 files changed, 12 insertions(+), 16 deletions(-) diff --git a/homeassistant_api/models/domains.py b/homeassistant_api/models/domains.py index 0499fc44..467defe0 100644 --- a/homeassistant_api/models/domains.py +++ b/homeassistant_api/models/domains.py @@ -65,10 +65,6 @@ def from_json(cls, json: Union[dict[str, JSONType], Any, None], **kwargs) -> Sel @classmethod def from_json_with_client( cls, json: Dict[str, JSONType], client: Union["Client", "WebsocketClient"] - def from_json( - cls, - json: Dict[str, JSONType], - client: Union["Client", "WebsocketClient"], ) -> "Domain": """Constructs Domain and Service models from json data.""" if "domain" not in json or "services" not in json: @@ -387,9 +383,9 @@ class ServiceFieldSelectorObject(BaseModel): class ServiceFieldSelectorQRCode(BaseModel): data: str scale: Optional[Union[int, float]] = None - error_correction_level: Optional[ - ServiceFieldSelectorQRCodeErrorCorrectionLevel - ] = None + error_correction_level: Optional[ServiceFieldSelectorQRCodeErrorCorrectionLevel] = ( + None + ) center_image: Optional[str] = None diff --git a/tests/conftest.py b/tests/conftest.py index aeeda93f..1d2ffa94 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -67,6 +67,6 @@ async def setup_async_websocket_client( async with WebsocketClient( os.environ["HOMEASSISTANTAPI_WS_URL"], os.environ["HOMEASSISTANTAPI_TOKEN"], - use_async=True + use_async=True, ) as client: yield client diff --git a/tests/test_client.py b/tests/test_client.py index 2a5a2447..6a11939b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -61,6 +61,6 @@ async def test_async_websocket_client_ping() -> None: async with WebsocketClient( os.environ["HOMEASSISTANTAPI_WS_URL"], os.environ["HOMEASSISTANTAPI_TOKEN"], - use_async=True + use_async=True, ) as client: assert (await client.async_ping_latency()) > 0 diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index a3e89c0e..4045aa5e 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -130,7 +130,7 @@ def test_websocket_get_rendered_template(websocket_client: WebsocketClient) -> N async def test_async_websocket_get_rendered_template( - async_websocket_client: WebsocketClient + async_websocket_client: WebsocketClient, ) -> None: """Tests the `"type": "render_template"` websocket command.""" rendered_template = await async_websocket_client.async_get_rendered_template( @@ -213,7 +213,7 @@ def test_websocket_get_entities(websocket_client: WebsocketClient) -> None: async def test_async_websocket_get_entities( - async_websocket_client: WebsocketClient + async_websocket_client: WebsocketClient, ) -> None: """Tests the `"type": "get_entities"` websocket command.""" entities = await async_websocket_client.async_get_entities() @@ -239,7 +239,7 @@ def test_websocket_get_domains(websocket_client: WebsocketClient) -> None: async def test_async_websocket_get_domains( - async_websocket_client: WebsocketClient + async_websocket_client: WebsocketClient, ) -> None: """Tests the `"type": "get_domains"` websocket command.""" domains = await async_websocket_client.async_get_domains() @@ -268,7 +268,7 @@ def test_websocket_get_domain(websocket_client: WebsocketClient) -> None: async def test_async_websocket_get_domain( - async_websocket_client: WebsocketClient + async_websocket_client: WebsocketClient, ) -> None: """Tests the `"type": "get_domain"` websocket command.""" domain = await async_websocket_client.async_get_domain("homeassistant") @@ -403,7 +403,7 @@ def test_websocket_trigger_service(websocket_client: WebsocketClient) -> None: async def test_async_websocket_trigger_service( - async_websocket_client: WebsocketClient + async_websocket_client: WebsocketClient, ) -> None: """Tests the `"type": "trigger_service"` websocket command.""" notify = await async_websocket_client.async_get_domain("notify") @@ -468,7 +468,7 @@ def test_websocket_trigger_service_with_response( async def test_async_websocket_trigger_service_with_response( - async_websocket_client: WebsocketClient + async_websocket_client: WebsocketClient, ) -> None: """Tests the `"type": "trigger_service_with_response"` websocket command.""" weather = await async_websocket_client.async_get_domain("weather") @@ -503,7 +503,7 @@ def test_websocket_get_states(websocket_client: WebsocketClient) -> None: async def test_async_websocket_get_states( - async_websocket_client: WebsocketClient + async_websocket_client: WebsocketClient, ) -> None: """Tests the `"type": "get_states"` websocket command.""" states = await async_websocket_client.async_get_states() From 01ba5b38edd17139927f827f52e6731a625c4cc1 Mon Sep 17 00:00:00 2001 From: Adam Logan Date: Sun, 22 Mar 2026 10:00:01 -0700 Subject: [PATCH 8/9] Add config entry websocket methods and fix Domain.from_json usage Use Domain.from_json_with_client() in websocket clients since Domain requires a client reference. Add config entry methods (get, disable, enable, ignore flow, subentries, subscribe) to RawWebsocketClient. --- homeassistant_api/rawasyncwebsocket.py | 2 +- homeassistant_api/rawwebsocket.py | 143 ++++++++++++++++++++++++- 2 files changed, 141 insertions(+), 4 deletions(-) diff --git a/homeassistant_api/rawasyncwebsocket.py b/homeassistant_api/rawasyncwebsocket.py index 19cadb76..15ca2feb 100644 --- a/homeassistant_api/rawasyncwebsocket.py +++ b/homeassistant_api/rawasyncwebsocket.py @@ -300,7 +300,7 @@ async def async_get_domains(self) -> dict[str, Domain]: """ resp = await self.async_recv(await self.async_send("get_services")) domains = map( - lambda item: Domain.from_json( + lambda item: Domain.from_json_with_client( {"domain": item[0], "services": item[1]}, client=cast(WebsocketClient, self), ), diff --git a/homeassistant_api/rawwebsocket.py b/homeassistant_api/rawwebsocket.py index 8bb403a5..3e71d85a 100644 --- a/homeassistant_api/rawwebsocket.py +++ b/homeassistant_api/rawwebsocket.py @@ -12,7 +12,16 @@ ResponseError, UnauthorizedError, ) -from homeassistant_api.models import Domain, Entity, Group, State +from homeassistant_api.models import ( + ConfigEntry, + ConfigEntryEvent, + ConfigSubEntry, + Domain, + Entity, + Group, + State, +) +from homeassistant_api.models.config_entries import DisableEnableResult, FlowResult from homeassistant_api.models.states import Context from homeassistant_api.models.websocket import ( AuthInvalid, @@ -295,9 +304,8 @@ def get_domains(self) -> dict[str, Domain]: Sends command :code:`{"type": "get_services", ...}`. """ resp = self.recv(self.send("get_services")) - print(resp) domains = map( - lambda item: Domain.from_json( + lambda item: Domain.from_json_with_client( {"domain": item[0], "services": item[1]}, client=cast(WebsocketClient, self), ), @@ -482,6 +490,135 @@ def _unsubscribe(self, subcription_id: int) -> None: assert cast(ResultResponse, resp).result is None self._event_responses.pop(subcription_id) + def get_config_entries(self) -> Tuple[ConfigEntry, ...]: + """ + Get all config entries. + + Sends command :code:`{"type": "config_entries/get", ...}`. + """ + resp = self.recv(self.send("config_entries/get")) + return tuple( + ConfigEntry.from_json(entry) + for entry in cast( + list[dict[str, JSONType]], + cast(ResultResponse, resp).result, + ) + ) + + def disable_config_entry(self, entry_id: str) -> DisableEnableResult: + """ + Disable a config entry. + + Sends command :code:`{"type": "config_entries/disable", ...}`. + """ + resp = self.recv( + self.send( + "config_entries/disable", + entry_id=entry_id, + disabled_by="user", + ) + ) + return DisableEnableResult.from_json( + cast(dict[str, JSONType], cast(ResultResponse, resp).result) + ) + + def enable_config_entry(self, entry_id: str) -> DisableEnableResult: + """ + Enable a config entry. + + Sends command :code:`{"type": "config_entries/disable", ...}`. + """ + resp = self.recv( + self.send( + "config_entries/disable", + entry_id=entry_id, + disabled_by=None, + ) + ) + return DisableEnableResult.from_json( + cast(dict[str, JSONType], cast(ResultResponse, resp).result) + ) + + def ignore_config_flow(self, flow_id: str, title: str) -> None: + """ + Ignore a config flow. + + Sends command :code:`{"type": "config_entries/ignore_flow", ...}`. + """ + self.recv( + self.send( + "config_entries/ignore_flow", + flow_id=flow_id, + title=title, + ) + ) + + def get_nonuser_flows_in_progress(self) -> Tuple[FlowResult, ...]: + """ + Get non-user config flows in progress. + + Sends command :code:`{"type": "config_entries/flow/progress", ...}`. + """ + resp = self.recv(self.send("config_entries/flow/progress")) + return tuple( + FlowResult.from_json(flow) + for flow in cast( + list[dict[str, JSONType]], + cast(ResultResponse, resp).result, + ) + ) + + def get_entry_subentries(self, entry_id: str) -> Tuple[ConfigSubEntry, ...]: + """ + Get subentries for a config entry. + + Sends command :code:`{"type": "config_entries/subentries/list", ...}`. + """ + resp = self.recv(self.send("config_entries/subentries/list", entry_id=entry_id)) + return tuple( + ConfigSubEntry.from_json(subentry) + for subentry in cast( + list[dict[str, JSONType]], + cast(ResultResponse, resp).result, + ) + ) + + def delete_entry_subentry(self, entry_id: str, subentry_id: str) -> None: + """ + Delete a subentry from a config entry. + + Sends command :code:`{"type": "config_entries/subentries/delete", ...}`. + """ + self.recv( + self.send( + "config_entries/subentries/delete", + entry_id=entry_id, + subentry_id=subentry_id, + ) + ) + + @contextlib.contextmanager + def listen_config_entries( + self, + ) -> Generator[Generator[list[ConfigEntryEvent], None, None], None, None]: + """ + Listen for config entry changes. + + Sends command :code:`{"type": "config_entries/subscribe", ...}`. + """ + subscription = self.recv(self.send("config_entries/subscribe")).id + yield self._wait_for_config_entries(subscription) + self._unsubscribe(subscription) + + def _wait_for_config_entries( + self, subscription_id: int + ) -> Generator[list[ConfigEntryEvent], None, None]: + """An iterator that waits for config entry events.""" + while True: + event_resp = cast(EventResponse, self.recv(subscription_id)) + entries = cast(list[dict[str, JSONType]], event_resp.event) + yield [ConfigEntryEvent.from_json(entry) for entry in entries] + def fire_event(self, event_type: str, **event_data) -> Context: """ Fire an event. From 4e70cdf7d40642a8c1941d57cc060e4f0ee1399a Mon Sep 17 00:00:00 2001 From: Adam Logan Date: Sun, 22 Mar 2026 13:33:32 -0700 Subject: [PATCH 9/9] Add async config entry methods and improve test coverage to 99% Add async counterparts for all config entry websocket methods in RawAsyncWebsocketClient. Add async tests for websocket state, entity, config entry, and error path coverage. Fix conftest return type annotation, replace zip(range) with break, and use pytest.raises instead of try/except in error tests. --- homeassistant_api/rawasyncwebsocket.py | 150 ++++++++++++++++++++++++- tests/conftest.py | 2 +- tests/test_endpoints.py | 139 ++++++++++++++++++++--- tests/test_events.py | 52 ++++++++- tests/test_websocket.py | 75 ++++++++++++- 5 files changed, 399 insertions(+), 19 deletions(-) diff --git a/homeassistant_api/rawasyncwebsocket.py b/homeassistant_api/rawasyncwebsocket.py index 15ca2feb..6c99ff6d 100644 --- a/homeassistant_api/rawasyncwebsocket.py +++ b/homeassistant_api/rawasyncwebsocket.py @@ -21,7 +21,16 @@ ResponseError, UnauthorizedError, ) -from homeassistant_api.models import Domain, Entity, Group, State +from homeassistant_api.models import ( + ConfigEntry, + ConfigEntryEvent, + ConfigSubEntry, + Domain, + Entity, + Group, + State, +) +from homeassistant_api.models.config_entries import DisableEnableResult, FlowResult from homeassistant_api.models.states import Context from homeassistant_api.models.websocket import ( AuthInvalid, @@ -497,6 +506,145 @@ async def _async_unsubscribe(self, subcription_id: int) -> None: assert cast(ResultResponse, resp).result is None self._event_responses.pop(subcription_id) + async def async_get_config_entries(self) -> Tuple[ConfigEntry, ...]: + """ + Get all config entries. + + Sends command :code:`{"type": "config_entries/get", ...}`. + """ + resp = await self.async_recv(await self.async_send("config_entries/get")) + return tuple( + ConfigEntry.from_json(entry) + for entry in cast( + list[dict[str, JSONType]], + cast(ResultResponse, resp).result, + ) + ) + + async def async_disable_config_entry(self, entry_id: str) -> DisableEnableResult: + """ + Disable a config entry. + + Sends command :code:`{"type": "config_entries/disable", ...}`. + """ + resp = await self.async_recv( + await self.async_send( + "config_entries/disable", + entry_id=entry_id, + disabled_by="user", + ) + ) + return DisableEnableResult.from_json( + cast(dict[str, JSONType], cast(ResultResponse, resp).result) + ) + + async def async_enable_config_entry(self, entry_id: str) -> DisableEnableResult: + """ + Enable a config entry. + + Sends command :code:`{"type": "config_entries/disable", ...}`. + """ + resp = await self.async_recv( + await self.async_send( + "config_entries/disable", + entry_id=entry_id, + disabled_by=None, + ) + ) + return DisableEnableResult.from_json( + cast(dict[str, JSONType], cast(ResultResponse, resp).result) + ) + + async def async_ignore_config_flow(self, flow_id: str, title: str) -> None: + """ + Ignore a config flow. + + Sends command :code:`{"type": "config_entries/ignore_flow", ...}`. + """ + await self.async_recv( + await self.async_send( + "config_entries/ignore_flow", + flow_id=flow_id, + title=title, + ) + ) + + async def async_get_nonuser_flows_in_progress(self) -> Tuple[FlowResult, ...]: + """ + Get non-user config flows in progress. + + Sends command :code:`{"type": "config_entries/flow/progress", ...}`. + """ + resp = await self.async_recv( + await self.async_send("config_entries/flow/progress") + ) + return tuple( + FlowResult.from_json(flow) + for flow in cast( + list[dict[str, JSONType]], + cast(ResultResponse, resp).result, + ) + ) + + async def async_get_entry_subentries( + self, entry_id: str + ) -> Tuple[ConfigSubEntry, ...]: + """ + Get subentries for a config entry. + + Sends command :code:`{"type": "config_entries/subentries/list", ...}`. + """ + resp = await self.async_recv( + await self.async_send("config_entries/subentries/list", entry_id=entry_id) + ) + return tuple( + ConfigSubEntry.from_json(subentry) + for subentry in cast( + list[dict[str, JSONType]], + cast(ResultResponse, resp).result, + ) + ) + + async def async_delete_entry_subentry( + self, entry_id: str, subentry_id: str + ) -> None: + """ + Delete a subentry from a config entry. + + Sends command :code:`{"type": "config_entries/subentries/delete", ...}`. + """ + await self.async_recv( + await self.async_send( + "config_entries/subentries/delete", + entry_id=entry_id, + subentry_id=subentry_id, + ) + ) + + @contextlib.asynccontextmanager + async def async_listen_config_entries( + self, + ) -> AsyncGenerator[AsyncGenerator[list[ConfigEntryEvent], None], None]: + """ + Listen for config entry changes. + + Sends command :code:`{"type": "config_entries/subscribe", ...}`. + """ + subscription = ( + await self.async_recv(await self.async_send("config_entries/subscribe")) + ).id + yield self._async_wait_for_config_entries(subscription) + await self._async_unsubscribe(subscription) + + async def _async_wait_for_config_entries( + self, subscription_id: int + ) -> AsyncGenerator[list[ConfigEntryEvent], None]: + """An async iterator that waits for config entry events.""" + while True: + event_resp = cast(EventResponse, await self.async_recv(subscription_id)) + entries = cast(list[dict[str, JSONType]], event_resp.event) + yield [ConfigEntryEvent.from_json(entry) for entry in entries] + async def async_fire_event(self, event_type: str, **event_data) -> Context: """ Fire an event. diff --git a/tests/conftest.py b/tests/conftest.py index 1d2ffa94..78ef2373 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -62,7 +62,7 @@ def setup_websocket_client( @pytest.fixture(name="async_websocket_client", scope="session") async def setup_async_websocket_client( wait_for_server: Literal[None], -) -> AsyncGenerator[Client, None]: +) -> AsyncGenerator[WebsocketClient, None]: """Initializes the Client and enters an async WebSocket session.""" async with WebsocketClient( os.environ["HOMEASSISTANTAPI_WS_URL"], diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 4045aa5e..0745600d 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -171,6 +171,15 @@ def test_websocket_get_config(websocket_client: WebsocketClient) -> None: assert config.get("state") in {"RUNNING", "NOT_RUNNING"} +async def test_async_websocket_get_config( + async_websocket_client: WebsocketClient, +) -> None: + """Tests the `"type": "get_config"` websocket command.""" + config = await async_websocket_client.async_get_config() + assert isinstance(config, dict) + assert config.get("state") in {"RUNNING", "NOT_RUNNING"} + + def test_websocket_get_state(websocket_client: WebsocketClient) -> None: """Tests WebsocketClient.get_state with entity_id.""" state = websocket_client.get_state(entity_id="sun.sun") @@ -200,6 +209,53 @@ def test_websocket_get_entity_no_args(websocket_client: WebsocketClient) -> None websocket_client.get_entity() +async def test_async_websocket_get_state( + async_websocket_client: WebsocketClient, +) -> None: + """Tests async WebsocketClient.async_get_state with entity_id.""" + state = await async_websocket_client.async_get_state(entity_id="sun.sun") + assert state.entity_id == "sun.sun" + assert state.state in {"above_horizon", "below_horizon"} + + +async def test_async_websocket_get_entity_by_group_slug( + async_websocket_client: WebsocketClient, +) -> None: + """Tests async WebsocketClient.async_get_entity with group_id and slug.""" + entity = await async_websocket_client.async_get_entity(group_id="sun", slug="sun") + assert entity is not None + assert entity.entity_id == "sun.sun" + + +async def test_async_websocket_get_entity_by_entity_id( + async_websocket_client: WebsocketClient, +) -> None: + """Tests async WebsocketClient.async_get_entity with entity_id.""" + entity = await async_websocket_client.async_get_entity(entity_id="sun.sun") + assert entity is not None + assert entity.entity_id == "sun.sun" + + +async def test_async_websocket_get_entity_no_args( + async_websocket_client: WebsocketClient, +) -> None: + """Tests async WebsocketClient.async_get_entity raises ValueError with no arguments.""" + with pytest.raises( + ValueError, match="Neither group_id and slug or entity_id provided" + ): + await async_websocket_client.async_get_entity() + + +async def test_async_websocket_get_state_not_found( + async_websocket_client: WebsocketClient, +) -> None: + """Tests async WebsocketClient.async_get_state raises ValueError for nonexistent entity.""" + with pytest.raises(ValueError, match="not found"): + await async_websocket_client.async_get_state( + entity_id="fake.nonexistent_entity_12345" + ) + + def test_websocket_get_state_not_found(websocket_client: WebsocketClient) -> None: """Tests WebsocketClient.get_state raises ValueError for nonexistent entity.""" with pytest.raises(ValueError, match="not found"): @@ -283,6 +339,14 @@ def test_get_nonuser_flows_in_progress(websocket_client: WebsocketClient) -> Non assert not flows +async def test_async_get_nonuser_flows_in_progress( + async_websocket_client: WebsocketClient, +) -> None: + """Tests the `"type": "config_entries/flow/progress"` websocket command.""" + flows = await async_websocket_client.async_get_nonuser_flows_in_progress() + assert not flows + + def test_disable_enable_config_entry(websocket_client: WebsocketClient) -> None: """Tests the `"type": "config_entries/disable"` websocket command.""" # Get sun entry @@ -299,21 +363,42 @@ def test_disable_enable_config_entry(websocket_client: WebsocketClient) -> None: # Re-enable websocket_client.enable_config_entry(entry.entry_id) - # Check that it was enable + # Check that it was enabled enabled_entry = websocket_client.get_config_entries()[0] assert enabled_entry.disabled_by is None +async def test_async_disable_enable_config_entry( + async_websocket_client: WebsocketClient, +) -> None: + """Tests the `"type": "config_entries/disable"` websocket command.""" + entry = (await async_websocket_client.async_get_config_entries())[0] + assert entry.disabled_by is None + + await async_websocket_client.async_disable_config_entry(entry.entry_id) + + disabled_entry = (await async_websocket_client.async_get_config_entries())[0] + assert disabled_entry.disabled_by is ConfigEntryDisabler.USER + + await async_websocket_client.async_enable_config_entry(entry.entry_id) + + enabled_entry = (await async_websocket_client.async_get_config_entries())[0] + assert enabled_entry.disabled_by is None + + def test_ignore_config_flow(websocket_client: WebsocketClient) -> None: """Tests the `"type": "config_entries/ignore_flow"` websocket command.""" # Currently not able to test as no flows are in progress. Send invalid parameters and handle that error - try: + with pytest.raises(RequestError, match="Config entry not found"): websocket_client.ignore_config_flow("", "") - except RequestError as error: - assert ( - error.__str__() - == "An error occurred while making the request to 'Config entry not found' with data: 'not_found'" - ) + + +async def test_async_ignore_config_flow( + async_websocket_client: WebsocketClient, +) -> None: + """Tests the `"type": "config_entries/ignore_flow"` websocket command.""" + with pytest.raises(RequestError, match="Config entry not found"): + await async_websocket_client.async_ignore_config_flow("", "") def test_get_config_entries(websocket_client: WebsocketClient) -> None: @@ -345,6 +430,20 @@ def test_get_config_entries(websocket_client: WebsocketClient) -> None: assert sun.num_subentries == 0 +async def test_async_get_config_entries( + async_websocket_client: WebsocketClient, +) -> None: + """Tests the `"type": "config_entries/get"` websocket command.""" + entries = await async_websocket_client.async_get_config_entries() + assert len(entries) == 4 + + sun = entries[0] + assert sun.entry_id == "5f8426fa502435857743f302651753c9" + assert sun.domain == "sun" + assert sun.title == "Sun" + assert sun.disabled_by is None + + def test_get_entry_subentries(websocket_client: WebsocketClient) -> None: """Tests the `"type": "config_entries/subentries/list"` websocket command.""" # Currently not able to test as no entries with subentries available @@ -356,16 +455,28 @@ def test_get_entry_subentries(websocket_client: WebsocketClient) -> None: assert not websocket_client.get_entry_subentries(sun.entry_id) +async def test_async_get_entry_subentries( + async_websocket_client: WebsocketClient, +) -> None: + """Tests the `"type": "config_entries/subentries/list"` websocket command.""" + sun = (await async_websocket_client.async_get_config_entries())[0] + assert sun + assert not await async_websocket_client.async_get_entry_subentries(sun.entry_id) + + def test_delete_entry_subentry(websocket_client: WebsocketClient) -> None: """Tests the `"type": "config_entries/subentries/delete"` websocket command.""" # Currently not able to test as no entries with subentries available. Send invalid parameters and handle that error - try: + with pytest.raises(RequestError, match="Config entry not found"): websocket_client.delete_entry_subentry("", "") - except RequestError as error: - assert ( - error.__str__() - == "An error occurred while making the request to 'Config entry not found' with data: 'not_found'" - ) + + +async def test_async_delete_entry_subentry( + async_websocket_client: WebsocketClient, +) -> None: + """Tests the `"type": "config_entries/subentries/delete"` websocket command.""" + with pytest.raises(RequestError, match="Config entry not found"): + await async_websocket_client.async_delete_entry_subentry("", "") def test_trigger_service(cached_client: Client) -> None: @@ -473,7 +584,7 @@ async def test_async_websocket_trigger_service_with_response( """Tests the `"type": "trigger_service_with_response"` websocket command.""" weather = await async_websocket_client.async_get_domain("weather") assert weather is not None - data = weather.get_forecasts( + data = await weather.get_forecasts( entity_id="weather.forecast_home", type="hourly", ) diff --git a/tests/test_events.py b/tests/test_events.py index ed5fcd5b..a8bf6aa4 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -15,10 +15,11 @@ def test_listen_events(websocket_client: WebsocketClient) -> None: websocket_client.fire_event( "test_event", message="Triggered by websocket client" ) - for _, event in zip(range(1), events): + for event in events: assert event.origin == "LOCAL" assert event.event_type == "test_event" assert event.data["message"] == "Triggered by websocket client" + break async def test_async_listen_events(async_websocket_client: WebsocketClient) -> None: @@ -41,11 +42,12 @@ def test_listen_trigger(websocket_client: WebsocketClient) -> None: with websocket_client.listen_trigger( "time", at=future.strftime("%H:%M:%S") ) as triggers: - for _, trigger in zip(range(1), triggers): + for trigger in triggers: assert trigger["trigger"]["platform"] == "time" assert datetime.fromisoformat( trigger["trigger"]["now"] ).timestamp() == pytest.approx(future.timestamp(), abs=1) + break def test_listen_config_entries(websocket_client: WebsocketClient) -> None: @@ -85,6 +87,52 @@ def test_listen_config_entries(websocket_client: WebsocketClient) -> None: assert flow[0].entry.state == ConfigEntryState.LOADED +async def test_async_listen_config_entries( + async_websocket_client: WebsocketClient, +) -> None: + async with async_websocket_client.async_listen_config_entries() as flows: + i = 0 + async for flow in flows: + if i == 0: + # The first "events" are currently available entries + assert flow[0].type is None + assert flow[0].entry.disabled_by is None + assert flow[0].entry.state == ConfigEntryState.LOADED + + # Trigger an "updated" event + await async_websocket_client.async_disable_config_entry( + flow[0].entry.entry_id + ) + + if i == 1: + assert flow[0].type == ConfigEntryChange.UPDATED + assert flow[0].entry.disabled_by == ConfigEntryDisabler.USER + assert flow[0].entry.state == ConfigEntryState.UNLOAD_IN_PROGRESS + + if i == 2: + assert flow[0].type == ConfigEntryChange.UPDATED + assert flow[0].entry.disabled_by == ConfigEntryDisabler.USER + assert flow[0].entry.state == ConfigEntryState.NOT_LOADED + + # Restore original state + await async_websocket_client.async_enable_config_entry( + flow[0].entry.entry_id + ) + + if i == 3: + assert flow[0].type == ConfigEntryChange.UPDATED + assert flow[0].entry.disabled_by is None + assert flow[0].entry.state == ConfigEntryState.SETUP_IN_PROGRESS + + if i == 4: + assert flow[0].type == ConfigEntryChange.UPDATED + assert flow[0].entry.disabled_by is None + assert flow[0].entry.state == ConfigEntryState.LOADED + break + + i += 1 + + async def test_async_listen_trigger(async_websocket_client: WebsocketClient) -> None: future = datetime.fromisoformat( await async_websocket_client.async_get_rendered_template( diff --git a/tests/test_websocket.py b/tests/test_websocket.py index ccd0757a..e86f4d09 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -1,8 +1,9 @@ -"""Unit tests for RawWebsocketClient and WebsocketClient error paths.""" +"""Unit tests for RawWebsocketClient, RawAsyncWebsocketClient, and WebsocketClient error paths.""" import pytest from homeassistant_api.errors import ReceivingError, RequestError, ResponseError +from homeassistant_api.rawasyncwebsocket import RawAsyncWebsocketClient from homeassistant_api.rawwebsocket import RawWebsocketClient from homeassistant_api.models import websocket as ws_models @@ -12,6 +13,11 @@ def make_raw_client() -> RawWebsocketClient: return RawWebsocketClient("ws://localhost:8123/api/websocket", "fake_token") +def make_raw_async_client() -> RawAsyncWebsocketClient: + """Create a RawAsyncWebsocketClient without connecting.""" + return RawAsyncWebsocketClient("ws://localhost:8123/api/websocket", "fake_token") + + def test_exit_without_connection() -> None: """Tests __exit__ raises ReceivingError when connection is not open.""" client = make_raw_client() @@ -98,3 +104,70 @@ def raise_runtime_error(*args, **kwargs): ResponseError, match="Unexpected response during authentication" ): client.authentication_phase() + + +async def test_async_aexit_without_connection() -> None: + """Tests __aexit__ raises ReceivingError when connection is not open.""" + client = make_raw_async_client() + with pytest.raises(ReceivingError, match="Connection is not open"): + await client.__aexit__(None, None, None) + + +async def test_async_send_without_connection() -> None: + """Tests _async_send raises ReceivingError when connection is not open.""" + client = make_raw_async_client() + with pytest.raises(ReceivingError, match="Connection is not open"): + await client._async_send({"type": "test"}) + + +async def test_async_recv_without_connection() -> None: + """Tests _async_recv raises ReceivingError when connection is not open.""" + client = make_raw_async_client() + with pytest.raises(ReceivingError, match="Connection is not open"): + await client._async_recv() + + +async def test_async_authentication_phase_invalid_welcome(monkeypatch) -> None: + """Tests async_authentication_phase raises ResponseError on invalid welcome message.""" + client = make_raw_async_client() + + async def fake_recv(): + return {"type": "not_auth_required"} + + monkeypatch.setattr(client, "_async_recv", fake_recv) + with pytest.raises( + ResponseError, match="Unexpected response during authentication" + ): + await client.async_authentication_phase() + + +async def test_async_authentication_phase_unexpected_auth_response( + monkeypatch, +) -> None: + """Tests async_authentication_phase raises ResponseError when AuthOk.model_validate raises a non-ValidationError.""" + call_count = 0 + + async def fake_recv(): + nonlocal call_count + call_count += 1 + if call_count == 1: + return {"type": "auth_required", "ha_version": "2024.1.0"} + return {"type": "auth_ok", "ha_version": "2024.1.0", "message": "unexpected"} + + client = make_raw_async_client() + monkeypatch.setattr(client, "_async_recv", fake_recv) + + async def fake_send(data): + pass + + monkeypatch.setattr(client, "_async_send", fake_send) + + def raise_runtime_error(*args, **kwargs): + raise RuntimeError("something went wrong") + + monkeypatch.setattr(ws_models.AuthOk, "model_validate", raise_runtime_error) + + with pytest.raises( + ResponseError, match="Unexpected response during authentication" + ): + await client.async_authentication_phase()