Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA][PASS]Legalize tensorcore #7147

Merged
merged 17 commits into from
Jan 29, 2021
42 changes: 42 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,27 @@
reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)


@reg.register_legalize("nn.dense")
def legalize_dense(attrs, inputs, types):
"""Legalize dense op.

Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types

Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
return topi.nn.dense_legalize(attrs, inputs, types)


# dense
reg.register_strategy("nn.dense", strategy.dense_strategy)
reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
Expand All @@ -67,6 +88,27 @@ def compute_fifo_buffer(attrs, inputs, out_type):
reg.register_pattern("nn.fifo_buffer", OpPattern.OPAQUE)


@reg.register_legalize("nn.batch_matmul")
def legalize_batch_matmul(attrs, inputs, types):
"""Legalize batch_matmul op.

Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types

Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
return topi.nn.batch_matmul_legalize(attrs, inputs, types)


# batch_matmul
reg.register_strategy("nn.batch_matmul", strategy.batch_matmul_strategy)
reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from . import _make
from .dyn import _make as _dyn_make
from .tensor import shape_of
from ..expr import TupleWrapper, const, Expr, Tuple
from ..expr import TupleWrapper, const, Expr, Tuple, Constant
from ...tir import expr as _expr


Expand Down Expand Up @@ -884,7 +884,7 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"):
The computed result.
"""
strides = strides or [1]
if isinstance(begin, Expr) or isinstance(end, Expr) or isinstance(strides, Expr):
if any([(isinstance(i, Expr) and not isinstance(i, Constant)) for i in (begin, end, strides)]):
if isinstance(begin, (tuple, list)):
begin = const(list(begin))
if isinstance(end, (tuple, list)):
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,5 @@
from .conv2d_hwnc_tensorcore import *
from .correlation import *
from .sparse import *
from . import tensorcore_alter_op
from .argwhere import *
49 changes: 49 additions & 0 deletions python/tvm/topi/cuda/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
from .. import nn
from ..utils import get_const_tuple
from .conv2d_winograd import _infer_tile_size
from .tensorcore_alter_op import pad_to_tensorcore
from ..nn import conv2d_legalize


logger = logging.getLogger("topi")


Expand Down Expand Up @@ -345,4 +347,51 @@ def _conv2d_legalize(attrs, inputs, arg_types):
else:
out = relay.nn.conv2d(data, kernel, **new_attrs)
return out
elif data_dtype in ["float16"]: # todo: support int8/int4
if data_layout == "NHWC" and kernel_layout == "HWIO":
batch = data_tensor.shape[0].value
in_channel = data_tensor.shape[3].value
out_channel = kernel_tensor.shape[3].value

if (
(batch % 8 == 0 and in_channel % 16 == 0 and out_channel % 32 == 0)
or (batch % 16 == 0 and in_channel % 16 == 0 and out_channel % 16 == 0)
or (batch % 32 == 0 and in_channel % 16 == 0 and out_channel % 8 == 0)
):
# no need to pad
return None

(db, di, do), extra_flops = pad_to_tensorcore(batch, in_channel, out_channel)

if extra_flops > 2:
logger.info("conv2d pad_to_tensorcore skipped, extra_flops %s", extra_flops)
return None

logger.info("conv2d pad_to_tensorcore, extra_flops %s", extra_flops)

# Pad batch size
if db != 0:
data = relay.nn.pad(data, pad_width=((0, db), (0, 0), (0, 0), (0, 0)))

# Pad input channel
if di != 0:
data = relay.nn.pad(data, pad_width=((0, 0), (0, 0), (0, 0), (0, di)))
kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, di), (0, 0)))

# Pad output channel
if do != 0:
kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, 0), (0, do)))

if do != 0:
new_out_channel = out_channel + do
new_attrs["channels"] = new_out_channel
out = tvm.relay.nn.conv2d(data, kernel, **new_attrs)
else:
out = relay.nn.conv2d(data, kernel, **new_attrs)
Meteorix marked this conversation as resolved.
Show resolved Hide resolved

if db != 0 or do != 0:
original_out_shape = [x.value for x in output_tensor.shape]
out = relay.strided_slice(out, begin=[0, 0, 0, 0], end=original_out_shape)

return out
return None
204 changes: 204 additions & 0 deletions python/tvm/topi/cuda/tensorcore_alter_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument
"""Tensorcore alter op and legalize functions for cuda backend"""

import logging
import math
from tvm import relay

from .. import nn

logger = logging.getLogger("topi")


@nn.batch_matmul_legalize.register("cuda")
def _batch_matmul_legalize(attrs, inputs, arg_types):
"""Legalizes batch_matmul op.

Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
arg_types : list of types
List of input and output types

Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
# Collect the input tensors.
x_tensor, y_tensor = arg_types[0], arg_types[1]
dtype = x_tensor.dtype

