From 0328ee391367bb28b900624ff8a58ec0519de96b Mon Sep 17 00:00:00 2001 From: "Mathias L. Baumann" Date: Thu, 2 Apr 2026 16:05:53 +0200 Subject: [PATCH 1/3] Add mock and integration tests for client RPCs Update test helpers and assertions to match the d9a23f3 proto changes: market_area moved to MarketLocationRef, MarketLocationId.value wrapped in MarketLocationIdValue, structured errors replace flat success/error_code fields. Register the integration pytest mark in pyproject.toml. Signed-off-by: Mathias L. Baumann --- pyproject.toml | 3 + tests/test_client.py | 680 ++++++++++++++++++++++++++++++++++++++ tests/test_integration.py | 396 ++++++++++++++++++++++ 3 files changed, 1079 insertions(+) create mode 100644 tests/test_client.py create mode 100644 tests/test_integration.py diff --git a/pyproject.toml b/pyproject.toml index b1fd41b..b66e411 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -183,6 +183,9 @@ testpaths = ["tests", "src"] asyncio_mode = "auto" asyncio_default_fixture_loop_scope = "function" required_plugins = ["pytest-asyncio", "pytest-mock"] +markers = [ + "integration: integration tests requiring a running gRPC server", +] [tool.mypy] explicit_package_bases = true diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..6482abb --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,680 @@ +# License: MIT +# Copyright © 2025 Frequenz Energy-as-a-Service GmbH + +"""Mock tests for the MarketMeteringApiClient gRPC methods.""" + +from __future__ import annotations + +from datetime import datetime, timezone +from typing import AsyncIterator +from unittest.mock import AsyncMock, MagicMock + +from frequenz.api.common.v1alpha8.pagination import ( + pagination_info_pb2 as pagination_info_pb, +) +from frequenz.api.marketmetering.v1alpha1 import marketmetering_pb2 as pb +from google.protobuf.timestamp_pb2 import Timestamp + +from frequenz.client.marketmetering import MarketMeteringApiClient +from frequenz.client.marketmetering.types import ( + ActivationFilter, + DataQuality, + EnergyFlowDirection, + MarketArea, + MarketLocation, + MarketLocationId, + MarketLocationIdType, + MarketLocationRef, + MarketLocationSample, + MarketLocationSeries, + MarketLocationsFilter, + MarketLocationUpdate, + MetricType, + MetricUnit, + PaginationParams, + ResamplingMethod, + TimeResolution, +) + + +def _make_client() -> MarketMeteringApiClient: + """Create a client with a mocked channel.""" + client = MarketMeteringApiClient( + server_url="grpc://localhost:50051?ssl=false", + auth_key="test-key", + connect=False, + ) + # Inject a mock stub so we don't need a real connection. + # pylint: disable=protected-access + client._stub = MagicMock() # noqa: SLF001 + client._channel = MagicMock() # noqa: SLF001 + # pylint: enable=protected-access + return client + + +def _make_ref( + enterprise_id: int = 42, malo_id: str = "DE0000000001" +) -> MarketLocationRef: + return MarketLocationRef( + enterprise_id=enterprise_id, + market_area=MarketArea.EU_DE, + market_location_id=MarketLocationId( + value=malo_id, + type=MarketLocationIdType.MALO_ID, + ), + ) + + +def _make_location() -> MarketLocation: + return MarketLocation( + display_name="Test Location", + supported_directions=[EnergyFlowDirection.IMPORT], + time_resolution=TimeResolution.MIN_15, + payload={}, + ) + + +def _make_timestamp(dt: datetime) -> Timestamp: + ts = Timestamp() + ts.FromDatetime(dt) + return ts + + +def _make_detail_pb( + enterprise_id: int = 42, + malo_id: str = "DE0000000001", + display_name: str = "Test Location", + revision: int = 1, + is_active: bool = True, +) -> pb.MarketLocationDetail: + """Build a MarketLocationDetail protobuf for mock responses.""" + return pb.MarketLocationDetail( + market_location_ref=pb.MarketLocationRef( + enterprise_id=enterprise_id, + market_area=pb.MARKET_AREA_EU_DE, + market_location_id=pb.MarketLocationId( + id=pb.MarketLocationIdValue(value=malo_id), + type=pb.MARKET_LOCATION_ID_TYPE_MALO_ID, + ), + ), + market_location=pb.MarketLocation( + display_name=display_name, + supported_directions=[pb.ENERGY_FLOW_DIRECTION_IMPORT], + time_resolution=pb.TIME_RESOLUTION_15_MIN, + ), + revision=revision, + is_active=is_active, + create_time=_make_timestamp(datetime(2025, 1, 1, tzinfo=timezone.utc)), + update_time=_make_timestamp(datetime(2025, 1, 1, tzinfo=timezone.utc)), + ) + + +class TestCreateMarketLocation: + """Tests for create_market_location.""" + + async def test_sends_correct_request(self) -> None: + """Test that create_market_location sends the right protobuf.""" + client = _make_client() + detail_pb = _make_detail_pb() + client.stub.CreateMarketLocation = AsyncMock( + return_value=pb.CreateMarketLocationResponse(market_location=detail_pb) + ) + + ml_ref = _make_ref() + ml = _make_location() + + result = await client.create_market_location( + market_location_ref=ml_ref, + market_location=ml, + ) + + client.stub.CreateMarketLocation.assert_called_once() + request = client.stub.CreateMarketLocation.call_args[0][0] + assert isinstance(request, pb.CreateMarketLocationRequest) + assert request.market_location_ref.enterprise_id == 42 + assert request.market_location_ref.market_location_id.id.value == "DE0000000001" + assert request.market_location_ref.market_area == pb.MARKET_AREA_EU_DE + assert request.market_location.display_name == "Test Location" + assert request.market_location.time_resolution == pb.TIME_RESOLUTION_15_MIN + assert list(request.market_location.supported_directions) == [ + pb.ENERGY_FLOW_DIRECTION_IMPORT + ] + assert result.revision == 1 + assert result.is_active is True + + async def test_with_payload(self) -> None: + """Test that payload is correctly serialized.""" + client = _make_client() + detail_pb = _make_detail_pb(display_name="With Payload") + client.stub.CreateMarketLocation = AsyncMock( + return_value=pb.CreateMarketLocationResponse(market_location=detail_pb) + ) + + ml_ref = _make_ref() + ml = MarketLocation( + display_name="With Payload", + supported_directions=[EnergyFlowDirection.IMPORT], + time_resolution=TimeResolution.MIN_15, + payload={"key": "value", "num": 42}, + ) + + await client.create_market_location( + market_location_ref=ml_ref, + market_location=ml, + ) + + request = client.stub.CreateMarketLocation.call_args[0][0] + assert request.market_location.payload["key"] == "value" + assert request.market_location.payload["num"] == 42 + + +class TestUpdateMarketLocation: + """Tests for update_market_location.""" + + async def test_sends_expected_revision(self) -> None: + """Test that expected_revision is passed through.""" + client = _make_client() + detail_pb = _make_detail_pb(revision=4) + client.stub.UpdateMarketLocation = AsyncMock( + return_value=pb.UpdateMarketLocationResponse( + market_location_detail=detail_pb + ) + ) + + ml_ref = _make_ref() + update = MarketLocationUpdate(display_name="New Name") + + result = await client.update_market_location( + market_location_ref=ml_ref, + update=update, + expected_revision=3, + ) + + request = client.stub.UpdateMarketLocation.call_args[0][0] + assert isinstance(request, pb.UpdateMarketLocationRequest) + assert request.expected_revision == 3 + assert result.revision == 4 + + async def test_update_display_name(self) -> None: + """Test updating display_name sets the correct field mask.""" + client = _make_client() + detail_pb = _make_detail_pb() + client.stub.UpdateMarketLocation = AsyncMock( + return_value=pb.UpdateMarketLocationResponse( + market_location_detail=detail_pb + ) + ) + + update = MarketLocationUpdate(display_name="Updated") + + await client.update_market_location( + market_location_ref=_make_ref(), + update=update, + expected_revision=1, + ) + + request = client.stub.UpdateMarketLocation.call_args[0][0] + assert "display_name" in request.update_mask.paths + assert request.update_fields.display_name == "Updated" + + async def test_update_supported_directions(self) -> None: + """Test updating supported_directions.""" + client = _make_client() + detail_pb = _make_detail_pb() + client.stub.UpdateMarketLocation = AsyncMock( + return_value=pb.UpdateMarketLocationResponse( + market_location_detail=detail_pb + ) + ) + + update = MarketLocationUpdate( + supported_directions=[ + EnergyFlowDirection.IMPORT, + EnergyFlowDirection.EXPORT, + ] + ) + + await client.update_market_location( + market_location_ref=_make_ref(), + update=update, + expected_revision=1, + ) + + request = client.stub.UpdateMarketLocation.call_args[0][0] + assert "supported_directions" in request.update_mask.paths + assert list(request.update_fields.supported_directions) == [ + pb.ENERGY_FLOW_DIRECTION_IMPORT, + pb.ENERGY_FLOW_DIRECTION_EXPORT, + ] + + async def test_update_time_resolution(self) -> None: + """Test updating time_resolution.""" + client = _make_client() + detail_pb = _make_detail_pb() + client.stub.UpdateMarketLocation = AsyncMock( + return_value=pb.UpdateMarketLocationResponse( + market_location_detail=detail_pb + ) + ) + + update = MarketLocationUpdate(time_resolution=TimeResolution.MIN_5) + + await client.update_market_location( + market_location_ref=_make_ref(), + update=update, + expected_revision=1, + ) + + request = client.stub.UpdateMarketLocation.call_args[0][0] + assert "time_resolution" in request.update_mask.paths + assert request.update_fields.time_resolution == pb.TIME_RESOLUTION_5_MIN + + async def test_update_payload(self) -> None: + """Test updating payload.""" + client = _make_client() + detail_pb = _make_detail_pb() + client.stub.UpdateMarketLocation = AsyncMock( + return_value=pb.UpdateMarketLocationResponse( + market_location_detail=detail_pb + ) + ) + + update = MarketLocationUpdate(payload={"new_key": "new_value"}) + + await client.update_market_location( + market_location_ref=_make_ref(), + update=update, + expected_revision=1, + ) + + request = client.stub.UpdateMarketLocation.call_args[0][0] + assert "payload" in request.update_mask.paths + assert request.update_fields.payload["new_key"] == "new_value" + + async def test_update_multiple_fields(self) -> None: + """Test updating multiple fields at once.""" + client = _make_client() + detail_pb = _make_detail_pb(revision=6) + client.stub.UpdateMarketLocation = AsyncMock( + return_value=pb.UpdateMarketLocationResponse( + market_location_detail=detail_pb + ) + ) + + update = MarketLocationUpdate( + display_name="Multi", + time_resolution=TimeResolution.MIN_1, + payload={"a": 1}, + ) + + await client.update_market_location( + market_location_ref=_make_ref(), + update=update, + expected_revision=5, + ) + + request = client.stub.UpdateMarketLocation.call_args[0][0] + assert set(request.update_mask.paths) == { + "display_name", + "time_resolution", + "payload", + } + assert request.expected_revision == 5 + + +class TestListMarketLocations: + """Tests for list_market_locations.""" + + async def test_basic_list(self) -> None: + """Test listing market locations returns parsed entries.""" + client = _make_client() + + # Build a mock response with one entry. + ml_pb = pb.MarketLocation( + display_name="Listed", + supported_directions=[pb.ENERGY_FLOW_DIRECTION_IMPORT], + time_resolution=pb.TIME_RESOLUTION_15_MIN, + ) + ref_pb = pb.MarketLocationRef( + enterprise_id=42, + market_area=pb.MARKET_AREA_EU_DE, + market_location_id=pb.MarketLocationId( + id=pb.MarketLocationIdValue(value="DE0000000001"), + type=pb.MARKET_LOCATION_ID_TYPE_MALO_ID, + ), + ) + detail_pb = pb.MarketLocationDetail( + market_location_ref=ref_pb, + market_location=ml_pb, + revision=1, + is_active=True, + create_time=_make_timestamp(datetime(2025, 1, 1, tzinfo=timezone.utc)), + update_time=_make_timestamp(datetime(2025, 1, 1, tzinfo=timezone.utc)), + ) + entry_pb = pb.ListMarketLocationsResponse.MarketLocationEntry( + enterprise_id=42, + market_location=detail_pb, + ) + response = pb.ListMarketLocationsResponse(market_locations=[entry_pb]) + + client.stub.ListMarketLocations = AsyncMock(return_value=response) + + entries, next_page = await client.list_market_locations(enterprise_id=42) + + assert len(entries) == 1 + assert entries[0].enterprise_id == 42 + assert entries[0].market_location.display_name == "Listed" + assert entries[0].market_location_ref.market_area == MarketArea.EU_DE + assert next_page is None + + async def test_sends_enterprise_id(self) -> None: + """Test that enterprise_id is sent in the request.""" + client = _make_client() + response = pb.ListMarketLocationsResponse() + client.stub.ListMarketLocations = AsyncMock(return_value=response) + + await client.list_market_locations(enterprise_id=99) + + request = client.stub.ListMarketLocations.call_args[0][0] + assert request.enterprise_id == 99 + + async def test_pagination(self) -> None: + """Test that pagination info is returned.""" + client = _make_client() + + pagination_info = pagination_info_pb.PaginationInfo(next_page_token="token123") + response = pb.ListMarketLocationsResponse(pagination_info=pagination_info) + client.stub.ListMarketLocations = AsyncMock(return_value=response) + + _, next_page = await client.list_market_locations(enterprise_id=1) + + assert next_page is not None + assert next_page.page_token == "token123" + + async def test_sends_filters(self) -> None: + """Test that filters are sent in the request.""" + client = _make_client() + response = pb.ListMarketLocationsResponse() + client.stub.ListMarketLocations = AsyncMock(return_value=response) + + filters = MarketLocationsFilter( + market_location_id_filters=[ + MarketLocationId(value="DE123", type=MarketLocationIdType.MALO_ID) + ], + activation_filter=ActivationFilter.ALL, + ) + await client.list_market_locations(enterprise_id=1, filters=filters) + + request = client.stub.ListMarketLocations.call_args[0][0] + assert request.filter.activation_filter == pb.ACTIVATION_FILTER_ALL + assert len(request.filter.market_location_id_filters) == 1 + assert request.filter.market_location_id_filters[0].value == "DE123" + + async def test_sends_page_size(self) -> None: + """Test that page_size is sent in the request. + + page_size and page_token are in a oneof, so only one can be set. + """ + client = _make_client() + response = pb.ListMarketLocationsResponse() + client.stub.ListMarketLocations = AsyncMock(return_value=response) + + params = PaginationParams(page_size=10) + await client.list_market_locations(enterprise_id=1, pagination_params=params) + + request = client.stub.ListMarketLocations.call_args[0][0] + assert request.pagination_params.page_size == 10 + + async def test_sends_page_token(self) -> None: + """Test that page_token is sent in the request. + + page_size and page_token are in a oneof, so only one can be set. + """ + client = _make_client() + response = pb.ListMarketLocationsResponse() + client.stub.ListMarketLocations = AsyncMock(return_value=response) + + params = PaginationParams(page_token="next-page-token") + await client.list_market_locations(enterprise_id=1, pagination_params=params) + + request = client.stub.ListMarketLocations.call_args[0][0] + assert request.pagination_params.page_token == "next-page-token" + + +class TestActivateMarketLocations: + """Tests for activate_market_locations.""" + + async def test_sends_correct_request(self) -> None: + """Test that activate sends the refs in market_location_refs.""" + client = _make_client() + + ml_ref_pb = pb.MarketLocationRef( + enterprise_id=7, + market_area=pb.MARKET_AREA_EU_DE, + market_location_id=pb.MarketLocationId( + id=pb.MarketLocationIdValue(value="DE_ACT_001"), + type=pb.MARKET_LOCATION_ID_TYPE_MALO_ID, + ), + ) + result_pb = pb.MarketLocationOperationResult( + market_location_ref=ml_ref_pb, + revision=2, + ) + client.stub.ActivateMarketLocation = AsyncMock( + return_value=pb.ActivateMarketLocationResponse(results=[result_pb]) + ) + + ml_ref = _make_ref(enterprise_id=7, malo_id="DE_ACT_001") + + results = await client.activate_market_locations(market_location_refs=[ml_ref]) + + request = client.stub.ActivateMarketLocation.call_args[0][0] + assert isinstance(request, pb.ActivateMarketLocationRequest) + assert len(request.market_location_refs) == 1 + assert request.market_location_refs[0].enterprise_id == 7 + assert ( + request.market_location_refs[0].market_location_id.id.value == "DE_ACT_001" + ) + assert len(results) == 1 + assert results[0].error is None + assert results[0].revision == 2 + + +class TestDeactivateMarketLocations: + """Tests for deactivate_market_locations.""" + + async def test_sends_correct_request(self) -> None: + """Test that deactivate sends the refs in market_location_refs.""" + client = _make_client() + + ml_ref_pb = pb.MarketLocationRef( + enterprise_id=8, + market_area=pb.MARKET_AREA_EU_DE, + market_location_id=pb.MarketLocationId( + id=pb.MarketLocationIdValue(value="DE_DEACT_001"), + type=pb.MARKET_LOCATION_ID_TYPE_MALO_ID, + ), + ) + result_pb = pb.MarketLocationOperationResult( + market_location_ref=ml_ref_pb, + revision=3, + ) + client.stub.DeactivateMarketLocation = AsyncMock( + return_value=pb.DeactivateMarketLocationResponse(results=[result_pb]) + ) + + ml_ref = _make_ref(enterprise_id=8, malo_id="DE_DEACT_001") + + results = await client.deactivate_market_locations( + market_location_refs=[ml_ref] + ) + + request = client.stub.DeactivateMarketLocation.call_args[0][0] + assert isinstance(request, pb.DeactivateMarketLocationRequest) + assert len(request.market_location_refs) == 1 + assert request.market_location_refs[0].enterprise_id == 8 + assert len(results) == 1 + assert results[0].error is None + + +class TestStreamSamples: + """Tests for stream_samples.""" + + async def test_yields_parsed_series(self) -> None: + """Test that stream_samples yields MarketLocationSeries.""" + client = _make_client() + + ml_ref_pb = pb.MarketLocationRef( + enterprise_id=42, + market_area=pb.MARKET_AREA_EU_DE, + market_location_id=pb.MarketLocationId( + id=pb.MarketLocationIdValue(value="DE001"), + type=pb.MARKET_LOCATION_ID_TYPE_MALO_ID, + ), + ) + sample_time = _make_timestamp(datetime(2025, 1, 1, tzinfo=timezone.utc)) + sample_pb = pb.MarketLocationSampleDetail( + sample_time=sample_time, + value=100.5, + quality=pb.DATA_QUALITY_MEASURED, + revision=1, + resampling_method=pb.RESAMPLING_METHOD_NATIVE, + ) + series_pb = pb.MarketLocationSeries( + market_location_ref=ml_ref_pb, + direction=pb.ENERGY_FLOW_DIRECTION_IMPORT, + metric_type=pb.METRIC_TYPE_ACTIVE_ENERGY, + metric_unit=pb.METRIC_UNIT_KWH, + resolution=pb.TIME_RESOLUTION_15_MIN, + samples=[sample_pb], + ) + response = pb.ReceiveMarketLocationSamplesStreamResponse(series=[series_pb]) + + # Mock the streaming call to return an async iterator. + async def mock_stream() -> AsyncIterator[ + pb.ReceiveMarketLocationSamplesStreamResponse + ]: + yield response + + client.stub.ReceiveMarketLocationSamplesStream = MagicMock( + return_value=mock_stream() + ) + + results: list[MarketLocationSeries] = [] + async for series in client.stream_samples( + market_locations=[_make_ref()], + directions=[EnergyFlowDirection.IMPORT], + metric_types=[MetricType.ACTIVE_ENERGY], + ): + results.append(series) + + assert len(results) == 1 + assert results[0].direction == EnergyFlowDirection.IMPORT + assert results[0].metric_type == MetricType.ACTIVE_ENERGY + assert results[0].metric_unit == MetricUnit.KWH + assert len(results[0].samples) == 1 + assert results[0].samples[0].value == 100.5 + + async def test_sends_time_filter(self) -> None: + """Test that start_time and end_time are sent.""" + client = _make_client() + + async def mock_stream() -> AsyncIterator[ + pb.ReceiveMarketLocationSamplesStreamResponse + ]: + return + yield # make it an async generator # noqa: RET504 + + client.stub.ReceiveMarketLocationSamplesStream = MagicMock( + return_value=mock_stream() + ) + + start = datetime(2025, 1, 1, tzinfo=timezone.utc) + end = datetime(2025, 1, 2, tzinfo=timezone.utc) + + async for _ in client.stream_samples( + market_locations=[_make_ref()], + directions=[EnergyFlowDirection.IMPORT], + metric_types=[MetricType.ACTIVE_ENERGY], + start_time=start, + end_time=end, + ): + pass + + request = client.stub.ReceiveMarketLocationSamplesStream.call_args[0][0] + assert request.stream_filter.HasField("time_filter") + interval = request.stream_filter.time_filter.interval + assert interval.start_time == _make_timestamp(start) + assert interval.end_time == _make_timestamp(end) + + +class TestUpsertSamples: + """Tests for upsert_samples.""" + + async def test_yields_parsed_results(self) -> None: + """Test that upsert_samples yields UpsertResult.""" + client = _make_client() + + ml_ref_pb = pb.MarketLocationRef( + enterprise_id=42, + market_area=pb.MARKET_AREA_EU_DE, + market_location_id=pb.MarketLocationId( + id=pb.MarketLocationIdValue(value="DE001"), + type=pb.MARKET_LOCATION_ID_TYPE_MALO_ID, + ), + ) + sample_time = _make_timestamp(datetime(2025, 1, 1, tzinfo=timezone.utc)) + sample_pb = pb.MarketLocationSample( + sample_time=sample_time, + value=50.0, + quality=pb.DATA_QUALITY_MEASURED, + revision=1, + ) + upsert_response = pb.UpsertMarketLocationSamplesStreamResponse( + market_location_ref=ml_ref_pb, + direction=pb.ENERGY_FLOW_DIRECTION_IMPORT, + metric_type=pb.METRIC_TYPE_ACTIVE_ENERGY, + metric_unit=pb.METRIC_UNIT_KWH, + sample=sample_pb, + ) + + async def mock_stream() -> AsyncIterator[ + pb.UpsertMarketLocationSamplesStreamResponse + ]: + yield upsert_response + + client.stub.UpsertMarketLocationSamplesStream = MagicMock( + return_value=mock_stream() + ) + + # Build an input stream. + ml_ref = _make_ref() + sample = MarketLocationSample( + sample_time=datetime(2025, 1, 1, tzinfo=timezone.utc), + value=50.0, + quality=DataQuality.MEASURED, + revision=1, + update_time=None, + resampling_method=ResamplingMethod.UNSPECIFIED, + ) + series = MarketLocationSeries( + market_location_ref=ml_ref, + direction=EnergyFlowDirection.IMPORT, + metric_type=MetricType.ACTIVE_ENERGY, + metric_unit=MetricUnit.KWH, + resolution=TimeResolution.MIN_15, + samples=[sample], + ) + + async def input_stream() -> AsyncIterator[ + tuple[MarketLocationRef, MarketLocationSeries] + ]: + yield (ml_ref, series) + + results = [] + async for result in client.upsert_samples(input_stream()): + results.append(result) + + assert len(results) == 1 + assert results[0].error is None + assert results[0].sample.value == 50.0 diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..bede358 --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,396 @@ +# License: MIT +# Copyright © 2025 Frequenz Energy-as-a-Service GmbH + +"""Integration tests for the MarketMeteringApiClient against a live service. + +These tests require a running marketmetering service and are excluded from +CI by default. To run them: + + 1. Start the service with auth disabled and test storage backend: + + ./target/release/marketmeteringd -c test-config.toml + + Where test-config.toml contains: + + [net] + ip = "[::1]" + port = 50051 + + [auth] + enabled = false + + [storage] + backend = "test" + + 2. Run the tests: + + uv run pytest -m integration +""" + +from collections.abc import AsyncIterator + +import grpc +import pytest +from grpc.aio import AioRpcError + +from frequenz.client.marketmetering import MarketMeteringApiClient +from frequenz.client.marketmetering.types import ( + EnergyFlowDirection, + MarketArea, + MarketLocation, + MarketLocationId, + MarketLocationIdType, + MarketLocationRef, + MarketLocationUpdate, + MetricType, + TimeResolution, +) + +SERVICE_URL = "grpc://[::1]:50051?ssl=false" +AUTH_KEY = "test-key" + +pytestmark = pytest.mark.integration + + +@pytest.fixture +async def client() -> AsyncIterator[MarketMeteringApiClient]: + """Create a connected client for testing.""" + c = MarketMeteringApiClient( + server_url=SERVICE_URL, + auth_key=AUTH_KEY, + ) + yield c + + +def make_ref( + enterprise_id: int = 42, malo_id: str = "DE0000000001" +) -> MarketLocationRef: + """Create a MarketLocationRef for testing.""" + return MarketLocationRef( + enterprise_id=enterprise_id, + market_area=MarketArea.EU_DE, + market_location_id=MarketLocationId( + value=malo_id, + type=MarketLocationIdType.MALO_ID, + ), + ) + + +def make_location( + display_name: str = "Test Location", + directions: list[EnergyFlowDirection] | None = None, + time_resolution: TimeResolution = TimeResolution.MIN_15, +) -> MarketLocation: + """Create a MarketLocation for testing.""" + if directions is None: + directions = [EnergyFlowDirection.IMPORT] + return MarketLocation( + display_name=display_name, + supported_directions=directions, + time_resolution=time_resolution, + payload={}, + ) + + +class TestCreateMarketLocation: + """Tests for create_market_location.""" + + async def test_create_basic(self, client: MarketMeteringApiClient) -> None: + """Test creating a basic market location.""" + ml_ref = make_ref(enterprise_id=1, malo_id="DE_CREATE_BASIC_001") + ml = make_location(display_name="Basic Test") + + await client.create_market_location( + market_location_ref=ml_ref, + market_location=ml, + ) + + async def test_create_with_both_directions( + self, client: MarketMeteringApiClient + ) -> None: + """Test creating a market location with import and export.""" + ml_ref = make_ref(enterprise_id=1, malo_id="DE_CREATE_BOTH_DIR_001") + ml = make_location( + display_name="Both Directions", + directions=[EnergyFlowDirection.IMPORT, EnergyFlowDirection.EXPORT], + ) + + await client.create_market_location( + market_location_ref=ml_ref, + market_location=ml, + ) + + async def test_create_with_payload(self, client: MarketMeteringApiClient) -> None: + """Test creating a market location with payload metadata.""" + ml_ref = make_ref(enterprise_id=1, malo_id="DE_CREATE_PAYLOAD_001") + ml = MarketLocation( + display_name="With Payload", + supported_directions=[EnergyFlowDirection.IMPORT], + time_resolution=TimeResolution.MIN_15, + payload={"meter_serial": "ABC123", "location_type": "industrial"}, + ) + + await client.create_market_location( + market_location_ref=ml_ref, + market_location=ml, + ) + + async def test_create_duplicate_fails( + self, client: MarketMeteringApiClient + ) -> None: + """Test that creating a duplicate market location fails.""" + ml_ref = make_ref(enterprise_id=1, malo_id="DE_CREATE_DUP_001") + ml = make_location(display_name="First") + + await client.create_market_location( + market_location_ref=ml_ref, + market_location=ml, + ) + + with pytest.raises(AioRpcError) as exc_info: + await client.create_market_location( + market_location_ref=ml_ref, + market_location=ml, + ) + assert exc_info.value.code() in ( + grpc.StatusCode.ALREADY_EXISTS, + grpc.StatusCode.INVALID_ARGUMENT, + grpc.StatusCode.INTERNAL, + ) + + async def test_create_with_different_resolutions( + self, client: MarketMeteringApiClient + ) -> None: + """Test creating locations with various time resolutions.""" + for i, res in enumerate( + [ + TimeResolution.MIN_1, + TimeResolution.MIN_5, + TimeResolution.MIN_15, + TimeResolution.MIN_60, + TimeResolution.DAY_1, + ] + ): + ml_ref = make_ref(enterprise_id=1, malo_id=f"DE_CREATE_RES_{i:03d}") + ml = make_location( + display_name=f"Resolution {res.name}", + time_resolution=res, + ) + await client.create_market_location( + market_location_ref=ml_ref, + market_location=ml, + ) + + +class TestUpdateMarketLocation: + """Tests for update_market_location.""" + + async def test_update_display_name(self, client: MarketMeteringApiClient) -> None: + """Test updating the display name of a market location.""" + ml_ref = make_ref(enterprise_id=2, malo_id="DE_UPDATE_NAME_001") + ml = make_location(display_name="Original Name") + + await client.create_market_location( + market_location_ref=ml_ref, + market_location=ml, + ) + + update = MarketLocationUpdate(display_name="Updated Name") + await client.update_market_location( + market_location_ref=ml_ref, + update=update, + expected_revision=1, + ) + + async def test_update_supported_directions( + self, client: MarketMeteringApiClient + ) -> None: + """Test adding a direction to supported_directions.""" + ml_ref = make_ref(enterprise_id=2, malo_id="DE_UPDATE_DIR_001") + ml = make_location( + display_name="Direction Update", + directions=[EnergyFlowDirection.IMPORT], + ) + + await client.create_market_location( + market_location_ref=ml_ref, + market_location=ml, + ) + + update = MarketLocationUpdate( + supported_directions=[ + EnergyFlowDirection.IMPORT, + EnergyFlowDirection.EXPORT, + ], + ) + await client.update_market_location( + market_location_ref=ml_ref, + update=update, + expected_revision=1, + ) + + async def test_update_time_resolution_to_finer( + self, client: MarketMeteringApiClient + ) -> None: + """Test changing time_resolution to a finer resolution.""" + ml_ref = make_ref(enterprise_id=2, malo_id="DE_UPDATE_RES_001") + ml = make_location( + display_name="Resolution Update", + time_resolution=TimeResolution.MIN_15, + ) + + await client.create_market_location( + market_location_ref=ml_ref, + market_location=ml, + ) + + update = MarketLocationUpdate(time_resolution=TimeResolution.MIN_5) + await client.update_market_location( + market_location_ref=ml_ref, + update=update, + expected_revision=1, + ) + + async def test_update_time_resolution_to_coarser( + self, client: MarketMeteringApiClient + ) -> None: + """Test that changing to a coarser resolution succeeds.""" + ml_ref = make_ref(enterprise_id=2, malo_id="DE_UPDATE_RES_COARSE_001") + ml = make_location( + display_name="Coarser OK", + time_resolution=TimeResolution.MIN_15, + ) + + await client.create_market_location( + market_location_ref=ml_ref, + market_location=ml, + ) + + update = MarketLocationUpdate(time_resolution=TimeResolution.MIN_60) + await client.update_market_location( + market_location_ref=ml_ref, + update=update, + expected_revision=1, + ) + + async def test_update_payload(self, client: MarketMeteringApiClient) -> None: + """Test updating the payload.""" + ml_ref = make_ref(enterprise_id=2, malo_id="DE_UPDATE_PAY_001") + ml = make_location(display_name="Payload Update") + + await client.create_market_location( + market_location_ref=ml_ref, + market_location=ml, + ) + + update = MarketLocationUpdate(payload={"new_key": "new_value"}) + await client.update_market_location( + market_location_ref=ml_ref, + update=update, + expected_revision=1, + ) + + async def test_update_nonexistent_fails( + self, client: MarketMeteringApiClient + ) -> None: + """Test that updating a non-existent location fails.""" + ml_ref = make_ref(enterprise_id=2, malo_id="DE_NONEXISTENT_001") + update = MarketLocationUpdate(display_name="Ghost") + + with pytest.raises(AioRpcError): + await client.update_market_location( + market_location_ref=ml_ref, + update=update, + expected_revision=1, + ) + + async def test_sequential_updates_increment_revision( + self, client: MarketMeteringApiClient + ) -> None: + """Test that sequential updates require incrementing revision.""" + ml_ref = make_ref(enterprise_id=2, malo_id="DE_UPDATE_SEQ_001") + ml = make_location(display_name="Sequential") + + await client.create_market_location( + market_location_ref=ml_ref, + market_location=ml, + ) + + await client.update_market_location( + market_location_ref=ml_ref, + update=MarketLocationUpdate(display_name="Updated Once"), + expected_revision=1, + ) + + await client.update_market_location( + market_location_ref=ml_ref, + update=MarketLocationUpdate(display_name="Updated Twice"), + expected_revision=2, + ) + + # Stale revision should fail + with pytest.raises(AioRpcError) as exc_info: + await client.update_market_location( + market_location_ref=ml_ref, + update=MarketLocationUpdate(display_name="Stale"), + expected_revision=1, + ) + assert "conflict" in (exc_info.value.details() or "").lower() + + +class TestListMarketLocations: + """Tests for list_market_locations.""" + + async def test_list_returns_results(self, client: MarketMeteringApiClient) -> None: + """Test that list_market_locations returns results.""" + # Create a location first so there's something to list. + ml_ref = make_ref(enterprise_id=3, malo_id="DE_LIST_001") + ml = make_location(display_name="Listable") + await client.create_market_location( + market_location_ref=ml_ref, + market_location=ml, + ) + + entries, _ = await client.list_market_locations(enterprise_id=3) + assert len(entries) >= 1 + + +class TestActivateDeactivate: + """Tests for activate/deactivate market locations.""" + + async def test_deactivate_and_activate( + self, client: MarketMeteringApiClient + ) -> None: + """Test deactivating and reactivating a market location.""" + ml_ref = make_ref(enterprise_id=3, malo_id="DE_ACT_DEACT_001") + ml = make_location(display_name="Toggle Active") + await client.create_market_location( + market_location_ref=ml_ref, + market_location=ml, + ) + + results = await client.deactivate_market_locations( + market_location_refs=[ml_ref] + ) + assert len(results) == 1 + assert results[0].error is None + + results = await client.activate_market_locations(market_location_refs=[ml_ref]) + assert len(results) == 1 + assert results[0].error is None + + +class TestStreamSamples: + """Tests for stream_samples.""" + + async def test_stream_returns_no_error( + self, client: MarketMeteringApiClient + ) -> None: + """Test that stream_samples can be called without error.""" + ml_ref = make_ref(enterprise_id=3, malo_id="DE_STREAM_001") + async for _ in client.stream_samples( + market_locations=[ml_ref], + directions=[EnergyFlowDirection.IMPORT], + metric_types=[MetricType.ACTIVE_ENERGY], + ): + break # Just verify the stream starts without error From ea73ca8db87df6fcde7d952c06ee06054624c09d Mon Sep 17 00:00:00 2001 From: "Mathias L. Baumann" Date: Thu, 2 Apr 2026 16:52:33 +0200 Subject: [PATCH 2/3] Fix auth for sample upsert stream The client got authentication and signing metadata from frequenz-client-base only for unary-unary and unary-stream RPCs. UpsertMarketLocationSamplesStream is a streaming upsert call, so it missed the key, timestamp, nonce, and signature metadata and the service rejected it with a missing-signature authentication error. Attach the same metadata explicitly for the upsert stream call and add a regression test covering the streaming path. Signed-off-by: Mathias L. Baumann --- RELEASE_NOTES.md | 1 + src/frequenz/client/marketmetering/_client.py | 32 ++++++ tests/test_client.py | 24 ++--- tests/test_integration.py | 23 ++-- tests/test_marketmetering.py | 102 +++++++++++++++++- 5 files changed, 163 insertions(+), 19 deletions(-) diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index b0fe19c..5faaa85 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -53,3 +53,4 @@ Initial release of the Frequenz Market Metering API client for Python. ## Bug Fixes - `update_market_location()`: Add missing `expected_revision` parameter required for optimistic concurrency control. +- `upsert_samples()`: Attach auth and signing metadata to the streaming upsert RPC so authenticated sample upserts work against services that require signed requests. diff --git a/src/frequenz/client/marketmetering/_client.py b/src/frequenz/client/marketmetering/_client.py index c23f214..7512d4f 100644 --- a/src/frequenz/client/marketmetering/_client.py +++ b/src/frequenz/client/marketmetering/_client.py @@ -5,6 +5,10 @@ from __future__ import annotations +import hmac +import secrets +import time +from base64 import urlsafe_b64encode from datetime import datetime, timedelta from typing import AsyncIterator, cast @@ -152,6 +156,33 @@ def stub(self) -> marketmetering_pb2_grpc.MarketMeteringServiceStub: raise ClientNotConnected(server_url=self.server_url, operation="stub") return self._stub + def _metadata(self, method: str) -> tuple[tuple[str, str | bytes], ...] | None: + """Build request metadata for RPCs not covered by client-base interceptors.""" + if self._auth_key is None: + return None + + metadata: list[tuple[str, str | bytes]] = [("key", self._auth_key)] + if self._sign_secret is None: + return tuple(metadata) + + ts = str(int(time.time())).encode() + nonce = urlsafe_b64encode(secrets.token_bytes(16)) + + digest = hmac.new(self._sign_secret.encode(), digestmod="sha256") + digest.update(self._auth_key.encode()) + digest.update(ts) + digest.update(nonce) + digest.update(method.encode()) + + metadata.extend( + [ + ("ts", ts), + ("nonce", nonce), + ("sig", urlsafe_b64encode(digest.digest()).rstrip(b"=")), + ] + ) + return tuple(metadata) + async def create_market_location( self, *, @@ -336,6 +367,7 @@ async def request_generator() -> ( AsyncIterator[pb.UpsertMarketLocationSamplesStreamResponse], self.stub.UpsertMarketLocationSamplesStream( request_generator(), # type: ignore[arg-type] + metadata=self._metadata("UpsertMarketLocationSamplesStream"), timeout=self._stream_timeout_seconds, ), ) diff --git a/tests/test_client.py b/tests/test_client.py index 6482abb..b1e6564 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -551,9 +551,9 @@ async def test_yields_parsed_series(self) -> None: response = pb.ReceiveMarketLocationSamplesStreamResponse(series=[series_pb]) # Mock the streaming call to return an async iterator. - async def mock_stream() -> AsyncIterator[ - pb.ReceiveMarketLocationSamplesStreamResponse - ]: + async def mock_stream() -> ( + AsyncIterator[pb.ReceiveMarketLocationSamplesStreamResponse] + ): yield response client.stub.ReceiveMarketLocationSamplesStream = MagicMock( @@ -579,9 +579,9 @@ async def test_sends_time_filter(self) -> None: """Test that start_time and end_time are sent.""" client = _make_client() - async def mock_stream() -> AsyncIterator[ - pb.ReceiveMarketLocationSamplesStreamResponse - ]: + async def mock_stream() -> ( + AsyncIterator[pb.ReceiveMarketLocationSamplesStreamResponse] + ): return yield # make it an async generator # noqa: RET504 @@ -638,9 +638,9 @@ async def test_yields_parsed_results(self) -> None: sample=sample_pb, ) - async def mock_stream() -> AsyncIterator[ - pb.UpsertMarketLocationSamplesStreamResponse - ]: + async def mock_stream() -> ( + AsyncIterator[pb.UpsertMarketLocationSamplesStreamResponse] + ): yield upsert_response client.stub.UpsertMarketLocationSamplesStream = MagicMock( @@ -666,9 +666,9 @@ async def mock_stream() -> AsyncIterator[ samples=[sample], ) - async def input_stream() -> AsyncIterator[ - tuple[MarketLocationRef, MarketLocationSeries] - ]: + async def input_stream() -> ( + AsyncIterator[tuple[MarketLocationRef, MarketLocationSeries]] + ): yield (ml_ref, series) results = [] diff --git a/tests/test_integration.py b/tests/test_integration.py index bede358..bff93fe 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,5 +1,5 @@ # License: MIT -# Copyright © 2025 Frequenz Energy-as-a-Service GmbH +# Copyright © 2026 Frequenz Energy-as-a-Service GmbH """Integration tests for the MarketMeteringApiClient against a live service. @@ -27,6 +27,8 @@ uv run pytest -m integration """ +import os +import socket from collections.abc import AsyncIterator import grpc @@ -52,9 +54,22 @@ pytestmark = pytest.mark.integration +def _service_available() -> bool: + """Check whether the local integration test service is reachable.""" + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock: + sock.settimeout(0.5) + return sock.connect_ex(("::1", 50051, 0, 0)) == 0 + + @pytest.fixture async def client() -> AsyncIterator[MarketMeteringApiClient]: """Create a connected client for testing.""" + if os.environ.get("CI") == "true": + pytest.skip("integration tests are not run in CI") + + if not _service_available(): + pytest.skip("integration test service is not running on [::1]:50051") + c = MarketMeteringApiClient( server_url=SERVICE_URL, auth_key=AUTH_KEY, @@ -152,11 +167,7 @@ async def test_create_duplicate_fails( market_location_ref=ml_ref, market_location=ml, ) - assert exc_info.value.code() in ( - grpc.StatusCode.ALREADY_EXISTS, - grpc.StatusCode.INVALID_ARGUMENT, - grpc.StatusCode.INTERNAL, - ) + assert exc_info.value.code() == grpc.StatusCode.ALREADY_EXISTS async def test_create_with_different_resolutions( self, client: MarketMeteringApiClient diff --git a/tests/test_marketmetering.py b/tests/test_marketmetering.py index 2ff80e9..b2eeea3 100644 --- a/tests/test_marketmetering.py +++ b/tests/test_marketmetering.py @@ -3,8 +3,13 @@ """Tests for the Market Metering client.""" -from datetime import timedelta +from collections.abc import AsyncIterator +from datetime import datetime, timedelta, timezone +from typing import Any, cast +import pytest + +from frequenz.client.marketmetering import MarketMeteringApiClient from frequenz.client.marketmetering.types import ( DataQuality, DownsamplingMethod, @@ -13,6 +18,8 @@ MarketLocationId, MarketLocationIdType, MarketLocationRef, + MarketLocationSample, + MarketLocationSeries, MetricType, MetricUnit, ResamplingMethod, @@ -117,3 +124,96 @@ def test_resampling_options_with_resolution(self) -> None: """Test ResamplingOptions with resolution.""" options = ResamplingOptions(resolution=TimeResolution.MIN_15) assert options.resolution == TimeResolution.MIN_15 + + +class _UpsertStub: + """A fake stub capturing upsert stream call arguments.""" + + def __init__(self) -> None: + self.metadata: tuple[tuple[str, str | bytes], ...] | None = None + self.requests: list[object] = [] + + # gRPC-generated method names use UpperCamelCase and this stub mirrors that. + # pylint: disable=invalid-name,missing-function-docstring + async def UpsertMarketLocationSamplesStream( # noqa: N802 + self, + request_iterator: AsyncIterator[object], + *, + metadata: tuple[tuple[str, str | bytes], ...] | None = None, + timeout: float | None = None, + ) -> AsyncIterator[object]: + del timeout + self.metadata = metadata + async for request in request_iterator: + self.requests.append(request) + items: tuple[object, ...] = () + for item in items: + yield item + + +class TestClientMethods: + """Tests for client RPC helpers.""" + + @pytest.mark.asyncio + async def test_upsert_samples_adds_stream_metadata(self) -> None: + """Test that bidi upsert includes auth and signing metadata.""" + client = MarketMeteringApiClient( + server_url="grpc://example.com", + auth_key="test-key", + sign_secret="test-secret", + connect=False, + ) + stub = _UpsertStub() + setattr(client, "_stub", cast(Any, stub)) + setattr(client, "_channel", cast(Any, object())) + + market_location_ref = MarketLocationRef( + enterprise_id=42, + market_area=MarketArea.EU_DE, + market_location_id=MarketLocationId( + value="DE01234567890", + type=MarketLocationIdType.MALO_ID, + ), + ) + series = MarketLocationSeries( + market_location_ref=market_location_ref, + direction=EnergyFlowDirection.IMPORT, + metric_type=MetricType.ACTIVE_ENERGY, + metric_unit=MetricUnit.KWH, + resolution=TimeResolution.MIN_15, + samples=[], + ) + sample = MarketLocationSample( + sample_time=datetime.now(timezone.utc), + value=1.0, + quality=DataQuality.MEASURED, + revision=1, + update_time=None, + resampling_method=ResamplingMethod.UNSPECIFIED, + ) + + async def sample_generator() -> ( + AsyncIterator[tuple[MarketLocationRef, MarketLocationSeries]] + ): + yield ( + market_location_ref, + MarketLocationSeries( + market_location_ref=series.market_location_ref, + direction=series.direction, + metric_type=series.metric_type, + metric_unit=series.metric_unit, + resolution=series.resolution, + samples=[sample], + ), + ) + + assert [ + result async for result in client.upsert_samples(sample_generator()) + ] == [] + assert stub.metadata is not None + metadata = dict(stub.metadata) + assert metadata["key"] == "test-key" + assert metadata["ts"] + assert metadata["nonce"] + assert metadata["sig"] + assert len(stub.requests) == 1 From 9fa9e2a1b295fbbd6c3b83a56dfd43fcd81b8547 Mon Sep 17 00:00:00 2001 From: "Mathias L. Baumann" Date: Tue, 7 Apr 2026 13:18:23 +0200 Subject: [PATCH 3/3] WIP: Add testcontainers for integration tests Integration tests now use testcontainers to spin up GreptimeDB and the marketmeteringd service binary automatically. - Added testcontainers and requests dependencies to dev-pytest - Created conftest.py with fixtures for: - GreptimeDB container - marketmeteringd service process - Database schema initialization - Updated test_integration.py to use new fixtures - Tests skip automatically when Docker is unavailable Requires: - Docker running - Service binary at ../frequenz-service-marketmetering/target/release/marketmeteringd Signed-off-by: Mathias L. Baumann --- pyproject.toml | 2 + tests/conftest.py | 235 ++++++++++++++++++++++++++++++++++++++ tests/test_integration.py | 59 +--------- 3 files changed, 243 insertions(+), 53 deletions(-) create mode 100644 tests/conftest.py diff --git a/pyproject.toml b/pyproject.toml index b66e411..36b6f38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,8 @@ dev-pytest = [ "pytest-asyncio == 1.3.0", "async-solipsism == 0.9", "frequenz-client-marketmetering[cli]", + "testcontainers == 4.10.0", + "requests == 2.32.3", ] dev = [ "frequenz-client-marketmetering[dev-mkdocs,dev-flake8,dev-formatting,dev-mkdocs,dev-mypy,dev-noxfile,dev-pylint,dev-pytest]", diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..e577efc --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,235 @@ +# License: MIT +# Copyright © 2026 Frequenz Energy-as-a-Service GmbH + +"""Pytest fixtures for integration tests using testcontainers.""" + +from __future__ import annotations + +import subprocess +import tempfile +import time +from collections.abc import AsyncIterator +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import pytest +import requests + +from frequenz.client.marketmetering import MarketMeteringApiClient + +if TYPE_CHECKING: + from collections.abc import Generator + +SERVICE_REPO_PATH = ( + Path(__file__).parent.parent.parent / "frequenz-service-marketmetering" +) +SCHEMA_PATH = SERVICE_REPO_PATH / "database" / "greptimedb" / "create_schema.sql" +SERVICE_BINARY = SERVICE_REPO_PATH / "target" / "release" / "marketmeteringd" + +GREPTIMEDB_IMAGE = "greptime/greptimedb:v1.0.0-rc.2" +GREPTIMEDB_GRPC_PORT = 4001 +GREPTIMEDB_HTTP_PORT = 4000 + +pytestmark = pytest.mark.integration + + +def _docker_available() -> bool: + """Check if Docker is available and running.""" + try: + import docker + + client = docker.from_env() + client.ping() + return True + except Exception: + return False + + +def pytest_configure(config: pytest.Config) -> None: + """Configure pytest markers.""" + config.addinivalue_line( + "markers", "integration: integration tests requiring Docker" + ) + + +def pytest_collection_modifyitems( + config: pytest.Config, items: list[pytest.Item] +) -> None: + """Skip integration tests if Docker is not available.""" + if not _docker_available(): + skip_reason = pytest.mark.skip(reason="Docker not available") + for item in items: + if "integration" in [m.name for m in item.iter_markers()]: + item.add_marker(skip_reason) + + +class _GreptimeDBContainer: + """GreptimeDB testcontainer wrapper.""" + + def __init__(self) -> None: + from testcontainers.core.container import DockerContainer + + self._container: DockerContainer = DockerContainer(GREPTIMEDB_IMAGE) + self._container.with_exposed_ports( + GREPTIMEDB_HTTP_PORT, GREPTIMEDB_GRPC_PORT, 4002, 4003 + ) + self._container.with_command( + "standalone start --http-addr 0.0.0.0:4000 --rpc-bind-addr 0.0.0.0:4001 " + "--mysql-addr 0.0.0.0:4002 --postgres-addr 0.0.0.0:4003" + ) + + def start(self) -> None: + """Start the container.""" + self._container.start() + + def stop(self) -> None: + """Stop the container.""" + self._container.stop() + + def get_grpc_endpoint(self) -> str: + """Get the gRPC endpoint for GreptimeDB.""" + host = self._container.get_container_host_ip() + port = self._container.get_exposed_port(GREPTIMEDB_GRPC_PORT) + return f"http://{host}:{port}" + + def get_http_url(self) -> str: + """Get the HTTP URL for GreptimeDB.""" + host = self._container.get_container_host_ip() + port = self._container.get_exposed_port(GREPTIMEDB_HTTP_PORT) + return f"http://{host}:{port}" + + def wait_for_health(self, timeout: int = 30) -> None: + """Wait for GreptimeDB to be healthy.""" + start = time.time() + while time.time() - start < timeout: + try: + resp = requests.get(f"{self.get_http_url()}/health", timeout=2) + if resp.status_code == 200: + return + except Exception: + pass + time.sleep(0.5) + raise RuntimeError("GreptimeDB health check timed out") + + +@pytest.fixture(scope="session") +def greptimedb_container() -> Generator[_GreptimeDBContainer, None, None]: + """Start a GreptimeDB container for the test session.""" + if not _docker_available(): + pytest.skip("Docker not available") + + from testcontainers.core.waiting_utils import wait_for_logs + + container = _GreptimeDBContainer() + container.start() + wait_for_logs(container._container, predicate=".*server started.*", timeout=30) + container.wait_for_health() + yield container + container.stop() + + +@pytest.fixture(scope="session") +def greptimedb_schema( + greptimedb_container: _GreptimeDBContainer, +) -> str: + """Initialize the GreptimeDB schema.""" + http_url = greptimedb_container.get_http_url() + sql_path = SCHEMA_PATH + + if not sql_path.exists(): + pytest.skip( + f"Schema file not found at {sql_path}. Is the service repo checked out?" + ) + + sql_content = sql_path.read_text() + + response = requests.post( + f"{http_url}/v1/sql", + data={"sql": sql_content}, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=30, + ) + if response.status_code != 200: + raise RuntimeError(f"Failed to initialize schema: {response.text}") + + return greptimedb_container.get_grpc_endpoint() + + +@pytest.fixture(scope="session") +def service_binary() -> Path: + """Return path to the marketmeteringd binary.""" + if not SERVICE_BINARY.exists(): + pytest.skip( + f"Service binary not found at {SERVICE_BINARY}. " + "Build the service with 'cargo build --release'" + ) + return SERVICE_BINARY + + +@pytest.fixture(scope="session") +def service_config( + greptimedb_container: _GreptimeDBContainer, + service_binary: Path, +) -> Path: + """Create a config file for the marketmeteringd service.""" + greptimedb_endpoint = greptimedb_container.get_grpc_endpoint() + + config_content = f"""[net] +ip = "[::1]" +port = 50051 + +[auth] +enabled = false + +[storage] +backend = "greptime" +endpoint = "{greptimedb_endpoint}/marketmetering" + +[service] +upsert_stream_buf_size = 32 +""" + config_file = tempfile.NamedTemporaryFile(mode="w", suffix=".toml", delete=False) + config_file.write(config_content) + config_file.close() + return Path(config_file.name) + + +@pytest.fixture(scope="session") +def service_process( + service_binary: Path, + service_config: Path, +) -> Generator[subprocess.Popen[Any], None, None]: + """Start the marketmeteringd service.""" + proc = subprocess.Popen( + [str(service_binary), "--config", str(service_config)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + time.sleep(2) + + if proc.poll() is not None: + stdout, stderr = proc.communicate(timeout=5) + raise RuntimeError( + f"Service failed to start.\nstdout: {stdout.decode()}\nstderr: {stderr.decode()}" + ) + + yield proc + + proc.terminate() + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + + +@pytest.fixture +async def client( + service_process: subprocess.Popen[Any], +) -> AsyncIterator[MarketMeteringApiClient]: + """Create a client connected to the test service.""" + c = MarketMeteringApiClient( + server_url="grpc://[::1]:50051?ssl=false", + auth_key="", + ) + yield c diff --git a/tests/test_integration.py b/tests/test_integration.py index bff93fe..ff5e78a 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -3,34 +3,15 @@ """Integration tests for the MarketMeteringApiClient against a live service. -These tests require a running marketmetering service and are excluded from -CI by default. To run them: +These tests use testcontainers to spin up GreptimeDB and run the +marketmeteringd service binary. They require: +1. Docker to be running +2. The marketmeteringd binary to be built (run 'cargo build --release' in the service repo) +3. The service repo to be checked out at ../frequenz-service-marketmetering - 1. Start the service with auth disabled and test storage backend: - - ./target/release/marketmeteringd -c test-config.toml - - Where test-config.toml contains: - - [net] - ip = "[::1]" - port = 50051 - - [auth] - enabled = false - - [storage] - backend = "test" - - 2. Run the tests: - - uv run pytest -m integration +To run: uv run pytest -m integration """ -import os -import socket -from collections.abc import AsyncIterator - import grpc import pytest from grpc.aio import AioRpcError @@ -48,34 +29,6 @@ TimeResolution, ) -SERVICE_URL = "grpc://[::1]:50051?ssl=false" -AUTH_KEY = "test-key" - -pytestmark = pytest.mark.integration - - -def _service_available() -> bool: - """Check whether the local integration test service is reachable.""" - with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock: - sock.settimeout(0.5) - return sock.connect_ex(("::1", 50051, 0, 0)) == 0 - - -@pytest.fixture -async def client() -> AsyncIterator[MarketMeteringApiClient]: - """Create a connected client for testing.""" - if os.environ.get("CI") == "true": - pytest.skip("integration tests are not run in CI") - - if not _service_available(): - pytest.skip("integration test service is not running on [::1]:50051") - - c = MarketMeteringApiClient( - server_url=SERVICE_URL, - auth_key=AUTH_KEY, - ) - yield c - def make_ref( enterprise_id: int = 42, malo_id: str = "DE0000000001"