Skip to content

Commit

Permalink
Fix unique IDs per block and label visualizaiton
Browse files Browse the repository at this point in the history
  • Loading branch information
cmalinmayor committed Jun 30, 2024
1 parent d57599a commit 357f616
Showing 1 changed file with 34 additions and 29 deletions.
63 changes: 34 additions & 29 deletions examples/tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,9 @@ def smooth_in_block_dask(x):
plt.imshow(smoothed.transpose((1, 2, 0)), origin="lower")

# %% [markdown]
# For some tasks, dask is much simpler and better suited. One key difference between dask and daisy is that in dask, functions are not supposed to have side effects. In daisy, functions are expected to have "side effects" - saving the output, rather than returning it. In general, daisy is designed for...
# - Cases where the output is too big to be kept in memory
# For many situations, both dask and daisy can work well. Indeed, for some tasks, dask is simpler and better suited, as it does for you many things that daisy leaves to you to implement. One key difference between dask and daisy is that in dask, functions are not supposed to have side effects. In daisy, functions can have side effects, allowing blocks to depend on other blocks in the scheduling order (see the last two examples in this tutorial about task chaining and read-write conflicts).
#
# In general, daisy is designed for...
# - Cases where you want to be able to pick up where you left off after an error, rather than starting the whole task over (because blocks that finished saved their results to disk)
# - Cases where blocks should be executed in a particular order, so that certain blocks see other blocks outputs (without passing the output through the scheduler)
# - Cases where the worker function needs setup and teardown that takes longer than processing a block (see our next example!)
Expand Down Expand Up @@ -597,7 +598,7 @@ def smooth_in_block_with_failure(block: daisy.Block):
# Debugging multi-process code is inherently difficult, but daisy tries to provide as much information as possible. First, you see the progress bar, which also reports the number of blocks at each state, including failed blocks. Any worker error messages are also logged to the scheduler log, although not the full traceback. Upon completion, daisy provides an error summary, which informs you of the final status of all the blocks, and points you to the full output and error logs for each worker, which can be found in `daisy_logs/<task_name>`. The worker error log will contain the full traceback for debugging the exact source of an error.

# %% [markdown]
# You may have noticed that while we coded the task to fail 50% of the time, much more than 50% of the blocks succeeded. This is because daisy by default retries each block 3 times (on different workers) before marking it as failed, to deal gracefully with random error. We will re-run this example, but set max_retries=1 to see the effect of this parameter.
# You may have noticed that while we coded the task to fail 50% of the time, much more than 50% of the blocks succeeded. This is because daisy by default retries each block 3 times (on different workers) before marking it as failed, to deal gracefully with random error. We will re-run this example, but set max_retries=0 to see the effect of this parameter.

# %%
# delete and re-create the dataset, so that we start from zeros again
Expand All @@ -613,7 +614,7 @@ def smooth_in_block_with_failure(block: daisy.Block):
read_roi=read_roi,
write_roi=block_roi,
read_write_conflict=False,
max_retries=1,
max_retries=0,
num_workers=5,
)
]
Expand All @@ -624,7 +625,7 @@ def smooth_in_block_with_failure(block: daisy.Block):


# %% [markdown]
# We still have greater than 50% success! This is because if a specific worker fails multiple times, daisy will assume something might have gone wrong with that worker. The scheduler will shut down and restart the worker, and retry the blocks that failed on that worker. So daisy is very robust to random error.
# Now we should see a success rate closer to what we would expect. You might notice in the logs a message like ```Worker hostname=...:port=..:task_id=fault tolerance test:worker_id=... failed too many times, restarting this worker...```, which shows another way that daisy is robust to pseudo-random errors. If a specific worker fails multiple times, daisy will assume something might have gone wrong with that worker (e.g. GPU memory taken by another process that the cluster scheduler was not aware of). The scheduler will shut down and restart the worker, and retry the blocks that failed on that worker. So daisy is very robust to random error.
#
# But what about non-random error, like stopping after 8 hours? If you don't want to re-process all the already processed blocks from a prior run, you can write a function that takes a block and checks if it is complete, and pass it to the scheduler. The scheduler will run this check function on each block and skip the block if the check function returns true.
#
Expand Down Expand Up @@ -703,15 +704,16 @@ def segment_blue_objects(input_group, output_group, block):

# Threshold the image to get only blue colors
mask = cv2.inRange(hsv_image, lower_blue, upper_blue) # this returns 0 and 255
mask = mask.astype(np.uint16)
mask = mask.astype(np.uint32)
mask = mask // 255 # turn into 0/1 labels

# give each connected component its own instance segmentation label
labels = skimage.measure.label(mask)

