Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of chi2_shift as a Method to calculate Shifts to the "Calculate Registration (image-based)" Task #741

Merged
merged 15 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions fractal_tasks_core/__FRACTAL_MANIFEST__.json
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,16 @@
"type": "string",
"description": "Wavelength that will be used for image-based registration; e.g. `A01_C01` for Yokogawa, `C01` for MD."
},
"method": {
"default": "phase_cross_correlation",
"allOf": [
{
"$ref": "#/definitions/RegistrationMethod"
}
],
"title": "Method",
"description": "Method to use for image registration. The available methods are `phase_cross_correlation` (scikit-image package, works for 2D & 3D) and \"chi2_shift\" (image_registration package, only works for 2D images)."
},
"roi_table": {
"title": "Roi Table",
"default": "FOV_ROI_table",
Expand Down Expand Up @@ -1063,6 +1073,14 @@
"required": [
"reference_zarr_url"
]
},
"RegistrationMethod": {
"title": "RegistrationMethod",
"description": "An enumeration.",
"enum": [
"phase_cross_correlation",
"chi2_shift"
]
}
}
},
Expand Down
50 changes: 50 additions & 0 deletions fractal_tasks_core/tasks/_registration_utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import copy

import anndata as ad
import dask.array as da
import numpy as np
import pandas as pd
from image_registration import chi2_shift

from fractal_tasks_core.ngff.zarr_utils import load_NgffWellMeta
from fractal_tasks_core.tasks._zarr_utils import _split_well_path_image_path
Expand Down Expand Up @@ -235,3 +237,51 @@ def apply_registration_to_single_ROI_table(
+ float(min_df.loc[roi, "translation_x"])
)
return roi_table


def chi2_shift_out(img_ref, img_cycle_x) -> list[np.ndarray]:
"""
Helper function to get the output of chi2_shift into the same format as
phase_cross_correlation. Calculates the shift between two images using
the chi2_shift method.

Args:
img_ref (np.ndarray): First image.
img_cycle_x (np.ndarray): Second image.

Returns:
List containing numpy array of shift in y and x direction.
"""
x, y, a, b = chi2_shift(np.squeeze(img_ref), np.squeeze(img_cycle_x))

"""
Running into issues when using direct float output for fractal.
When rounding to integer and using integer dtype, it typically works
but for some reasons fails when run over a whole 384 well plate (but
the well where it fails works fine when run alone). For now, rounding
to integer, but still using float64 dtype (like the scikit-image
phase cross correlation function) seems to be the safest option.
"""
shifts = np.array([-np.round(y), -np.round(x)], dtype="float64")
# return as a list to adhere to the phase_cross_correlation output format
return [shifts]


