Skip to content

Commit

Permalink
Merge pull request #33 from angelolab/uint16
Browse files Browse the repository at this point in the history
Fixed uint16 inference bug
  • Loading branch information
JLrumberger authored Dec 3, 2024
2 parents 3f380f2 + ebeebac commit 983cd9b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/nimbus_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def get_channel_normalized(self, fov: str, channel: str):
print("No normalization dict found. Preparing normalization dict...")
self.prepare_normalization_dict()
mplex_img = self.get_channel(fov, channel)
mplex_img = mplex_img.astype(np.float32)
if channel in self.normalization_dict.keys():
norm_factor = self.normalization_dict[channel]
else:
Expand Down
24 changes: 22 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,10 +395,10 @@ def segmentation_naming_convention(fov_path):
fov_0_seg_ = dataset_ome.get_segmentation(fov="fov_0")
assert np.alltrue(fov_0_seg == fov_0_seg_)

# test everything again with single channel
# test everything again with single channel and float dtypes
fov_paths, _ = prepare_tif_data(
num_samples=1, temp_dir=temp_dir, selected_markers=["CD4", "CD56"],
shape=(512, 256)
shape=(512, 256), image_dtype=np.float32, instance_dtype=np.float32
)
cd4_channel = io.imread(os.path.join(fov_paths[0], "CD4.tiff"))
fov_0_seg = io.imread(segmentation_naming_convention(fov_paths[0]))
Expand Down Expand Up @@ -428,6 +428,26 @@ def segmentation_naming_convention(fov_path):
(groundtruth_df["fov"] == "fov_0") & (groundtruth_df["channel"] == "CD4")
]
assert np.alltrue(df == subset_df)

# test everything again with single channel and uint16 dtypes
fov_paths, _ = prepare_tif_data(
num_samples=1, temp_dir=temp_dir, selected_markers=["CD4", "CD56"],
shape=(512, 256), image_dtype=np.uint16, instance_dtype=np.uint16
)
cd4_channel = io.imread(os.path.join(fov_paths[0], "CD4.tiff"))
fov_0_seg = io.imread(segmentation_naming_convention(fov_paths[0]))
dataset = MultiplexDataset(
fov_paths, segmentation_naming_convention, suffix=".tiff",
groundtruth_df=groundtruth_df, output_dir=temp_dir
)
assert len(dataset) == 1
assert set(dataset.channels) == set(["CD4", "CD56"])
assert dataset.fov_paths == fov_paths
assert dataset.multi_channel == False
cd4_channel_ = dataset.get_channel(fov="fov_0", channel="CD4")
assert np.alltrue(cd4_channel == cd4_channel_)
fov_0_seg_ = dataset.get_segmentation(fov="fov_0")
assert np.alltrue(fov_0_seg == fov_0_seg_)


def test_prepare_normalization_dict():
Expand Down

0 comments on commit 983cd9b

Please sign in to comment.