-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add deepedit transforms Signed-off-by: Andres <diazandr3s@gmail.com> * Run unittests - autofix Signed-off-by: Andres <diazandr3s@gmail.com> * Update transform Signed-off-by: Andres <diazandr3s@gmail.com>
- Loading branch information
1 parent
f981ad0
commit f99ebda
Showing
4 changed files
with
275 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Copyright 2020 - 2021 MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
import json | ||
import logging | ||
from typing import Dict, Hashable, Mapping, Tuple | ||
|
||
import numpy as np | ||
|
||
from monai.config import KeysCollection | ||
from monai.transforms.transform import MapTransform, Randomizable, Transform | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
from monai.utils import optional_import | ||
|
||
distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") | ||
|
||
|
||
class DiscardAddGuidanced(MapTransform): | ||
def __init__( | ||
self, | ||
keys: KeysCollection, | ||
probability: float = 1.0, | ||
allow_missing_keys: bool = False, | ||
): | ||
""" | ||
Discard positive and negative points randomly or Add the two channels for inference time | ||
:param probability: Discard probability; For inference it will be always 1.0 | ||
""" | ||
super().__init__(keys, allow_missing_keys) | ||
self.probability = probability | ||
|
||
def _apply(self, image): | ||
if self.probability >= 1.0 or np.random.choice([True, False], p=[self.probability, 1 - self.probability]): | ||
signal = np.zeros((1, image.shape[-3], image.shape[-2], image.shape[-1]), dtype=np.float32) | ||
if image.shape[0] == 3: | ||
image[1] = signal | ||
image[2] = signal | ||
else: | ||
image = np.concatenate((image, signal, signal), axis=0) | ||
return image | ||
|
||
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: | ||
d: Dict = dict(data) | ||
for key in self.key_iterator(d): | ||
if key == "image": | ||
d[key] = self._apply(d[key]) | ||
else: | ||
print("This transform only applies to the image") | ||
return d | ||
|
||
|
||
class ResizeGuidanceCustomd(Transform): | ||
""" | ||
Resize the guidance based on cropped vs resized image. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
guidance: str, | ||
ref_image: str, | ||
) -> None: | ||
self.guidance = guidance | ||
self.ref_image = ref_image | ||
|
||
def __call__(self, data): | ||
d = dict(data) | ||
current_shape = d[self.ref_image].shape[1:] | ||
|
||
factor = np.divide(current_shape, d["image_meta_dict"]["dim"][1:4]) | ||
pos_clicks, neg_clicks = d["foreground"], d["background"] | ||
|
||
pos = np.multiply(pos_clicks, factor).astype(int).tolist() if len(pos_clicks) else [] | ||
neg = np.multiply(neg_clicks, factor).astype(int).tolist() if len(neg_clicks) else [] | ||
|
||
d[self.guidance] = [pos, neg] | ||
return d | ||
|
||
|
||
class ClickRatioAddRandomGuidanced(Randomizable, Transform): | ||
""" | ||
Add random guidance based on discrepancies that were found between label and prediction. | ||
Args: | ||
guidance: key to guidance source, shape (2, N, # of dim) | ||
discrepancy: key that represents discrepancies found between label and prediction, shape (2, C, D, H, W) or (2, C, H, W) | ||
probability: key that represents click/interaction probability, shape (1) | ||
fn_fp_click_ratio: ratio of clicks between FN and FP | ||
""" | ||
|
||
def __init__( | ||
self, | ||
guidance: str = "guidance", | ||
discrepancy: str = "discrepancy", | ||
probability: str = "probability", | ||
fn_fp_click_ratio: Tuple[float, float] = (1.0, 1.0), | ||
): | ||
self.guidance = guidance | ||
self.discrepancy = discrepancy | ||
self.probability = probability | ||
self.fn_fp_click_ratio = fn_fp_click_ratio | ||
self._will_interact = None | ||
|
||
def randomize(self, data=None): | ||
probability = data[self.probability] | ||
self._will_interact = self.R.choice([True, False], p=[probability, 1.0 - probability]) | ||
|
||
def find_guidance(self, discrepancy): | ||
distance = distance_transform_cdt(discrepancy).flatten() | ||
probability = np.exp(distance) - 1.0 | ||
idx = np.where(discrepancy.flatten() > 0)[0] | ||
|
||
if np.sum(discrepancy > 0) > 0: | ||
seed = self.R.choice(idx, size=1, p=probability[idx] / np.sum(probability[idx])) | ||
dst = distance[seed] | ||
|
||
g = np.asarray(np.unravel_index(seed, discrepancy.shape)).transpose().tolist()[0] | ||
g[0] = dst[0] | ||
return g | ||
return None | ||
|
||
def add_guidance(self, discrepancy, will_interact): | ||
if not will_interact: | ||
return None, None | ||
|
||
pos_discr = discrepancy[0] | ||
neg_discr = discrepancy[1] | ||
|
||
can_be_positive = np.sum(pos_discr) > 0 | ||
can_be_negative = np.sum(neg_discr) > 0 | ||
|
||
pos_prob = self.fn_fp_click_ratio[0] / (self.fn_fp_click_ratio[0] + self.fn_fp_click_ratio[1]) | ||
neg_prob = self.fn_fp_click_ratio[1] / (self.fn_fp_click_ratio[0] + self.fn_fp_click_ratio[1]) | ||
|
||
correct_pos = self.R.choice([True, False], p=[pos_prob, neg_prob]) | ||
|
||
if can_be_positive and not can_be_negative: | ||
return self.find_guidance(pos_discr), None | ||
|
||
if not can_be_positive and can_be_negative: | ||
return None, self.find_guidance(neg_discr) | ||
|
||
if correct_pos and can_be_positive: | ||
return self.find_guidance(pos_discr), None | ||
|
||
if not correct_pos and can_be_negative: | ||
return None, self.find_guidance(neg_discr) | ||
return None, None | ||
|
||
def _apply(self, guidance, discrepancy): | ||
guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance | ||
guidance = json.loads(guidance) if isinstance(guidance, str) else guidance | ||
pos, neg = self.add_guidance(discrepancy, self._will_interact) | ||
if pos: | ||
guidance[0].append(pos) | ||
guidance[1].append([-1] * len(pos)) | ||
if neg: | ||
guidance[0].append([-1] * len(neg)) | ||
guidance[1].append(neg) | ||
|
||
return json.dumps(np.asarray(guidance).astype(int).tolist()) | ||
|
||
def __call__(self, data): | ||
d = dict(data) | ||
guidance = d[self.guidance] | ||
discrepancy = d[self.discrepancy] | ||
self.randomize(data) | ||
d[self.guidance] = self._apply(guidance, discrepancy) | ||
return d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright 2020 - 2021 MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
|
||
import numpy as np | ||
from parameterized import parameterized | ||
|
||
from monai.apps.deepedit.transforms import ClickRatioAddRandomGuidanced, DiscardAddGuidanced, ResizeGuidanceCustomd | ||
|
||
IMAGE = np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]]) | ||
LABEL = np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]) | ||
|
||
DATA_1 = { | ||
"image": IMAGE, | ||
"label": LABEL, | ||
"image_meta_dict": {"dim": IMAGE.shape}, | ||
"label_meta_dict": {}, | ||
"foreground": [0, 0, 0], | ||
"background": [0, 0, 0], | ||
} | ||
|
||
DISCARD_ADD_GUIDANCE_TEST_CASE = [ | ||
{"image": IMAGE, "label": LABEL}, | ||
DATA_1, | ||
(3, 1, 5, 5), | ||
] | ||
|
||
DATA_2 = { | ||
"image": IMAGE, | ||
"label": LABEL, | ||
"guidance": np.array([[[1, 0, 2, 2]], [[-1, -1, -1, -1]]]), | ||
"discrepancy": np.array( | ||
[ | ||
[[[[0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]], | ||
[[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]], | ||
] | ||
), | ||
"probability": 1.0, | ||
} | ||
|
||
CLICK_RATIO_ADD_RANDOM_GUIDANCE_TEST_CASE_1 = [ | ||
{"guidance": "guidance", "discrepancy": "discrepancy", "probability": "probability"}, | ||
DATA_2, | ||
"[[[1, 0, 2, 2], [-1, -1, -1, -1]], [[-1, -1, -1, -1], [1, 0, 2, 1]]]", | ||
] | ||
|
||
DATA_3 = { | ||
"image": np.arange(1000).reshape((1, 5, 10, 20)), | ||
"image_meta_dict": {"foreground_cropped_shape": (1, 10, 20, 40), "dim": [3, 512, 512, 128]}, | ||
"guidance": [[[6, 10, 14], [8, 10, 14]], [[8, 10, 16]]], | ||
"foreground": [[10, 14, 6], [10, 14, 8]], | ||
"background": [[10, 16, 8]], | ||
} | ||
|
||
RESIZE_GUIDANCE_TEST_CASE_1 = [ | ||
{"ref_image": "image", "guidance": "guidance"}, | ||
DATA_3, | ||
[[[0, 0, 0], [0, 0, 1]], [[0, 0, 1]]], | ||
] | ||
|
||
|
||
class TestDiscardAddGuidanced(unittest.TestCase): | ||
@parameterized.expand([DISCARD_ADD_GUIDANCE_TEST_CASE]) | ||
def test_correct_results(self, arguments, input_data, expected_result): | ||
add_fn = DiscardAddGuidanced(arguments) | ||
result = add_fn(input_data) | ||
self.assertEqual(result["image"].shape, expected_result) | ||
|
||
|
||
class TestClickRatioAddRandomGuidanced(unittest.TestCase): | ||
@parameterized.expand([CLICK_RATIO_ADD_RANDOM_GUIDANCE_TEST_CASE_1]) | ||
def test_correct_results(self, arguments, input_data, expected_result): | ||
seed = 0 | ||
add_fn = ClickRatioAddRandomGuidanced(**arguments) | ||
add_fn.set_random_state(seed) | ||
result = add_fn(input_data) | ||
self.assertEqual(result[arguments["guidance"]], expected_result) | ||
|
||
|
||
class TestResizeGuidanced(unittest.TestCase): | ||
@parameterized.expand([RESIZE_GUIDANCE_TEST_CASE_1]) | ||
def test_correct_results(self, arguments, input_data, expected_result): | ||
result = ResizeGuidanceCustomd(**arguments)(input_data) | ||
self.assertEqual(result[arguments["guidance"]], expected_result) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |