Skip to content

Commit

Permalink
✨ feat: support matrix in task config
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Oct 18, 2024
1 parent 4badd52 commit f7430a3
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 17 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ Nanoflow is a simple and efficient workflow framework for Python. It allows you
- GPU resource management for parallel task execution

## Roadmap
- [x] Split commands into command and args to avoid too long
- [ ] Integration with FastAPI for managing workflows as web APIs
- [ ] Enhance TUI, improve task log display, use terminal-like style
- [ ] Support for multiple configuration files or folders
- [ ] Support for passing parameters and matrix
- [x] Split commands into command and args to avoid too long
- [ ] Support to depend on a task that has matrix

## Installation [![Downloads](https://pepy.tech/badge/nanoflow)](https://pepy.tech/project/nanoflow)

Expand Down
9 changes: 8 additions & 1 deletion examples/matrix.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,12 @@ command = "sleep 1 && echo '{content}:a'"
command = "sleep 2 && echo '{content}:b'"

[tasks.c]
command = "sleep 1 && echo '{content}:c'"
command = "sleep {time} && echo '{content}:c'"

[tasks.c.matrix]
time = ["1", "2", "3"]

[tasks.d]
command = "sleep 1 && echo '{content}:d'"
# You can't depend on a task that has matrix, such as c.
deps = ["a", "b"]
62 changes: 48 additions & 14 deletions nanoflow/config.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,48 @@
from __future__ import annotations

from collections.abc import Generator
from itertools import product
from typing import Any, Literal

from pydantic import BaseModel


class DefaultDict(dict):
def __missing__(self, key: str):
return f"{{{key}}}"


def flatten_matrix(matrix: dict[str, list[str]]) -> Generator[dict[str, str], Any, None]:
matrix_keys, matrix_values = zip(*matrix.items(), strict=True)
product_matrix = product(*matrix_values)
for values in product_matrix:
yield DefaultDict(zip(matrix_keys, values, strict=True))


class TaskConfig(BaseModel):
"""Task config.
Example:
>>> config = TaskConfig(command="echo", args=["task1"])
>>> config.get_command()
'echo task1'
>>> config = TaskConfig(command="echo {task}", args=["task1"])
>>> config.format({"task": "task1"})
TaskConfig(command='echo task1', args=['task1'], deps=[])
>>> config = TaskConfig(command="echo", args=["{task}"])
>>> config.format({"task": "task1"}).get_command()
'echo task1'
>>> config = TaskConfig(command="echo", args=["{task}"], matrix={"task": ["task1", "task2"]})
>>> for task in config.wrap_matrix():
... print(task.get_command())
echo task1
echo task2
"""

command: str
matrix: dict[str, list[str]] | None = None
args: list[str] = []
deps: list[str] = []

def get_command(self) -> str:
assert self.matrix is None, "Matrix is not None, you must run wrap_matrix first"
return f"{self.command} {' '.join(self.args)}"

def format(self, template_values: dict[str, str], inplace: bool = False) -> TaskConfig:
Expand All @@ -31,11 +51,20 @@ def format(self, template_values: dict[str, str], inplace: bool = False) -> Task
else:
task = self.model_copy(deep=True)

task.command = task.command.format(**template_values)
task.args = [arg.format(**template_values) for arg in task.args]
task.deps = [dep.format(**template_values) for dep in task.deps]
task.command = task.command.format_map(template_values)
task.args = [arg.format_map(template_values) for arg in task.args]
task.deps = [dep.format_map(template_values) for dep in task.deps]
return task

def wrap_matrix(self) -> list[TaskConfig]:
assert self.matrix is not None, "You cannot run wrap_matrix without matrix"
tasks: list[TaskConfig] = []
for template_values in flatten_matrix(self.matrix):
task = self.format(template_values, inplace=False)
task.matrix = None
tasks.append(task)
return tasks


class WorkflowConfig(BaseModel):
"""
Expand Down Expand Up @@ -74,16 +103,21 @@ def model_post_init(self, __context: Any) -> None:
if self.matrix is None:
return

matrix_keys, matrix_values = zip(*self.matrix.items(), strict=True)
product_matrix = product(*matrix_values)

tasks: dict[str, TaskConfig] = {}
for i, values in enumerate(product_matrix):
template_values = dict(zip(matrix_keys, values, strict=True))
for i, template_values in enumerate(flatten_matrix(self.matrix)):
for task_name, task_config in self.tasks.items():
task = task_config.format(template_values)
task.deps = [f"{i}_{dep}" for dep in task.deps]
tasks[f"{i}_{task_name}"] = task
if task_config.matrix is not None:
wrapped_tasks = task_config.wrap_matrix()
for j, wrapped_task in enumerate(wrapped_tasks):
wrapped_task_name = f"{i}_{task_name.format_map(template_values)}_{j}"
task = wrapped_task.format(template_values, inplace=False)
task.deps = [f"{i}_{dep}" for dep in task.deps]
tasks[wrapped_task_name] = task
else:
task_name = f"{i}_{task_name.format_map(template_values)}"
task = task_config.format(template_values, inplace=False)
task.deps = [f"{i}_{dep}" for dep in task.deps]
tasks[task_name] = task

self.tasks = tasks

Expand Down
2 changes: 1 addition & 1 deletion nanoflow/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def wrapper_fn():
try:
if self.resource_pool is not None:
resource = await self.resource_pool.acquire()
logger.info(f"Acquired resource: {resource}")
logger.info(f"Acquired resource by task [blue]{self.name}[/blue]: {resource}")
if self.resource_modifier is not None:
fn = self.resource_modifier(self.fn, resource)
else:
Expand Down

0 comments on commit f7430a3

Please sign in to comment.