Skip to content
This repository has been archived by the owner on Apr 20, 2020. It is now read-only.

[WIP]Distributed ilastik: Adds OpFormattedDataExport.run_distributed_export() #335

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions lazyflow/distributed/TaskOrchestrator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from mpi4py import MPI
from threading import Thread
import enum
from typing import Generator, TypeVar, Generic, Callable, Tuple
import uuid

END = "END-eb23ae13-709e-4ac3-931d-99ab059ef0c2"
TASK_DATUM = TypeVar("TASK_DATUM")


@enum.unique
class Tags(enum.IntEnum):
TASK_DONE = 13
TO_WORKER = enum.auto()


class _Worker(Generic[TASK_DATUM]):
def __init__(self, comm, rank: int):
self.comm = comm
self.rank = rank
self.stopped = False

def send(self, datum: TASK_DATUM):
print(f"Sending datum {datum} to worker {self.rank}...")
self.comm.send(datum, dest=self.rank, tag=Tags.TO_WORKER)

def stop(self):
self.send(END)
self.stopped = True


class TaskOrchestrator(Generic[TASK_DATUM]):
def __init__(self, comm=None):
self.comm = comm or MPI.COMM_WORLD
self.rank = self.comm.Get_rank()
num_workers = self.comm.size - 1
if num_workers <= 0:
raise Exception("Trying to orchestrate tasks with no workers!!!")
self.workers = {rank: _Worker(self.comm, rank) for rank in range(1, num_workers + 1)}

def get_finished_worker(self) -> _Worker[TASK_DATUM]:
status = MPI.Status()
self.comm.recv(source=MPI.ANY_SOURCE, tag=Tags.TASK_DONE, status=status)
return self.workers[status.Get_source()]

def orchestrate(self, task_data: Generator[TASK_DATUM, None, None]):
print(f"ORCHESTRATOR: Starting orchestration of {len(self.workers)}...")
num_busy_workers = 0
for worker in self.workers.values():
try:
worker.send(next(task_data))
num_busy_workers += 1
except StopIteration:
break

while True:
worker = self.get_finished_worker()
try:
worker.send(next(task_data))
except StopIteration:
worker.stop()
num_busy_workers -= 1
if num_busy_workers == 0:
break

for worker in self.workers.values():
if not worker.stopped:
worker.stop()

def start_as_worker(self, target: Callable[[TASK_DATUM, int], None]):
print(f"WORKER {self.rank}: Started")
while True:
status = MPI.Status()
datum = self.comm.recv(source=MPI.ANY_SOURCE, tag=Tags.TO_WORKER, status=status)
if datum == END:
break
self.comm.send(target(datum, self.rank), dest=status.Get_source(), tag=Tags.TASK_DONE)
print(f"WORKER {self.rank}: Terminated")
4 changes: 4 additions & 0 deletions lazyflow/metaDict.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import copy
import numpy
from collections import OrderedDict, defaultdict
from ndstructs import Shape5D


class MetaDict(defaultdict):
Expand Down Expand Up @@ -153,6 +154,9 @@ def getTaggedShape(self):
keys = self.getAxisKeys()
return OrderedDict(list(zip(keys, self.shape)))

def getShape5D(self):
return Shape5D(**self.getTaggedShape())

def getAxisKeys(self):
assert self.axistags is not None
return [tag.key for tag in self.axistags]
Expand Down
56 changes: 55 additions & 1 deletion lazyflow/operators/ioOperators/opFormattedDataExport.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,19 @@
import collections
import warnings
import numpy
from typing import Tuple
from pathlib import Path

import z5py
from ndstructs import Shape5D, Slice5D

from lazyflow.utility import format_known_keys
from lazyflow.graph import Operator, InputSlot, OutputSlot
from lazyflow.roi import roiFromShape
from lazyflow.operators.generic import OpSubRegion, OpPixelOperator
from lazyflow.operators.valueProviders import OpMetadataInjector
from lazyflow.operators.opReorderAxes import OpReorderAxes
from lazyflow.utility.pathHelpers import PathComponents

from .opExportSlot import OpExportSlot

Expand Down Expand Up @@ -125,7 +131,7 @@ def __init__(self, *args, **kwargs):
self.FormatSelectionErrorMsg.connect(self._opExportSlot.FormatSelectionErrorMsg)
self.progressSignal = self._opExportSlot.progressSignal

def setupOutputs(self):
def get_new_roi(self) -> Tuple[Tuple, Tuple]:
# Prepare subregion operator
total_roi = roiFromShape(self.Input.meta.shape)
total_roi = list(map(tuple, total_roi))
Expand All @@ -150,7 +156,22 @@ def setupOutputs(self):
)

new_start, new_stop = tuple(clipped_start), tuple(clipped_stop)
return new_start, new_stop

def get_cutout(self) -> Slice5D:
input_axiskeys = self.Input.meta.getAxisKeys()
cutout_start, cutout_stop = self.get_new_roi()
cutout_slices = tuple(slice(start, stop) for start, stop in zip(cutout_start, cutout_stop))
return Slice5D.zero(**{axis: slc for axis, slc in zip(input_axiskeys, cutout_slices)})

def set_cutout(self, cutout: Slice5D):
input_axiskeys = self.Input.meta.getAxisKeys()
start = cutout.start.to_tuple(input_axiskeys, int)
stop = cutout.stop.to_tuple(input_axiskeys, int)
self._opSubRegion.Roi.setValue((start, stop))

def setupOutputs(self):
new_start, new_stop = self.get_new_roi()
# If we're in the process of switching input data,
# then the roi dimensionality might not match up.
# Just leave the roi disconnected for now.
Expand Down Expand Up @@ -258,3 +279,36 @@ def run_export(self):

def run_export_to_array(self):
return self._opExportSlot.run_export_to_array()

def run_distributed_export(self, block_shape: Shape5D):
from lazyflow.distributed.TaskOrchestrator import TaskOrchestrator

orchestrator = TaskOrchestrator()
n5_file_path = Path(self.OutputFilenameFormat.value).with_suffix(".n5")
output_meta = self.ImageToExport.meta
if orchestrator.rank == 0:
output_shape = output_meta.getShape5D()
block_shape = block_shape.clamped(maximum=output_shape)

with z5py.File(n5_file_path, "w") as f:
ds = f.create_dataset(
self.OutputInternalPath.value,
shape=output_meta.shape,
chunks=block_shape.to_tuple(output_meta.getAxisKeys()),
dtype=output_meta.dtype.__name__,
)
ds.attrs["axes"] = list(reversed(output_meta.getAxisKeys()))
ds[...] = 1 # FIXME: for some reason setting to 0 does nothing

cutout = self.get_cutout()
orchestrator.orchestrate(cutout.split(block_shape=block_shape))
else:

def process_tile(tile: Slice5D, rank: int):
self.set_cutout(tile)
slices = tile.to_slices(output_meta.getAxisKeys())
with z5py.File(n5_file_path, "r+") as n5_file:
dataset = n5_file[self.OutputInternalPath.value]
dataset[slices] = self.ImageToExport.value

orchestrator.start_as_worker(process_tile)