Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow ApplyTransformToPointsd receive a sequence of refer keys #8063

Merged
merged 14 commits into from
Sep 4, 2024
30 changes: 17 additions & 13 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1758,8 +1758,9 @@ class ApplyTransformToPointsd(MapTransform, InvertibleTransform):
Args:
keys: keys of the corresponding items to be transformed.
See also: monai.transforms.MapTransform
refer_key: The key of the reference item used for transformation.
It can directly refer to an affine or an image from which the affine can be derived.
refer_keys: The key of the reference item used for transformation.
It can directly refer to an affine or an image from which the affine can be derived. It can also be a
sequence of keys, in which case each refers to the affine applied to the matching points in `keys`.
vikashg marked this conversation as resolved.
Show resolved Hide resolved
dtype: The desired data type for the output.
affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates
from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary
Expand All @@ -1782,31 +1783,34 @@ class ApplyTransformToPointsd(MapTransform, InvertibleTransform):
def __init__(
self,
keys: KeysCollection,
refer_key: str | None = None,
refer_keys: KeysCollection | None = None,
dtype: DtypeLike | torch.dtype = torch.float64,
affine: torch.Tensor | None = None,
invert_affine: bool = True,
affine_lps_to_ras: bool = False,
allow_missing_keys: bool = False,
):
MapTransform.__init__(self, keys, allow_missing_keys)
self.refer_key = refer_key
self.refer_keys = ensure_tuple_rep(None, len(self.keys)) if refer_keys is None else ensure_tuple(refer_keys)
if len(self.keys) != len(self.refer_keys):
raise ValueError("refer_keys should have the same length as keys.")
self.converter = ApplyTransformToPoints(
dtype=dtype, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras
)

def __call__(self, data: Mapping[Hashable, torch.Tensor]):
d = dict(data)
if self.refer_key is not None:
if self.refer_key in d:
refer_data = d[self.refer_key]
else:
raise KeyError(f"The refer_key '{self.refer_key}' is not found in the data.")
else:
refer_data = None
affine = getattr(refer_data, "affine", refer_data)
for key in self.key_iterator(d):
for key, refer_key in self.key_iterator(d, self.refer_keys):
coords = d[key]
affine = None # represents using affine given in constructor
if refer_key is not None:
if refer_key in d:
refer_data = d[refer_key]
else:
raise KeyError(f"The refer_key '{refer_key}' is not found in the data.")

# use the "affine" member of refer_data, or refer_data itself, as the affine matrix
affine = getattr(refer_data, "affine", refer_data)
d[key] = self.converter(coords, affine)
return d

Expand Down
136 changes: 94 additions & 42 deletions tests/test_apply_transform_to_pointsd.py
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -30,72 +30,90 @@
POINT_3D_WORLD = torch.tensor([[[2, 4, 6], [8, 10, 12]], [[14, 16, 18], [20, 22, 24]]])
POINT_3D_IMAGE = torch.tensor([[[-8, 8, 6], [-2, 14, 12]], [[4, 20, 18], [10, 26, 24]]])
POINT_3D_IMAGE_RAS = torch.tensor([[[-12, 0, 6], [-18, -6, 12]], [[-24, -12, 18], [-30, -18, 24]]])
AFFINE_1 = torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
AFFINE_2 = torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])

