Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor DiffusionModelUNetMaisi #7989

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,15 @@
from torch import nn

from monai.networks.blocks import Convolution
from monai.utils import ensure_tuple_rep, optional_import
from monai.utils.type_conversion import convert_to_tensor

get_down_block, has_get_down_block = optional_import(
"generative.networks.nets.diffusion_model_unet", name="get_down_block"
)
get_mid_block, has_get_mid_block = optional_import(
"generative.networks.nets.diffusion_model_unet", name="get_mid_block"
)
get_timestep_embedding, has_get_timestep_embedding = optional_import(
"generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding"
from monai.networks.nets.diffusion_model_unet import (
get_down_block,
get_mid_block,
get_timestep_embedding,
get_up_block,
zero_module,
)
get_up_block, has_get_up_block = optional_import("generative.networks.nets.diffusion_model_unet", name="get_up_block")
xformers, has_xformers = optional_import("xformers")
zero_module, has_zero_module = optional_import("generative.networks.nets.diffusion_model_unet", name="zero_module")
from monai.utils import ensure_tuple_rep
from monai.utils.type_conversion import convert_to_tensor

__all__ = ["DiffusionModelUNetMaisi"]

Expand All @@ -78,6 +72,8 @@ class DiffusionModelUNetMaisi(nn.Module):
cross_attention_dim: Number of context dimensions to use.
num_class_embeds: If specified (as an int), then this model will be class-conditional with `num_class_embeds` classes.
upcast_attention: If True, upcast attention operations to full precision.
include_fc: whether to include the final linear layer. Default to False.
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
dropout_cattn: If different from zero, this will be the dropout value for the cross-attention layers.
include_top_region_index_input: If True, use top region index input.
Expand All @@ -102,6 +98,8 @@ def __init__(
cross_attention_dim: int | None = None,
num_class_embeds: int | None = None,
upcast_attention: bool = False,
include_fc: bool = False,
use_combined_linear: bool = False,
use_flash_attention: bool = False,
dropout_cattn: float = 0.0,
include_top_region_index_input: bool = False,
Expand Down Expand Up @@ -152,9 +150,6 @@ def __init__(
"`num_channels`."
)

if use_flash_attention and not has_xformers:
raise ValueError("use_flash_attention is True but xformers is not installed.")

if use_flash_attention is True and not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU."
Expand Down Expand Up @@ -210,7 +205,6 @@ def __init__(
input_channel = output_channel
output_channel = num_channels[i]
is_final_block = i == len(num_channels) - 1

down_block = get_down_block(
spatial_dims=spatial_dims,
in_channels=input_channel,
Expand All @@ -227,6 +221,8 @@ def __init__(
transformer_num_layers=transformer_num_layers,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
dropout_cattn=dropout_cattn,
)
Expand All @@ -245,6 +241,8 @@ def __init__(
transformer_num_layers=transformer_num_layers,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
dropout_cattn=dropout_cattn,
)
Expand Down Expand Up @@ -280,6 +278,8 @@ def __init__(
transformer_num_layers=transformer_num_layers,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
include_fc=include_fc,
use_combined_linear=use_combined_linear,
use_flash_attention=use_flash_attention,
dropout_cattn=dropout_cattn,
)
Expand Down
7 changes: 1 addition & 6 deletions tests/test_diffusion_model_unet_maisi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@
import torch
from parameterized import parameterized

from monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi
from monai.networks import eval_mode
from monai.utils import optional_import

_, has_einops = optional_import("einops")
_, has_generative = optional_import("generative")

if has_generative:
from monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi

UNCOND_CASES_2D = [
[
Expand Down Expand Up @@ -291,7 +288,6 @@
]


@skipUnless(has_generative, "monai-generative required")
class TestDiffusionModelUNetMaisi2D(unittest.TestCase):

@parameterized.expand(UNCOND_CASES_2D)
Expand Down Expand Up @@ -510,7 +506,6 @@ def test_shape_with_additional_inputs(self, input_param):
self.assertEqual(result.shape, (1, 1, 16, 16))


@skipUnless(has_generative, "monai-generative required")
class TestDiffusionModelUNetMaisi3D(unittest.TestCase):

@parameterized.expand(UNCOND_CASES_3D)
Expand Down
Loading