diff --git a/src/hyfi/workflow/__init__.py b/src/hyfi/workflow/__init__.py index 75dd1b1a..60ca4996 100644 --- a/src/hyfi/workflow/__init__.py +++ b/src/hyfi/workflow/__init__.py @@ -1,6 +1,8 @@ from typing import Dict, List, Optional, Union -from hyfi.composer import BaseConfig +from pydantic import BaseModel, ConfigDict, model_validator + +from hyfi.composer import Composer from hyfi.project import ProjectConfig from hyfi.task import TaskConfig from hyfi.utils.logging import LOGGING @@ -10,16 +12,24 @@ Tasks = List[TaskConfig] -class WorkflowConfig(BaseConfig): - _config_name_: str = "__init__" - _config_group_: str = "workflow" - +class WorkflowConfig(BaseModel): project: Optional[ProjectConfig] = None tasks: Optional[List[Union[str, Dict]]] = [] + verbose: bool = False + + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="allow", + validate_assignment=False, + ) # type: ignore + + @model_validator(mode="before") + def validate_model_config_before(cls, data): + # logger.debug("Validating model config before validating each field.") + return Composer.to_dict(data) def get_tasks(self) -> Tasks: self.tasks = self.tasks or [] - tasks: Tasks = [ TaskConfig(**getattr(self, name)) for name in self.tasks