Skip to content

Commit

Permalink
fix: Added a wrapper class for renderers for testing purposes
Browse files Browse the repository at this point in the history
  • Loading branch information
KamenDimitrov97 committed Sep 5, 2024
1 parent 3bff5e5 commit 20219a0
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 96 deletions.
66 changes: 42 additions & 24 deletions src/dewret/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@
import click
import json

from .core import set_configuration, set_render_configuration, RawRenderModule, StructuredRenderModule
from .core import (
set_configuration,
set_render_configuration,
RawRenderModule,
StructuredRenderModule,
)
from .utils import load_module_or_package
from .render import get_render_method
from .tasks import Backend, construct
Expand All @@ -52,27 +57,23 @@
default=Backend.DASK.name,
help="Backend to use for workflow evaluation.",
)
@click.option(
"--construct-args",
default="simplify_ids:true"
)
@click.option(
"--renderer",
default="cwl"
)
@click.option(
"--renderer-args",
default=""
)
@click.option(
"--output",
default="-"
)
@click.option("--construct-args", default="simplify_ids:true")
@click.option("--renderer", default="cwl")
@click.option("--renderer-args", default="")
@click.option("--output", default="-")
@click.argument("workflow_py", type=click.Path(exists=True, path_type=Path))
@click.argument("task")
@click.argument("arguments", nargs=-1)
def render(
workflow_py: Path, task: str, arguments: list[str], pretty: bool, backend: Backend, construct_args: str, renderer: str, renderer_args: str, output: str
workflow_py: Path,
task: str,
arguments: list[str],
pretty: bool,
backend: Backend,
construct_args: str,
renderer: str,
renderer_args: str,
output: str,
) -> None:
"""Render a workflow.
Expand All @@ -91,14 +92,21 @@ def render(
kwargs[key] = json.loads(val)

render_module: Path | ModuleType
if (mtch := re.match(r"^([a-z_0-9-.]+)$", renderer)):
if mtch := re.match(r"^([a-z_0-9-.]+)$", renderer):
render_module = importlib.import_module(f"dewret.renderers.{mtch.group(1)}")
if not isinstance(render_module, RawRenderModule) and not isinstance(render_module, StructuredRenderModule):
raise NotImplementedError("The imported render module does not seem to match the `RawRenderModule` or `StructuredRenderModule` protocols.")
if not isinstance(render_module.Renderer, RawRenderModule) and not isinstance(
render_module.Renderer, StructuredRenderModule
):
raise NotImplementedError(
"The imported render module does not seem to match the `RawRenderModule` or `StructuredRenderModule` protocols."
)
render_module = render_module.Renderer
elif renderer.startswith("@"):
render_module = Path(renderer[1:])
else:
raise RuntimeError("Renderer argument should be a known dewret renderer, or '@FILENAME' where FILENAME is a renderer")
raise RuntimeError(
"Renderer argument should be a known dewret renderer, or '@FILENAME' where FILENAME is a renderer"
)

if construct_args.startswith("@"):
with Path(construct_args[1:]).open() as construct_args_f:
Expand All @@ -118,18 +126,22 @@ def render(
renderer_kwargs = dict(pair.split(":") for pair in renderer_args.split(","))

if output == "-":

@contextmanager
def _opener(key: str, _: str) -> Generator[IO[Any], None, None]:
print(" ------ ", key, " ------ ")
yield sys.stdout
print()

opener = _opener
else:

@contextmanager
def _opener(key: str, mode: str) -> Generator[IO[Any], None, None]:
output_file = output.replace("%", key)
with Path(output_file).open(mode) as output_f:
yield output_f

opener = _opener

render = get_render_method(render_module, pretty=pretty)
Expand All @@ -138,8 +150,13 @@ def _opener(key: str, mode: str) -> Generator[IO[Any], None, None]:
task_fn = getattr(workflow, task)

try:
with set_configuration(**construct_kwargs), set_render_configuration(renderer_kwargs):
rendered = render(construct(task_fn(**kwargs), **construct_kwargs), **renderer_kwargs)
with (
set_configuration(**construct_kwargs),
set_render_configuration(renderer_kwargs),
):
rendered = render(
construct(task_fn(**kwargs), **construct_kwargs), **renderer_kwargs
)
except Exception as exc:
import traceback

Expand All @@ -164,4 +181,5 @@ def _opener(key: str, mode: str) -> Generator[IO[Any], None, None]:
output_f.write("\n---\n")
output_f.write(value)


render()
2 changes: 1 addition & 1 deletion src/dewret/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
Annotated,
Callable,
cast,
runtime_checkable,
runtime_checkable
)
from contextlib import contextmanager
from contextvars import ContextVar
Expand Down
45 changes: 26 additions & 19 deletions src/dewret/renderers/cwl.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@
Unset,
)
from dewret.render import base_render
from dewret.core import Reference, get_render_configuration, set_render_configuration
from dewret.core import (
Reference,
get_render_configuration,
set_render_configuration,
StructuredRenderModule,
)


class CommandInputSchema(TypedDict):
Expand Down Expand Up @@ -744,23 +749,25 @@ def render(self) -> dict[str, RawType]:
}


def render(
workflow: Workflow, **kwargs: Unpack[CWLRendererConfiguration]
) -> dict[str, dict[str, RawType]]:
"""Render to a dict-like structure.
class Renderer(StructuredRenderModule):
"""Wrapper class implementing StructuredRenderModule protocol."""
def render(
workflow: Workflow, **kwargs: Unpack[CWLRendererConfiguration]
) -> dict[str, dict[str, RawType]]:
"""Render to a dict-like structure.
Args:
workflow: workflow to evaluate result.
**kwargs: additional configuration arguments - these should match CWLRendererConfiguration.
Args:
workflow: workflow to evaluate result.
**kwargs: additional configuration arguments - these should match CWLRendererConfiguration.
Returns:
Reduced form as a native Python dict structure for
serialization.
"""
# TODO: Again, convincing mypy that a TypedDict has RawType values.
with set_render_configuration(kwargs): # type: ignore
rendered = base_render(
workflow,
lambda workflow: WorkflowDefinition.from_workflow(workflow).render(),
)
return rendered
Returns:
Reduced form as a native Python dict structure for
serialization.
"""
# TODO: Again, convincing mypy that a TypedDict has RawType values.
with set_render_configuration(kwargs): # type: ignore
rendered = base_render(
workflow,
lambda workflow: WorkflowDefinition.from_workflow(workflow).render(),
)
return rendered
10 changes: 5 additions & 5 deletions tests/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import yaml

from dewret.tasks import construct, workflow, TaskException
from dewret.renderers.cwl import render
from dewret.renderers.cwl import Renderer
from dewret.annotations import AtRender, FunctionAnalyser, Fixed
from dewret.core import set_configuration

Expand Down Expand Up @@ -82,7 +82,7 @@ def test_at_render() -> None:

result = to_int(num=increment(num=3), should_double=True)
wkflw = construct(result, simplify_ids=True)
subworkflows = render(wkflw, allow_complex_types=True)
subworkflows = Renderer.render(wkflw, allow_complex_types=True)
rendered = subworkflows["__root__"]
assert rendered == yaml.safe_load("""
cwlVersion: 1.2
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_at_render() -> None:

