Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-891] Support tuple of scales in upsample operator #15811

Open
wants to merge 91 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 80 commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
6f86613
added support for tuple dimenions except for upsamplingforward function
benhe2011 Aug 6, 2019
e59245d
modified mshadow to support separate dimension scales
benhe2011 Aug 6, 2019
a0781af
add make script
benhe2011 Aug 7, 2019
a10a840
save patches for tuple support
benhe2011 Aug 7, 2019
3f7f64f
fixed e scales member bug
benhe2011 Aug 7, 2019
3cf44c6
fix e scales member var bug
benhe2011 Aug 7, 2019
3e54dec
added initializations for plan struct
benhe2011 Aug 7, 2019
9659057
fixed syntax error
benhe2011 Aug 7, 2019
2b18eea
add print msg to make sure up-to-date
benhe2011 Aug 7, 2019
c8c3d01
moved update check msg
benhe2011 Aug 7, 2019
592e3aa
fixed some more syntax bugs
benhe2011 Aug 7, 2019
76209ae
Merge branch 'master' of git://github.com/apache/incubator-mxnet
benhe2011 Aug 8, 2019
7f67ba8
Merge branch 'master' of git://github.com/apache/incubator-mxnet
benhe2011 Aug 8, 2019
9a44fab
modified test cases
benhe2011 Aug 8, 2019
65f7ac3
removed files created for testing upsampling
benhe2011 Aug 8, 2019
496f1d6
removed some unnecessary extra lines
benhe2011 Aug 8, 2019
4924b98
put repeated code into scaleComp function
benhe2011 Aug 9, 2019
c96172f
change scalepointer to array
benhe2011 Aug 9, 2019
4c1d31d
fix scale_hw syntax error
benhe2011 Aug 9, 2019
9b01e02
change scaleComp to return dynamically allocated array
benhe2011 Aug 9, 2019
4858738
deleted commented out sections
benhe2011 Aug 9, 2019
37ba2e7
formatting
benhe2011 Aug 9, 2019
81b54cf
modified test cases
benhe2011 Aug 9, 2019
ccca549
formatting and added case in test_operator to handle if no values are…
benhe2011 Aug 10, 2019
b2a8d0b
Merge branch 'master' of git://github.com/apache/incubator-mxnet
benhe2011 Aug 10, 2019
ac13554
removed test script
benhe2011 Aug 10, 2019
611df32
some whitespace reformatting in upsampling-inl.h
benhe2011 Aug 10, 2019
9ce3127
Merge branch 'master' of git://github.com/apache/incubator-mxnet
benhe2011 Aug 10, 2019
328fdf2
Merge branch 'master' of git://github.com/apache/incubator-mxnet
benhe2011 Aug 11, 2019
21572b8
more whitespace reformatting based on github checks
benhe2011 Aug 11, 2019
238a7cf
more whitespace formatting-line 136/137
benhe2011 Aug 12, 2019
2154f4c
inserted prin msg for scala test debugging (throwaway)
benhe2011 Aug 13, 2019
d02fbe5
removed throwaway scala test case print statement
benhe2011 Aug 13, 2019
77d296c
Merge branch 'master' of git://github.com/apache/incubator-mxnet
benhe2011 Aug 13, 2019
64b705f
update temporary testing file
benhe2011 Aug 14, 2019
a2a6b46
update temp testing file again
benhe2011 Aug 14, 2019
435fce4
Merge branch 'master' of git://github.com/apache/incubator-mxnet
benhe2011 Aug 14, 2019
9668f1e
modified scala demo to help with debugging scala test case
benhe2011 Aug 14, 2019
049e708
Merge branch 'master' of git://github.com/apache/incubator-mxnet
benhe2011 Aug 14, 2019
4336d6e
modified scala demo to debug upsampling scala test case
benhe2011 Aug 14, 2019
2f2bb7b
temporary changes to scala demo for debugging
benhe2011 Aug 14, 2019
23834d9
removed edits to hello world scala demo code
benhe2011 Aug 15, 2019
8229db4
Merge branch 'master' of git://github.com/apache/incubator-mxnet into…
benhe2011 Aug 15, 2019
dbd387e
minimal changes to scala and clojure tests for Upsampling operator so…
benhe2011 Aug 16, 2019
1ea1e7c
Merge branch 'master' of git://github.com/apache/incubator-mxnet
benhe2011 Aug 16, 2019
2f2809e
removed temporary upsampling testing file
benhe2011 Aug 16, 2019
531cf4d
Merge branch 'master' of git://github.com/apache/incubator-mxnet
benhe2011 Aug 16, 2019
1b0d39a
added Vandana's 14942 PR edits
benhe2011 Aug 16, 2019
15cdf1f
Revert "added Vandana's 14942 PR edits"
benhe2011 Aug 16, 2019
bcbfc90
clojure syntax error fix
benhe2011 Aug 16, 2019
0dbdbb5
Merge branch 'master' of git://github.com/apache/incubator-mxnet
benhe2011 Aug 18, 2019
6a31cb3
change clojure test for Upsampling operator so that only single shape…
benhe2011 Aug 18, 2019
a931a1b
clojure test change for tuple-supported upsampling operator syntax fix
benhe2011 Aug 18, 2019
407a6b9
modified clojure test for tuple-supported upsampling operator syntax …
benhe2011 Aug 18, 2019
c8f952b
Merge branch 'master' of git://github.com/apache/incubator-mxnet
benhe2011 Aug 18, 2019
5bd1463
Merge branch 'master' of git://github.com/apache/incubator-mxnet
benhe2011 Aug 19, 2019
f354034
used scaleComp function to replace code block in upsampling.cc
benhe2011 Aug 19, 2019
706f39a
fix syntax error from previous commit
benhe2011 Aug 19, 2019
b214bb2
removed some comments
benhe2011 Aug 19, 2019
7a7d613
Merge branch 'master' of git://github.com/apache/incubator-mxnet
benhe2011 Aug 19, 2019
5b409c7
Merge branch 'master' of git://github.com/apache/incubator-mxnet
benhe2011 Aug 20, 2019
83e44c5
empty commit
benhe2011 Aug 20, 2019
12741af
Merge branch 'master' of git://github.com/apache/incubator-mxnet
benhe2011 Aug 21, 2019
5050cbb
fixed bug when multiple shapes are passed
benhe2011 Aug 21, 2019
ec77909
remodified upsampling tests
benhe2011 Aug 21, 2019
6e75cab
fixed test_operator parameter list bug
benhe2011 Aug 22, 2019
93f7955
fixed test_operator.py test cases for upsampling
benhe2011 Aug 22, 2019
e8940c1
Merge branch 'master' of git://github.com/apache/incubator-mxnet
benhe2011 Aug 22, 2019
c69baae
removed onnx changes to be put in a different PR
benhe2011 Aug 22, 2019
d76ac2a
removed extra line
benhe2011 Aug 22, 2019
dd4121c
removed some comments
benhe2011 Aug 22, 2019
74c10c2
Merge branch 'master' of git://github.com/apache/incubator-mxnet
benhe2011 Aug 22, 2019
6cfc3a9
removed an onnx test case
benhe2011 Aug 23, 2019
b3299c2
Merge branch 'master' of git://github.com/apache/incubator-mxnet
benhe2011 Aug 23, 2019
0cfe594
changed scaleComp to return vector for scale values
benhe2011 Aug 23, 2019
b3dc940
fixed small syntax error
benhe2011 Aug 23, 2019
43a1621
added explanation for upsampling test cases in test_operator
benhe2011 Aug 24, 2019
f8be93a
added explanation for test_operator upsampling test
benhe2011 Aug 26, 2019
b461a39
Merge branch 'master' of git://github.com/apache/incubator-mxnet
benhe2011 Aug 26, 2019
d3eff5a
fixed syntax error in test_operator
benhe2011 Aug 26, 2019
d687d4d
scaleComp tshape debugging test
benhe2011 Aug 27, 2019
53a79f0
fix syntax error
benhe2011 Aug 27, 2019
413f0f3
more scaleComp tuple debug testing
benhe2011 Aug 27, 2019
00d226b
changed scaleComp in upsampling implementation to return TShape
benhe2011 Aug 27, 2019
9c033e3
Merge branch 'master' of git://github.com/apache/incubator-mxnet
benhe2011 Aug 27, 2019
c233e2d
some space formatting, changed ::fmod operation in upsampling.cc, add…
benhe2011 Aug 27, 2019
435b330
Merge branch 'master' of git://github.com/apache/incubator-mxnet
benhe2011 Aug 27, 2019
7cf5f6c
modified test cases in test_operator.py for upsampling
benhe2011 Aug 28, 2019
ea0ca72
Merge branch 'master' of git://github.com/apache/incubator-mxnet
benhe2011 Aug 28, 2019
baf3357
modified test cases for upsampling
benhe2011 Aug 28, 2019
f4df285
Merge branch 'master' of git://github.com/apache/incubator-mxnet
benhe2011 Sep 3, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions 3rdparty/mshadow/mshadow/extension/spatial_upsampling_nearest.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
namespace mshadow {
namespace expr {

/*! \brief nearest neighboor upsampling
/*! \brief nearest neighbor upsampling
* out(x, y) = in(int(x / scale_x), int(y / scale_y))
* \tparam SrcExp source expression
* \tparam DType data type
Expand All @@ -24,47 +24,55 @@ struct UpSamplingNearestExp :
/*! \brief source oprand */
const SrcExp &src_;
/*! \brief up sampling scale */
index_t scale_;
index_t scale_h_;
index_t scale_w_;

/*! \brief constructor */
UpSamplingNearestExp(const SrcExp &src, index_t scale)
: src_(src), scale_(scale) {
UpSamplingNearestExp(const SrcExp &src, index_t scale_h, index_t scale_w)
: src_(src), scale_h_(scale_h), scale_w_(scale_w) {
this->shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_);
this->shape_[srcdim - 2] *= scale_;
this->shape_[srcdim - 1] *= scale_;
this->shape_[srcdim - 2] *= scale_h;
this->shape_[srcdim - 1] *= scale_w;
}
};


template<typename SrcExp, typename DType, int etype>
inline UpSamplingNearestExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
upsampling_nearest(const Exp<SrcExp, DType, etype> &src, index_t scale) {
upsampling_nearest(const Exp<SrcExp, DType, etype> &src, index_t scale_h, index_t scale_w) {
TypeCheckPass<ExpInfo<SrcExp>::kDim >= 2>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return UpSamplingNearestExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self(), scale);
return UpSamplingNearestExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>(src.self(), scale_h, scale_w);
}

template<typename SrcExp, typename DType, int srcdim>
struct Plan<UpSamplingNearestExp<SrcExp, DType, srcdim>, DType> {
public:
explicit Plan(const UpSamplingNearestExp<SrcExp, DType, srcdim> &e)
: src_(MakePlan(e.src_)),
scale_(e.scale_),
scale_h_(e.scale_h_),
scale_w_(e.scale_w_),
new_height_(e.shape_[srcdim - 2]),
src_height_(static_cast<index_t>(e.shape_[srcdim - 2] / e.scale_)) {}
new_width_(e.shape_[srcdim - 1]),
src_height_(static_cast<index_t>(e.shape_[srcdim - 2] / e.scale_h_)),
src_width_(static_cast<index_t>(e.shape_[srcdim - 1] / e.scale_w_)) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
const index_t x = j;
const index_t y = i % new_height_;
const index_t c = i / new_height_;
const index_t h = static_cast<index_t>(y / scale_);
const index_t w = static_cast<index_t>(x / scale_);
const index_t h = static_cast<index_t>(y / scale_h_);
const index_t w = static_cast<index_t>(x / scale_w_);
return src_.Eval(c * src_height_ + h, w);
}

private:
Plan<SrcExp, DType> src_;
const index_t scale_;
const index_t scale_h_;
const index_t scale_w_;
const index_t new_height_;
const index_t new_width_;
const index_t src_height_;
const index_t src_width_;
};
} // namespace expr
} // namespace mshadow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@
scale (range 1 4)
num-shape (range 1 4)
base (range 1 4)]
(let [shape-vecs (mapv (fn [i] [1 3 (* base root-scale (int (Math/pow scale (- (dec num-shape) i))))
(let [shape-vecs (mapv (fn [i] [1 3 (* base root-scale (int (Math/pow scale (- (dec num-shape) i))))
(* base root-scale (int (Math/pow scale (- (dec num-shape) i))))])
(range 0 num-shape))]
(check-nearest-up-sampling-with-shape {:shape-vecs shape-vecs :scale scale :root-scale root-scale})))))
76 changes: 52 additions & 24 deletions src/operator/nn/upsampling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,18 @@ enum UpSamplingMultiInputMode {kConcat, kSum};
} // namespace up_enum

struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> {
int scale;
TShape scale;
int num_filter;
int sample_type;
int num_args;
int multi_input_mode;
uint64_t workspace;
DMLC_DECLARE_PARAMETER(UpSamplingParam) {
DMLC_DECLARE_FIELD(scale)
.set_range(1, 1000)
.describe("Up sampling scale");
.set_default(TShape())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this going to break backward compatibility?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it is still backwards compatible, as it supports a scale as either an integer or a tuple.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use mxnet::TShape to clarify that you're not using nnvm::TShape or vice-versa.

Can you please explain how this will work for passing in a scalar argument? I dont see TShape constructor taking a single scalar value.

Copy link
Author

@benhe2011 benhe2011 Aug 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, made changes. Not sure if there's any historical or significant reason for the 1000 range limit though.

.describe("Up sampling scale. Integer or tuple of integers. "
"Different scale per dimension is allowed only for "
"nearest neighbor upsampling.");
DMLC_DECLARE_FIELD(num_filter)
.describe("Input filter. Only used by bilinear sample_type."
"Since bilinear upsampling uses deconvolution, num_filters "
Expand All @@ -84,6 +86,21 @@ struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> {
}
}; // struct UpSamplingParam

inline std::vector<int> scaleComp(const UpSamplingParam &param) {
benhe2011 marked this conversation as resolved.
Show resolved Hide resolved
std::vector<int> scaleArr{ 1, 1 };
if (param.scale.ndim() == 1) {
scaleArr[0] = param.scale[0];
scaleArr[1] = param.scale[0];
} else if (param.scale.ndim() == 2) {
scaleArr[0] = param.scale[0];
scaleArr[1] = param.scale[1];
} else if (param.scale.ndim() == 4) {
scaleArr[0] = param.scale[2];
benhe2011 marked this conversation as resolved.
Show resolved Hide resolved
scaleArr[1] = param.scale[3];
}
benhe2011 marked this conversation as resolved.
Show resolved Hide resolved
return scaleArr;
}

template<typename xpu, typename DType>
void UpSamplingForward(const OpContext &ctx, const UpSamplingParam &param,
const std::vector<TBlob> &in_data,
Expand All @@ -103,21 +120,27 @@ void UpSamplingForward(const OpContext &ctx, const UpSamplingParam &param,
for (int i = 0; i < param.num_args; ++i) {
Tensor<xpu, 4, DType> data = in_data[i].get<xpu, 4, DType>(s);
int end = begin + data.size(1);
int scale = out_data[up_enum::kOut].size(2)/in_data[i].size(2);
// 3rd dimension of TBlob
int scale_h = out_data[up_enum::kOut].size(2)/in_data[i].size(2);
// 4th dimension of TBlob
int scale_w = out_data[up_enum::kOut].size(3)/in_data[i].size(3);
if (param.multi_input_mode == up_enum::kSum) {
if (i == 0) {
Assign(out, req[up_enum::kOut], upsampling_nearest(data, scale));
Assign(out, req[up_enum::kOut], upsampling_nearest(data, scale_h, scale_w));
} else {
out += upsampling_nearest(data, scale);
out += upsampling_nearest(data, scale_h, scale_w);
}
} else {
Assign(slice<1>(out, begin, end), req[up_enum::kOut], upsampling_nearest(data, scale));
Assign(slice<1>(out, begin, end),
req[up_enum::kOut],
upsampling_nearest(data, scale_h, scale_w));
}
begin = end;
}
} else {
Tensor<xpu, 4, DType> data = in_data[up_enum::kData].get<xpu, 4, DType>(s);
Assign(out, req[up_enum::kOut], upsampling_nearest(data, param.scale));
std::vector<int> scale_hw = scaleComp(param);
Assign(out, req[up_enum::kOut], upsampling_nearest(data, scale_hw[0], scale_hw[1]));
}
}

Expand All @@ -136,44 +159,49 @@ void UpSamplingBackward(const OpContext &ctx, const UpSamplingParam &param,
Tensor<xpu, 4, DType> input_grad = in_grad[i].get<xpu, 4, DType>(s);
mshadow::Shape<2> in_shape = Shape2(input_grad.shape_[2], input_grad.shape_[3]);
int end = begin + input_grad.size(1);
int scale = grad.size(2)/in_shape[0];
int scale_h = grad.size(2)/in_shape[0];
int scale_w = grad.size(3)/in_shape[1];
if (param.multi_input_mode == up_enum::kSum) {
Assign(input_grad, req[i],
pool<mshadow::red::sum>(grad,
in_shape,
scale,
scale,
scale,
scale));
scale_h,
scale_w,
scale_h,
scale_w));
} else {
Assign(input_grad, req[i],
pool<mshadow::red::sum>(slice<1>(grad, begin, end),
in_shape,
scale,
scale,
scale,
scale));
scale_h,
scale_w,
scale_h,
scale_w));
}
begin = end;
}
} else {
Tensor<xpu, 4, DType> input_grad = in_grad[up_enum::kData].get<xpu, 4, DType>(s);
mshadow::Shape<2> in_shape = Shape2(input_grad.shape_[2], input_grad.shape_[3]);
std::vector<int> scale_hw = scaleComp(param);
Assign(input_grad, req[up_enum::kData],
pool<mshadow::red::sum>(grad,
in_shape,
param.scale,
param.scale,
param.scale,
param.scale));
scale_hw[0],
scale_hw[1],
scale_hw[0],
scale_hw[1]));
}
}

