Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reconstruct a subset of time indices #398

Merged
merged 16 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/birefringence-and-phase.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ input_channel_names:
- State1
- State2
- State3
time_indices: all
reconstruction_dimension: 3
birefringence:
transfer_function:
Expand Down
1 change: 1 addition & 0 deletions examples/birefringence.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ input_channel_names:
- State1
- State2
- State3
time_indices: all
reconstruction_dimension: 3
birefringence:
transfer_function:
Expand Down
1 change: 1 addition & 0 deletions examples/fluorescence.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
input_channel_names:
- GFP
time_indices: all
reconstruction_dimension: 3
fluorescence:
transfer_function:
Expand Down
1 change: 1 addition & 0 deletions examples/phase.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
input_channel_names:
- BF
time_indices: all
reconstruction_dimension: 3
phase:
transfer_function:
Expand Down
76 changes: 47 additions & 29 deletions recOrder/cli/apply_inverse_transfer_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,22 @@
import numpy as np
import torch
from iohub import open_ome_zarr
from recOrder.cli.printing import echo_headline, echo_settings
from recOrder.cli.settings import ReconstructionSettings
from recOrder.cli.parsing import (
input_data_path_argument,
config_path_option,
output_dataset_option,
)
from recOrder.io import utils
from waveorder.models import (
inplane_oriented_thick_pol3d,
isotropic_fluorescent_thick_3d,
isotropic_thin_3d,
phase_thick_3d,
isotropic_fluorescent_thick_3d,
)

from recOrder.cli.parsing import (
config_path_option,
input_data_path_argument,
output_dataset_option,
)
from recOrder.cli.printing import echo_headline, echo_settings
from recOrder.cli.settings import ReconstructionSettings
from recOrder.io import utils


def _check_background_consistency(background_shape, data_shape):
data_cyx_shape = (data_shape[1],) + data_shape[3:]
Expand Down Expand Up @@ -53,8 +54,20 @@ def apply_inverse_transfer_function_cli(
input_dataset.channel_names.index(input_channel_name)
)

# Load dataset shape
t_shape = input_dataset.data.shape[0]
# Find time indices
if settings.time_indices == "all":
time_indices = range(input_dataset.data.shape[0])
elif isinstance(settings.time_indices, list):
time_indices = settings.time_indices
elif isinstance(settings.time_indices, int):
time_indices = [settings.time_indices]

# Check for invalid times
time_ubound = input_dataset.data.shape[0] - 1
if np.max(time_indices) > time_ubound:
raise ValueError(
f"time_indices = {time_indices} includes a time index beyond the maximum index of the dataset = {time_ubound}"
)

