Skip to content

Commit

Permalink
♻️ refactor: format task name based on its own matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Oct 22, 2024
1 parent c5b8d87 commit 687e65b
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions nanoflow/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ class TaskConfig(BaseModel):
>>> config.get_command()
'echo task1'
>>> config = TaskConfig(command="echo", args=["{task}"], matrix={"task": ["task1", "task2"]})
>>> for task in config.wrap_matrix():
... print(task.get_command())
echo task1
echo task2
>>> for name, task in config.wrap_matrix("task:{task}").items():
... print(name, task.get_command())
task:task1 echo task1
task:task2 echo task2
"""

command: str
Expand All @@ -50,24 +50,30 @@ def get_command(self) -> str:
assert self.matrix is None, "Matrix is not None, you must run wrap_matrix first"
return f"{self.command} {' '.join(self.args)}"

def format(self, template_values: dict[str, str], inplace: bool = False) -> TaskConfig:
def format(
self, template_values: dict[str, str], *, format_deps: bool = False, inplace: bool = False
) -> TaskConfig:
if inplace:
task = self
else:
task = self.model_copy(deep=True)

task.command = task.command.format_map(template_values)
task.args = [arg.format_map(template_values) for arg in task.args]
task.deps = [dep.format_map(template_values) for dep in task.deps]
if format_deps:
task.deps = [dep.format_map(template_values) for dep in task.deps]
return task

def wrap_matrix(self) -> list[TaskConfig]:
def wrap_matrix(self, name: str) -> dict[str, TaskConfig]:
assert self.matrix is not None, "You cannot run wrap_matrix without matrix"
tasks: list[TaskConfig] = []
for template_values in flatten_matrix(self.matrix):
tasks: dict[str, TaskConfig] = {}
for i, template_values in enumerate(flatten_matrix(self.matrix)):
task = self.format(template_values, inplace=False)
task.matrix = None
tasks.append(task)
task_name = name.format(**template_values)
if task_name in tasks:
task_name = f"{task_name}_{i}"
tasks[task_name] = task
return tasks


Expand Down Expand Up @@ -100,8 +106,8 @@ class WorkflowConfig(BaseModel):
"""

name: str
matrix: dict[str, list[str]] | None = None
tasks: dict[str, TaskConfig]
matrix: dict[str, list[str]] | None = None
resources: Literal["gpus"] | list[str] | None = None

def model_post_init(self, __context: Any) -> None:
Expand All @@ -112,14 +118,14 @@ def model_post_init(self, __context: Any) -> None:
for i, template_values in enumerate(flatten_matrix(self.matrix)):
for task_name, task_config in self.tasks.items():
if task_config.matrix is not None:
wrapped_tasks = task_config.wrap_matrix()
for j, wrapped_task in enumerate(wrapped_tasks):
wrapped_task_name = f"{i}_{task_name.format_map(template_values)}_{j}"
wrapped_tasks = task_config.wrap_matrix(task_name)
for wrapped_name, wrapped_task in wrapped_tasks.items():
wrapped_task_name = f"{i}_{wrapped_name}"
task = wrapped_task.format(template_values, inplace=False)
task.deps = [f"{i}_{dep}" for dep in task.deps]
tasks[wrapped_task_name] = task
else:
task_name = f"{i}_{task_name.format_map(template_values)}"
task_name = f"{i}_{task_name}"
task = task_config.format(template_values, inplace=False)
task.deps = [f"{i}_{dep}" for dep in task.deps]
tasks[task_name] = task
Expand Down

0 comments on commit 687e65b

Please sign in to comment.