Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Tailored ControlNet Implementations into Generative Model Application #7875

Merged
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ ci:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.6.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
Expand Down
10 changes: 10 additions & 0 deletions monai/apps/generation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
10 changes: 10 additions & 0 deletions monai/apps/generation/maisi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
10 changes: 10 additions & 0 deletions monai/apps/generation/maisi/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
178 changes: 178 additions & 0 deletions monai/apps/generation/maisi/networks/controlnet_maisi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Sequence, cast

import torch

from monai.utils import optional_import

ControlNet, has_controlnet = optional_import("generative.networks.nets.controlnet", name="ControlNet")
get_timestep_embedding, has_get_timestep_embedding = optional_import(
"generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding"
)

if TYPE_CHECKING:
from generative.networks.nets.controlnet import ControlNet as ControlNetType
else:
ControlNetType = cast(type, ControlNet)


class ControlNetMaisi(ControlNetType):
"""
Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image
Diffusion Models" (https://arxiv.org/abs/2302.05543)

Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
num_res_blocks: number of residual blocks (see ResnetBlock) per level.
num_channels: tuple of block output channels.
attention_levels: list of levels to add attention.
norm_num_groups: number of groups for the normalization.
norm_eps: epsilon for the normalization.
resblock_updown: if True use residual blocks for up/downsampling.
num_head_channels: number of channels in each attention head.
with_conditioning: if True add spatial transformers to perform conditioning.
transformer_num_layers: number of layers of Transformer blocks to use.
cross_attention_dim: number of context dimensions to use.
num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
classes.
upcast_attention: if True, upcast attention operations to full precision.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
conditioning_embedding_in_channels: number of input channels for the conditioning embedding.
conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding.
use_checkpointing: if True, use activation checkpointing to save memory.
"""

def __init__(
self,
spatial_dims: int,
in_channels: int,
num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
num_channels: Sequence[int] = (32, 64, 64, 64),
attention_levels: Sequence[bool] = (False, False, True, True),
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
resblock_updown: bool = False,
num_head_channels: int | Sequence[int] = 8,
with_conditioning: bool = False,
transformer_num_layers: int = 1,
cross_attention_dim: int | None = None,
num_class_embeds: int | None = None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
conditioning_embedding_in_channels: int = 1,
conditioning_embedding_num_channels: Sequence[int] | None = (16, 32, 96, 256),
use_checkpointing: bool = True,
) -> None:
super().__init__(
spatial_dims,
in_channels,
num_res_blocks,
num_channels,
attention_levels,
norm_num_groups,
norm_eps,
resblock_updown,
num_head_channels,
with_conditioning,
transformer_num_layers,
cross_attention_dim,
num_class_embeds,
upcast_attention,
use_flash_attention,
conditioning_embedding_in_channels,
conditioning_embedding_num_channels,
)
self.use_checkpointing = use_checkpointing

def forward(
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
self,
x: torch.Tensor,
timesteps: torch.Tensor,
controlnet_cond: torch.Tensor,
conditioning_scale: float = 1.0,
context: torch.Tensor | None = None,
class_labels: torch.Tensor | None = None,
) -> tuple[Sequence[torch.Tensor], torch.Tensor]:
emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels)
h = self._apply_initial_convolution(x)
if self.use_checkpointing:
controlnet_cond = torch.utils.checkpoint.checkpoint(
self.controlnet_cond_embedding, controlnet_cond, use_reentrant=False
)
else:
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
h += controlnet_cond
down_block_res_samples, h = self._apply_down_blocks(emb, context, h)
h = self._apply_mid_block(emb, context, h)
down_block_res_samples, mid_block_res_sample = self._apply_controlnet_blocks(h, down_block_res_samples)
# scaling
down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples]
mid_block_res_sample *= conditioning_scale

return down_block_res_samples, mid_block_res_sample

def _prepare_time_and_class_embedding(self, x, timesteps, class_labels):
# 1. time
t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])

# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=x.dtype)
emb = self.time_embed(t_emb)

# 2. class
if self.num_class_embeds is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
class_emb = self.class_embedding(class_labels)
class_emb = class_emb.to(dtype=x.dtype)
emb = emb + class_emb

return emb

def _apply_initial_convolution(self, x):
# 3. initial convolution
h = self.conv_in(x)
return h

def _apply_down_blocks(self, emb, context, h):
# 4. down
if context is not None and self.with_conditioning is False:
raise ValueError("model should have with_conditioning = True if context is provided")
down_block_res_samples: list[torch.Tensor] = [h]
for downsample_block in self.down_blocks:
h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context)
for residual in res_samples:
mingxin-zheng marked this conversation as resolved.
Show resolved Hide resolved
down_block_res_samples.append(residual)

return down_block_res_samples, h

def _apply_mid_block(self, emb, context, h):
# 5. mid
h = self.middle_block(hidden_states=h, temb=emb, context=context)
return h

def _apply_controlnet_blocks(self, h, down_block_res_samples):
# 6. Control net blocks
controlnet_down_block_res_samples = []
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples.append(down_block_res_sample)

mid_block_res_sample = self.controlnet_mid_block(h)

