Skip to content

Commit

Permalink
Merge pull request #79 from Miksus/dev/update_session
Browse files Browse the repository at this point in the history
ENH: Update session
  • Loading branch information
Miksus authored Aug 14, 2022
2 parents 05861b7 + 9168608 commit 18d0cb1
Show file tree
Hide file tree
Showing 42 changed files with 999 additions and 644 deletions.
3 changes: 3 additions & 0 deletions docs/versions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ Version history
- ``2.3.0``

- Add: Cron style scheduling
- Add: Task groups (``Grouper``) to support bigger applications
- Add: New condition, ``TaskRunnable``
- Add: New methods to session (``remove_task`` & ``create_task``)
- Add: ``always`` time period
- Fix: Various bugs related to ``Any``, ``All`` and ``StaticInterval`` time periods
- Fix: Integers as start and end in time periods
- Upd: Now time periods are immutable
- Upd: Now if session is not specified, tasks create new one.

- ``2.2.0``

Expand Down
2 changes: 1 addition & 1 deletion rocketry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
session = Session()
session.set_as_default()

from .application import Rocketry
from .application import Rocketry, Grouper

from . import _version
__version__ = _version.get_versions()['version']
52 changes: 31 additions & 21 deletions rocketry/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,35 @@
from rocketry import Session

class _AppMixin:

session: Session

def task(self, start_cond=None, name=None, *, command=None, path=None, **kwargs):
def task(self, start_cond=None, name=None, **kwargs):
"Create a task"

kwargs['session'] = self.session
kwargs['start_cond'] = start_cond
kwargs['name'] = name

if command is not None:
return CommandTask(command=command, **kwargs)
elif path is not None:
# Non-wrapped FuncTask
return FuncTask(path=path, **kwargs)
else:
return FuncTask(name_include_module=False, _name_template='{func_name}', **kwargs)
return self.session.create_task(start_cond=start_cond, name=name, **kwargs)

def param(self, name:Optional[str]=None):
"Set one session parameter (decorator)"
return FuncParam(name, session=self.session)

def cond(self, syntax: Union[str, Pattern, List[Union[str, Pattern]]]=None):
"Create a condition (decorator)"
return FuncCond(syntax=syntax, session=self.session, decor_return_func=False)

def params(self, **kwargs):
"Set session parameters"
self.session.parameters.update(kwargs)

def include_grouper(self, group:'Grouper'):
for task in group.session.tasks:
if group.prefix:
task.name = group.prefix + task.name
if group.start_cond is not None:
task.start_cond = task.start_cond & group.start_cond
task.execution = group.execution if task.execution is None else task.execution

self.session.add_task(task)
self.session.parameters.update(group.session.parameters)

class Rocketry(_AppMixin):
"""Rocketry scheduling application"""
Expand Down Expand Up @@ -66,14 +75,6 @@ async def serve(self, debug=False):
self.session.set_as_default()
await self.session.serve()

def cond(self, syntax: Union[str, Pattern, List[Union[str, Pattern]]]=None):
"Create a condition (decorator)"
return FuncCond(syntax=syntax, session=self.session, decor_return_func=False)

def params(self, **kwargs):
"Set session parameters"
self.session.parameters.update(kwargs)

