Skip to content

Commit

Permalink
[ppdiffusers] add uvit recompute (#347)
Browse files Browse the repository at this point in the history
  • Loading branch information
nemonameless authored Dec 12, 2023
1 parent 9a74e91 commit f9daafc
Showing 1 changed file with 27 additions and 12 deletions.
39 changes: 27 additions & 12 deletions ppdiffusers/ppdiffusers/models/uvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Optional

import einops
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
Expand Down Expand Up @@ -53,7 +54,7 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

self._use_memory_efficient_attention_xformers = False
self._use_memory_efficient_attention_xformers = is_ppxformers_available()
self._attention_op = None

def reshape_heads_to_batch_dim(self, tensor, transpose=True):
Expand All @@ -71,9 +72,8 @@ def reshape_batch_dim_to_heads(self, tensor, transpose=True):
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[str] = None
):
# remove this PR: https://github.com/PaddlePaddle/Paddle/pull/56045
# if self.head_size > 128 and attention_op == "flash":
# attention_op = "cutlass"
if self.head_size > 128 and attention_op == "flash":
attention_op = "cutlass"
if use_memory_efficient_attention_xformers:
if not is_ppxformers_available():
raise NotImplementedError(
Expand Down Expand Up @@ -194,6 +194,7 @@ class UViTModel(ModelMixin, ConfigMixin):
after concatenat-ing a long skip connection, which stabilizes the training of U-ViT in UniDiffuser.
"""
_supports_gradient_checkpointing = True

@register_to_config
def __init__(
Expand Down Expand Up @@ -253,6 +254,7 @@ def __init__(
norm_layer = nn.LayerNorm
self.pos_drop = nn.Dropout(p=pos_drop_rate)

dpr = np.linspace(0, drop_rate, depth + 1)
self.in_blocks = nn.LayerList(
[
Block(
Expand All @@ -261,11 +263,11 @@ def __init__(
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
drop=dpr[i],
attn_drop=attn_drop_rate,
norm_layer=norm_layer,
)
for _ in range(depth // 2)
for i in range(depth // 2)
]
)

Expand All @@ -275,7 +277,7 @@ def __init__(
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
drop=dpr[depth // 2],
attn_drop=attn_drop_rate,
norm_layer=norm_layer,
)
Expand All @@ -288,12 +290,12 @@ def __init__(
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
drop=dpr[i + 1 + depth // 2],
attn_drop=attn_drop_rate,
norm_layer=norm_layer,
skip=True,
)
for _ in range(depth // 2)
for i in range(depth // 2)
]
)

Expand All @@ -305,6 +307,10 @@ def __init__(
self.pos_embed_token = self.create_parameter(
shape=(1, 1, embed_dim), default_initializer=nn.initializer.Constant(0.0)
)
self.gradient_checkpointing = False

def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value

def forward(
self,
Expand Down Expand Up @@ -362,13 +368,22 @@ def forward(

skips = []
for blk in self.in_blocks:
x = blk(x)
if self.gradient_checkpointing:
x = paddle.distributed.fleet.utils.recompute(blk, x)
else:
x = blk(x)
skips.append(x)

x = self.mid_block(x)
if self.gradient_checkpointing:
x = paddle.distributed.fleet.utils.recompute(self.mid_block, x)
else:
x = self.mid_block(x)

for blk in self.out_blocks:
x = blk(x, skips.pop())
if self.gradient_checkpointing:
x = paddle.distributed.fleet.utils.recompute(blk, x, skips.pop())
else:
x = blk(x, skips.pop())

x = self.norm(x)

Expand Down

0 comments on commit f9daafc

Please sign in to comment.