From 35c2b37f8be9f13cf4638a79f11e0df0887c4b07 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 12 Jan 2022 11:20:24 -0500 Subject: [PATCH] `EnsureChannelFirst`: avoid re-creation of `AddChannel` (#3649) * avoid re-creation of AddChannel in EnsureChannelFirst --- monai/inferers/inferer.py | 3 ++- monai/transforms/utility/array.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index c7b70e06ca..0b700d7c2e 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -30,7 +30,8 @@ class Inferer(ABC): Example code:: device = torch.device("cuda:0") - data = ToTensor()(LoadImage()(filename=img_path)).to(device) + transform = Compose([ToTensor(), LoadImage(image_only=True)]) + data = transform(img_path).to(device) model = UNet(...).to(device) inferer = SlidingWindowInferer(...) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index a107cf1cb1..664433270b 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -202,6 +202,7 @@ def __init__(self, strict_check: bool = True): strict_check: whether to raise an error when the meta information is insufficient. """ self.strict_check = strict_check + self.add_channel = AddChannel() def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) -> NdarrayOrTensor: """ @@ -223,7 +224,7 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) -> warnings.warn(msg) return img if channel_dim == "no_channel": - return AddChannel()(img) + return self.add_channel(img) return AsChannelFirst(channel_dim=channel_dim)(img)