# Simplify important settings names
recon_biref = settings.birefringence is not None
Expand Down Expand Up @@ -87,26 +100,31 @@ def apply_inverse_transfer_function_cli(
output_z_shape = input_dataset.data.shape[2]

output_shape = (
t_shape,
input_dataset.data.shape[0],
len(channel_names),
output_z_shape,
) + input_dataset.data.shape[3:]

# Create output dataset
output_dataset = open_ome_zarr(
output_path, layout="fov", mode="w", channel_names=channel_names
output_path, layout="fov", mode="a", channel_names=channel_names
)
output_array = output_dataset.create_zeros(
name="0",
shape=output_shape,
dtype=np.float32,
chunks=(
1,
1,
1,

# Create an empty TCZYX array if it doesn't exist
if "0" not in output_dataset:
output_array = output_dataset.create_zeros(
name="0",
shape=output_shape,
dtype=np.float32,
chunks=(
1,
1,
1,
)
+ input_dataset.data.shape[3:], # chunk by YX
)
+ input_dataset.data.shape[3:], # chunk by YX
)
else:
output_array = output_dataset[0]

# Load data
tczyx_uint16_numpy = input_dataset.data.oindex[:, channel_indices]
Expand Down Expand Up @@ -143,7 +161,7 @@ def apply_inverse_transfer_function_cli(
transfer_function_dataset["intensity_to_stokes_matrix"][0, 0, 0]
)

for time_index in range(t_shape):
for time_index in time_indices:
# Apply
reconstructed_parameters = (
inplane_oriented_thick_pol3d.apply_inverse_transfer_function(
Expand Down Expand Up @@ -180,7 +198,7 @@ def apply_inverse_transfer_function_cli(
transfer_function_dataset["phase_transfer_function"][0, 0]
)

for time_index in range(t_shape):
for time_index in time_indices:
# Apply
(
_,
Expand Down Expand Up @@ -210,7 +228,7 @@ def apply_inverse_transfer_function_cli(
)

# Apply
for time_index in range(t_shape):
for time_index in time_indices:
zyx_phase = phase_thick_3d.apply_inverse_transfer_function(
tczyx_data[time_index, 0],
real_potential_transfer_function,
Expand Down Expand Up @@ -246,7 +264,7 @@ def apply_inverse_transfer_function_cli(
transfer_function_dataset["phase_transfer_function"][0, 0]
)

for time_index in range(t_shape):
for time_index in time_indices:
# Apply
reconstructed_parameters_2d = inplane_oriented_thick_pol3d.apply_inverse_transfer_function(
tczyx_data[time_index],
Expand Down Expand Up @@ -304,7 +322,7 @@ def apply_inverse_transfer_function_cli(
)

# Apply
for time_index in range(t_shape):
for time_index in time_indices:
reconstructed_parameters_3d = inplane_oriented_thick_pol3d.apply_inverse_transfer_function(
tczyx_data[time_index],
intensity_to_stokes_matrix,
Expand Down Expand Up @@ -348,7 +366,7 @@ def apply_inverse_transfer_function_cli(
)

# Apply
for time_index in range(t_shape):
for time_index in time_indices:
zyx_recon = isotropic_fluorescent_thick_3d.apply_inverse_transfer_function(
tczyx_data[time_index, 0],
optical_transfer_function,
Expand Down
5 changes: 4 additions & 1 deletion recOrder/cli/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
root_validator,
validator,
)
from typing import Literal, List, Optional
from typing import Literal, List, Optional, Union

# This file defines the configuration settings for the CLI.

Expand Down Expand Up @@ -148,6 +148,9 @@ class FluorescenceSettings(MyBaseModel):
# Top level settings
class ReconstructionSettings(MyBaseModel):
input_channel_names: List[str] = [f"State{i}" for i in range(4)]
time_indices: Union[
NonNegativeInt, List[NonNegativeInt], Literal["all"]
] = "all"
reconstruction_dimension: Literal[2, 3] = 3
birefringence: Optional[BirefringenceSettings]
phase: Optional[PhaseSettings]
Expand Down
35 changes: 24 additions & 11 deletions recOrder/tests/cli_tests/test_reconstruct.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import numpy as np
from recOrder.cli.main import cli
from recOrder.cli import settings
from recOrder.io import utils
from click.testing import CliRunner
from iohub.ngff import open_ome_zarr

from recOrder.cli import settings
from recOrder.cli.main import cli
from recOrder.io import utils


def test_reconstruct(tmp_path):
input_path = tmp_path / "input.zarr"
Expand All @@ -17,26 +18,35 @@ def test_reconstruct(tmp_path):
mode="w",
channel_names=channel_names,
)
dataset.create_zeros("0", (2, 4, 4, 5, 6), dtype=np.uint16)
dataset.create_zeros("0", (5, 4, 4, 5, 6), dtype=np.uint16)

# Setup options
birefringence_settings = settings.BirefringenceSettings(
transfer_function=settings.BirefringenceTransferFunctionSettings()
)

# birefringence_option, time_indices, phase_option, dimension_option, time_length_target
all_options = [
(birefringence_settings, None, 2),
(birefringence_settings, settings.PhaseSettings(), 2),
(birefringence_settings, None, 3),
(birefringence_settings, settings.PhaseSettings(), 3),
(birefringence_settings, [0, 3, 4], None, 2, 5),
(birefringence_settings, 0, settings.PhaseSettings(), 2, 5),
(birefringence_settings, [0, 1], None, 3, 5),
(birefringence_settings, "all", settings.PhaseSettings(), 3, 5),
]

for birefringence_option, phase_option, dimension_option in all_options:
for i, (
birefringence_option,
time_indices,
phase_option,
dimension_option,
time_length_target,
) in enumerate(all_options):
if (birefringence_option is None) and (phase_option is None):
continue

# Generate recon settings
recon_settings = settings.ReconstructionSettings(
input_channel_names=channel_names,
time_indices=time_indices,
reconstruction_dimension=dimension_option,
birefringence=birefringence_option,
phase=phase_option,
Expand All @@ -57,11 +67,12 @@ def test_reconstruct(tmp_path):
"-o",
str(tf_path),
],
catch_exceptions=False,
)
assert tf_path.exists()

# Apply the tf
result_path = input_path.with_name("result.zarr")
result_path = input_path.with_name(f"result{i}.zarr")

result_inv = runner.invoke(
cli,
Expand All @@ -74,14 +85,15 @@ def test_reconstruct(tmp_path):
"-o",
str(result_path),
],
catch_exceptions=False,
)
assert result_path.exists()
assert result_inv.exit_code == 0
assert "Reconstructing" in result_inv.output

# Check output
result_dataset = open_ome_zarr(result_path)
assert result_dataset["0"].shape[0] == 2
assert result_dataset["0"].shape[0] == time_length_target
assert result_dataset["0"].shape[3:] == (5, 6)

# Test direct recon
Expand All @@ -99,3 +111,4 @@ def test_reconstruct(tmp_path):
assert result_path.exists()
assert result_inv.exit_code == 0
assert "Reconstructing" in result_inv.output
assert "Reconstructing" in result_inv.output
Loading