Skip to content

Commit

Permalink
fix bugs of paddle.linalg.lstsq (#44280)
Browse files Browse the repository at this point in the history
  • Loading branch information
haohongxiang authored Jul 13, 2022
1 parent 7cf72a3 commit 2af286a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 10 deletions.
23 changes: 16 additions & 7 deletions paddle/fluid/operators/lstsq_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class LstsqCUDAKernel : public framework::OpKernel<T> {
true,
batch_count,
m,
n,
nrhs,
k,
x_data,
x_stride,
Expand Down Expand Up @@ -137,14 +137,17 @@ class LstsqCUDAKernel : public framework::OpKernel<T> {

// Step 2, solve R^H Z = Y
Tensor trans_r = dito.Transpose(new_x);
Tensor slice_r = dito.Slice(trans_r, {-2}, {0}, {min_mn});
Tensor res_r = dito.TrilTriu(slice_r, 0, false);

phi::TriangularSolveKernel<T, Context>(
phi_dev_ctx, trans_r, new_y, true, true, false, solution);
phi_dev_ctx, res_r, new_y, true, true, false, solution);

// Step 3, X <- Q Z
BatchedOrgqr<DeviceContext, T>(dev_ctx,
batch_count,
n,
n,
m,
min_mn,
x_data,
n,
Expand Down Expand Up @@ -183,15 +186,18 @@ void BatchedOrmqr<platform::CUDADeviceContext, float>(
auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSormqr_bufferSize(
handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork));
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float));
float* workspace_ptr = reinterpret_cast<float*>(workspace->ptr());
auto info = memory::Alloc(dev_ctx, sizeof(int));
int* info_d = reinterpret_cast<int*>(info->ptr());

for (int i = 0; i < batch_size; ++i) {
float* a_working_ptr = &a[i * a_stride];
float* tau_working_ptr = &tau[i * tau_stride];
float* other_working_ptr = &other[i * other_stride];

handle = dev_ctx.cusolver_dn_handle();
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float));
float* workspace_ptr = reinterpret_cast<float*>(workspace->ptr());

// compute ormgr
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cusolverDnSormqr(handle,
Expand Down Expand Up @@ -249,15 +255,18 @@ void BatchedOrmqr<platform::CUDADeviceContext, double>(
auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDormqr_bufferSize(
handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork));
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double));
double* workspace_ptr = reinterpret_cast<double*>(workspace->ptr());
auto info = memory::Alloc(dev_ctx, sizeof(int));
int* info_d = reinterpret_cast<int*>(info->ptr());

for (int i = 0; i < batch_size; ++i) {
double* a_working_ptr = &a[i * a_stride];
double* tau_working_ptr = &tau[i * tau_stride];
double* other_working_ptr = &other[i * other_stride];

handle = dev_ctx.cusolver_dn_handle();
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double));
double* workspace_ptr = reinterpret_cast<double*>(workspace->ptr());

// compute ormgr
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cusolverDnDormqr(handle,
Expand Down
26 changes: 23 additions & 3 deletions python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,16 @@ def init_config(self):
self._input_shape_2 = (5, 8)


class LinalgLstsqTestCase3(LinalgLstsqTestCase):

def init_config(self):
self.dtype = 'float64'
self.rcond = 1e-15
self.driver = "gels"
self._input_shape_1 = (10, 7, 3)
self._input_shape_2 = (10, 7, 6)


class LinalgLstsqTestCaseRcond(LinalgLstsqTestCase):

def init_config(self):
Expand All @@ -192,7 +202,17 @@ def init_config(self):
self.rcond = None
self.driver = "gels"
self._input_shape_1 = (10, 5)
self._input_shape_2 = (10, 2)
self._input_shape_2 = (10, 8)


class LinalgLstsqTestCaseGelsFloat64(LinalgLstsqTestCase):

def init_config(self):
self.dtype = 'float32'
self.rcond = None
self.driver = "gels"
self._input_shape_1 = (3, 2, 8)
self._input_shape_2 = (3, 2, 15)


class LinalgLstsqTestCaseGelssFloat64(LinalgLstsqTestCase):
Expand Down Expand Up @@ -230,9 +250,9 @@ class LinalgLstsqTestCaseBatch2(LinalgLstsqTestCase):
def init_config(self):
self.dtype = 'float64'
self.rcond = 1e-15
self.driver = "gelss"
self.driver = "gels"
self._input_shape_1 = (10, 8, 6)
self._input_shape_2 = (10, 8, 2)
self._input_shape_2 = (10, 8, 10)


class LinalgLstsqTestCaseLarge1(LinalgLstsqTestCase):
Expand Down

0 comments on commit 2af286a

Please sign in to comment.