Skip to content

Commit

Permalink
chore: Formated methods
Browse files Browse the repository at this point in the history
  • Loading branch information
KamenDimitrov97 committed Sep 2, 2024
1 parent f6f478b commit e76a0a9
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 31 deletions.
62 changes: 46 additions & 16 deletions src/dewret/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,30 @@
import importlib
from functools import lru_cache
from types import FunctionType, ModuleType
from typing import Any, TypeVar, Annotated, Callable, get_origin, get_args, Mapping, get_type_hints
from typing import (
Any,
TypeVar,
Annotated,
Callable,
get_origin,
get_args,
Mapping,
get_type_hints,
)

T = TypeVar("T")
AtRender = Annotated[T, "AtRender"]
Fixed = Annotated[T, "Fixed"]


class FunctionAnalyser:
"""Convenience class for analysing a function with reduced duplication of effort.
Attributes:
_fn: the wrapped callable
_annotations: stored annotations for the function.
"""

_fn: Callable[..., Any]
_annotations: dict[str, Any]

Expand Down Expand Up @@ -79,9 +90,13 @@ def _typ_has(typ: type, annotation: type) -> bool:
Returns: True if the type has the given annotation, otherwise False.
"""
if not hasattr(annotation, "__metadata__"):
return False
if (origin := get_origin(typ)):
if origin is Annotated and hasattr(typ, "__metadata__") and typ.__metadata__ == annotation.__metadata__:
return False
if origin := get_origin(typ):
if (
origin is Annotated
and hasattr(typ, "__metadata__")
and typ.__metadata__ == annotation.__metadata__
):
return True
if any(FunctionAnalyser._typ_has(arg, annotation) for arg in get_args(typ)):
return True
Expand Down Expand Up @@ -113,19 +128,28 @@ def _get_all_imported_names(mod: ModuleType) -> dict[str, tuple[ModuleType, str]
if isinstance(node, ast.ImportFrom):
for name in node.names:
imported_names[name.asname or name.name] = (
importlib.import_module("".join(["."] * node.level) + (node.module or ""), package=mod.__package__),
name.name
importlib.import_module(
"".join(["."] * node.level) + (node.module or ""),
package=mod.__package__,
),
name.name,
)
return imported_names

@property
def free_vars(self) -> dict[str, Any]:
"""Get the free variables for this Callable."""
if self.fn.__code__ and self.fn.__closure__:
return dict(zip(self.fn.__code__.co_freevars, (c.cell_contents for c in self.fn.__closure__), strict=False))
return dict(
zip(
self.fn.__code__.co_freevars,
(c.cell_contents for c in self.fn.__closure__),
strict=False,
)
)
return {}

def get_argument_annotation(self, arg: str, exhaustive: bool=False) -> Any:
def get_argument_annotation(self, arg: str, exhaustive: bool = False) -> Any:
"""Retrieve the annotations for this argument.
Args:
Expand All @@ -135,22 +159,28 @@ def get_argument_annotation(self, arg: str, exhaustive: bool=False) -> Any:
Returns: annotation if available, else None.
"""
typ: type | None = None
if (typ := self.fn.__annotations__.get(arg)):
if typ := self.fn.__annotations__.get(arg):
if isinstance(typ, str):
typ = get_type_hints(self.fn, include_extras=True).get(arg)
elif exhaustive:
if (anns := get_type_hints(sys.modules[self.fn.__module__], include_extras=True)):
if (typ := anns.get(arg)):
if anns := get_type_hints(
sys.modules[self.fn.__module__], include_extras=True
):
if typ := anns.get(arg):
...
elif (orig_pair := self.get_all_imported_names().get(arg)):
elif orig_pair := self.get_all_imported_names().get(arg):
orig_module, orig_name = orig_pair
typ = orig_module.__annotations__.get(orig_name)
elif (value := self.free_vars.get(arg)):
elif value := self.free_vars.get(arg):
if not inspect.isclass(value) or inspect.isfunction(value):
raise RuntimeError(f"Cannot use free variables - please put {arg} at the global scope")
raise RuntimeError(
f"Cannot use free variables - please put {arg} at the global scope"
)
return typ

def argument_has(self, arg: str, annotation: type, exhaustive: bool=False) -> bool:
def argument_has(
self, arg: str, annotation: type, exhaustive: bool = False
) -> bool:
"""Check if the named argument has the given annotation.
Args:
Expand All @@ -163,7 +193,7 @@ def argument_has(self, arg: str, annotation: type, exhaustive: bool=False) -> bo
typ = self.get_argument_annotation(arg, exhaustive)
return bool(typ and self._typ_has(typ, annotation))

