Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#41 from wjm202/mp_bug_fix
Browse files Browse the repository at this point in the history
Mp bug fix
  • Loading branch information
lyuwenyu authored Aug 3, 2023
2 parents 373cfd2 + 4e36e42 commit 28dec7c
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 20 deletions.
3 changes: 0 additions & 3 deletions paddlevlp/examples/blip2/run_pretrain_stage1.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,6 @@ class PreTrainingArguments(TrainingArguments):
per_device_eval_batch_size : int = field(
default=128, metadata={"help": " Batch size per GPU core/CPU for evaluation. (default:8)"}
)
warmup_start_lr : float = field(
default=1e-6, metadata={"help": " The initial learning rate of blip2."}
)
output_dir : str = field(
default=".", metadata={"help": "The output path"}
)
Expand Down
25 changes: 20 additions & 5 deletions paddlevlp/models/blip2/Qformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(self, config):
self.position_embedding_type = getattr(
config, "position_embedding_type", "absolute"
)
self.mp_degree=config.mp_degree

def forward(
self,
Expand Down Expand Up @@ -134,7 +135,10 @@ def forward(
else:
embeddings = query_embeds
embeddings = self.LayerNorm(embeddings)
with get_rng_state_tracker().rng_state("global_seed"):
if self.mp_degree>1:
with get_rng_state_tracker().rng_state("global_seed"):
embeddings = self.dropout(embeddings)
else:
embeddings = self.dropout(embeddings)
return embeddings

Expand Down Expand Up @@ -182,7 +186,7 @@ def __init__(self, config, is_cross_attention):
else:
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)

self.mp_degree=config.mp_degree
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(
config, "position_embedding_type", "absolute"
Expand Down Expand Up @@ -305,7 +309,10 @@ def forward(

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with get_rng_state_tracker().rng_state("global_seed"):
if self.mp_degree>1:
with get_rng_state_tracker().rng_state("global_seed"):
attention_probs_dropped = self.dropout(attention_probs)
else:
attention_probs_dropped = self.dropout(attention_probs)

# Mask heads if we want to
Expand All @@ -332,10 +339,14 @@ def __init__(self, config):
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.mp_degree=config.mp_degree

def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
with get_rng_state_tracker().rng_state("global_seed"):
if self.mp_degree>1:
with get_rng_state_tracker().rng_state("global_seed"):
hidden_states = self.dropout(hidden_states)
else:
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
Expand Down Expand Up @@ -395,10 +406,14 @@ def __init__(self, config):
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.mp_degree =config.mp_degree

def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
with get_rng_state_tracker().rng_state("global_seed"):
if self.mp_degree>1:
with get_rng_state_tracker().rng_state("global_seed"):
hidden_states = self.dropout(hidden_states)
else:
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
Expand Down
2 changes: 1 addition & 1 deletion paddlevlp/models/blip2/blip2_stage1.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def forward(self, pixel_values,text_input):

image = pixel_values
image_embeds = self.ln_vision(self.visual_encoder(image))
# breakpoint()

image_atts = paddle.ones(image_embeds.shape[:-1], dtype="int64")
query_tokens = self.query_tokens.expand(shape=[image_embeds.shape[0], -1, -1])

Expand Down
24 changes: 19 additions & 5 deletions paddlevlp/models/blip2/eva_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,18 @@ def __init__(self,
else:
self.fc1 = nn.Linear(in_features, hidden_features)
self.fc2 = nn.Linear(hidden_features, out_features)
self.mp_degree = mp_degree
self.act = act_layer()
self.drop = nn.Dropout(drop)

def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
with get_rng_state_tracker().rng_state("global_seed"):
if self.mp_degree>1:
with get_rng_state_tracker().rng_state("global_seed"):
x = self.drop(x)
else:
x = self.drop(x)
return x

Expand Down Expand Up @@ -122,6 +126,7 @@ def __init__(self,
gather_output=True)
else:
self.proj = nn.Linear(dim, dim)
self.mp_degree=mp_degree
self.proj_drop = nn.Dropout(proj_drop)

def _register_relative_position_index(
Expand Down Expand Up @@ -177,12 +182,18 @@ def forward(self, x, rel_pos_bias=None):
attn = attn + relative_position_bias.unsqueeze(0)

attn = nn.functional.softmax(attn, axis=-1)
with get_rng_state_tracker().rng_state("global_seed"):
if self.mp_degree>1:
with get_rng_state_tracker().rng_state("global_seed"):
attn = self.attn_drop(attn)
else:
attn = self.attn_drop(attn)

x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((-1, N, C))
x = self.proj(x)
with get_rng_state_tracker().rng_state("global_seed"):
if self.mp_degree>1:
with get_rng_state_tracker().rng_state("global_seed"):
x = self.proj_drop(x)
else:
x = self.proj_drop(x)
return x

Expand Down Expand Up @@ -394,7 +405,7 @@ def __init__(self,
])

#self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)

self.mp_degree=mp_degree
if self.pos_embed is not None:
trunc_normal_(self.pos_embed)
trunc_normal_(self.cls_token)
Expand All @@ -418,7 +429,10 @@ def forward_features(self, x):

if self.pos_embed is not None:
x = x + self.pos_embed
with get_rng_state_tracker().rng_state("global_seed"):
if self.mp_degree>1:
with get_rng_state_tracker().rng_state("global_seed"):
x = self.pos_drop(x)
else:
x = self.pos_drop(x)
rel_pos_bias = self.rel_pos_bias() if hasattr(self,
'rel_pos_bias') else None
Expand Down
26 changes: 20 additions & 6 deletions paddlevlp/models/blip2/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(
self.dropout = config.attention_probs_dropout_prob
self.need_weights = need_weights
self.fuse_attention_qkv = config.fuse_attention_qkv
self.mp_degree=config.mp_degree

assert (
self.head_dim * self.num_heads * config.mp_degree == config.hidden_size
Expand Down Expand Up @@ -268,7 +269,10 @@ def forward(self, query, key, value, attn_mask=None, use_cache=False, cache=None

weights = F.softmax(product)
if self.dropout:
with get_rng_state_tracker().rng_state("local_seed"):
if self.mp_degree>1:
with get_rng_state_tracker().rng_state("local_seed"):
weights = F.dropout(weights, self.dropout, training=self.training, mode="upscale_in_train")
else:
weights = F.dropout(weights, self.dropout, training=self.training, mode="upscale_in_train")

out = tensor.matmul(weights, v)
Expand Down Expand Up @@ -358,6 +362,7 @@ def __init__(self, config):
self.activation = nn.GELU(approximate=True)
else:
self.activation = getattr(F, activation)
self.mp_degree=config.mp_degree

def forward(self, tgt, memory, tgt_mask=None, use_cache=False, cache=None, output_attentions=False):
residual = tgt
Expand All @@ -370,15 +375,21 @@ def forward(self, tgt, memory, tgt_mask=None, use_cache=False, cache=None, outpu
tgt, attn_weights = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache)
else:
tgt, attn_weights, incremental_cache = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache)
with get_rng_state_tracker().rng_state("global_seed"):
if self.mp_degree>1:
with get_rng_state_tracker().rng_state("global_seed"):
tgt = residual + self.dropout1(tgt)
else:
tgt = residual + self.dropout1(tgt)
if not self.normalize_before:
tgt = self.norm1(tgt)

residual = tgt
if self.normalize_before:
tgt = self.norm2(tgt)
with get_rng_state_tracker().rng_state("global_seed"):
if self.mp_degree>1:
with get_rng_state_tracker().rng_state("global_seed"):
tgt = self.dropout2(self.linear2(self.activation(self.linear1(tgt))))
else:
tgt = self.dropout2(self.linear2(self.activation(self.linear1(tgt))))
tgt = residual + tgt

Expand Down Expand Up @@ -587,7 +598,7 @@ def __init__(self, config: OPTConfig):
embedding_dim=config.hidden_size,
initializer_range=config.initializer_range,
)

self.mp_degree=config.mp_degree
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, input_ids=None, attention_mask=None, input_embeddings=None, past_key_values_length=None):
Expand All @@ -600,9 +611,12 @@ def forward(self, input_ids=None, attention_mask=None, input_embeddings=None, pa
position_embeddings = self.position_embeddings(attention_mask, past_key_values_length)

embeddings = input_embeddings + position_embeddings
with get_rng_state_tracker().rng_state("global_seed"):
if self.mp_degree>1:
with get_rng_state_tracker().rng_state("global_seed"):
embeddings = self.dropout(embeddings)
else:
embeddings = self.dropout(embeddings)
return embeddings
return embeddings


class OPTPretrainedModel(PretrainedModel):
Expand Down

0 comments on commit 28dec7c

Please sign in to comment.