Skip to content

Commit

Permalink
[FastMath] Add cuda & x86 schedules for fast_softmax (apache#8150)
Browse files Browse the repository at this point in the history
* Add cuda & x86 schedules for fast_softmax

* Bug fix

* Re-trigger CI
  • Loading branch information
jcf94 authored and Trevor Morris committed Jun 17, 2021
1 parent ee08dd5 commit d2d1db8
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 7 deletions.
12 changes: 12 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,18 @@ def softmax_strategy_cuda(attrs, inputs, out_type, target):
return strategy


@fast_softmax_strategy.register(["cuda", "gpu"])
def fast_softmax_strategy_cuda(attrs, inputs, out_type, target):
"""fast_softmax cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_softmax(topi.nn.fast_softmax),
wrap_topi_schedule(topi.cuda.schedule_softmax),
name="fast_softmax.cuda",
)
return strategy


@schedule_log_softmax.register(["cuda", "gpu"])
def schedule_log_softmax_cuda(attrs, outs, target):
"""scheudle log_softmax for cuda"""
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ def softmax_strategy_cpu(attrs, inputs, out_type, target):
return strategy


@fast_softmax_strategy.register("cpu")
def fast_softmax_strategy_cpu(attrs, inputs, out_type, target):
"""fast_softmax x86 strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_softmax(topi.nn.fast_softmax),
wrap_topi_schedule(topi.x86.schedule_softmax),
name="fast_softmax.x86",
)
return strategy


@schedule_log_softmax.register("cpu")
def schedule_log_softmax_cpu(attrs, outs, target):
"""schedule log_softmax op for x86"""
Expand Down
21 changes: 18 additions & 3 deletions python/tvm/topi/cuda/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,15 @@ def schedule_softmax(outs):
expsum = softmax.op.input_tensors[1]
exp = softmax.op.input_tensors[0]
max_elem = s[exp].op.input_tensors[1]
delta = None
elif op_tag == "fast_softmax_output":
expsum = softmax.op.input_tensors[1]
exp = softmax.op.input_tensors[0]
delta = s[exp].op.input_tensors[0]
max_elem = s[delta].op.input_tensors[1]
elif op_tag == "log_softmax_output":
exp = None
delta = None
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]
else:
Expand All @@ -73,6 +80,8 @@ def sched_warp_softmax():

if len(softmax.shape) > 2:
ops = [max_elem.op, expsum.op, softmax.op]
if delta is not None:
ops.append(delta.op)
if exp is not None:
ops.append(exp.op)

Expand All @@ -99,7 +108,10 @@ def sched_warp_softmax():
s[expsum].compute_at(s[softmax], xo)

# (2) exp
if exp is not None:
if delta is not None:
s[exp].compute_inline()
s[delta].compute_inline()
elif exp is not None:
xo, xi = s[exp].split(exp.op.axis[1], nparts=num_thread)
_, xii = s[exp].split(xi, factor=4)
s[exp].vectorize(xii)
Expand All @@ -112,7 +124,7 @@ def sched_warp_softmax():
k = max_elem.op.reduce_axis[0]
ko, _ = s[max_elem].split(k, nparts=num_thread)
s[max_elem].bind(ko, thread_x)
if exp is not None:
if exp is not None and delta is None:
s[max_elem].compute_at(s[exp], xo)
else:
s[max_elem].bind(ko, thread_x)
Expand All @@ -123,7 +135,10 @@ def sched_warp_softmax():
block_x = te.thread_axis("blockIdx.x")
thread_x = te.thread_axis((0, num_thread), "threadIdx.x")

if exp is not None:
if delta is not None:
s[exp].compute_inline()
s[delta].compute_inline()
elif exp is not None:
s[exp].bind(exp.op.axis[0], block_x)

s[max_elem].bind(max_elem.op.axis[0], block_x)
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/topi/x86/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,17 @@ def schedule_softmax(outs):
exp = softmax.op.input_tensors[0]
expsum = softmax.op.input_tensors[1]
max_elem = s[exp].op.input_tensors[1]
delta = None
axis = int(softmax.op.attrs["axis"])
elif op_tag == "fast_softmax_output":
exp = softmax.op.input_tensors[0]
expsum = softmax.op.input_tensors[1]
delta = s[exp].op.input_tensors[0]
max_elem = s[delta].op.input_tensors[1]
axis = int(softmax.op.attrs["axis"])
elif op_tag == "log_softmax_output":
exp = None
delta = None
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]
axis = 1
Expand All @@ -65,6 +73,9 @@ def schedule_softmax(outs):
s[max_elem].compute_at(s[softmax], fused_outer_axes)
s[expsum].compute_at(s[softmax], fused_outer_axes)

if delta is not None:
s[exp].compute_inline()
s[delta].compute_inline()
if exp is not None:
s[exp].compute_at(s[softmax], fused_outer_axes)

Expand Down
11 changes: 7 additions & 4 deletions tests/python/relay/test_op_fast_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
from tvm import topi
from tvm import te
from tvm.contrib import graph_executor
from tvm.topi import testing


def test_fastmath():
@tvm.testing.parametrize_targets("llvm", "cuda")
def test_fastmath(target, dev):
def test_apply(relay_op, name, f_numpy, low, high, step, dtype="float32"):
a_np = np.arange(low, high, step).astype(dtype).reshape((1, -1))
b_np = f_numpy(a_np)
Expand All @@ -36,13 +38,14 @@ def test_apply(relay_op, name, f_numpy, low, high, step, dtype="float32"):
mod = tvm.IRModule.from_expr(func)

with tvm.transform.PassContext(opt_level=3, required_pass=["FastMath"]):
graph, lib, params = relay.build(mod, target="llvm", params=None)
graph, lib, params = relay.build(mod, target=target, params=None)

# Check that the op related to fast math have been convered to function in lib
func_name = "fused_" + name
assert lib.get_function(func_name)
# When there're multiple targets in tvm.testing.parametrize_targets, the function
# built will have a "_1" in function name
assert func_name in graph

dev = tvm.cpu(0)
m = graph_executor.create(graph, lib, dev)
# Set inputs
m.set_input("x", tvm.nd.array(a_np, dev))
Expand Down

0 comments on commit d2d1db8

Please sign in to comment.