Skip to content

Commit

Permalink
Merge pull request huggingface#18 from stevezheng23/dev/zheng/quac
Browse files Browse the repository at this point in the history
fix at issues in roberta/berta modeling (cont.)
  • Loading branch information
stevezheng23 authored Oct 30, 2019
2 parents 97c6ac9 + c7e3cae commit 245834d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions transformers/modeling_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,11 +1004,11 @@ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_

at_span_loss.backward()
at_perturb_grads = at_perturbs.grad.detach()
at_perturb_grads /= torch.norm(at_perturb_grads, p='fro', dim=-1)
at_perturb_grads /= torch.norm(at_perturb_grads, p='fro', dim=-1, keepdim=True)
at_updated_perturbs = at_perturbs + self.at_alpha * at_perturb_grads
at_updated_perturbs = torch.clamp(at_updated_perturbs, min=-self.at_epsilon, max=self.at_epsilon)
at_perturbs.data = at_updated_perturbs
at_perturbs.grad.zero()
at_perturbs.grad.zero_()

at_outputs = self.roberta(input_ids,
attention_mask=attention_mask,
Expand Down

0 comments on commit 245834d

Please sign in to comment.