diff --git a/src/hyfi/pipeline/configs.py b/src/hyfi/pipeline/configs.py index 7b06ba6a..9dd2e9be 100644 --- a/src/hyfi/pipeline/configs.py +++ b/src/hyfi/pipeline/configs.py @@ -16,7 +16,7 @@ class BaseRunConfig(BaseModel): """Run Configuration""" - run_with: Optional[Dict[str, Any]] = {} + run: Optional[Union[str, Dict[str, Any]]] = {} name: Optional[str] = "" desc: Optional[str] = "" @@ -36,8 +36,21 @@ def validate_model_config_before(cls, data): return Composer.replace_special_keys(Composer.to_dict(data)) @property - def kwargs(self): - return self.run_with or {} + def run_config(self) -> Dict[str, Any]: + if self.run and isinstance(self.run, str): + return {"_target_": self.run} + return self.run or {} + + @property + def run_target(self) -> str: + return self.run_config.get("_target_") or "" + + @property + def run_kwargs(self) -> Dict[str, Any]: + _kwargs = self.run_config.copy() + _kwargs.pop("_target_", None) + _kwargs.pop("_partial_", None) + return _kwargs class RunningConfig(BaseRunConfig): @@ -55,7 +68,6 @@ class PipeConfig(BaseRunConfig): """Pipe Configuration""" pipe_target: str = "" - run: str = "" use_pipe_obj: bool = True pipe_obj_arg_name: Optional[str] = "" @@ -75,20 +87,35 @@ def get_pipe_func(self) -> Optional[Callable]: return None def get_run_func(self) -> Optional[Callable]: - if self.run.startswith("lambda"): - logger.info("Returning lambda function: %s", self.run) - return eval(self.run) - elif self.run: - kwargs = self.run_with or {} + run_cfg = self.run_config + run_target = self.run_target + if run_target and run_target.startswith("lambda"): + logger.info("Returning lambda function: %s", run_target) + return eval(run_target) + elif run_cfg: if self.pipe_obj_arg_name: - kwargs.pop(self.pipe_obj_arg_name) + run_cfg.pop(self.pipe_obj_arg_name) logger.info( - "Returning partial function: %s with kwargs: %s", self.run, kwargs + "Returning partial function: %s with kwargs: %s", run_target, run_cfg ) - return Composer.partial(self.run, **kwargs) + return Composer.partial(run_cfg) else: logger.warning("No function found for %s", self) return None + # if self.run.startswith("lambda"): + # logger.info("Returning lambda function: %s", self.run) + # return eval(self.run) + # elif self.run: + # kwargs = self.run_with or {} + # if self.pipe_obj_arg_name: + # kwargs.pop(self.pipe_obj_arg_name) + # logger.info( + # "Returning partial function: %s with kwargs: %s", self.run, kwargs + # ) + # return Composer.partial(self.run, **kwargs) + # else: + # logger.warning("No function found for %s", self) + # return None class DataframePipeConfig(PipeConfig):