diff --git a/pydra/engine/tests/test_specs.py b/pydra/engine/tests/test_specs.py index 8b1407d23..f607b1a67 100644 --- a/pydra/engine/tests/test_specs.py +++ b/pydra/engine/tests/test_specs.py @@ -15,7 +15,6 @@ from pydra.engine.core import Workflow from pydra.engine.node import Node from pydra.engine.submitter import Submitter, NodeExecution, DiGraph -from pydra.utils.typing import StateArray from pydra.design import python, workflow from .utils import Foo, FunAddTwo, FunAddVar, ListSum diff --git a/pydra/engine/tests/test_submitter.py b/pydra/engine/tests/test_submitter.py index 909d1c064..ddce586cd 100644 --- a/pydra/engine/tests/test_submitter.py +++ b/pydra/engine/tests/test_submitter.py @@ -5,10 +5,10 @@ import time import attrs import typing as ty -from random import randint import os from unittest.mock import patch import pytest +from pydra.design import workflow from fileformats.generic import Directory from .utils import ( need_sge, @@ -23,65 +23,67 @@ from pydra.design import python from pathlib import Path from datetime import datetime +from pydra.engine.specs import Result @python.define -def sleep_add_one(x): +def SleepAddOne(x): time.sleep(1) return x + 1 def test_callable_wf(plugin, tmpdir): - wf = BasicWorkflow() - res = wf() - assert res.output.out == 9 - del wf, res + wf = BasicWorkflow(x=5) + outputs = wf(cache_dir=tmpdir) + assert outputs.out == 9 + del wf, outputs # providing plugin - wf = BasicWorkflow() - res = wf(plugin="cf") - assert res.output.out == 9 - del wf, res + wf = BasicWorkflow(x=5) + outputs = wf(worker="cf") + assert outputs.out == 9 + del wf, outputs # providing plugin_kwargs - wf = BasicWorkflow() - res = wf(plugin="cf", plugin_kwargs={"n_procs": 2}) - assert res.output.out == 9 - del wf, res + wf = BasicWorkflow(x=5) + outputs = wf(worker="cf", n_procs=2) + assert outputs.out == 9 + del wf, outputs # providing wrong plugin_kwargs - wf = BasicWorkflow() + wf = BasicWorkflow(x=5) with pytest.raises(TypeError, match="an unexpected keyword argument"): - wf(plugin="cf", plugin_kwargs={"sbatch_args": "-N2"}) + wf(worker="cf", sbatch_args="-N2") # providing submitter - wf = BasicWorkflow() - wf.cache_dir = tmpdir - sub = Submitter(plugin) - res = wf(submitter=sub) - assert res.output.out == 9 + wf = BasicWorkflow(x=5) + + with Submitter(workflow=plugin, cache_dir=tmpdir) as sub: + res = sub(wf) + assert res.outputs.out == 9 def test_concurrent_wf(plugin, tmpdir): # concurrent workflow # A --> C # B --> D - wf = Workflow("new_wf", input_spec=["x", "y"]) - wf.inputs.x = 5 - wf.inputs.y = 10 - wf.add(sleep_add_one(name="taska", x=wf.lzin.x)) - wf.add(sleep_add_one(name="taskb", x=wf.lzin.y)) - wf.add(sleep_add_one(name="taskc", x=wf.taska.lzout.out)) - wf.add(sleep_add_one(name="taskd", x=wf.taskb.lzout.out)) - wf.set_output([("out1", wf.taskc.lzout.out), ("out2", wf.taskd.lzout.out)]) - wf.cache_dir = tmpdir - - with Submitter(plugin) as sub: - sub(wf) + @workflow.define(outputs=["out1", "out2"]) + def Workflow(x, y): + taska = workflow.add(SleepAddOne(x=x), name="taska") + taskb = workflow.add(SleepAddOne(x=y), name="taskb") + taskc = workflow.add(SleepAddOne(x=taska.out), name="taskc") + taskd = workflow.add(SleepAddOne(x=taskb.out), name="taskd") + return taskc.out, taskd.out + + wf = Workflow(x=5, y=10) + + with Submitter(worker=plugin, cache_dir=tmpdir) as sub: + results = sub(wf) - res = wf.result() - assert res.output.out1 == 7 - assert res.output.out2 == 12 + assert not results.errored, " ".join(results.errors["error message"]) + outputs = results.outputs + assert outputs.out1 == 7 + assert outputs.out2 == 12 def test_concurrent_wf_nprocs(tmpdir): @@ -89,49 +91,49 @@ def test_concurrent_wf_nprocs(tmpdir): # setting n_procs in Submitter that is passed to the worker # A --> C # B --> D - wf = Workflow("new_wf", input_spec=["x", "y"]) - wf.inputs.x = 5 - wf.inputs.y = 10 - wf.add(sleep_add_one(name="taska", x=wf.lzin.x)) - wf.add(sleep_add_one(name="taskb", x=wf.lzin.y)) - wf.add(sleep_add_one(name="taskc", x=wf.taska.lzout.out)) - wf.add(sleep_add_one(name="taskd", x=wf.taskb.lzout.out)) - wf.set_output([("out1", wf.taskc.lzout.out), ("out2", wf.taskd.lzout.out)]) - wf.cache_dir = tmpdir - - with Submitter("cf", n_procs=2) as sub: - sub(wf) + @workflow.define(outputs=["out1", "out2"]) + def Workflow(x, y): + taska = workflow.add(SleepAddOne(x=x), name="taska") + taskb = workflow.add(SleepAddOne(x=y), name="taskb") + taskc = workflow.add(SleepAddOne(x=taska.out), name="taskc") + taskd = workflow.add(SleepAddOne(x=taskb.out), name="taskd") + return taskc.out, taskd.out + + wf = Workflow(x=5, y=10) + with Submitter(worker="cf", n_procs=2, cache_dir=tmpdir) as sub: + res = sub(wf) - res = wf.result() - assert res.output.out1 == 7 - assert res.output.out2 == 12 + assert not res.errored, " ".join(res.errors["error message"]) + outputs = res.outputs + assert outputs.out1 == 7 + assert outputs.out2 == 12 def test_wf_in_wf(plugin, tmpdir): """WF(A --> SUBWF(A --> B) --> B)""" - wf = Workflow(name="wf_in_wf", input_spec=["x"]) - wf.inputs.x = 3 - wf.add(sleep_add_one(name="wf_a", x=wf.lzin.x)) # workflow task - subwf = Workflow(name="sub_wf", input_spec=["x"]) - subwf.add(sleep_add_one(name="sub_a", x=subwf.lzin.x)) - subwf.add(sleep_add_one(name="sub_b", x=subwf.sub_a.lzout.out)) - subwf.set_output([("out", subwf.sub_b.lzout.out)]) - # connect, then add - subwf.inputs.x = wf.wf_a.lzout.out - subwf.cache_dir = tmpdir - - wf.add(subwf) - wf.add(sleep_add_one(name="wf_b", x=wf.sub_wf.lzout.out)) - wf.set_output([("out", wf.wf_b.lzout.out)]) - wf.cache_dir = tmpdir - - with Submitter(plugin) as sub: - sub(wf) + @workflow.define + def SubWf(x): + sub_a = workflow.add(SleepAddOne(x=x), name="sub_a") + sub_b = workflow.add(SleepAddOne(x=sub_a.out), name="sub_b") + return sub_b.out - res = wf.result() - assert res.output.out == 7 + @workflow.define + def WfInWf(x): + a = workflow.add(SleepAddOne(x=x), name="a") + subwf = workflow.add(SubWf(x=a.out), name="subwf") + b = workflow.add(SleepAddOne(x=subwf.out), name="b") + return b.out + + wf = WfInWf(x=3) + + with Submitter(worker=plugin, cache_dir=tmpdir) as sub: + results = sub(wf) + + assert not results.errored, " ".join(results.errors["error message"]) + outputs = results.outputs + assert outputs.out == 7 @pytest.mark.flaky(reruns=2) # when dask @@ -139,60 +141,59 @@ def test_wf2(plugin_dask_opt, tmpdir): """workflow as a node workflow-node with one task and no splitter """ - wfnd = Workflow(name="wfnd", input_spec=["x"]) - wfnd.add(sleep_add_one(name="add2", x=wfnd.lzin.x)) - wfnd.set_output([("out", wfnd.add2.lzout.out)]) - wfnd.inputs.x = 2 - wf = Workflow(name="wf", input_spec=["x"]) - wf.add(wfnd) - wf.set_output([("out", wf.wfnd.lzout.out)]) - wf.cache_dir = tmpdir + @workflow.define + def Wfnd(x): + add2 = workflow.add(SleepAddOne(x=x)) + return add2.out - with Submitter(worker=plugin_dask_opt) as sub: - sub(wf) + @workflow.define + def Workflow(x): + wfnd = workflow.add(Wfnd(x=x)) + return wfnd.out + + wf = Workflow(x=2) + + with Submitter(worker=plugin_dask_opt, cache_dir=tmpdir) as sub: + res = sub(wf) - res = wf.result() - assert res.output.out == 3 + assert res.outputs.out == 3 @pytest.mark.flaky(reruns=2) # when dask def test_wf_with_state(plugin_dask_opt, tmpdir): - wf = Workflow(name="wf_with_state", input_spec=["x"]) - wf.add(sleep_add_one(name="taska", x=wf.lzin.x)) - wf.add(sleep_add_one(name="taskb", x=wf.taska.lzout.out)) + @workflow.define + def Workflow(x): + taska = workflow.add(SleepAddOne(x=x), name="taska") + taskb = workflow.add(SleepAddOne(x=taska.out), name="taskb") + return taskb.out - wf.split("x", x=[1, 2, 3]) - wf.set_output([("out", wf.taskb.lzout.out)]) - wf.cache_dir = tmpdir + wf = Workflow().split(x=[1, 2, 3]) - with Submitter(worker=plugin_dask_opt) as sub: - sub(wf) - - res = wf.result() + with Submitter(cache_dir=tmpdir, worker=plugin_dask_opt) as sub: + res = sub(wf) - assert res[0].output.out == 3 - assert res[1].output.out == 4 - assert res[2].output.out == 5 + assert res.outputs.out[0] == 3 + assert res.outputs.out[1] == 4 + assert res.outputs.out[2] == 5 -def test_serial_wf(): +def test_debug_wf(): # Use serial plugin to execute workflow instead of CF - wf = BasicWorkflow() - res = wf(plugin="serial") - assert res.output.out == 9 + wf = BasicWorkflow(x=5) + outputs = wf(worker="debug") + assert outputs.out == 9 @need_slurm def test_slurm_wf(tmpdir): - wf = BasicWorkflow() - wf.cache_dir = tmpdir + wf = BasicWorkflow(x=1) # submit workflow and every task as slurm job - with Submitter("slurm") as sub: - sub(wf) + with Submitter(worker="slurm", cache_dir=tmpdir) as sub: + res = sub(wf) - res = wf.result() - assert res.output.out == 9 + outputs = res.outputs + assert outputs.out == 9 script_dir = tmpdir / "SlurmWorker_scripts" assert script_dir.exists() # ensure each task was executed with slurm @@ -202,13 +203,11 @@ def test_slurm_wf(tmpdir): @need_slurm def test_slurm_wf_cf(tmpdir): # submit entire workflow as single job executing with cf worker - wf = BasicWorkflow() - wf.cache_dir = tmpdir - wf.plugin = "cf" - with Submitter("slurm") as sub: - sub(wf) - res = wf.result() - assert res.output.out == 9 + wf = BasicWorkflow(x=1) + with Submitter(worker="slurm", cache_dir=tmpdir) as sub: + res = sub(wf) + outputs = res.outputs + assert outputs.out == 9 script_dir = tmpdir / "SlurmWorker_scripts" assert script_dir.exists() # ensure only workflow was executed with slurm @@ -220,14 +219,12 @@ def test_slurm_wf_cf(tmpdir): @need_slurm def test_slurm_wf_state(tmpdir): - wf = BasicWorkflow() - wf.split("x", x=[5, 6]) - wf.cache_dir = tmpdir - with Submitter("slurm") as sub: - sub(wf) - res = wf.result() - assert res[0].output.out == 9 - assert res[1].output.out == 10 + wf = BasicWorkflow(x=1).split(x=[5, 6]) + with Submitter(worker="slurm", cache_dir=tmpdir) as sub: + res = sub(wf) + + assert res.outputs.out[0] == 9 + assert res.outputs.out[1] == 10 script_dir = tmpdir / "SlurmWorker_scripts" assert script_dir.exists() sdirs = [sd for sd in script_dir.listdir() if sd.isdir()] @@ -237,16 +234,18 @@ def test_slurm_wf_state(tmpdir): @need_slurm @pytest.mark.flaky(reruns=3) def test_slurm_max_jobs(tmpdir): - wf = Workflow("new_wf", input_spec=["x", "y"], cache_dir=tmpdir) - wf.inputs.x = 5 - wf.inputs.y = 10 - wf.add(sleep_add_one(name="taska", x=wf.lzin.x)) - wf.add(sleep_add_one(name="taskb", x=wf.lzin.y)) - wf.add(sleep_add_one(name="taskc", x=wf.taska.lzout.out)) - wf.add(sleep_add_one(name="taskd", x=wf.taskb.lzout.out)) - wf.set_output([("out1", wf.taskc.lzout.out), ("out2", wf.taskd.lzout.out)]) - with Submitter("slurm", max_jobs=1) as sub: - sub(wf) + @workflow.define(outputs=["out1", "out2"]) + def Workflow(x, y): + taska = workflow.add(SleepAddOne(x=x)) + taskb = workflow.add(SleepAddOne(x=y)) + taskc = workflow.add(SleepAddOne(x=taska.out)) + taskd = workflow.add(SleepAddOne(x=taskb.out)) + return taskc.out, taskd.out + + wf = Workflow(x=5, y=10) + + with Submitter(worker="slurm", cache_dir=tmpdir, max_jobs=1) as sub: + res = sub(wf) jobids = [] time.sleep(0.5) # allow time for sacct to collect itself @@ -277,14 +276,12 @@ def test_slurm_max_jobs(tmpdir): @need_slurm def test_slurm_args_1(tmpdir): """testing sbatch_args provided to the submitter""" - task = sleep_add_one(x=1) - task.cache_dir = tmpdir + task = SleepAddOne(x=1) # submit workflow and every task as slurm job - with Submitter("slurm", sbatch_args="-N1") as sub: - sub(task) + with Submitter(worker="slurm", cache_dir=tmpdir, sbatch_args="-N1") as sub: + res = sub(task) - res = task.result() - assert res.output.out == 2 + assert res.outputs.out == 2 script_dir = tmpdir / "SlurmWorker_scripts" assert script_dir.exists() @@ -294,11 +291,12 @@ def test_slurm_args_2(tmpdir): """testing sbatch_args provided to the submitter exception should be raised for invalid options """ - task = sleep_add_one(x=1) - task.cache_dir = tmpdir + task = SleepAddOne(x=1) # submit workflow and every task as slurm job with pytest.raises(RuntimeError, match="Error returned from sbatch:"): - with Submitter("slurm", sbatch_args="-N1 --invalid") as sub: + with Submitter( + worker="slurm", cache_dir=tmpdir, sbatch_args="-N1 --invalid" + ) as sub: sub(task) @@ -347,26 +345,23 @@ def test_slurm_cancel_rerun_1(tmpdir): The first job should be re-queue and finish without problem. (possibly has to be improved, in theory cancel job might finish before cancel) """ - wf = Workflow( - name="wf", - input_spec=["x", "job_name_cancel", "job_name_resqueue"], - cache_dir=tmpdir, - ) - wf.add(sleep(name="sleep1", x=wf.lzin.x, job_name_part=wf.lzin.job_name_cancel)) - wf.add(cancel(name="cancel1", job_name_part=wf.lzin.job_name_resqueue)) - wf.inputs.x = 10 - wf.inputs.job_name_resqueue = "sleep1" - wf.inputs.job_name_cancel = "cancel1" - - wf.set_output([("out", wf.sleep1.lzout.out), ("canc_out", wf.cancel1.lzout.out)]) - with Submitter("slurm") as sub: - sub(wf) - res = wf.result() - assert res.output.out == 10 + @workflow.define(outputs=["out", "canc_out"]) + def Workflow(x, job_name_cancel, job_name_resqueue): + sleep1 = workflow.add(sleep(x=x, job_name_part=job_name_cancel)) + cancel1 = workflow.add(cancel(job_name_part=job_name_resqueue)) + return sleep1.out, cancel1.out + + wf = Workflow(x=10, job_name_resqueue="sleep1", job_name_cancel="cancel1") + + with Submitter(worker="slurm", cache_dir=tmpdir) as sub: + res = sub(wf) + + outputs = res.outputs + assert outputs.out == 10 # checking if indeed the sleep-task job was cancelled by cancel-task - assert "Terminating" in res.output.canc_out - assert "Invalid" not in res.output.canc_out + assert "Terminating" in outputs.canc_out + assert "Invalid" not in outputs.canc_out script_dir = tmpdir / "SlurmWorker_scripts" assert script_dir.exists() @@ -379,32 +374,32 @@ def test_slurm_cancel_rerun_2(tmpdir): job_id of the first task and cancel it. The first job is not able t be rescheduled and the error is returned. """ - wf = Workflow(name="wf", input_spec=["x", "job_name"], cache_dir=tmpdir) - wf.add(sleep(name="sleep2", x=wf.lzin.x)) - wf.add(cancel(name="cancel2", job_name_part=wf.lzin.job_name)) - wf.inputs.x = 10 - wf.inputs.job_name = "sleep2" + @workflow.define(outputs=["out", "canc_out"]) + def Workflow(x, job_name): + sleep2 = workflow.add(sleep(x=x)) + cancel2 = workflow.add(cancel(job_name_part=job_name)) + return sleep2.out, cancel2.out + + wf = Workflow(x=10, job_name="sleep2") - wf.set_output([("out", wf.sleep2.lzout.out), ("canc_out", wf.cancel2.lzout.out)]) with pytest.raises(Exception): - with Submitter("slurm", sbatch_args="--no-requeue") as sub: + with Submitter( + worker="slurm", cache_dir=tmpdir, sbatch_args="--no-requeue" + ) as sub: sub(wf) @need_sge def test_sge_wf(tmpdir): """testing that a basic workflow can be run with the SGEWorker""" - wf = BasicWorkflow() - wf.cache_dir = tmpdir + wf = BasicWorkflow(x=1) # submit workflow and every task as sge job - with Submitter( - "sge", - ) as sub: - sub(wf) + with Submitter(worker="sge", cache_dir=tmpdir) as sub: + res = sub(wf) - res = wf.result() - assert res.output.out == 9 + outputs = res.outputs + assert outputs.out == 9 script_dir = tmpdir / "SGEWorker_scripts" assert script_dir.exists() # ensure each task was executed with sge @@ -412,18 +407,16 @@ def test_sge_wf(tmpdir): @need_sge -def test_sge_wf_cf(tmpdir): +def test_sge_wf_cf(tmp_path): """testing the SGEWorker can submit SGE tasks while the workflow uses the concurrent futures plugin""" # submit entire workflow as single job executing with cf worker - wf = BasicWorkflow() - wf.cache_dir = tmpdir - wf.plugin = "cf" - with Submitter("sge") as sub: - sub(wf) - res = wf.result() - assert res.output.out == 9 - script_dir = tmpdir / "SGEWorker_scripts" + wf = BasicWorkflow(x=1) + with Submitter(worker="sge", cache_dir=tmp_path) as sub: + res = sub(wf) + outputs = res.outputs + assert outputs.out == 9 + script_dir = tmp_path / "SGEWorker_scripts" assert script_dir.exists() # ensure only workflow was executed with slurm sdirs = [sd for sd in script_dir.listdir() if sd.isdir()] @@ -435,15 +428,11 @@ def test_sge_wf_cf(tmpdir): @need_sge def test_sge_wf_state(tmpdir): """testing the SGEWorker can be used with a workflow with state""" - wf = BasicWorkflow() - wf.split("x") - wf.inputs.x = [5, 6] - wf.cache_dir = tmpdir - with Submitter("sge") as sub: - sub(wf) - res = wf.result() - assert res[0].output.out == 9 - assert res[1].output.out == 10 + wf = BasicWorkflow().split(x=[5, 6]) + with Submitter(worker="sge", cache_dir=tmpdir) as sub: + res = sub(wf) + assert res.output.out[0] == 9 + assert res.output.out[1] == 10 script_dir = tmpdir / "SGEWorker_scripts" assert script_dir.exists() sdirs = [sd for sd in script_dir.listdir() if sd.isdir()] @@ -469,12 +458,10 @@ def qacct_output_to_dict(qacct_output): def test_sge_set_threadcount(tmpdir): """testing the number of threads for an SGEWorker task can be set using the input_spec variable sgeThreads""" - wf = BasicWorkflowWithThreadCount() - wf.inputs.x = 5 - wf.cache_dir = tmpdir + wf = BasicWorkflowWithThreadCount(x=5) jobids = [] - with Submitter("sge") as sub: + with Submitter(worker="sge", cache_dir=tmpdir) as sub: sub(wf) jobids = list(sub.worker.jobid_by_task_uid.values()) jobids.sort() @@ -499,13 +486,10 @@ def test_sge_set_threadcount(tmpdir): def test_sge_limit_maxthreads(tmpdir): """testing the ability to limit the number of threads used by the SGE at one time with the max_threads argument to SGEWorker""" - wf = BasicWorkflowWithThreadCountConcurrent() - wf.inputs.x = [5, 6] - wf.split("x") - wf.cache_dir = tmpdir + wf = BasicWorkflowWithThreadCountConcurrent().split(x=[5, 6]) jobids = [] - with Submitter("sge", max_threads=8) as sub: + with Submitter(worker="sge", max_threads=8, cache_dir=tmpdir) as sub: sub(wf) jobids = list(sub.worker.jobid_by_task_uid.values()) jobids.sort() @@ -543,13 +527,10 @@ def test_sge_limit_maxthreads(tmpdir): def test_sge_no_limit_maxthreads(tmpdir): """testing unlimited threads can be used at once by SGE when max_threads is not set""" - wf = BasicWorkflowWithThreadCountConcurrent() - wf.inputs.x = [5, 6] - wf.split("x") - wf.cache_dir = tmpdir + wf = BasicWorkflowWithThreadCountConcurrent().split(x=[5, 6]) jobids = [] - with Submitter("sge", max_threads=None) as sub: + with Submitter(worker="sge", max_threads=None, cache_dir=tmpdir) as sub: sub(wf) jobids = list(sub.worker.jobid_by_task_uid.values()) jobids.sort() @@ -587,7 +568,7 @@ def output_dir_as_input(out_dir: Directory) -> Directory: task = output_dir_as_input(out_dir=tmp_path) with pytest.raises(RuntimeError, match="Input field hashes have changed"): - task() + task(cache_dir=tmp_path) def test_hash_changes_in_task_inputs_unstable(tmp_path): @@ -605,69 +586,28 @@ def unstable_input(unstable: Unstable) -> int: task = unstable_input(unstable=Unstable(1)) with pytest.raises(RuntimeError, match="Input field hashes have changed"): - task() + task(cache_dir=tmp_path) def test_hash_changes_in_workflow_inputs(tmp_path): @python.define - def output_dir_as_output(out_dir: Path) -> Directory: + def OutputDirAsOutput(out_dir: Path) -> Directory: (out_dir / "new-file.txt").touch() return out_dir - wf = Workflow( - name="test_hash_change", input_spec={"in_dir": Directory}, in_dir=tmp_path - ) - wf.add(output_dir_as_output(out_dir=wf.lzin.in_dir, name="task")) - wf.set_output(("out_dir", wf.task.lzout.out)) - with pytest.raises(RuntimeError, match="Input field hashes have changed.*Workflow"): - wf() - - -def test_hash_changes_in_workflow_graph(tmpdir): - class X: - """Dummy class with unstable hash (i.e. which isn't altered in a node in which - it is an input)""" - - value = 1 - - def __bytes_repr__(self, cache): - """Bytes representation from class attribute, which will be changed be - 'alter_x" node. + @workflow.define(outputs=["out_dir"]) + def Workflow(in_dir: Directory): + task = workflow.add(OutputDirAsOutput(out_dir=in_dir), name="task") + return task.out - NB: this is a contrived example where the bytes_repr implementation returns - a bytes representation of a class attribute in order to trigger the exception, - hopefully cases like this will be very rare""" - yield bytes(self.value) + in_dir = tmp_path / "in_dir" + in_dir.mkdir() + cache_dir = tmp_path / "cache_dir" + cache_dir.mkdir() - @python.define - @mark.annotate({"return": {"x": X, "y": int}}) - def identity(x: X) -> ty.Tuple[X, int]: - return x, 99 - - @python.define - def alter_x(y): - X.value = 2 - return y - - @python.define - def to_tuple(x, y): - return (x, y) - - wf = Workflow(name="wf_with_blocked_tasks", input_spec=["x", "y"]) - wf.add(identity(name="taska", x=wf.lzin.x)) - wf.add(alter_x(name="taskb", y=wf.taska.lzout.y)) - wf.add(to_tuple(name="taskc", x=wf.taska.lzout.x, y=wf.taskb.lzout.out)) - wf.set_output([("out", wf.taskc.lzout.out)]) - - wf.inputs.x = X() - - wf.cache_dir = tmpdir - - with pytest.raises( - RuntimeError, match="Graph of 'wf_with_blocked_tasks' workflow is not empty" - ): - with Submitter("cf") as sub: - result = sub(wf) + wf = Workflow(in_dir=in_dir) + with pytest.raises(RuntimeError, match="Input field hashes have changed.*"): + wf(cache_dir=cache_dir) @python.define @@ -684,36 +624,38 @@ def __init__(self, add_var, **kwargs): super().__init__(**kwargs) self.add_var = add_var - async def exec_serial(self, runnable, rerun=False, environment=None): - if isinstance(runnable, Task): - with patch.dict(os.environ, {"BYO_ADD_VAR": str(self.add_var)}): - result = runnable._run(rerun, environment=environment) - return result - else: # it could be tuple that includes pickle files with tasks and inputs - return super().exec_serial(runnable, rerun, environment) + def run( + self, + task: "Task", + rerun: bool = False, + ) -> "Result": + with patch.dict(os.environ, {"BYO_ADD_VAR": str(self.add_var)}): + return super().run(task, rerun) @python.define -def add_env_var_task(x: int) -> int: +def AddEnvVarTask(x: int) -> int: return x + int(os.environ.get("BYO_ADD_VAR", 0)) -def test_byo_worker(): +def test_byo_worker(tmp_path): - task1 = add_env_var_task(x=1) + task1 = AddEnvVarTask(x=1) - with Submitter(worker=BYOAddVarWorker, add_var=10) as sub: - assert sub.plugin == "byo_add_env_var" + with Submitter(worker=BYOAddVarWorker, add_var=10, cache_dir=tmp_path) as sub: + assert sub.worker_name == "byo_add_env_var" result = sub(task1) - assert outputs.out == 11 + assert result.outputs.out == 11 + + task2 = AddEnvVarTask(x=2) - task2 = add_env_var_task(x=2) + new_cache_dir = tmp_path / "new" - with Submitter(worker="serial") as sub: + with Submitter(worker="debug", cache_dir=new_cache_dir) as sub: result = sub(task2) - assert outputs.out == 2 + assert result.outputs.out == 2 def test_bad_builtin_worker():