# Collect the output tensor.
output_tensor = arg_types[2]

# Collect the input exprs.
x, y = inputs

# Pad input and output channels to use tensorcore schedule.
if dtype in ["float16"]: # todo: support int8/int4
B, M, K = x_tensor.shape
B, N, K = y_tensor.shape
M = M.value
K = K.value
N = N.value

# The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)
if (
(M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
):
# no need to pad
return None

(dm, dk, dn), extra_flops = pad_to_tensorcore(M, K, N)

if extra_flops > 2:
logger.info("batch_matmul pad_to_tensorcore skipped, extra_flops %s", extra_flops)
return None

logger.info("batch_matmul pad_to_tensorcore, extra_flops %s", extra_flops)
if dm or dk:
x_ = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk)))
else:
x_ = x
if dn or dk:
y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dn), (0, dk)))
else:
y_ = y
out_ = relay.nn.batch_matmul(x_, y_)
if dm or dn:
original_out_shape = [x.value for x in output_tensor.shape]
out = relay.strided_slice(out_, begin=[0, 0, 0], end=original_out_shape)
else:
out = out_
return out
return None


@nn.dense_legalize.register("cuda")
def _dense_legalize(attrs, inputs, arg_types):
"""Legalizes dense op.

Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types

Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
# Collect the input tensors.
x_tensor, y_tensor = arg_types[0], arg_types[1]
dtype = x_tensor.dtype

# Collect the output tensor.
output_tensor = arg_types[2]

# Collect the input exprs.
x, y = inputs

# Pad input and output channels to use tensorcore schedule.
if dtype in ["float16"]: # todo: support int8/int4
M, K = x_tensor.shape
N, K = y_tensor.shape
try:
M = M.value
K = K.value
N = N.value
except AttributeError:
# todo: deal with unfixed shape when compiling wdl model
return None

# The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)
if (
(M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
):
# no need to pad
return None

(dm, dk, dn), extra_flops_ratio = pad_to_tensorcore(M, K, N)

if extra_flops_ratio > 2:
logger.info("dense pad_to_tensorcore skipped, extra_flops_ratio %s", extra_flops_ratio)
return None

logger.info("dense pad_to_tensorcore, extra_flops_ratio %s", extra_flops_ratio)

if dm or dk:
x_ = relay.nn.pad(x, pad_width=((0, dm), (0, dk)))
else:
x_ = x
if dn or dk:
y_ = relay.nn.pad(y, pad_width=((0, dn), (0, dk)))
else:
y_ = y
out_ = relay.nn.dense(x_, y_)
if dm or dn:
original_out_shape = [x.value for x in output_tensor.shape]
out = relay.strided_slice(out_, begin=[0, 0], end=original_out_shape)
else:
out = out_
return out
return None


def pad_to_tensorcore(M, K, N):
"""pad shape to enable tensorcore"""
candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)]

flops = M * K * N
extra_flops = math.inf
best_pad = (0, 0, 0)
for padding in candidates:
dm, dk, dn = _pad_to(M, K, N, padding)
e = (M + dm) * (N + dn) * (K + dk) - M * N * K
# print(dm, dk, dn, e, flops)
if e < extra_flops:
extra_flops = e
best_pad = (dm, dk, dn)
return best_pad, extra_flops / flops


def _pad_to(M, K, N, PADDING):
dm, dk, dn = 0, 0, 0

if M % PADDING[0] != 0:
M_ = ((M + PADDING[0]) // PADDING[0]) * PADDING[0]
dm = M_ - M
if K % PADDING[1] != 0:
K_ = ((K + PADDING[1]) // PADDING[1]) * PADDING[1]
dk = K_ - K
if N % PADDING[2] != 0:
N_ = ((N + PADDING[2]) // PADDING[2]) * PADDING[2]
dn = N_ - N

return dm, dk, dn
24 changes: 24 additions & 0 deletions python/tvm/topi/nn/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
"""Batch matrix multiplication"""
# pylint: disable=invalid-name
import tvm
from tvm import te, auto_scheduler
from ..utils import get_const_tuple

Expand Down Expand Up @@ -77,3 +78,26 @@ def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""):
output = auto_scheduler.rewrite_compute_body(output, auto_scheduler_rewritten_layout)

return output


@tvm.target.generic_func
def batch_matmul_legalize(attrs, inputs, types):
"""Legalizes batch_matmul op.

Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current batch_matmul
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types

Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
# not to change by default
# pylint: disable=unused-argument
return None
Loading