def is_at_construct_arg(self, arg: str, exhaustive: bool=False) -> bool:
def is_at_construct_arg(self, arg: str, exhaustive: bool = False) -> bool:
"""Convience function to check for `AtConstruct`, wrapping `FunctionAnalyser.argument_has`."""
return self.argument_has(arg, AtRender, exhaustive)

Expand Down
11 changes: 8 additions & 3 deletions src/dewret/backends/backend_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,19 @@ def is_lazy(task: Any) -> bool:
True if so, False otherwise.
"""
return isinstance(task, Delayed) or (
isinstance(task, tuple | list) and
all(is_lazy(elt) for elt in task)
isinstance(task, tuple | list) and all(is_lazy(elt) for elt in task)
)


lazy = delayed

def run(workflow: Workflow | None, task: Lazy | list[Lazy] | tuple[Lazy], thread_pool: ThreadPoolExecutor | None=None, **kwargs: Any) -> Any:

def run(
workflow: Workflow | None,
task: Lazy | list[Lazy] | tuple[Lazy],
thread_pool: ThreadPoolExecutor | None = None,
**kwargs: Any,
) -> Any:
"""Execute a task as the output of a workflow.
Runs a task with dask.
Expand Down
4 changes: 1 addition & 3 deletions src/dewret/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@
BasicType = str | float | bool | bytes | int | None
RawType = BasicType | list["RawType"] | dict[str, "RawType"]
FirmType = RawType | list["FirmType"] | dict[str, "FirmType"] | tuple["FirmType", ...]
ExprType = (
FirmType | Basic | list["ExprType"] | dict[str, "ExprType"] | tuple["ExprType", ...]
) # type: ignore
ExprType = (FirmType | Basic | list["ExprType"] | dict[str, "ExprType"] | tuple["ExprType", ...]) # type: ignore

U = TypeVar("U")
T = TypeVar("T")
Expand Down
38 changes: 29 additions & 9 deletions src/dewret/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,20 @@
import yaml

from .workflow import Workflow, NestedStep
from .core import RawType, RenderCall, BaseRenderModule, RawRenderModule, StructuredRenderModule, RenderConfiguration
from .core import (
RawType,
RenderCall,
BaseRenderModule,
RawRenderModule,
StructuredRenderModule,
RenderConfiguration,
)
from .utils import load_module_or_package

T = TypeVar("T")

def structured_to_raw(rendered: RawType, pretty: bool=False) -> str:

def structured_to_raw(rendered: RawType, pretty: bool = False) -> str:
"""Serialize a serializable structure to a string.
Args:
Expand All @@ -45,7 +53,10 @@ def structured_to_raw(rendered: RawType, pretty: bool=False) -> str:
output = str(rendered)
return output

def get_render_method(renderer: Path | RawRenderModule | StructuredRenderModule, pretty: bool=False) -> RenderCall:

def get_render_method(
renderer: Path | RawRenderModule | StructuredRenderModule, pretty: bool = False
) -> RenderCall:
"""Create a ready-made callable to render the workflow that is appropriate for the renderer module.
Args:
Expand All @@ -70,20 +81,29 @@ def get_render_method(renderer: Path | RawRenderModule | StructuredRenderModule,
if isinstance(render_module, RawRenderModule):
return render_module.render_raw
elif isinstance(render_module, (StructuredRenderModule)):
def _render(workflow: Workflow, render_module: StructuredRenderModule, pretty: bool=False, **kwargs: RenderConfiguration) -> dict[str, str]:

def _render(
workflow: Workflow,
render_module: StructuredRenderModule,
pretty: bool = False,
**kwargs: RenderConfiguration,
) -> dict[str, str]:
rendered = render_module.render(workflow, **kwargs)
return {
key: structured_to_raw(value, pretty=pretty)
for key, value in rendered.items()
}

return cast(RenderCall, partial(_render, render_module=render_module, pretty=pretty))
return cast(
RenderCall, partial(_render, render_module=render_module, pretty=pretty)
)

raise NotImplementedError(
"This render module neither seems to be a structured nor a raw render module."
)

raise NotImplementedError("This render module neither seems to be a structured nor a raw render module.")

def base_render(
workflow: Workflow, build_cb: Callable[[Workflow], T]
) -> dict[str, T]:
def base_render(workflow: Workflow, build_cb: Callable[[Workflow], T]) -> dict[str, T]:
"""Render to a dict-like structure.
Args:
Expand Down

0 comments on commit e76a0a9

Please sign in to comment.