Skip to content

Commit

Permalink
Rename "pipelines" to "tasks" (#311)
Browse files Browse the repository at this point in the history
  • Loading branch information
ashleve authored May 26, 2022
1 parent a17c461 commit 56c4d00
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 58 deletions.
10 changes: 2 additions & 8 deletions src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,13 @@
import pyrootutils
from omegaconf import DictConfig

root = pyrootutils.setup_root(
indicator=".git",
search_from=__file__,
pythonpath=True,
cwd=True,
dotenv=True,
)
root = pyrootutils.setup_root(__file__)


@hydra.main(config_path=root / "configs", config_name="eval.yaml")
def main(cfg: DictConfig):

from src.pipelines.eval_pipeline import eval
from src.tasks.eval_task import eval

eval(cfg)

Expand Down
1 change: 0 additions & 1 deletion src/pipelines/__init__.py

This file was deleted.

42 changes: 0 additions & 42 deletions src/pipelines/_pipeline_wrapper.py

This file was deleted.

1 change: 1 addition & 0 deletions src/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from src.tasks._task_wrapper import task_wrapper
47 changes: 47 additions & 0 deletions src/tasks/_task_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import time

from omegaconf import DictConfig
from pytorch_lightning.utilities import rank_zero_only

from src import utils

log = utils.get_logger(__name__)


def task_wrapper(task_func):
"""Optional decorator that wraps the task function in extra utilities:
- logging the total time of execution
- enabling repeating task execution on failure
- calling the utils.extras() before the task is started
- calling the utils.finish() after the task is finished
"""

def wrap(cfg: DictConfig):
start = time.time()

# applies optional config utilities
utils.extras(cfg)

# TODO: repeat call if fails...
result = task_func(cfg=cfg)
metric_value, object_dict = result

# make sure everything closed properly
utils.finish(object_dict)

end = time.time()
save_exec_time(task_name=task_func.__name__, time_in_seconds=end - start)

assert isinstance(metric_value, float) or metric_value is None
assert isinstance(object_dict, dict)
return metric_value, object_dict

return wrap


@rank_zero_only
def save_exec_time(task_name, time_in_seconds):
with open("execution_time.log", "w+") as file:
file.write("Total execution time.\n")
file.write(task_name + ": " + str(time_in_seconds) + "\n")
6 changes: 3 additions & 3 deletions src/pipelines/eval_pipeline.py → src/tasks/eval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
from pytorch_lightning.loggers import LightningLoggerBase

from src import utils
from src.pipelines import pipeline_wrapper
from src.tasks import task_wrapper

log = utils.get_logger(__name__)


@pipeline_wrapper
@task_wrapper
def eval(cfg: DictConfig) -> Tuple[None, Dict[str, Any]]:
"""Evaluates given checkpoint on a datamodule testset.
This method is wrapped in @pipeline_wrapper decorator which applies extra utilities
This method is wrapped in @task_wrapper decorator which applies extra utilities
before and after the call.
Args:
Expand Down
6 changes: 3 additions & 3 deletions src/pipelines/train_pipeline.py → src/tasks/train_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
from pytorch_lightning.loggers import LightningLoggerBase

from src import utils
from src.pipelines import pipeline_wrapper
from src.tasks import task_wrapper

log = utils.get_logger(__name__)


@pipeline_wrapper
@task_wrapper
def train(cfg: DictConfig) -> Tuple[Optional[float], Dict[str, Any]]:
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
training.
This method is wrapped in @pipeline_wrapper decorator which applies extra utilities
This method is wrapped in @task_wrapper decorator which applies extra utilities
before and after the call.
Args:
Expand Down
2 changes: 1 addition & 1 deletion src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def main(cfg: DictConfig) -> float:

# imports can be nested inside @hydra.main to optimize tab completion
# https://github.com/facebookresearch/hydra/issues/934
from src.pipelines.train_pipeline import train
from src.tasks.train_task import train

# train model
metric_value, _ = train(cfg)
Expand Down

0 comments on commit 56c4d00

Please sign in to comment.