diff --git a/packages/asyncfast/pyproject.toml b/packages/asyncfast/pyproject.toml index ee5501d..89e7a3c 100644 --- a/packages/asyncfast/pyproject.toml +++ b/packages/asyncfast/pyproject.toml @@ -34,7 +34,7 @@ classifiers = [ ] dependencies = [ "amgi-types==0.27.0", - "pydantic==2.11.7", + "pydantic>=2.0.0", ] optional-dependencies.standard = [ diff --git a/packages/asyncfast/src/asyncfast/__init__.py b/packages/asyncfast/src/asyncfast/__init__.py index 91be293..6c22780 100644 --- a/packages/asyncfast/src/asyncfast/__init__.py +++ b/packages/asyncfast/src/asyncfast/__init__.py @@ -73,20 +73,29 @@ async def send(self, message: M) -> None: await _send_message(self._send, message) +class _Field: + def __init__(self, type_: type): + self.type = type_ + self.type_adapter = TypeAdapter[Any](type_) + + def __hash__(self) -> int: + return hash(self.type) + + class Message(Mapping[str, Any]): __address__: ClassVar[str | None] = None - __headers__: ClassVar[dict[str, TypeAdapter[Any]]] + __headers__: ClassVar[dict[str, _Field]] __parameters__: ClassVar[dict[str, TypeAdapter[Any]]] - __payload__: ClassVar[tuple[str, TypeAdapter[Any]] | None] - __bindings__: ClassVar[dict[str, TypeAdapter[Any]]] + __payload__: ClassVar[tuple[str, _Field] | None] + __bindings__: ClassVar[dict[str, _Field]] def __init_subclass__(cls, address: str | None = None, **kwargs: Any) -> None: cls.__address__ = address annotations = list(_generate_message_annotations(address, cls.__annotations__)) headers = { - name: TypeAdapter(annotated) + name: _Field(annotated) for name, annotated in annotations if isinstance(get_args(annotated)[1], Header) } @@ -98,13 +107,13 @@ def __init_subclass__(cls, address: str | None = None, **kwargs: Any) -> None: } bindings = { - name: TypeAdapter(annotated) + name: _Field(annotated) for name, annotated in annotations if isinstance(get_args(annotated)[1], Binding) } payloads = [ - (name, TypeAdapter(annotated)) + (name, _Field(annotated)) for name, annotated in annotations if isinstance(get_args(annotated)[1], Payload) ] @@ -153,8 +162,8 @@ def _get_address(self) -> str | None: def _get_headers(self) -> Iterable[tuple[bytes, bytes]]: return [ - (name.encode(), self._get_value(name, type_adapter)) - for name, type_adapter in self.__headers__.items() + (name.encode(), self._get_value(name, field.type_adapter)) + for name, field in self.__headers__.items() ] def _get_value(self, name: str, type_adapter: TypeAdapter[Any]) -> bytes: @@ -167,18 +176,18 @@ def _get_value(self, name: str, type_adapter: TypeAdapter[Any]) -> bytes: def _get_payload(self) -> bytes | None: if self.__payload__ is None: return None - name, type_adapter = self.__payload__ - return type_adapter.dump_json(getattr(self, name)) + name, field = self.__payload__ + return field.type_adapter.dump_json(getattr(self, name)) def _get_bindings(self) -> dict[str, dict[str, Any]]: bindings: dict[str, dict[str, Any]] = {} - for name, type_adapter in self.__bindings__.items(): - binding_type = get_args(type_adapter._type)[1] + for name, field in self.__bindings__.items(): + binding_type = get_args(field.type)[1] assert isinstance(binding_type, Binding) bindings.setdefault(binding_type.__protocol__, {})[ binding_type.__field_name__ - ] = self._get_value(name, type_adapter) + ] = self._get_value(name, field.type_adapter) return bindings @@ -254,28 +263,28 @@ def _add_channel( annotations = list(_generate_annotations(address, signature)) headers = { - name: TypeAdapter(annotated) + name: _Field(annotated) for name, annotated in annotations if get_origin(annotated) is Annotated and isinstance(get_args(annotated)[1], Header) } parameters = { - name: TypeAdapter(annotated) + name: _Field(annotated) for name, annotated in annotations if get_origin(annotated) is Annotated and isinstance(get_args(annotated)[1], Parameter) } payloads = [ - (name, TypeAdapter(annotated)) + (name, _Field(annotated)) for name, annotated in annotations if get_origin(annotated) is Annotated and isinstance(get_args(annotated)[1], Payload) ] bindings = { - name: TypeAdapter(annotated) + name: _Field(annotated) for name, annotated in annotations if get_origin(annotated) is Annotated and isinstance(get_args(annotated)[1], Binding) @@ -381,35 +390,29 @@ def _generate_inputs( self, ) -> Generator[tuple[int, JsonSchemaMode, CoreSchema], None, None]: for channel in self._channels: - for type_adapter in channel._bindings.values(): - yield hash( - type_adapter._type - ), "serialization", type_adapter.core_schema + for field in channel._bindings.values(): + yield hash(field), "validation", field.type_adapter.core_schema headers_model = channel.headers_model if headers_model: - yield hash(headers_model), "serialization", TypeAdapter( + yield hash(headers_model), "validation", TypeAdapter( headers_model ).core_schema payload = channel.payload if payload: - _, type_adapter = payload - yield hash( - type_adapter._type - ), "serialization", type_adapter.core_schema + _, field = payload + yield hash(field), "validation", field.type_adapter.core_schema for message in channel.messages: if message.__payload__: - _, type_adapter = message.__payload__ + _, field = message.__payload__ - yield hash( - type_adapter._type - ), "serialization", type_adapter.core_schema + yield hash(field), "serialization", field.type_adapter.core_schema - for type_adapter in message.__bindings__.values(): + for field in message.__bindings__.values(): yield hash( - type_adapter._type - ), "serialization", type_adapter.core_schema + field.type + ), "serialization", field.type_adapter.core_schema def _generate_annotations( @@ -504,11 +507,11 @@ def __init__( address: str, address_pattern: Pattern[str], handler: Callable[..., Awaitable[None]], - headers: Mapping[str, TypeAdapter[Any]], - parameters: Mapping[str, TypeAdapter[Any]], - payload: tuple[str, TypeAdapter[Any]] | None, + headers: Mapping[str, _Field], + parameters: Mapping[str, _Field], + payload: tuple[str, _Field] | None, messages: Sequence[type[Message]], - bindings: Mapping[str, TypeAdapter[Any]], + bindings: Mapping[str, _Field], message_sender: str | None, ) -> None: self._address = address @@ -534,30 +537,29 @@ def title(self) -> str: return "".join(part.title() for part in self.name.split("_")) @property - def headers(self) -> Mapping[str, TypeAdapter[Any]]: + def headers(self) -> Mapping[str, _Field]: return self._headers @cached_property def headers_model(self) -> type[BaseModel] | None: if self._headers: headers_name = f"{self.title}Headers" + field_definitions: dict[str, Any] = { + name.replace("_", "-"): get_args(field.type) + for name, field in self._headers.items() + } headers_model = create_model( - headers_name, - **{ - name.replace("_", "-"): value._type - for name, value in self._headers.items() - }, - __base__=BaseModel, + headers_name, __base__=BaseModel, **field_definitions ) return headers_model return None @property - def payload(self) -> tuple[str, TypeAdapter[Any]] | None: + def payload(self) -> tuple[str, _Field] | None: return self._payload @property - def parameters(self) -> Mapping[str, TypeAdapter[Any]]: + def parameters(self) -> Mapping[str, _Field]: return self._parameters @property @@ -626,11 +628,20 @@ def _generate_arguments( parameters: dict[str, str], send: AMGISendCallable, ) -> Generator[tuple[str, Any], None, None]: + yield from self._generate_headers(message_receive_event) + yield from self._generate_payload(message_receive_event) + yield from self._generate_parameters(parameters) + yield from self._generate_bindings(message_receive_event) + if self._message_sender: + yield self._message_sender, MessageSender(send) + def _generate_headers( + self, message_receive_event: MessageReceiveEvent + ) -> Generator[tuple[str, Any], None, None]: if self.headers: headers = Headers(message_receive_event["headers"]) - for name, type_adapter in self.headers.items(): - annotated_args = get_args(type_adapter._type) + for name, field in self.headers.items(): + annotated_args = get_args(field.type) header_alias = annotated_args[1].alias alias = header_alias if header_alias else name.replace("_", "-") header = headers.get( @@ -641,30 +652,39 @@ def _generate_arguments( ) yield name, value + def _generate_payload( + self, message_receive_event: MessageReceiveEvent + ) -> Generator[tuple[str, Any], None, None]: if self.payload: - name, type_adapter = self.payload + name, field = self.payload payload = message_receive_event.get("payload") payload_obj = None if payload is None else json.loads(payload) - value = type_adapter.validate_python(payload_obj, from_attributes=True) + value = field.type_adapter.validate_python( + payload_obj, from_attributes=True + ) yield name, value - if self._parameters: - for name, type_adapter in self._parameters.items(): - yield name, type_adapter.validate_python(parameters[name]) - + def _generate_bindings( + self, message_receive_event: MessageReceiveEvent + ) -> Generator[tuple[str, Any], None, None]: if self._bindings: bindings = message_receive_event.get("bindings", {}) - for name, type_adapter in self._bindings.items(): - binding_type = get_args(type_adapter._type)[1] + for name, field in self._bindings.items(): + binding_type = get_args(field.type)[1] assert isinstance(binding_type, Binding) - yield name, type_adapter.validate_python( + yield name, field.type_adapter.validate_python( bindings.get(binding_type.__protocol__, {}).get( binding_type.__field_name__ ) ) - if self._message_sender: - yield self._message_sender, MessageSender(send) + + def _generate_parameters( + self, parameters: dict[str, str] + ) -> Generator[tuple[str, Any], None, None]: + if self._parameters: + for name, field in self._parameters.items(): + yield name, field.type_adapter.validate_python(parameters[name]) def _generate_messages( @@ -677,26 +697,24 @@ def _generate_messages( headers_model = channel.headers_model if headers_model: message["headers"] = field_mapping[ - hash(channel.headers_model), "serialization" + hash(channel.headers_model), "validation" ] payload = channel.payload if payload: - _, type_adapter = payload - message["payload"] = field_mapping[ - hash(type_adapter._type), "serialization" - ] + _, field = payload + message["payload"] = field_mapping[hash(field), "validation"] bindings: dict[str, dict[str, Any]] if channel._bindings: bindings = {} - for type_adapter in channel._bindings.values(): - binding_type = get_args(type_adapter._type)[1] + for field in channel._bindings.values(): + binding_type = get_args(field.type)[1] assert isinstance(binding_type, Binding) bindings.setdefault(binding_type.__protocol__, {})[ binding_type.__field_name__ - ] = field_mapping[hash(type_adapter._type), "serialization"] + ] = field_mapping[hash(field), "validation"] message["bindings"] = bindings yield f"{channel.title}Message", message @@ -705,20 +723,18 @@ def _generate_messages( message_message = {} if channel_message.__payload__: - _, type_adapter = channel_message.__payload__ - message_message["payload"] = field_mapping[ - hash(type_adapter._type), "serialization" - ] + _, field = channel_message.__payload__ + message_message["payload"] = field_mapping[hash(field), "serialization"] if channel_message.__bindings__: bindings = {} - for type_adapter in channel_message.__bindings__.values(): - binding_type = get_args(type_adapter._type)[1] + for field in channel_message.__bindings__.values(): + binding_type = get_args(field.type)[1] assert isinstance(binding_type, Binding) bindings.setdefault(binding_type.__protocol__, {})[ binding_type.__field_name__ - ] = field_mapping[hash(type_adapter._type), "serialization"] + ] = field_mapping[hash(field), "serialization"] message_message["bindings"] = bindings yield channel_message.__name__, message_message diff --git a/packages/asyncfast/tests_asyncfast/test_asyncapi.py b/packages/asyncfast/tests_asyncfast/test_asyncapi.py index ac6e7a6..1ff2a67 100644 --- a/packages/asyncfast/tests_asyncfast/test_asyncapi.py +++ b/packages/asyncfast/tests_asyncfast/test_asyncapi.py @@ -867,3 +867,56 @@ async def receive_handler( "sendSendB": {"action": "send", "channel": {"$ref": "#/channels/SendB"}}, }, } + + +def test_header_default() -> None: + app = AsyncFast() + + @app.channel("notification_channel") + async def notification_channel_handler( + request_id: Annotated[int, Header()] = 0, + ) -> None: + pass # pragma: no cover + + assert app.asyncapi() == { + "asyncapi": "3.0.0", + "channels": { + "NotificationChannelHandler": { + "address": "notification_channel", + "messages": { + "NotificationChannelHandlerMessage": { + "$ref": "#/components/messages/NotificationChannelHandlerMessage" + } + }, + } + }, + "components": { + "messages": { + "NotificationChannelHandlerMessage": { + "headers": { + "$ref": "#/components/schemas/NotificationChannelHandlerHeaders" + } + } + }, + "schemas": { + "NotificationChannelHandlerHeaders": { + "properties": { + "request-id": { + "default": 0, + "title": "Request-Id", + "type": "integer", + } + }, + "title": "NotificationChannelHandlerHeaders", + "type": "object", + } + }, + }, + "info": {"title": "AsyncFast", "version": "0.1.0"}, + "operations": { + "receiveNotificationChannelHandler": { + "action": "receive", + "channel": {"$ref": "#/channels/NotificationChannelHandler"}, + } + }, + } diff --git a/tox.ini b/tox.ini index 3d655c0..5722a1a 100644 --- a/tox.ini +++ b/tox.ini @@ -2,9 +2,11 @@ requires = tox>=4.2 env_list = + py313-asyncfast-pydantic2{8-12} clean pre-commit - py3{10-13}-{amgi-aiobotocore, amgi-aiokafka, amgi-common, amgi-paho-mqtt, amgi-redis, amgi-sqs-event-source-mapping, asyncfast} + py3{10-13}-{amgi-aiobotocore, amgi-aiokafka, amgi-common, amgi-paho-mqtt, amgi-redis, amgi-sqs-event-source-mapping} + py3{10-12}-asyncfast-pydantic2{0-12} py3{10-13}-{amgi-aiobotocore, amgi-aiokafka, amgi-common, amgi-paho-mqtt, amgi-redis, amgi-sqs-event-source-mapping, amgi-types, asyncfast, asyncfast-cli}-import [testenv] @@ -34,6 +36,25 @@ description = run pre-commit commands = pre-commit run --all-files --show-diff-on-failure +[testenv:py3{10-13}-asyncfast-pydantic2{0-12}] +commands_pre = + pydantic20: uv pip install "pydantic>=2.0,<2.1" + pydantic21: uv pip install "pydantic>=2.1,<2.2" + pydantic22: uv pip install "pydantic>=2.2,<2.3" + pydantic23: uv pip install "pydantic>=2.3,<2.4" + pydantic24: uv pip install "pydantic>=2.4,<2.5" + pydantic25: uv pip install "pydantic>=2.5,<2.6" + pydantic26: uv pip install "pydantic>=2.6,<2.7" + pydantic27: uv pip install "pydantic>=2.7,<2.8" + pydantic28: uv pip install "pydantic>=2.8,<2.9" + pydantic29: uv pip install "pydantic>=2.9,<2.10" + pydantic210: uv pip install "pydantic>=2.10,<2.11" + pydantic211: uv pip install "pydantic>=2.11,<2.12" + pydantic212: uv pip install "pydantic>=2.12,<2.13" +commands = + {[testenv]commands} packages/asyncfast +uv_sync_flags = --package=asyncfast + [testenv:py3{10-13}-amgi-aiobotocore] commands = {[testenv]commands} packages/amgi-aiobotocore @@ -64,11 +85,6 @@ commands = {[testenv]commands} packages/amgi-sqs-event-source-mapping uv_sync_flags = --package=amgi-sqs-event-source-mapping -[testenv:py3{10-13}-asyncfast] -commands = - {[testenv]commands} packages/asyncfast -uv_sync_flags = --package=asyncfast - [testenv:py3{10-13}-amgi-aiobotocore-import] commands = python -c "import amgi_aiobotocore" diff --git a/uv.lock b/uv.lock index 54e18ac..e7705b8 100644 --- a/uv.lock +++ b/uv.lock @@ -575,7 +575,7 @@ dev = [ requires-dist = [ { name = "amgi-types", editable = "packages/amgi-types" }, { name = "asyncfast-cli", marker = "extra == 'standard'", editable = "packages/asyncfast-cli" }, - { name = "pydantic", specifier = "==2.11.7" }, + { name = "pydantic", specifier = ">=2.0.0" }, ] provides-extras = ["standard"]