Skip to content

Commit

Permalink
Move gen_taskio from .map to .submit (for all fps to be stored in the…
Browse files Browse the repository at this point in the history
… local storage)
  • Loading branch information
annshress committed Feb 6, 2024
1 parent babc445 commit bf29a5d
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 38 deletions.
10 changes: 5 additions & 5 deletions em_workflows/lrg_2d_rgb/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import SimpleITK as sitk
from pytools import HedwigZarrImage, HedwigZarrImages
from prefect import flow, task, unmapped
from prefect import flow, task

from em_workflows.utils import utils
from em_workflows.utils import neuroglancer as ng
Expand Down Expand Up @@ -208,10 +208,10 @@ def lrg_2d_flow(
VALID_LRG_2D_RGB_INPUTS,
single_file=x_file_name,
)
fps = gen_taskio.map(
share_name=unmapped(file_share),
input_dir=unmapped(input_dir_fp),
fp_in=input_fps.result(),
fps = gen_taskio.submit(
share_name=file_share,
input_dir=input_dir_fp,
input_fps=input_fps,
)
tiffs = convert_png_to_tiff.map(taskio=fps)
zarrs = gen_zarr.map(taskio=tiffs)
Expand Down
53 changes: 24 additions & 29 deletions em_workflows/utils/task_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,16 @@ def taskio_handler(func):
does not pass into downstream tasks
"""

def wrapper(**kwargs):
assert (
"taskio" in kwargs
), "Task functions must have `taskio` keyword argument in their definition to use `taskio_handler` definition."
prev_taskio: TaskIO = kwargs["taskio"]
def wrapper(taskio):
prev_taskio: TaskIO = taskio

if prev_taskio.error:
return prev_taskio

try:
new_taskio = func(**kwargs)
new_taskio = func(taskio)
except RuntimeError as e:
# We are currently handling only ValueError.
# So any other exception will cause pipeline to fail
# We are raising errors that are instance of RuntimeError
new_taskio = TaskIO(
file_path=prev_taskio.file_path,
output_path=None,
Expand All @@ -66,39 +63,40 @@ def wrapper(**kwargs):
error="Something went wrong!",
)
new_taskio.file_path = prev_taskio.file_path

# if we want to save history of upstream tasks
# new_taskio.upstream_history = prev_taskio.history
# new_taskio.upstream_history[func.__name__] = new_taskio

return new_taskio

return wrapper


@task
def gen_response(fps: List[TaskIO], taskios: List[TaskIO]):
def gen_response(fps: List[TaskIO], assets: List[TaskIO]):
# turning a list to dict with primary filepath as the key
etl_items = {
etl_item.file_path.fp_in: etl_item.file_path.gen_prim_fp_elt()
for etl_item in fps
}

for taskio in taskios:
print(f"\n---\nTaskIO being processed for {taskio.file_path.fp_in}\n\n***")
etl_item = etl_items[taskio.file_path.fp_in]
# if error is already registered... ignore
for asset in assets:
print(f"\n---\nTaskIO being processed for {asset.file_path.fp_in}\n\n***")
etl_item = etl_items[asset.file_path.fp_in]

# if error is already registered from previous asset... ignore
if etl_item["status"] == "error":
continue

if taskio.error:
if asset.error:
etl_item["status"] = "error"
etl_item["message"] = taskio.error
etl_item["message"] = asset.error
etl_item["imageSet"] = None
else:
if isinstance(taskio.data, list):
etl_item["imageSet"][0]["assets"].extend(taskio.data)
elif isinstance(taskio.data, dict):
etl_item["imageSet"][0]["assets"].append(taskio.data)
if isinstance(asset.data, list):
etl_item["imageSet"][0]["assets"].extend(asset.data)
elif isinstance(asset.data, dict):
etl_item["imageSet"][0]["assets"].append(asset.data)

resp = list(etl_items.values())
return resp
Expand All @@ -111,12 +109,9 @@ def gen_response(fps: List[TaskIO], taskios: List[TaskIO]):
result_serializer=Config.pickle_serializer,
result_storage_key="{flow_run.id}__gen_fps",
)
def gen_taskio(share_name: str, input_dir: Path, fp_in: Path) -> TaskIO:
file_path = FilePath(share_name=share_name, input_dir=input_dir, fp_in=fp_in)
return TaskIO(file_path=file_path, output_path=file_path)


@task
def gen_prim_fps(taskio: TaskIO) -> Dict:
base_elt = taskio.file_path.gen_prim_fp_elt()
return base_elt
def gen_taskio(share_name: str, input_dir: Path, input_fps: List[Path]) -> TaskIO:
result = list()
for fp_in in input_fps:
file_path = FilePath(share_name=share_name, input_dir=input_dir, fp_in=fp_in)
result.append(TaskIO(file_path=file_path, output_path=file_path))
return result
8 changes: 4 additions & 4 deletions em_workflows/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,16 +682,16 @@ def send_callback_body(

def copy_workdirs_and_cleanup_hook(flow, flow_run, state):
stored_result = Config.local_storage.read_path(f"{flow_run.id}__gen_fps")
fps: List[FilePath] = Config.pickle_serializer.loads(
taskio_fps = Config.pickle_serializer.loads(
json.loads(stored_result)["data"].encode()
)
parameters = flow_run.parameters
x_keep_workdir = parameters["x_keep_workdir"]
x_keep_workdir = parameters.get("x_keep_workdir", False)

for fp in fps:
for fp in taskio_fps:
copy_workdir_logs.fn(file_path=fp)

cleanup_workdir.fn(fps, x_keep_workdir)
cleanup_workdir.fn(taskio_fps, x_keep_workdir)


def callback_with_cleanup(
Expand Down

0 comments on commit bf29a5d

Please sign in to comment.