static inline DeconvolutionParam GetDeconvolutionParam(const UpSamplingParam& param) {
DeconvolutionParam p = DeconvolutionParam();
int kernel = 2 * param.scale - param.scale % 2;
int stride = param.scale;
int pad = static_cast<int>(ceil((param.scale - 1) / 2.));
std::vector<int> scale_hw = scaleComp(param);
CHECK_EQ(scale_hw[0], scale_hw[1]) <<
"UpSamplingBilinear: Scale should be the same along all dimensions for bilinear upsampling";
int kernel = static_cast<int>(2.0 * scale_hw[0] - ::fmod(scale_hw[0], 2));
int stride = scale_hw[0];
int pad = static_cast<int>(ceil((scale_hw[0] - 1) / 2.));
p.workspace = param.workspace;
p.num_group = param.num_filter;
p.num_filter = param.num_filter;
Expand Down
13 changes: 9 additions & 4 deletions src/operator/nn/upsampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,16 @@ static bool UpSamplingShape(const nnvm::NodeAttrs& attrs,
CHECK_GE(in_shape->size(), 1U);
const mxnet::TShape &dshape = (*in_shape)[0];
mxnet::TShape oshape = dshape;
std::vector<int> scale_hw = scaleComp(param_);
int scale_h = scale_hw[0];
int scale_w = scale_hw[1];
if (param_.sample_type == up_enum::kNearest) {
CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
oshape[1] = 0;
for (auto& shape : *in_shape) {
CHECK_EQ(shape.ndim(), 4U) << \
"UpSamplingNearest: Input data should be 4D in (batch, channel, y, x)";
int oh = dshape[2]*param_.scale, ow = dshape[3]*param_.scale;
int oh = dshape[2]*scale_h, ow = dshape[3]*scale_w;
benhe2011 marked this conversation as resolved.
Show resolved Hide resolved
CHECK_EQ(oh%shape[2], 0U) << "UpSamplingNearest: input height of " << shape[2] << \
"does not divide output height of " << oh;
CHECK_EQ(ow%shape[3], 0U) << "UpSamplingNearest: input width of " << shape[3] << \
Expand All @@ -58,17 +61,19 @@ static bool UpSamplingShape(const nnvm::NodeAttrs& attrs,
}
} else {
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]";
CHECK_EQ(scale_h, scale_w) <<
"UpSamplingBilinear: Scale should be the same along all dimensions for bilinear upsampling";
CHECK_EQ(dshape.ndim(), 4U) << \
"UpSamplingBilinear: Input data should be 4D in (batch, channel, y, x)";
if (!shape_is_known(dshape)) return false;
int kernel = 2 * param_.scale - param_.scale % 2;
int kernel = static_cast<int>(2.0 * scale_h - ::fmod(scale_h, 2));
benhe2011 marked this conversation as resolved.
Show resolved Hide resolved
SHAPE_ASSIGN_CHECK(*in_shape,
up_enum::kWeight,
mshadow::Shape4(dshape[1], 1, kernel, kernel));
oshape = dshape;
}
oshape[2] = dshape[2] * param_.scale;
oshape[3] = dshape[3] * param_.scale;
oshape[2] = dshape[2] * scale_h;
oshape[3] = dshape[3] * scale_w;
out_shape->clear();
out_shape->push_back(oshape);
return true;
Expand Down
39 changes: 30 additions & 9 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,15 +1559,23 @@ def check_deconvolution_forward_with_bias(shape=(1, 16, 5, 5), num_filter=32, nu
def check_nearest_upsampling_with_shape(shapes, scale, root_scale):
arr = {'arg_%d'%i: mx.random.uniform(-10.0, 10.0, shape, ctx=mx.cpu()).copyto(default_context()) for i, shape in zip(range(len(shapes)), shapes)}
arr_grad = {'arg_%d'%i: mx.nd.zeros(shape) for i, shape in zip(range(len(shapes)), shapes)}

up = mx.sym.UpSampling(*[mx.sym.Variable('arg_%d'%i) for i in range(len(shapes))], sample_type='nearest', scale=root_scale)
exe = up.bind(default_context(), args=arr, args_grad=arr_grad)
exe.forward(is_train=True)
exe.backward(exe.outputs)
for k in range(len(shapes)):
name = 'arg_%d'%k
assert_allclose(arr[name].asnumpy()*root_scale**2*scale**(2*k), arr_grad[name].asnumpy(), rtol=1e-4)

out = arr_grad[name].asnumpy()
root_h = root_w = 1
if type(root_scale) is int:
root_h = root_w = root_scale
elif len(root_scale) == 1:
root_h = root_w = root_scale[0]
elif len(root_scale) >= 2:
root_h = root_scale[0]
root_w = root_scale[1]
exp = arr[name].asnumpy()*root_h*root_w*scale**(2*k)
benhe2011 marked this conversation as resolved.
Show resolved Hide resolved
assert_allclose(exp, out, rtol=1e-4)

def check_bilinear_upsampling_with_shape(data_shape, weight_shape, scale, root_scale, num_filter):
def _init_bilinear(arr, f):
Expand Down Expand Up @@ -1597,16 +1605,29 @@ def _init_bilinear(arr, f):
assert out.shape == data_shape[:2] + target_shape


"""
The test cases include integer, tuple,
and empty tuple scales on up to 3 shapes
at once with the shapes having various sizes
for their heights and widths
"""
@with_seed()
def test_nearest_upsampling():
for root_scale in [1,2,3]:
for scale in [1,2,3]:
for num_shape in [1,2,3]:
for base in [1,2,3]:
shapes = [(1,3,base*root_scale*scale**(num_shape-1-i),base*root_scale*scale**(num_shape-1-i)) for i in range(num_shape)]
for root_scale in [1, 2, (3), (2,3), (3,2), (1,1), (5,1), (2,2), ()]:
benhe2011 marked this conversation as resolved.
Show resolved Hide resolved
for scale in [1, 2, 3]:
for num_shape in [1, 2, 3]:
for base in [1, 2, 3]:
root_h = root_w = 1
if type(root_scale) is int:
root_h = root_w = root_scale
elif len(root_scale) == 1:
root_h = root_w = root_scale[0]
elif len(root_scale) >= 2:
root_h = root_scale[0]
root_w = root_scale[1]
shapes = [(1, 3, base*root_h*scale**(num_shape-1-i), base*root_w*scale**(num_shape-1-i)) for i in range(num_shape)]
benhe2011 marked this conversation as resolved.
Show resolved Hide resolved
check_nearest_upsampling_with_shape(shapes, scale, root_scale)


benhe2011 marked this conversation as resolved.
Show resolved Hide resolved
@with_seed()
def test_bilinear_upsampling():
rootscale = [2,3]
Expand Down