Skip to content

Commit

Permalink
add log_softmax cudnn test
Browse files Browse the repository at this point in the history
  • Loading branch information
altanh committed Jun 30, 2021
1 parent 68d81a1 commit 8d8fdc1
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions tests/python/contrib/test_cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,30 +176,40 @@ def test_conv3d():
verify_conv3d("float32", "float32", tensor_format=0, groups=2)


def verify_softmax(shape, axis, dtype="float32"):
def verify_softmax(shape, axis, dtype="float32", log_softmax=False):
cudnn_op = cudnn.log_softmx if log_softmax else cudnn.softmax
testing_op = (
tvm.topi.testing.log_softmax_python if log_softmax else tvm.topi.testing.softmax_python
)

A = te.placeholder(shape, dtype=dtype, name="A")
B = cudnn.softmax(A, axis)
B = cudnn_op(A, axis)
s = te.create_schedule([B.op])

dev = tvm.cuda(0)
a_np = np.random.uniform(size=shape).astype(dtype)
b_np = tvm.topi.testing.softmax_python(a_np)
b_np = testing_op(a_np)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
f = tvm.build(s, [A, B], target="cuda --host=llvm", name="softmax")
f(a, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-3)


def verify_softmax_4d(shape, dtype="float32"):
def verify_softmax_4d(shape, dtype="float32", log_softmax=False):
cudnn_op = cudnn.log_softmx if log_softmax else cudnn.softmax
testing_op = (
tvm.topi.testing.log_softmax_python if log_softmax else tvm.topi.testing.softmax_python
)

A = te.placeholder(shape, dtype=dtype, name="A")
B = cudnn.softmax(A, axis=1)
B = cudnn_op(A, axis=1)
s = te.create_schedule([B.op])

dev = tvm.cuda(0)
n, c, h, w = shape
a_np = np.random.uniform(size=shape).astype(dtype)
b_np = tvm.topi.testing.softmax_python(a_np.transpose(0, 2, 3, 1).reshape(h * w, c))
b_np = testing_op(a_np.transpose(0, 2, 3, 1).reshape(h * w, c))
b_np = b_np.reshape(n, h, w, c).transpose(0, 3, 1, 2)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
Expand All @@ -217,6 +227,12 @@ def test_softmax():
verify_softmax_4d((1, 16, 256, 256))
verify_softmax_4d((1, 16, 256, 256), "float64")

verify_softmax((32, 10), -1, log_softmax=True)
verify_softmax((3, 4), -1, log_softmax=True)
verify_softmax((1, 5), -1, "float64", log_softmax=True)
verify_softmax_4d((1, 16, 256, 256), log_softmax=True)
verify_softmax_4d((1, 16, 256, 256), "float64", log_softmax=True)


test_kwargs_default_2d = {
"tensor_format": 0,
Expand Down

0 comments on commit 8d8fdc1

Please sign in to comment.