Skip to content

Commit 4973d72

Browse files
committed
Fix auth for sample upsert stream
The client got authentication and signing metadata from frequenz-client-base only for unary-unary and unary-stream RPCs. UpsertMarketLocationSamplesStream is a streaming upsert call, so it missed the key, timestamp, nonce, and signature metadata and the service rejected it with a missing-signature authentication error. Attach the same metadata explicitly for the upsert stream call and add a regression test covering the streaming path. Signed-off-by: Mathias L. Baumann <mathias.baumann@frequenz.com>
1 parent 0328ee3 commit 4973d72

5 files changed

Lines changed: 161 additions & 13 deletions

File tree

RELEASE_NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,4 @@ Initial release of the Frequenz Market Metering API client for Python.
5353
## Bug Fixes
5454

5555
- `update_market_location()`: Add missing `expected_revision` parameter required for optimistic concurrency control.
56+
- `upsert_samples()`: Attach auth and signing metadata to the streaming upsert RPC so authenticated sample upserts work against services that require signed requests.

src/frequenz/client/marketmetering/_client.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55

66
from __future__ import annotations
77

8+
import hmac
9+
import secrets
10+
import time
11+
from base64 import urlsafe_b64encode
812
from datetime import datetime, timedelta
913
from typing import AsyncIterator, cast
1014

@@ -152,6 +156,33 @@ def stub(self) -> marketmetering_pb2_grpc.MarketMeteringServiceStub:
152156
raise ClientNotConnected(server_url=self.server_url, operation="stub")
153157
return self._stub
154158

