From 1498128390190e243108ef570ec35fbf78e0c5c3 Mon Sep 17 00:00:00 2001 From: Young Joon Lee Date: Fri, 30 Jun 2023 09:36:26 +0900 Subject: [PATCH] feat(pipeline): update run_pipeline method --- src/hyfi/pipeline/__init__.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/hyfi/pipeline/__init__.py b/src/hyfi/pipeline/__init__.py index 1929877b..e4c87321 100644 --- a/src/hyfi/pipeline/__init__.py +++ b/src/hyfi/pipeline/__init__.py @@ -22,6 +22,8 @@ class PipelineConfig(BaseRunConfig): """Pipeline Configuration""" steps: Optional[List[Union[str, Dict]]] = [] + initial_object: Optional[Any] = None + use_self_as_initial_object: bool = False @validator("steps", pre=True) def steps_to_list(cls, v): @@ -88,7 +90,7 @@ class PIPELINEs: @staticmethod def run_pipeline( config: Union[Dict, PipelineConfig], - initial_obj: Optional[Any] = None, + initial_object: Optional[Any] = None, task: Optional[TaskConfig] = None, ) -> Any: """ @@ -106,17 +108,19 @@ def run_pipeline( if not isinstance(config, PipelineConfig): config = PipelineConfig(**Composer.to_dict(config)) pipes = config.get_pipes(task) + if initial_object is None and config.initial_object is not None: + initial_object = config.initial_object # Return initial object for the initial object if not pipes: logger.warning("No pipes specified") - return initial_obj + return initial_object logger.info("Applying %s pipes", len(pipes)) # Run the task in the current directory. if task is not None: with change_directory(task.root_dir): - return reduce(PIPELINEs.run_pipe, pipes, initial_obj) - return reduce(PIPELINEs.run_pipe, pipes, initial_obj) + return reduce(PIPELINEs.run_pipe, pipes, initial_object) + return reduce(PIPELINEs.run_pipe, pipes, initial_object) @staticmethod def run_pipe( @@ -225,10 +229,10 @@ def run_task(task: TaskConfig, project: Optional[ProjectConfig] = None): task.project = project # Run all pipelines in the pipeline. for pipeline in PIPELINEs.get_pipelines(task): - # Run the pipeline if verbose is true if task.verbose: logger.info("Running pipeline: %s", pipeline.dict()) - PIPELINEs.run_pipeline(pipeline, task=task) + initial_object = task if pipeline.use_self_as_initial_object else None + PIPELINEs.run_pipeline(pipeline, initial_object, task) @staticmethod def run_workflow(workflow: WorkflowConfig):