diff --git a/README.md b/README.md index f3c9be62..4022e605 100644 --- a/README.md +++ b/README.md @@ -190,7 +190,6 @@ model: attention_dropout: 0.0 gradient_checkpointing: True - enable_xformers_memory_efficient_attention: False dataset: @@ -259,7 +258,6 @@ __model__: - `model.vq_model.pretrained`: The pretrained vq model to use. Can be a path to a saved checkpoint or a huggingface model name. - `model.transformer`: The transformer model configuration. - `model.gradient_checkpointing`: Enable gradient checkpointing for the transformer model. -- `enable_xformers_memory_efficient_attention`: Enable memory efficient attention or flash attention for the transformer model. For flash attention we need to use `fp16` or `bf16`. [xformers](https://github.com/facebookresearch/xformers) needs to be installed for this to work. __dataset__: - `dataset.params.train_shards_path_or_url`: The path or url to the `webdataset` training shards. diff --git a/configs/cc12m.yaml b/configs/cc12m.yaml index 21ed3233..e005211e 100644 --- a/configs/cc12m.yaml +++ b/configs/cc12m.yaml @@ -49,7 +49,6 @@ model: use_codebook_size_for_output: True gradient_checkpointing: True - enable_xformers_memory_efficient_attention: True dataset: diff --git a/configs/cc12m_movq.yaml b/configs/cc12m_movq.yaml index 477d1c42..597fe3d2 100644 --- a/configs/cc12m_movq.yaml +++ b/configs/cc12m_movq.yaml @@ -51,8 +51,6 @@ model: patch_size: 2 gradient_checkpointing: True - enable_xformers_memory_efficient_attention: True - dataset: type: "text2image" diff --git a/configs/cc12m_uvit.yaml b/configs/cc12m_uvit.yaml index 74c28ccc..875ede0a 100644 --- a/configs/cc12m_uvit.yaml +++ b/configs/cc12m_uvit.yaml @@ -54,7 +54,6 @@ model: use_codebook_size_for_output: True gradient_checkpointing: True - enable_xformers_memory_efficient_attention: True dataset: diff --git a/configs/cc12m_uvit_clip.yaml b/configs/cc12m_uvit_clip.yaml index a80dab21..276b5d01 100644 --- a/configs/cc12m_uvit_clip.yaml +++ b/configs/cc12m_uvit_clip.yaml @@ -54,7 +54,6 @@ model: use_codebook_size_for_output: True gradient_checkpointing: True - enable_xformers_memory_efficient_attention: True dataset: diff --git a/configs/cc12m_uvit_larger_paellavq_f8_clip.yaml b/configs/cc12m_uvit_larger_paellavq_f8_clip.yaml index ca79474e..14b25453 100644 --- a/configs/cc12m_uvit_larger_paellavq_f8_clip.yaml +++ b/configs/cc12m_uvit_larger_paellavq_f8_clip.yaml @@ -54,7 +54,6 @@ model: use_codebook_size_for_output: True gradient_checkpointing: True - enable_xformers_memory_efficient_attention: True dataset: diff --git a/configs/cc12m_uvit_paellavq.yaml b/configs/cc12m_uvit_paellavq.yaml index c5305cd5..000b2390 100644 --- a/configs/cc12m_uvit_paellavq.yaml +++ b/configs/cc12m_uvit_paellavq.yaml @@ -55,7 +55,6 @@ model: use_codebook_size_for_output: True gradient_checkpointing: True - enable_xformers_memory_efficient_attention: True dataset: diff --git a/configs/cc12m_uvit_paellavq_larger.yaml b/configs/cc12m_uvit_paellavq_larger.yaml index 829af687..80e0a6b6 100644 --- a/configs/cc12m_uvit_paellavq_larger.yaml +++ b/configs/cc12m_uvit_paellavq_larger.yaml @@ -54,7 +54,6 @@ model: use_codebook_size_for_output: True gradient_checkpointing: True - enable_xformers_memory_efficient_attention: True dataset: diff --git a/configs/imagenet.yaml b/configs/imagenet.yaml index c31b8147..f24c5ac6 100644 --- a/configs/imagenet.yaml +++ b/configs/imagenet.yaml @@ -42,7 +42,6 @@ model: attention_dropout: 0.0 gradient_checkpointing: True - enable_xformers_memory_efficient_attention: True dataset: diff --git a/configs/imagenet_movq.yaml b/configs/imagenet_movq.yaml index 16b812d1..05d75bf9 100644 --- a/configs/imagenet_movq.yaml +++ b/configs/imagenet_movq.yaml @@ -42,7 +42,6 @@ model: attention_dropout: 0.0 gradient_checkpointing: True - enable_xformers_memory_efficient_attention: True dataset: diff --git a/configs/imagenet_text2image.yaml b/configs/imagenet_text2image.yaml index 36af7e2a..d4f39672 100644 --- a/configs/imagenet_text2image.yaml +++ b/configs/imagenet_text2image.yaml @@ -47,7 +47,6 @@ model: attention_dropout: 0.0 gradient_checkpointing: True - enable_xformers_memory_efficient_attention: True dataset: diff --git a/configs/imagenet_text2image_movq_conv.yaml b/configs/imagenet_text2image_movq_conv.yaml index 9783ad7f..d370360d 100644 --- a/configs/imagenet_text2image_movq_conv.yaml +++ b/configs/imagenet_text2image_movq_conv.yaml @@ -51,7 +51,6 @@ model: patch_size: 2 gradient_checkpointing: True - enable_xformers_memory_efficient_attention: True dataset: diff --git a/configs/laiona6plus_uvit_clip.yaml b/configs/laiona6plus_uvit_clip.yaml index 73c678ec..fc102383 100644 --- a/configs/laiona6plus_uvit_clip.yaml +++ b/configs/laiona6plus_uvit_clip.yaml @@ -54,7 +54,6 @@ model: use_codebook_size_for_output: True gradient_checkpointing: True - enable_xformers_memory_efficient_attention: True dataset: diff --git a/configs/laiona6plus_uvit_clip_f8.yaml b/configs/laiona6plus_uvit_clip_f8.yaml index 92e1e5a3..013b43e3 100644 --- a/configs/laiona6plus_uvit_clip_f8.yaml +++ b/configs/laiona6plus_uvit_clip_f8.yaml @@ -55,7 +55,6 @@ model: use_codebook_size_for_output: True gradient_checkpointing: True - enable_xformers_memory_efficient_attention: True dataset: diff --git a/configs/template_config.yaml b/configs/template_config.yaml index 4e4d688c..68182de9 100644 --- a/configs/template_config.yaml +++ b/configs/template_config.yaml @@ -39,7 +39,6 @@ model: attention_dropout: 0.0 gradient_checkpointing: True - enable_xformers_memory_efficient_attention: False dataset: diff --git a/muse/modeling_movq.py b/muse/modeling_movq.py index e1379a0c..1037a800 100644 --- a/muse/modeling_movq.py +++ b/muse/modeling_movq.py @@ -10,13 +10,6 @@ from .modeling_utils import ConfigMixin, ModelMixin, register_to_config -try: - import xformers.ops as xops - - is_xformers_available = True -except ImportError: - is_xformers_available = False - class SpatialNorm(nn.Module): def __init__( @@ -170,17 +163,6 @@ def __init__(self, in_channels, zq_ch=None, add_conv=False): self.v = nn.Linear(in_channels, in_channels) self.proj_out = nn.Linear(in_channels, in_channels) - self.use_memory_efficient_attention_xformers = False - self.xformers_attention_op = None - - def set_use_memory_efficient_attention_xformers( - self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None - ): - if use_memory_efficient_attention_xformers and not is_xformers_available: - raise ImportError("Please install xformers to use memory efficient attention") - self.use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers - self.xformers_attention_op = attention_op - def forward(self, hidden_states, zq=None): residual = hidden_states batch, channel, height, width = hidden_states.shape @@ -190,33 +172,12 @@ def forward(self, hidden_states, zq=None): hidden_states = self.norm(hidden_states) hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - scale = 1.0 / torch.sqrt(torch.tensor(channel, dtype=hidden_states.dtype, device=hidden_states.device)) query = self.q(hidden_states) key = self.k(hidden_states) value = self.v(hidden_states) - if self.use_memory_efficient_attention_xformers: - # Memory efficient attention - hidden_states = xops.memory_efficient_attention( - query, key, value, attn_bias=None, op=self.xformers_attention_op - ) - else: - attention_scores = torch.baddbmm( - torch.empty( - query.shape[0], - query.shape[1], - key.shape[1], - dtype=query.dtype, - device=query.device, - ), - query, - key.transpose(-1, -2), - beta=0, - alpha=scale, - ) - attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) - hidden_states = torch.bmm(attention_probs, value) + hidden_states = F.scaled_dot_product_attention(query, key, value) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.transpose(-1, -2).view(batch, channel, height, width) diff --git a/muse/modeling_transformer.py b/muse/modeling_transformer.py index 954ca4c7..5d653e36 100644 --- a/muse/modeling_transformer.py +++ b/muse/modeling_transformer.py @@ -29,14 +29,6 @@ from .modeling_utils import ConfigMixin, ModelMixin, register_to_config from .sampling import cosine_schedule, gumbel_sample, mask_by_random_topk, top_k -try: - import xformers.ops as xops - - is_xformers_available = True -except ImportError: - is_xformers_available = False - - # classifier free guidance functions @@ -116,7 +108,7 @@ def forward(self, x): class GlobalResponseNorm(nn.Module): - "Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" + # Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 def __init__(self, dim): super().__init__() @@ -309,21 +301,7 @@ def __init__(self, hidden_size, num_heads, encoder_hidden_size=None, attention_d self.out = nn.Linear(self.hidden_size, self.hidden_size, bias=use_bias) self.dropout = nn.Dropout(attention_dropout) - self.use_memory_efficient_attention_xformers = False - self.xformers_attention_op = None - - def set_use_memory_efficient_attention_xformers( - self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None - ): - if use_memory_efficient_attention_xformers and not is_xformers_available: - raise ImportError("Please install xformers to use memory efficient attention") - self.use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers - self.xformers_attention_op = attention_op - def forward(self, hidden_states, encoder_hidden_states=None, encoder_attention_mask=None): - if encoder_attention_mask is not None and self.use_memory_efficient_attention_xformers: - raise ValueError("Memory efficient attention does not yet support encoder attention mask") - context = hidden_states if encoder_hidden_states is None else encoder_hidden_states batch, q_seq_len, _ = hidden_states.shape kv_seq_len = q_seq_len if encoder_hidden_states is None else encoder_hidden_states.shape[1] @@ -332,43 +310,21 @@ def forward(self, hidden_states, encoder_hidden_states=None, encoder_attention_m key = self.key(context) value = self.value(context) - query = query.view(batch, q_seq_len, self.num_heads, self.head_dim) # (B, T, nh, hs) - key = key.view(batch, kv_seq_len, self.num_heads, self.head_dim) # (B, T, nh, hs) - value = value.view(batch, kv_seq_len, self.num_heads, self.head_dim) # (B, T, nh, hs) + query = query.view(batch, q_seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # (B, nh, T, hs) + key = key.view(batch, kv_seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # (B, nh, T, hs) + value = value.view(batch, kv_seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # (B, nh, T, hs) - if self.use_memory_efficient_attention_xformers: - attn_output = xops.memory_efficient_attention(query, key, value, op=self.xformers_attention_op) - attn_output = attn_output.view(batch, q_seq_len, self.hidden_size) - else: - attention_mask = None - if encoder_attention_mask is not None: - src_attn_mask = torch.ones(batch, q_seq_len, dtype=torch.long, device=query.device) - attention_mask = make_attention_mask(src_attn_mask, encoder_attention_mask, dtype=query.dtype) - attn_output = self.attention(query, key, value, attention_mask) + attention_mask = None + if encoder_attention_mask is not None: + src_attn_mask = torch.ones(batch, q_seq_len, dtype=torch.long, device=query.device) + attention_mask = make_attention_mask(src_attn_mask, encoder_attention_mask, dtype=query.dtype) - attn_output = self.out(attn_output) - return attn_output + attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) # (B, nh, T, hs) + attn_output = attn_output.permute(0, 2, 1, 3).reshape( + batch, q_seq_len, self.num_heads * self.head_dim + ) # (B, T, q_dim) - def attention(self, query, key, value, attention_mask=None): - batch, seq_len = query.shape[:2] - kv_seq_len = key.shape[1] - query, key, value = map(lambda t: t.transpose(1, 2).contiguous(), (query, key, value)) # (B, nh, T, hs) - - attn_weights = torch.baddbmm( - input=torch.zeros(batch * self.num_heads, seq_len, kv_seq_len, dtype=query.dtype, device=query.device), - batch1=query.view(batch * self.num_heads, seq_len, self.head_dim), - batch2=key.view(batch * self.num_heads, kv_seq_len, self.head_dim).transpose(1, 2), - alpha=1 / self.scale_attn, - ) - attn_weights = attn_weights.view(batch, self.num_heads, seq_len, kv_seq_len) # -1 is kv_seq_len - # Apply the attention mask - if attention_mask is not None: - attn_weights = torch.masked_fill(attn_weights, attention_mask, torch.finfo(query.dtype).min) - attn_weights = F.softmax(attn_weights, dim=-1) - attn_weights = self.dropout(attn_weights) - attn_output = torch.matmul(attn_weights, value) # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - # re-assemble all head outputs side by side - attn_output = attn_output.transpose(1, 2).contiguous().view(batch, seq_len, self.hidden_size) + attn_output = self.out(attn_output) return attn_output diff --git a/muse/modeling_utils.py b/muse/modeling_utils.py index f31c337e..69c50481 100644 --- a/muse/modeling_utils.py +++ b/muse/modeling_utils.py @@ -273,61 +273,6 @@ def disable_gradient_checkpointing(self): if self._supports_gradient_checkpointing: self.apply(partial(self._set_gradient_checkpointing, value=False)) - def set_use_memory_efficient_attention_xformers( - self, valid: bool, attention_op: Optional[Callable] = None - ) -> None: - # Recursively walk through all the children. - # Any children which exposes the set_use_memory_efficient_attention_xformers method - # gets the message - def fn_recursive_set_mem_eff(module: torch.nn.Module): - if hasattr(module, "set_use_memory_efficient_attention_xformers"): - module.set_use_memory_efficient_attention_xformers(valid, attention_op) - - for child in module.children(): - fn_recursive_set_mem_eff(child) - - for module in self.children(): - if isinstance(module, torch.nn.Module): - fn_recursive_set_mem_eff(module) - - def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): - r""" - Enable memory efficient attention as implemented in xformers. - - When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference - time. Speed up at training time is not guaranteed. - - Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention - is used. - - Parameters: - attention_op (`Callable`, *optional*): - Override the default `None` operator for use as `op` argument to the - [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention) - function of xFormers. - - Examples: - - ```py - >>> import torch - >>> from diffusers import UNet2DConditionModel - >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp - - >>> model = UNet2DConditionModel.from_pretrained( - ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16 - ... ) - >>> model = model.to("cuda") - >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) - ``` - """ - self.set_use_memory_efficient_attention_xformers(True, attention_op) - - def disable_xformers_memory_efficient_attention(self): - r""" - Disable memory efficient attention as implemented in xformers. - """ - self.set_use_memory_efficient_attention_xformers(False) - def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/muse/pipeline_muse.py b/muse/pipeline_muse.py index 1a94b521..02ac0b33 100644 --- a/muse/pipeline_muse.py +++ b/muse/pipeline_muse.py @@ -205,7 +205,7 @@ def from_pretrained( # TODO: Add config for pipeline to specify text encoder is_clip = "clip" in text_encoder_args["pretrained_model_name_or_path"] text_encoder_cls = CLIPTextModel if is_clip else T5EncoderModel - + text_encoder = text_encoder_cls.from_pretrained(**text_encoder_args) tokenizer = AutoTokenizer.from_pretrained(**tokenizer_args) diff --git a/scripts/benchmark_models.py b/scripts/benchmark_models.py index 72be2fb4..d1132d22 100644 --- a/scripts/benchmark_models.py +++ b/scripts/benchmark_models.py @@ -36,7 +36,6 @@ def create_model_and_benchmark(args): time_vanilla_fp16 = benchmark_torch_function(f) print("Running benchmark for efficient attention in FP16 ...") - model.enable_xformers_memory_efficient_attention() f = lambda: model.generate2(encoder_hidden_states=encoder_hidden_states, timesteps=args.time_steps) time_efficient_fp16 = benchmark_torch_function(f) diff --git a/scripts/calculate_fid.py b/scripts/calculate_fid.py index c678ab33..4cdeec05 100644 --- a/scripts/calculate_fid.py +++ b/scripts/calculate_fid.py @@ -41,7 +41,6 @@ def generate_and_save_images(args): print("Loading pipe") pipeline = PipelineMuse.from_pretrained(args.model_name_or_path).to(args.device) - pipeline.transformer.enable_xformers_memory_efficient_attention() print("Loading data") dataset = Flickr8kDataset(args.dataset_root, args.dataset_captions_file) diff --git a/scripts/log_generations_wandb.py b/scripts/log_generations_wandb.py index e2e2a545..df1c87d3 100644 --- a/scripts/log_generations_wandb.py +++ b/scripts/log_generations_wandb.py @@ -32,7 +32,6 @@ def generate_and_log(args): transformer_path=args.transformer, is_class_conditioned=args.is_class_conditioned, ).to(device=args.device) - pipe.transformer.enable_xformers_memory_efficient_attention() imagenet_class_ids = list(range(1000)) with open(args.imagenet_class_mapping_path) as f: diff --git a/training/train_maskgit_imagenet.py b/training/train_maskgit_imagenet.py index 1ef83908..5bc58234 100644 --- a/training/train_maskgit_imagenet.py +++ b/training/train_maskgit_imagenet.py @@ -225,10 +225,6 @@ def main(): # Freeze the VQGAN vq_model.requires_grad_(False) - # Enable flash attention if asked - if config.model.enable_xformers_memory_efficient_attention: - model.enable_xformers_memory_efficient_attention() - optimizer_config = config.optimizer.params learning_rate = optimizer_config.learning_rate if optimizer_config.scale_lr: diff --git a/training/train_muse.py b/training/train_muse.py index 86fb7338..21be30ca 100644 --- a/training/train_muse.py +++ b/training/train_muse.py @@ -250,10 +250,6 @@ def main(): text_encoder.requires_grad_(False) vq_model.requires_grad_(False) - # Enable flash attention if asked - if config.model.enable_xformers_memory_efficient_attention: - model.enable_xformers_memory_efficient_attention() - optimizer_config = config.optimizer.params learning_rate = optimizer_config.learning_rate if optimizer_config.scale_lr: