Skip to content

Commit

Permalink
fix naming
Browse files Browse the repository at this point in the history
  • Loading branch information
parmeet committed Nov 28, 2021
1 parent eeedc24 commit 6158899
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions torchtext/models/roberta/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ def __init__(
for p in self.parameters():
p.requires_grad = False

def forward(self, tokens: Tensor, maked_tokens: Optional[Tensor] = None) -> Tensor:
def forward(self, tokens: Tensor, masked_tokens: Optional[Tensor] = None) -> Tensor:
output = self.transformer(tokens)
if torch.jit.isinstance(output, List[Tensor]):
output = output[-1]
output = output.transpose(1, 0)
if maked_tokens is not None:
output = output[maked_tokens.to(torch.bool), :]
if masked_tokens is not None:
output = output[masked_tokens.to(torch.bool), :]
return output


Expand Down Expand Up @@ -119,8 +119,8 @@ def __init__(self,
self.encoder = RobertaEncoder(**asdict(encoder_conf), freeze=freeze_encoder)
self.head = head

def forward(self, tokens: Tensor, mask: Optional[Tensor] = None) -> Tensor:
features = self.encoder(tokens, mask)
def forward(self, tokens: Tensor, masked_tokens: Optional[Tensor] = None) -> Tensor:
features = self.encoder(tokens, masked_tokens)
if self.head is None:
return features

Expand Down

0 comments on commit 6158899

Please sign in to comment.