diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 0ca3c3d6d423..2227d440126d 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -582,3 +582,29 @@ def softmax(x, axis=-1): ), name="y", ) + + +def log_softmax(x, axis=-1): + """Compute log_softmax using CuDNN + + Parameters + ---------- + x : tvm.te.Tensor + The input tensor + + axis : int + The axis to compute log softmax over + + Returns + ------- + ret : tvm.te.Tensor + The result tensor + """ + return te.extern( + x.shape, + [x], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.cudnn.log_softmax.forward", ins[0], outs[0], axis + ), + name="y", + ) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 056cb5694a48..753a17605667 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -48,7 +48,7 @@ # log_softmax -reg.register_schedule("nn.log_softmax", strategy.schedule_log_softmax) +reg.register_strategy("nn.log_softmax", strategy.log_softmax_strategy) reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index dd265e4b4d5b..aeeb62af11a9 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -101,11 +101,23 @@ def fast_softmax_strategy_cuda(attrs, inputs, out_type, target): return strategy -@schedule_log_softmax.register(["cuda", "gpu"]) -def schedule_log_softmax_cuda(attrs, outs, target): - """scheudle log_softmax for cuda""" - with target: - return topi.cuda.schedule_softmax(outs) +@log_softmax_strategy.register(["cuda", "gpu"]) +def log_softmax_strategy_cuda(attrs, inputs, out_type, target): + """log_softmax cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_softmax(topi.nn.log_softmax), + wrap_topi_schedule(topi.cuda.schedule_softmax), + name="log_softmax.cuda", + ) + if target.kind.name == "cuda" and "cudnn" in target.libs: + strategy.add_implementation( + wrap_compute_softmax(topi.cuda.log_softmax_cudnn), + wrap_topi_schedule(topi.cuda.schedule_log_softmax_cudnn), + name="log_softmax.cudnn", + plevel=15, + ) + return strategy @schedule_lrn.register(["cuda", "gpu"]) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 5cb3f65f3ebe..3348d8033904 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -181,12 +181,15 @@ def fast_softmax_strategy(attrs, inputs, out_type, target): return strategy -# log_softmax -@generic_func -def schedule_log_softmax(attrs, outs, target): - """Schedule log_softmax op""" - with target: - return topi.generic.schedule_softmax(outs) +@override_native_generic_func("log_softmax_strategy") +def log_softmax_strategy(attrs, inputs, out_type, target): + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_softmax(topi.nn.log_softmax), + wrap_topi_schedule(topi.generic.schedule_softmax), + name="log_softmax.generic", + ) + return strategy # lrn diff --git a/python/tvm/relay/op/strategy/hls.py b/python/tvm/relay/op/strategy/hls.py index 761ac9e6aa01..b147af06cfc3 100644 --- a/python/tvm/relay/op/strategy/hls.py +++ b/python/tvm/relay/op/strategy/hls.py @@ -68,11 +68,16 @@ def softmax_strategy_hls(attrs, inputs, out_type, target): return strategy -@schedule_log_softmax.register("hls") -def schedule_log_softmax_hls(attrs, inputs, out_type, target): - """schedule log_softmax for hls""" - with target: - return topi.hls.schedule_softmax(outs) +@log_softmax_strategy.register("hls") +def log_softmax_strategy_hls(attrs, inputs, out_type, target): + """log_softmax hls strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_softmax(topi.nn.log_softmax), + wrap_topi_schedule(topi.hls.schedule_softmax), + name="log_softmax.hls", + ) + return strategy @override_native_generic_func("conv2d_strategy") diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index d09d90a50d41..6a4030514580 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -91,11 +91,16 @@ def fast_softmax_strategy_cpu(attrs, inputs, out_type, target): return strategy -@schedule_log_softmax.register("cpu") -def schedule_log_softmax_cpu(attrs, outs, target): - """schedule log_softmax op for x86""" - with target: - return topi.x86.schedule_softmax(outs) +@log_softmax_strategy.register("cpu") +def log_softmax_strategy_cpu(attrs, inputs, out_type, target): + """log_softmax x86 strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_softmax(topi.nn.log_softmax), + wrap_topi_schedule(topi.x86.schedule_softmax), + name="log_softmax.x86", + ) + return strategy @conv2d_strategy.register("cpu") diff --git a/python/tvm/topi/cuda/softmax.py b/python/tvm/topi/cuda/softmax.py index b743aefc50d5..516d4f93672e 100644 --- a/python/tvm/topi/cuda/softmax.py +++ b/python/tvm/topi/cuda/softmax.py @@ -164,3 +164,13 @@ def softmax_cudnn(x, axis=-1): def schedule_softmax_cudnn(outs): """Schedule for softmax cudnn op""" return generic.schedule_extern(outs) + + +def log_softmax_cudnn(x, axis=-1): + """Perform log_softmax on the data using cudnn""" + return cudnn.log_softmax(x, axis) + + +def schedule_log_softmax_cudnn(outs): + """Schedule for log_softmax cudnn op""" + return generic.schedule_extern(outs) diff --git a/python/tvm/topi/nn/softmax.py b/python/tvm/topi/nn/softmax.py index 6d2bb1548de5..a13b17686708 100644 --- a/python/tvm/topi/nn/softmax.py +++ b/python/tvm/topi/nn/softmax.py @@ -123,7 +123,7 @@ def _normalize(exp, expsum, *indices): @tvm.te.tag_scope(tag="log_softmax_output") -def log_softmax(x): +def log_softmax(x, axis=-1): """Perform log softmax activation on the data Parameters @@ -136,8 +136,9 @@ def log_softmax(x): output : tvm.te.Tensor 2-D output with same shape """ - assert len(x.shape) == 2, "only support 2-dim log softmax" + # pylint: disable=R1714 + assert axis == -1 or axis == len(x.shape) - 1, "only support last axis log softmax" m, n = x.shape k = te.reduce_axis((0, n), name="k") max_elem = te.compute((m,), lambda i: tvm.te.max(x[i, k], axis=k)) diff --git a/src/runtime/contrib/cudnn/softmax.cc b/src/runtime/contrib/cudnn/softmax.cc index 648c9b633ea4..4b37c428e0b5 100644 --- a/src/runtime/contrib/cudnn/softmax.cc +++ b/src/runtime/contrib/cudnn/softmax.cc @@ -31,54 +31,59 @@ namespace contrib { using namespace runtime; -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward") - .set_body([](TVMArgs args, TVMRetValue* ret) { - DLTensor* x = args[0]; - DLTensor* y = args[1]; - int axis = args[2]; - int ndim = x->ndim; - int64_t* shape = x->shape; - if (axis < 0) axis += ndim; - ICHECK(axis >= 0 && axis < ndim); +void softmax_impl(cudnnSoftmaxAlgorithm_t alg, TVMArgs args, TVMRetValue* ret) { + DLTensor* x = args[0]; + DLTensor* y = args[1]; + int axis = args[2]; + int ndim = x->ndim; + int64_t* shape = x->shape; + if (axis < 0) axis += ndim; + ICHECK(axis >= 0 && axis < ndim); - CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); - entry_ptr->softmax_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype); + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + entry_ptr->softmax_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype); - // Set mode and shape descriptor - if (axis == ndim - 1) { - int64_t N = 1; - for (int i = 0; i < ndim - 1; ++i) { - N *= shape[i]; - } - entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_INSTANCE; - CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc, - CUDNN_TENSOR_NCHW, entry_ptr->softmax_entry.data_type, - static_cast(N), - static_cast(shape[ndim - 1]), 1, 1)); - } else { - int64_t pre_axis_dim = 1; - int64_t post_axis_dim = 1; - for (int i = 0; i < ndim; ++i) { - if (i < axis) { - pre_axis_dim *= shape[i]; - } else if (i > axis) { - post_axis_dim *= shape[i]; - } - } - entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_CHANNEL; - CUDNN_CALL(cudnnSetTensor4dDescriptor( - entry_ptr->softmax_entry.shape_desc, CUDNN_TENSOR_NCHW, - entry_ptr->softmax_entry.data_type, static_cast(pre_axis_dim), - static_cast(shape[axis]), static_cast(post_axis_dim), 1)); + // Set mode and shape descriptor + if (axis == ndim - 1) { + int64_t N = 1; + for (int i = 0; i < ndim - 1; ++i) { + N *= shape[i]; + } + entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_INSTANCE; + CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc, CUDNN_TENSOR_NCHW, + entry_ptr->softmax_entry.data_type, static_cast(N), + static_cast(shape[ndim - 1]), 1, 1)); + } else { + int64_t pre_axis_dim = 1; + int64_t post_axis_dim = 1; + for (int i = 0; i < ndim; ++i) { + if (i < axis) { + pre_axis_dim *= shape[i]; + } else if (i > axis) { + post_axis_dim *= shape[i]; } + } + entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_CHANNEL; + CUDNN_CALL(cudnnSetTensor4dDescriptor( + entry_ptr->softmax_entry.shape_desc, CUDNN_TENSOR_NCHW, entry_ptr->softmax_entry.data_type, + static_cast(pre_axis_dim), static_cast(shape[axis]), + static_cast(post_axis_dim), 1)); + } + + auto alpha = CuDNNDataType::GetConst<1>(entry_ptr->softmax_entry.data_type); + auto beta = CuDNNDataType::GetConst<0>(entry_ptr->softmax_entry.data_type); + CUDNN_CALL(cudnnSoftmaxForward(entry_ptr->handle, alg, entry_ptr->softmax_entry.mode, alpha, + entry_ptr->softmax_entry.shape_desc, x->data, beta, + entry_ptr->softmax_entry.shape_desc, y->data)); +} - auto alpha = CuDNNDataType::GetConst<1>(entry_ptr->softmax_entry.data_type); - auto beta = CuDNNDataType::GetConst<0>(entry_ptr->softmax_entry.data_type); - CUDNN_CALL(cudnnSoftmaxForward(entry_ptr->handle, CUDNN_SOFTMAX_ACCURATE, - entry_ptr->softmax_entry.mode, alpha, - entry_ptr->softmax_entry.shape_desc, x->data, beta, - entry_ptr->softmax_entry.shape_desc, y->data)); +TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward") + .set_body([](TVMArgs args, TVMRetValue* ret) { + softmax_impl(CUDNN_SOFTMAX_ACCURATE, args, ret); }); +TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.log_softmax.forward") + .set_body([](TVMArgs args, TVMRetValue* ret) { softmax_impl(CUDNN_SOFTMAX_LOG, args, ret); }); + } // namespace contrib } // namespace tvm diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index 7651bdea36a6..069f3f3769b5 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -176,14 +176,19 @@ def test_conv3d(): verify_conv3d("float32", "float32", tensor_format=0, groups=2) -def verify_softmax(shape, axis, dtype="float32"): +def verify_softmax(shape, axis, dtype="float32", log_softmax=False): + cudnn_op = cudnn.log_softmax if log_softmax else cudnn.softmax + testing_op = ( + tvm.topi.testing.log_softmax_python if log_softmax else tvm.topi.testing.softmax_python + ) + A = te.placeholder(shape, dtype=dtype, name="A") - B = cudnn.softmax(A, axis) + B = cudnn_op(A, axis) s = te.create_schedule([B.op]) dev = tvm.cuda(0) a_np = np.random.uniform(size=shape).astype(dtype) - b_np = tvm.topi.testing.softmax_python(a_np) + b_np = testing_op(a_np) a = tvm.nd.array(a_np, dev) b = tvm.nd.array(b_np, dev) f = tvm.build(s, [A, B], target="cuda --host=llvm", name="softmax") @@ -191,15 +196,20 @@ def verify_softmax(shape, axis, dtype="float32"): tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-3) -def verify_softmax_4d(shape, dtype="float32"): +def verify_softmax_4d(shape, dtype="float32", log_softmax=False): + cudnn_op = cudnn.log_softmax if log_softmax else cudnn.softmax + testing_op = ( + tvm.topi.testing.log_softmax_python if log_softmax else tvm.topi.testing.softmax_python + ) + A = te.placeholder(shape, dtype=dtype, name="A") - B = cudnn.softmax(A, axis=1) + B = cudnn_op(A, axis=1) s = te.create_schedule([B.op]) dev = tvm.cuda(0) n, c, h, w = shape a_np = np.random.uniform(size=shape).astype(dtype) - b_np = tvm.topi.testing.softmax_python(a_np.transpose(0, 2, 3, 1).reshape(h * w, c)) + b_np = testing_op(a_np.transpose(0, 2, 3, 1).reshape(h * w, c)) b_np = b_np.reshape(n, h, w, c).transpose(0, 3, 1, 2) a = tvm.nd.array(a_np, dev) b = tvm.nd.array(b_np, dev) @@ -217,6 +227,12 @@ def test_softmax(): verify_softmax_4d((1, 16, 256, 256)) verify_softmax_4d((1, 16, 256, 256), "float64") + verify_softmax((32, 10), -1, log_softmax=True) + verify_softmax((3, 4), -1, log_softmax=True) + verify_softmax((1, 5), -1, "float64", log_softmax=True) + verify_softmax_4d((1, 16, 256, 256), log_softmax=True) + verify_softmax_4d((1, 16, 256, 256), "float64", log_softmax=True) + test_kwargs_default_2d = { "tensor_format": 0,