From 3b537be6594d0ccaad044ecbe968a09c8877a4e4 Mon Sep 17 00:00:00 2001 From: Rui Zhu Date: Thu, 5 May 2022 10:03:51 -0700 Subject: [PATCH] Replace TransformerEncoder in torchtext with better transformer (#34) Summary: Pull Request resolved: https://github.com/facebookresearch/multimodal/pull/34 X-link: https://github.com/pytorch/text/pull/1700 Replace the usage of TransformerEncoder by BetterTransformerEncoder In theory we should be able to remove torchtext.TransformerEncoderLayer after this diff. Reviewed By: parmeet Differential Revision: D36084653 fbshipit-source-id: 64ed3810e809fc1db840e75e2e05783089ff31d2 --- .../modules/encoders/clip_text_encoder.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/torchmultimodal/modules/encoders/clip_text_encoder.py b/torchmultimodal/modules/encoders/clip_text_encoder.py index d85cd345..920d3f06 100644 --- a/torchmultimodal/modules/encoders/clip_text_encoder.py +++ b/torchmultimodal/modules/encoders/clip_text_encoder.py @@ -71,20 +71,16 @@ def initialize_parameters(self) -> None: std=self.POS_EMBEDDING_INIT_STD, ) - proj_std = (self.width ** -0.5) * ((2 * len(self.encoder.layers)) ** -0.5) + proj_std = (self.width ** -0.5) * ((2 * self.encoder.layers.num_layers) ** -0.5) attn_std = self.width ** -0.5 fc_std = (2 * self.width) ** -0.5 - for layer in self.encoder.layers: - nn.init.normal_( - layer.better_transformer.self_attn.in_proj_weight, std=attn_std - ) - nn.init.normal_( - layer.better_transformer.self_attn.out_proj.weight, std=proj_std - ) + for layer in self.encoder.layers.layers: + nn.init.normal_(layer.self_attn.in_proj_weight, std=attn_std) + nn.init.normal_(layer.self_attn.out_proj.weight, std=proj_std) # c_fc in CLIP corresponds to the first residual MLP layer - nn.init.normal_(layer.better_transformer.linear1.weight, std=fc_std) + nn.init.normal_(layer.linear1.weight, std=fc_std) # c_proj in CLIP corresponds to the last residual MLP layer - nn.init.normal_(layer.better_transformer.linear2.weight, std=proj_std) + nn.init.normal_(layer.linear2.weight, std=proj_std) # Initialize projection nn.init.normal_(self.projection.weight, std=self.width ** -0.5)