Skip to content

Commit

Permalink
Add deepedit transforms (#2810)
Browse files Browse the repository at this point in the history
* Add deepedit transforms

Signed-off-by: Andres <diazandr3s@gmail.com>

* Run unittests - autofix

Signed-off-by: Andres <diazandr3s@gmail.com>

* Update transform

Signed-off-by: Andres <diazandr3s@gmail.com>
  • Loading branch information
diazandr3s authored Aug 27, 2021
1 parent f981ad0 commit f99ebda
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 0 deletions.
10 changes: 10 additions & 0 deletions monai/apps/deepedit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
167 changes: 167 additions & 0 deletions monai/apps/deepedit/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import json
import logging
from typing import Dict, Hashable, Mapping, Tuple

import numpy as np

from monai.config import KeysCollection
from monai.transforms.transform import MapTransform, Randomizable, Transform

logger = logging.getLogger(__name__)

from monai.utils import optional_import

distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt")


class DiscardAddGuidanced(MapTransform):
def __init__(
self,
keys: KeysCollection,
probability: float = 1.0,
allow_missing_keys: bool = False,
):
"""
Discard positive and negative points randomly or Add the two channels for inference time
:param probability: Discard probability; For inference it will be always 1.0
"""
super().__init__(keys, allow_missing_keys)
self.probability = probability

def _apply(self, image):
if self.probability >= 1.0 or np.random.choice([True, False], p=[self.probability, 1 - self.probability]):
signal = np.zeros((1, image.shape[-3], image.shape[-2], image.shape[-1]), dtype=np.float32)
if image.shape[0] == 3:
image[1] = signal
image[2] = signal
else:
image = np.concatenate((image, signal, signal), axis=0)
return image

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d: Dict = dict(data)
for key in self.key_iterator(d):
if key == "image":
d[key] = self._apply(d[key])
else:
print("This transform only applies to the image")
return d


class ResizeGuidanceCustomd(Transform):
"""
Resize the guidance based on cropped vs resized image.
"""

def __init__(
self,
guidance: str,
ref_image: str,
) -> None:
self.guidance = guidance
self.ref_image = ref_image

def __call__(self, data):
d = dict(data)
current_shape = d[self.ref_image].shape[1:]

factor = np.divide(current_shape, d["image_meta_dict"]["dim"][1:4])
pos_clicks, neg_clicks = d["foreground"], d["background"]

pos = np.multiply(pos_clicks, factor).astype(int).tolist() if len(pos_clicks) else []
neg = np.multiply(neg_clicks, factor).astype(int).tolist() if len(neg_clicks) else []

d[self.guidance] = [pos, neg]
return d


class ClickRatioAddRandomGuidanced(Randomizable, Transform):
"""
Add random guidance based on discrepancies that were found between label and prediction.
Args:
guidance: key to guidance source, shape (2, N, # of dim)
discrepancy: key that represents discrepancies found between label and prediction, shape (2, C, D, H, W) or (2, C, H, W)
probability: key that represents click/interaction probability, shape (1)
fn_fp_click_ratio: ratio of clicks between FN and FP
"""

def __init__(
self,
guidance: str = "guidance",
discrepancy: str = "discrepancy",
probability: str = "probability",
fn_fp_click_ratio: Tuple[float, float] = (1.0, 1.0),
):
self.guidance = guidance
self.discrepancy = discrepancy
self.probability = probability
self.fn_fp_click_ratio = fn_fp_click_ratio
self._will_interact = None

def randomize(self, data=None):
probability = data[self.probability]
self._will_interact = self.R.choice([True, False], p=[probability, 1.0 - probability])

def find_guidance(self, discrepancy):
distance = distance_transform_cdt(discrepancy).flatten()
probability = np.exp(distance) - 1.0
idx = np.where(discrepancy.flatten() > 0)[0]

if np.sum(discrepancy > 0) > 0:
seed = self.R.choice(idx, size=1, p=probability[idx] / np.sum(probability[idx]))
dst = distance[seed]

g = np.asarray(np.unravel_index(seed, discrepancy.shape)).transpose().tolist()[0]
g[0] = dst[0]
return g
return None

def add_guidance(self, discrepancy, will_interact):
if not will_interact:
return None, None

pos_discr = discrepancy[0]
neg_discr = discrepancy[1]

can_be_positive = np.sum(pos_discr) > 0
can_be_negative = np.sum(neg_discr) > 0

pos_prob = self.fn_fp_click_ratio[0] / (self.fn_fp_click_ratio[0] + self.fn_fp_click_ratio[1])
neg_prob = self.fn_fp_click_ratio[1] / (self.fn_fp_click_ratio[0] + self.fn_fp_click_ratio[1])

correct_pos = self.R.choice([True, False], p=[pos_prob, neg_prob])

if can_be_positive and not can_be_negative:
return self.find_guidance(pos_discr), None

if not can_be_positive and can_be_negative:
return None, self.find_guidance(neg_discr)

if correct_pos and can_be_positive:
return self.find_guidance(pos_discr), None

if not correct_pos and can_be_negative:
return None, self.find_guidance(neg_discr)
return None, None

def _apply(self, guidance, discrepancy):
guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance
guidance = json.loads(guidance) if isinstance(guidance, str) else guidance
pos, neg = self.add_guidance(discrepancy, self._will_interact)
if pos:
guidance[0].append(pos)
guidance[1].append([-1] * len(pos))
if neg:
guidance[0].append([-1] * len(neg))
guidance[1].append(neg)

return json.dumps(np.asarray(guidance).astype(int).tolist())

def __call__(self, data):
d = dict(data)
guidance = d[self.guidance]
discrepancy = d[self.discrepancy]
self.randomize(data)
d[self.guidance] = self._apply(guidance, discrepancy)
return d
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def run_testsuit():
"test_csv_iterable_dataset",
"test_dataset",
"test_dataset_summary",
"test_deepedit_transforms",
"test_deepgrow_dataset",
"test_deepgrow_interaction",
"test_deepgrow_transforms",
Expand Down
97 changes: 97 additions & 0 deletions tests/test_deepedit_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from parameterized import parameterized

from monai.apps.deepedit.transforms import ClickRatioAddRandomGuidanced, DiscardAddGuidanced, ResizeGuidanceCustomd

IMAGE = np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]])
LABEL = np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]])