# get a unique ID for each element in the whole volume (avoid repeats between blocks)
block_id_mask = mask * (block.block_id[1])
labels = labels + block_id_mask
max_number_obj = 128*128 # This number is an upper bound on the maximum number of objects in a block
block_id_mask = mask * (block.block_id[1] * max_number_obj)
labels = labels + block_id_mask

# save the output
output_ds = open_ds('sample_data.zarr', output_group, 'a')
Expand All @@ -733,7 +735,7 @@ def segment_blue_objects(input_group, output_group, block):
total_roi=total_read_roi,
read_roi=read_roi,
write_roi=block_roi,
num_workers=1,
num_workers=3,
read_write_conflict=False,
check_function=partial(check_complete, "smoothed_for_seg")
)
Expand All @@ -753,7 +755,7 @@ def segment_blue_objects(input_group, output_group, block):
"blue_objects",
total_roi=total_roi,
voxel_size=daisy.Coordinate((1,1)),
dtype=np.uint16, # This is different! Our labels will be uint16
dtype=np.uint32, # This is different! Our labels will be uint32
write_size=seg_block_roi.shape, # use the new block roi to determine the chunk size
)

Expand All @@ -775,10 +777,17 @@ def segment_blue_objects(input_group, output_group, block):
daisy.run_blockwise([smoothing_task, seg_task])

# %%
from skimage.color import label2rgb
# make a colormap to display the labels as unique colors
# This doesn't actually guarantee uniqueness, since it wraps the list at some point
import matplotlib.colors as mcolors
colors_list = list(mcolors.XKCD_COLORS.keys())
# colors_list.remove("black")
colormap = mcolors.ListedColormap(colors=["black"] + colors_list*101)

figure, axes = plt.subplots(1, 2)
axes[0].imshow(zarr.open('sample_data.zarr', 'r')['smoothed_for_seg'][:].transpose(1,2,0), origin="lower")
axes[1].imshow(label2rgb(zarr.open('sample_data.zarr', 'r')['blue_objects'][:]), origin="lower")
blue_objs = zarr.open('sample_data.zarr', 'r')['blue_objects'][:]
axes[1].imshow(blue_objs, origin="lower", cmap=colormap, interpolation='nearest')

# %% [markdown]
# Now we know how to chain tasks together, but we've created another issue. If objects crossed block boundaries, they were assigned different IDs. We will address this problem in the next section.
Expand All @@ -787,10 +796,9 @@ def segment_blue_objects(input_group, output_group, block):
# ## Process functions that need to read their neighbor's output (and the "read_write_conflict" flag)

# %% [markdown]
# So far, we have always written process functions that read from one input array and write to one output array. However, we have much more flexibility than that - daisy just gives the worker the block, and the worker can do whatever it wants.
#
# Say we expand our read_roi as before. Reading more of the input context of a block does not help us resolve our conflicting labels, but looking at the block's neighbors' output can help! Let's visualize an example:
# There is a class of problem where it is useful for a block to see the output of its neighboring blocks. Usually, this need comes up when the task performs detection and linking in the same step. To detect an object in a block, it is useful to know if a neighbor has already detected an object that the object in the current block should link to. This is especially useful if there is some sort of continuity constraint, as in tracking objects over time. Even for tasks without a continuity constraint, like agglomeration of fragments for instance segmentation, performing the detection and linking at the same time can save you an extra pass over the dataset, which matters when the datasets are large.
#
# The example in this tutorial is only to illustrate how daisy implements the ability to depend on neighboring blocks outputs, by showing how we can relabel the segmentation IDs to be consistent across blocks during the detection step. Let's visualize an example block as though it were not completed, but its neighbors are done, and the read_roi is expanded by a certain amount of context.

# %%
context = 10 # It could be as low as 1, but we use 10 for ease of visualization
Expand All @@ -806,26 +814,25 @@ def segment_blue_objects(input_group, output_group, block):
)

# simulate this block not being completed yet
from funlib.persistence import open_ds
output_ds = open_ds("sample_data.zarr", "blue_objects", "a")
output_ds[seg_block.write_roi] = 0
blue_objs = zarr.open('sample_data.zarr', 'r')['blue_objects'][:]
blue_objs[128:356, 128:256] = 0

from skimage.color import label2rgb
figure, axes = plt.subplots(1, 2)
axes[0].imshow(zarr.open('sample_data.zarr', 'r')['smoothed_for_seg'][:].transpose(1,2,0), origin="lower")
axes[1].imshow(label2rgb(zarr.open('sample_data.zarr', 'r')['blue_objects'][:]), origin="lower")
axes[1].imshow(blue_objs, origin="lower", cmap=colormap, interpolation='nearest')
display_roi(figure.axes[0], seg_block.read_roi, color="purple")
display_roi(figure.axes[0], seg_block.write_roi, color="white")
display_roi(figure.axes[1], seg_block.read_roi, color="purple")
display_roi(figure.axes[1], seg_block.write_roi, color="white")


