Skip to content

Commit

Permalink
fix(pipeline/configs): handle both string and dict inputs for run config
Browse files Browse the repository at this point in the history
  • Loading branch information
entelecheia committed Jul 26, 2023
1 parent 65b05a8 commit 793f218
Showing 1 changed file with 39 additions and 12 deletions.
51 changes: 39 additions & 12 deletions src/hyfi/pipeline/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class BaseRunConfig(BaseModel):
"""Run Configuration"""

run_with: Optional[Dict[str, Any]] = {}
run: Optional[Union[str, Dict[str, Any]]] = {}

name: Optional[str] = ""
desc: Optional[str] = ""
Expand All @@ -36,8 +36,21 @@ def validate_model_config_before(cls, data):
return Composer.replace_special_keys(Composer.to_dict(data))

@property
def kwargs(self):
return self.run_with or {}
def run_config(self) -> Dict[str, Any]:
if self.run and isinstance(self.run, str):
return {"_target_": self.run}
return self.run or {}

@property
def run_target(self) -> str:
return self.run_config.get("_target_") or ""

@property
def run_kwargs(self) -> Dict[str, Any]:
_kwargs = self.run_config.copy()
_kwargs.pop("_target_", None)
_kwargs.pop("_partial_", None)
return _kwargs


class RunningConfig(BaseRunConfig):
Expand All @@ -55,7 +68,6 @@ class PipeConfig(BaseRunConfig):
"""Pipe Configuration"""

pipe_target: str = ""
run: str = ""

use_pipe_obj: bool = True
pipe_obj_arg_name: Optional[str] = ""
Expand All @@ -75,20 +87,35 @@ def get_pipe_func(self) -> Optional[Callable]:
return None

def get_run_func(self) -> Optional[Callable]:
if self.run.startswith("lambda"):
logger.info("Returning lambda function: %s", self.run)
return eval(self.run)
elif self.run:
kwargs = self.run_with or {}
run_cfg = self.run_config
run_target = self.run_target
if run_target and run_target.startswith("lambda"):
logger.info("Returning lambda function: %s", run_target)
return eval(run_target)
elif run_cfg:
if self.pipe_obj_arg_name:
kwargs.pop(self.pipe_obj_arg_name)
run_cfg.pop(self.pipe_obj_arg_name)
logger.info(
"Returning partial function: %s with kwargs: %s", self.run, kwargs
"Returning partial function: %s with kwargs: %s", run_target, run_cfg
)
return Composer.partial(self.run, **kwargs)
return Composer.partial(run_cfg)
else:
logger.warning("No function found for %s", self)
return None
# if self.run.startswith("lambda"):
# logger.info("Returning lambda function: %s", self.run)
# return eval(self.run)
# elif self.run:
# kwargs = self.run_with or {}
# if self.pipe_obj_arg_name:
# kwargs.pop(self.pipe_obj_arg_name)
# logger.info(
# "Returning partial function: %s with kwargs: %s", self.run, kwargs
# )
# return Composer.partial(self.run, **kwargs)
# else:
# logger.warning("No function found for %s", self)
# return None


class DataframePipeConfig(PipeConfig):
Expand Down

0 comments on commit 793f218

Please sign in to comment.