Skip to content

Commit

Permalink
Fix layer_norm_1p Apex call to match Apex API
Browse files Browse the repository at this point in the history
Signed-off-by: Shriya Palsamudram <spalsamudram@nvidia.com>
  • Loading branch information
ShriyaPalsamudram committed Jan 9, 2024
1 parent 76a712a commit d210a75
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def reset_parameters(self):
torch.nn.init.zeros_(self.bias)

def forward(self, x):
return _fast_layer_norm(x, self.weight + 1, self.bias, self.epsilon)
return _fast_layer_norm(x, self.weight + 1, self.bias, self.epsilon, self.memory_efficient)


else:
Expand Down

0 comments on commit d210a75

Please sign in to comment.