# %% [markdown]
# Here the purple is the read_roi of our current block, and the white is the write_roi. As before, the process function will read the input image in the read_roi and segment it. From the previous result visualization, we can see that the function will detect the top half of the crest and assign it the pink label.
# Here the purple is the read_roi of our current block, and the white is the write_roi. As before, the process function will read the input image in the read_roi and segment it. From the previous result visualization, we can see that the function will detect the top half of the crest and assign it the label id that is visualized as brown.
#
# Before we write out the pink object to the write_roi of the output dataset, however, we can adapt the process function to **also read in the existing results in the output dataset** in the read_roi. The existing blue label will then overlap with the pink label, and our process function can relabel the top half of the crest blue based on this information, before writing to the write_roi.
# Before we write out the brown label to the write_roi of the output dataset, however, we can adapt the process function to **also read in the existing results in the output dataset** in the read_roi. The existing green label will then overlap with the brown label, and our process function can relabel the top half of the crest green based on this information, before writing to the write_roi.
#
# This approach only works if the overlapping blocks are run sequentially. If they are run in parallel, it is possible that the blue label will not yet be there when our current block reads existing labels, but also that our pink label will not yet be there when the block containing the blue object reads existing labels. If `read_write_conflicts` argument is set to True in a task, the daisy scheduler will ensure that pairs of blocks with overlapping read/write rois will never be run at the same time, thus avoiding this race condition.
# This approach only works if the overlapping blocks are run sequentially (and if objects don't span across non-adjacent blocks - for large objects, you cannot relabel without a second pass). If the neighbors are run in parallel, it is possible that the green label will not yet be there when our current block reads existing labels, but also that our brown label will not yet be there when the block containing the green object reads existing labels. If `read_write_conflicts` argument is set to True in a task, the daisy scheduler will ensure that pairs of blocks with overlapping read/write rois will never be run at the same time, thus avoiding this race condition.

# %%
# here is the new and improved segmentation function that reads in neighboring output context
Expand Down Expand Up @@ -863,14 +870,15 @@ def get_overlapping_labels(array1, array2):

# Threshold the image to get only blue colors
mask = cv2.inRange(hsv_image, lower_blue, upper_blue) # this returns 0 and 255
mask = mask.astype(np.uint16)
mask = mask.astype(np.uint32)
mask = mask // 255 # turn into 0/1 labels

# give each connected component its own instance segmentation label
labels = skimage.measure.label(mask)

# get a unique ID for each element in the whole volume (avoid repeats between blocks)
block_id_mask = mask * (block.block_id[1])
max_number_obj = 128*128 # This number is an upper bound on the maximum number of objects in a block
block_id_mask = mask * (block.block_id[1] * max_number_obj)
labels = labels + block_id_mask

# load the existing labels in the output
Expand Down Expand Up @@ -903,7 +911,7 @@ def get_overlapping_labels(array1, array2):
"blue_objects_with_context",
total_roi=total_roi,
voxel_size=daisy.Coordinate((1,1)),
dtype=np.uint16,
dtype=np.uint32,
write_size=seg_block_roi.shape,
)

Expand All @@ -920,17 +928,14 @@ def get_overlapping_labels(array1, array2):
daisy.run_blockwise([seg_task])

# %%
from skimage.color import label2rgb
figure, axes = plt.subplots(1, 2)
axes[0].imshow(zarr.open('sample_data.zarr', 'r')['smoothed_for_seg'][:].transpose(1,2,0), origin="lower")
axes[1].imshow(label2rgb(zarr.open('sample_data.zarr', 'r')['blue_objects_with_context'][:]), origin="lower")
axes[1].imshow(zarr.open('sample_data.zarr', 'r')['blue_objects_with_context'][:], cmap=colormap, interpolation="nearest", origin="lower")

# %% [markdown]
# All the labels are now consistent! If you re-run the previous cell with `read_write_conflict=False`, you should see an inconsistent result again due to the race conditions, even though the process function still reads the neighboring output.

# %% [markdown]
# This pattern where one block wants to see the output of its neighbors so it can incorporate them into its processing comes up surprisingly often, from agglomeration to tracking. See [this blog post](https://localshapedescriptors.github.io/#throughput) for visualizations of this process in neuron segmentation. Setting `read_write_conflicts=True` allows you to solve many complex tasks in this manner.
#
# **IMPORTANT PERFORMANCE NOTE:** Be aware that `read_write_conflicts` is set to `True` by default and can lead to performance hits in cases where you don't need it, so be sure to turn it off if you want every block to be run in parallel!

# %%

0 comments on commit 357f616

Please sign in to comment.