Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions faststream_outbox/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

if typing.TYPE_CHECKING:
from fast_depends.dependencies import Dependant
from fast_depends.library.serializer import SerializerProto
from faststream._internal.context.repository import ContextRepo
from sqlalchemy import Table
from sqlalchemy.ext.asyncio import AsyncEngine
Expand Down Expand Up @@ -114,14 +115,15 @@ def __init__( # noqa: PLR0913
log_level: int = logging.INFO,
# FastDepends
apply_types: bool = True,
serializer: "SerializerProto | None" = EMPTY,
# AsyncAPI
description: str | None = None,
tags: Iterable[Tag | TagDict] = (),
) -> None:
self._outbox_table = outbox_table
engine_state = EngineState(engine)
client = OutboxClient(engine, outbox_table) if engine is not None else None
fd_config = FastDependsConfig(use_fastdepends=apply_types)
fd_config = FastDependsConfig(use_fastdepends=apply_types, serializer=serializer)
broker_config = OutboxBrokerConfig(
engine_state=engine_state,
client=client,
Expand Down Expand Up @@ -222,7 +224,13 @@ async def publish( # ty: ignore[invalid-method-override] # noqa: PLR0913
if activate_in is not None and activate_at is not None:
msg = "broker.publish accepts at most one of activate_in / activate_at"
raise ValueError(msg)
payload, hdrs = _encode_payload(body, headers=headers, correlation_id=correlation_id)
serializer = self.config.broker_config.fd_config._serializer # noqa: SLF001
payload, hdrs = _encode_payload(
body,
headers=headers,
correlation_id=correlation_id,
serializer=serializer,
)
t = self._outbox_table
values: dict[str, typing.Any] = {"queue": queue, "payload": payload, "headers": hdrs}
# Server-side compute keeps timing immune to worker/DB clock skew (mirrors
Expand Down Expand Up @@ -291,9 +299,10 @@ async def publish_batch( # ty: ignore[invalid-method-override]
next_at = _dt.datetime.now(tz=_dt.UTC) + activate_in
elif activate_at is not None:
next_at = activate_at
serializer = self.config.broker_config.fd_config._serializer # noqa: SLF001
rows = []
for body in bodies:
payload, hdrs = _encode_payload(body, headers=headers)
payload, hdrs = _encode_payload(body, headers=headers, serializer=serializer)
row: dict[str, typing.Any] = {"queue": queue, "payload": payload, "headers": hdrs}
if next_at is not None:
row["next_attempt_at"] = next_at
Expand Down
14 changes: 11 additions & 3 deletions faststream_outbox/envelope.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,33 @@
"""Internal payload encoding for ``broker.publish``."""

from typing import Any
from typing import TYPE_CHECKING, Any

from faststream.message import gen_cor_id
from faststream.message.utils import encode_message


if TYPE_CHECKING:
from fast_depends.library.serializer import SerializerProto


def _encode_payload(
body: Any,
*,
headers: dict[str, str] | None = None,
correlation_id: str | None = None,
serializer: "SerializerProto | None" = None,
) -> tuple[bytes, dict[str, str]]:
"""
Serialize *body* into ``(payload_bytes, headers_dict)`` for an outbox row.

*body* may be ``bytes``, a pydantic model, a dataclass, a ``dict``, or any value
FastStream's ``encode_message`` accepts. *correlation_id* is auto-generated if
not supplied so handlers can always rely on it being present.
not supplied so handlers can always rely on it being present. *serializer* is
forwarded to FastStream's ``encode_message`` — pass the broker's resolved
``FastDependsConfig._serializer`` so pydantic models / dataclasses encode the
same way they do for every other FastStream broker.
"""
payload, content_type = encode_message(body, serializer=None)
payload, content_type = encode_message(body, serializer=serializer)
out_headers: dict[str, str] = dict(headers or {})
if content_type and "content-type" not in out_headers:
out_headers["content-type"] = content_type
Expand Down
54 changes: 53 additions & 1 deletion tests/test_unit.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import datetime as _dt
import json
import uuid
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from pydantic import BaseModel
from sqlalchemy import MetaData
from sqlalchemy.dialects import postgresql
from sqlalchemy.ext.asyncio import AsyncSession
Expand Down Expand Up @@ -116,6 +118,21 @@ def test_encode_payload_merges_user_headers() -> None:
assert headers["content-type"] == "application/json"


class _PydanticBody(BaseModel):
order_id: int
name: str


def test_encode_payload_serializes_pydantic_model_with_default_serializer() -> None:
"""Default broker resolves PydanticSerializer so BaseModel encodes as JSON."""
broker = _make_broker()
serializer = broker.config.broker_config.fd_config._serializer # noqa: SLF001
body = _PydanticBody(order_id=1, name="x")
payload, headers = _encode_payload(body, serializer=serializer)
assert json.loads(payload) == body.model_dump()
assert headers["content-type"] == "application/json"


# --- retry strategies ---


Expand Down Expand Up @@ -312,14 +329,38 @@ async def test_broker_publish_executes_insert_then_pg_notify_on_session() -> Non
assert "INSERT INTO" in str(insert_stmt)
params = insert_stmt.compile().params
assert params["queue"] == "orders"
assert params["payload"] == b'{"order_id": 1}'
assert json.loads(params["payload"]) == {"order_id": 1}
assert params["headers"]["content-type"] == "application/json"
notify_stmt, notify_params = session.execute.await_args_list[1].args
assert "pg_notify" in str(notify_stmt)
assert notify_params["channel"] == "outbox_outbox"
assert notify_params["payload"] == "orders"


async def test_broker_publish_encodes_pydantic_model() -> None:
broker = _make_broker()
session = _make_session_mock()
body = _PydanticBody(order_id=7, name="alpha")
await broker.publish(body, queue="orders", session=session)
insert_stmt = session.execute.await_args_list[0].args[0]
params = insert_stmt.compile().params
assert json.loads(params["payload"]) == body.model_dump()
assert params["headers"]["content-type"] == "application/json"


async def test_broker_publish_batch_encodes_pydantic_models() -> None:
broker = _make_broker()
session = _make_session_mock()
bodies = [_PydanticBody(order_id=1, name="a"), _PydanticBody(order_id=2, name="b")]
await broker.publish_batch(*bodies, queue="orders", session=session)
# First execute is the INSERT (executemany), second is pg_notify.
insert_call = session.execute.await_args_list[0]
rows = insert_call.args[1]
assert [json.loads(row["payload"]) for row in rows] == [b.model_dump() for b in bodies]
for row in rows:
assert row["headers"]["content-type"] == "application/json"


async def test_broker_publish_does_not_commit() -> None:
broker = _make_broker()
session = _make_session_mock()
Expand Down Expand Up @@ -884,6 +925,17 @@ async def handle(body: str) -> None: ...
assert sub._make_response_publisher(MagicMock()) == () # noqa: SLF001


async def test_subscriber_client_property_raises_when_broker_has_no_engine() -> None:
broker = _make_broker() # no engine → broker_config.client is None

@broker.subscriber("orders")
async def handle(body: dict) -> None: ...

sub = next(iter(broker._subscribers)) # noqa: SLF001
with pytest.raises(RuntimeError, match="not connected"):
_ = sub._client # noqa: SLF001


# --- _open_listen_connection fallback paths ---


Expand Down
Loading