Skip to content

Commit

Permalink
extensive unit tests for Flow and Job update_metadata callback_filter
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Nov 18, 2024
1 parent dc23d2e commit aa4a02f
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 15 deletions.
73 changes: 73 additions & 0 deletions tests/core/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,8 @@ def test_set_output():


def test_update_metadata():
from jobflow import Flow, Job

# test no filter
flow = get_test_flow()
flow.update_metadata({"b": 5})
Expand All @@ -841,6 +843,77 @@ def test_update_metadata():
assert "b" not in flow[0].metadata
assert flow[1].metadata["b"] == 8

# test callback filter
flow = get_test_flow()
# Only update jobs with metadata containing "b"
flow.update_metadata(
{"c": 10}, callback_filter=lambda x: isinstance(x, Job) and "b" in x.metadata
)
assert "c" not in flow[0].metadata
assert flow[1].metadata["c"] == 10
assert "c" not in flow.metadata # Flow itself shouldn't be updated

# Test callback filter on Flow only
flow = get_test_flow()
flow.update_metadata(
{"d": 15}, callback_filter=lambda x: isinstance(x, Flow) and x.name == "Flow"
)
assert flow.metadata["d"] == 15
assert "d" not in flow[0].metadata
assert "d" not in flow[1].metadata

# Test callback filter with multiple conditions and nested structure
from dataclasses import dataclass

from jobflow import Maker, job

@dataclass
class TestMaker(Maker):
name: str = "test_maker"

@job
def make(self):
return Job(lambda: None, name="inner_job")

maker = TestMaker()
inner_flow = Flow([maker.make()], name="inner")
outer_flow = Flow([inner_flow], name="outer")

# Update only flows named "inner" and their jobs
outer_flow.update_metadata(
{"e": 20},
callback_filter=lambda x: (isinstance(x, Flow) and x.name == "inner")
or (isinstance(x, Job) and x.name == "inner_job"),
)
assert "e" not in outer_flow.metadata
assert inner_flow.metadata["e"] == 20

# Test callback filter with dynamic updates
flow = get_test_flow()
flow.update_metadata(
{"f": 25},
callback_filter=lambda x: isinstance(x, Job) and x.name.startswith("div"),
dynamic=True,
)
assert "f" not in flow.metadata
assert "f" not in flow[0].metadata
assert flow[1].metadata["f"] == 25
assert any(
update.get("callback_filter") is not None for update in flow[1].metadata_updates
)

# Test callback filter with maker type checking
flow = get_maker_flow()
flow.update_metadata(
{"g": 30},
callback_filter=lambda x: (
isinstance(x, Job) and x.maker is not None and x.maker.name == "div"
),
)
assert "g" not in flow.metadata
assert "g" not in flow[0].metadata
assert flow[1].metadata["g"] == 30


def test_flow_metadata_initialization():
from jobflow import Flow
Expand Down
92 changes: 77 additions & 15 deletions tests/core/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,32 +1096,94 @@ def jsm_wrapped(a, b):
test_job.update_metadata({"b": 5}, function_filter=A.jsm_wrapped)
assert test_job.metadata["b"] == 5

# test dict mod
# test callback filter with complex conditions
test_job = Job(add, function_args=(1,))
test_job.metadata = {"b": 2}
test_job.update_metadata({"_inc": {"b": 5}}, dict_mod=True)
assert test_job.metadata["b"] == 7
test_job.metadata = {"x": 1, "y": 2}
test_job.name = "test_name"

# Test multiple metadata keys
test_job.update_metadata(
{"z": 3},
callback_filter=lambda job: (
all(key in job.metadata for key in ["x", "y"])
and job.name == "test_name"
and isinstance(job.function_args[0], int)
),
)
assert test_job.metadata["z"] == 3

# test applied dynamic updates
# Test callback filter with no match due to complex condition
test_job = Job(add, function_args=(1,))
test_job.metadata = {"x": 1}
test_job.name = "test_name"
test_job.update_metadata(
{"z": 3},
callback_filter=lambda job: (
all(key in job.metadata for key in ["x", "y"]) and job.name == "test_name"
),
)
assert "z" not in test_job.metadata

# Test callback filter with function argument inspection
test_job = Job(add, function_args=(1, 2))
test_job.update_metadata(
{"w": 4},
callback_filter=lambda job: (
len(job.function_args) == 2
and all(isinstance(arg, int) for arg in job.function_args)
),
)
assert test_job.metadata["w"] == 4

# Test callback filter with maker attributes
@dataclass
class TestMaker(Maker):
name = "test"
class SpecialMaker(Maker):
name: str = "special"
value: int = 42

@job
def make(self, a, b):
return a + b
def make(self):
return 1

maker = SpecialMaker()
test_job = maker.make()
test_job.update_metadata(
{"v": 5},
callback_filter=lambda job: (job.maker is not None and job.maker.value == 42),
)
assert test_job.metadata["v"] == 5

# Test callback filter with dynamic updates and complex conditions
@job
def use_maker(maker):
return Response(replace=maker.make())

test_job = use_maker(TestMaker())
test_job.name = "use"
test_job.update_metadata({"b": 2}, name_filter="test")
assert "b" not in test_job.metadata
test_job = use_maker(SpecialMaker())
test_job.update_metadata(
{"u": 6},
callback_filter=lambda job: (
hasattr(job, "maker") and getattr(job.maker, "name", "") == "special"
),
dynamic=True,
)
response = test_job.run(memory_jobstore)
assert response.replace[0].metadata["b"] == 2
assert response.replace[0].metadata_updates[0]["update"] == {"b": 2}
assert "u" not in test_job.metadata # Original job shouldn't match
assert response.replace[0].metadata["u"] == 6 # But replacement should
assert any(
"callback_filter" in update and update["update"].get("u") == 6
for update in response.replace[0].metadata_updates
)

# Test callback filter with function inspection
def has_specific_signature(job):
import inspect

sig = inspect.signature(job.function)
return len(sig.parameters) == 2 and "b" in sig.parameters

test_job = Job(add, function_args=(1,))
test_job.update_metadata({"t": 7}, callback_filter=has_specific_signature)
assert test_job.metadata["t"] == 7


def test_update_config(memory_jobstore):
Expand Down

0 comments on commit aa4a02f

Please sign in to comment.