Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ConvertLayout] slice_like support #7184

Merged
merged 1 commit into from
Jan 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2711,6 +2711,46 @@ Expr MakeSliceLike(Expr data, Expr shape_like, Array<Integer> axes) {
return Call(op, {data, shape_like}, Attrs(attrs), {});
}

Array<Array<Layout>> SliceLikeInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
Array<Integer> new_axes;
if (old_in_layouts.defined() && new_in_layouts.defined()) {
ICHECK_EQ(new_in_layouts.size(), 2);
ICHECK_EQ(new_in_layouts[0]->name, new_in_layouts[1]->name);
ICHECK_EQ(old_in_layouts.size(), 2);
ICHECK_EQ(old_in_layouts[0]->name, old_in_layouts[1]->name);

auto old_layout = old_in_layouts[0];
auto new_layout = new_in_layouts[0];

// Discard "const" qualifier.
auto* params = const_cast<SliceLikeAttrs*>(attrs.as<SliceLikeAttrs>());
ICHECK(params != nullptr);

for (auto axis : params->axes) {
auto new_axis = new_layout.IndexOf(old_layout[axis->value]);
// Cannot find the target axis in the new layout.
if (new_axis == -1) {
new_axes.clear();
break;
}
new_axes.push_back(new_axis);
}
if (!new_axes.empty()) {
params->axes = std::move(new_axes);
return Array<Array<Layout>>({{new_layout, new_layout}, {new_layout}});
}
}

if (old_in_layouts.defined()) {
ICHECK_EQ(old_in_layouts.size(), 2);
return {{old_in_layouts[0], old_in_layouts[1]}, {old_in_layouts[1]}};
}
return Array<Array<Layout>>({{Layout::Undef(), Layout::Undef()}, {Layout::Undef()}});
}

Array<te::Tensor> SliceLikeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<SliceLikeAttrs>();
Expand Down Expand Up @@ -2760,6 +2800,7 @@ RELAY_REGISTER_OP("slice_like")
.set_support_level(10)
.add_type_rel("SliceLike", SliceLikeRel)
.set_attr<FTVMCompute>("FTVMCompute", SliceLikeCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", SliceLikeInferCorrectLayout)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// relay.layout_transform
Expand Down
70 changes: 70 additions & 0 deletions tests/python/relay/test_pass_convert_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,75 @@ def before():
assert len(has_lt) == 1


def test_slice_like_convert_layout():
def verify_slice_like(after, expected_axes):
# Verify if the slice_like after the convert layout has the expected axes.
has_expected = list()
checker = lambda x: has_expected.append(
isinstance(x, tvm.relay.expr.Call)
and x.op.name == "slice_like"
and str(x.attrs.axes) == str(expected_axes)
)
relay.analysis.post_order_visit(after, checker)
assert any(has_expected)

def func_nhwc():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var("weight1", shape=(3, 3, 64, 32))
y = relay.nn.conv2d(
x,
weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
out = relay.slice_like(y, y, axes=[1, 2])
return relay.Function(analysis.free_vars(out), out)

after = run_opt_pass(func_nhwc(), transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
verify_slice_like(after, [2, 3])

def func_nchw():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var("weight1", shape=(32, 64, 3, 3))
y = relay.nn.conv2d(
x,
weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
)
out = relay.slice_like(y, y, axes=[2, 3])
return relay.Function(analysis.free_vars(out), out)

after = run_opt_pass(func_nchw(), transform.ConvertLayout({"nn.conv2d": ["NHWC", "default"]}))
verify_slice_like(after, [1, 2])

def func_vars():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var("weight1", shape=(3, 3, 64, 32))
y = relay.nn.conv2d(
x,
weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
# z has no layout information so convert layout won't happen.
z = relay.var("y", shape=(1, 56, 56, 32))
out = relay.slice_like(y, z, axes=[1, 2])
return relay.Function(analysis.free_vars(out), out)

after = run_opt_pass(func_vars(), transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
verify_slice_like(after, [1, 2])


def test_resnet_convert_layout():
def before():
x = relay.var("x", shape=(1, 56, 56, 64))
Expand Down Expand Up @@ -1412,6 +1481,7 @@ def expected():
test_conv_concat_convert_layout()
test_dual_path_convert_layout()
test_bn_convert_layout()
test_slice_like_convert_layout()
test_resnet_convert_layout()
test_scalar_convert_layout()
test_conv_bn_convert_layout()
Expand Down