Skip to content

Commit

Permalink
Case7:paddle.distribution.Beta:fix beta(true stack) (#51847)
Browse files Browse the repository at this point in the history
  • Loading branch information
Difers authored Mar 22, 2023
1 parent 65c6d2e commit 32baca9
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 2 deletions.
9 changes: 9 additions & 0 deletions paddle/phi/kernels/cpu/stack_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ void StackKernel(const Context& dev_ctx,
int axis,
DenseTensor* out) {
if (axis < 0) axis += (x[0]->dims().size() + 1);

auto x_dims = x[0]->dims();
for (int i = 0; i < x_dims.size(); i++) {
PADDLE_ENFORCE_GT(x_dims[i],
0,
phi::errors::InvalidArgument(
"The dims of Input(X) should be greater than 0"));
}

int n = static_cast<int>(x.size());
T* y_data = dev_ctx.template Alloc<T>(out);
std::vector<const T*> x_datas(n);
Expand Down
5 changes: 3 additions & 2 deletions paddle/phi/kernels/funcs/stack_and_unstack.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,12 @@ void StackRawKernel(const Context& ctx,

// Split x dim from axis to matrix of shape [x_row, x_col], and the output
// tensor's shape is [x_row, out_col].
int64_t x_row = 1;
int64_t x_row = 1, x_row_bak = 1;
for (int i = 0; i < axis; ++i) {
x_row *= x[0]->dims()[i];
}
int64_t x_col = x[0]->numel() / x_row;
x_row_bak = x_row == 0 ? 1 : x_row;
int64_t x_col = x[0]->numel() / x_row_bak;
int64_t out_col = x_col * num;

if (out->numel() < std::numeric_limits<int32_t>::max()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ def test_sample_shape(self):
== case.get('expect')
)

def test_errors(self):
with self.assertRaises(ValueError):
array = np.array([], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [0]), dtype='int32')
paddle.distribution.Beta(alpha=x, beta=x)


if __name__ == '__main__':
unittest.main()

0 comments on commit 32baca9

Please sign in to comment.