Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][AlterOp] Improving support for broadcast layout alteration. #4040

Merged
merged 1 commit into from
Oct 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions include/tvm/data_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,28 @@ class Layout : public NodeRef {
return ct;
}

/*!
* \brief Returns a new layout where the dims have been expanded to match the primal dimensions.
* \param dst_layout The dst layout to which current layout has to be expanded.
* \return The expanded Layout.
*/
inline Layout ExpandPrimal(const Layout& dst_layout) {
Layout new_src_layout;
// 1) Find the axis which are missing in the current layout. Make them the prefix.
std::string new_src_layout_str = "";
for (auto dst_axis : dst_layout->axes) {
if (LayoutAxis::Get(dst_axis).IsPrimal()) {
if (!this->Contains(LayoutAxis::Get(dst_axis))) {
new_src_layout_str += dst_axis->var->name_hint;
}
}
}
// 2) Now, add the primal axis of the current layout.
new_src_layout_str += this->name();
new_src_layout = Layout(new_src_layout_str);
return new_src_layout;
}

/*!
* \brief return the index of the input axis.
* If it is not found in the layout or the layout is undefined,
Expand Down
1 change: 1 addition & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "../op_common.h"
#include "../../../arithmetic/compute_expr.h"
#include "../../pass/alter_op_layout.h"
#include "../../pass/pattern_util.h"
#include "transform.h"

namespace tvm {
Expand Down
41 changes: 29 additions & 12 deletions src/relay/pass/alter_op_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,43 @@
#include <unordered_map>

#include "alter_op_layout.h"
#include "pattern_util.h"

namespace tvm {
namespace relay {

namespace alter_op_layout {

// Make a transform CallNode
/* Performs 2 operations
* 1) If src_layout ndim is smaller then dst_layout, expand_dim is inserted to match the dim size.
* For example, src_layout = C, dst_layout = NCHW16c. The src is expanded to NHWC.
* 2) Call layout transform with new src layout.
*/
Expr TransformLayout(Expr raw, Layout src_layout, Layout dst_layout) {
if (src_layout.Equals(dst_layout)) { return raw; }
CHECK(src_layout.defined() && dst_layout.defined())
<< "Cannot insert layout transform because there are undefined layouts";
CHECK(BijectiveLayoutNode::make(src_layout, dst_layout).defined())
<< "Cannot insert layout transform because there are inconvertible layouts: "
<< src_layout << " v.s. " << dst_layout;
static auto &transform_op = Op::Get("layout_transform");
NodePtr<LayoutTransformAttrs> attrs = make_node<LayoutTransformAttrs>();
attrs->src_layout = src_layout.name();
attrs->dst_layout = dst_layout.name();
Call transform = CallNode::make(transform_op, {raw}, Attrs{attrs});
return std::move(transform);
if (src_layout.Equals(dst_layout)) {
return raw;
}

// 1) Check if the shape lengths are different. If yes, expand dims.
Expr input_expr = raw;
Layout new_src_layout = src_layout;
if (src_layout.ndim_primal() < dst_layout.ndim_primal()) {
int num_new_axis = dst_layout.ndim_primal() - src_layout.ndim_primal();
new_src_layout = src_layout.ExpandPrimal(dst_layout);
input_expr = MakeExpandDims(input_expr, 0, num_new_axis);
if (new_src_layout.Equals(dst_layout)) {
return input_expr;
}
}

// 2) Insert layout transform on the transformed src.
CHECK(new_src_layout.defined() && dst_layout.defined())
<< "Cannot insert layout transform because there are undefined layouts";
CHECK(BijectiveLayoutNode::make(new_src_layout, dst_layout).defined())
<< "Cannot insert layout transform because there are inconvertible layouts: "
<< new_src_layout << " v.s. " << dst_layout;
return MakeLayoutTransform(input_expr, new_src_layout.name(), dst_layout.name());
}

// Memorize layout transform so we can reuse internal transformed nodes
Expand Down
92 changes: 75 additions & 17 deletions src/relay/pass/alter_op_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,57 @@

#include <tvm/data_layout.h>
#include <tvm/relay/expr.h>
#include <string>

