Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ repos:
hooks:
- id: reorder-python-imports
- repo: https://github.com/psf/black
rev: 25.11.0
rev: 26.1.0
hooks:
- id: black
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.19.0
rev: v1.19.1
hooks:
- id: mypy
additional_dependencies: ["pydantic==2.11.7", "types-docutils", "sphinx>=7.4.7"]
Expand All @@ -18,11 +18,11 @@ repos:
- id: autoflake
args: [--remove-all-unused-imports, --in-place]
- repo: https://github.com/google/yamlfmt
rev: v0.20.0
rev: v0.21.0
hooks:
- id: yamlfmt
- repo: https://github.com/commitizen-tools/commitizen
rev: v4.10.0
rev: v4.12.1
hooks:
- id: commitizen
stages: [commit-msg]
Expand Down Expand Up @@ -56,6 +56,6 @@ repos:
args: [--py310-plus]
exclude: '^.*test.*\.py$'
- repo: https://github.com/tox-dev/tox-ini-fmt
rev: "1.7.0"
rev: "1.7.1"
hooks:
- id: tox-ini-fmt
1 change: 0 additions & 1 deletion packages/amgi-aiokafka/src/amgi_aiokafka/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from amgi_types import MessageScope
from amgi_types import MessageSendEvent


logger = logging.getLogger("amgi-aiokafka.error")


Expand Down
13 changes: 13 additions & 0 deletions packages/asyncfast/docs/examples/header_alias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Annotated

from asyncfast import AsyncFast
from asyncfast import Header

app = AsyncFast()


@app.channel("topic")
async def topic_handler(
etag: Annotated[str, Header(alias="ETag")],
) -> None:
print(etag)
23 changes: 23 additions & 0 deletions packages/asyncfast/docs/message_headers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,33 @@
Header Parameters
###################

Use ``Header()`` to declare that a parameter should be read from message headers.

AsyncFast header handling follows the `AsyncAPI specification
<https://www.asyncapi.com/docs/reference/specification/v3.0.0#specification>`_ for message headers and schema and is
therefore case sensitive

By default, the header name is derived from the argument name and underscores become hyphens, for example, the argument
``request_id`` would become ``request-id``

You can also set an explicit header key using ``alias=`` (useful for exact casing like ``Idempotency-Key`` or ``ETag``).
These aliases are reflected both at runtime and in the generated AsyncAPI schema.

.. async-fast-example:: examples/header_builtin.py

*****************
Header Aliasing
*****************

Use ``alias=`` when you need a specific header name (including casing):

.. async-fast-example:: examples/header_alias.py

*********
Default
*********

Header parameters can have defaults, just like normal arguments. If the header is not present, the default value is
used.

.. async-fast-example:: examples/header_default.py
92 changes: 65 additions & 27 deletions packages/asyncfast/src/asyncfast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class Message(Mapping[str, Any]):

__address__: ClassVar[str | None] = None
__headers__: ClassVar[dict[str, _Field]]
__headers_model__: ClassVar[type[BaseModel] | None]
__parameters__: ClassVar[dict[str, TypeAdapter[Any]]]
__payload__: ClassVar[tuple[str, _Field] | None]
__bindings__: ClassVar[dict[str, _Field]]
Expand Down Expand Up @@ -160,18 +161,23 @@ def _get_address(self) -> str | None:

return self.__address__.format(**parameters)

def _generate_headers(self) -> Iterable[tuple[str, bytes]]:
for name, field in self.__headers__.items():
_, annotation = get_args(field.type)
alias = annotation.alias if annotation.alias else name.replace("_", "-")
yield alias, self._get_value(name, field.type_adapter)

def _get_headers(self) -> Iterable[tuple[bytes, bytes]]:
return [
(name.encode(), self._get_value(name, field.type_adapter))
for name, field in self.__headers__.items()
]
return [(name.encode(), value) for name, value in self._generate_headers()]

def _get_value(self, name: str, type_adapter: TypeAdapter[Any]) -> bytes:
json_value = type_adapter.dump_json(getattr(self, name))
value = json.loads(json_value)
if isinstance(value, str):
return value.encode()
return json_value
value = getattr(self, name)
python = type_adapter.dump_python(value, mode="json")
if isinstance(python, str):
return python.encode()
if isinstance(python, bytes):
return python
return type_adapter.dump_json(value)

def _get_payload(self) -> bytes | None:
if self.__payload__ is None:
Expand All @@ -190,6 +196,17 @@ def _get_bindings(self) -> dict[str, dict[str, Any]]:
] = self._get_value(name, field.type_adapter)
return bindings

