diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index a9c305935ad3..dcb4cb6e88d7 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -1203,6 +1203,36 @@ struct SubPixelAttrs : public tvm::AttrsNode { } }; // struct SubPixelAttrs +/*! \brief Attributes used in correlation operators */ +struct CorrelationAttrs : public tvm::AttrsNode { + int kernel_size; + int max_displacement; + int stride1; + int stride2; + Array padding; + bool is_multiply; + String layout; + + TVM_DECLARE_ATTRS(CorrelationAttrs, "relay.attrs.CorrelationAttrs") { + TVM_ATTR_FIELD(kernel_size) + .describe("Kernel size for correlation, must be an odd number.") + .set_default(1); + TVM_ATTR_FIELD(max_displacement).describe("Max displacement of Correlation.").set_default(1); + TVM_ATTR_FIELD(stride1).describe("Stride for data1.").set_default(1); + TVM_ATTR_FIELD(stride2).describe("Stride for data2.").set_default(1); + TVM_ATTR_FIELD(padding) + .describe("Padding for data1 and data2.") + .set_default(Array{0, 0}); + TVM_ATTR_FIELD(is_multiply) + .describe("Operation type is either multiplication or substraction.") + .set_default(true); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively."); + } +}; // struct CorrelationAttrs + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_NN_H_ diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index edf668041fd5..9f97ee92cce8 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -1133,6 +1133,19 @@ def _mx_space_to_depth(inputs, attrs): return _op.nn.space_to_depth(*inputs, **new_attrs) +def _mx_correlation(inputs, attrs): + assert len(inputs) == 2 + new_attrs = {} + new_attrs["kernel_size"] = attrs.get_int("kernel_size", 1) + new_attrs["max_displacement"] = attrs.get_int("max_displacement", 1) + new_attrs["stride1"] = attrs.get_int("stride1", 1) + new_attrs["stride2"] = attrs.get_int("stride2", 1) + new_attrs["padding"] = attrs.get_int("pad_size", 0) + new_attrs["is_multiply"] = attrs.get_bool("is_multiply", True) + new_attrs["layout"] = "NCHW" + return _op.nn.correlation(*inputs, **new_attrs) + + def _mx_contrib_fifo_buffer(inputs, attrs): new_attrs = {} new_attrs['axis'] = attrs.get_int('axis') @@ -1971,6 +1984,7 @@ def impl(inputs, input_types): "one_hot" : _mx_one_hot, "depth_to_space" : _mx_depth_to_space, "space_to_depth" : _mx_space_to_depth, + "Correlation" : _mx_correlation, # vision "_contrib_BilinearResize2D" : _mx_resize, "_contrib_MultiBoxPrior" : _mx_multibox_prior, diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 9a9bfe0fbcbe..063345152b63 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -563,6 +563,11 @@ def compute_space_to_depth(attrs, inputs, out_dtype): reg.register_pattern("nn.space_to_depth", OpPattern.INJECTIVE) +# correlation +reg.register_strategy("nn.correlation", strategy.correlation_strategy) +reg.register_pattern("nn.correlation", OpPattern.OUT_ELEMWISE_FUSABLE) + + ##################### # Shape functions # ##################### diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 96708c9e51d0..0f1f158f6bac 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -2761,3 +2761,86 @@ def global_avg_pool3d(data, """ output_size = [1, 1, 1] return _make.adaptive_avg_pool3d(data, output_size, layout) + + +def correlation(data1, data2, kernel_size, max_displacement, stride1, stride2, padding, + is_multiply, layout): + r"""Applies correlation to inputs. + + The correlation layer performs multiplicative patch comparisons between two feature maps. + Given two multi-channel feature maps :math:`f_{1}, f_{2}`, with :math:`w`, :math:`h`, and + :math:`c` being their width, height, and number of channels, the correlation layer lets the + network compare each patch from :math:`f_{1}` with each patch from :math:`f_{2}`. + + For now we consider only a single comparison of two patches. The 'correlation' of two patches + centered at :math:`x_{1}` in the first map and :math:`x_{2}` in the second map is then defined + as: + + .. math:: + + c(x_{1}, x_{2}) = \sum_{o \in [-k,k] \times [-k,k]} + + for a square patch of size :math:`K:=2k+1`. + + Note that the equation above is identical to one step of a convolution in neural networks, but + instead of convolving data with a filter, it convolves data with other data. For this + reason, it has no training weights. + + Computing :math:`c(x_{1}, x_{2})` involves :math:`c * K^{2}` multiplications. Comparing all + patch combinations involves :math:`w^{2}*h^{2}` such computations. + + Given a maximum displacement :math:`d`, for each location :math:`x_{1}` it computes + correlations :math:`c(x_{1}, x_{2})` only in a neighborhood of size :math:`D:=2d+1`, + by limiting the range of :math:`x_{2}`. We use strides :math:`s_{1}, s_{2}`, to quantize + :math:`x_{1}` globally and to quantize :math:`x_{2}` within the neighborhood + centered around :math:`x_{1}`. + + The final output is defined by the following expression: + + .. math:: + + out[n, q, i, j] = c(x_{i, j}, x_{q}) + + where :math:`i` and :math:`j` enumerate spatial locations in :math:`f_{1}`, and :math:`q` + denotes the :math:`q^{th}` neighborhood of :math:`x_{i,j}`. + + Parameters + ---------- + data1 : tvm.te.Tensor + 4-D with shape [batch, channel, height, width] + + data2 : tvm.te.Tensor + 4-D with shape [batch, channel, height, width] + + kernel_size: int + Kernel size for correlation, must be an odd number + + max_displacement: int + Max displacement of Correlation + + stride1: int + Stride for data1 + + stride2: int + Stride for data2 within the neightborhood centered around data1 + + padding : int or a list/tuple of 2 or 4 ints + Padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints + + is_multiply: bool + operation type is either multiplication or substraction + + layout: str + layout of data1, data2 and the output + + Returns + ------- + Output : tvm.te.Tensor + 4-D with shape [batch, out_channel, out_height, out_width] + """ + if isinstance(padding, int): + padding = (padding, padding) + return _make.correlation(data1, data2, kernel_size, max_displacement, stride1, stride2, + padding, is_multiply, layout) diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index fee213c431a4..0686125cdf95 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -357,3 +357,8 @@ class DilateAttrs(Attrs): @tvm._ffi.register_object("relay.attrs.SubPixelAttrs") class SubPixelAttrs(Attrs): """Attributes used in depth to space and space to depth operators""" + + +@tvm._ffi.register_object("relay.attrs.CorrelationAttrs") +class CorrelationAttrs(Attrs): + """Attributes used in correlation operators""" diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 83e4e40b53b9..59d4ec9c6c34 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -590,3 +590,15 @@ def winograd_judge(N, H, W, KH, KW, CI, CO, padding, stride_h, stride_h == 1 and stride_w == 1 and \ dilation_h == 1 and dilation_w == 1 return judge_winograd_tensorcore, judge_winograd_shape + +@correlation_strategy.register(["cuda", "gpu"]) +def correlation_strategy_cuda(attrs, inputs, out_type, target): + """correlation cuda strategy""" + layout = attrs.layout + assert layout == "NCHW", "Only support NCHW layout" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_correlation(topi.cuda.correlation_nchw), + wrap_topi_schedule(topi.cuda.schedule_correlation_nchw), + name="correlation.cuda") + return strategy diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index c3eadce2b8dd..6db5b1459f6b 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -829,3 +829,30 @@ def bitserial_dense_strategy(attrs, inputs, out_type, target): wrap_topi_schedule(topi.generic.schedule_bitserial_dense), name="bitserial_dense.generic") return strategy + +# correlation +def wrap_compute_correlation(topi_compute): + """wrap correlation topi compute""" + def _compute_correlation(attrs, inputs, out_type): + kernel_size = attrs.kernel_size + max_displacement = attrs.max_displacement + stride1 = attrs.stride1 + stride2 = attrs.stride2 + padding = get_const_tuple(attrs.padding) + is_multiply = attrs.is_multiply + return [topi_compute(inputs[0], inputs[1], kernel_size, max_displacement, stride1, stride2, + padding, is_multiply)] + return _compute_correlation + +@override_native_generic_func("correlation_strategy") +def correlation_strategy(attrs, inputs, out_type, target): + """correlation generic strategy""" + logger.warning("correlation is not optimized for this platform.") + layout = attrs.layout + assert layout == "NCHW", "Only support NCHW layout" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_correlation(topi.nn.correlation_nchw), + wrap_topi_schedule(topi.generic.schedule_correlation_nchw), + name="correlation.generic") + return strategy diff --git a/src/relay/op/nn/correlation.cc b/src/relay/op/nn/correlation.cc new file mode 100644 index 000000000000..67f42b7d3e85 --- /dev/null +++ b/src/relay/op/nn/correlation.cc @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file correlation.cc + * \brief Correlation operators + */ +#include +#include +#include +#include +#include + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relay { + +// relay.nn.correlation +TVM_REGISTER_NODE_TYPE(CorrelationAttrs); + +Array> CorrelationInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { + const auto* params = attrs.as(); + Layout layout{params->layout}; + return Array>{{layout, layout}, {layout}}; +} + +// Positional relay function to create correlation operator +// used by frontend FFI. +Expr MakeCorrelation(Expr data1, Expr data2, int kernel_size, int max_displacement, int stride1, + int stride2, Array padding, bool is_multiply, String layout) { + auto attrs = make_object(); + attrs->kernel_size = kernel_size; + attrs->max_displacement = max_displacement; + attrs->stride1 = stride1; + attrs->stride2 = stride2; + attrs->padding = std::move(padding); + attrs->is_multiply = is_multiply; + attrs->layout = std::move(layout); + static const Op& op = Op::Get("nn.correlation"); + return Call(op, {data1, data2}, Attrs(attrs), {}); +} + +bool CorrelationRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data1 = types[0].as(); + const auto* data2 = types[1].as(); + if (data1 == nullptr || data2 == nullptr) return false; + + const CorrelationAttrs* param = attrs.as(); + CHECK(param != nullptr); + CHECK_EQ(param->layout, "NCHW") << "layout not supported."; + IndexExpr pad_h, pad_w; + GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); + IndexExpr padded_height = data1->shape[2] + pad_h; + IndexExpr padded_width = data2->shape[3] + pad_w; + int kernel_radius = (param->kernel_size - 1) / 2; + int border_size = param->max_displacement + kernel_radius; + int displacement_radius = param->max_displacement / param->stride2; + int displacement_size = 2 * displacement_radius + 1; + int out_channel = displacement_size * displacement_size; + IndexExpr out_height = + indexdiv((padded_height - 2 * border_size + param->stride1 - 1), param->stride1); + IndexExpr out_width = + indexdiv((padded_width - 2 * border_size + param->stride1 - 1), param->stride1); + Array oshape{data1->shape[0], out_channel, out_height, out_width}; + // assign output type + reporter->Assign(types[2], TensorType(oshape, data1->dtype)); + return true; +} + +TVM_REGISTER_GLOBAL("relay.op.nn._make.correlation").set_body_typed(MakeCorrelation); + +RELAY_REGISTER_OP("nn.correlation") + .describe(R"code(Applies correlation to inputs. + +The correlation layer performs multiplicative patch comparisons between two feature maps. +Given two multi-channel feature maps :math:`f_{1}, f_{2}`, with :math:`w`, :math:`h`, and :math:`c` being their width, height, and number of channels, +the correlation layer lets the network compare each patch from :math:`f_{1}` with each patch from :math:`f_{2}`. + +For now we consider only a single comparison of two patches. The 'correlation' of two patches centered at :math:`x_{1}` in the first map and +:math:`x_{2}` in the second map is then defined as: + +.. math:: + c(x_{1}, x_{2}) = \sum_{o \in [-k,k] \times [-k,k]} + +for a square patch of size :math:`K:=2k+1`. + +Note that the equation above is identical to one step of a convolution in neural networks, but instead of convolving data with a filter, it convolves data with other +data. For this reason, it has no training weights. + +Computing :math:`c(x_{1}, x_{2})` involves :math:`c * K^{2}` multiplications. Comparing all patch combinations involves :math:`w^{2}*h^{2}` such computations. + +Given a maximum displacement :math:`d`, for each location :math:`x_{1}` it computes correlations :math:`c(x_{1}, x_{2})` only in a neighborhood of size :math:`D:=2d+1`, +by limiting the range of :math:`x_{2}`. We use strides :math:`s_{1}, s_{2}`, to quantize :math:`x_{1}` globally and to quantize :math:`x_{2}` within the neighborhood +centered around :math:`x_{1}`. + +The final output is defined by the following expression: + +.. math:: + out[n, q, i, j] = c(x_{i, j}, x_{q}) + +where :math:`i` and :math:`j` enumerate spatial locations in :math:`f_{1}`, and :math:`q` denotes the :math:`q^{th}` neighborhood of :math:`x_{i,j}`. +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data1", "Tensor", "Input data1 to the correlation.") + .add_argument("data2", "Tensor", "Input data2 to the correlation.") + .set_support_level(2) + .set_attr("FInferCorrectLayout", CorrelationInferCorrectLayout) + .add_type_rel("Correlation", CorrelationRel); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 6e8acdeab101..99fc6c309567 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -1114,6 +1114,38 @@ def verify(shape, blocksize=2): verify((1, 1, 9, 9), 3) +def test_forward_correlation(): + def verify(data_shape, kernel_size, max_displacement, stride1, stride2, pad_size, + is_multiply): + data1 = np.random.uniform(size=data_shape).astype("float32") + data2 = np.random.uniform(size=data_shape).astype("float32") + ref_res = mx.nd.Correlation(data1=mx.nd.array(data1), data2=mx.nd.array(data2), + kernel_size=kernel_size, max_displacement=max_displacement, + stride1=stride1, stride2=stride2, pad_size=pad_size, + is_multiply=is_multiply) + mx_sym = mx.sym.Correlation(data1=mx.sym.var('data1'), data2=mx.sym.var('data2'), + kernel_size=kernel_size, max_displacement=max_displacement, + stride1=stride1, stride2=stride2, pad_size=pad_size, + is_multiply=is_multiply) + shape_dict = {"data1": data1.shape, "data2": data2.shape} + mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(data1, data2) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) + + verify((1, 3, 10, 10), kernel_size = 1, max_displacement = 4, stride1 = 1, stride2 = 1, pad_size = 4, is_multiply = False) + verify((5, 1, 15, 15), kernel_size = 1, max_displacement = 5, stride1 = 1, stride2 = 1, pad_size = 5, is_multiply = False) + verify((5, 1, 15, 15), kernel_size = 1, max_displacement = 5, stride1 = 1, stride2 = 1, pad_size = 5, is_multiply = True) + verify((5, 1, 15, 15), kernel_size = 1, max_displacement = 10, stride1 = 1, stride2 = 2, pad_size = 10, is_multiply = True) + verify((5, 1, 4, 4), kernel_size = 3, max_displacement = 1, stride1 = 1, stride2 = 1, pad_size = 2, is_multiply = True) + verify((5, 1, 4, 4), kernel_size = 3, max_displacement = 1, stride1 = 2, stride2 = 1, pad_size = 2, is_multiply = True) + verify((5, 1, 4, 4), kernel_size = 3, max_displacement = 1, stride1 = 2, stride2 = 1, pad_size = 2, is_multiply = False) + verify((5, 1, 6, 4), kernel_size = 3, max_displacement = 1, stride1 = 2, stride2 = 1, pad_size = 2, is_multiply = False) + verify((5, 1, 11, 11), kernel_size = 5, max_displacement = 1, stride1 = 1, stride2 = 1, pad_size = 2, is_multiply = False) + + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() @@ -1177,4 +1209,5 @@ def verify(shape, blocksize=2): test_forward_cond() test_forward_make_loss() test_forward_unravel_index() - test_forward_swap_axis() \ No newline at end of file + test_forward_swap_axis() + test_forward_correlation() diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index cf9d2d43eb12..68eced328fa8 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1342,6 +1342,45 @@ def test_bitpack_infer_type(): # TODO(@jwfromm): Need to add bitserial_conv2d & bitpack run test cases +def test_correlation(): + def _test_correlation(data_shape, kernel_size, max_displacement, stride1, stride2, padding, is_multiply, dtype='float32'): + data1 = relay.var("data1", relay.ty.TensorType(data_shape, dtype)) + data2 = relay.var("data2", relay.ty.TensorType(data_shape, dtype)) + y = relay.nn.correlation(data1, data2, kernel_size, max_displacement, stride1, stride2, + padding, is_multiply, "NCHW") + yy = run_infer_type(y) + padded_height = data_shape[2] + 2 * padding + padded_width = data_shape[3] + 2 * padding + border_size = (kernel_size - 1) // 2 + max_displacement + displacement_radius = max_displacement // stride2 + out_channel = ((2 * displacement_radius) + 1) ** 2 + out_height = (padded_height - 2 * border_size + stride1 - 1) // stride1 + out_width = (padded_width - 2 * border_size + stride1 - 1) // stride1 + assert yy.checked_type == relay.TensorType( + (data_shape[0], out_channel, out_height, out_width), dtype + ) + func = relay.Function([data1, data2], y) + data1_np = np.random.uniform(size=data_shape).astype(dtype) + data2_np = np.random.uniform(size=data_shape).astype(dtype) + ref_res = topi.testing.correlation_nchw_python(data1_np, data2_np, kernel_size, max_displacement, stride1, stride2, padding, is_multiply) + + for target, ctx in ctx_list(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(data1_np, data2_np) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + + _test_correlation((1, 3, 10, 10), kernel_size=1, max_displacement=4, + stride1=1, stride2=1, padding=4, is_multiply=True) + _test_correlation((1, 3, 10, 10), kernel_size=1, max_displacement=5, + stride1=1, stride2=1, padding=5, is_multiply=True) + _test_correlation((5, 1, 4, 4), kernel_size=3, max_displacement=1, + stride1=2, stride2=1, padding=2, is_multiply=True) + _test_correlation((5, 1, 6, 4), kernel_size=3, max_displacement=1, + stride1=2, stride2=2, padding=2, is_multiply=False) + _test_correlation((5, 1, 11, 11), kernel_size=5, max_displacement=1, + stride1=1, stride2=1, padding=2, is_multiply=False) + + if __name__ == "__main__": test_pool1d() test_pool2d() @@ -1374,3 +1413,4 @@ def test_bitpack_infer_type(): test_upsampling3d() test_conv2d_int8_intrinsics() test_depthwise_conv2d_int8() + test_correlation() diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index 8ccd80f38a91..ba5c54b1addf 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -49,3 +49,4 @@ from .conv2d_nhwc_tensorcore import * from .conv3d_ndhwc_tensorcore import * from .dense_tensorcore import * +from .correlation import * diff --git a/topi/python/topi/cuda/correlation.py b/topi/python/topi/cuda/correlation.py new file mode 100644 index 000000000000..a383e4e7188e --- /dev/null +++ b/topi/python/topi/cuda/correlation.py @@ -0,0 +1,176 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Correlation operators on CUDA""" +import tvm +from tvm import te +from tvm import autotvm + +from .. import nn +from ..util import traverse_inline + + +@autotvm.register_topi_compute("correlation_nchw.cuda") +def correlation_nchw(cfg, data1, data2, kernel_size, max_displacement, stride1, stride2, padding, + is_multiply): + """Correlation operator in NCHW layout. + + Parameters + ---------- + data1 : tvm.te.Tensor + 4-D with shape [batch, channel, height, width] + + data2 : tvm.te.Tensor + 4-D with shape [batch, channel, height, width] + + kernel_size: int + Kernel size for correlation, must be an odd number + + max_displacement: int + Max displacement of Correlation + + stride1: int + Stride for data1 + + stride2: int + Stride for data2 within the neightborhood centered around data1 + + padding : int or a list/tuple of 2 or 4 ints + Padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints + + is_multiply: bocorrelation + operation type is either multiplication or substraction + + Returns + ------- + Output : tvm.te.Tensor + 4-D with shape [batch, out_channel, out_height, out_width] + """ + # pylint: disable=unused-argument + return nn.correlation_nchw(data1, data2, kernel_size, max_displacement, stride1, stride2, + padding, is_multiply) + + +def _schedule_correlation_nchw(cfg, s, correlation): + """Schedule correlation_nchw direct template""" + # pylint: disable=invalid-name + ##### space definition begin ##### + n, f, y, x = s[correlation].op.axis + rc, ry, rx = s[correlation].op.reduce_axis + cfg.define_split("tile_f", f, num_outputs=4) + cfg.define_split("tile_y", y, num_outputs=4) + cfg.define_split("tile_x", x, num_outputs=4) + cfg.define_split("tile_rc", rc, num_outputs=2) + cfg.define_split("tile_ry", ry, num_outputs=2) + cfg.define_split("tile_rx", rx, num_outputs=2) + cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) + + target = tvm.target.Target.current() + if target.target_name in ['nvptx', 'rocm']: + cfg.define_knob("unroll_explicit", [1]) + else: + cfg.define_knob("unroll_explicit", [0, 1]) + + ##### space definition end ##### + + padded_data1, padded_data2 = s[correlation].op.input_tensors + s[padded_data1].compute_inline() + s[padded_data2].compute_inline() + + # create cache stage + s[correlation].set_scope('local') + AA = s.cache_read(padded_data1, 'shared', [correlation]) + BB = s.cache_read(padded_data2, 'shared', [correlation]) + + output = s.outputs[0].output(0) + + # tile and bind spatial axes + n, f, y, x = s[output].op.axis + kernel_scope, n = s[output].split(n, nparts=1) + + bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f) + by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) + bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) + + bf = s[output].fuse(n, bf) + s[output].bind(bf, te.thread_axis("blockIdx.z")) + s[output].bind(by, te.thread_axis("blockIdx.y")) + s[output].bind(bx, te.thread_axis("blockIdx.x")) + s[output].bind(vf, te.thread_axis("vthread")) + s[output].bind(vy, te.thread_axis("vthread")) + s[output].bind(vx, te.thread_axis("vthread")) + s[output].bind(tf, te.thread_axis("threadIdx.z")) + s[output].bind(ty, te.thread_axis("threadIdx.y")) + s[output].bind(tx, te.thread_axis("threadIdx.x")) + s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi) + s[correlation].compute_at(s[output], tx) + + # tile reduction axes + n, f, y, x = s[correlation].op.axis + rc, ry, rx = s[correlation].op.reduce_axis + rco, rci = cfg['tile_rc'].apply(s, correlation, rc) + ryo, ryi = cfg['tile_ry'].apply(s, correlation, ry) + rxo, rxi = cfg['tile_rx'].apply(s, correlation, rx) + s[correlation].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x) + + s[AA].compute_at(s[correlation], rxo) + s[BB].compute_at(s[correlation], rxo) + + # cooperative fetching + for load in [AA, BB]: + n, f, y, x = s[load].op.axis + fused = s[load].fuse(n, f, y, x) + tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2]) + ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2]) + tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2]) + s[load].bind(tz, te.thread_axis("threadIdx.z")) + s[load].bind(ty, te.thread_axis("threadIdx.y")) + s[load].bind(tx, te.thread_axis("threadIdx.x")) + + # unroll + s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val) + s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val) + + +@autotvm.register_topi_schedule("correlation_nchw.cuda") +def schedule_correlation_nchw(cfg, outs): + """schedule of correlation_nchw for cuda gpu + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + + outs: Array of Tensor + The computation graph description of correlation + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for correlation. + """ + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == 'correlation_nchw': + _schedule_correlation_nchw(cfg, s, op.output(0)) + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index 2be4bbb456de..d0c165db01df 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -672,3 +672,20 @@ def schedule_batch_matmul(outs): The computation schedule for the op. """ return _default_schedule(outs, False) + + +def schedule_correlation_nchw(outs): + """Schedule for correlation_nchw + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of correlation_nchw + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) diff --git a/topi/python/topi/nn/__init__.py b/topi/python/topi/nn/__init__.py index bd806b9d0e83..3830bd06c68d 100644 --- a/topi/python/topi/nn/__init__.py +++ b/topi/python/topi/nn/__init__.py @@ -22,6 +22,7 @@ from .conv1d import * from .conv2d import * from .conv3d import * +from .correlation import * from .deformable_conv2d import * from .depthwise_conv2d import * from .elemwise import * diff --git a/topi/python/topi/nn/correlation.py b/topi/python/topi/nn/correlation.py new file mode 100644 index 000000000000..94aea55d83b9 --- /dev/null +++ b/topi/python/topi/nn/correlation.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Correlation operators""" +from tvm import te + +from .pad import pad +from ..util import get_const_tuple + + +def correlation_nchw(data1, data2, kernel_size, max_displacement, stride1, stride2, padding, + is_multiply): + """Correlation operator in NCHW layout. + + Parameters + ---------- + data1 : tvm.te.Tensor + 4-D with shape [batch, channel, height, width] + + data2 : tvm.te.Tensor + 4-D with shape [batch, channel, height, width] + + kernel_size: int + Kernel size for correlation, must be an odd number + + max_displacement: int + Max displacement of Correlation + + stride1: int + Stride for data1 + + stride2: int + Stride for data2 within the neightborhood centered around data1 + + padding : int or a list/tuple of 2 or 4 ints + Padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints + + is_multiply: bool + operation type is either multiplication or substraction + + Returns + ------- + Output : tvm.te.Tensor + 4-D with shape [batch, out_channel, out_height, out_width] + """ + # pylint: disable=unnecessary-lambda, invalid-name + data_shape = get_const_tuple(data1.shape) + assert get_const_tuple(data2.shape) == data_shape, "data1 and data2 should have the same shape" + assert kernel_size > 0 and kernel_size % 2, "kernel_size should be non-negative odd number" + if isinstance(padding, (tuple, list)): + if len(padding) == 2: + pad_before_h = pad_after_h = padding[0] + pad_before_w = pad_after_w = padding[1] + elif len(padding) == 4: + pad_before_h, pad_before_w, pad_after_h, pad_after_w = padding + else: + raise ValueError("invalid padding") + elif isinstance(padding, int): + pad_before_h = pad_after_h = pad_before_w = pad_after_w = padding + else: + raise ValueError("invalid padding") + pad_before = [0, 0, pad_before_h, pad_before_w] + pad_after = [0, 0, pad_after_h, pad_after_w] + padded_data1 = pad(data1, pad_before, pad_after) + padded_data2 = pad(data2, pad_before, pad_after) + + batch, channel, height, width = data_shape + + kernel_radius = (kernel_size - 1) // 2 + border_size = max_displacement + kernel_radius + displacement_radius = max_displacement // stride2 + displacement_size = 2 * displacement_radius + 1 + + padded_width = width + pad_before_w + pad_after_w + padded_height = height + pad_before_h + pad_after_h + out_channel = displacement_size * displacement_size + out_height = (padded_height - 2 * border_size + stride1 - 1) // stride1 + out_width = (padded_width - 2 * border_size + stride1 - 1) // stride1 + + rc = te.reduce_axis((0, channel), name='rc') + ry = te.reduce_axis((0, kernel_size), name='ry') + rx = te.reduce_axis((0, kernel_size), name='rx') + + if is_multiply: + corr_func = lambda x, y: x * y + else: + corr_func = lambda x, y: te.abs(x - y) + + def _compute_correlation(n, q, i, j): + # location in data1 + y1 = i * stride1 + max_displacement + x1 = j * stride1 + max_displacement + # location in data2 + y2 = y1 + (te.indexdiv(q, displacement_size) - displacement_radius) * stride2 + x2 = x1 + (te.indexmod(q, displacement_size) - displacement_radius) * stride2 + return te.sum(corr_func(padded_data1[n, rc, y1 + ry, x1 + rx], + padded_data2[n, rc, y2 + ry, x2 + rx]), axis=[rc, ry, rx]) + + correlation = te.compute((batch, out_channel, out_height, out_width), lambda n, q, i, j: + _compute_correlation(n, q, i, j), tag="correlation_nchw") + return correlation / (kernel_size * kernel_size * channel) diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 36c460e671f5..511fe168c2e1 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -29,6 +29,7 @@ from .conv3d_ndhwc_python import conv3d_ndhwc_python from .conv2d_transpose_python import conv2d_transpose_nchw_python, conv2d_transpose_nhwc_python from .conv1d_transpose_ncw_python import conv1d_transpose_ncw_python +from .correlation_nchw_python import correlation_nchw_python from .deformable_conv2d_nchw_python import deformable_conv2d_nchw_python from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc from .dilate_python import dilate_python diff --git a/topi/python/topi/testing/correlation_nchw_python.py b/topi/python/topi/testing/correlation_nchw_python.py new file mode 100644 index 000000000000..f0536560849b --- /dev/null +++ b/topi/python/topi/testing/correlation_nchw_python.py @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +"""Convolution 3D in python""" +import numpy as np + + +def correlation_nchw_python(data1, data2, kernel_size, max_displacement, stride1, stride2, padding, is_multiply): + """Correlationn operator in NCHW layout. + + Parameters + ---------- + data1_np : numpy.ndarray + 4-D with shape [batch, in_channel, in_height, in_width] + + data2_np : numpy.ndarray + 4-D with shape [batch, in_channel, in_height, in_width] + + kernel_size: int + Kernel size for correlation, must be an odd number + + max_displacement: int + Max displacement of Correlation + + stride1: int + Stride for data1 + + stride2: int + Stride for data2 within the neightborhood centered around data1 + + padding: int + Padding for correlation + + is_multiply: bool + operation type is either multiplication or substraction + + Returns + ------- + c_np : np.ndarray + 4-D with shape [batch, out_channel, out_height, out_width] + """ + # compute output's dimension + pad_data_height = data1.shape[2] + 2 * padding + pad_data_width = data1.shape[3] + 2 * padding + kernel_radius = (kernel_size - 1) // 2 + border_size = max_displacement + kernel_radius + out_width = (pad_data_width - border_size * 2) // stride1 + out_height = (pad_data_height - border_size * 2) // stride1 + neighborhood_grid_radius = max_displacement // stride2 + neighborhood_grid_width = neighborhood_grid_radius * 2 + 1 + out_channel = neighborhood_grid_width * neighborhood_grid_width + + out = np.zeros((data1.shape[0], out_channel, out_height, out_width)) + pad_data1 = np.zeros((data1.shape[0], data1.shape[1], + pad_data_height, pad_data_width)) + pad_data2 = np.zeros((data1.shape[0], data1.shape[1], + pad_data_height, pad_data_width)) + + pad_data1[:, :, padding:padding + data1.shape[2], + padding:padding + data1.shape[3]] = data1[:, :, :, :] + pad_data2[:, :, padding:padding + data2.shape[2], + padding:padding + data2.shape[3]] = data2[:, :, :, :] + + if is_multiply: + corr_func = lambda x, y: x * y + else: + corr_func = lambda x, y: abs(x - y) + + # pylint: disable=too-many-nested-blocks + for i in range(out_height): + for j in range(out_width): + for nbatch in range(data1.shape[0]): + # x1,y1 is the location in data1 , i,j is the location in output + x1 = j * stride1 + max_displacement + y1 = i * stride1 + max_displacement + + for q in range(out_channel): + # location in data2 + x2 = x1 + (q % neighborhood_grid_width - neighborhood_grid_radius) * stride2 + y2 = y1 + (q // neighborhood_grid_width - neighborhood_grid_radius) * stride2 + + for h in range(kernel_size): + for w in range(kernel_size): + for channel in range(data1.shape[1]): + out[nbatch, q, i, j] += corr_func(pad_data1[nbatch, channel, y1 + h, x1 + w], + pad_data2[nbatch, channel, y2 + h, x2 + w]) + + out /= float(kernel_size** 2 *data1.shape[1]) + return out diff --git a/topi/tests/python/test_topi_correlation.py b/topi/tests/python/test_topi_correlation.py new file mode 100644 index 000000000000..663564fab469 --- /dev/null +++ b/topi/tests/python/test_topi_correlation.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License +"""test of correlation operator in NCHW layout""" +import numpy as np +import tvm +from tvm import te +from tvm import autotvm +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple + +from common import get_all_backend + + +_correlation_implement = { + "generic": (topi.nn.correlation_nchw, topi.generic.schedule_correlation_nchw), + "cuda": (topi.cuda.correlation_nchw, topi.cuda.schedule_correlation_nchw), +} + + +def verify_correlation_nchw(data_shape, kernel_size, max_displacement, stride1, stride2, pad_size, + is_multiply): + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d, %d)" % (data_shape[0], data_shape[1], data_shape[2], data_shape[3], + kernel_size, max_displacement, stride1, stride2, pad_size, + is_multiply)) + + A = te.placeholder(data_shape, name='data1') + B = te.placeholder(data_shape, name='data2') + dtype = A.dtype + + @memoize("topi.tests.test_topi_correlation_nchw.verify_correlation_nchw") + def get_ref_data(): + a_np = np.random.uniform(size=data_shape).astype(dtype) + b_np = np.random.uniform(size=data_shape).astype(dtype) + c_np = topi.testing.correlation_nchw_python(a_np, b_np, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply) + return a_np, b_np, c_np + + a_np, b_np, c_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + fcompute, fschedule = topi.testing.dispatch( + device, _correlation_implement) + with tvm.target.create(device): + C = fcompute(A, B, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply) + s = fschedule([C]) + + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.empty(c_np.shape, dtype=dtype, ctx=ctx) + + func = tvm.build(s, [A, B, C], device) + func(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) + + for device in get_all_backend(): + check_device(device) + + +def test_correlation_nchw(): + verify_correlation_nchw((1, 3, 10, 10), kernel_size=1, max_displacement=4, + stride1=1, stride2=1, pad_size=4, is_multiply=True) + verify_correlation_nchw((1, 3, 10, 10), kernel_size=1, max_displacement=5, + stride1=1, stride2=1, pad_size=5, is_multiply=True) + verify_correlation_nchw((5, 1, 4, 4), kernel_size=3, max_displacement=1, + stride1=2, stride2=1, pad_size=2, is_multiply=True) + verify_correlation_nchw((5, 1, 6, 4), kernel_size=3, max_displacement=1, + stride1=2, stride2=2, pad_size=2, is_multiply=False) + verify_correlation_nchw((5, 1, 11, 11), kernel_size=5, max_displacement=1, + stride1=1, stride2=1, pad_size=2, is_multiply=False) + + +if __name__ == "__main__": + test_correlation_nchw()