def set_logger(self):
warnings.warn((
"set_logger is deprecated and will be removed in the future. "
Expand Down Expand Up @@ -103,3 +104,12 @@ def _get_repo(self, repo:str):
return CSVFileRepo(filename=filepath, model=LogRecord)
else:
raise NotImplementedError(f"Repo creation for {repo} not implemented")

class Grouper(_AppMixin):

def __init__(self, prefix:str=None, start_cond=None, execution=None):
self.prefix = prefix
self.start_cond = start_cond
self.execution = execution

self.session = Session()
2 changes: 1 addition & 1 deletion rocketry/core/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def handle_logs(self):
break
else:
self.logger.debug(f"Inserting record for '{record.task_name}' ({record.action})")
task = self.session.get_task(record.task_name)
task = self.session[record.task_name]
if record.action == "fail":
# There is a caveat in logging
# https://github.com/python/cpython/blame/fad6af2744c0b022568f7f4a8afc93fed056d4db/Lib/logging/handlers.py#L1383
Expand Down
17 changes: 11 additions & 6 deletions rocketry/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@

_IS_WINDOWS = platform.system()

def _create_session():
# To avoid circular imports
from rocketry import Session
return Session()

class Task(RedBase, BaseModel):
"""Base class for Tasks.
Expand Down Expand Up @@ -116,8 +121,7 @@ class Task(RedBase, BaseModel):
Logger of the task. Typically not needed
to be set.
session : rocketry.session.Session, optional
Session the task is binded to,
by default default session
Session the task is binded to.
Attributes
Expand Down Expand Up @@ -253,7 +257,8 @@ def __init__(self, **kwargs):
hooker.prerun(self)

if kwargs.get("session") is None:
kwargs['session'] = self.session
warnings.warn("Task's session not defined. Creating new.", UserWarning)
kwargs['session'] = _create_session()
kwargs['name'] = self._get_name(**kwargs)

super().__init__(**kwargs)
Expand Down Expand Up @@ -811,7 +816,7 @@ def _lock_to_run_log(self, log_queue):
else:

#self.logger.debug(f"Inserting record for '{record.task_name}' ({record.action})")
task = self.session.get_task(record.task_name)
task = self.session[record.task_name]
task.log_record(record)

action = record.action
Expand Down Expand Up @@ -1065,13 +1070,13 @@ def period(self) -> TimePeriod:
session = self.session

if isinstance(cond, (TaskSucceeded, TaskFinished)):
if session.get_task(cond.kwargs["task"]) is self:
if session[cond.kwargs["task"]] is self:
return cond.period

elif isinstance(cond, All):
task_periods = []
for sub_stmt in cond:
if isinstance(sub_stmt, (TaskFinished, TaskFinished)) and session.get_task(sub_stmt.kwargs["task"]) is self:
if isinstance(sub_stmt, (TaskFinished, TaskFinished)) and session[sub_stmt.kwargs["task"]] is self:
task_periods.append(sub_stmt.period)
if task_periods:
return AllTime(*task_periods)
Expand Down
29 changes: 28 additions & 1 deletion rocketry/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,32 @@ def get_tasks(self) -> list:
return self.tasks

def get_task(self, task):
#! TODO: Do we need this?
warnings.warn((
"Method get_task will be removed in the future version."
"Please use instead: session['task name']"
), DeprecationWarning)
return self[task]

def get_cond_parsers(self):
"Used by the actual string condition parser"
return self._cond_parsers

def create_task(self, *, command=None, path=None, **kwargs):
"Create a task and put it to the session"

# To avoid circular imports
from rocketry.tasks import CommandTask, FuncTask

kwargs['session'] = self

if command is not None:
return CommandTask(command=command, **kwargs)
elif path is not None:
# Non-wrapped FuncTask
return FuncTask(path=path, **kwargs)
else:
return FuncTask(name_include_module=False, _name_template='{func_name}', **kwargs)

def add_task(self, task: 'Task'):
"Add the task to the session"
if_exists = self.config.task_pre_exist
Expand All @@ -359,6 +378,14 @@ def add_task(self, task: 'Task'):
raise KeyError(f"Task '{task.name}' already exists")
else:
self.tasks.add(task)

# Adding the session to the task
task.session = self

def remove_task(self, task: Union['Task', str]):
if isinstance(task, str):
task = self[task]
self.session.tasks.remove(task)

def task_exists(self, task: 'Task'):
warnings.warn((
Expand Down
45 changes: 44 additions & 1 deletion rocketry/test/app/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,36 @@ def do_daily(arg=Arg('arg_3')):
logger = app.session['do_daily'].logger
assert logger.filter_by(action="success").count() == 1

def test_nested_args_from_func_arg():
set_logging_defaults()

# Creating app
app = Rocketry(config={'task_execution': 'main'})

@app.param('arg_1')
def my_arg_1():
return 'arg 1'

def my_func_2(arg=Arg('arg_1')):
assert arg == "arg 1"
return 'arg 2'

def my_func_3(arg_1=Arg('arg_1'), arg_2=FuncArg(my_func_2)):
assert arg_1 == "arg 1"
assert arg_2 == "arg 2"
return 'arg 3'

# Creating a task to test this
@app.task(true)
def do_daily(arg=FuncArg(my_func_3)):
...
assert arg == "arg 3"

app.session.config.shut_cond = TaskStarted(task='do_daily')
app.run()
logger = app.session['do_daily'].logger
assert logger.filter_by(action="success").count() == 1

def test_arg_ref():
set_logging_defaults()

Expand Down Expand Up @@ -203,4 +233,17 @@ def do_never(arg_1):
task_example = session['never done']
assert task_example.execution == 'process'
assert task_example.name == 'never done'
assert dict(task_example.parameters) == {'arg_1': 'something'}
assert dict(task_example.parameters) == {'arg_1': 'something'}


def test_task_name():
set_logging_defaults()

app = Rocketry(config={'task_execution': 'main'})

@app.task()
def do_func():
...
return 'return value'

assert app.session[do_func].name == "do_func"
Loading

0 comments on commit 18d0cb1

Please sign in to comment.