Skip to content

Commit

Permalink
Replace TransformerEncoder in torchtext with better transformer (face…
Browse files Browse the repository at this point in the history
…bookresearch#34)

Summary:
Pull Request resolved: facebookresearch#34

X-link: pytorch/text#1700

Replace the usage of TransformerEncoder by BetterTransformerEncoder
In theory we should be able to remove torchtext.TransformerEncoderLayer after this diff.

Reviewed By: parmeet

Differential Revision: D36084653

fbshipit-source-id: 64ed3810e809fc1db840e75e2e05783089ff31d2
  • Loading branch information
zrphercule2 authored and facebook-github-bot committed May 5, 2022
1 parent d216331 commit 3b537be
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions torchmultimodal/modules/encoders/clip_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,16 @@ def initialize_parameters(self) -> None:
std=self.POS_EMBEDDING_INIT_STD,
)

proj_std = (self.width ** -0.5) * ((2 * len(self.encoder.layers)) ** -0.5)
proj_std = (self.width ** -0.5) * ((2 * self.encoder.layers.num_layers) ** -0.5)
attn_std = self.width ** -0.5
fc_std = (2 * self.width) ** -0.5
for layer in self.encoder.layers:
nn.init.normal_(
layer.better_transformer.self_attn.in_proj_weight, std=attn_std
)
nn.init.normal_(
layer.better_transformer.self_attn.out_proj.weight, std=proj_std
)
for layer in self.encoder.layers.layers:
nn.init.normal_(layer.self_attn.in_proj_weight, std=attn_std)
nn.init.normal_(layer.self_attn.out_proj.weight, std=proj_std)
# c_fc in CLIP corresponds to the first residual MLP layer
nn.init.normal_(layer.better_transformer.linear1.weight, std=fc_std)
nn.init.normal_(layer.linear1.weight, std=fc_std)
# c_proj in CLIP corresponds to the last residual MLP layer
nn.init.normal_(layer.better_transformer.linear2.weight, std=proj_std)
nn.init.normal_(layer.linear2.weight, std=proj_std)

# Initialize projection
nn.init.normal_(self.projection.weight, std=self.width ** -0.5)
Expand Down

0 comments on commit 3b537be

Please sign in to comment.