Skip to content

Commit

Permalink
[Relay,Topi] Correlation
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed May 20, 2020
1 parent 9a9fe97 commit 308e8b5
Show file tree
Hide file tree
Showing 18 changed files with 892 additions and 1 deletion.
30 changes: 30 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,36 @@ struct SubPixelAttrs : public tvm::AttrsNode<SubPixelAttrs> {
}
}; // struct SubPixelAttrs

/*! \brief Attributes used in correlation operators */
struct CorrelationAttrs : public tvm::AttrsNode<CorrelationAttrs> {
int kernel_size;
int max_displacement;
int stride1;
int stride2;
Array<IndexExpr> padding;
bool is_multiply;
String layout;

TVM_DECLARE_ATTRS(CorrelationAttrs, "relay.attrs.CorrelationAttrs") {
TVM_ATTR_FIELD(kernel_size)
.describe("Kernel size for correlation, must be an odd number.")
.set_default(1);
TVM_ATTR_FIELD(max_displacement).describe("Max displacement of Correlation.").set_default(1);
TVM_ATTR_FIELD(stride1).describe("Stride for data1.").set_default(1);
TVM_ATTR_FIELD(stride2).describe("Stride for data2.").set_default(1);
TVM_ATTR_FIELD(padding)
.describe("Padding for data1 and data2.")
.set_default(Array<IndexExpr>{0, 0});
TVM_ATTR_FIELD(is_multiply)
.describe("Operation type is either multiplication or substraction.")
.set_default(true);
TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
"Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively.");
}
}; // struct CorrelationAttrs

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_NN_H_
14 changes: 14 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,19 @@ def _mx_space_to_depth(inputs, attrs):
return _op.nn.space_to_depth(*inputs, **new_attrs)


def _mx_correlation(inputs, attrs):
assert len(inputs) == 2
new_attrs = {}
new_attrs["kernel_size"] = attrs.get_int("kernel_size", 1)
new_attrs["max_displacement"] = attrs.get_int("max_displacement", 1)
new_attrs["stride1"] = attrs.get_int("stride1", 1)
new_attrs["stride2"] = attrs.get_int("stride2", 1)
new_attrs["padding"] = attrs.get_int("pad_size", 0)
new_attrs["is_multiply"] = attrs.get_bool("is_multiply", True)
new_attrs["layout"] = "NCHW"
return _op.nn.correlation(*inputs, **new_attrs)


def _mx_contrib_fifo_buffer(inputs, attrs):
new_attrs = {}
new_attrs['axis'] = attrs.get_int('axis')
Expand Down Expand Up @@ -1942,6 +1955,7 @@ def impl(inputs, input_types):
"one_hot" : _mx_one_hot,
"depth_to_space" : _mx_depth_to_space,
"space_to_depth" : _mx_space_to_depth,
"Correlation" : _mx_correlation,
# vision
"_contrib_BilinearResize2D" : _mx_resize,
"_contrib_MultiBoxPrior" : _mx_multibox_prior,
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,11 @@ def compute_space_to_depth(attrs, inputs, out_dtype):
reg.register_pattern("nn.space_to_depth", OpPattern.INJECTIVE)


# correlation
reg.register_strategy("nn.correlation", strategy.correlation_strategy)
reg.register_pattern("nn.correlation", OpPattern.OUT_ELEMWISE_FUSABLE)


