Skip to content

Commit

Permalink
Fix safety-bug of functional.linear (#34696)
Browse files Browse the repository at this point in the history
* Fix safety-bug of functional.linear

* Fix safety-bug of functional.linear

* Fix safety-bug of functional.linear

* Fix safety-bug of functional.linear
  • Loading branch information
Ray2020BD authored Aug 12, 2021
1 parent 589d13c commit 0e28c8b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
6 changes: 6 additions & 0 deletions paddle/fluid/operators/math/blas_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,12 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
T alpha, const T *A, const T *B, T beta, T *C, int batchCount,
int64_t strideA, int64_t strideB) const {
PADDLE_ENFORCE_NOT_NULL(
A, platform::errors::InvalidArgument("Pointer A should not be null."));
PADDLE_ENFORCE_NOT_NULL(
B, platform::errors::InvalidArgument("Pointer B should not be null."));
PADDLE_ENFORCE_NOT_NULL(
C, platform::errors::InvalidArgument("Pointer C should not be null."));
#ifdef PADDLE_WITH_MKLML
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
Expand Down
9 changes: 9 additions & 0 deletions python/paddle/fluid/tests/unittests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@ def test_error(self, place=paddle.CPUPlace()):
np.testing.assert_array_almost_equal(res_f, res_nn)
np.testing.assert_array_almost_equal(res_nn, res_np)

def test_error_dummy_input(self, place=paddle.CPUPlace()):
with self.assertRaises(ValueError):
x_arr = np.array([], dtype=np.float32)
x = paddle.to_tensor(
np.reshape(x_arr, (0, 4, 4, 4)), dtype='float32')
weight = paddle.zeros([4, 4, 4], dtype='float32')
bias = paddle.to_tensor([], dtype='float32')
paddle.nn.functional.linear(x, weight, bias=bias)


if __name__ == "__main__":
unittest.main()

0 comments on commit 0e28c8b

Please sign in to comment.