Skip to content

Commit

Permalink
Fix 堆栈溢出 (stack overflow) of case10: paddle.unique (#49981)
Browse files Browse the repository at this point in the history
* add axis check in UniqueRawInferMeta

* add unittest for negative axis

* simplify check for unique
  • Loading branch information
RedContritio authored Jan 31, 2023
1 parent 82edc65 commit dbfdefa
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
9 changes: 9 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4648,13 +4648,22 @@ void UniqueRawInferMeta(const MetaTensor& x,
if (axis_value < 0) {
axis_value += x.dims().size();
}

PADDLE_ENFORCE_LT(
axis_value,
x.dims().size(),
phi::errors::InvalidArgument("The axis(%d) should be less than "
"the dimension size(%d) of x.",
axis_value,
x.dims().size()));
PADDLE_ENFORCE_GE(
axis_value,
0,
phi::errors::InvalidArgument(
"The axis(%d) + rank(x) (%d) should be greater than or equal to 0.",
axis_value,
-x.dims().size()));

auto out_dims = x.dims();
out_dims[axis_value] = -1;
out->set_dims(out_dims);
Expand Down
26 changes: 26 additions & 0 deletions python/paddle/fluid/tests/unittests/test_unique.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,32 @@ def init_config(self):
}


class TestUniqueOpAxisNeg(TestUniqueOp):
def init_config(self):
self.inputs = {'X': np.random.random((6, 1, 8)).astype('float64')}
unique, indices, inverse, counts = np.unique(
self.inputs['X'],
return_index=True,
return_inverse=True,
return_counts=True,
axis=-1,
)
self.attrs = {
'dtype': int(core.VarDesc.VarType.INT32),
"return_index": True,
"return_inverse": True,
"return_counts": True,
"axis": [-1],
"is_sorted": True,
}
self.outputs = {
'Out': unique,
'Indices': indices,
"Index": inverse,
"Counts": counts,
}


class TestUniqueOpAxis1(TestUniqueOp):
def init_config(self):
self.inputs = {'X': np.random.random((3, 8, 8)).astype('float64')}
Expand Down

0 comments on commit dbfdefa

Please sign in to comment.