Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/asyncfast/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ classifiers = [
]
dependencies = [
"amgi-types==0.27.0",
"pydantic==2.11.7",
"pydantic>=2.0.0",
]

optional-dependencies.standard = [
Expand Down
164 changes: 90 additions & 74 deletions packages/asyncfast/src/asyncfast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,29 @@ async def send(self, message: M) -> None:
await _send_message(self._send, message)


class _Field:
def __init__(self, type_: type):
self.type = type_
self.type_adapter = TypeAdapter[Any](type_)

def __hash__(self) -> int:
return hash(self.type)


class Message(Mapping[str, Any]):

__address__: ClassVar[str | None] = None
__headers__: ClassVar[dict[str, TypeAdapter[Any]]]
__headers__: ClassVar[dict[str, _Field]]
__parameters__: ClassVar[dict[str, TypeAdapter[Any]]]
__payload__: ClassVar[tuple[str, TypeAdapter[Any]] | None]
__bindings__: ClassVar[dict[str, TypeAdapter[Any]]]
__payload__: ClassVar[tuple[str, _Field] | None]
__bindings__: ClassVar[dict[str, _Field]]

def __init_subclass__(cls, address: str | None = None, **kwargs: Any) -> None:
cls.__address__ = address
annotations = list(_generate_message_annotations(address, cls.__annotations__))

headers = {
name: TypeAdapter(annotated)
name: _Field(annotated)
for name, annotated in annotations
if isinstance(get_args(annotated)[1], Header)
}
Expand All @@ -98,13 +107,13 @@ def __init_subclass__(cls, address: str | None = None, **kwargs: Any) -> None:
}

bindings = {
name: TypeAdapter(annotated)
name: _Field(annotated)
for name, annotated in annotations
if isinstance(get_args(annotated)[1], Binding)
}

payloads = [
(name, TypeAdapter(annotated))
(name, _Field(annotated))
for name, annotated in annotations
if isinstance(get_args(annotated)[1], Payload)
]
Expand Down Expand Up @@ -153,8 +162,8 @@ def _get_address(self) -> str | None:

def _get_headers(self) -> Iterable[tuple[bytes, bytes]]:
return [
(name.encode(), self._get_value(name, type_adapter))
for name, type_adapter in self.__headers__.items()
(name.encode(), self._get_value(name, field.type_adapter))
for name, field in self.__headers__.items()
]

def _get_value(self, name: str, type_adapter: TypeAdapter[Any]) -> bytes:
Expand All @@ -167,18 +176,18 @@ def _get_value(self, name: str, type_adapter: TypeAdapter[Any]) -> bytes:
def _get_payload(self) -> bytes | None:
if self.__payload__ is None:
return None
name, type_adapter = self.__payload__
return type_adapter.dump_json(getattr(self, name))
name, field = self.__payload__
return field.type_adapter.dump_json(getattr(self, name))

def _get_bindings(self) -> dict[str, dict[str, Any]]:
bindings: dict[str, dict[str, Any]] = {}
for name, type_adapter in self.__bindings__.items():
binding_type = get_args(type_adapter._type)[1]
for name, field in self.__bindings__.items():
binding_type = get_args(field.type)[1]
assert isinstance(binding_type, Binding)

bindings.setdefault(binding_type.__protocol__, {})[
binding_type.__field_name__
] = self._get_value(name, type_adapter)
] = self._get_value(name, field.type_adapter)
return bindings


