Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ dependencies = [
# Pact dependencies
"pact-python-ffi~=0.4.0",
# External dependencies
"sniffio~=1.0",
"yarl~=1.0",
"typing-extensions~=4.0 ; python_version < '3.13'",
]
Expand Down
148 changes: 136 additions & 12 deletions src/pact/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,34 @@

from __future__ import annotations

import asyncio
import contextvars
import inspect
import logging
import socket
import warnings
from contextlib import closing
from functools import partial
from inspect import Parameter, _ParameterKind
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING, Any, TypeVar

if TYPE_CHECKING:
from collections.abc import Callable, Mapping
from collections.abc import Callable, Coroutine, Mapping

try:
import sniffio # type: ignore[import-not-found]
except ImportError: # pragma: no cover
sniffio = None # type: ignore[assignment]

try:
import trio # type: ignore[import-not-found]
except ImportError: # pragma: no cover
trio = None # type: ignore[assignment]

try:
import curio # type: ignore[import-not-found,import-untyped]
except (ImportError, AttributeError):
curio = None # type: ignore[assignment]

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -179,7 +196,7 @@ def find_free_port() -> int:
return s.getsockname()[1]


def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa: C901
def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa: C901, PLR0912
"""
Apply arguments to a function.

Expand All @@ -188,6 +205,9 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
it is possible to pass arguments by name, and falling back to positional
arguments if not.

This function supports both synchronous and asynchronous callables. If the
callable is a coroutine function, it will be executed in an event loop.

Args:
f:
The function to apply the arguments to.
Expand All @@ -200,6 +220,8 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
Returns:
The result of the function.
"""
func_to_check = f.func if isinstance(f, partial) else f
is_async = inspect.iscoroutinefunction(func_to_check)
signature = inspect.signature(f)
f_name = (
f.__qualname__
Expand All @@ -226,7 +248,17 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
# First, we inspect the keyword arguments and try and pass in some arguments
# by currying them in.
for param in signature.parameters.values():
if param.name not in args:
arg_key = None
if param.name in args:
arg_key = param.name
elif (
param.name.startswith("_")
and len(param.name) > 1
and param.name[1:] in args
):
arg_key = param.name[1:]

if arg_key is None:
# If a parameter is not known, we will ignore it.
#
# If the ignored parameter doesn't have a default value, it will
Expand All @@ -246,12 +278,13 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
if param.kind in positional_match:
# We iterate through the parameters in order that they are defined,
# making it fine to pass them in by position one at a time.
f = partial(f, args[param.name])
del args[param.name]
f = partial(f, args[arg_key])
del args[arg_key]
continue

if param.kind in keyword_match:
f = partial(f, **{param.name: args[param.name]})
del args[param.name]
f = partial(f, **{param.name: args[arg_key]})
del args[arg_key]
continue

# At this stage, we have checked all arguments. If we have any arguments
Expand Down Expand Up @@ -281,8 +314,99 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
},
)

if is_async:
result = f()
if inspect.iscoroutine(result):
return _run_async_coroutine(result)
return result # pragma: no cover
return f()


def _run_async_coroutine(coro: Coroutine[Any, Any, _T]) -> _T:
"""
Run a coroutine in an event loop.

Detects the current async runtime and runs the coroutine in it,
preserving ContextVars across the dispatch.

Args:
coro:
The coroutine to run.

Returns:
The result of the coroutine.

Raises:
RuntimeError:
If the detected runtime (trio or curio) is not installed.
"""
runtime = _detect_async_runtime(coro)

if runtime == "trio":
if trio is None:
msg = "trio is not installed"
raise RuntimeError(msg)

ctx = contextvars.copy_context()

async def _run_trio() -> _T:
return await coro

return ctx.run(trio.run, _run_trio)

if runtime == "curio":
if curio is None:
msg = "curio is not installed"
raise RuntimeError(msg)

try:
return curio.AWAIT(coro)
except RuntimeError:
ctx = contextvars.copy_context()

async def _run_curio() -> _T:
return await coro

return ctx.run(curio.run, _run_curio)

try:
return f()
except Exception:
logger.exception("Error occurred while calling function %s", f_name)
raise
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None

if loop is not None:
future: asyncio.Future[_T] = asyncio.run_coroutine_threadsafe(coro, loop) # type: ignore[arg-type,assignment]
return future.result()

ctx = contextvars.copy_context()
return ctx.run(asyncio.run, coro) # type: ignore[arg-type,return-value]


def _detect_async_runtime(coro: Coroutine[Any, Any, _T]) -> str:
"""
Detect the async runtime to use for a given coroutine.

When called from within a running async context, `sniffio` is used to
identify the library. Otherwise the coroutine's `co_names` is inspected
for `trio` or `curio` references.

Args:
coro:
The coroutine to inspect.

Returns:
The detected runtime: "asyncio", "trio", or "curio".
"""
if sniffio is not None:
try:
return sniffio.current_async_library()
except sniffio.AsyncLibraryNotFoundError:
pass

names = set(coro.cr_code.co_names) # type: ignore[attr-defined]
if trio is not None and "trio" in names:
return "trio"
if curio is not None and "curio" in names:
return "curio"

return "asyncio"
Loading
Loading