#####################
# Shape functions #
#####################
Expand Down
81 changes: 81 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2761,3 +2761,84 @@ def global_avg_pool3d(data,
"""
output_size = [1, 1, 1]
return _make.adaptive_avg_pool3d(data, output_size, layout)


def correlation(data1, data2, kernel_size, max_displacement, stride1, stride2, padding,
is_multiply, layout):
r"""Applies correlation to inputs.
The correlation layer performs multiplicative patch comparisons between two feature maps.
Given two multi-channel feature maps :math:`f_{1}, f_{2}`, with :math:`w`, :math:`h`, and
:math:`c` being their width, height, and number of channels, the correlation layer lets the
network compare each patch from :math:`f_{1}` with each patch from :math:`f_{2}`.
For now we consider only a single comparison of two patches. The 'correlation' of two patches
centered at :math:`x_{1}` in the first map and :math:`x_{2}` in the second map is then defined
as:
.. math::
c(x_{1}, x_{2}) = \sum_{o \in [-k,k] \times [-k,k]} <f_{1}(x_{1} + o), f_{2}(x_{2} + o)>
for a square patch of size :math:`K:=2k+1`.
Note that the equation above is identical to one step of a convolution in neural networks, but
instead of convolving data with a filter, it convolves data with other data. For this
reason, it has no training weights.
Computing :math:`c(x_{1}, x_{2})` involves :math:`c * K^{2}` multiplications. Comparing all
patch combinations involves :math:`w^{2}*h^{2}` such computations.
Given a maximum displacement :math:`d`, for each location :math:`x_{1}` it computes
correlations :math:`c(x_{1}, x_{2})` only in a neighborhood of size :math:`D:=2d+1`,
by limiting the range of :math:`x_{2}`. We use strides :math:`s_{1}, s_{2}`, to quantize
:math:`x_{1}` globally and to quantize :math:`x_{2}` within the neighborhood
centered around :math:`x_{1}`.
The final output is defined by the following expression:
.. math::
out[n, q, i, j] = c(x_{i, j}, x_{q})
where :math:`i` and :math:`j` enumerate spatial locations in :math:`f_{1}`, and :math:`q`
denotes the :math:`q^{th}` neighborhood of :math:`x_{i,j}`.
Parameters
----------
data1 : tvm.te.Tensor
4-D with shape [batch, channel, height, width]
data2 : tvm.te.Tensor
4-D with shape [batch, channel, height, width]
kernel_size: int
Kernel size for correlation, must be an odd number
max_displacement: int
Max displacement of Correlation
stride1: int
Stride for data1
stride2: int
Stride for data2 within the neightborhood centered around data1
padding : int or a list/tuple of 2 or 4 ints
Padding size, or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 4 ints
is_multiply: bool
operation type is either multiplication or substraction
layout: str
layout of data1, data2 and the output
Returns
-------
Output : tvm.te.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
if isinstance(padding, int):
padding = (padding, padding)
return _make.correlation(data1, data2, kernel_size, max_displacement, stride1, stride2,
padding, is_multiply, layout)
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,3 +357,8 @@ class DilateAttrs(Attrs):
@tvm._ffi.register_object("relay.attrs.SubPixelAttrs")
class SubPixelAttrs(Attrs):
"""Attributes used in depth to space and space to depth operators"""


@tvm._ffi.register_object("relay.attrs.CorrelationAttrs")
class CorrelationAttrs(Attrs):
"""Attributes used in correlation operators"""
12 changes: 12 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,15 @@ def winograd_judge(N, H, W, KH, KW, CI, CO, padding, stride_h,
stride_h == 1 and stride_w == 1 and \
dilation_h == 1 and dilation_w == 1
return judge_winograd_tensorcore, judge_winograd_shape

@correlation_strategy.register(["cuda", "gpu"])
def correlation_strategy_cuda(attrs, inputs, out_type, target):
"""correlation cuda strategy"""
layout = attrs.layout
assert layout == "NCHW", "Only support NCHW layout"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_correlation(topi.cuda.correlation_nchw),
wrap_topi_schedule(topi.cuda.schedule_correlation_nchw),
name="correlation.cuda")
return strategy
27 changes: 27 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,3 +829,30 @@ def bitserial_dense_strategy(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.generic.schedule_bitserial_dense),
name="bitserial_dense.generic")
return strategy

# correlation
def wrap_compute_correlation(topi_compute):
"""wrap correlation topi compute"""
def _compute_correlation(attrs, inputs, out_type):
kernel_size = attrs.kernel_size
max_displacement = attrs.max_displacement
stride1 = attrs.stride1
stride2 = attrs.stride2
padding = get_const_tuple(attrs.padding)
is_multiply = attrs.is_multiply
return [topi_compute(inputs[0], inputs[1], kernel_size, max_displacement, stride1, stride2,
padding, is_multiply)]
return _compute_correlation

@override_native_generic_func("correlation_strategy")
def correlation_strategy(attrs, inputs, out_type, target):
"""correlation generic strategy"""
logger.warning("correlation is not optimized for this platform.")
layout = attrs.layout
assert layout == "NCHW", "Only support NCHW layout"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_correlation(topi.nn.correlation_nchw),
wrap_topi_schedule(topi.generic.schedule_correlation_nchw),
name="correlation.generic")
return strategy
136 changes: 136 additions & 0 deletions src/relay/op/nn/correlation.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file correlation.cc
* \brief Correlation operators
*/
#include <topi/nn.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/op.h>
#include <tvm/tir/data_layout.h>
#include <tvm/tir/op.h>

#include <vector>

#include "../op_common.h"