159+
def _metadata(self, method: str) -> tuple[tuple[str, str | bytes], ...] | None:
160+
"""Build request metadata for RPCs not covered by client-base interceptors."""
161+
if self._auth_key is None:
162+
return None
163+
164+
metadata: list[tuple[str, str | bytes]] = [("key", self._auth_key)]
165+
if self._sign_secret is None:
166+
return tuple(metadata)
167+
168+
ts = str(int(time.time())).encode()
169+
nonce = urlsafe_b64encode(secrets.token_bytes(16))
170+
171+
digest = hmac.new(self._sign_secret.encode(), digestmod="sha256")
172+
digest.update(self._auth_key.encode())
173+
digest.update(ts)
174+
digest.update(nonce)
175+
digest.update(method.encode())
176+
177+
metadata.extend(
178+
[
179+
("ts", ts),
180+
("nonce", nonce),
181+
("sig", urlsafe_b64encode(digest.digest()).rstrip(b"=")),
182+
]
183+
)
184+
return tuple(metadata)
185+
155186
async def create_market_location(
156187
self,
157188
*,
@@ -336,6 +367,7 @@ async def request_generator() -> (
336367
AsyncIterator[pb.UpsertMarketLocationSamplesStreamResponse],
337368
self.stub.UpsertMarketLocationSamplesStream(
338369
request_generator(), # type: ignore[arg-type]
370+
metadata=self._metadata("UpsertMarketLocationSamplesStream"),
339371
timeout=self._stream_timeout_seconds,
340372
),
341373
)

tests/test_client.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -551,9 +551,9 @@ async def test_yields_parsed_series(self) -> None:
551551
response = pb.ReceiveMarketLocationSamplesStreamResponse(series=[series_pb])
552552

553553
# Mock the streaming call to return an async iterator.
554-
async def mock_stream() -> AsyncIterator[
555-
pb.ReceiveMarketLocationSamplesStreamResponse
556-
]:
554+
async def mock_stream() -> (
555+
AsyncIterator[pb.ReceiveMarketLocationSamplesStreamResponse]
556+
):
557557
yield response
558558

559559
client.stub.ReceiveMarketLocationSamplesStream = MagicMock(
@@ -579,9 +579,9 @@ async def test_sends_time_filter(self) -> None:
579579
"""Test that start_time and end_time are sent."""
580580
client = _make_client()
581581

582-
async def mock_stream() -> AsyncIterator[
583-
pb.ReceiveMarketLocationSamplesStreamResponse
584-
]:
582+
async def mock_stream() -> (
583+
AsyncIterator[pb.ReceiveMarketLocationSamplesStreamResponse]
584+
):
585585
return
586586
yield # make it an async generator # noqa: RET504
587587

@@ -638,9 +638,9 @@ async def test_yields_parsed_results(self) -> None:
638638
sample=sample_pb,
639639
)
640640

641-
async def mock_stream() -> AsyncIterator[
642-
pb.UpsertMarketLocationSamplesStreamResponse
643-
]:
641+
async def mock_stream() -> (
642+
AsyncIterator[pb.UpsertMarketLocationSamplesStreamResponse]
643+
):
644644
yield upsert_response
645645

646646
client.stub.UpsertMarketLocationSamplesStream = MagicMock(
@@ -666,9 +666,9 @@ async def mock_stream() -> AsyncIterator[
666666
samples=[sample],
667667
)
668668

669-
async def input_stream() -> AsyncIterator[
670-
tuple[MarketLocationRef, MarketLocationSeries]
671-
]:
669+
async def input_stream() -> (
670+
AsyncIterator[tuple[MarketLocationRef, MarketLocationSeries]]
671+
):
672672
yield (ml_ref, series)
673673

674674
results = []

tests/test_integration.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
uv run pytest -m integration
2828
"""
2929

30+
import os
31+
import socket
3032
from collections.abc import AsyncIterator
3133

3234
import grpc
@@ -52,9 +54,22 @@
5254
pytestmark = pytest.mark.integration
5355

5456

57+
def _service_available() -> bool:
58+
"""Check whether the local integration test service is reachable."""
59+
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock:
60+
sock.settimeout(0.5)
61+
return sock.connect_ex(("::1", 50051, 0, 0)) == 0
62+
63+
5564
@pytest.fixture
5665
async def client() -> AsyncIterator[MarketMeteringApiClient]:
5766
"""Create a connected client for testing."""
67+
if os.environ.get("CI") == "true":
68+
pytest.skip("integration tests are not run in CI")
69+
70+
if not _service_available():
71+
pytest.skip("integration test service is not running on [::1]:50051")
72+
5873
c = MarketMeteringApiClient(
5974
server_url=SERVICE_URL,
6075
auth_key=AUTH_KEY,

tests/test_marketmetering.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,13 @@
33

44
"""Tests for the Market Metering client."""
55

6-
from datetime import timedelta
6+
from collections.abc import AsyncIterator
7+
from datetime import datetime, timedelta, timezone
8+
from typing import Any, cast
79

10+
import pytest
11+
12+
from frequenz.client.marketmetering import MarketMeteringApiClient
813
from frequenz.client.marketmetering.types import (
914
DataQuality,
1015
DownsamplingMethod,
@@ -13,6 +18,8 @@
1318
MarketLocationId,
1419
MarketLocationIdType,
1520
MarketLocationRef,
21+
MarketLocationSample,
22+
MarketLocationSeries,
1623
MetricType,
1724
MetricUnit,
1825
ResamplingMethod,
@@ -117,3 +124,96 @@ def test_resampling_options_with_resolution(self) -> None:
117124
"""Test ResamplingOptions with resolution."""
118125
options = ResamplingOptions(resolution=TimeResolution.MIN_15)
119126
assert options.resolution == TimeResolution.MIN_15
127+
128+
129+
class _UpsertStub:
130+
"""A fake stub capturing upsert stream call arguments."""
131+
132+
def __init__(self) -> None:
133+
self.metadata: tuple[tuple[str, str | bytes], ...] | None = None
134+
self.requests: list[object] = []
135+
136+
# gRPC-generated method names use UpperCamelCase and this stub mirrors that.
137+
# pylint: disable=invalid-name,missing-function-docstring
138+
async def UpsertMarketLocationSamplesStream( # noqa: N802
139+
self,
140+
request_iterator: AsyncIterator[object],
141+
*,
142+
metadata: tuple[tuple[str, str | bytes], ...] | None = None,
143+
timeout: float | None = None,
144+
) -> AsyncIterator[object]:
145+
del timeout
146+
self.metadata = metadata
147+
async for request in request_iterator:
148+
self.requests.append(request)
149+
items: tuple[object, ...] = ()
150+
for item in items:
151+
yield item
152+
153+
154+
class TestClientMethods:
155+
"""Tests for client RPC helpers."""
156+
157+
@pytest.mark.asyncio
158+
async def test_upsert_samples_adds_stream_metadata(self) -> None:
159+
"""Test that bidi upsert includes auth and signing metadata."""
160+
client = MarketMeteringApiClient(
161+
server_url="grpc://example.com",
162+
auth_key="test-key",
163+
sign_secret="test-secret",
164+
connect=False,
165+
)
166+
stub = _UpsertStub()
167+
setattr(client, "_stub", cast(Any, stub))
168+
setattr(client, "_channel", cast(Any, object()))
169+
170+
market_location_ref = MarketLocationRef(
171+
enterprise_id=42,
172+
market_area=MarketArea.EU_DE,
173+
market_location_id=MarketLocationId(
174+
value="DE01234567890",
175+
type=MarketLocationIdType.MALO_ID,
176+
),
177+
)
178+
series = MarketLocationSeries(
179+
market_location_ref=market_location_ref,
180+
direction=EnergyFlowDirection.IMPORT,
181+
metric_type=MetricType.ACTIVE_ENERGY,
182+
metric_unit=MetricUnit.KWH,
183+
resolution=TimeResolution.MIN_15,
184+
samples=[],
185+
)
186+
sample = MarketLocationSample(
187+
sample_time=datetime.now(timezone.utc),
188+
value=1.0,
189+
quality=DataQuality.MEASURED,
190+
revision=1,
191+
update_time=None,
192+
resampling_method=ResamplingMethod.UNSPECIFIED,
193+
)
194+
195+
async def sample_generator() -> (
196+
AsyncIterator[tuple[MarketLocationRef, MarketLocationSeries]]
197+
):
198+
yield (
199+
market_location_ref,
200+
MarketLocationSeries(
201+
market_location_ref=series.market_location_ref,
202+
direction=series.direction,
203+
metric_type=series.metric_type,
204+
metric_unit=series.metric_unit,
205+
resolution=series.resolution,
206+
samples=[sample],
207+
),
208+
)
209+
210+
assert [
211+
result async for result in client.upsert_samples(sample_generator())
212+
] == []
213+
assert stub.metadata is not None
214+
metadata = dict(stub.metadata)
215+
assert metadata["key"] == "test-key"
216+
assert metadata["ts"]
217+
assert metadata["nonce"]
218+
assert metadata["sig"]
219+
assert len(stub.requests) == 1

0 commit comments

Comments
 (0)