diff --git a/torchtext/models/roberta/model.py b/torchtext/models/roberta/model.py index 60a4362550..4f68518b5e 100644 --- a/torchtext/models/roberta/model.py +++ b/torchtext/models/roberta/model.py @@ -10,6 +10,7 @@ from .modules import ( TransformerEncoder, + ProjectionLayer, ) import logging logger = logging.getLogger(__name__) @@ -25,6 +26,8 @@ class RobertaEncoderConf: num_attention_heads: int = 12 num_encoder_layers: int = 12 dropout: float = 0.1 + projection_dim: Optional[int] = None + projection_dropout: Optional[float] = None scaling: Optional[float] = None normalize_before: bool = False @@ -40,6 +43,8 @@ def __init__( num_attention_heads: int, num_encoder_layers: int, dropout: float = 0.1, + projection_dim: Optional[int] = None, + projection_dropout: Optional[float] = None, scaling: Optional[float] = None, normalize_before: bool = False, ): @@ -62,6 +67,10 @@ def __init__( return_all_layers=False, ) + self.project = nn.Identity() + if projection_dim is not None: + self.project = ProjectionLayer(embed_dim=embedding_dim, projection_dim=projection_dim, dropout=projection_dropout) + @classmethod def from_config(cls, config: RobertaEncoderConf): return cls(**asdict(config)) @@ -73,6 +82,9 @@ def forward(self, tokens: Tensor, mask: Optional[Tensor] = None) -> Tensor: output = output.transpose(1, 0) if mask is not None: output = output[mask.to(torch.bool), :] + + output = self.project(output) + return output diff --git a/torchtext/models/roberta/modules.py b/torchtext/models/roberta/modules.py index 901e896270..0c3702d28b 100644 --- a/torchtext/models/roberta/modules.py +++ b/torchtext/models/roberta/modules.py @@ -31,6 +31,27 @@ def _make_positions(self, tensor, pad_index: int): return torch.cumsum(masked, dim=1) * masked + pad_index +class ProjectionLayer(Module): + def __init__(self, + embed_dim: int, + projection_dim: int, + dropout: Optional[float] = None) -> None: + super().__init__() + + self.projection_layer = nn.Linear(embed_dim, projection_dim) + self.norm_layer = nn.LayerNorm(projection_dim) + if dropout is not None: + self.dropout_layer = nn.Dropout(dropout) + else: + self.dropout_layer = nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.projection_layer(x) + x = self.norm_layer(x) + x = self.dropout_layer(x) + return x + + class ResidualMLP(Module): def __init__( self,