Skip to content

Commit

Permalink
update doc to clarify that only 2D input data is supported
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Aug 5, 2021
1 parent 1c2ec67 commit 95a3299
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 12 deletions.
6 changes: 3 additions & 3 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,9 +1259,9 @@ def dense_shape_func(attrs, inputs, _):
@script
def _dense_pack_shape_func(data_shape, weight_shape):
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(out.shape[0] - 1):
out[i] = data_shape[i]
out[out.shape[0] - 1] = weight_shape[0] * weight_shape[2]
assert data_shape.shape[0] == 2, "Input data must be 2D"
out[0] = data_shape[0]
out[1] = weight_shape[0] * weight_shape[2]

return out

Expand Down
7 changes: 3 additions & 4 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1550,7 +1550,7 @@ def dense(data, weight, units=None, out_dtype=""):

def contrib_dense_pack(data, weight, units=None, out_dtype=""):
"""Dense operator.
Applies a linear transformation
Applies a linear transformation with packed weight
.. math::
Expand All @@ -1560,7 +1560,7 @@ def contrib_dense_pack(data, weight, units=None, out_dtype=""):
----------
data : tvm.relay.Expr
The input data to the operator,
of shape `(d_1, d_2, ..., d_n, units_in)`.
of shape `(batch, units_in)`.
weight : tvm.relay.Expr
The transformed weight expressions, 3-D matrix,
Expand All @@ -1570,8 +1570,7 @@ def contrib_dense_pack(data, weight, units=None, out_dtype=""):
Number of hidden units of the dense transformation.
out_dtype : str, optional
Specifies the output data type for mixed precision dense,
of shape `(d_1, d_2, ..., d_n, units)`.
Specifies the output data type for mixed precision dense.
Returns
-------
Expand Down
10 changes: 5 additions & 5 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,9 @@ bool DensePackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const DenseAttrs* param = attrs.as<DenseAttrs>();
ICHECK(param != nullptr);

ICHECK(data->shape.size() == 2) << "Input data must be 2D";
Array<tvm::PrimExpr> oshape = data->shape;
oshape.Set((oshape.size() - 1), weight->shape[0] * weight->shape[2]);
oshape.Set(1, weight->shape[0] * weight->shape[2]);

DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
Expand All @@ -261,17 +262,16 @@ bool DensePackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
RELAY_REGISTER_OP("nn.contrib_dense_pack")
.describe(R"code(Applies a linear transformation: :math:`Y = XW^T`.
- **data**: `(x1, x2, ..., xn, input_dim)`
- **data**: `(batch, input_dim)`
- **weight**: `(units // pack_weight_tile, input_dim, pack_weight_tile)`
- **out**: `(x1, x2, ..., xn, units)`.
- **out**: `(batch, units)`.
)code" TVM_ADD_FILELINE)
.set_attrs_type<DenseAttrs>()
.set_num_inputs(2)
.add_argument("data", "nD Tensor", "Input data.")
.add_argument("data", "2D Tensor", "Input data.")
.add_argument("weight", "3D Tensor", "Packed weight matrix.")
.set_support_level(10)

.add_type_rel("DensePack", DensePackRel);
// ------------------- relay.nn.contrib_dense_pack

Expand Down

0 comments on commit 95a3299

Please sign in to comment.