Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add options to skip operations for RestoreLabeld Transform #8125

Open
wants to merge 7 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 38 additions & 23 deletions monai/apps/deepgrow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,10 @@ def __init__(
original_shape_key: str = "foreground_original_shape",
cropped_shape_key: str = "foreground_cropped_shape",
allow_missing_keys: bool = False,
restore_resize: bool = True,
restore_crop: bool = True,
restore_spacing: bool = True,
restore_slicing: bool = True,
ctongh marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
super().__init__(keys, allow_missing_keys)
self.ref_image = ref_image
Expand All @@ -833,6 +837,10 @@ def __init__(
self.end_coord_key = end_coord_key
self.original_shape_key = original_shape_key
self.cropped_shape_key = cropped_shape_key
self.restore_resize = restore_resize
self.restore_crop = restore_crop
self.restore_spacing = restore_spacing
self.restore_slicing = restore_slicing

def __call__(self, data: Any) -> dict:
d = dict(data)
Expand All @@ -842,38 +850,45 @@ def __call__(self, data: Any) -> dict:
image = d[key]

# Undo Resize
current_shape = image.shape
cropped_shape = meta_dict[self.cropped_shape_key]
if np.any(np.not_equal(current_shape, cropped_shape)):
resizer = Resize(spatial_size=cropped_shape[1:], mode=mode)
image = resizer(image, mode=mode, align_corners=align_corners)
if self.restore_resize:
current_shape = image.shape
cropped_shape = meta_dict[self.cropped_shape_key]
if np.any(np.not_equal(current_shape, cropped_shape)):
resizer = Resize(spatial_size=cropped_shape[1:], mode=mode)
image = resizer(image, mode=mode, align_corners=align_corners)

# Undo Crop
original_shape = meta_dict[self.original_shape_key]
result = np.zeros(original_shape, dtype=np.float32)
box_start = meta_dict[self.start_coord_key]
box_end = meta_dict[self.end_coord_key]

spatial_dims = min(len(box_start), len(image.shape[1:]))
slices = tuple(
[slice(None)] + [slice(s, e) for s, e in zip(box_start[:spatial_dims], box_end[:spatial_dims])]
)
result[slices] = image
if self.restore_crop:
original_shape = meta_dict[self.original_shape_key]
result = np.zeros(original_shape, dtype=np.float32)
box_start = meta_dict[self.start_coord_key]
box_end = meta_dict[self.end_coord_key]

spatial_dims = min(len(box_start), len(image.shape[1:]))
slices = tuple(
[slice(None)] + [slice(s, e) for s, e in zip(box_start[:spatial_dims], box_end[:spatial_dims])]
)
result[slices] = image
else:
result = image

# Undo Spacing
current_size = result.shape[1:]
# change spatial_shape from HWD to DHW
spatial_shape = list(np.roll(meta_dict["spatial_shape"], 1))
spatial_size = spatial_shape[-len(current_size) :]
if self.restore_spacing:
current_size = result.shape[1:]
# change spatial_shape from HWD to DHW
spatial_shape = list(np.roll(meta_dict["spatial_shape"], 1))
spatial_size = spatial_shape[-len(current_size) :]

if np.any(np.not_equal(current_size, spatial_size)):
resizer = Resize(spatial_size=spatial_size, mode=mode)
result = resizer(result, mode=mode, align_corners=align_corners) # type: ignore
if np.any(np.not_equal(current_size, spatial_size)):
resizer = Resize(spatial_size=spatial_size, mode=mode)
result = resizer(result, mode=mode, align_corners=align_corners) # type: ignore

# Undo Slicing
slice_idx = meta_dict.get("slice_idx")
final_result: NdarrayOrTensor
if slice_idx is None or self.slice_only:
if self.restore_slicing == False:
ctongh marked this conversation as resolved.
Show resolved Hide resolved
final_result = result
elif slice_idx is None or self.slice_only:
final_result = result if len(result.shape) <= 3 else result[0]
else:
slice_idx = meta_dict["slice_idx"][0]
Expand Down
35 changes: 32 additions & 3 deletions tests/test_deepgrow_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@
]

RESTORE_LABEL_TEST_CASE_1 = [
{"keys": ["pred"], "ref_image": "image", "mode": "nearest"},
{"keys": ["pred"], "ref_image": "image", "mode": "nearest", "restore_resize": True, "restore_crop": True, "restore_spacing": True, "restore_slicing": True},
ctongh marked this conversation as resolved.
Show resolved Hide resolved
DATA_10,
np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]),
]
Expand All @@ -327,7 +327,36 @@
]
)

RESTORE_LABEL_TEST_CASE_2 = [{"keys": ["pred"], "ref_image": "image", "mode": "nearest"}, DATA_11, RESULT]
RESTORE_LABEL_TEST_CASE_2 = [
{"keys": ["pred"], "ref_image": "image", "mode": "nearest"},
DATA_11,
RESULT
]

RESTORE_LABEL_TEST_CASE_3_RESULT = np.zeros((10, 20, 20))
for layer in range(5):
RESTORE_LABEL_TEST_CASE_3_RESULT[layer, 0:10, 0:10] = 1
RESTORE_LABEL_TEST_CASE_3_RESULT[layer, 0:10, 10:20] = 2
RESTORE_LABEL_TEST_CASE_3_RESULT[layer, 10:20, 0:10] = 3
RESTORE_LABEL_TEST_CASE_3_RESULT[layer, 10:20, 10:20] = 4

for layer in range(5, 10):
RESTORE_LABEL_TEST_CASE_3_RESULT[layer, 0:10, 0:10] = 5
RESTORE_LABEL_TEST_CASE_3_RESULT[layer, 0:10, 10:20] = 6
RESTORE_LABEL_TEST_CASE_3_RESULT[layer, 10:20, 0:10] = 7
RESTORE_LABEL_TEST_CASE_3_RESULT[layer, 10:20, 10:20] = 8
ctongh marked this conversation as resolved.
Show resolved Hide resolved

RESTORE_LABEL_TEST_CASE_3 = [
{"keys": ["pred"], "ref_image": "image", "mode": "nearest", "restore_crop": False},
DATA_11,
RESTORE_LABEL_TEST_CASE_3_RESULT,
]

RESTORE_LABEL_TEST_CASE_4 = [
{"keys": ["pred"], "ref_image": "image", "mode": "nearest", "restore_resize": False, "restore_spacing": False, "restore_slicing": False, "restore_crop": False},
DATA_11,
np.array([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]]),
]

FETCH_2D_SLICE_TEST_CASE_1 = [
{"keys": ["image"], "guidance": "guidance"},
Expand Down Expand Up @@ -445,7 +474,7 @@ def test_correct_results(self, arguments, input_data, expected_result):

class TestRestoreLabeld(unittest.TestCase):

@parameterized.expand([RESTORE_LABEL_TEST_CASE_1, RESTORE_LABEL_TEST_CASE_2])
@parameterized.expand([RESTORE_LABEL_TEST_CASE_1, RESTORE_LABEL_TEST_CASE_2, RESTORE_LABEL_TEST_CASE_3, RESTORE_LABEL_TEST_CASE_4])
def test_correct_results(self, arguments, input_data, expected_result):
result = RestoreLabeld(**arguments)(input_data)
np.testing.assert_allclose(result["pred"], expected_result)
Expand Down
Loading