|
14 | 14 | from monai.deploy.core import Image |
15 | 15 | from monai.deploy.operators.monai_bundle_inference_operator import MonaiBundleInferenceOperator, get_bundle_config |
16 | 16 | from monai.deploy.utils.importutil import optional_import |
| 17 | +from monai.transforms import ConcatItemsd, ResampleToMatch |
17 | 18 |
|
18 | 19 | torch, _ = optional_import("torch", "1.10.2") |
19 | | - |
| 20 | +MetaTensor, _ = optional_import("monai.data.meta_tensor", name="MetaTensor") |
20 | 21 | __all__ = ["MONetBundleInferenceOperator"] |
21 | 22 |
|
22 | 23 |
|
@@ -60,10 +61,18 @@ def _init_config(self, config_names): |
60 | 61 | self._nnunet_predictor = parser.get_parsed_content("network_def") |
61 | 62 |
|
62 | 63 | def predict(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]: |
63 | | - """Predicts output using the inferer.""" |
| 64 | + """Predicts output using the inferer. If multimodal data is provided as keyword arguments, |
| 65 | + it concatenates the data with the main input data.""" |
64 | 66 |
|
65 | 67 | self._nnunet_predictor.predictor.network = self._model_network |
66 | 68 |
|
| 69 | + if len(kwargs) > 0: |
| 70 | + multimodal_data = {"image": data} |
| 71 | + for key in kwargs.keys(): |
| 72 | + if isinstance(kwargs[key], MetaTensor): |
| 73 | + multimodal_data[key] = ResampleToMatch(mode="bilinear")(kwargs[key], img_dst=data |
| 74 | + ) |
| 75 | + data = ConcatItemsd(keys=list(multimodal_data.keys()),name="image")(multimodal_data)["image"] |
67 | 76 | if len(data.shape) == 4: |
68 | 77 | data = data[None] |
69 | 78 | return self._nnunet_predictor(data) |
0 commit comments