DATA_1 = {
"image": IMAGE,
"label": LABEL,
"image_meta_dict": {"dim": IMAGE.shape},
"label_meta_dict": {},
"foreground": [0, 0, 0],
"background": [0, 0, 0],
}

DISCARD_ADD_GUIDANCE_TEST_CASE = [
{"image": IMAGE, "label": LABEL},
DATA_1,
(3, 1, 5, 5),
]

DATA_2 = {
"image": IMAGE,
"label": LABEL,
"guidance": np.array([[[1, 0, 2, 2]], [[-1, -1, -1, -1]]]),
"discrepancy": np.array(
[
[[[[0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],
[[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],
]
),
"probability": 1.0,
}

CLICK_RATIO_ADD_RANDOM_GUIDANCE_TEST_CASE_1 = [
{"guidance": "guidance", "discrepancy": "discrepancy", "probability": "probability"},
DATA_2,
"[[[1, 0, 2, 2], [-1, -1, -1, -1]], [[-1, -1, -1, -1], [1, 0, 2, 1]]]",
]

DATA_3 = {
"image": np.arange(1000).reshape((1, 5, 10, 20)),
"image_meta_dict": {"foreground_cropped_shape": (1, 10, 20, 40), "dim": [3, 512, 512, 128]},
"guidance": [[[6, 10, 14], [8, 10, 14]], [[8, 10, 16]]],
"foreground": [[10, 14, 6], [10, 14, 8]],
"background": [[10, 16, 8]],
}

RESIZE_GUIDANCE_TEST_CASE_1 = [
{"ref_image": "image", "guidance": "guidance"},
DATA_3,
[[[0, 0, 0], [0, 0, 1]], [[0, 0, 1]]],
]


class TestDiscardAddGuidanced(unittest.TestCase):
@parameterized.expand([DISCARD_ADD_GUIDANCE_TEST_CASE])
def test_correct_results(self, arguments, input_data, expected_result):
add_fn = DiscardAddGuidanced(arguments)
result = add_fn(input_data)
self.assertEqual(result["image"].shape, expected_result)


class TestClickRatioAddRandomGuidanced(unittest.TestCase):
@parameterized.expand([CLICK_RATIO_ADD_RANDOM_GUIDANCE_TEST_CASE_1])
def test_correct_results(self, arguments, input_data, expected_result):
seed = 0
add_fn = ClickRatioAddRandomGuidanced(**arguments)
add_fn.set_random_state(seed)
result = add_fn(input_data)
self.assertEqual(result[arguments["guidance"]], expected_result)


class TestResizeGuidanced(unittest.TestCase):
@parameterized.expand([RESIZE_GUIDANCE_TEST_CASE_1])
def test_correct_results(self, arguments, input_data, expected_result):
result = ResizeGuidanceCustomd(**arguments)(input_data)
self.assertEqual(result[arguments["guidance"]], expected_result)


if __name__ == "__main__":
unittest.main()

0 comments on commit f99ebda

Please sign in to comment.