From 9f27c154069dc796cdaa8ddd65c51832237846fc Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Mon, 24 Jun 2024 17:30:38 -0400 Subject: [PATCH 01/22] inital commit Signed-off-by: Pengfei Guo --- monai/apps/generation/__init__.py | 10 ++ monai/apps/generation/maisi/__init__.py | 10 ++ .../generation/maisi/networks/__init__.py | 10 ++ .../maisi/networks/controlnet_maisi.py | 163 ++++++++++++++++++ 4 files changed, 193 insertions(+) create mode 100644 monai/apps/generation/__init__.py create mode 100644 monai/apps/generation/maisi/__init__.py create mode 100644 monai/apps/generation/maisi/networks/__init__.py create mode 100644 monai/apps/generation/maisi/networks/controlnet_maisi.py diff --git a/monai/apps/generation/__init__.py b/monai/apps/generation/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/apps/generation/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/apps/generation/maisi/__init__.py b/monai/apps/generation/maisi/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/apps/generation/maisi/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/apps/generation/maisi/networks/__init__.py b/monai/apps/generation/maisi/networks/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/apps/generation/maisi/networks/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py new file mode 100644 index 0000000000..e44e39a2f4 --- /dev/null +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -0,0 +1,163 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import monai +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Sequence +from monai.networks.blocks import Convolution +from generative.networks.nets.controlnet import ControlNet +from generative.networks.nets.diffusion_model_unet import get_timestep_embedding + + +class CustomControlNet(ControlNet): + """ + Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image + Diffusion Models" (https://arxiv.org/abs/2302.05543) + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + num_res_blocks: number of residual blocks (see ResnetBlock) per level. + num_channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for up/downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` + classes. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + conditioning_embedding_in_channels: number of input channels for the conditioning embedding. + conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + num_channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + conditioning_embedding_in_channels: int = 1, + conditioning_embedding_num_channels: Sequence[int] | None = (16, 32, 96, 256), + ) -> None: + super().__init__( + spatial_dims, + in_channels, + num_res_blocks, + num_channels, + attention_levels, + norm_num_groups, + norm_eps, + resblock_updown, + num_head_channels, + with_conditioning, + transformer_num_layers, + cross_attention_dim, + num_class_embeds, + upcast_attention, + use_flash_attention, + conditioning_embedding_in_channels, + conditioning_embedding_num_channels, + ) + + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + ) -> tuple[tuple[torch.Tensor], torch.Tensor]: + """ + Args: + x: input tensor (N, C, SpatialDims). + timesteps: timestep tensor (N,). + controlnet_cond: controlnet conditioning tensor (N, C, SpatialDims). + conditioning_scale: conditioning scale. + context: context tensor (N, 1, ContextDim). + class_labels: context tensor (N, ). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. initial convolution + h = self.conv_in(x) + + # controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + controlnet_cond = torch.utils.checkpoint.checkpoint(self.controlnet_cond_embedding, + controlnet_cond, + use_reentrant=False) + + h += controlnet_cond + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + down_block_res_samples: list[torch.Tensor] = [h] + for downsample_block in self.down_blocks: + h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) + for residual in res_samples: + down_block_res_samples.append(residual) + + # 5. mid + h = self.middle_block(hidden_states=h, temb=emb, context=context) + + # 6. Control net blocks + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(h) + + # 6. scaling + down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples] + mid_block_res_sample *= conditioning_scale + + return down_block_res_samples, mid_block_res_sample From 63290b5ec1e5ec42477859ae117641f3efd5571a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Jun 2024 21:36:24 +0000 Subject: [PATCH 02/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Pengfei Guo --- monai/apps/generation/maisi/networks/controlnet_maisi.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py index e44e39a2f4..2da6f0b6b9 100644 --- a/monai/apps/generation/maisi/networks/controlnet_maisi.py +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -9,14 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import monai -import numpy as np import torch -import torch.nn as nn -import torch.nn.functional as F from typing import Sequence -from monai.networks.blocks import Convolution from generative.networks.nets.controlnet import ControlNet from generative.networks.nets.diffusion_model_unet import get_timestep_embedding From 15fb9f6f9d2dec78d66ba579b22758d4c9bb47b8 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Tue, 25 Jun 2024 18:28:31 -0400 Subject: [PATCH 03/22] add unit test and refactor forward Signed-off-by: Pengfei Guo --- .../maisi/networks/controlnet_maisi.py | 71 ++++++++++------ requirements-dev.txt | 1 + tests/test_controlnet_maisi.py | 81 +++++++++++++++++++ 3 files changed, 127 insertions(+), 26 deletions(-) create mode 100644 tests/test_controlnet_maisi.py diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py index 2da6f0b6b9..663f68980f 100644 --- a/monai/apps/generation/maisi/networks/controlnet_maisi.py +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -9,14 +9,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch +from __future__ import annotations from typing import Sequence -from generative.networks.nets.controlnet import ControlNet -from generative.networks.nets.diffusion_model_unet import get_timestep_embedding + +import torch + +from monai.utils import optional_import + +ControlNet, has_controlnet = optional_import("generative.networks.nets.controlnet", name="ControlNet") +get_timestep_embedding, has_get_timestep_embedding = optional_import( + "generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding" +) + +__all__ = ["ControlNetMaisi"] -class CustomControlNet(ControlNet): +class ControlNetMaisi(ControlNet): """ Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image Diffusion Models" (https://arxiv.org/abs/2302.05543) @@ -40,6 +49,7 @@ class CustomControlNet(ControlNet): use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. conditioning_embedding_in_channels: number of input channels for the conditioning embedding. conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding. + use_checkpointing: if True, use activation checkpointing to save memory. """ def __init__( @@ -61,6 +71,7 @@ def __init__( use_flash_attention: bool = False, conditioning_embedding_in_channels: int = 1, conditioning_embedding_num_channels: Sequence[int] | None = (16, 32, 96, 256), + use_checkpointing: bool = True, ) -> None: super().__init__( spatial_dims, @@ -81,7 +92,7 @@ def __init__( conditioning_embedding_in_channels, conditioning_embedding_num_channels, ) - + self.use_checkpointing = use_checkpointing def forward( self, @@ -92,15 +103,25 @@ def forward( context: torch.Tensor | None = None, class_labels: torch.Tensor | None = None, ) -> tuple[tuple[torch.Tensor], torch.Tensor]: - """ - Args: - x: input tensor (N, C, SpatialDims). - timesteps: timestep tensor (N,). - controlnet_cond: controlnet conditioning tensor (N, C, SpatialDims). - conditioning_scale: conditioning scale. - context: context tensor (N, 1, ContextDim). - class_labels: context tensor (N, ). - """ + emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels) + h = self._apply_initial_convolution(x) + if self.use_checkpointing: + controlnet_cond = torch.utils.checkpoint.checkpoint( + self.controlnet_cond_embedding, controlnet_cond, use_reentrant=False + ) + else: + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + h += controlnet_cond + down_block_res_samples, h = self._apply_down_blocks(emb, context, h) + h = self._apply_mid_block(emb, context, h) + down_block_res_samples, mid_block_res_sample = self._apply_controlnet_blocks(h, down_block_res_samples) + # scaling + down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples] + mid_block_res_sample *= conditioning_scale + + return down_block_res_samples, mid_block_res_sample + + def _prepare_time_and_class_embedding(self, x, timesteps, class_labels): # 1. time t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) @@ -118,16 +139,14 @@ def forward( class_emb = class_emb.to(dtype=x.dtype) emb = emb + class_emb + return emb + + def _apply_initial_convolution(self, x): # 3. initial convolution h = self.conv_in(x) + return h - # controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) - controlnet_cond = torch.utils.checkpoint.checkpoint(self.controlnet_cond_embedding, - controlnet_cond, - use_reentrant=False) - - h += controlnet_cond - + def _apply_down_blocks(self, emb, context, h): # 4. down if context is not None and self.with_conditioning is False: raise ValueError("model should have with_conditioning = True if context is provided") @@ -137,12 +156,16 @@ def forward( for residual in res_samples: down_block_res_samples.append(residual) + return down_block_res_samples, h + + def _apply_mid_block(self, emb, context, h): # 5. mid h = self.middle_block(hidden_states=h, temb=emb, context=context) + return h + def _apply_controlnet_blocks(self, h, down_block_res_samples): # 6. Control net blocks controlnet_down_block_res_samples = () - for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): down_block_res_sample = controlnet_block(down_block_res_sample) controlnet_down_block_res_samples += (down_block_res_sample,) @@ -151,8 +174,4 @@ def forward( mid_block_res_sample = self.controlnet_mid_block(h) - # 6. scaling - down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples] - mid_block_res_sample *= conditioning_scale - return down_block_res_samples, mid_block_res_sample diff --git a/requirements-dev.txt b/requirements-dev.txt index a8ba25966b..37e5917c6a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -57,3 +57,4 @@ zarr lpips==0.1.4 nvidia-ml-py huggingface_hub +monai-generative diff --git a/tests/test_controlnet_maisi.py b/tests/test_controlnet_maisi.py new file mode 100644 index 0000000000..6c3cdefa55 --- /dev/null +++ b/tests/test_controlnet_maisi.py @@ -0,0 +1,81 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.utils import optional_import + +_, has_generative = optional_import("generative") + +if has_generative: + from monai.apps.generation.maisi.networks.controlnet_maisi import ControlNetMaisi + +TEST_CASES = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "conditioning_embedding_in_channels": 1, + "conditioning_embedding_num_channels": (8, 8), + "use_checkpointing": False, + }, + 6, + (1, 8, 4, 4), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "conditioning_embedding_in_channels": 1, + "conditioning_embedding_num_channels": (8, 8), + "use_checkpointing": True, + }, + 6, + (1, 8, 4, 4, 4), + ], +] + + +@skipUnless(has_generative, "monai-generative required") +class TestControlNet(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape): + net = ControlNetMaisi(**input_param) + with eval_mode(net): + x = torch.rand((1, 1, 16, 16)) if input_param["spatial_dims"] == 2 else torch.rand((1, 1, 16, 16, 16)) + timesteps = torch.randint(0, 1000, (1,)).long() + controlnet_cond = ( + torch.rand((1, 1, 32, 32)) if input_param["spatial_dims"] == 2 else torch.rand((1, 1, 32, 32, 32)) + ) + result = net.forward(x, timesteps, controlnet_cond) + self.assertEqual(len(result[0]), expected_num_down_blocks_residuals) + self.assertEqual(result[1].shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() From 55ad03186a865538b19302ffc2f64228f9523567 Mon Sep 17 00:00:00 2001 From: Adam Klimont Date: Tue, 25 Jun 2024 09:39:39 +0100 Subject: [PATCH 04/22] Change deprecated scipy.ndimage namespaces in optional imports (#7847) Fixes #7677 . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. --------- Signed-off-by: alkamid Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Pengfei Guo --- monai/apps/deepedit/transforms.py | 2 +- monai/apps/deepgrow/transforms.py | 2 +- monai/apps/nuclick/transforms.py | 2 +- monai/apps/pathology/transforms/post/array.py | 2 +- monai/apps/pathology/utils.py | 4 ++-- monai/metrics/utils.py | 6 +++--- monai/transforms/signal/array.py | 2 +- 7 files changed, 10 insertions(+), 10 deletions(-) diff --git a/monai/apps/deepedit/transforms.py b/monai/apps/deepedit/transforms.py index 6d0825f54a..5af082e2b0 100644 --- a/monai/apps/deepedit/transforms.py +++ b/monai/apps/deepedit/transforms.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) -distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") +distance_transform_cdt, _ = optional_import("scipy.ndimage", name="distance_transform_cdt") class DiscardAddGuidanced(MapTransform): diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index 9aca77a36c..c2f97091fd 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -27,7 +27,7 @@ from monai.utils.enums import PostFix measure, _ = optional_import("skimage.measure", "0.14.2", min_version) -distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") +distance_transform_cdt, _ = optional_import("scipy.ndimage", name="distance_transform_cdt") DEFAULT_POST_FIX = PostFix.meta() diff --git a/monai/apps/nuclick/transforms.py b/monai/apps/nuclick/transforms.py index f22ea764be..4828bd2e5a 100644 --- a/monai/apps/nuclick/transforms.py +++ b/monai/apps/nuclick/transforms.py @@ -24,7 +24,7 @@ measure, _ = optional_import("skimage.measure") morphology, _ = optional_import("skimage.morphology") -distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") +distance_transform_cdt, _ = optional_import("scipy.ndimage", name="distance_transform_cdt") class NuclickKeys(StrEnum): diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py index 99e94f89c0..42ca385fa0 100644 --- a/monai/apps/pathology/transforms/post/array.py +++ b/monai/apps/pathology/transforms/post/array.py @@ -33,7 +33,7 @@ from monai.utils.misc import ensure_tuple_rep from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor -label, _ = optional_import("scipy.ndimage.measurements", name="label") +label, _ = optional_import("scipy.ndimage", name="label") disk, _ = optional_import("skimage.morphology", name="disk") opening, _ = optional_import("skimage.morphology", name="opening") watershed, _ = optional_import("skimage.segmentation", name="watershed") diff --git a/monai/apps/pathology/utils.py b/monai/apps/pathology/utils.py index d3ebe0a7a6..3aa0bfab86 100644 --- a/monai/apps/pathology/utils.py +++ b/monai/apps/pathology/utils.py @@ -33,10 +33,10 @@ def compute_multi_instance_mask(mask: np.ndarray, threshold: float) -> Any: """ neg = 255 - mask * 255 - distance = ndimage.morphology.distance_transform_edt(neg) + distance = ndimage.distance_transform_edt(neg) binary = distance < threshold - filled_image = ndimage.morphology.binary_fill_holes(binary) + filled_image = ndimage.binary_fill_holes(binary) multi_instance_mask = measure.label(filled_image, connectivity=2) return multi_instance_mask diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index e7057256fb..340e54a1d7 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -35,9 +35,9 @@ optional_import, ) -binary_erosion, _ = optional_import("scipy.ndimage.morphology", name="binary_erosion") -distance_transform_edt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_edt") -distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") +binary_erosion, _ = optional_import("scipy.ndimage", name="binary_erosion") +distance_transform_edt, _ = optional_import("scipy.ndimage", name="distance_transform_edt") +distance_transform_cdt, _ = optional_import("scipy.ndimage", name="distance_transform_cdt") __all__ = [ "ignore_background", diff --git a/monai/transforms/signal/array.py b/monai/transforms/signal/array.py index 938f42192c..97df04f233 100644 --- a/monai/transforms/signal/array.py +++ b/monai/transforms/signal/array.py @@ -28,7 +28,7 @@ from monai.utils.enums import TransformBackends from monai.utils.type_conversion import convert_data_type, convert_to_tensor -shift, has_shift = optional_import("scipy.ndimage.interpolation", name="shift") +shift, has_shift = optional_import("scipy.ndimage", name="shift") iirnotch, has_iirnotch = optional_import("scipy.signal", name="iirnotch") with warnings.catch_warnings(): warnings.simplefilter("ignore", UserWarning) # project-monai/monai#5204 From 2ea47187366b13cb85fcf3d64f055a481f10ab68 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Tue, 25 Jun 2024 18:57:06 -0400 Subject: [PATCH 05/22] update Signed-off-by: Pengfei Guo --- monai/apps/generation/maisi/networks/controlnet_maisi.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py index 663f68980f..0135461e6c 100644 --- a/monai/apps/generation/maisi/networks/controlnet_maisi.py +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import Sequence +from typing import TYPE_CHECKING, Sequence, cast import torch @@ -22,7 +22,10 @@ "generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding" ) -__all__ = ["ControlNetMaisi"] +if TYPE_CHECKING: + from generative.networks.nets.controlnet import ControlNet as ControlNetType +else: + ControlNetType = cast(type, ControlNet) class ControlNetMaisi(ControlNet): From b191e7f36d2bd9e08accddad3e4d47b2e9995bdb Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Tue, 25 Jun 2024 19:12:20 -0400 Subject: [PATCH 06/22] update Signed-off-by: Pengfei Guo --- monai/apps/generation/maisi/networks/controlnet_maisi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py index 0135461e6c..63940debca 100644 --- a/monai/apps/generation/maisi/networks/controlnet_maisi.py +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -28,7 +28,7 @@ ControlNetType = cast(type, ControlNet) -class ControlNetMaisi(ControlNet): +class ControlNetMaisi(ControlNetType): """ Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image Diffusion Models" (https://arxiv.org/abs/2302.05543) From 24c6fb2d626a7d28ffc61907b3ca8b5224cb4513 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Tue, 25 Jun 2024 19:57:25 -0400 Subject: [PATCH 07/22] update Signed-off-by: Pengfei Guo --- monai/apps/generation/maisi/networks/controlnet_maisi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py index 63940debca..9de9269803 100644 --- a/monai/apps/generation/maisi/networks/controlnet_maisi.py +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -105,7 +105,7 @@ def forward( conditioning_scale: float = 1.0, context: torch.Tensor | None = None, class_labels: torch.Tensor | None = None, - ) -> tuple[tuple[torch.Tensor], torch.Tensor]: + ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]: emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels) h = self._apply_initial_convolution(x) if self.use_checkpointing: From 63ca2411d4232f5a4addc6d1bc8f220131bca791 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Tue, 25 Jun 2024 20:41:05 -0400 Subject: [PATCH 08/22] update Signed-off-by: Pengfei Guo --- monai/apps/generation/maisi/networks/controlnet_maisi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py index 9de9269803..7d99870519 100644 --- a/monai/apps/generation/maisi/networks/controlnet_maisi.py +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence, cast +from typing import TYPE_CHECKING, Any, Sequence, cast import torch @@ -105,7 +105,7 @@ def forward( conditioning_scale: float = 1.0, context: torch.Tensor | None = None, class_labels: torch.Tensor | None = None, - ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]: + ) -> tuple[tuple[Any, ...], torch.Tensor]: emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels) h = self._apply_initial_convolution(x) if self.use_checkpointing: From bca7aa223be14248040b6f43d4af455e15c81f3c Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Tue, 25 Jun 2024 20:54:58 -0400 Subject: [PATCH 09/22] update Signed-off-by: Pengfei Guo --- monai/apps/generation/maisi/networks/controlnet_maisi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py index 7d99870519..a8e3d22f4c 100644 --- a/monai/apps/generation/maisi/networks/controlnet_maisi.py +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Sequence, cast +from typing import TYPE_CHECKING, Sequence, cast import torch @@ -105,7 +105,7 @@ def forward( conditioning_scale: float = 1.0, context: torch.Tensor | None = None, class_labels: torch.Tensor | None = None, - ) -> tuple[tuple[Any, ...], torch.Tensor]: + ) -> tuple[tuple[torch.Tensor], torch.Tensor] | tuple[()]: emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels) h = self._apply_initial_convolution(x) if self.use_checkpointing: From 6413581be39d02fc503715a3990daf32c9629662 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Tue, 25 Jun 2024 21:09:31 -0400 Subject: [PATCH 10/22] update Signed-off-by: Pengfei Guo --- monai/apps/generation/maisi/networks/controlnet_maisi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py index a8e3d22f4c..fd50e1ea6d 100644 --- a/monai/apps/generation/maisi/networks/controlnet_maisi.py +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -105,7 +105,7 @@ def forward( conditioning_scale: float = 1.0, context: torch.Tensor | None = None, class_labels: torch.Tensor | None = None, - ) -> tuple[tuple[torch.Tensor], torch.Tensor] | tuple[()]: + ) -> tuple[Sequence[torch.Tensor], torch.Tensor]: emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels) h = self._apply_initial_convolution(x) if self.use_checkpointing: From 8bdbc79795b8c2963c7c3b6d88e7c1056c70a114 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Tue, 25 Jun 2024 21:38:17 -0400 Subject: [PATCH 11/22] update Signed-off-by: Pengfei Guo --- monai/apps/generation/maisi/networks/controlnet_maisi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py index fd50e1ea6d..2066887eb8 100644 --- a/monai/apps/generation/maisi/networks/controlnet_maisi.py +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -105,7 +105,7 @@ def forward( conditioning_scale: float = 1.0, context: torch.Tensor | None = None, class_labels: torch.Tensor | None = None, - ) -> tuple[Sequence[torch.Tensor], torch.Tensor]: + ): emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels) h = self._apply_initial_convolution(x) if self.use_checkpointing: From 1daeddfe810dbaeb0bbb3c00f90210c0521d90d9 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Tue, 25 Jun 2024 21:53:26 -0400 Subject: [PATCH 12/22] update Signed-off-by: Pengfei Guo --- monai/apps/generation/maisi/networks/controlnet_maisi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py index 2066887eb8..e182808ca4 100644 --- a/monai/apps/generation/maisi/networks/controlnet_maisi.py +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -105,7 +105,7 @@ def forward( conditioning_scale: float = 1.0, context: torch.Tensor | None = None, class_labels: torch.Tensor | None = None, - ): + ) -> tuple[tuple[torch.Tensor], torch.Tensor]: emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels) h = self._apply_initial_convolution(x) if self.use_checkpointing: @@ -171,7 +171,7 @@ def _apply_controlnet_blocks(self, h, down_block_res_samples): controlnet_down_block_res_samples = () for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): down_block_res_sample = controlnet_block(down_block_res_sample) - controlnet_down_block_res_samples += (down_block_res_sample,) + controlnet_down_block_res_samples += (down_block_res_sample,) # type: ignore down_block_res_samples = controlnet_down_block_res_samples From f992471c3b8654eff1f6db3c8df8e374f3fbd223 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Tue, 25 Jun 2024 22:08:22 -0400 Subject: [PATCH 13/22] update Signed-off-by: Pengfei Guo --- monai/apps/generation/maisi/networks/controlnet_maisi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py index e182808ca4..696e83378d 100644 --- a/monai/apps/generation/maisi/networks/controlnet_maisi.py +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -105,7 +105,7 @@ def forward( conditioning_scale: float = 1.0, context: torch.Tensor | None = None, class_labels: torch.Tensor | None = None, - ) -> tuple[tuple[torch.Tensor], torch.Tensor]: + ) -> tuple[Sequence[torch.Tensor], torch.Tensor]: emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels) h = self._apply_initial_convolution(x) if self.use_checkpointing: From 9ae94fe49f09407f946bd06342fe040c68ee7f30 Mon Sep 17 00:00:00 2001 From: Pengfei Guo <32000655+guopengf@users.noreply.github.com> Date: Wed, 26 Jun 2024 11:50:04 -0400 Subject: [PATCH 14/22] Update monai/apps/generation/maisi/networks/controlnet_maisi.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Pengfei Guo <32000655+guopengf@users.noreply.github.com> Signed-off-by: Pengfei Guo --- .../generation/maisi/networks/controlnet_maisi.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py index 696e83378d..e960975dce 100644 --- a/monai/apps/generation/maisi/networks/controlnet_maisi.py +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -168,13 +168,11 @@ def _apply_mid_block(self, emb, context, h): def _apply_controlnet_blocks(self, h, down_block_res_samples): # 6. Control net blocks - controlnet_down_block_res_samples = () - for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): - down_block_res_sample = controlnet_block(down_block_res_sample) - controlnet_down_block_res_samples += (down_block_res_sample,) # type: ignore + controlnet_down_block_res_samples = [] + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples.append(down_block_res_sample) - down_block_res_samples = controlnet_down_block_res_samples + mid_block_res_sample = self.controlnet_mid_block(h) - mid_block_res_sample = self.controlnet_mid_block(h) - - return down_block_res_samples, mid_block_res_sample + return controlnet_down_block_res_samples, mid_block_res_sample From c342b23af1c7030691aa59ebdce4561d742ae165 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Wed, 26 Jun 2024 15:13:58 -0400 Subject: [PATCH 15/22] add more test cases Signed-off-by: Pengfei Guo --- .../maisi/networks/controlnet_maisi.py | 12 +-- tests/test_controlnet_maisi.py | 86 +++++++++++++++++++ 2 files changed, 92 insertions(+), 6 deletions(-) diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py index e960975dce..3641124b7d 100644 --- a/monai/apps/generation/maisi/networks/controlnet_maisi.py +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -168,11 +168,11 @@ def _apply_mid_block(self, emb, context, h): def _apply_controlnet_blocks(self, h, down_block_res_samples): # 6. Control net blocks - controlnet_down_block_res_samples = [] - for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): - down_block_res_sample = controlnet_block(down_block_res_sample) - controlnet_down_block_res_samples.append(down_block_res_sample) + controlnet_down_block_res_samples = [] + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples.append(down_block_res_sample) - mid_block_res_sample = self.controlnet_mid_block(h) + mid_block_res_sample = self.controlnet_mid_block(h) - return controlnet_down_block_res_samples, mid_block_res_sample + return controlnet_down_block_res_samples, mid_block_res_sample diff --git a/tests/test_controlnet_maisi.py b/tests/test_controlnet_maisi.py index 6c3cdefa55..06c2de6d60 100644 --- a/tests/test_controlnet_maisi.py +++ b/tests/test_controlnet_maisi.py @@ -60,6 +60,71 @@ ], ] +TEST_CASES_CONDITIONAL = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "conditioning_embedding_in_channels": 1, + "conditioning_embedding_num_channels": (8, 8), + "use_checkpointing": False, + "with_conditioning": True, + "cross_attention_dim": 2, + }, + 6, + (1, 8, 4, 4), + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + "conditioning_embedding_in_channels": 1, + "conditioning_embedding_num_channels": (8, 8), + "use_checkpointing": True, + "with_conditioning": True, + "cross_attention_dim": 2, + }, + 6, + (1, 8, 4, 4, 4), + ], +] + +TEST_CASES_ERROR = [ + [ + {"spatial_dims": 2, "in_channels": 1, "with_conditioning": True, "cross_attention_dim": None}, + "ControlNet expects dimension of the cross-attention conditioning " + "(cross_attention_dim) when using with_conditioning.", + ], + [ + {"spatial_dims": 2, "in_channels": 1, "with_conditioning": False, "cross_attention_dim": 2}, + "ControlNet expects with_conditioning=True when specifying the cross_attention_dim.", + ], + [ + {"spatial_dims": 2, "in_channels": 1, "num_channels": (8, 16), "norm_num_groups": 16}, + "ControlNet expects all num_channels being multiple of norm_num_groups", + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "num_channels": (8, 16), + "attention_levels": (True,), + "norm_num_groups": 8, + }, + "ControlNet expects num_channels being same size of attention_levels", + ], +] + @skipUnless(has_generative, "monai-generative required") class TestControlNet(unittest.TestCase): @@ -76,6 +141,27 @@ def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_ self.assertEqual(len(result[0]), expected_num_down_blocks_residuals) self.assertEqual(result[1].shape, expected_shape) + @parameterized.expand(TEST_CASES_CONDITIONAL) + def test_shape_conditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape): + net = ControlNetMaisi(**input_param) + with eval_mode(net): + x = torch.rand((1, 1, 16, 16)) if input_param["spatial_dims"] == 2 else torch.rand((1, 1, 16, 16, 16)) + timesteps = torch.randint(0, 1000, (1,)).long() + controlnet_cond = ( + torch.rand((1, 1, 32, 32)) if input_param["spatial_dims"] == 2 else torch.rand((1, 1, 32, 32, 32)) + ) + context = torch.randn((1, 1, input_param["cross_attention_dim"])) + result = net.forward(x, timesteps, controlnet_cond, context=context) + self.assertEqual(len(result[0]), expected_num_down_blocks_residuals) + self.assertEqual(result[1].shape, expected_shape) + + @parameterized.expand(TEST_CASES_ERROR) + def test_error_input(self, input_param, expected_error): + with self.assertRaises(ValueError) as context: # output shape too small + _ = ControlNetMaisi(**input_param) + runtime_error = context.exception + self.assertEqual(str(runtime_error), expected_error) + if __name__ == "__main__": unittest.main() From 6d24a5190897f4199cc45b23ac55f310fb98d098 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Fri, 28 Jun 2024 12:50:13 -0400 Subject: [PATCH 16/22] update torch version req Signed-off-by: Pengfei Guo --- tests/test_controlnet_maisi.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_controlnet_maisi.py b/tests/test_controlnet_maisi.py index 06c2de6d60..9a44b66505 100644 --- a/tests/test_controlnet_maisi.py +++ b/tests/test_controlnet_maisi.py @@ -127,6 +127,7 @@ @skipUnless(has_generative, "monai-generative required") +@skipUnless(torch.__version__ >= "2.0", "torch>=2.0 required") class TestControlNet(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape): From 5f85e3b9b2b8ce7bd206b035eb5a214c4dc4f5d9 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Fri, 28 Jun 2024 13:23:10 -0400 Subject: [PATCH 17/22] update torch version req Signed-off-by: Pengfei Guo --- tests/test_controlnet_maisi.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_controlnet_maisi.py b/tests/test_controlnet_maisi.py index 9a44b66505..b522b750c8 100644 --- a/tests/test_controlnet_maisi.py +++ b/tests/test_controlnet_maisi.py @@ -19,6 +19,7 @@ from monai.networks import eval_mode from monai.utils import optional_import +from tests.utils import SkipIfBeforePyTorchVersion _, has_generative = optional_import("generative") @@ -126,8 +127,8 @@ ] +@SkipIfBeforePyTorchVersion((2, 0)) @skipUnless(has_generative, "monai-generative required") -@skipUnless(torch.__version__ >= "2.0", "torch>=2.0 required") class TestControlNet(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape): From 403bf9bffa353878943363ba6561127f18cee3a3 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Fri, 28 Jun 2024 13:34:24 -0400 Subject: [PATCH 18/22] update pre-commit-config Signed-off-by: Pengfei Guo --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b9debaf08f..3fff6ed631 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ ci: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.6.0 hooks: - id: end-of-file-fixer - id: trailing-whitespace From fcc95ccebeb139275c27cfbd435c0b0bc76aec99 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Mon, 1 Jul 2024 13:39:02 +0800 Subject: [PATCH 19/22] temp test Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 4c9b8c6b75..8aae3bc312 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -57,5 +57,5 @@ zarr lpips==0.1.4 nvidia-ml-py huggingface_hub -monai-generative +# monai-generative pyamg>=5.0.0 From 8eedb252b2d6c03d3887fe582b451eca7bce5369 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Mon, 1 Jul 2024 13:52:47 +0800 Subject: [PATCH 20/22] fix flake8 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/data/torchscript_utils.py | 2 +- requirements-dev.txt | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index cabf06ce89..507cf411d6 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -116,7 +116,7 @@ def load_net_with_metadata( Returns: Triple containing loaded object, metadata dict, and extra files dict containing other file data if present """ - extra_files = {f: "" for f in more_extra_files} + extra_files = dict.fromkeys(more_extra_files, "") extra_files[METADATA_FILENAME] = "" jit_obj = torch.jit.load(filename_prefix_or_stream, map_location, extra_files) diff --git a/requirements-dev.txt b/requirements-dev.txt index 8aae3bc312..517c842d1e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -57,5 +57,4 @@ zarr lpips==0.1.4 nvidia-ml-py huggingface_hub -# monai-generative pyamg>=5.0.0 From 984c5da610ac31d74272688f5e94c6342e208ee6 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Mon, 1 Jul 2024 17:34:19 +0800 Subject: [PATCH 21/22] temp-fix Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- requirements-dev.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-dev.txt b/requirements-dev.txt index 517c842d1e..e5eb28ff7c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -58,3 +58,4 @@ lpips==0.1.4 nvidia-ml-py huggingface_hub pyamg>=5.0.0 +git+https://github.com/KumoLiu/GenerativeModels.git@cuda#egg=monai-generative \ No newline at end of file From ebcf787fbcadb99f830a4ff5bc49b999b1c9acd4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 1 Jul 2024 09:36:02 +0000 Subject: [PATCH 22/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index e5eb28ff7c..b598f301f6 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -58,4 +58,4 @@ lpips==0.1.4 nvidia-ml-py huggingface_hub pyamg>=5.0.0 -git+https://github.com/KumoLiu/GenerativeModels.git@cuda#egg=monai-generative \ No newline at end of file +git+https://github.com/KumoLiu/GenerativeModels.git@cuda#egg=monai-generative