diff --git a/faststream_outbox/broker.py b/faststream_outbox/broker.py index f75af73..9bff053 100644 --- a/faststream_outbox/broker.py +++ b/faststream_outbox/broker.py @@ -36,6 +36,7 @@ if typing.TYPE_CHECKING: from fast_depends.dependencies import Dependant + from fast_depends.library.serializer import SerializerProto from faststream._internal.context.repository import ContextRepo from sqlalchemy import Table from sqlalchemy.ext.asyncio import AsyncEngine @@ -114,6 +115,7 @@ def __init__( # noqa: PLR0913 log_level: int = logging.INFO, # FastDepends apply_types: bool = True, + serializer: "SerializerProto | None" = EMPTY, # AsyncAPI description: str | None = None, tags: Iterable[Tag | TagDict] = (), @@ -121,7 +123,7 @@ def __init__( # noqa: PLR0913 self._outbox_table = outbox_table engine_state = EngineState(engine) client = OutboxClient(engine, outbox_table) if engine is not None else None - fd_config = FastDependsConfig(use_fastdepends=apply_types) + fd_config = FastDependsConfig(use_fastdepends=apply_types, serializer=serializer) broker_config = OutboxBrokerConfig( engine_state=engine_state, client=client, @@ -222,7 +224,13 @@ async def publish( # ty: ignore[invalid-method-override] # noqa: PLR0913 if activate_in is not None and activate_at is not None: msg = "broker.publish accepts at most one of activate_in / activate_at" raise ValueError(msg) - payload, hdrs = _encode_payload(body, headers=headers, correlation_id=correlation_id) + serializer = self.config.broker_config.fd_config._serializer # noqa: SLF001 + payload, hdrs = _encode_payload( + body, + headers=headers, + correlation_id=correlation_id, + serializer=serializer, + ) t = self._outbox_table values: dict[str, typing.Any] = {"queue": queue, "payload": payload, "headers": hdrs} # Server-side compute keeps timing immune to worker/DB clock skew (mirrors @@ -291,9 +299,10 @@ async def publish_batch( # ty: ignore[invalid-method-override] next_at = _dt.datetime.now(tz=_dt.UTC) + activate_in elif activate_at is not None: next_at = activate_at + serializer = self.config.broker_config.fd_config._serializer # noqa: SLF001 rows = [] for body in bodies: - payload, hdrs = _encode_payload(body, headers=headers) + payload, hdrs = _encode_payload(body, headers=headers, serializer=serializer) row: dict[str, typing.Any] = {"queue": queue, "payload": payload, "headers": hdrs} if next_at is not None: row["next_attempt_at"] = next_at diff --git a/faststream_outbox/envelope.py b/faststream_outbox/envelope.py index 28a5e30..bc216b5 100644 --- a/faststream_outbox/envelope.py +++ b/faststream_outbox/envelope.py @@ -1,25 +1,33 @@ """Internal payload encoding for ``broker.publish``.""" -from typing import Any +from typing import TYPE_CHECKING, Any from faststream.message import gen_cor_id from faststream.message.utils import encode_message +if TYPE_CHECKING: + from fast_depends.library.serializer import SerializerProto + + def _encode_payload( body: Any, *, headers: dict[str, str] | None = None, correlation_id: str | None = None, + serializer: "SerializerProto | None" = None, ) -> tuple[bytes, dict[str, str]]: """ Serialize *body* into ``(payload_bytes, headers_dict)`` for an outbox row. *body* may be ``bytes``, a pydantic model, a dataclass, a ``dict``, or any value FastStream's ``encode_message`` accepts. *correlation_id* is auto-generated if - not supplied so handlers can always rely on it being present. + not supplied so handlers can always rely on it being present. *serializer* is + forwarded to FastStream's ``encode_message`` — pass the broker's resolved + ``FastDependsConfig._serializer`` so pydantic models / dataclasses encode the + same way they do for every other FastStream broker. """ - payload, content_type = encode_message(body, serializer=None) + payload, content_type = encode_message(body, serializer=serializer) out_headers: dict[str, str] = dict(headers or {}) if content_type and "content-type" not in out_headers: out_headers["content-type"] = content_type diff --git a/tests/test_unit.py b/tests/test_unit.py index 2f7a1af..414cf1c 100644 --- a/tests/test_unit.py +++ b/tests/test_unit.py @@ -1,8 +1,10 @@ import datetime as _dt +import json import uuid from unittest.mock import AsyncMock, MagicMock, patch import pytest +from pydantic import BaseModel from sqlalchemy import MetaData from sqlalchemy.dialects import postgresql from sqlalchemy.ext.asyncio import AsyncSession @@ -116,6 +118,21 @@ def test_encode_payload_merges_user_headers() -> None: assert headers["content-type"] == "application/json" +class _PydanticBody(BaseModel): + order_id: int + name: str + + +def test_encode_payload_serializes_pydantic_model_with_default_serializer() -> None: + """Default broker resolves PydanticSerializer so BaseModel encodes as JSON.""" + broker = _make_broker() + serializer = broker.config.broker_config.fd_config._serializer # noqa: SLF001 + body = _PydanticBody(order_id=1, name="x") + payload, headers = _encode_payload(body, serializer=serializer) + assert json.loads(payload) == body.model_dump() + assert headers["content-type"] == "application/json" + + # --- retry strategies --- @@ -312,7 +329,7 @@ async def test_broker_publish_executes_insert_then_pg_notify_on_session() -> Non assert "INSERT INTO" in str(insert_stmt) params = insert_stmt.compile().params assert params["queue"] == "orders" - assert params["payload"] == b'{"order_id": 1}' + assert json.loads(params["payload"]) == {"order_id": 1} assert params["headers"]["content-type"] == "application/json" notify_stmt, notify_params = session.execute.await_args_list[1].args assert "pg_notify" in str(notify_stmt) @@ -320,6 +337,30 @@ async def test_broker_publish_executes_insert_then_pg_notify_on_session() -> Non assert notify_params["payload"] == "orders" +async def test_broker_publish_encodes_pydantic_model() -> None: + broker = _make_broker() + session = _make_session_mock() + body = _PydanticBody(order_id=7, name="alpha") + await broker.publish(body, queue="orders", session=session) + insert_stmt = session.execute.await_args_list[0].args[0] + params = insert_stmt.compile().params + assert json.loads(params["payload"]) == body.model_dump() + assert params["headers"]["content-type"] == "application/json" + + +async def test_broker_publish_batch_encodes_pydantic_models() -> None: + broker = _make_broker() + session = _make_session_mock() + bodies = [_PydanticBody(order_id=1, name="a"), _PydanticBody(order_id=2, name="b")] + await broker.publish_batch(*bodies, queue="orders", session=session) + # First execute is the INSERT (executemany), second is pg_notify. + insert_call = session.execute.await_args_list[0] + rows = insert_call.args[1] + assert [json.loads(row["payload"]) for row in rows] == [b.model_dump() for b in bodies] + for row in rows: + assert row["headers"]["content-type"] == "application/json" + + async def test_broker_publish_does_not_commit() -> None: broker = _make_broker() session = _make_session_mock() @@ -884,6 +925,17 @@ async def handle(body: str) -> None: ... assert sub._make_response_publisher(MagicMock()) == () # noqa: SLF001 +async def test_subscriber_client_property_raises_when_broker_has_no_engine() -> None: + broker = _make_broker() # no engine → broker_config.client is None + + @broker.subscriber("orders") + async def handle(body: dict) -> None: ... + + sub = next(iter(broker._subscribers)) # noqa: SLF001 + with pytest.raises(RuntimeError, match="not connected"): + _ = sub._client # noqa: SLF001 + + # --- _open_listen_connection fallback paths ---