TEST_CASES = [
[MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, False, POINT_2D_IMAGE], # use image affine
[None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), None, False, False, POINT_2D_WORLD], # use point affine
[None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), AFFINE_1, False, False, POINT_2D_WORLD], # use input affine
[None, POINT_2D_WORLD, AFFINE_1, True, False, POINT_2D_IMAGE], # use input affine
[
MetaTensor(DATA_2D, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])),
MetaTensor(DATA_2D, affine=AFFINE_1),
POINT_2D_WORLD,
None,
True,
False,
POINT_2D_IMAGE,
],
True,
POINT_2D_IMAGE_RAS,
], # test affine_lps_to_ras
[MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE],
["affine", POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE], # use refer_data itself
[
None,
MetaTensor(POINT_2D_IMAGE, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])),
MetaTensor(DATA_3D, affine=AFFINE_2),
MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2),
None,
False,
False,
POINT_2D_WORLD,
POINT_3D_WORLD,
],
[MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS],
[MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS],
]
TEST_CASES_SEQUENCE = [
ericspod marked this conversation as resolved.
Show resolved Hide resolved
[
(MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)),
[POINT_2D_WORLD, POINT_3D_WORLD],
None,
MetaTensor(POINT_2D_IMAGE, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])),
torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]),
False,
True,
False,
POINT_2D_WORLD,
],
["image_1", "image_2"],
[POINT_2D_IMAGE, POINT_3D_IMAGE],
], # use image affine
[
MetaTensor(DATA_2D, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])),
POINT_2D_WORLD,
(MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)),
[POINT_2D_WORLD, POINT_3D_WORLD],
None,
True,
True,
POINT_2D_IMAGE_RAS,
],
["image_1", "image_2"],
[POINT_2D_IMAGE_RAS, POINT_3D_IMAGE_RAS],
], # test affine_lps_to_ras
[
MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])),
POINT_3D_WORLD,
(None, None),
[MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)],
None,
False,
False,
None,
[POINT_2D_WORLD, POINT_3D_WORLD],
], # use point affine
[
(None, None),
[POINT_2D_WORLD, POINT_2D_WORLD],
AFFINE_1,
True,
False,
POINT_3D_IMAGE,
],
["affine", POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE],
None,
[POINT_2D_IMAGE, POINT_2D_IMAGE],
], # use input affine
[
MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])),
MetaTensor(POINT_3D_IMAGE, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])),
(MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)),
[MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)],
None,
False,
False,
POINT_3D_WORLD,
],
[
MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])),
POINT_3D_WORLD,
None,
True,
True,
POINT_3D_IMAGE_RAS,
["image_1", "image_2"],
[POINT_2D_WORLD, POINT_3D_WORLD],
],
]

TEST_CASES_WRONG = [
[POINT_2D_WORLD, True, None],
[POINT_2D_WORLD.unsqueeze(0), False, None],
[POINT_3D_WORLD[..., 0:1], False, None],
[POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]])],
[POINT_2D_WORLD, True, None, None],
[POINT_2D_WORLD.unsqueeze(0), False, None, None],
[POINT_3D_WORLD[..., 0:1], False, None, None],
[POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]]), None],
[POINT_3D_WORLD, False, None, "image"],
[POINT_3D_WORLD, False, None, []],
]


Expand All @@ -107,10 +125,10 @@ def test_transform_coordinates(self, image, points, affine, invert_affine, affin
"point": points,
"affine": torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]),
}
refer_key = "image" if (image is not None and image != "affine") else image
refer_keys = "image" if (image is not None and image != "affine") else image
transform = ApplyTransformToPointsd(
keys="point",
refer_key=refer_key,
refer_keys=refer_keys,
dtype=torch.int64,
affine=affine,
invert_affine=invert_affine,
Expand All @@ -122,11 +140,45 @@ def test_transform_coordinates(self, image, points, affine, invert_affine, affin
invert_out = transform.inverse(output)
self.assertTrue(torch.allclose(invert_out["point"], points))

@parameterized.expand(TEST_CASES_SEQUENCE)
def test_transform_coordinates_sequences(
self, image, points, affine, invert_affine, affine_lps_to_ras, refer_keys, expected_output
):
data = {"image_1": image[0], "image_2": image[1], "point_1": points[0], "point_2": points[1]}
keys = ["point_1", "point_2"]
transform = ApplyTransformToPointsd(
keys=keys,
refer_keys=refer_keys,
dtype=torch.int64,
affine=affine,
invert_affine=invert_affine,
affine_lps_to_ras=affine_lps_to_ras,
)
output = transform(data)

self.assertTrue(torch.allclose(output["point_1"], expected_output[0]))
self.assertTrue(torch.allclose(output["point_2"], expected_output[1]))
invert_out = transform.inverse(output)
self.assertTrue(torch.allclose(invert_out["point_1"], points[0]))

@parameterized.expand(TEST_CASES_WRONG)
def test_wrong_input(self, input, invert_affine, affine):
transform = ApplyTransformToPointsd(keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine)
with self.assertRaises(ValueError):
transform({"point": input})
def test_wrong_input(self, input, invert_affine, affine, refer_keys):
if refer_keys == []:
with self.assertRaises(ValueError):
ApplyTransformToPointsd(
keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine, refer_keys=refer_keys
)
else:
transform = ApplyTransformToPointsd(
keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine, refer_keys=refer_keys
)
data = {"point": input}
if refer_keys == "image":
with self.assertRaises(KeyError):
transform(data)
else:
with self.assertRaises(ValueError):
transform(data)


if __name__ == "__main__":
Expand Down
Loading