Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cuDNN] Add support for log_softmax #8369

Merged
merged 4 commits into from
Jul 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions python/tvm/contrib/cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,3 +582,29 @@ def softmax(x, axis=-1):
),
name="y",
)


def log_softmax(x, axis=-1):
"""Compute log_softmax using CuDNN

Parameters
----------
x : tvm.te.Tensor
The input tensor

axis : int
The axis to compute log softmax over

Returns
-------
ret : tvm.te.Tensor
The result tensor
"""
return te.extern(
x.shape,
[x],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cudnn.log_softmax.forward", ins[0], outs[0], axis
),
name="y",
)
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@


# log_softmax
reg.register_schedule("nn.log_softmax", strategy.schedule_log_softmax)
reg.register_strategy("nn.log_softmax", strategy.log_softmax_strategy)
reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)


Expand Down
22 changes: 17 additions & 5 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,23 @@ def fast_softmax_strategy_cuda(attrs, inputs, out_type, target):
return strategy


@schedule_log_softmax.register(["cuda", "gpu"])
def schedule_log_softmax_cuda(attrs, outs, target):
"""scheudle log_softmax for cuda"""
with target:
return topi.cuda.schedule_softmax(outs)
@log_softmax_strategy.register(["cuda", "gpu"])
def log_softmax_strategy_cuda(attrs, inputs, out_type, target):
"""log_softmax cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_softmax(topi.nn.log_softmax),
wrap_topi_schedule(topi.cuda.schedule_softmax),
name="log_softmax.cuda",
)
if target.kind.name == "cuda" and "cudnn" in target.libs:
strategy.add_implementation(
wrap_compute_softmax(topi.cuda.log_softmax_cudnn),
wrap_topi_schedule(topi.cuda.schedule_log_softmax_cudnn),
name="log_softmax.cudnn",
plevel=15,
)
return strategy


@schedule_lrn.register(["cuda", "gpu"])
Expand Down
15 changes: 9 additions & 6 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,15 @@ def fast_softmax_strategy(attrs, inputs, out_type, target):
return strategy


# log_softmax
@generic_func
def schedule_log_softmax(attrs, outs, target):
"""Schedule log_softmax op"""
with target:
return topi.generic.schedule_softmax(outs)
@override_native_generic_func("log_softmax_strategy")
def log_softmax_strategy(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_softmax(topi.nn.log_softmax),
wrap_topi_schedule(topi.generic.schedule_softmax),
name="log_softmax.generic",
)
return strategy


# lrn
Expand Down
15 changes: 10 additions & 5 deletions python/tvm/relay/op/strategy/hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,16 @@ def softmax_strategy_hls(attrs, inputs, out_type, target):
return strategy


@schedule_log_softmax.register("hls")
def schedule_log_softmax_hls(attrs, inputs, out_type, target):
"""schedule log_softmax for hls"""
with target:
return topi.hls.schedule_softmax(outs)
@log_softmax_strategy.register("hls")
def log_softmax_strategy_hls(attrs, inputs, out_type, target):
"""log_softmax hls strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_softmax(topi.nn.log_softmax),
wrap_topi_schedule(topi.hls.schedule_softmax),
name="log_softmax.hls",
)
return strategy


@override_native_generic_func("conv2d_strategy")
Expand Down
15 changes: 10 additions & 5 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,16 @@ def fast_softmax_strategy_cpu(attrs, inputs, out_type, target):
return strategy


@schedule_log_softmax.register("cpu")
def schedule_log_softmax_cpu(attrs, outs, target):
"""schedule log_softmax op for x86"""
with target:
return topi.x86.schedule_softmax(outs)
@log_softmax_strategy.register("cpu")
def log_softmax_strategy_cpu(attrs, inputs, out_type, target):
"""log_softmax x86 strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_softmax(topi.nn.log_softmax),
wrap_topi_schedule(topi.x86.schedule_softmax),
name="log_softmax.x86",
)
return strategy


@conv2d_strategy.register("cpu")
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/topi/cuda/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,13 @@ def softmax_cudnn(x, axis=-1):
def schedule_softmax_cudnn(outs):
"""Schedule for softmax cudnn op"""
return generic.schedule_extern(outs)


def log_softmax_cudnn(x, axis=-1):
"""Perform log_softmax on the data using cudnn"""
return cudnn.log_softmax(x, axis)


def schedule_log_softmax_cudnn(outs):
"""Schedule for log_softmax cudnn op"""
return generic.schedule_extern(outs)
5 changes: 3 additions & 2 deletions python/tvm/topi/nn/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _normalize(exp, expsum, *indices):


@tvm.te.tag_scope(tag="log_softmax_output")
def log_softmax(x):
def log_softmax(x, axis=-1):
"""Perform log softmax activation on the data

