Skip to content

Commit

Permalink
[MXNET-100] Fix buggy type inference in Correlation (apache#10135)
Browse files Browse the repository at this point in the history
* fix buggy type inference in correlation

* add test for mutual type inference
  • Loading branch information
haojin2 authored and ashokei committed Mar 27, 2018
1 parent dc5add5 commit c7e164f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/operator/correlation-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,10 @@ void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) overr
std::vector<int> *out_type,
std::vector<int> *aux_type) const override {
int dtype = (*in_type)[0];
type_assign(&(*in_type)[1], dtype);
type_assign(&(*out_type)[0], dtype);
type_assign(&(*out_type)[1], dtype);
type_assign(&(*out_type)[2], dtype);
type_assign(&dtype, (*in_type)[1]);
type_assign(&dtype, (*out_type)[0]);
type_assign(&dtype, (*out_type)[1]);
type_assign(&dtype, (*out_type)[2]);

TYPE_ASSIGN_CHECK(*in_type, 0, dtype);
TYPE_ASSIGN_CHECK(*in_type, 1, dtype);
Expand Down
22 changes: 22 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2198,7 +2198,29 @@ def unittest_correlation(data_shape,kernel_size,max_displacement,stride1,stride2

@with_seed()
def test_correlation():
def test_infer_type(dtype):
a = mx.sym.Variable('a')
b = mx.sym.Variable('b')
corr = mx.sym.Correlation(data1=a, data2=b)
arg_type1, out_type1, _ = corr.infer_type(a=dtype)
if arg_type1[0] != np.dtype(dtype) and arg_type1[1] != np.dtype(dtype) and out_type1[0] != np.dtype(dtype):
msg = npt.npt.build_err_msg([a, b],
err_msg="Inferred type from a is not as expected, "
"Expected :%s %s %s, Got: %s %s %s"
% (dtype, dtype, dtype, arg_type1[0], arg_type1[1], out_type1[0]),
names=['a', 'b'])
raise AssertionError(msg)
arg_type2, out_type2, _ = corr.infer_type(b=dtype)
if arg_type2[0] != np.dtype(dtype) and arg_type2[1] != np.dtype(dtype) and out_type2[0] != np.dtype(dtype):
msg = npt.npt.build_err_msg([a, b],
err_msg="Inferred type from b is not as expected, "
"Expected :%s %s %s, Got: %s %s %s"
% (dtype, dtype, dtype, arg_type1[0], arg_type1[1], out_type1[0]),
names=['a', 'b'])
raise AssertionError(msg)

for dtype in ['float16', 'float32', 'float64']:
test_infer_type(dtype)
unittest_correlation((1,3,10,10), kernel_size = 1,max_displacement = 4,stride1 = 1,stride2 = 1,pad_size = 4,is_multiply = False, dtype = dtype)
unittest_correlation((5,1,15,15), kernel_size = 1,max_displacement = 5,stride1 = 1,stride2 = 1,pad_size = 5,is_multiply = False, dtype = dtype)
unittest_correlation((5,1,15,15), kernel_size = 1,max_displacement = 5,stride1 = 1,stride2 = 1,pad_size = 5,is_multiply = True, dtype = dtype)
Expand Down

0 comments on commit c7e164f

Please sign in to comment.