Skip to content

Commit

Permalink
add unit test and refactor forward
Browse files Browse the repository at this point in the history
  • Loading branch information
guopengf committed Jun 25, 2024
1 parent 9c9acb0 commit 7e2a8b0
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 26 deletions.
71 changes: 45 additions & 26 deletions monai/apps/generation/maisi/networks/controlnet_maisi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from __future__ import annotations

from typing import Sequence
from generative.networks.nets.controlnet import ControlNet
from generative.networks.nets.diffusion_model_unet import get_timestep_embedding

import torch

from monai.utils import optional_import

ControlNet, has_controlnet = optional_import("generative.networks.nets.controlnet", name="ControlNet")
get_timestep_embedding, has_get_timestep_embedding = optional_import(
"generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding"
)

__all__ = ["ControlNetMaisi"]


class CustomControlNet(ControlNet):
class ControlNetMaisi(ControlNet):
"""
Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image
Diffusion Models" (https://arxiv.org/abs/2302.05543)
Expand All @@ -40,6 +49,7 @@ class CustomControlNet(ControlNet):
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
conditioning_embedding_in_channels: number of input channels for the conditioning embedding.
conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding.
use_checkpointing: if True, use activation checkpointing to save memory.
"""

def __init__(
Expand All @@ -61,6 +71,7 @@ def __init__(
use_flash_attention: bool = False,
conditioning_embedding_in_channels: int = 1,
conditioning_embedding_num_channels: Sequence[int] | None = (16, 32, 96, 256),
use_checkpointing: bool = True,
) -> None:
super().__init__(
spatial_dims,
Expand All @@ -81,7 +92,7 @@ def __init__(
conditioning_embedding_in_channels,
conditioning_embedding_num_channels,
)

self.use_checkpointing = use_checkpointing

def forward(
self,
Expand All @@ -92,15 +103,25 @@ def forward(
context: torch.Tensor | None = None,
class_labels: torch.Tensor | None = None,
) -> tuple[tuple[torch.Tensor], torch.Tensor]:
"""
Args:
x: input tensor (N, C, SpatialDims).
timesteps: timestep tensor (N,).
controlnet_cond: controlnet conditioning tensor (N, C, SpatialDims).
conditioning_scale: conditioning scale.
context: context tensor (N, 1, ContextDim).
class_labels: context tensor (N, ).
"""
emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels)
h = self._apply_initial_convolution(x)
if self.use_checkpointing:
controlnet_cond = torch.utils.checkpoint.checkpoint(
self.controlnet_cond_embedding, controlnet_cond, use_reentrant=False
)
else:
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
h += controlnet_cond
down_block_res_samples, h = self._apply_down_blocks(emb, context, h)
h = self._apply_mid_block(emb, context, h)
down_block_res_samples, mid_block_res_sample = self._apply_controlnet_blocks(h, down_block_res_samples)
# scaling
down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples]
mid_block_res_sample *= conditioning_scale

return down_block_res_samples, mid_block_res_sample

def _prepare_time_and_class_embedding(self, x, timesteps, class_labels):
# 1. time
t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])

Expand All @@ -118,16 +139,14 @@ def forward(
class_emb = class_emb.to(dtype=x.dtype)
emb = emb + class_emb

return emb

def _apply_initial_convolution(self, x):
# 3. initial convolution
h = self.conv_in(x)
return h

# controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
controlnet_cond = torch.utils.checkpoint.checkpoint(self.controlnet_cond_embedding,
controlnet_cond,
use_reentrant=False)

h += controlnet_cond

def _apply_down_blocks(self, emb, context, h):
# 4. down
if context is not None and self.with_conditioning is False:
raise ValueError("model should have with_conditioning = True if context is provided")
Expand All @@ -137,12 +156,16 @@ def forward(
for residual in res_samples:
down_block_res_samples.append(residual)

return down_block_res_samples, h

def _apply_mid_block(self, emb, context, h):
# 5. mid
h = self.middle_block(hidden_states=h, temb=emb, context=context)
return h

def _apply_controlnet_blocks(self, h, down_block_res_samples):
# 6. Control net blocks
controlnet_down_block_res_samples = ()

for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples += (down_block_res_sample,)
Expand All @@ -151,8 +174,4 @@ def forward(

mid_block_res_sample = self.controlnet_mid_block(h)

# 6. scaling
down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples]
mid_block_res_sample *= conditioning_scale

return down_block_res_samples, mid_block_res_sample
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,4 @@ zarr
lpips==0.1.4
nvidia-ml-py
huggingface_hub
monai-generative
81 changes: 81 additions & 0 deletions tests/test_controlnet_maisi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest
from unittest import skipUnless

import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.utils import optional_import

_, has_generative = optional_import("generative")

if has_generative:
from monai.apps.generation.maisi.networks.controlnet_maisi import ControlNetMaisi

TEST_CASES = [
[
{
"spatial_dims": 2,
"in_channels": 1,
"num_res_blocks": 1,
"num_channels": (8, 8, 8),
"attention_levels": (False, False, True),
"num_head_channels": 8,
"norm_num_groups": 8,
"conditioning_embedding_in_channels": 1,
"conditioning_embedding_num_channels": (8, 8),
"use_checkpointing": False,
},
6,
(1, 8, 4, 4),
],
[
{
"spatial_dims": 3,
"in_channels": 1,
"num_res_blocks": 1,
"num_channels": (8, 8, 8),
"attention_levels": (False, False, True),
"num_head_channels": 8,
"norm_num_groups": 8,
"conditioning_embedding_in_channels": 1,
"conditioning_embedding_num_channels": (8, 8),
"use_checkpointing": True,
},
6,
(1, 8, 4, 4, 4),
],
]


@skipUnless(has_generative, "monai-generative required")
class TestControlNet(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape):
net = ControlNetMaisi(**input_param)
with eval_mode(net):
x = torch.rand((1, 1, 16, 16)) if input_param["spatial_dims"] == 2 else torch.rand((1, 1, 16, 16, 16))
timesteps = torch.randint(0, 1000, (1,)).long()
controlnet_cond = (
torch.rand((1, 1, 32, 32)) if input_param["spatial_dims"] == 2 else torch.rand((1, 1, 32, 32, 32))
)
result = net.forward(x, timesteps, controlnet_cond)
self.assertEqual(len(result[0]), expected_num_down_blocks_residuals)
self.assertEqual(result[1].shape, expected_shape)


if __name__ == "__main__":
unittest.main()

0 comments on commit 7e2a8b0

Please sign in to comment.