Expand Down Expand Up @@ -254,28 +263,28 @@ def _add_channel(
annotations = list(_generate_annotations(address, signature))

headers = {
name: TypeAdapter(annotated)
name: _Field(annotated)
for name, annotated in annotations
if get_origin(annotated) is Annotated
and isinstance(get_args(annotated)[1], Header)
}

parameters = {
name: TypeAdapter(annotated)
name: _Field(annotated)
for name, annotated in annotations
if get_origin(annotated) is Annotated
and isinstance(get_args(annotated)[1], Parameter)
}

payloads = [
(name, TypeAdapter(annotated))
(name, _Field(annotated))
for name, annotated in annotations
if get_origin(annotated) is Annotated
and isinstance(get_args(annotated)[1], Payload)
]

bindings = {
name: TypeAdapter(annotated)
name: _Field(annotated)
for name, annotated in annotations
if get_origin(annotated) is Annotated
and isinstance(get_args(annotated)[1], Binding)
Expand Down Expand Up @@ -381,35 +390,29 @@ def _generate_inputs(
self,
) -> Generator[tuple[int, JsonSchemaMode, CoreSchema], None, None]:
for channel in self._channels:
for type_adapter in channel._bindings.values():
yield hash(
type_adapter._type
), "serialization", type_adapter.core_schema
for field in channel._bindings.values():
yield hash(field), "validation", field.type_adapter.core_schema

headers_model = channel.headers_model
if headers_model:
yield hash(headers_model), "serialization", TypeAdapter(
yield hash(headers_model), "validation", TypeAdapter(
headers_model
).core_schema
payload = channel.payload
if payload:
_, type_adapter = payload
yield hash(
type_adapter._type
), "serialization", type_adapter.core_schema
_, field = payload
yield hash(field), "validation", field.type_adapter.core_schema

for message in channel.messages:
if message.__payload__:
_, type_adapter = message.__payload__
_, field = message.__payload__

yield hash(
type_adapter._type
), "serialization", type_adapter.core_schema
yield hash(field), "serialization", field.type_adapter.core_schema

for type_adapter in message.__bindings__.values():
for field in message.__bindings__.values():
yield hash(
type_adapter._type
), "serialization", type_adapter.core_schema
field.type
), "serialization", field.type_adapter.core_schema


def _generate_annotations(
Expand Down Expand Up @@ -504,11 +507,11 @@ def __init__(
address: str,
address_pattern: Pattern[str],
handler: Callable[..., Awaitable[None]],
headers: Mapping[str, TypeAdapter[Any]],
parameters: Mapping[str, TypeAdapter[Any]],
payload: tuple[str, TypeAdapter[Any]] | None,
headers: Mapping[str, _Field],
parameters: Mapping[str, _Field],
payload: tuple[str, _Field] | None,
messages: Sequence[type[Message]],
bindings: Mapping[str, TypeAdapter[Any]],
bindings: Mapping[str, _Field],
message_sender: str | None,
) -> None:
self._address = address
Expand All @@ -534,30 +537,29 @@ def title(self) -> str:
return "".join(part.title() for part in self.name.split("_"))

@property
def headers(self) -> Mapping[str, TypeAdapter[Any]]:
def headers(self) -> Mapping[str, _Field]:
return self._headers

@cached_property
def headers_model(self) -> type[BaseModel] | None:
if self._headers:
headers_name = f"{self.title}Headers"
field_definitions: dict[str, Any] = {
name.replace("_", "-"): get_args(field.type)
for name, field in self._headers.items()
}
headers_model = create_model(
headers_name,
**{
name.replace("_", "-"): value._type
for name, value in self._headers.items()
},
__base__=BaseModel,
headers_name, __base__=BaseModel, **field_definitions
)
return headers_model
return None

@property
def payload(self) -> tuple[str, TypeAdapter[Any]] | None:
def payload(self) -> tuple[str, _Field] | None:
return self._payload

@property
def parameters(self) -> Mapping[str, TypeAdapter[Any]]:
def parameters(self) -> Mapping[str, _Field]:
return self._parameters

@property
Expand Down Expand Up @@ -626,11 +628,20 @@ def _generate_arguments(
parameters: dict[str, str],
send: AMGISendCallable,
) -> Generator[tuple[str, Any], None, None]:
yield from self._generate_headers(message_receive_event)
yield from self._generate_payload(message_receive_event)
yield from self._generate_parameters(parameters)
yield from self._generate_bindings(message_receive_event)
if self._message_sender:
yield self._message_sender, MessageSender(send)

def _generate_headers(
self, message_receive_event: MessageReceiveEvent
) -> Generator[tuple[str, Any], None, None]:
if self.headers:
headers = Headers(message_receive_event["headers"])
for name, type_adapter in self.headers.items():
annotated_args = get_args(type_adapter._type)
for name, field in self.headers.items():
annotated_args = get_args(field.type)
header_alias = annotated_args[1].alias
alias = header_alias if header_alias else name.replace("_", "-")
header = headers.get(
Expand All @@ -641,30 +652,39 @@ def _generate_arguments(
)
yield name, value

def _generate_payload(
self, message_receive_event: MessageReceiveEvent
) -> Generator[tuple[str, Any], None, None]:
if self.payload:
name, type_adapter = self.payload
name, field = self.payload
payload = message_receive_event.get("payload")
payload_obj = None if payload is None else json.loads(payload)
value = type_adapter.validate_python(payload_obj, from_attributes=True)
value = field.type_adapter.validate_python(
payload_obj, from_attributes=True
)
yield name, value

if self._parameters:
for name, type_adapter in self._parameters.items():
yield name, type_adapter.validate_python(parameters[name])

def _generate_bindings(
self, message_receive_event: MessageReceiveEvent
) -> Generator[tuple[str, Any], None, None]:
if self._bindings:
bindings = message_receive_event.get("bindings", {})
for name, type_adapter in self._bindings.items():
binding_type = get_args(type_adapter._type)[1]
for name, field in self._bindings.items():
binding_type = get_args(field.type)[1]
assert isinstance(binding_type, Binding)

yield name, type_adapter.validate_python(
yield name, field.type_adapter.validate_python(
bindings.get(binding_type.__protocol__, {}).get(
binding_type.__field_name__
)
)
if self._message_sender:
yield self._message_sender, MessageSender(send)

def _generate_parameters(
self, parameters: dict[str, str]
) -> Generator[tuple[str, Any], None, None]:
if self._parameters:
for name, field in self._parameters.items():
yield name, field.type_adapter.validate_python(parameters[name])


def _generate_messages(
Expand All @@ -677,26 +697,24 @@ def _generate_messages(
headers_model = channel.headers_model
if headers_model:
message["headers"] = field_mapping[
hash(channel.headers_model), "serialization"
hash(channel.headers_model), "validation"
]

payload = channel.payload
if payload:
_, type_adapter = payload
message["payload"] = field_mapping[
hash(type_adapter._type), "serialization"
]
_, field = payload
message["payload"] = field_mapping[hash(field), "validation"]

bindings: dict[str, dict[str, Any]]
if channel._bindings:
bindings = {}
for type_adapter in channel._bindings.values():
binding_type = get_args(type_adapter._type)[1]
for field in channel._bindings.values():
binding_type = get_args(field.type)[1]
assert isinstance(binding_type, Binding)

bindings.setdefault(binding_type.__protocol__, {})[
binding_type.__field_name__
] = field_mapping[hash(type_adapter._type), "serialization"]
] = field_mapping[hash(field), "validation"]
message["bindings"] = bindings

yield f"{channel.title}Message", message
Expand All @@ -705,20 +723,18 @@ def _generate_messages(
message_message = {}

if channel_message.__payload__:
_, type_adapter = channel_message.__payload__
message_message["payload"] = field_mapping[
hash(type_adapter._type), "serialization"
]
_, field = channel_message.__payload__
message_message["payload"] = field_mapping[hash(field), "serialization"]

if channel_message.__bindings__:
bindings = {}
for type_adapter in channel_message.__bindings__.values():
binding_type = get_args(type_adapter._type)[1]
for field in channel_message.__bindings__.values():
binding_type = get_args(field.type)[1]
assert isinstance(binding_type, Binding)

bindings.setdefault(binding_type.__protocol__, {})[
binding_type.__field_name__
] = field_mapping[hash(type_adapter._type), "serialization"]
] = field_mapping[hash(field), "serialization"]
message_message["bindings"] = bindings

yield channel_message.__name__, message_message
Expand Down
Loading