Skip to content

Commit

Permalink
Fix implicit type conversion error and enable FP32 training
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed Nov 28, 2020
1 parent 0bf55d4 commit bda29cc
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions python/cuml/experimental/linear_model/lars.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,16 @@ class Lars(Base, RegressorMixin):
""" Remove mean and scale each feature column. """
x_mean = cp.zeros(self.n_cols, dtype=self.dtype)
x_scale = cp.ones(self.n_cols, dtype=self.dtype)
y_mean = 0.0
y_mean = self.dtype.type(0.0)
X = cp.asarray(X_m)
y = cp.asarray(y_m)
if self.fit_intercept:
y_mean = cp.mean(y)
y = y - y_mean
if self.normalize:
x_mean = cp.mean(X, axis=0)
x_scale = cp.sqrt(cp.var(X, axis=0) * X.shape[0])
x_scale = cp.sqrt(cp.var(X, axis=0) *
self.dtype.type(X.shape[0]))
x_scale[x_scale==0] = 1
X = (X - x_mean) / x_scale
return X, y, x_mean, x_scale, y_mean
Expand Down Expand Up @@ -204,10 +205,12 @@ class Lars(Base, RegressorMixin):
def _fit_cpp(self, X, y, Gram, x_scale):
""" Fit lars model using cpp solver"""
cdef handle_t* handle_ = <handle_t*><size_t>self.handle.getHandle()
X_m, _, _, _ = input_to_cuml_array(X, order='F')
X_m, _, _, _ = input_to_cuml_array(X, check_dtype=self.dtype,
order='F')
cdef uintptr_t X_ptr = X_m.ptr
cdef int n_rows = X.shape[0]
cdef uintptr_t y_ptr = input_to_cuml_array(y).array.ptr
cdef uintptr_t y_ptr = \
input_to_cuml_array(y, check_dtype=self.dtype).array.ptr
cdef int max_iter = self.n_nonzero_coefs
self.beta_ = CumlArray.zeros(max_iter, dtype=self.dtype)
cdef uintptr_t beta_ptr = self.beta_.ptr
Expand Down Expand Up @@ -269,8 +272,7 @@ class Lars(Base, RegressorMixin):
self._set_output_type(X)

X_m, n_rows, self.n_cols, self.dtype = input_to_cuml_array(
X, check_dtype=[np.float32, np.float64], order='F',
convert_to_dtype=np.float64)
X, check_dtype=[np.float32, np.float64], order='F')

conv_dtype = self.dtype if convert_dtype else None
y_m, _, _, _ = input_to_cuml_array(
Expand Down

0 comments on commit bda29cc

Please sign in to comment.