Skip to content

Commit

Permalink
[Layout] Unify dense op input layout (#8921)
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi authored Sep 7, 2021
1 parent 0034732 commit 0fb840e
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 10 deletions.
4 changes: 2 additions & 2 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -1017,8 +1017,8 @@ struct DensePackAttrs : public tvm::AttrsNode<DensePackAttrs> {
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
TVM_ATTR_FIELD(weight_layout)
.set_default("NK")
.describe("Dimension ordering of weight. Packed layouts, such as NK8n, are possible.");
.set_default("NC")
.describe("Dimension ordering of weight. Packed layouts, such as NC8n, are possible.");
}
};

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,7 +1548,7 @@ def dense(data, weight, units=None, out_dtype=""):
return _make.dense(data, weight, units, out_dtype)


def contrib_dense_pack(data, weight, weight_layout="NK", units=None, out_dtype=""):
def contrib_dense_pack(data, weight, weight_layout="NC", units=None, out_dtype=""):
"""Dense operator.
Applies a linear transformation with packed weight
Expand All @@ -1567,7 +1567,7 @@ def contrib_dense_pack(data, weight, weight_layout="NK", units=None, out_dtype="
of shape `(units // pack_weight_tile, units_in, pack_weight_tile)`.
weight_layout: str
The layout of weight, such as "NK" or "NK8n".
The layout of weight, such as "NC" or "NC8n".
units : int, optional
Number of hidden units of the dense transformation.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/x86/dense_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type):
if cfg.is_fallback:
_default_dense_pack_config(cfg, M, N, K)
packw_bn = cfg["tile_x"].size[-1]
weight_layout = "NK%dn" % packw_bn
weight_layout = "NC%dn" % packw_bn
new_weight = te.placeholder(
(N // packw_bn, K, packw_bn),
dtype=weight_tensor.dtype,
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ InferCorrectLayoutOutput DenseInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
return InferCorrectLayoutOutput({"NC", "NK"}, {"NC"}, attrs);
return InferCorrectLayoutOutput({"NC", "NC"}, {"NC"}, attrs);
}

TVM_REGISTER_GLOBAL("relay.op.nn._make.dense").set_body_typed(MakeDense);
Expand Down
10 changes: 10 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1576,6 +1576,13 @@ class LinearNoBias(Module):
def forward(self, input, weight):
return F.linear(input, weight)

class LinearNested(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y, z):
return F.linear(x, F.linear(y, z))

input2d = torch.rand([2, 2]).float()
input3d = torch.rand([4, 3, 2]).float()
weight1d = torch.rand([2]).float()
Expand All @@ -1595,6 +1602,9 @@ def forward(self, input, weight):
verify_model(LinearNoBias(), input_data=[input2d, weight1d])
# 3D input, 2D weight, no bias
verify_model(LinearNoBias(), input_data=[input3d, weight3x2])

verify_model(LinearNested(), input_data=[torch.randn(10, 10) for _ in range(3)])

# TODO: Add the following cases when matmul(1D, _) is supported by TVM
# 1D input, 2D weight, 1D bias
# 1D input, 2D weight, no bias
Expand Down
8 changes: 4 additions & 4 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,8 +1317,8 @@ def before():
def expected():
x = relay.var("x", shape=(32, 64))
weight = relay.var("weight", shape=(48, 64))
target_layout = "NK16n"
weight_transform = relay.layout_transform(weight, "NK", target_layout)
target_layout = "NC16n"
weight_transform = relay.layout_transform(weight, "NC", target_layout)
y = relay.nn.contrib_dense_pack(
x, weight_transform, target_layout, units=None, out_dtype="float32"
)
Expand Down Expand Up @@ -1387,8 +1387,8 @@ def expected():
squeeze = relay.squeeze(pool, axis=[2, 3])
dense = relay.nn.contrib_dense_pack(
relay.layout_transform(squeeze, "NC8c", "NC"),
relay.layout_transform(dense_weight, "NK", "NK16n"),
"NK16n",
relay.layout_transform(dense_weight, "NC", "NC16n"),
"NC16n",
out_dtype="float32",
)
return relay.Function(analysis.free_vars(dense), dense)
Expand Down

0 comments on commit 0fb840e

Please sign in to comment.