Skip to content

Commit

Permalink
Remove softplus from blocklist
Browse files Browse the repository at this point in the history
- unblock nn.functional.softplus test
- unblock test_softplus test
  • Loading branch information
Ronian526 committed Jan 27, 2023
1 parent 628cecd commit 21e01b8
Showing 1 changed file with 21 additions and 19 deletions.
40 changes: 21 additions & 19 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4679,29 +4679,31 @@ def helper(shape, dim=0):
for dim in range(len(shape)):
helper(shape, dim)

# # Test softplus
# def test_softplus(self):
# def helper(shape, beta=1, threshold=20):
# cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
# x = cpu_x.detach().clone().to('mps').requires_grad_()
# Test softplus
def test_softplus(self):
def helper(shape, beta=1, threshold=20):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()

# softplus_result = torch.nn.Softplus(beta=beta, threshold=threshold)(x)
# softplus_result_cpu = torch.nn.Softplus(beta=beta, threshold=threshold)(cpu_x)
softplus_result = torch.nn.Softplus(beta=beta, threshold=threshold)(x)
softplus_result_cpu = torch.nn.Softplus(beta=beta, threshold=threshold)(cpu_x)

# cpu_grad = torch.randn(softplus_result.shape)
# grad = cpu_grad.to('mps')
cpu_grad = torch.randn(softplus_result.shape)
grad = cpu_grad.to('mps')

# softplus_result.backward(gradient=grad)
# softplus_result_cpu.backward(gradient=cpu_grad)
softplus_result.backward(gradient=grad)
softplus_result_cpu.backward(gradient=cpu_grad)

# self.assertEqual(softplus_result, softplus_result_cpu)
# self.assertEqual(x.grad, cpu_x.grad)
self.assertEqual(softplus_result, softplus_result_cpu)
self.assertEqual(x.grad, cpu_x.grad)

# # Test empty shape too
# for shape in [(), (2, 3), (10, 10), (2, 3, 4, 5)]:
# for beta in [0.5, 1, 2, 3, 4]:
# for threshold in [0.5, 20, 30, 40, 50]:
# helper(shape, beta, threshold)
# Test empty shape too
for shape in [(), (2, 3), (10, 10), (2, 3, 4, 5)]:
for beta in [0.5, 1, 2, 3, 4]:
for threshold in [0.5, 20, 30, 40, 50]:
helper(shape, beta, threshold)

# Test silu

def test_silu(self):
def helper(shape):
Expand Down Expand Up @@ -9198,6 +9200,7 @@ class TestConsistency(TestCase):
'zeros': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'zeros_like': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'index_add': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'nn.functional.softplus': ['f32'],
}

ALLOWLIST_OP_GRAD = {
Expand Down Expand Up @@ -9379,7 +9382,6 @@ class TestConsistency(TestCase):
# All the entries in this list should be removed
BLOCKLIST = {
# Functions that hard crash
'nn.functional.softplus': [torch.float32],
'sgn': [torch.bool],
'linalg.inv': [torch.float32],
'linalg.inv_ex': [torch.float32],
Expand Down

0 comments on commit 21e01b8

Please sign in to comment.