Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay, Quantization, TOPI] int8 dense on CUDA & Dense op quantization #2877

Merged
merged 12 commits into from
Apr 26, 2019
6 changes: 6 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,16 @@ struct GlobalPool2DAttrs : public tvm::AttrsNode<GlobalPool2DAttrs> {
/*! \brief Attributes for dense operator */
struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
IndexExpr units;
DataType out_dtype;

TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.DenseAttrs") {
TVM_ATTR_FIELD(units)
.describe("Number of hidden units of the dense transformation.");

// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/autotvm/task/topi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def _topi_nn_conv2d_transpose_nchw(*args, **kwargs):
def _topi_nn_dense(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
data, weight, bias = args
data, weight, bias, _ = args
C = topi.nn.dense(*args, **kwargs)
s = topi.generic.schedule_dense([C])
if bias is not None:
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def schedule_log_softmax(_, outputs, target):
@reg.register_compute("nn.dense")
def compute_dense(attrs, inputs, out_type, target):
"""Compute definition of dense"""
return [topi.nn.dense(inputs[0], inputs[1])]
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
return [topi.nn.dense(inputs[0], inputs[1], out_dtype=out_dtype)]

@reg.register_schedule("nn.dense")
def schedule_dense(attrs, outputs, target):
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def bias_add(data, bias, axis=1):
return _make.bias_add(data, bias, axis)


def dense(data, weight, units=None):
def dense(data, weight, units=None, out_dtype=""):
"""Dense operator.
Applies a linear transformation

Expand All @@ -494,12 +494,15 @@ def dense(data, weight, units=None):
units : int, optional
Number of hidden units of the dense transformation.

out_dtype : str, optional
Specifies the output data type for mixed precision dense.

Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.dense(data, weight, units)
return _make.dense(data, weight, units, out_dtype)


def relu(data):
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relay/quantize/_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,26 @@ def conv2d_rewrite(ref_call, new_args, ctx):
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)


@register_annotate_function("nn.dense")
def dense_rewrite(ref_call, new_args, ctx):
"""Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of
dense will be quantized to weight field. Output would be in activation field."""
cnt = _conv_counter()
if cnt < current_qconfig().skip_k_conv:
return None
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

if lhs_kind is None or lhs_kind != QAnnotateKind.INPUT:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)

assert rhs_kind is None
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)

expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)


@register_annotate_function("multiply")
def multiply_rewrite(ref_call, new_args, ctx):
"""Rewrite function for multiply."""
Expand Down
10 changes: 8 additions & 2 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,18 +131,24 @@ bool DenseRel(const Array<Type>& types,
oshape.Set((oshape.size() - 1), wshape[0]);
}

DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
// assign output type
reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
return true;
}


// Positional relay function to create dense operator used by frontend FFI.
Expr MakeDense(Expr data,
Expr weight,
IndexExpr units) {
IndexExpr units,
DataType out_dtype) {
auto attrs = make_node<DenseAttrs>();
attrs->units = units;
attrs->out_dtype = out_dtype;
static const Op& op = Op::Get("nn.dense");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
}
Expand Down
33 changes: 33 additions & 0 deletions src/relay/pass/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,39 @@ RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", Conv2dRealize);


Expr DenseRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 2);
if (!new_args[0]->derived_from<TempExprNode>() || !new_args[1]->derived_from<TempExprNode>()) {
return Expr(nullptr);
}
const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
const auto* rhs = new_args[1].as<QRealizeIntExprNode>();

Expr ldata = lhs->data;
if (lhs->dtype != cfg->dtype_input) {
ldata = Cast(ldata, cfg->dtype_input);
}
Expr rdata = Cast(rhs->data, cfg->dtype_weight);

const auto ref_attrs = ref_call->attrs.as<DenseAttrs>();
auto attrs = make_node<DenseAttrs>();
*attrs = *ref_attrs;
DataType out_dtype = cfg->dtype_activation;
attrs->out_dtype = out_dtype;

Expr ret = CallNode::make(ref_call->op,
{ldata, rdata}, Attrs(attrs), ref_call->type_args);
Expr dom_scale = FoldConstant(Multiply(lhs->dom_scale, rhs->dom_scale));
return QRealizeIntExprNode::make(ret, dom_scale, out_dtype);
}

RELAY_REGISTER_OP("nn.dense")
.set_attr<FForwardRewrite>("FQRealizeRewrite", DenseRealize);


Expr MulRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
Expand Down
7 changes: 5 additions & 2 deletions topi/include/topi/cuda/dense.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,15 @@ namespace cuda {
* \param data Tensor with shape [batch, in_dim]
* \param weight Tensor with shape [out_dim, in_dim]
* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor()
* \param out_dtype Output data type. Used for mixed precision.
*
* \return Tensor with shape [batch, out_dim]
*/
inline tvm::Tensor dense_cuda(const Target& target,
const tvm::Tensor& data,
const tvm::Tensor& weight,
const tvm::Tensor& bias) {
const tvm::Tensor& bias,
const Type& out_dtype) {
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
if (bias.defined()) {
Expand All @@ -62,6 +64,7 @@ inline tvm::Tensor dense_cuda(const Target& target,
auto out_dim = weight->shape[0];

if (target->libs().count("cublas")) {
CHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported.";
auto mm = topi::contrib::cublas_matmul(data, weight, false, true);
if (bias.defined()) {
mm = tvm::compute({ batch, out_dim },
Expand All @@ -72,7 +75,7 @@ inline tvm::Tensor dense_cuda(const Target& target,

return mm;
} else {
return topi::nn::dense(data, weight, bias);
return topi::nn::dense(data, weight, bias, out_dtype);
}
}

Expand Down
9 changes: 6 additions & 3 deletions topi/include/topi/nn/dense.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ using namespace tvm;
* \param data Tensor with shape [batch, in_dim]
* \param weight Tensor with shape [out_dim, in_dim]
* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor()
* \param out_dtype Output data type. Used for mixed precision.
*
* \return Tensor with shape [batch, out_dim]
*/
inline tvm::Tensor dense(const tvm::Tensor& data,
const tvm::Tensor& weight,
const tvm::Tensor& bias) {
const tvm::Tensor& bias,
const Type& out_dtype) {
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
if (bias.defined()) {
Expand All @@ -60,14 +62,15 @@ inline tvm::Tensor dense(const tvm::Tensor& data,
auto matmul = tvm::compute(
{ batch, out_dim },
[&](Var i, Var j) {
return tvm::sum(data(i, k) * weight(j, k), { k });
return tvm::sum(tvm::cast(out_dtype, data(i, k)) *
tvm::cast(out_dtype, weight(j, k)), { k });
}, "tensor", "dense");

if (bias.defined()) {
matmul = tvm::compute(
{ batch, out_dim },
[&](Var i, Var j) {
return matmul(i, j) + bias(j);
return matmul(i, j) + tvm::cast(out_dtype, bias(j));
}, "tensor", kBroadcast);
}

Expand Down
7 changes: 5 additions & 2 deletions topi/include/topi/rocm/dense.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ namespace rocm {
* \param data Tensor with shape [batch, in_dim]
* \param weight Tensor with shape [out_dim, in_dim]
* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor()
* \param out_dtype Output data type. Used for mixed precision.
*
* \return Tensor with shape [batch, out_dim]
*/
inline tvm::Tensor dense_rocm(const Target& target,
const tvm::Tensor& data,
const tvm::Tensor& weight,
const tvm::Tensor& bias) {
const tvm::Tensor& bias,
const Type& out_dtype) {
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
if (bias.defined()) {
Expand All @@ -63,6 +65,7 @@ inline tvm::Tensor dense_rocm(const Target& target,
auto out_dim = weight->shape[0];

if (target->libs().count("rocblas")) {
CHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported.";
auto mm = topi::contrib::rocblas_matmul(data, weight, false, true);
if (bias.defined()) {
mm = tvm::compute({ batch, out_dim },
Expand All @@ -73,7 +76,7 @@ inline tvm::Tensor dense_rocm(const Target& target,

return mm;
} else {
return topi::nn::dense(data, weight, bias);
return topi::nn::dense(data, weight, bias, out_dtype);
}
}

Expand Down
4 changes: 2 additions & 2 deletions topi/python/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
"""CUDA specific declaration and schedules."""
from __future__ import absolute_import as _abs

from . import conv2d, depthwise_conv2d, conv2d_transpose_nchw, deformable_conv2d, group_conv2d_nchw
from . import conv2d, depthwise_conv2d, conv2d_transpose_nchw, deformable_conv2d, \
group_conv2d_nchw, dense
from .conv2d_hwcn import schedule_conv2d_hwcn
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc
from .group_conv2d_nchw import schedule_conv2d_nchw_cuda
from .reduction import schedule_reduce
from .softmax import schedule_softmax
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
from .dense import dense_cuda, schedule_dense
from .pooling import schedule_pool, schedule_global_pool
from .extern import schedule_extern
from .nn import schedule_lrn, schedule_l2_normalize
Expand Down
Loading