diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index c5af5d83bd7d..37ee6b6e929f 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -52,6 +52,27 @@ reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE) +@reg.register_legalize("nn.dense") +def legalize_dense(attrs, inputs, types): + """Legalize dense op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + return topi.nn.dense_legalize(attrs, inputs, types) + + # dense reg.register_strategy("nn.dense", strategy.dense_strategy) reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) @@ -67,6 +88,27 @@ def compute_fifo_buffer(attrs, inputs, out_type): reg.register_pattern("nn.fifo_buffer", OpPattern.OPAQUE) +@reg.register_legalize("nn.batch_matmul") +def legalize_batch_matmul(attrs, inputs, types): + """Legalize batch_matmul op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + return topi.nn.batch_matmul_legalize(attrs, inputs, types) + + # batch_matmul reg.register_strategy("nn.batch_matmul", strategy.batch_matmul_strategy) reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index e0ff5a12a9b2..bf3582c01d4f 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -55,5 +55,6 @@ from .conv2d_hwnc_tensorcore import * from .correlation import * from .sparse import * +from . import tensorcore_alter_op from .argwhere import * from .scan import * diff --git a/python/tvm/topi/cuda/conv2d_alter_op.py b/python/tvm/topi/cuda/conv2d_alter_op.py index 8cf0519ebe29..65bf9d1f178d 100644 --- a/python/tvm/topi/cuda/conv2d_alter_op.py +++ b/python/tvm/topi/cuda/conv2d_alter_op.py @@ -24,8 +24,10 @@ from .. import nn from ..utils import get_const_tuple from .conv2d_winograd import _infer_tile_size +from .tensorcore_alter_op import pad_to_tensorcore from ..nn import conv2d_legalize + logger = logging.getLogger("topi") @@ -345,4 +347,50 @@ def _conv2d_legalize(attrs, inputs, arg_types): else: out = relay.nn.conv2d(data, kernel, **new_attrs) return out + elif data_dtype in ["float16"]: # todo: support int8/int4 + if data_layout == "NHWC" and kernel_layout == "HWIO": + batch = data_tensor.shape[0].value + in_channel = data_tensor.shape[3].value + out_channel = kernel_tensor.shape[3].value + + if ( + (batch % 8 == 0 and in_channel % 16 == 0 and out_channel % 32 == 0) + or (batch % 16 == 0 and in_channel % 16 == 0 and out_channel % 16 == 0) + or (batch % 32 == 0 and in_channel % 16 == 0 and out_channel % 8 == 0) + ): + # no need to pad + return None + + (db, di, do), extra_flops = pad_to_tensorcore(batch, in_channel, out_channel) + + if extra_flops > 2: + logger.info("conv2d pad_to_tensorcore skipped, extra_flops %s", extra_flops) + return None + + logger.info("conv2d pad_to_tensorcore, extra_flops %s", extra_flops) + + # Pad batch size + if db != 0: + data = relay.nn.pad(data, pad_width=((0, db), (0, 0), (0, 0), (0, 0))) + + # Pad input channel + if di != 0: + data = relay.nn.pad(data, pad_width=((0, 0), (0, 0), (0, 0), (0, di))) + kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, di), (0, 0))) + + # Pad output channel + if do != 0: + kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, 0), (0, do))) + + if do != 0: + new_out_channel = out_channel + do + new_attrs["channels"] = new_out_channel + + out = relay.nn.conv2d(data, kernel, **new_attrs) + + if db != 0 or do != 0: + original_out_shape = [x.value for x in output_tensor.shape] + out = relay.strided_slice(out, begin=[0, 0, 0, 0], end=original_out_shape) + + return out return None diff --git a/python/tvm/topi/cuda/tensorcore_alter_op.py b/python/tvm/topi/cuda/tensorcore_alter_op.py new file mode 100644 index 000000000000..aec7acbfde56 --- /dev/null +++ b/python/tvm/topi/cuda/tensorcore_alter_op.py @@ -0,0 +1,204 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument +"""Tensorcore alter op and legalize functions for cuda backend""" + +import logging +import math +from tvm import relay + +from .. import nn + +logger = logging.getLogger("topi") + + +@nn.batch_matmul_legalize.register("cuda") +def _batch_matmul_legalize(attrs, inputs, arg_types): + """Legalizes batch_matmul op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + arg_types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + # Collect the input tensors. + x_tensor, y_tensor = arg_types[0], arg_types[1] + dtype = x_tensor.dtype + + # Collect the output tensor. + output_tensor = arg_types[2] + + # Collect the input exprs. + x, y = inputs + + # Pad input and output channels to use tensorcore schedule. + if dtype in ["float16"]: # todo: support int8/int4 + B, M, K = x_tensor.shape + B, N, K = y_tensor.shape + M = M.value + K = K.value + N = N.value + + # The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) + if ( + (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) + or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) + or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0) + ): + # no need to pad + return None + + (dm, dk, dn), extra_flops = pad_to_tensorcore(M, K, N) + + if extra_flops > 2: + logger.info("batch_matmul pad_to_tensorcore skipped, extra_flops %s", extra_flops) + return None + + logger.info("batch_matmul pad_to_tensorcore, extra_flops %s", extra_flops) + if dm or dk: + x_ = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) + else: + x_ = x + if dn or dk: + y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dn), (0, dk))) + else: + y_ = y + out_ = relay.nn.batch_matmul(x_, y_) + if dm or dn: + original_out_shape = [x.value for x in output_tensor.shape] + out = relay.strided_slice(out_, begin=[0, 0, 0], end=original_out_shape) + else: + out = out_ + return out + return None + + +@nn.dense_legalize.register("cuda") +def _dense_legalize(attrs, inputs, arg_types): + """Legalizes dense op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + # Collect the input tensors. + x_tensor, y_tensor = arg_types[0], arg_types[1] + dtype = x_tensor.dtype + + # Collect the output tensor. + output_tensor = arg_types[2] + + # Collect the input exprs. + x, y = inputs + + # Pad input and output channels to use tensorcore schedule. + if dtype in ["float16"]: # todo: support int8/int4 + M, K = x_tensor.shape + N, K = y_tensor.shape + try: + M = M.value + K = K.value + N = N.value + except AttributeError: + # todo: deal with unfixed shape when compiling wdl model + return None + + # The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) + if ( + (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) + or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) + or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0) + ): + # no need to pad + return None + + (dm, dk, dn), extra_flops_ratio = pad_to_tensorcore(M, K, N) + + if extra_flops_ratio > 2: + logger.info("dense pad_to_tensorcore skipped, extra_flops_ratio %s", extra_flops_ratio) + return None + + logger.info("dense pad_to_tensorcore, extra_flops_ratio %s", extra_flops_ratio) + + if dm or dk: + x_ = relay.nn.pad(x, pad_width=((0, dm), (0, dk))) + else: + x_ = x + if dn or dk: + y_ = relay.nn.pad(y, pad_width=((0, dn), (0, dk))) + else: + y_ = y + out_ = relay.nn.dense(x_, y_) + if dm or dn: + original_out_shape = [x.value for x in output_tensor.shape] + out = relay.strided_slice(out_, begin=[0, 0], end=original_out_shape) + else: + out = out_ + return out + return None + + +def pad_to_tensorcore(M, K, N): + """pad shape to enable tensorcore""" + candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)] + + flops = M * K * N + extra_flops = math.inf + best_pad = (0, 0, 0) + for padding in candidates: + dm, dk, dn = _pad_to(M, K, N, padding) + e = (M + dm) * (N + dn) * (K + dk) - M * N * K + # print(dm, dk, dn, e, flops) + if e < extra_flops: + extra_flops = e + best_pad = (dm, dk, dn) + return best_pad, extra_flops / flops + + +def _pad_to(M, K, N, PADDING): + dm, dk, dn = 0, 0, 0 + + if M % PADDING[0] != 0: + M_ = ((M + PADDING[0]) // PADDING[0]) * PADDING[0] + dm = M_ - M + if K % PADDING[1] != 0: + K_ = ((K + PADDING[1]) // PADDING[1]) * PADDING[1] + dk = K_ - K + if N % PADDING[2] != 0: + N_ = ((N + PADDING[2]) // PADDING[2]) * PADDING[2] + dn = N_ - N + + return dm, dk, dn diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py index 9ca2df7c46e1..9c5848129397 100644 --- a/python/tvm/topi/nn/batch_matmul.py +++ b/python/tvm/topi/nn/batch_matmul.py @@ -16,6 +16,7 @@ # under the License. """Batch matrix multiplication""" # pylint: disable=invalid-name +import tvm from tvm import te, auto_scheduler from ..utils import get_const_tuple @@ -77,3 +78,26 @@ def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""): output = auto_scheduler.rewrite_compute_body(output, auto_scheduler_rewritten_layout) return output + + +@tvm.target.generic_func +def batch_matmul_legalize(attrs, inputs, types): + """Legalizes batch_matmul op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current batch_matmul + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + # not to change by default + # pylint: disable=unused-argument + return None diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py index 474fea42a7cb..bb6ea90c3fcd 100644 --- a/python/tvm/topi/nn/dense.py +++ b/python/tvm/topi/nn/dense.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """TVM operator fully connected compute.""" +import tvm from tvm import te, auto_scheduler from .. import tag @@ -80,3 +81,26 @@ def dense(data, weight, bias=None, out_dtype=None, auto_scheduler_rewritten_layo matmul = auto_scheduler.rewrite_compute_body(matmul, auto_scheduler_rewritten_layout) return matmul + + +@tvm.target.generic_func +def dense_legalize(attrs, inputs, types): + """Legalizes dense op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current dense + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + # not to change by default + # pylint: disable=unused-argument + return None diff --git a/tests/python/relay/test_pass_legalize_tensorcore.py b/tests/python/relay/test_pass_legalize_tensorcore.py new file mode 100644 index 000000000000..5ecda4ba07a8 --- /dev/null +++ b/tests/python/relay/test_pass_legalize_tensorcore.py @@ -0,0 +1,239 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test legalize pass""" +import numpy as np +import tvm +from tvm import te +from tvm import topi +from tvm import relay +from tvm.contrib import graph_runtime +from tvm.relay import transform, analysis +from tvm.relay.testing.temp_op_attr import TempOpAttr + + +def run_opt_pass(expr, passes): + passes = passes if isinstance(passes, list) else [passes] + mod = tvm.IRModule.from_expr(expr) + seq = tvm.transform.Sequential(passes) + with tvm.transform.PassContext(opt_level=3): + mod = seq(mod) + entry = mod["main"] + return entry if isinstance(expr, relay.Function) else entry.body + + +@tvm.testing.uses_gpu +def test_legalize_conv2d(): + """test legalize conv2d to enable tensorcore""" + + def _test_legalize_conv2d(data_shape, kernel_shape, pad_shape, do_pad=True): + out_channel = kernel_shape[3] + out_shape = list(data_shape) + out_shape[3] = out_channel + db, di, do = pad_shape + + def before(): + x = relay.var("x", shape=data_shape, dtype="float16") + weight = relay.var("weight", shape=kernel_shape, dtype="float16") + y = relay.nn.conv2d( + x, + weight, + channels=out_channel, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.Function([x, weight], y) + return y + + def legalize_conv2d(attrs, inputs, types): + with tvm.target.Target("cuda"): + return topi.nn.conv2d_legalize(attrs, inputs, types) + + def expected(): + if not do_pad: + return before() + x = relay.var("x", shape=data_shape, dtype="float16") + if db or di: + x_pad = relay.nn.pad(x, pad_width=((0, db), (0, 0), (0, 0), (0, di))) + else: + x_pad = x + weight = relay.var("weight", shape=(kernel_shape), dtype="float16") + if di or do: + weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, 0), (0, di), (0, do))) + else: + weight_pad = weight + y_pad = relay.nn.conv2d( + x_pad, + weight=weight_pad, + channels=out_channel + do, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + if db or do: + y = relay.strided_slice(y_pad, begin=[0, 0, 0, 0], end=out_shape) + else: + y = y_pad + y = relay.Function([x, weight], y) + return y + + with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d): + a = before() + a = run_opt_pass(a, transform.Legalize()) + b = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b) + + # conv2d pad batch + _test_legalize_conv2d((7, 16, 16, 64), (3, 3, 64, 64), (1, 0, 0)) + _test_legalize_conv2d((3, 16, 16, 64), (3, 3, 64, 64), (5, 0, 0)) + _test_legalize_conv2d((2, 16, 16, 64), (3, 3, 64, 64), (0, 0, 0), False) + # conv2d pad in_channel + _test_legalize_conv2d((8, 16, 16, 63), (3, 3, 63, 64), (0, 1, 0)) + _test_legalize_conv2d((8, 16, 16, 33), (3, 3, 33, 64), (0, 15, 0)) + _test_legalize_conv2d((8, 16, 16, 13), (3, 3, 13, 64), (0, 3, 0)) + _test_legalize_conv2d((8, 16, 16, 1), (3, 3, 1, 64), (0, 0, 0), False) + # conv2d pad out_channel + _test_legalize_conv2d((8, 16, 16, 64), (3, 3, 64, 63), (0, 0, 1)) + _test_legalize_conv2d((8, 16, 16, 64), (3, 3, 64, 33), (0, 0, 31)) + _test_legalize_conv2d((8, 16, 16, 64), (3, 3, 64, 1), (0, 0, 0), False) + + +@tvm.testing.uses_gpu +def test_legalize_dense(): + def _test_legalize_dense(data_shape, kernel_shape, pad_shape, do_pad=True): + """test legalize dense to enable tensorcore""" + M, K = data_shape + N, _ = kernel_shape + out_shape = (M, N) + dm, dk, dn = pad_shape + + def before(): + x = relay.var("x", shape=data_shape, dtype="float16") + weight = relay.var("weight", shape=kernel_shape, dtype="float16") + y = relay.nn.dense(x, weight) + y = relay.Function([x, weight], y) + return y + + def legalize_dense(attrs, inputs, types): + with tvm.target.Target("cuda"): + return topi.nn.dense_legalize(attrs, inputs, types) + + def expected(): + if not do_pad: + return before() + x = relay.var("x", shape=data_shape, dtype="float16") + if dm or dk: + x_pad = relay.nn.pad(x, pad_width=((0, dm), (0, dk))) + else: + x_pad = x + weight = relay.var("weight", shape=(kernel_shape), dtype="float16") + if dn or dk: + weight_pad = relay.nn.pad(weight, pad_width=((0, dn), (0, dk))) + else: + weight_pad = weight + y_pad = relay.nn.dense( + x_pad, + weight_pad, + ) + if dm or dn: + y = relay.strided_slice(y_pad, begin=[0, 0], end=out_shape) + else: + y = y_pad + y = relay.Function([x, weight], y) + return y + + with TempOpAttr("nn.dense", "FTVMLegalize", legalize_dense): + a = before() + a = run_opt_pass(a, transform.Legalize()) + b = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b) + + # dense + _test_legalize_dense((8, 16), (32, 16), (0, 0, 0), False) + _test_legalize_dense((7, 16), (32, 16), (1, 0, 0)) + _test_legalize_dense((8, 15), (32, 15), (0, 1, 0)) + _test_legalize_dense((8, 16), (31, 16), (0, 0, 1)) + _test_legalize_dense((7, 15), (31, 15), (1, 1, 1)) + _test_legalize_dense((3, 16), (32, 16), (5, 0, 0)) + _test_legalize_dense((2, 16), (32, 16), (0, 0, 0), False) + + +@tvm.testing.uses_gpu +def test_legalize_batch_matmul(): + def _test_legalize_batch_matmul(data_shape, kernel_shape, pad_shape, do_pad=True): + """test legalize dense to enable tensorcore""" + B, M, _ = data_shape + _, N, _ = kernel_shape + out_shape = (B, M, N) + dm, dk, dn = pad_shape + + def before(): + x = relay.var("x", shape=data_shape, dtype="float16") + weight = relay.var("weight", shape=kernel_shape, dtype="float16") + y = relay.nn.batch_matmul(x, weight) + y = relay.Function([x, weight], y) + return y + + def legalize_batch_matmul(attrs, inputs, types): + with tvm.target.Target("cuda"): + return topi.nn.batch_matmul_legalize(attrs, inputs, types) + + def expected(): + if not do_pad: + return before() + x = relay.var("x", shape=data_shape, dtype="float16") + if dm or dk: + x_pad = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) + else: + x_pad = x + weight = relay.var("weight", shape=(kernel_shape), dtype="float16") + if dn or dk: + weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, dn), (0, dk))) + else: + weight_pad = weight + y_pad = relay.nn.batch_matmul( + x_pad, + weight_pad, + ) + if dm or dn: + y = relay.strided_slice(y_pad, begin=[0, 0, 0], end=out_shape) + else: + y = y_pad + y = relay.Function([x, weight], y) + return y + + with TempOpAttr("nn.batch_matmul", "FTVMLegalize", legalize_batch_matmul): + a = before() + a = run_opt_pass(a, transform.Legalize()) + b = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b) + + _test_legalize_batch_matmul((16, 8, 16), (16, 32, 16), (0, 0, 0), False) + _test_legalize_batch_matmul((16, 7, 16), (16, 32, 16), (1, 0, 0)) + _test_legalize_batch_matmul((16, 8, 15), (16, 32, 15), (0, 1, 0)) + _test_legalize_batch_matmul((16, 8, 16), (16, 31, 16), (0, 0, 1)) + _test_legalize_batch_matmul((16, 7, 15), (16, 31, 15), (1, 1, 1)) + _test_legalize_batch_matmul((16, 3, 16), (16, 32, 16), (5, 0, 0)) + _test_legalize_batch_matmul((16, 2, 16), (16, 32, 16), (0, 0, 0), False) + + +if __name__ == "__main__": + test_legalize_conv2d() + test_legalize_dense() + test_legalize_batch_matmul()