1e31aef6aSopenharmony_ciimport inspect
2e31aef6aSopenharmony_ciimport typing as t
3e31aef6aSopenharmony_cifrom functools import WRAPPER_ASSIGNMENTS
4e31aef6aSopenharmony_cifrom functools import wraps
5e31aef6aSopenharmony_ci
6e31aef6aSopenharmony_cifrom .utils import _PassArg
7e31aef6aSopenharmony_cifrom .utils import pass_eval_context
8e31aef6aSopenharmony_ci
9e31aef6aSopenharmony_ciV = t.TypeVar("V")
10e31aef6aSopenharmony_ci
11e31aef6aSopenharmony_ci
12e31aef6aSopenharmony_cidef async_variant(normal_func):  # type: ignore
13e31aef6aSopenharmony_ci    def decorator(async_func):  # type: ignore
14e31aef6aSopenharmony_ci        pass_arg = _PassArg.from_obj(normal_func)
15e31aef6aSopenharmony_ci        need_eval_context = pass_arg is None
16e31aef6aSopenharmony_ci
17e31aef6aSopenharmony_ci        if pass_arg is _PassArg.environment:
18e31aef6aSopenharmony_ci
19e31aef6aSopenharmony_ci            def is_async(args: t.Any) -> bool:
20e31aef6aSopenharmony_ci                return t.cast(bool, args[0].is_async)
21e31aef6aSopenharmony_ci
22e31aef6aSopenharmony_ci        else:
23e31aef6aSopenharmony_ci
24e31aef6aSopenharmony_ci            def is_async(args: t.Any) -> bool:
25e31aef6aSopenharmony_ci                return t.cast(bool, args[0].environment.is_async)
26e31aef6aSopenharmony_ci
27e31aef6aSopenharmony_ci        # Take the doc and annotations from the sync function, but the
28e31aef6aSopenharmony_ci        # name from the async function. Pallets-Sphinx-Themes
29e31aef6aSopenharmony_ci        # build_function_directive expects __wrapped__ to point to the
30e31aef6aSopenharmony_ci        # sync function.
31e31aef6aSopenharmony_ci        async_func_attrs = ("__module__", "__name__", "__qualname__")
32e31aef6aSopenharmony_ci        normal_func_attrs = tuple(set(WRAPPER_ASSIGNMENTS).difference(async_func_attrs))
33e31aef6aSopenharmony_ci
34e31aef6aSopenharmony_ci        @wraps(normal_func, assigned=normal_func_attrs)
35e31aef6aSopenharmony_ci        @wraps(async_func, assigned=async_func_attrs, updated=())
36e31aef6aSopenharmony_ci        def wrapper(*args, **kwargs):  # type: ignore
37e31aef6aSopenharmony_ci            b = is_async(args)
38e31aef6aSopenharmony_ci
39e31aef6aSopenharmony_ci            if need_eval_context:
40e31aef6aSopenharmony_ci                args = args[1:]
41e31aef6aSopenharmony_ci
42e31aef6aSopenharmony_ci            if b:
43e31aef6aSopenharmony_ci                return async_func(*args, **kwargs)
44e31aef6aSopenharmony_ci
45e31aef6aSopenharmony_ci            return normal_func(*args, **kwargs)
46e31aef6aSopenharmony_ci
47e31aef6aSopenharmony_ci        if need_eval_context:
48e31aef6aSopenharmony_ci            wrapper = pass_eval_context(wrapper)
49e31aef6aSopenharmony_ci
50e31aef6aSopenharmony_ci        wrapper.jinja_async_variant = True
51e31aef6aSopenharmony_ci        return wrapper
52e31aef6aSopenharmony_ci
53e31aef6aSopenharmony_ci    return decorator
54e31aef6aSopenharmony_ci
55e31aef6aSopenharmony_ci
56e31aef6aSopenharmony_ci_common_primitives = {int, float, bool, str, list, dict, tuple, type(None)}
57e31aef6aSopenharmony_ci
58e31aef6aSopenharmony_ci
59e31aef6aSopenharmony_ciasync def auto_await(value: t.Union[t.Awaitable["V"], "V"]) -> "V":
60e31aef6aSopenharmony_ci    # Avoid a costly call to isawaitable
61e31aef6aSopenharmony_ci    if type(value) in _common_primitives:
62e31aef6aSopenharmony_ci        return t.cast("V", value)
63e31aef6aSopenharmony_ci
64e31aef6aSopenharmony_ci    if inspect.isawaitable(value):
65e31aef6aSopenharmony_ci        return await t.cast("t.Awaitable[V]", value)
66e31aef6aSopenharmony_ci
67e31aef6aSopenharmony_ci    return t.cast("V", value)
68e31aef6aSopenharmony_ci
69e31aef6aSopenharmony_ci
70e31aef6aSopenharmony_ciasync def auto_aiter(
71e31aef6aSopenharmony_ci    iterable: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
72e31aef6aSopenharmony_ci) -> "t.AsyncIterator[V]":
73e31aef6aSopenharmony_ci    if hasattr(iterable, "__aiter__"):
74e31aef6aSopenharmony_ci        async for item in t.cast("t.AsyncIterable[V]", iterable):
75e31aef6aSopenharmony_ci            yield item
76e31aef6aSopenharmony_ci    else:
77e31aef6aSopenharmony_ci        for item in iterable:
78e31aef6aSopenharmony_ci            yield item
79e31aef6aSopenharmony_ci
80e31aef6aSopenharmony_ci
81e31aef6aSopenharmony_ciasync def auto_to_list(
82e31aef6aSopenharmony_ci    value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
83e31aef6aSopenharmony_ci) -> t.List["V"]:
84e31aef6aSopenharmony_ci    return [x async for x in auto_aiter(value)]
85