@classmethod
def _headers_model(cls) -> type[BaseModel] | None:
if not hasattr(cls, "__headers_model__"):
if cls.__headers__:
cls.__headers_model__ = _create_headers_model(
f"{cls.__name__}Headers", cls.__headers__
)
else:
cls.__headers_model__ = None
return cls.__headers_model__


def _generate_message_annotations(
address: str | None,
Expand Down Expand Up @@ -224,7 +241,7 @@ def __init__(
version: str = "0.1.0",
lifespan: Lifespan | None = None,
) -> None:
self._channels: list[Channel] = []
self._channels: list[_Channel] = []
self._title = title
self._version = version
self._lifespan_context = lifespan
Expand Down Expand Up @@ -319,7 +336,7 @@ def _add_channel(

address_pattern = _address_pattern(address)

channel = Channel(
channel = _Channel(
address,
address_pattern,
function,
Expand Down Expand Up @@ -414,6 +431,12 @@ def _generate_inputs(
field.type
), "serialization", field.type_adapter.core_schema

message_headers_model = message._headers_model()
if message_headers_model:
yield hash(message_headers_model), "serialization", TypeAdapter(
message_headers_model
).core_schema


def _generate_annotations(
address: str,
Expand Down Expand Up @@ -500,7 +523,24 @@ async def _receive_messages(
more_messages = message.get("more_messages", False)


class Channel:
def _generate_field_definitions(
headers: Mapping[str, _Field],
) -> Iterator[tuple[str, Any]]:
for name, field in headers.items():
type_, annotation = get_args(field.type)
alias = annotation.alias if annotation.alias else name.replace("_", "-")
yield alias, (type_, annotation)


def _create_headers_model(
headers_name: str, headers: Mapping[str, _Field]
) -> type[BaseModel]:
return create_model(
headers_name, __base__=BaseModel, **dict(_generate_field_definitions(headers))
)


class _Channel:

def __init__(
self,
Expand Down Expand Up @@ -543,15 +583,7 @@ def headers(self) -> Mapping[str, _Field]:
@cached_property
def headers_model(self) -> type[BaseModel] | None:
if self._headers:
headers_name = f"{self.title}Headers"
field_definitions: dict[str, Any] = {
name.replace("_", "-"): get_args(field.type)
for name, field in self._headers.items()
}
headers_model = create_model(
headers_name, __base__=BaseModel, **field_definitions
)
return headers_model
return _create_headers_model(f"{self.title}Headers", self._headers)
return None

@property
Expand Down Expand Up @@ -639,7 +671,7 @@ def _generate_headers(
self, message_receive_event: MessageReceiveEvent
) -> Generator[tuple[str, Any], None, None]:
if self.headers:
headers = Headers(message_receive_event["headers"])
headers = _Headers(message_receive_event["headers"])
for name, field in self.headers.items():
annotated_args = get_args(field.type)
header_alias = annotated_args[1].alias
Expand Down Expand Up @@ -688,7 +720,7 @@ def _generate_parameters(


def _generate_messages(
channels: Iterable[Channel],
channels: Iterable[_Channel],
field_mapping: dict[tuple[int, JsonSchemaMode], JsonSchemaValue],
) -> Generator[tuple[str, dict[str, Any]], None, None]:
for channel in channels:
Expand Down Expand Up @@ -726,6 +758,12 @@ def _generate_messages(
_, field = channel_message.__payload__
message_message["payload"] = field_mapping[hash(field), "serialization"]

message_headers_model = channel_message._headers_model()
if message_headers_model:
message_message["headers"] = field_mapping[
hash(message_headers_model), "serialization"
]

if channel_message.__bindings__:
bindings = {}
for field in channel_message.__bindings__.values():
Expand All @@ -741,7 +779,7 @@ def _generate_messages(


def _generate_channels(
channels: Iterable[Channel],
channels: Iterable[_Channel],
) -> Generator[tuple[str, dict[str, Any]], None, None]:
for channel in channels:
message_name = f"{channel.title}Message"
Expand Down Expand Up @@ -776,7 +814,7 @@ def _generate_channels(


def _generate_operations(
channels: Iterable[Channel],
channels: Iterable[_Channel],
) -> Generator[tuple[str, dict[str, Any]], None, None]:
for channel in channels:
yield f"receive{channel.title}", {
Expand Down Expand Up @@ -815,14 +853,14 @@ def _get_address_parameters(address: str | None) -> set[str]:
return set(parameters)


class Headers(Mapping[str, str]):
class _Headers(Mapping[str, str]):

def __init__(self, raw_list: Iterable[tuple[bytes, bytes]]) -> None:
self.raw_list = list(raw_list)

def __getitem__(self, key: str, /) -> str:
for header_key, header_value in self.raw_list:
if header_key.decode().lower() == key.lower():
if header_key.decode() == key:
return header_value.decode()
raise KeyError(key)

Expand Down
Loading