return controlnet_down_block_res_samples, mid_block_res_sample
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,5 @@ zarr
lpips==0.1.4
nvidia-ml-py
huggingface_hub
monai-generative
pyamg>=5.0.0
169 changes: 169 additions & 0 deletions tests/test_controlnet_maisi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest
from unittest import skipUnless

import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.utils import optional_import
from tests.utils import SkipIfBeforePyTorchVersion

_, has_generative = optional_import("generative")

if has_generative:
from monai.apps.generation.maisi.networks.controlnet_maisi import ControlNetMaisi

TEST_CASES = [
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
[
{
"spatial_dims": 2,
"in_channels": 1,
"num_res_blocks": 1,
"num_channels": (8, 8, 8),
"attention_levels": (False, False, True),
"num_head_channels": 8,
"norm_num_groups": 8,
"conditioning_embedding_in_channels": 1,
"conditioning_embedding_num_channels": (8, 8),
"use_checkpointing": False,
},
6,
(1, 8, 4, 4),
],
[
{
"spatial_dims": 3,
"in_channels": 1,
"num_res_blocks": 1,
"num_channels": (8, 8, 8),
"attention_levels": (False, False, True),
"num_head_channels": 8,
"norm_num_groups": 8,
"conditioning_embedding_in_channels": 1,
"conditioning_embedding_num_channels": (8, 8),
"use_checkpointing": True,
},
6,
(1, 8, 4, 4, 4),
],
]

TEST_CASES_CONDITIONAL = [
[
{
"spatial_dims": 2,
"in_channels": 1,
"num_res_blocks": 1,
"num_channels": (8, 8, 8),
"attention_levels": (False, False, True),
"num_head_channels": 8,
"norm_num_groups": 8,
"conditioning_embedding_in_channels": 1,
"conditioning_embedding_num_channels": (8, 8),
"use_checkpointing": False,
"with_conditioning": True,
"cross_attention_dim": 2,
},
6,
(1, 8, 4, 4),
],
[
{
"spatial_dims": 3,
"in_channels": 1,
"num_res_blocks": 1,
"num_channels": (8, 8, 8),
"attention_levels": (False, False, True),
"num_head_channels": 8,
"norm_num_groups": 8,
"conditioning_embedding_in_channels": 1,
"conditioning_embedding_num_channels": (8, 8),
"use_checkpointing": True,
"with_conditioning": True,
"cross_attention_dim": 2,
},
6,
(1, 8, 4, 4, 4),
],
]

TEST_CASES_ERROR = [
[
{"spatial_dims": 2, "in_channels": 1, "with_conditioning": True, "cross_attention_dim": None},
"ControlNet expects dimension of the cross-attention conditioning "
"(cross_attention_dim) when using with_conditioning.",
],
[
{"spatial_dims": 2, "in_channels": 1, "with_conditioning": False, "cross_attention_dim": 2},
"ControlNet expects with_conditioning=True when specifying the cross_attention_dim.",
],
[
{"spatial_dims": 2, "in_channels": 1, "num_channels": (8, 16), "norm_num_groups": 16},
"ControlNet expects all num_channels being multiple of norm_num_groups",
],
[
{
"spatial_dims": 2,
"in_channels": 1,
"num_channels": (8, 16),
"attention_levels": (True,),
"norm_num_groups": 8,
},
"ControlNet expects num_channels being same size of attention_levels",
],
]


@SkipIfBeforePyTorchVersion((2, 0))
@skipUnless(has_generative, "monai-generative required")
class TestControlNet(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape):
net = ControlNetMaisi(**input_param)
with eval_mode(net):
x = torch.rand((1, 1, 16, 16)) if input_param["spatial_dims"] == 2 else torch.rand((1, 1, 16, 16, 16))
timesteps = torch.randint(0, 1000, (1,)).long()
controlnet_cond = (
torch.rand((1, 1, 32, 32)) if input_param["spatial_dims"] == 2 else torch.rand((1, 1, 32, 32, 32))
)
result = net.forward(x, timesteps, controlnet_cond)
self.assertEqual(len(result[0]), expected_num_down_blocks_residuals)
self.assertEqual(result[1].shape, expected_shape)

@parameterized.expand(TEST_CASES_CONDITIONAL)
def test_shape_conditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape):
net = ControlNetMaisi(**input_param)
with eval_mode(net):
x = torch.rand((1, 1, 16, 16)) if input_param["spatial_dims"] == 2 else torch.rand((1, 1, 16, 16, 16))
timesteps = torch.randint(0, 1000, (1,)).long()
controlnet_cond = (
torch.rand((1, 1, 32, 32)) if input_param["spatial_dims"] == 2 else torch.rand((1, 1, 32, 32, 32))
)
context = torch.randn((1, 1, input_param["cross_attention_dim"]))
result = net.forward(x, timesteps, controlnet_cond, context=context)
self.assertEqual(len(result[0]), expected_num_down_blocks_residuals)
self.assertEqual(result[1].shape, expected_shape)

@parameterized.expand(TEST_CASES_ERROR)
def test_error_input(self, input_param, expected_error):
with self.assertRaises(ValueError) as context: # output shape too small
_ = ControlNetMaisi(**input_param)
runtime_error = context.exception
self.assertEqual(str(runtime_error), expected_error)


if __name__ == "__main__":
unittest.main()
Loading