diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index d5f5f6136b..e990b5fc98 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -37,21 +37,15 @@ from torch import nn from monai.networks.blocks import Convolution -from monai.utils import ensure_tuple_rep, optional_import -from monai.utils.type_conversion import convert_to_tensor - -get_down_block, has_get_down_block = optional_import( - "generative.networks.nets.diffusion_model_unet", name="get_down_block" -) -get_mid_block, has_get_mid_block = optional_import( - "generative.networks.nets.diffusion_model_unet", name="get_mid_block" -) -get_timestep_embedding, has_get_timestep_embedding = optional_import( - "generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding" +from monai.networks.nets.diffusion_model_unet import ( + get_down_block, + get_mid_block, + get_timestep_embedding, + get_up_block, + zero_module, ) -get_up_block, has_get_up_block = optional_import("generative.networks.nets.diffusion_model_unet", name="get_up_block") -xformers, has_xformers = optional_import("xformers") -zero_module, has_zero_module = optional_import("generative.networks.nets.diffusion_model_unet", name="zero_module") +from monai.utils import ensure_tuple_rep +from monai.utils.type_conversion import convert_to_tensor __all__ = ["DiffusionModelUNetMaisi"] @@ -78,6 +72,8 @@ class DiffusionModelUNetMaisi(nn.Module): cross_attention_dim: Number of context dimensions to use. num_class_embeds: If specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. upcast_attention: If True, upcast attention operations to full precision. + include_fc: whether to include the final linear layer. Default to False. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. dropout_cattn: If different from zero, this will be the dropout value for the cross-attention layers. include_top_region_index_input: If True, use top region index input. @@ -102,6 +98,8 @@ def __init__( cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, + include_fc: bool = False, + use_combined_linear: bool = False, use_flash_attention: bool = False, dropout_cattn: float = 0.0, include_top_region_index_input: bool = False, @@ -152,9 +150,6 @@ def __init__( "`num_channels`." ) - if use_flash_attention and not has_xformers: - raise ValueError("use_flash_attention is True but xformers is not installed.") - if use_flash_attention is True and not torch.cuda.is_available(): raise ValueError( "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." @@ -210,7 +205,6 @@ def __init__( input_channel = output_channel output_channel = num_channels[i] is_final_block = i == len(num_channels) - 1 - down_block = get_down_block( spatial_dims=spatial_dims, in_channels=input_channel, @@ -227,6 +221,8 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) @@ -245,6 +241,8 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) @@ -280,6 +278,8 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) diff --git a/tests/test_diffusion_model_unet_maisi.py b/tests/test_diffusion_model_unet_maisi.py index 059a4a4ba8..f9384e6d82 100644 --- a/tests/test_diffusion_model_unet_maisi.py +++ b/tests/test_diffusion_model_unet_maisi.py @@ -17,14 +17,11 @@ import torch from parameterized import parameterized +from monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi from monai.networks import eval_mode from monai.utils import optional_import _, has_einops = optional_import("einops") -_, has_generative = optional_import("generative") - -if has_generative: - from monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi UNCOND_CASES_2D = [ [ @@ -291,7 +288,6 @@ ] -@skipUnless(has_generative, "monai-generative required") class TestDiffusionModelUNetMaisi2D(unittest.TestCase): @parameterized.expand(UNCOND_CASES_2D) @@ -510,7 +506,6 @@ def test_shape_with_additional_inputs(self, input_param): self.assertEqual(result.shape, (1, 1, 16, 16)) -@skipUnless(has_generative, "monai-generative required") class TestDiffusionModelUNetMaisi3D(unittest.TestCase): @parameterized.expand(UNCOND_CASES_3D)