def is_3D(dask_array: da.array) -> bool:
"""
Check if a dask array is 3D.

Treats singelton Z dimensions as 2D images.
(1, 2000, 2000) => False
(10, 2000, 2000) => True

Args:
dask_array: Input array to be checked

Returns:
bool on whether the array is 3D
"""
if len(dask_array.shape) == 3 and dask_array.shape[0] > 1:
return True
else:
return False
46 changes: 31 additions & 15 deletions fractal_tasks_core/tasks/calculate_registration_image_based.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Calculates translation for image-based registration
"""
import logging
from enum import Enum

import anndata as ad
import dask.array as da
Expand All @@ -36,14 +37,27 @@
from fractal_tasks_core.tasks._registration_utils import (
calculate_physical_shifts,
)
from fractal_tasks_core.tasks._registration_utils import chi2_shift_out
from fractal_tasks_core.tasks._registration_utils import (
get_ROI_table_with_translation,
)
from fractal_tasks_core.tasks._registration_utils import is_3D
from fractal_tasks_core.tasks.io_models import InitArgsRegistration

logger = logging.getLogger(__name__)


class RegistrationMethod(Enum):
PHASE_CROSS_CORRELATION = "phase_cross_correlation"
CHI2_SHIFT = "chi2_shift"

def register(self, img_ref, img_acq_x):
if self == RegistrationMethod.PHASE_CROSS_CORRELATION:
return phase_cross_correlation(img_ref, img_acq_x)
elif self == RegistrationMethod.CHI2_SHIFT:
return chi2_shift_out(img_ref, img_acq_x)


@validate_arguments
def calculate_registration_image_based(
*,
Expand All @@ -52,6 +66,7 @@
init_args: InitArgsRegistration,
# Core parameters
wavelength_id: str,
method: RegistrationMethod = "phase_cross_correlation",
roi_table: str = "FOV_ROI_table",
level: int = 2,
) -> None:
Expand All @@ -73,6 +88,10 @@
(standard argument for Fractal tasks, managed by Fractal server).
wavelength_id: Wavelength that will be used for image-based
registration; e.g. `A01_C01` for Yokogawa, `C01` for MD.
method: Method to use for image registration. The available methods
are `phase_cross_correlation` (scikit-image package, works for 2D
& 3D) and "chi2_shift" (image_registration package, only works for
2D images).
roi_table: Name of the ROI table over which the task loops to
calculate the registration. Examples: `FOV_ROI_table` => loop over
the field of views, `well_ROI_table` => process the whole well as
Expand Down Expand Up @@ -115,6 +134,16 @@
channel_index_align
]

# Check if data is 3D (as not all registration methods work in 3D)
# TODO: Abstract this check into a higher-level Zarr loading class
if is_3D(data_reference_zyx):
if method == "chi2_shift":
raise ValueError(

Check notice on line 141 in fractal_tasks_core/tasks/calculate_registration_image_based.py

View workflow job for this annotation

GitHub Actions / Coverage

Missing coverage

Missing coverage on line 141
"The `chi2_shift` registration method has not been "
"implemented for 3D images and the input image had a shape of "
f"{data_reference_zyx.shape}."
)

# Read ROIs
ROI_table_ref = ad.read_zarr(
f"{init_args.reference_zarr_url}/tables/{roi_table}"
Expand Down Expand Up @@ -211,30 +240,17 @@
##############
# Calculate the transformation
##############
# Basic version (no padding, no internal binning)
if img_ref.shape != img_acq_x.shape:
raise NotImplementedError(
"This registration is not implemented for ROIs with "
"different shapes between acquisitions."
)
shifts = phase_cross_correlation(
np.squeeze(img_ref), np.squeeze(img_acq_x)
)[0]

# Registration based on scmultiplex, image-based
# shifts, _, _ = calculate_shift(np.squeeze(img_ref),
# np.squeeze(img_acq_x), bin=binning, binarize=False)

# TODO: Make this work on label images
# (=> different loading) etc.
shifts = method.register(np.squeeze(img_ref), np.squeeze(img_acq_x))[0]

##############
# Storing the calculated transformation ###
# Store the calculated transformation ###
##############
# Store the shift in ROI table
# TODO: Store in OME-NGFF transformations: Check SpatialData approach,
# per ROI storage?

# Adapt ROIs for the given ROI table:
ROI_name = ROI_table_ref.obs.index[i_ROI]
new_shifts[ROI_name] = calculate_physical_shifts(
Expand Down
1 change: 1 addition & 0 deletions fractal_tasks_core/tasks/find_registration_consensus.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def find_registration_consensus(
new_roi_table,
shifted_rois[acq_zarr_url],
table_attrs=roi_tables_attrs[acq_zarr_url],
overwrite=True,
)

# TODO: Optionally apply registration to other tables as well?
Expand Down
121 changes: 118 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ napari-skimage-regionprops = { version = "^0.8.1", optional = true }
napari-tools-menu = { version = "^0.1.19", optional = true }
cellpose = { version = "~2.2", optional = true }
torch = { version = "<=2.0.0", optional = true }
image_registration = { version = ">=0.2.9", optional = true }

[tool.poetry.extras]
fractal-tasks = ["Pillow", "imageio-ffmpeg", "scikit-image", "llvmlite", "napari-segment-blobs-and-things-with-membranes", "napari-workflows", "napari-skimage-regionprops", "napari-tools-menu", "cellpose", "torch"]
fractal-tasks = ["Pillow", "imageio-ffmpeg", "scikit-image", "llvmlite", "napari-segment-blobs-and-things-with-membranes", "napari-workflows", "napari-skimage-regionprops", "napari-tools-menu", "cellpose", "torch", "image_registration"]

[tool.poetry.group.dev]
optional = true
Expand Down
Loading
Loading