Skip to content

Commit

Permalink
[BERT/PyT] fix onnx export (NVIDIA#689)
Browse files Browse the repository at this point in the history
  • Loading branch information
sharathts authored and changlan committed Apr 5, 2021
1 parent c29efa3 commit 1f65809
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions PyTorch/LanguageModeling/BERT/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,9 @@ def __init__(self, hidden_size, eps=1e-12):

def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
s = (x - u)
s = s * s
s = s.mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias

Expand Down Expand Up @@ -323,7 +325,9 @@ def forward(self, x):
x = self.fused_layer_norm(x)
else:
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
s = (x - u)
s = s * s
s = s.mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight * x + self.bias
return x
Expand Down

0 comments on commit 1f65809

Please sign in to comment.