namespace tvm {
namespace relay {

/*!
* \brief Returns a new layout where the subordinate factors are adjusted based on the tensor
* shape.
* \param old_layout The old layout before any transformation.
* \param old_shape The shape of the original tensor.
* \return The adjusted Layout.
*/
inline Layout AdjustSubordinateFactors(const Layout& src_layout, const Layout& old_layout,
const Array<tvm::Expr>& old_shape) {
// For each subordinate axis
// 1) Find the corresponding dual axis.
// 2) Find the Index of this dual axis in old_layout.
// 3) Find the shape of the that axis in old_shape.
// 4) a) Adjust factor to 1, if that shape is 1. b) Else retain the factor.
std::string new_layout;
for (auto axis : src_layout->axes) {
if (!LayoutAxis::Get(axis).IsPrimal()) {
// 1) Find the corresponding dual axis
auto dual_axis = LayoutAxis::Get(axis).ToPrimal().name()[0];

// 2) Find the index of this dual axis in old_layout
int old_axis = old_layout.IndexOf(LayoutAxis::Get(dual_axis));
Comment on lines +56 to +59
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually you don't need to do .name() first then ::Get, ToPrimal() already returns LayoutAxis

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried that. But it does not work because some of the methods are private. I didn't want to change the class description.


// 3) Find the shape of this index in old_shape
auto shape_val = old_shape[old_axis];

// 4) a) Check if this shape element is 1.
bool is_shape_one = false;
if (auto* shape_int = shape_val.as<IntImm>()) {
if (shape_int->value == 1) {
new_layout += "1";
is_shape_one = true;
}
}

// 4) b) If shape is not 1, retain the factor.
if (!is_shape_one) {
auto new_shape_val = src_layout.FactorOf(LayoutAxis::Get(dual_axis));
new_layout += std::to_string(new_shape_val);
}
}
new_layout += LayoutAxis::Get(axis).name();
}
return Layout(new_layout);
}

/*!
* \brief Infer & correct function of node layout. See \p Layout for layout convention
* \param attrs The attribute of the node.
Expand Down Expand Up @@ -111,28 +158,39 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
int scalar = layouts[0].ndim() == 0 ? 0 : 1;
return Array<Array<Layout> >{layouts, {layouts[1-scalar]}};
} else {
// try to broadcast the tensors to the larger dimension
// Set the layout of the larger dimension. If one dimension size is lower, we call expand dims
// while transforming layout.
int large_idx = layouts[0].ndim_primal() >= layouts[1].ndim_primal() ? 0 : 1;
int small_idx = 1 - large_idx;
Layout ret = layouts[large_idx];

// extract common part
size_t i = layouts[large_idx].ndim();
for (; i != 0; --i) {
const auto& axis = layouts[large_idx][i-1];
if (!layouts[small_idx].Contains(axis.ToPrimal())) {
break;
}
}

Layout common_part = layouts[large_idx].SubLayout(i, layouts[large_idx].ndim() - i);
if (!BijectiveLayoutNode::make(layouts[small_idx], common_part).defined()) {
// not convertible
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
if (old_in_layouts[0].Equals(old_in_layouts[1])) {
// Support scenarios where original operands were of type [N, H, W, C] and [N, H, W, 1]
// In this case, we might have NCHW16c coming for 1 operand. However, the other operand does
// not have enough C dimension. To reuse broadcasting, we would want to use NCHW1c for the
// second operand. The following section of code walks through the layouts and shapes to
// perform that operation.
// a in NCHWC16c
// b in NHW1
// b = layout_transform(b) from NHW1 -> NCHW1c
// add(a, b)
auto old_small_shape = old_in_shapes[small_idx];
auto old_small_layout = old_in_layouts[small_idx];
auto new_small_layout =
AdjustSubordinateFactors(layouts[large_idx], old_small_layout, old_small_shape);
layouts.Set(small_idx, new_small_layout);
} else {
// Support scenarios where original operands were of type [N, H, W, C] and [C]. In this case,
// while transforming the layout, we expand dims to make C go to NHWC, and then use the
// modified layout of the first operator to call the layout transform. E.g.
// a in NCHWC16c
// b in C
// b = expand_dims(b) from C -> NHWC
// b = layout_transform(b) from NHWC -> NCHW16c
// add(a, b)
layouts.Set(small_idx, ret);
}

layouts.Set(small_idx, common_part);
return Array<Array<Layout> > {layouts, {ret}};
return Array<Array<Layout>>{layouts, {ret}};
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/relay/pass/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,8 @@ Expr MakeSqueeze(Expr data, Array<Integer> axis);

Expr MakeExpandDims(Expr data, int axis, int num_newaxis);

Expr MakeLayoutTransform(Expr data, std::string src_layout, std::string dst_layout);

Expr StopFusion(Expr data);

Expr CastHint(Expr data, DataType dtype);
Expand Down
14 changes: 1 addition & 13 deletions src/relay/qnn/op/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,19 +242,7 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& zp_kernel, const QnnC
if (param->kernel_zero_point != 1) {
multiplied_t2 = Multiply(zp_kernel, reduced_t2);
}

// Replicate to go back to NHWC/NCHW. This is not necessarily needed, but it fails AlterOpLayout.
// We can remove this once AlterOpLayout refactoring completes -
// https://github.com/dmlc/tvm/issues/3670
Array<Integer> reps;
if (param->data_layout == "NCHW") {
reps = {1, out_channels, 1, 1};
} else if (param->data_layout == "NHWC") {
reps = {1, 1, 1, out_channels};
} else {
LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout";
}
return Tile(multiplied_t2, reps);
return multiplied_t2;
}

/*
Expand Down
34 changes: 34 additions & 0 deletions tests/python/relay/test_op_qnn_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,39 @@ def tflite_anistropic_strides():
golden_output = np.array((124, -92, 164, -132)).reshape(1, 1, 2, 2)
np.testing.assert_equal(qnn_output, golden_output)

def broadcast_layout_test():
# Test broadcast support for NHWC layout.
data_shape = (1, 229, 229, 3) # NHWC
data_dtype = 'uint8'
kernel_shape = (7, 7, 3, 64) # HWIO
kernel_dtype = 'int8'
_, qnn_func = get_funcs(data_shape=data_shape,
data_dtype=data_dtype,
kernel_shape=kernel_shape,
kernel_dtype=kernel_dtype,
input_zero_point=8,
kernel_zero_point=3,
kernel_size=(7, 7),
padding=(1, 1),
strides=(1, 1),
dilation=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
out_dtype="int32")
func = qnn_func['main'].body
bias = relay.var("bias", shape=(64,), dtype="int32")
bias2 = relay.var("bias2", shape=(1, 225, 225, 1), dtype="int32")

# Check broadcast support on both lhs and rhs
func = relay.add(func, bias2)
func = relay.add(bias2, func)
func = relay.add(bias, func)
func = relay.add(func, bias)
func = relay.Function(relay.analysis.free_vars(func), func)
mod = relay.Module.from_expr(func)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512")

if __name__ == "__main__":
no_zero_point_test()
input_zero_point_test()
Expand All @@ -621,3 +654,4 @@ def tflite_anistropic_strides():
tflite_large_irregular_test()
tflite_output_multiplier_greater_than_one()
tflite_anistropic_strides()
broadcast_layout_test()
9 changes: 6 additions & 3 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def expected():
kernel_layout="OIHW16i",
data_layout="NCHW16c")
b = relay.expand_dims(bias, axis=1, num_newaxis=2)
b = relay.layout_transform(b, "CHW", "CHW16c")
b = relay.expand_dims(b, axis=0, num_newaxis=1)
b = relay.layout_transform(b, "NCHW", "NCHW16c")
y = relay.add(y, b)

y = relay.nn.relu(y)
Expand Down Expand Up @@ -304,8 +305,10 @@ def expected():
weight = relay.var("weight")
x = relay.layout_transform(x, "NCHW", "NCHW16c")
bias = relay.expand_dims(bias, 1, 2)
bias = relay.layout_transform(bias, "CHW", "CHW16c")
scale = relay.layout_transform(scale, "CHW", "CHW16c")
bias = relay.expand_dims(bias, 0, 1)
bias = relay.layout_transform(bias, "NCHW", "NCHW16c")
scale = relay.expand_dims(scale, 0, 1)
scale = relay.layout_transform(scale, "NCHW", "NCHW16c")
y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1),
data_layout="NCHW16c")
y = relay.add(y, bias) # test broadcasting to lhs
Expand Down