result = to_int(num=increment(num=3), should_double=False)
wkflw = construct(result, simplify_ids=True)
subworkflows = render(wkflw, allow_complex_types=True)
subworkflows = Renderer.render(wkflw, allow_complex_types=True)
rendered = subworkflows["__root__"]
assert rendered == yaml.safe_load("""
cwlVersion: 1.2
Expand Down Expand Up @@ -161,7 +161,7 @@ def test_at_render_between_modules() -> None:
"""Test rendering of workflows across different modules using `dewret.annotations.AtRender`."""
result = try_nothing()
wkflw = construct(result, simplify_ids=True)
subworkflows = render(wkflw, allow_complex_types=True)
subworkflows = Renderer.render(wkflw, allow_complex_types=True)
subworkflows["__root__"]


Expand All @@ -181,7 +181,7 @@ def loop_over_lists(list_1: list[int]) -> list[int]:
with set_configuration(flatten_all_nested=True):
result = loop_over_lists(list_1=[5, 6, 7, 8])
wkflw = construct(result, simplify_ids=True)
subworkflows = render(wkflw, allow_complex_types=True)
subworkflows = Renderer.render(wkflw, allow_complex_types=True)
rendered = subworkflows["__root__"]
assert rendered == yaml.safe_load("""
class: Workflow
Expand Down
4 changes: 2 additions & 2 deletions tests/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import yaml
import pytest
from dewret.tasks import construct, workflow, TaskException
from dewret.renderers.cwl import render
from dewret.renderers.cwl import Renderer
from dewret.core import set_configuration
from dewret.annotations import AtRender
from ._lib.extra import increment
Expand Down Expand Up @@ -38,7 +38,7 @@ def test_cwl_with_parameter() -> None:
with set_configuration(flatten_all_nested=True):
result = increment(num=floor(num=3, expected=True))
workflow = construct(result, simplify_ids=True)
rendered = render(workflow)["__root__"]
rendered = Renderer.render(workflow)["__root__"]
num_param = list(workflow.find_parameters())[0]
assert num_param

Expand Down
28 changes: 16 additions & 12 deletions tests/test_cwl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datetime import datetime, timedelta
from dewret.core import set_configuration
from dewret.tasks import construct, task, factory, TaskException
from dewret.renderers.cwl import render
from dewret.renderers.cwl import Renderer
from dewret.utils import hasher
from dewret.workflow import param

Expand All @@ -19,6 +19,7 @@
tuple_float_return,
)


