diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 96cef8bc35889..a9ccc5aa2d241 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -1259,9 +1259,9 @@ def dense_shape_func(attrs, inputs, _): @script def _dense_pack_shape_func(data_shape, weight_shape): out = output_tensor((data_shape.shape[0],), "int64") - for i in const_range(out.shape[0] - 1): - out[i] = data_shape[i] - out[out.shape[0] - 1] = weight_shape[0] * weight_shape[2] + assert data_shape.shape[0] == 2, "Input data must be 2D" + out[0] = data_shape[0] + out[1] = weight_shape[0] * weight_shape[2] return out diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 64b397a4d4f9b..531a35c94200a 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -1550,7 +1550,7 @@ def dense(data, weight, units=None, out_dtype=""): def contrib_dense_pack(data, weight, units=None, out_dtype=""): """Dense operator. - Applies a linear transformation + Applies a linear transformation with packed weight .. math:: @@ -1560,7 +1560,7 @@ def contrib_dense_pack(data, weight, units=None, out_dtype=""): ---------- data : tvm.relay.Expr The input data to the operator, - of shape `(d_1, d_2, ..., d_n, units_in)`. + of shape `(batch, units_in)`. weight : tvm.relay.Expr The transformed weight expressions, 3-D matrix, @@ -1570,8 +1570,7 @@ def contrib_dense_pack(data, weight, units=None, out_dtype=""): Number of hidden units of the dense transformation. out_dtype : str, optional - Specifies the output data type for mixed precision dense, - of shape `(d_1, d_2, ..., d_n, units)`. + Specifies the output data type for mixed precision dense. Returns ------- diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index af9ede279f6e6..1384a2f13f955 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -246,8 +246,9 @@ bool DensePackRel(const Array& types, int num_inputs, const Attrs& attrs, const DenseAttrs* param = attrs.as(); ICHECK(param != nullptr); + ICHECK(data->shape.size() == 2) << "Input data must be 2D"; Array oshape = data->shape; - oshape.Set((oshape.size() - 1), weight->shape[0] * weight->shape[2]); + oshape.Set(1, weight->shape[0] * weight->shape[2]); DataType out_dtype = param->out_dtype; if (out_dtype.bits() == 0) { @@ -261,17 +262,16 @@ bool DensePackRel(const Array& types, int num_inputs, const Attrs& attrs, RELAY_REGISTER_OP("nn.contrib_dense_pack") .describe(R"code(Applies a linear transformation: :math:`Y = XW^T`. -- **data**: `(x1, x2, ..., xn, input_dim)` +- **data**: `(batch, input_dim)` - **weight**: `(units // pack_weight_tile, input_dim, pack_weight_tile)` -- **out**: `(x1, x2, ..., xn, units)`. +- **out**: `(batch, units)`. )code" TVM_ADD_FILELINE) .set_attrs_type() .set_num_inputs(2) - .add_argument("data", "nD Tensor", "Input data.") + .add_argument("data", "2D Tensor", "Input data.") .add_argument("weight", "3D Tensor", "Packed weight matrix.") .set_support_level(10) - .add_type_rel("DensePack", DensePackRel); // ------------------- relay.nn.contrib_dense_pack