diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml index ac148a87ec0b..0920ae0870e8 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml @@ -121,7 +121,6 @@ model: use_flash_attention: True enable_amp_o2_fp16: False resblock_gn_groups: 32 - use_te_fp8: False first_stage_config: _target_: nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKL @@ -192,7 +191,7 @@ model: synthetic_data_length: 10000 train: dataset_path: - - /datasets/coyo/wdinfo/coyo-700m/wdinfo-selene.pkl + - /datasets/coyo/test.pkl augmentations: resize_smallest_side: 512 center_crop_h_w: 512, 512 diff --git a/nemo/collections/multimodal/modules/stable_diffusion/attention.py b/nemo/collections/multimodal/modules/stable_diffusion/attention.py index 0f7744abcb3f..e70a473d658b 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/attention.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/attention.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -import os from inspect import isfunction import torch @@ -22,13 +21,6 @@ from torch import einsum, nn from torch._dynamo import disable -if os.environ.get("USE_NATIVE_GROUP_NORM", "0") == "1": - from nemo.gn_native import GroupNormNormlization as GroupNorm -else: - from apex.contrib.group_norm import GroupNorm - -from transformer_engine.pytorch.module import LayerNormLinear, LayerNormMLP - from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import checkpoint from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ( AdapterName, @@ -103,19 +95,13 @@ def forward(self, x): class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0, use_te=False): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) + project_in = nn.Sequential(LinearWrapper(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) - if use_te: - activation = 'gelu' if not glu else 'geglu' - # TODO: more parameters to be confirmed, dropout, seq_length - self.net = LayerNormMLP(hidden_size=dim, ffn_hidden_size=inner_dim, activation=activation,) - else: - norm = nn.LayerNorm(dim) - project_in = nn.Sequential(LinearWrapper(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) - self.net = nn.Sequential(norm, project_in, nn.Dropout(dropout), LinearWrapper(inner_dim, dim_out)) + self.net = nn.Sequential(project_in, nn.Dropout(dropout), LinearWrapper(inner_dim, dim_out)) def forward(self, x): return self.net(x) @@ -238,7 +224,6 @@ def __init__( dropout=0.0, use_flash_attention=False, lora_network_alpha=None, - use_te=False, ): super().__init__() @@ -252,16 +237,10 @@ def __init__( self.scale = dim_head ** -0.5 self.heads = heads + self.to_q = LinearWrapper(query_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha) self.to_k = LinearWrapper(context_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha) self.to_v = LinearWrapper(context_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha) - if use_te: - self.norm_to_q = LayerNormLinear(query_dim, self.inner_dim, bias=False) - else: - norm = nn.LayerNorm(query_dim) - to_q = LinearWrapper(query_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha) - self.norm_to_q = nn.Sequential(norm, to_q) - self.to_out = nn.Sequential( LinearWrapper(self.inner_dim, query_dim, lora_network_alpha=lora_network_alpha), nn.Dropout(dropout) ) @@ -276,7 +255,7 @@ def __init__( def forward(self, x, context=None, mask=None): h = self.heads - q = self.norm_to_q(x) + q = self.to_q(x) context = default(context, x) k = self.to_k(context) v = self.to_v(context) @@ -356,7 +335,6 @@ def __init__( use_flash_attention=False, disable_self_attn=False, lora_network_alpha=None, - use_te=False, ): super().__init__() self.disable_self_attn = disable_self_attn @@ -368,9 +346,8 @@ def __init__( use_flash_attention=use_flash_attention, context_dim=context_dim if self.disable_self_attn else None, lora_network_alpha=lora_network_alpha, - use_te=use_te, ) # is a self-attention - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, use_te=use_te) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.attn2 = CrossAttention( query_dim=dim, context_dim=context_dim, @@ -379,8 +356,10 @@ def __init__( dropout=dropout, use_flash_attention=use_flash_attention, lora_network_alpha=lora_network_alpha, - use_te=use_te, ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) self.use_checkpoint = use_checkpoint def forward(self, x, context=None): @@ -390,9 +369,9 @@ def forward(self, x, context=None): return self._forward(x, context) def _forward(self, x, context=None): - x = self.attn1(x, context=context if self.disable_self_attn else None) + x - x = self.attn2(x, context=context) + x - x = self.ff(x) + x + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x return x @@ -418,7 +397,6 @@ def __init__( use_checkpoint=False, use_flash_attention=False, lora_network_alpha=None, - use_te=False, ): super().__init__() if exists(context_dim) and not isinstance(context_dim, list): @@ -444,7 +422,6 @@ def __init__( use_flash_attention=use_flash_attention, disable_self_attn=disable_self_attn, lora_network_alpha=lora_network_alpha, - use_te=use_te, ) for d in range(depth) ] diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py index 831681761446..62842da602dc 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -import os from abc import abstractmethod -from contextlib import nullcontext import numpy as np import torch @@ -22,9 +20,6 @@ import torch.nn as nn import torch.nn.functional as F -# FP8 related import -import transformer_engine - from nemo.collections.multimodal.modules.stable_diffusion.attention import SpatialTransformer from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import ( avg_pool_nd, @@ -50,39 +45,6 @@ def convert_module_to_fp16(module): convert_module_to_dtype(module, torch.float16) -def convert_module_to_fp32(module): - convert_module_to_dtype(module, torch.float32) - - -def convert_module_to_fp8(model): - def _set_module(model, submodule_key, module): - tokens = submodule_key.split('.') - sub_tokens = tokens[:-1] - cur_mod = model - for s in sub_tokens: - cur_mod = getattr(cur_mod, s) - setattr(cur_mod, tokens[-1], module) - - import copy - - from transformer_engine.pytorch.module import Linear as te_Linear - - for n, v in model.named_modules(): - if isinstance(v, torch.nn.Linear): - # if n in ['class_embed', 'bbox_embed.layers.0', 'bbox_embed.layers.1', 'bbox_embed.layers.2']: continue - logging.info(f'[INFO] Replace Linear: {n}, weight: {v.weight.shape}') - if v.bias is None: - is_bias = False - else: - is_bias = True - newlinear = te_Linear(v.in_features, v.out_features, bias=is_bias) - newlinear.weight = copy.deepcopy(v.weight) - if v.bias is not None: - newlinear.bias = copy.deepcopy(v.bias) - _set_module(model, n, newlinear) - - -## go class AttentionPool2d(nn.Module): """ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py @@ -509,7 +471,6 @@ def __init__( use_flash_attention: bool = False, enable_amp_o2_fp16: bool = False, lora_network_alpha=None, - use_te_fp8: bool = False, ): super().__init__() if use_spatial_transformer: @@ -565,7 +526,6 @@ def __init__( input_block_chans = [model_channels] ch = model_channels ds = 1 - self.use_te_fp8 = use_te_fp8 for level, mult in enumerate(channel_mult): for _ in range(num_res_blocks): layers = [ @@ -608,7 +568,6 @@ def __init__( use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint, use_flash_attention=use_flash_attention, - use_te=self.use_te_fp8, lora_network_alpha=lora_network_alpha, ) ) @@ -674,7 +633,6 @@ def __init__( use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint, use_flash_attention=use_flash_attention, - use_te=self.use_te_fp8, lora_network_alpha=lora_network_alpha, ), ResBlock( @@ -702,7 +660,6 @@ def __init__( dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, - resblock_gn_groups=resblock_gn_groups, ) ] ch = model_channels * mult @@ -733,7 +690,6 @@ def __init__( use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint, use_flash_attention=use_flash_attention, - use_te=self.use_te_fp8, lora_network_alpha=lora_network_alpha, ) ) @@ -790,32 +746,6 @@ def __init__( if enable_amp_o2_fp16: self.convert_to_fp16() - elif self.use_te_fp8: - assert enable_amp_o2_fp16 is False, "fp8 training can't work with fp16 O2 amp recipe" - convert_module_to_fp8(self) - - fp8_margin = int(os.getenv("FP8_MARGIN", '0')) - fp8_interval = int(os.getenv("FP8_INTERVAL", '1')) - fp8_format = os.getenv("FP8_FORMAT", "hybrid") - fp8_amax_history_len = int(os.getenv("FP8_HISTORY_LEN", '1024')) - fp8_amax_compute_algo = os.getenv("FP8_COMPUTE_ALGO", 'max') - fp8_wgrad = os.getenv("FP8_WGRAD", '1') == '1' - - fp8_format_dict = { - 'hybrid': transformer_engine.common.recipe.Format.HYBRID, - 'e4m3': transformer_engine.common.recipe.Format.E4M3, - } - fp8_format = fp8_format_dict[fp8_format] - - self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling( - margin=fp8_margin, - interval=fp8_interval, - fp8_format=fp8_format, - amax_history_len=fp8_amax_history_len, - amax_compute_algo=fp8_amax_compute_algo, - override_linear_precision=(False, False, not fp8_wgrad), - ) - def _input_blocks_mapping(self, input_dict): res_dict = {} for key_, value_ in input_dict.items(): @@ -1030,7 +960,7 @@ def convert_to_fp16(self): """ self.apply(convert_module_to_fp16) - def _forward(self, x, timesteps=None, context=None, y=None, **kwargs): + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): """ Apply the model to an input batch. @@ -1069,13 +999,6 @@ def _forward(self, x, timesteps=None, context=None, y=None, **kwargs): else: return self.out(h) - def forward(self, x, timesteps=None, context=None, y=None, **kwargs): - with transformer_engine.pytorch.fp8_autocast( - enabled=self.use_te_fp8, fp8_recipe=self.fp8_recipe, - ) if self.use_te_fp8 else nullcontext(): - out = self._forward(x, timesteps, context, y, **kwargs) - return out - class EncoderUNetModel(nn.Module): """ diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 782c90577882..3491befc4caa 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -1003,30 +1003,6 @@ def should_process(key): new_state_dict[key_] = state_dict[key_] state_dict = new_state_dict - if conf.get('unet_config') and conf.get('unet_config').get('use_te_fp8') == False: - # remove _extra_state in fp8 if there is. - new_state_dict = {} - for key in state_dict.keys(): - if 'extra_state' in key: - continue - - ### LayerNormLinear - # norm_to_q.layer_norm_{weight|bias} -> norm_to_q.0.{weight|bias} - # norm_to_q.weight -> norm_to_q.1.weight - new_key = key.replace('norm_to_q.layer_norm_', 'norm_to_q.0.') - new_key = new_key.replace('norm_to_q.weight', 'norm_to_q.1.weight') - - ### LayerNormMLP - # ff.net.layer_norm_{weight|bias} -> ff.net.0.{weight|bias} - # ff.net.fc1_{weight|bias} -> ff.net.1.proj.{weight|bias} - # ff.net.fc2_{weight|bias} -> ff.net.3.{weight|bias} - new_key = new_key.replace('ff.net.layer_norm_', 'ff.net.0.') - new_key = new_key.replace('ff.net.fc1_', 'ff.net.1.proj.') - new_key = new_key.replace('ff.net.fc2_', 'ff.net.3.') - - new_state_dict[new_key] = state_dict[key] - state_dict = new_state_dict - return state_dict def _load_state_dict_from_disk(self, model_weights, map_location=None):