@task()
def floor(num: int | float) -> int:
"""Converts int/float to int."""
Expand All @@ -44,7 +45,7 @@ def test_basic_cwl() -> None:
"""
result = pi()
workflow = construct(result)
rendered = render(workflow)["__root__"]
rendered = Renderer.render(workflow)["__root__"]
hsh = hasher(("pi",))

assert rendered == yaml.safe_load(f"""
Expand Down Expand Up @@ -77,7 +78,9 @@ def get_now() -> datetime:
now = factory(get_now)()
result = days_in_future(now=now, num=3)
workflow = construct(result, simplify_ids=True)
rendered = render(workflow, allow_complex_types=True, factories_as_params=True)["__root__"]
rendered = Renderer.render(
workflow, allow_complex_types=True, factories_as_params=True
)["__root__"]

assert rendered == yaml.safe_load("""
cwlVersion: 1.2
Expand Down Expand Up @@ -106,7 +109,7 @@ def get_now() -> datetime:
out: [out]
""")

rendered = render(workflow, allow_complex_types=True)["__root__"]
rendered = Renderer.render(workflow, allow_complex_types=True)["__root__"]

assert rendered == yaml.safe_load("""
cwlVersion: 1.2
Expand Down Expand Up @@ -145,7 +148,7 @@ def test_cwl_with_parameter() -> None:
"""
result = increment(num=3)
workflow = construct(result)
rendered = render(workflow)["__root__"]
rendered = Renderer.render(workflow)["__root__"]
num_param = list(workflow.find_parameters())[0]
hsh = hasher(("increment", ("num", f"int|:param:{num_param._.unique_name}")))

Expand All @@ -171,6 +174,7 @@ def test_cwl_with_parameter() -> None:
out: [out]
""")


def test_cwl_with_positional_parameter() -> None:
"""Check whether we can move raw input to parameters.
Expand All @@ -182,7 +186,7 @@ def test_cwl_with_positional_parameter() -> None:
with set_configuration(allow_positional_args=True):
result = increment(3)
workflow = construct(result)
rendered = render(workflow)["__root__"]
rendered = Renderer.render(workflow)["__root__"]
num_param = list(workflow.find_parameters())[0]
hsh = hasher(("increment", ("num", f"int|:param:{num_param._.unique_name}")))

Expand Down Expand Up @@ -218,7 +222,7 @@ def test_cwl_without_default() -> None:

result = increment(num=my_param)
workflow = construct(result)
rendered = render(workflow)["__root__"]
rendered = Renderer.render(workflow)["__root__"]
hsh = hasher(("increment", ("num", "int|:param:my_param")))

assert rendered == yaml.safe_load(f"""
Expand Down Expand Up @@ -248,7 +252,7 @@ def test_cwl_with_subworkflow() -> None:
my_param = param("num", typ=int)
result = increment(num=floor(num=triple_and_one(num=increment(num=my_param))))
workflow = construct(result, simplify_ids=True)
subworkflows = render(workflow)
subworkflows = Renderer.render(workflow)
rendered = subworkflows["__root__"]
del subworkflows["__root__"]

Expand Down Expand Up @@ -345,7 +349,7 @@ def test_cwl_references() -> None:
"""
result = double(num=increment(num=3))
workflow = construct(result)
rendered = render(workflow)["__root__"]
rendered = Renderer.render(workflow)["__root__"]
num_param = list(workflow.find_parameters())[0]
hsh_increment = hasher(
("increment", ("num", f"int|:param:{num_param._.unique_name}"))
Expand Down Expand Up @@ -390,7 +394,7 @@ def test_complex_cwl_references() -> None:
"""
result = sum(left=double(num=increment(num=23)), right=mod10(num=increment(num=23)))
workflow = construct(result, simplify_ids=True)
rendered = render(workflow)["__root__"]
rendered = Renderer.render(workflow)["__root__"]

assert rendered == yaml.safe_load("""
cwlVersion: 1.2
Expand Down Expand Up @@ -442,7 +446,7 @@ def test_cwl_with_subworkflow_and_raw_params() -> None:
my_param = param("num", typ=int)
result = increment(num=floor(num=triple_and_one(num=sum(left=my_param, right=3))))
workflow = construct(result, simplify_ids=True)
subworkflows = render(workflow)
subworkflows = Renderer.render(workflow)
rendered = subworkflows["__root__"]

del subworkflows["__root__"]
Expand Down Expand Up @@ -548,7 +552,7 @@ def test_tuple_floats() -> None:
"""
result = tuple_float_return()
workflow = construct(result, simplify_ids=True)
rendered = render(workflow)["__root__"]
rendered = Renderer.render(workflow)["__root__"]
assert rendered == yaml.safe_load("""
cwlVersion: 1.2
class: Workflow
Expand Down
Loading

0 comments on commit 20219a0

Please sign in to comment.