Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-100] Fix buggy type inference in Correlation #10135

Merged
merged 2 commits into from
Mar 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/operator/correlation-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,10 @@ void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) overr
std::vector<int> *out_type,
std::vector<int> *aux_type) const override {
int dtype = (*in_type)[0];
type_assign(&(*in_type)[1], dtype);
type_assign(&(*out_type)[0], dtype);
type_assign(&(*out_type)[1], dtype);
type_assign(&(*out_type)[2], dtype);
type_assign(&dtype, (*in_type)[1]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just call ElemwiseType<2, 3>(attrs, in_attrs, out_attrs) to achieve the same purpose.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function does not have attrs argument, I guess what we have here is the only way to go.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it just trying to set all of the inputs to the first non -1 output type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, all input and output types to the inferred type.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry your statement there is kind of vague. can you please explain what you’re trying to do here? pretend i am 5 years old...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, you can just pass a blank NodAttrs if necessary

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, you can just pass a blank NodAttrs if necessary

type_assign(&dtype, (*out_type)[0]);
type_assign(&dtype, (*out_type)[1]);
type_assign(&dtype, (*out_type)[2]);

TYPE_ASSIGN_CHECK(*in_type, 0, dtype);
TYPE_ASSIGN_CHECK(*in_type, 1, dtype);
Expand Down
22 changes: 22 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2198,7 +2198,29 @@ def unittest_correlation(data_shape,kernel_size,max_displacement,stride1,stride2

@with_seed()
def test_correlation():
def test_infer_type(dtype):
a = mx.sym.Variable('a')
b = mx.sym.Variable('b')
corr = mx.sym.Correlation(data1=a, data2=b)
arg_type1, out_type1, _ = corr.infer_type(a=dtype)
if arg_type1[0] != np.dtype(dtype) and arg_type1[1] != np.dtype(dtype) and out_type1[0] != np.dtype(dtype):
msg = npt.npt.build_err_msg([a, b],
err_msg="Inferred type from a is not as expected, "
"Expected :%s %s %s, Got: %s %s %s"
% (dtype, dtype, dtype, arg_type1[0], arg_type1[1], out_type1[0]),
names=['a', 'b'])
raise AssertionError(msg)
arg_type2, out_type2, _ = corr.infer_type(b=dtype)
if arg_type2[0] != np.dtype(dtype) and arg_type2[1] != np.dtype(dtype) and out_type2[0] != np.dtype(dtype):
msg = npt.npt.build_err_msg([a, b],
err_msg="Inferred type from b is not as expected, "
"Expected :%s %s %s, Got: %s %s %s"
% (dtype, dtype, dtype, arg_type1[0], arg_type1[1], out_type1[0]),
names=['a', 'b'])
raise AssertionError(msg)

for dtype in ['float16', 'float32', 'float64']:
test_infer_type(dtype)
unittest_correlation((1,3,10,10), kernel_size = 1,max_displacement = 4,stride1 = 1,stride2 = 1,pad_size = 4,is_multiply = False, dtype = dtype)
unittest_correlation((5,1,15,15), kernel_size = 1,max_displacement = 5,stride1 = 1,stride2 = 1,pad_size = 5,is_multiply = False, dtype = dtype)
unittest_correlation((5,1,15,15), kernel_size = 1,max_displacement = 5,stride1 = 1,stride2 = 1,pad_size = 5,is_multiply = True, dtype = dtype)
Expand Down