Skip to content

Commit

Permalink
Merge pull request #56 from Miksus/fix/nested_params
Browse files Browse the repository at this point in the history
fix: support for nested parameters and FuncParam reference
  • Loading branch information
Miksus authored Jul 23, 2022
2 parents 7599e82 + b0c9d59 commit 91fe241
Show file tree
Hide file tree
Showing 10 changed files with 154 additions and 28 deletions.
12 changes: 12 additions & 0 deletions docs/handbooks/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,15 @@ Options
system might take a lot of CPU and if too high, the scheduler might not be accurate.

By default it is set to ``0.1``.

**param_materialize**: When to turn arguments to actual values.

Whether to turn the arguments to actual values before or after
creating threads (for ``execution="thread``) and processes
(for ``execution="process``). Options:

- ``pre``: Before thread/process creation.
- ``post``: After thread/process creation. (default)

Only applicable for some argument types and materialization type
specified in the argument itself overrides configuration setting.
45 changes: 27 additions & 18 deletions rocketry/args/builtin.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@

from typing import Any, Callable
from typing import Any, Callable, Optional
import warnings
try:
from typing import Literal
except ImportError: # pragma: no cover
from typing_extensions import Literal

from rocketry.core.parameters import BaseArgument
from rocketry.core.parameters import BaseArgument, Parameters
from rocketry.core.utils import filter_keyword_args

class SimpleArg(BaseArgument):
Expand Down Expand Up @@ -32,8 +36,10 @@ class Arg(BaseArgument):
def __init__(self, key:Any):
self.key = key

def get_value(self, task=None, **kwargs) -> Any:
return task.session.parameters[self.key]
def get_value(self, task=None, session=None, **kwargs) -> Any:
if session is None:
session = task.session
return session.parameters._get(self.key, task=task, session=session, **kwargs)

def __repr__(self):
return f'session.parameters[{repr(self.key)}]'
Expand Down Expand Up @@ -159,25 +165,28 @@ def my_task(my_param):
>>> session.parameters
Parameters(myarg1=FuncArg(myarg1), myarg2=FuncArg(myfunc))
"""
def __init__(self, __func:Callable, *args, **kwargs):
def __init__(self, __func:Callable, *args, materialize:Optional[Literal['pre', 'post']]=None, **kwargs):
self.func = __func
self.materialize = materialize
self.args = args
self.kwargs = self._get_kwargs(kwargs)
self.kwargs = kwargs

def _get_kwargs(self, kwargs):
defaults = {
"session": self.session,
}
defaults.update(kwargs)
return filter_keyword_args(self.func, defaults)

def get_value(self, task=None, **kwargs):
return self(task=task)
def get_value(self, **kwargs):
return self(**kwargs)

def __call__(self, **kwargs):
kwargs.update(self.kwargs)
kwargs = filter_keyword_args(self.func, kwargs)
return self.func(*self.args, **kwargs)
param_kwargs = Parameters._from_signature(self.func)
params = param_kwargs.materialize(**kwargs)
return self.func(*self.args, **params, **self.kwargs)

def stage(self, **kwargs):
session = kwargs['session']
materialize = self.materialize if self.materialize is not None else session.config.param_materialize

if materialize == "pre":
return self.get_value(**kwargs)
else:
return self

def __repr__(self):
cls_name = type(self).__name__
Expand Down
2 changes: 1 addition & 1 deletion rocketry/args/secret.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Private(BaseArgument):
def __init__(self, value):
self.__value = value

def get_value(self, task=None):
def get_value(self, task=None, **kwargs):
if task is None:
return self.string_hidden
else:
Expand Down
10 changes: 8 additions & 2 deletions rocketry/core/parameters/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ def get(self, item, default=None):
except KeyError:
return default

def _get(self, __item, **kwargs):
item = __item
if callable(item) and hasattr(item, "__rocketry__") and "param_name" in item.__rocketry__:
item = item.__rocketry__['param_name']
value = self._params[item]
return value if not isinstance(value, BaseArgument) else value.get_value(**kwargs)

def __iter__(self):
return iter(self._params)

Expand All @@ -76,8 +83,7 @@ def __len__(self):

def __getitem__(self, item):
"Materializes the parameters and hide private"
value = self._params[item]
return value if not isinstance(value, BaseArgument) else value.get_value()
return self._get(item)

def pre_materialize(self, *args, **kwargs):
"""Turn arguments to their values before passed
Expand Down
8 changes: 4 additions & 4 deletions rocketry/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

try:
from typing import Literal
except ImportError:
except ImportError: # pragma: no cover
from typing_extensions import Literal

from pydantic import BaseModel, Field, PrivateAttr, validator
Expand Down Expand Up @@ -449,7 +449,7 @@ async def _run_as_async(self, params:Parameters, direct_params:Parameters, execu
exc_info = (None, None, None)
params = self.postfilter_params(params)
params = Parameters(params) | Parameters(direct_params)
params = params.materialize(task=self)
params = params.materialize(task=self, session=self.session)

if execution in ('main', 'async'):
self.log_running()
Expand Down Expand Up @@ -522,8 +522,8 @@ async def _run_as_async(self, params:Parameters, direct_params:Parameters, execu
def run_as_thread(self, params:Parameters, **kwargs):
"""Create a new thread and run the task on that."""

params = params.pre_materialize(task=self)
direct_params = self.parameters.pre_materialize(task=self)
params = params.pre_materialize(task=self, session=self.session)
direct_params = self.parameters.pre_materialize(task=self, session=self.session)

self._thread_terminate.clear()

Expand Down
1 change: 1 addition & 0 deletions rocketry/parameters/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __call__(self, func: Callable):
session = FuncArg.session if self.session is None else self.session
name = self._get_name(func)
session.parameters[name] = FuncArg(func)
func.__rocketry__ = {'param_name': name}
return func

def _get_name(self, func):
Expand Down
6 changes: 6 additions & 0 deletions rocketry/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
from rocketry.log.defaults import create_default_handler
from typing import TYPE_CHECKING, Callable, ClassVar, Iterable, Dict, List, Optional, Set, Tuple, Type, Union, Any
from itertools import chain
try:
from typing import Literal
except ImportError: # pragma: no cover
from typing_extensions import Literal

from redbird.logging import RepoHandler
from rocketry._base import RedBase
Expand Down Expand Up @@ -59,6 +63,8 @@ class Config:
timeout: datetime.timedelta = datetime.timedelta(minutes=30)
shut_cond: Optional['BaseCondition'] = None

param_materialize:Literal['pre', 'post'] = 'post'

@validator('shut_cond', pre=True)
def parse_shut_cond(cls, value):
from rocketry.parse import parse_condition
Expand Down
2 changes: 1 addition & 1 deletion rocketry/tasks/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

try:
from typing import Literal
except ImportError:
except ImportError: # pragma: no cover
from typing_extensions import Literal

from pydantic import Field, validator
Expand Down
60 changes: 59 additions & 1 deletion rocketry/test/app/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from rocketry import Session
from rocketry.tasks import CommandTask
from rocketry.tasks import FuncTask
from rocketry.conds import false
from rocketry.conds import false, true

def set_logging_defaults():
task_logger = logging.getLogger("rocketry.task")
Expand Down Expand Up @@ -69,6 +69,64 @@ def do_func():

assert app.session['do_never'].start_cond == false

def test_nested_args():
set_logging_defaults()

# Creating app
app = Rocketry(config={'task_execution': 'main'})

@app.param('arg_1')
def my_arg_1():
return 'arg 1'

@app.param('arg_2')
def my_func_2(arg=Arg('arg_1')):
assert arg == "arg 1"
return 'arg 2'

@app.param('arg_3')
def my_func_3(arg_1=Arg('arg_1'), arg_2=Arg("arg_2")):
assert arg_1 == "arg 1"
assert arg_2 == "arg 2"
return 'arg 3'

# Creating a task to test this
@app.task(true)
def do_daily(arg=Arg('arg_3')):
...
assert arg == "arg 3"

app.session.config.shut_cond = TaskStarted(task='do_daily')
app.run()
logger = app.session['do_daily'].logger
assert logger.filter_by(action="success").count() == 1

def test_arg_ref():
set_logging_defaults()

# Creating app
app = Rocketry(config={'task_execution': 'main'})

@app.param('arg_1')
def my_arg_1():
return 'arg 1'

@app.param('arg_2')
def my_arg_2():
return 'arg 2'

# Creating a task to test this
@app.task(true)
def do_daily(arg_1=Arg(my_arg_1), arg_2=Arg(my_arg_2)):
...
assert arg_1 == "arg 1"
assert arg_2 == "arg 2"

app.session.config.shut_cond = TaskStarted(task='do_daily')
app.run()
logger = app.session['do_daily'].logger
assert logger.filter_by(action="success").count() == 1

def test_app_async():
set_logging_defaults()

Expand Down
36 changes: 35 additions & 1 deletion rocketry/test/session/params/test_func.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

import platform
import pytest

from rocketry.args import Private, Return
Expand All @@ -13,9 +13,17 @@
def get_x():
return "x"

def get_y():
return "y"

def func_x_with_arg(myparam):
assert myparam == "x"

def get_with_nested_args(arg = FuncArg(get_y), arg_2 = Arg('session_arg')):
assert arg == "y"
assert arg_2 == "z"
return 'x'

@pytest.mark.parametrize("execution", ["main", "thread", "process"])
def test_simple(session, execution):

Expand Down Expand Up @@ -51,6 +59,32 @@ def test_session(session, execution):

assert "success" == task.status

@pytest.mark.parametrize("config_mater", ['pre', 'post', None])
@pytest.mark.parametrize("materialize", ['pre', 'post', None])
@pytest.mark.parametrize("execution", ["main", "thread", "process"])
def test_nested(session, execution, materialize, config_mater):
if config_mater is not None:
session.config.param_materialize = config_mater

session.parameters["session_arg"] = "z"
session.parameters["myparam"] = FuncArg(get_with_nested_args, materialize=materialize)

task = FuncTask(
func_x_with_arg,
execution=execution,
name="a task",
start_cond=AlwaysTrue()
)
session.config.shut_cond = (TaskStarted(task="a task") >= 1)

assert task.status is None
session.start()
if platform.system() == "Windows" and execution == "process" and (materialize == "post" or (materialize is None and config_mater in ("post", None))):
# Windows cannot pickle the session but apparently Linux can
assert "fail" == task.status
else:
assert "success" == task.status

class UnPicklable:
def __getstate__(self):
raise RuntimeError("Cannot be pickled")
Expand Down

0 comments on commit 91fe241

Please sign in to comment.