diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index d8202836a4d4f..8367a681d0229 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -89,6 +89,18 @@ def softmax_strategy_cuda(attrs, inputs, out_type, target): return strategy +@fast_softmax_strategy.register(["cuda", "gpu"]) +def fast_softmax_strategy_cuda(attrs, inputs, out_type, target): + """fast_softmax cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_softmax(topi.nn.fast_softmax), + wrap_topi_schedule(topi.cuda.schedule_softmax), + name="fast_softmax.cuda", + ) + return strategy + + @schedule_log_softmax.register(["cuda", "gpu"]) def schedule_log_softmax_cuda(attrs, outs, target): """scheudle log_softmax for cuda""" diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 60bd92ef63d1c..c21ec4d139061 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -79,6 +79,18 @@ def softmax_strategy_cpu(attrs, inputs, out_type, target): return strategy +@fast_softmax_strategy.register("cpu") +def fast_softmax_strategy_cpu(attrs, inputs, out_type, target): + """fast_softmax x86 strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_softmax(topi.nn.fast_softmax), + wrap_topi_schedule(topi.x86.schedule_softmax), + name="fast_softmax.x86", + ) + return strategy + + @schedule_log_softmax.register("cpu") def schedule_log_softmax_cpu(attrs, outs, target): """schedule log_softmax op for x86""" diff --git a/python/tvm/topi/cuda/softmax.py b/python/tvm/topi/cuda/softmax.py index 99fbdd0367db5..b743aefc50d5e 100644 --- a/python/tvm/topi/cuda/softmax.py +++ b/python/tvm/topi/cuda/softmax.py @@ -47,8 +47,15 @@ def schedule_softmax(outs): expsum = softmax.op.input_tensors[1] exp = softmax.op.input_tensors[0] max_elem = s[exp].op.input_tensors[1] + delta = None + elif op_tag == "fast_softmax_output": + expsum = softmax.op.input_tensors[1] + exp = softmax.op.input_tensors[0] + delta = s[exp].op.input_tensors[0] + max_elem = s[delta].op.input_tensors[1] elif op_tag == "log_softmax_output": exp = None + delta = None max_elem = softmax.op.input_tensors[1] expsum = softmax.op.input_tensors[2] else: @@ -73,6 +80,8 @@ def sched_warp_softmax(): if len(softmax.shape) > 2: ops = [max_elem.op, expsum.op, softmax.op] + if delta is not None: + ops.append(delta.op) if exp is not None: ops.append(exp.op) @@ -99,7 +108,10 @@ def sched_warp_softmax(): s[expsum].compute_at(s[softmax], xo) # (2) exp - if exp is not None: + if delta is not None: + s[exp].compute_inline() + s[delta].compute_inline() + elif exp is not None: xo, xi = s[exp].split(exp.op.axis[1], nparts=num_thread) _, xii = s[exp].split(xi, factor=4) s[exp].vectorize(xii) @@ -112,7 +124,7 @@ def sched_warp_softmax(): k = max_elem.op.reduce_axis[0] ko, _ = s[max_elem].split(k, nparts=num_thread) s[max_elem].bind(ko, thread_x) - if exp is not None: + if exp is not None and delta is None: s[max_elem].compute_at(s[exp], xo) else: s[max_elem].bind(ko, thread_x) @@ -123,7 +135,10 @@ def sched_warp_softmax(): block_x = te.thread_axis("blockIdx.x") thread_x = te.thread_axis((0, num_thread), "threadIdx.x") - if exp is not None: + if delta is not None: + s[exp].compute_inline() + s[delta].compute_inline() + elif exp is not None: s[exp].bind(exp.op.axis[0], block_x) s[max_elem].bind(max_elem.op.axis[0], block_x) diff --git a/python/tvm/topi/x86/nn.py b/python/tvm/topi/x86/nn.py index 0994700fe98c8..4c39f2ad7382e 100644 --- a/python/tvm/topi/x86/nn.py +++ b/python/tvm/topi/x86/nn.py @@ -42,9 +42,17 @@ def schedule_softmax(outs): exp = softmax.op.input_tensors[0] expsum = softmax.op.input_tensors[1] max_elem = s[exp].op.input_tensors[1] + delta = None + axis = int(softmax.op.attrs["axis"]) + elif op_tag == "fast_softmax_output": + exp = softmax.op.input_tensors[0] + expsum = softmax.op.input_tensors[1] + delta = s[exp].op.input_tensors[0] + max_elem = s[delta].op.input_tensors[1] axis = int(softmax.op.attrs["axis"]) elif op_tag == "log_softmax_output": exp = None + delta = None max_elem = softmax.op.input_tensors[1] expsum = softmax.op.input_tensors[2] axis = 1 @@ -65,6 +73,9 @@ def schedule_softmax(outs): s[max_elem].compute_at(s[softmax], fused_outer_axes) s[expsum].compute_at(s[softmax], fused_outer_axes) + if delta is not None: + s[exp].compute_inline() + s[delta].compute_inline() if exp is not None: s[exp].compute_at(s[softmax], fused_outer_axes) diff --git a/tests/python/relay/test_op_fast_math.py b/tests/python/relay/test_op_fast_math.py index c9314fae37aca..f968dbedddfe2 100644 --- a/tests/python/relay/test_op_fast_math.py +++ b/tests/python/relay/test_op_fast_math.py @@ -23,9 +23,11 @@ from tvm import topi from tvm import te from tvm.contrib import graph_executor +from tvm.topi import testing -def test_fastmath(): +@tvm.testing.parametrize_targets("llvm", "cuda") +def test_fastmath(target, dev): def test_apply(relay_op, name, f_numpy, low, high, step, dtype="float32"): a_np = np.arange(low, high, step).astype(dtype).reshape((1, -1)) b_np = f_numpy(a_np) @@ -36,13 +38,14 @@ def test_apply(relay_op, name, f_numpy, low, high, step, dtype="float32"): mod = tvm.IRModule.from_expr(func) with tvm.transform.PassContext(opt_level=3, required_pass=["FastMath"]): - graph, lib, params = relay.build(mod, target="llvm", params=None) + graph, lib, params = relay.build(mod, target=target, params=None) # Check that the op related to fast math have been convered to function in lib func_name = "fused_" + name - assert lib.get_function(func_name) + # When there're multiple targets in tvm.testing.parametrize_targets, the function + # built will have a "_1" in function name + assert func_name in graph - dev = tvm.cpu(0) m = graph_executor.create(graph, lib, dev) # Set inputs m.set_input("x", tvm.nd.array(a_np, dev))