namespace tvm {
namespace relay {

// relay.nn.correlation
TVM_REGISTER_NODE_TYPE(CorrelationAttrs);

Array<Array<Layout>> CorrelationInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
const auto* params = attrs.as<CorrelationAttrs>();
Layout layout{params->layout};
return Array<Array<Layout>>{{layout, layout}, {layout}};
}

// Positional relay function to create correlation operator
// used by frontend FFI.
Expr MakeCorrelation(Expr data1, Expr data2, int kernel_size, int max_displacement, int stride1,
int stride2, Array<IndexExpr> padding, bool is_multiply, String layout) {
auto attrs = make_object<CorrelationAttrs>();
attrs->kernel_size = kernel_size;
attrs->max_displacement = max_displacement;
attrs->stride1 = stride1;
attrs->stride2 = stride2;
attrs->padding = padding;
attrs->is_multiply = is_multiply;
attrs->layout = layout;
static const Op& op = Op::Get("nn.correlation");
return Call(op, {data1, data2}, Attrs(attrs), {});
}

bool CorrelationRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data1 = types[0].as<TensorTypeNode>();
const auto* data2 = types[1].as<TensorTypeNode>();
if (data1 == nullptr || data2 == nullptr) return false;

const CorrelationAttrs* param = attrs.as<CorrelationAttrs>();
CHECK(param != nullptr);
CHECK_EQ(param->layout, "NCHW") << "layout not supported.";
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
IndexExpr padded_height = data1->shape[2] + pad_h;
IndexExpr padded_width = data2->shape[3] + pad_w;
int kernel_radius = (param->kernel_size - 1) / 2;
int border_size = param->max_displacement + kernel_radius;
int displacement_radius = param->max_displacement / param->stride2;
int displacement_size = 2 * displacement_radius + 1;
int out_channel = displacement_size * displacement_size;
IndexExpr out_height =
indexdiv((padded_height - 2 * border_size + param->stride1 - 1), param->stride1);
IndexExpr out_width =
indexdiv((padded_width - 2 * border_size + param->stride1 - 1), param->stride1);
Array<tvm::PrimExpr> oshape{data1->shape[0], out_channel, out_height, out_width};
// assign output type
reporter->Assign(types[2], TensorType(oshape, data1->dtype));
return true;
}

TVM_REGISTER_GLOBAL("relay.op.nn._make.correlation").set_body_typed(MakeCorrelation);

RELAY_REGISTER_OP("nn.correlation")
.describe(R"code(Applies correlation to inputs.
The correlation layer performs multiplicative patch comparisons between two feature maps.
Given two multi-channel feature maps :math:`f_{1}, f_{2}`, with :math:`w`, :math:`h`, and :math:`c` being their width, height, and number of channels,
the correlation layer lets the network compare each patch from :math:`f_{1}` with each patch from :math:`f_{2}`.
For now we consider only a single comparison of two patches. The 'correlation' of two patches centered at :math:`x_{1}` in the first map and
:math:`x_{2}` in the second map is then defined as:
.. math::
c(x_{1}, x_{2}) = \sum_{o \in [-k,k] \times [-k,k]} <f_{1}(x_{1} + o), f_{2}(x_{2} + o)>
for a square patch of size :math:`K:=2k+1`.
Note that the equation above is identical to one step of a convolution in neural networks, but instead of convolving data with a filter, it convolves data with other
data. For this reason, it has no training weights.
Computing :math:`c(x_{1}, x_{2})` involves :math:`c * K^{2}` multiplications. Comparing all patch combinations involves :math:`w^{2}*h^{2}` such computations.
Given a maximum displacement :math:`d`, for each location :math:`x_{1}` it computes correlations :math:`c(x_{1}, x_{2})` only in a neighborhood of size :math:`D:=2d+1`,
by limiting the range of :math:`x_{2}`. We use strides :math:`s_{1}, s_{2}`, to quantize :math:`x_{1}` globally and to quantize :math:`x_{2}` within the neighborhood
centered around :math:`x_{1}`.
The final output is defined by the following expression:
.. math::
out[n, q, i, j] = c(x_{i, j}, x_{q})
where :math:`i` and :math:`j` enumerate spatial locations in :math:`f_{1}`, and :math:`q` denotes the :math:`q^{th}` neighborhood of :math:`x_{i,j}`.
)code" TVM_ADD_FILELINE)
.set_attrs_type<CorrelationAttrs>()
.set_num_inputs(2)
.add_argument("data1", "Tensor", "Input data1 to the correlation.")
.add_argument("data2", "Tensor", "Input data2 to the correlation.")
.set_support_level(2)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", CorrelationInferCorrectLayout)
.add_type_rel("Correlation", CorrelationRel);

} // namespace relay
} // namespace tvm
Loading

0 comments on commit 308e8b5

Please sign in to comment.