-
Notifications
You must be signed in to change notification settings - Fork 180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Re-network the DIT, fix some parameters, and simplify the model networking code #632
Re-network the DIT, fix some parameters, and simplify the model networking code #632
Conversation
Thanks for your contribution! |
if qkv is not None: | ||
state_dict[qkv_key_b] = paddle.concat([qkv, state_dict.pop(key)], axis=-1) | ||
|
||
for key in list(state_dict.keys()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
518行以下改成
map_from_my_dit = {}
for i in range(28):
map_from_my_dit[f'tmp_ZKKFacebookDIT.qkv.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_qkv.weight'
map_from_my_dit[f'tmp_ZKKFacebookDIT.qkv.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_qkv.bias'
map_from_my_dit[f'tmp_ZKKFacebookDIT.out_proj.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_out.0.weight'
map_from_my_dit[f'tmp_ZKKFacebookDIT.out_proj.{i}.bias'] = f'transformer_blocks.{i}.attn1.to_out.0.bias'
map_from_my_dit[f'tmp_ZKKFacebookDIT.ffn1.{i}.weight'] = f'transformer_blocks.{i}.ff.net.0.proj.weight'
map_from_my_dit[f'tmp_ZKKFacebookDIT.ffn1.{i}.bias'] = f'transformer_blocks.{i}.ff.net.0.proj.bias'
map_from_my_dit[f'tmp_ZKKFacebookDIT.ffn2.{i}.weight'] = f'transformer_blocks.{i}.ff.net.2.weight'
map_from_my_dit[f'tmp_ZKKFacebookDIT.ffn2.{i}.bias'] = f'transformer_blocks.{i}.ff.net.2.bias'
map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs0.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.weight'
map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs0.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_1.bias'
map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs1.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.weight'
map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs1.{i}.bias'] = f'transformer_blocks.{i}.norm1.emb.timestep_embedder.linear_2.bias'
map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs2.{i}.weight'] = f'transformer_blocks.{i}.norm1.linear.weight'
map_from_my_dit[f'tmp_ZKKFacebookDIT.fcs2.{i}.bias'] = f'transformer_blocks.{i}.norm1.linear.bias'
map_from_my_dit[f'tmp_ZKKFacebookDIT.embs.{i}.weight'] = f'transformer_blocks.{i}.norm1.emb.class_embedder.embedding_table.weight'
for key in map_from_my_dit.keys():
state_dict[key] = paddle.assign(state_dict[map_from_my_dit[key]])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更改!
感谢提供修改意见,辛苦!
def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int): | ||
super().__init__() | ||
self.num_layers = num_layers | ||
self.dtype = "float16" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.dtype = "float16"
改成可配置的。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更改!
感谢提供修改意见,辛苦!
@@ -1130,6 +1134,8 @@ def _find_mismatched_keys( | |||
error_msgs.append( | |||
f"Error size mismatch, {key_name} receives a shape {loaded_shape}, but the expected shape is {model_shape}." | |||
) | |||
if os.getenv('Inference_Optimize'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里去掉,改在transformer_2d.py
里面判断吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更改!
感谢提供修改意见,辛苦!
@@ -28,11 +28,15 @@ | |||
recompute_use_reentrant, | |||
use_old_recompute, | |||
) | |||
from .simplified_facebook_dit import Simplified_FacebookDIT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplified_FacebookDIT
改成SimplifiedFacebookDIT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更改!
感谢提供修改意见,辛苦!
@@ -213,6 +219,8 @@ def __init__( | |||
for d in range(num_layers) | |||
] | |||
) | |||
if self.Inference_Optimize: | |||
self.simplified_facebookDIT = SimplifiedFacebookDIT(num_layers, inner_dim, num_attention_heads, attention_head_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里del self.transformer_blocks
吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修改该项会引发相关报错,因为该方法还需要在其他位置调用,暂时不做更改!
感谢提供修改意见,辛苦!
@@ -114,6 +118,8 @@ def __init__( | |||
self.inner_dim = inner_dim = num_attention_heads * attention_head_dim | |||
self.data_format = data_format | |||
|
|||
self.Inference_Optimize = bool(os.getenv('Inference_Optimize')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.Inference_Optimize = os.getenv('Inference_Optimize') == "True"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更改!
感谢提供修改意见,辛苦!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
应该把代码format一下
return | ||
map_from_my_dit = {} | ||
for i in range(28): | ||
map_from_my_dit[f'simplified_facebookDIT.q.{i}.weight'] = f'transformer_blocks.{i}.attn1.to_q.weight' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
尽量减少代码的拷贝,例如公共的命名前缀应该抽出来,避免后续修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
尽量减少代码的拷贝,例如公共的命名前缀应该抽出来,避免后续修改
已更改,折叠了部分命名代码!
感谢提供修改意见,辛苦!
from ppdiffusers import DDIMScheduler, DiTPipeline | ||
|
||
dtype = paddle.float32 | ||
os.environ["Inference_Optimize"] = "False" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
环境变量全都大写吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更改!
感谢提供修改意见,辛苦!
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", paddle_dtype=dtype) | ||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | ||
set_seed(42) | ||
|
||
words = ["golden retriever"] # class_ids [207] | ||
class_ids = pipe.get_label_ids(words) | ||
|
||
# warmup | ||
for i in range(5): | ||
image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里只是为了测benchmark,实际用户并不需要warmpup。看下是否增加benchmark开关。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更改,添加benchmark & inference_optimize 的相关开关!
感谢提供修改意见,辛苦!
|
||
|
||
import datetime | ||
import time |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import移动到前面
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更改!
感谢提供修改意见,辛苦!
|
||
image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] | ||
for i in range(repeat_times): | ||
image = pipe(class_labels=class_ids, num_inference_steps=25).images[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,benchmark才需要,用户使用不需要
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更改!
感谢提供修改意见,辛苦!
enable_new_ir=True, | ||
cache_static_model=False, | ||
exp_enable_use_cutlass=True, | ||
delete_pass_lists=["add_norm_fuse_pass"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
遵守代码规范,一行不会要超过80字符
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已使用pre-commit调整!
感谢提供修改意见,辛苦!
@@ -114,6 +118,8 @@ def __init__( | |||
self.inner_dim = inner_dim = num_attention_heads * attention_head_dim | |||
self.data_format = data_format | |||
|
|||
self.Inference_Optimize = os.getenv('Inference_Optimize') == "True" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.inference_optimize ,遵守命名规范
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更改!
感谢提供修改意见,辛苦!
import paddle.nn.functional as F | ||
import math | ||
|
||
class SimplifiedFacebookDIT(nn.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
必须一定要简化这个模块吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
必须一定要简化这个模块吗?
手工优化需要
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
手工优化需要对原动态图模型组网 做高性能精简重组,这一模块还将transformer循环中的冗余计算部分提出,减少了部分计算量。
感谢提供修改意见,辛苦!
@@ -221,7 +240,9 @@ def __init__( | |||
if use_linear_projection: | |||
self.proj_out = linear_cls(inner_dim, in_channels) | |||
else: | |||
self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0, data_format=data_format) | |||
self.proj_out = conv_cls( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
格式修改请忽略
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
格式修改请忽略
采用pre-commit统一修改格式!
感谢提供修改意见,辛苦!
@@ -154,11 +158,15 @@ def __init__( | |||
if self.is_input_continuous: | |||
self.in_channels = in_channels | |||
|
|||
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, epsilon=1e-6, data_format=data_format) | |||
self.norm = nn.GroupNorm( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
格式修改请忽略
if use_linear_projection: | ||
self.proj_in = linear_cls(in_channels, inner_dim) | ||
else: | ||
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0, data_format=data_format) | ||
self.proj_in = conv_cls( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
格式修改请忽略
Latest optimization: Re-network DIT, simplify the original model dynamic graph into a high-performance model network,
paddle.incubate.jit.inference
to do dynamic and static conversion, and removes redundant parts in the loop;Currently facebook-DIT takes: 219.936 ms