Skip to content

Commit

Permalink
Prototype of ROI-based illumination_correction task (ref #114)
Browse files Browse the repository at this point in the history
  • Loading branch information
tcompa committed Jul 21, 2022
1 parent d068c4a commit 21ce80d
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 26 deletions.
68 changes: 44 additions & 24 deletions fractal/tasks/illumination_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,23 @@
import json
import warnings

import anndata as ad
import dask
import dask.array as da
import numpy as np
from skimage.io import imread

from fractal.tasks.lib_pyramid_creation import write_pyramid
from fractal.tasks.lib_regions_of_interest import convert_ROI_table_to_indices
from fractal.tasks.lib_regions_of_interest import (
split_3D_indices_into_z_layers,
)


def correct(
img,
illum_img=None,
background=110,
img_size_y=2160,
img_size_x=2560,
block_info=None,
):
"""
Corrects single Z level input image using an illumination profile
Expand All @@ -42,25 +45,21 @@ def correct(
:type illum_img: np.array
:param background: value for background subtraction (optional, default 110)
:type background: int
:param img_size_y: image size along Y (optional, default 2160)
:type img_size_y: int
:param img_size_x: image size along X (optional, default 2560)
:type img_size_x: int
"""

# Check shapes
if img.shape != (1, img_size_y, img_size_x):
if illum_img.shape != img.shape[1:]:
raise Exception(
f"Error in illumination_correction, img.shape: {img.shape}"
)
if illum_img.shape != (img_size_y, img_size_x):
raise Exception(
"Error in illumination_correction, "
"Error in illumination_correction\n"
f"img.shape: {img.shape}\n"
f"illum_img.shape: {illum_img.shape}"
)

# Background subtraction
# FIXME: is there a problem with these changes?
# devdoc.net/python/dask-2.23.0-doc/delayed-best-practices.html
# ?highlight=delayed#don-t-mutate-inputs
img[img <= background] = 0
img[img > background] = img[img > background] - background

Expand Down Expand Up @@ -190,24 +189,45 @@ def illumination_correction(
f"Error in illumination_correction, chunks_x: {chunks_x}"
)

# Read FOV ROIs
FOV_ROI_table = ad.read_zarr(f"{zarrurl}tables/FOV_ROI_table")

# Create list of indices for 3D FOVs spanning the entire Z direction
list_indices = convert_ROI_table_to_indices(
FOV_ROI_table, level=0, coarsening_xy=coarsening_xy
)

# Create the final list of single-Z-layer FOVs
list_indices = split_3D_indices_into_z_layers(list_indices)

# Prepare delayed function
delayed_correct = dask.delayed(correct)

# FIXME The dask array will consist of a single chunk.
# (docs.dask.org/en/stable/_modules/dask/array/core.html#from_delayed)

# Loop over channels
# FIXME: map_blocks could take care of this
data_czyx_new = []
for ind_ch, ch in enumerate(chl_list):

data_zyx = data_czyx[ind_ch]
illum_img = corrections[ch]

# Map correct(..) function onto each block
data_zyx_new = data_zyx.map_blocks(
correct,
chunks=(1, img_size_y, img_size_x),
meta=np.array((), dtype=dtype),
illum_img=illum_img,
background=background,
img_size_y=img_size_y,
img_size_x=img_size_x,
)
data_zyx_new = da.empty_like(data_zyx)

for indices in list_indices:
s_z, e_z, s_y, e_y, s_x, e_x = indices[:]
shape = [e_z - s_z, e_y - s_y, e_x - s_x]
new_img = delayed_correct(
data_zyx[s_z:e_z, s_y:e_y, s_x:e_x],
illum_img,
background=background,
)
# FIXME what about meta and name kwargs?
data_zyx_new[s_z:e_z, s_y:e_y, s_x:e_x] = da.from_delayed(
new_img, shape, dtype
)

data_czyx_new.append(data_zyx_new)
accumulated_data = da.stack(data_czyx_new, axis=0)

Expand Down
4 changes: 2 additions & 2 deletions fractal/tasks/lib_regions_of_interest.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def convert_ROI_table_to_indices(
return list_indices


def split_3D_ROI_indices_into_z_layers(
def split_3D_indices_into_z_layers(
list_indices: List[List[int]],
) -> List[List[int]]:

Expand Down Expand Up @@ -164,7 +164,7 @@ def _inspect_ROI_table(
adata, level=level, coarsening_xy=coarsening_xy
)

list_indices = split_3D_ROI_indices_into_z_layers(list_indices)
list_indices = split_3D_indices_into_z_layers(list_indices)

print(f"level: {level}")
print(f"coarsening_xy: {coarsening_xy}")
Expand Down

0 comments on commit 21ce80d

Please sign in to comment.