Parameters
Expand All @@ -136,8 +136,9 @@ def log_softmax(x):
output : tvm.te.Tensor
2-D output with same shape
"""

assert len(x.shape) == 2, "only support 2-dim log softmax"
# pylint: disable=R1714
assert axis == -1 or axis == len(x.shape) - 1, "only support last axis log softmax"
m, n = x.shape
k = te.reduce_axis((0, n), name="k")
max_elem = te.compute((m,), lambda i: tvm.te.max(x[i, k], axis=k))
Expand Down
91 changes: 48 additions & 43 deletions src/runtime/contrib/cudnn/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,54 +31,59 @@ namespace contrib {

using namespace runtime;

TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward")
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* x = args[0];
DLTensor* y = args[1];
int axis = args[2];
int ndim = x->ndim;
int64_t* shape = x->shape;
if (axis < 0) axis += ndim;
ICHECK(axis >= 0 && axis < ndim);
void softmax_impl(cudnnSoftmaxAlgorithm_t alg, TVMArgs args, TVMRetValue* ret) {
DLTensor* x = args[0];
DLTensor* y = args[1];
int axis = args[2];
int ndim = x->ndim;
int64_t* shape = x->shape;
if (axis < 0) axis += ndim;
ICHECK(axis >= 0 && axis < ndim);

CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
entry_ptr->softmax_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype);
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
entry_ptr->softmax_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype);

// Set mode and shape descriptor
if (axis == ndim - 1) {
int64_t N = 1;
for (int i = 0; i < ndim - 1; ++i) {
N *= shape[i];
}
entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_INSTANCE;
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc,
CUDNN_TENSOR_NCHW, entry_ptr->softmax_entry.data_type,
static_cast<int>(N),
static_cast<int>(shape[ndim - 1]), 1, 1));
} else {
int64_t pre_axis_dim = 1;
int64_t post_axis_dim = 1;
for (int i = 0; i < ndim; ++i) {
if (i < axis) {
pre_axis_dim *= shape[i];
} else if (i > axis) {
post_axis_dim *= shape[i];
}
}
entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_CHANNEL;
CUDNN_CALL(cudnnSetTensor4dDescriptor(
entry_ptr->softmax_entry.shape_desc, CUDNN_TENSOR_NCHW,
entry_ptr->softmax_entry.data_type, static_cast<int>(pre_axis_dim),
static_cast<int>(shape[axis]), static_cast<int>(post_axis_dim), 1));
// Set mode and shape descriptor
if (axis == ndim - 1) {
int64_t N = 1;
for (int i = 0; i < ndim - 1; ++i) {
N *= shape[i];
}
entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_INSTANCE;
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc, CUDNN_TENSOR_NCHW,
entry_ptr->softmax_entry.data_type, static_cast<int>(N),
static_cast<int>(shape[ndim - 1]), 1, 1));
} else {
int64_t pre_axis_dim = 1;
int64_t post_axis_dim = 1;
for (int i = 0; i < ndim; ++i) {
if (i < axis) {
pre_axis_dim *= shape[i];
} else if (i > axis) {
post_axis_dim *= shape[i];
}
}
entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_CHANNEL;
CUDNN_CALL(cudnnSetTensor4dDescriptor(
entry_ptr->softmax_entry.shape_desc, CUDNN_TENSOR_NCHW, entry_ptr->softmax_entry.data_type,
static_cast<int>(pre_axis_dim), static_cast<int>(shape[axis]),
static_cast<int>(post_axis_dim), 1));
}

auto alpha = CuDNNDataType::GetConst<1>(entry_ptr->softmax_entry.data_type);
auto beta = CuDNNDataType::GetConst<0>(entry_ptr->softmax_entry.data_type);
CUDNN_CALL(cudnnSoftmaxForward(entry_ptr->handle, alg, entry_ptr->softmax_entry.mode, alpha,
entry_ptr->softmax_entry.shape_desc, x->data, beta,
entry_ptr->softmax_entry.shape_desc, y->data));
}

auto alpha = CuDNNDataType::GetConst<1>(entry_ptr->softmax_entry.data_type);
auto beta = CuDNNDataType::GetConst<0>(entry_ptr->softmax_entry.data_type);
CUDNN_CALL(cudnnSoftmaxForward(entry_ptr->handle, CUDNN_SOFTMAX_ACCURATE,
entry_ptr->softmax_entry.mode, alpha,
entry_ptr->softmax_entry.shape_desc, x->data, beta,
entry_ptr->softmax_entry.shape_desc, y->data));
TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward")
.set_body([](TVMArgs args, TVMRetValue* ret) {
softmax_impl(CUDNN_SOFTMAX_ACCURATE, args, ret);
});

TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.log_softmax.forward")
.set_body([](TVMArgs args, TVMRetValue* ret) { softmax_impl(CUDNN_SOFTMAX_LOG, args, ret); });

} // namespace contrib
} // namespace tvm
28 changes: 22 additions & 6 deletions tests/python/contrib/test_cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,30 +176,40 @@ def test_conv3d():
verify_conv3d("float32", "float32", tensor_format=0, groups=2)


def verify_softmax(shape, axis, dtype="float32"):
def verify_softmax(shape, axis, dtype="float32", log_softmax=False):
cudnn_op = cudnn.log_softmax if log_softmax else cudnn.softmax
testing_op = (
tvm.topi.testing.log_softmax_python if log_softmax else tvm.topi.testing.softmax_python
)

A = te.placeholder(shape, dtype=dtype, name="A")
B = cudnn.softmax(A, axis)
B = cudnn_op(A, axis)
s = te.create_schedule([B.op])

dev = tvm.cuda(0)
a_np = np.random.uniform(size=shape).astype(dtype)
b_np = tvm.topi.testing.softmax_python(a_np)
b_np = testing_op(a_np)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
f = tvm.build(s, [A, B], target="cuda --host=llvm", name="softmax")
f(a, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-3)


def verify_softmax_4d(shape, dtype="float32"):
def verify_softmax_4d(shape, dtype="float32", log_softmax=False):
cudnn_op = cudnn.log_softmax if log_softmax else cudnn.softmax
testing_op = (
tvm.topi.testing.log_softmax_python if log_softmax else tvm.topi.testing.softmax_python
)

A = te.placeholder(shape, dtype=dtype, name="A")
B = cudnn.softmax(A, axis=1)
B = cudnn_op(A, axis=1)
s = te.create_schedule([B.op])

dev = tvm.cuda(0)
n, c, h, w = shape
a_np = np.random.uniform(size=shape).astype(dtype)
b_np = tvm.topi.testing.softmax_python(a_np.transpose(0, 2, 3, 1).reshape(h * w, c))
b_np = testing_op(a_np.transpose(0, 2, 3, 1).reshape(h * w, c))
b_np = b_np.reshape(n, h, w, c).transpose(0, 3, 1, 2)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
Expand All @@ -217,6 +227,12 @@ def test_softmax():
verify_softmax_4d((1, 16, 256, 256))
verify_softmax_4d((1, 16, 256, 256), "float64")

verify_softmax((32, 10), -1, log_softmax=True)
verify_softmax((3, 4), -1, log_softmax=True)
verify_softmax((1, 5), -1, "float64", log_softmax=True)
verify_softmax_4d((1, 16, 256, 256), log_softmax=True)
verify_softmax_4d((1, 16, 256, 256), "float64", log_softmax=True)


test_kwargs_default_2d = {
"tensor_format": 0,
Expand Down