Skip to content

Commit

Permalink
[ConvertLayout] Squeeze and reduce ops (apache#7835)
Browse files Browse the repository at this point in the history
  • Loading branch information
lixiaoquan authored and Trevor Morris committed May 6, 2021
1 parent 01db55c commit 887d22c
Show file tree
Hide file tree
Showing 4 changed files with 294 additions and 17 deletions.
44 changes: 31 additions & 13 deletions src/relay/op/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,11 @@ Array<Array<Layout>> ReduceInferCorrectLayout(const Attrs& attrs,
uint32_t indim = old_in_shapes[0].size();
auto r_axes = GetReduceAxes(indim, params->axis, params->exclude);

Layout ret = Layout::Undef();
if (new_in_layouts.defined() && r_axes.size()) {
// Adapt to new layout. The axis has to change. Record original reduce axes. Convert to the
// modified layout axes.
ICHECK_EQ(new_in_layouts.size(), 1);
ICHECK_EQ(old_in_layouts.size(), 1);
Layout inferred_in = Layout::Undef();
Layout inferred_out = Layout::Undef();

// Infer [in_layout, out_layout, new_r_axes] from old_in_layout or new_in_layout
auto infer = [&](const Layout& layout) {
// 1) Collect the original axes
std::unordered_set<std::string> old_r_dims;
for (auto r_axis : r_axes) {
Expand All @@ -146,31 +144,51 @@ Array<Array<Layout>> ReduceInferCorrectLayout(const Attrs& attrs,

// 2) Collect the new axes by walking new_layout.
tvm::Array<tvm::Integer> new_r_axes;
std::string new_layout_string = "";
std::string inferred_in_string = "";
std::string inferred_out_string = "";
int axis_index = 0;
for (auto iter_var : new_in_layouts[0]->axes) {
for (auto iter_var : layout->axes) {
const auto& layout_axis = LayoutAxis::Get(iter_var);
const std::string& layout_dim = layout_axis.name();
if (old_r_dims.count(layout_dim)) {
new_r_axes.push_back(tvm::Integer(axis_index));
}
// Collect only the primal axis.
if (layout_axis.IsPrimal()) {
new_layout_string += layout_dim;
if (!old_r_dims.count(layout_dim) || params->keepdims) {
inferred_out_string += layout_dim;
}
inferred_in_string += layout_dim;
axis_index++;
}
}

// 3) Set the new axis and layout.
ret = Layout(new_layout_string);
return std::make_tuple(Layout(inferred_in_string), Layout(inferred_out_string), new_r_axes);
};

std::string new_layout_string;
Array<Integer> new_r_axes;

if (new_in_layouts.defined() && r_axes.size()) {
// Adapt to new layout. The axis has to change. Record original reduce axes. Convert to the
// modified layout axes.
ICHECK_EQ(new_in_layouts.size(), 1);
ICHECK_EQ(old_in_layouts.size(), 1);

// Get inferred_in and inferred_out from new_in_layout.
std::tie(inferred_in, inferred_out, new_r_axes) = infer(new_in_layouts[0]);
params->axis = new_r_axes;
} else if (old_in_layouts.defined()) {
// If the new layout is undefined, set the old layout as the inferred layout.
ICHECK_EQ(old_in_layouts.size(), 1);
ret = old_in_layouts[0];

// If the new layout is undefined, get inferred_in and inferred_out from old_in_layout.
if (old_in_layouts[0].defined()) {
std::tie(inferred_in, inferred_out, std::ignore) = infer(old_in_layouts[0]);
}
}

return Array<Array<Layout>>{{ret}, {ret}};
return Array<Array<Layout>>{{inferred_in}, {inferred_out}};
}

template <typename F>
Expand Down
66 changes: 65 additions & 1 deletion src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2211,6 +2211,69 @@ Array<te::Tensor> SqueezeCompute(const Attrs& attrs, const Array<te::Tensor>& in
return {topi::squeeze(inputs[0], param->axis)};
}

Array<Array<Layout>> SqueezeInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
// NOTE: Discard "const" qualifier here.
SqueezeAttrs* params = const_cast<SqueezeAttrs*>(attrs.as<SqueezeAttrs>());

Layout inferred_input = new_in_layouts.defined() ? new_in_layouts[0] : old_in_layouts[0];
Layout inferred_output = inferred_input;

ICHECK(old_in_types[0].as<TensorTypeNode>());
const auto& shape = old_in_types[0].as<TensorTypeNode>()->shape;

// axis to squeeze
Array<Integer> axis;
if (params->axis.defined()) {
axis = params->axis;
} else {
// if axes is None, squeeze all axes of dimension 1
for (size_t i = 0; i < shape.size(); i++) {
if (topi::detail::GetConstInt(shape[i]) == 1) {
axis.push_back(i);
}
}
}

// If new_in_layouts are defined, this code tries to modify the layout
if (new_in_layouts.defined() && old_in_layouts.defined()) {
Array<Integer> new_axis;
for (const auto& e : axis) {
const auto& dim = old_in_layouts[0][e];
new_axis.push_back((new_in_layouts[0]).IndexOf(dim));
}
params->axis = new_axis;
axis = new_axis;
}

// Infer output layout
Array<tir::IterVar> kept_axes;
for (size_t i = 0; i < inferred_input.ndim(); i++) {
bool is_dim_kept = true;

// Check whether the dim should be kept
for (const auto& e : axis) {
int64_t axis_val = e->value;
if (axis_val < 0) {
axis_val += inferred_input.ndim();
}
if (static_cast<int64_t>(i) == axis_val) {
is_dim_kept = false;
break;
}
}

if (is_dim_kept) {
kept_axes.push_back(inferred_input->axes[i]);
}
}
inferred_output = Layout(kept_axes);

return Array<Array<Layout>>{{inferred_input}, {inferred_output}};
}

RELAY_REGISTER_OP("squeeze")
.describe(R"code(Squeeze the input tensor at the dimensions given by axes
Expand All @@ -2223,7 +2286,8 @@ RELAY_REGISTER_OP("squeeze")
.set_support_level(3)
.add_type_rel("Squeeze", SqueezeRel)
.set_attr<FTVMCompute>("FTVMCompute", SqueezeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", SqueezeInferCorrectLayout);

// CollapseSumLike: <A, B> -> B where BroadCast(A, B) = A
bool CollapseSumLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Expand Down
14 changes: 11 additions & 3 deletions src/relay/transforms/convert_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,17 @@ class ConvertTransformMemorizer : public TransformMemorizer {
auto desired_layouts = operator->()->desired_layouts_;
if (desired_layouts.find(op->name) != desired_layouts.end()) {
tvm::Array<tvm::te::Tensor> tinfos;
for (auto expr : ref_call->args) {
auto ttype = expr->type_as<TensorTypeNode>();
tinfos.push_back(tvm::te::placeholder(ttype->shape, ttype->dtype));
for (auto& expr : ref_call->args) {
if (expr->checked_type()->IsInstance<TupleTypeNode>()) {
auto tuple_ttype_node = expr->type_as<TupleTypeNode>();
for (auto& ttype : tuple_ttype_node->fields) {
auto ttype_node = ttype.as<TensorTypeNode>();
tinfos.push_back(tvm::te::placeholder(ttype_node->shape, ttype_node->dtype));
}
} else {
auto ttype = expr->type_as<TensorTypeNode>();
tinfos.push_back(tvm::te::placeholder(ttype->shape, ttype->dtype));
}
}

Array<String> op_desired_layouts = desired_layouts.at(op->name);
Expand Down
187 changes: 187 additions & 0 deletions tests/python/relay/test_pass_convert_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,6 +1556,191 @@ def expected():
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


def test_conv_squeeze_convert_layout():
def _test_conv_squeeze_convert_layout1():
# specified axis is squeezed
def before():
x = relay.var("x", shape=(1, 1, 1, 2048))
weight = relay.var("weight", shape=(1, 1, 2048, 1000))
y = relay.nn.conv2d(
x,
weight,
channels=1000,
kernel_size=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
y = relay.nn.relu(y)
y = relay.squeeze(y, axis=[-3])
return relay.Function(analysis.free_vars(y), y)

def expected():
x = relay.var("x", shape=(1, 1, 1, 2048))
weight = relay.var("weight", shape=(1, 1, 2048, 1000))
weight = relay.layout_transform(weight, "HWIO", "OIHW")
x = relay.layout_transform(x, "NHWC", "NCHW")
y = relay.nn.conv2d(x, weight, channels=1000, kernel_size=(1, 1))
y = relay.nn.relu(y)
y = relay.squeeze(y, axis=[2])
y = relay.layout_transform(y, "NCW", "NWC")
return relay.Function(analysis.free_vars(y), y)

a = before()
a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)

def _test_conv_squeeze_convert_layout2():
# all axes of dimension 1 are squeezed
def before():
x = relay.var("x", shape=(1, 1, 1, 2048))
weight = relay.var("weight", shape=(1, 1, 2048, 1000))
y = relay.nn.conv2d(
x,
weight,
channels=1000,
kernel_size=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
y = relay.nn.relu(y)
y = relay.squeeze(y)
return relay.Function(analysis.free_vars(y), y)

def expected():
x = relay.var("x", shape=(1, 1, 1, 2048))
weight = relay.var("weight", shape=(1, 1, 2048, 1000))
weight = relay.layout_transform(weight, "HWIO", "OIHW")
x = relay.layout_transform(x, "NHWC", "NCHW")
y = relay.nn.conv2d(x, weight, channels=1000, kernel_size=(1, 1))
y = relay.nn.relu(y)
y = relay.squeeze(y, [0, 2, 3])
return relay.Function(analysis.free_vars(y), y)

a = before()
a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)

def _test_conv_squeeze_convert_layout3():
# squeeze axis is empty
def before():
x = relay.var("x", shape=(1, 1, 1, 2048))
weight = relay.var("weight", shape=(1, 1, 2048, 1000))
y = relay.nn.conv2d(
x,
weight,
channels=1000,
kernel_size=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
y = relay.nn.relu(y)
y = relay.squeeze(y, axis=[])
return relay.Function(analysis.free_vars(y), y)

def expected():
x = relay.var("x", shape=(1, 1, 1, 2048))
weight = relay.var("weight", shape=(1, 1, 2048, 1000))
weight = relay.layout_transform(weight, "HWIO", "OIHW")
x = relay.layout_transform(x, "NHWC", "NCHW")
y = relay.nn.conv2d(x, weight, channels=1000, kernel_size=(1, 1))
y = relay.nn.relu(y)
y = relay.squeeze(y, axis=[])
y = relay.layout_transform(y, "NCHW", "NHWC")
return relay.Function(analysis.free_vars(y), y)

a = before()
a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)

_test_conv_squeeze_convert_layout1()
_test_conv_squeeze_convert_layout2()
_test_conv_squeeze_convert_layout3()


def test_conv_reduce_convert_layout():
def _test_conv_reduce_convert_layout1():
def before():
x = relay.var("x", shape=(1, 1, 1, 2048))
weight = relay.var("weight", shape=(1, 1, 2048, 1000))
y = relay.nn.conv2d(
x,
weight,
channels=1000,
kernel_size=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
y = relay.nn.relu(y)
y = relay.sum(y, axis=(1, 2))
y = relay.sum(y, axis=(1,))
y = relay.sum(y)
y = relay.sum(y)
return relay.Function(analysis.free_vars(y), y)

def expected():
x = relay.var("x", shape=(1, 1, 1, 2048))
weight = relay.var("weight", shape=(1, 1, 2048, 1000))
weight = relay.layout_transform(weight, "HWIO", "OIHW")
x = relay.layout_transform(x, "NHWC", "NCHW")
y = relay.nn.conv2d(x, weight, channels=1000, kernel_size=(1, 1))
y = relay.nn.relu(y)
y = relay.sum(y, axis=(2, 3))
y = relay.sum(y, axis=(1,))
y = relay.sum(y)
y = relay.sum(y)
return relay.Function(analysis.free_vars(y), y)

a = before()
a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)

def _test_conv_reduce_convert_layout2():
def before():
x = relay.var("x", shape=(1, 38, 38, 512))
weight = relay.var("weight", shape=(3, 3, 512, 512))
y = relay.nn.conv2d(
x,
weight,
channels=512,
kernel_size=(3, 3),
data_layout="NHWC",
kernel_layout="HWIO",
)
y = relay.nn.relu(y)
y = relay.multiply(y, y)
y = relay.sum(y, axis=(3,), keepdims=True)
return relay.Function(analysis.free_vars(y), y)

def expected():
x = relay.var("x", shape=(1, 38, 38, 512))
weight = relay.var("weight", shape=(3, 3, 512, 512))
weight = relay.layout_transform(weight, "HWIO", "OIHW")
x = relay.layout_transform(x, "NHWC", "NCHW")
y = relay.nn.conv2d(x, weight, channels=512, kernel_size=(3, 3))
y = relay.nn.relu(y)
y = relay.multiply(y, y)
y = relay.sum(y, axis=(1,), keepdims=True)
y = relay.layout_transform(y, "NCHW", "NHWC")
return relay.Function(analysis.free_vars(y), y)

a = before()
a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)

_test_conv_reduce_convert_layout1()
_test_conv_reduce_convert_layout2()


if __name__ == "__main__":
test_qnn_binary_no_convert_layout()
test_no_convert_layout()
Expand Down Expand Up @@ -1584,3 +1769,5 @@ def expected():
test_different_ops_convert_layout()
test_no_desired_layout()
test_convert_with_config()
test_conv_squeeze_convert_layout()
test_conv_reduce_convert_layout()

0 comments on commit 887d22c

Please sign in to comment.