Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support for nested parameters and FuncParam reference #56

Merged
merged 7 commits into from
Jul 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/handbooks/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,15 @@ Options
system might take a lot of CPU and if too high, the scheduler might not be accurate.

By default it is set to ``0.1``.

**param_materialize**: When to turn arguments to actual values.

Whether to turn the arguments to actual values before or after
creating threads (for ``execution="thread``) and processes
(for ``execution="process``). Options:

- ``pre``: Before thread/process creation.
- ``post``: After thread/process creation. (default)

Only applicable for some argument types and materialization type
specified in the argument itself overrides configuration setting.
45 changes: 27 additions & 18 deletions rocketry/args/builtin.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@

from typing import Any, Callable
from typing import Any, Callable, Optional
import warnings
try:
from typing import Literal
except ImportError: # pragma: no cover
from typing_extensions import Literal

from rocketry.core.parameters import BaseArgument
from rocketry.core.parameters import BaseArgument, Parameters
from rocketry.core.utils import filter_keyword_args

class SimpleArg(BaseArgument):
Expand Down Expand Up @@ -32,8 +36,10 @@ class Arg(BaseArgument):
def __init__(self, key:Any):
self.key = key

def get_value(self, task=None, **kwargs) -> Any:
return task.session.parameters[self.key]
def get_value(self, task=None, session=None, **kwargs) -> Any:
if session is None:
session = task.session
return session.parameters._get(self.key, task=task, session=session, **kwargs)

def __repr__(self):
return f'session.parameters[{repr(self.key)}]'
Expand Down Expand Up @@ -159,25 +165,28 @@ def my_task(my_param):
>>> session.parameters
Parameters(myarg1=FuncArg(myarg1), myarg2=FuncArg(myfunc))
"""
def __init__(self, __func:Callable, *args, **kwargs):
def __init__(self, __func:Callable, *args, materialize:Optional[Literal['pre', 'post']]=None, **kwargs):
self.func = __func
self.materialize = materialize
self.args = args
self.kwargs = self._get_kwargs(kwargs)
self.kwargs = kwargs

def _get_kwargs(self, kwargs):
defaults = {
"session": self.session,
}
defaults.update(kwargs)
return filter_keyword_args(self.func, defaults)

def get_value(self, task=None, **kwargs):
return self(task=task)
def get_value(self, **kwargs):
return self(**kwargs)

def __call__(self, **kwargs):
kwargs.update(self.kwargs)
kwargs = filter_keyword_args(self.func, kwargs)
return self.func(*self.args, **kwargs)
param_kwargs = Parameters._from_signature(self.func)
params = param_kwargs.materialize(**kwargs)
return self.func(*self.args, **params, **self.kwargs)

def stage(self, **kwargs):
session = kwargs['session']
materialize = self.materialize if self.materialize is not None else session.config.param_materialize

if materialize == "pre":
return self.get_value(**kwargs)
else:
return self

def __repr__(self):
cls_name = type(self).__name__
Expand Down
2 changes: 1 addition & 1 deletion rocketry/args/secret.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Private(BaseArgument):
def __init__(self, value):
self.__value = value

def get_value(self, task=None):
def get_value(self, task=None, **kwargs):
if task is None:
return self.string_hidden
else:
Expand Down
10 changes: 8 additions & 2 deletions rocketry/core/parameters/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ def get(self, item, default=None):
except KeyError:
return default

def _get(self, __item, **kwargs):
item = __item
if callable(item) and hasattr(item, "__rocketry__") and "param_name" in item.__rocketry__:
item = item.__rocketry__['param_name']
value = self._params[item]
return value if not isinstance(value, BaseArgument) else value.get_value(**kwargs)

def __iter__(self):
return iter(self._params)

Expand All @@ -76,8 +83,7 @@ def __len__(self):

def __getitem__(self, item):
"Materializes the parameters and hide private"
value = self._params[item]
return value if not isinstance(value, BaseArgument) else value.get_value()
return self._get(item)

def pre_materialize(self, *args, **kwargs):
"""Turn arguments to their values before passed
Expand Down
8 changes: 4 additions & 4 deletions rocketry/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

try:
from typing import Literal
except ImportError:
except ImportError: # pragma: no cover
from typing_extensions import Literal

from pydantic import BaseModel, Field, PrivateAttr, validator
Expand Down Expand Up @@ -449,7 +449,7 @@ async def _run_as_async(self, params:Parameters, direct_params:Parameters, execu
exc_info = (None, None, None)
params = self.postfilter_params(params)
params = Parameters(params) | Parameters(direct_params)
params = params.materialize(task=self)
params = params.materialize(task=self, session=self.session)

if execution in ('main', 'async'):
self.log_running()
Expand Down Expand Up @@ -522,8 +522,8 @@ async def _run_as_async(self, params:Parameters, direct_params:Parameters, execu
def run_as_thread(self, params:Parameters, **kwargs):
"""Create a new thread and run the task on that."""

params = params.pre_materialize(task=self)
direct_params = self.parameters.pre_materialize(task=self)
params = params.pre_materialize(task=self, session=self.session)
direct_params = self.parameters.pre_materialize(task=self, session=self.session)

self._thread_terminate.clear()

Expand Down
1 change: 1 addition & 0 deletions rocketry/parameters/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __call__(self, func: Callable):
session = FuncArg.session if self.session is None else self.session
name = self._get_name(func)
session.parameters[name] = FuncArg(func)
func.__rocketry__ = {'param_name': name}
return func

def _get_name(self, func):
Expand Down
6 changes: 6 additions & 0 deletions rocketry/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
from rocketry.log.defaults import create_default_handler
from typing import TYPE_CHECKING, Callable, ClassVar, Iterable, Dict, List, Optional, Set, Tuple, Type, Union, Any
from itertools import chain
try:
from typing import Literal
except ImportError: # pragma: no cover
from typing_extensions import Literal

from redbird.logging import RepoHandler
from rocketry._base import RedBase
Expand Down Expand Up @@ -59,6 +63,8 @@ class Config:
timeout: datetime.timedelta = datetime.timedelta(minutes=30)
shut_cond: Optional['BaseCondition'] = None

param_materialize:Literal['pre', 'post'] = 'post'

@validator('shut_cond', pre=True)
def parse_shut_cond(cls, value):
from rocketry.parse import parse_condition
Expand Down
2 changes: 1 addition & 1 deletion rocketry/tasks/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

try:
from typing import Literal
except ImportError:
except ImportError: # pragma: no cover
from typing_extensions import Literal

from pydantic import Field, validator
Expand Down
60 changes: 59 additions & 1 deletion rocketry/test/app/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from rocketry import Session
from rocketry.tasks import CommandTask
from rocketry.tasks import FuncTask
from rocketry.conds import false
from rocketry.conds import false, true

def set_logging_defaults():
task_logger = logging.getLogger("rocketry.task")
Expand Down Expand Up @@ -69,6 +69,64 @@ def do_func():

assert app.session['do_never'].start_cond == false

def test_nested_args():
set_logging_defaults()

# Creating app
app = Rocketry(config={'task_execution': 'main'})

@app.param('arg_1')
def my_arg_1():
return 'arg 1'

@app.param('arg_2')
def my_func_2(arg=Arg('arg_1')):
assert arg == "arg 1"
return 'arg 2'

@app.param('arg_3')
def my_func_3(arg_1=Arg('arg_1'), arg_2=Arg("arg_2")):
assert arg_1 == "arg 1"
assert arg_2 == "arg 2"
return 'arg 3'

# Creating a task to test this
@app.task(true)
def do_daily(arg=Arg('arg_3')):
...
assert arg == "arg 3"

app.session.config.shut_cond = TaskStarted(task='do_daily')
app.run()
logger = app.session['do_daily'].logger
assert logger.filter_by(action="success").count() == 1

def test_arg_ref():
set_logging_defaults()

# Creating app
app = Rocketry(config={'task_execution': 'main'})

@app.param('arg_1')
def my_arg_1():
return 'arg 1'

@app.param('arg_2')
def my_arg_2():
return 'arg 2'

# Creating a task to test this
@app.task(true)
def do_daily(arg_1=Arg(my_arg_1), arg_2=Arg(my_arg_2)):
...
assert arg_1 == "arg 1"
assert arg_2 == "arg 2"

app.session.config.shut_cond = TaskStarted(task='do_daily')
app.run()
logger = app.session['do_daily'].logger
assert logger.filter_by(action="success").count() == 1

def test_app_async():
set_logging_defaults()

Expand Down
36 changes: 35 additions & 1 deletion rocketry/test/session/params/test_func.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

import platform
import pytest

from rocketry.args import Private, Return
Expand All @@ -13,9 +13,17 @@
def get_x():
return "x"

def get_y():
return "y"

def func_x_with_arg(myparam):
assert myparam == "x"

def get_with_nested_args(arg = FuncArg(get_y), arg_2 = Arg('session_arg')):
assert arg == "y"
assert arg_2 == "z"
return 'x'

@pytest.mark.parametrize("execution", ["main", "thread", "process"])
def test_simple(session, execution):

Expand Down Expand Up @@ -51,6 +59,32 @@ def test_session(session, execution):

assert "success" == task.status

@pytest.mark.parametrize("config_mater", ['pre', 'post', None])
@pytest.mark.parametrize("materialize", ['pre', 'post', None])
@pytest.mark.parametrize("execution", ["main", "thread", "process"])
def test_nested(session, execution, materialize, config_mater):
if config_mater is not None:
session.config.param_materialize = config_mater

session.parameters["session_arg"] = "z"
session.parameters["myparam"] = FuncArg(get_with_nested_args, materialize=materialize)

task = FuncTask(
func_x_with_arg,
execution=execution,
name="a task",
start_cond=AlwaysTrue()
)
session.config.shut_cond = (TaskStarted(task="a task") >= 1)

assert task.status is None
session.start()
if platform.system() == "Windows" and execution == "process" and (materialize == "post" or (materialize is None and config_mater in ("post", None))):
# Windows cannot pickle the session but apparently Linux can
assert "fail" == task.status
else:
assert "success" == task.status

class UnPicklable:
def __getstate__(self):
raise RuntimeError("Cannot be pickled")
Expand Down