From 959288d0c554f89710b0b8bbc989802fb5dfdc6f Mon Sep 17 00:00:00 2001 From: Malte Londschien <61679398+mlondschien@users.noreply.github.com> Date: Fri, 8 Nov 2024 08:42:24 +0100 Subject: [PATCH] Use dtype dependent precision (#844) Co-authored-by: Martin Stancsics Co-authored-by: Luca Bittarello <15511539+lbittarello@users.noreply.github.com> --- CHANGELOG.rst | 3 ++- src/glum/_cd_fast.pyx | 5 +++-- src/glum/_glm.py | 2 +- src/glum/_solvers.py | 11 +++++++++-- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index b57d73cb..4e1813ed 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -16,9 +16,10 @@ Changelog - New fitted attributes ``col_means_`` and ``col_stds_`` for classes :class:`~glum.GeneralizedLinearRegressor` and :class:`~glum.GeneralizedLinearRegressorCV`. - :class:`~glum.GeneralizedLinearRegressor` now prints more informative logs when fitting with ``alpha_search=True`` and ``verbose=True``. -**Bug fix: +**Bug fixes:** - Fixed a bug where :meth:`glum.GeneralizedLinearRegressor.fit` would raise a ``dtype`` mismatch error if fit with ``alpha_search=True``. +- Use data type (``float64`` or ``float32``) dependent precision in solvers. 3.0.2 - 2024-06-25 ------------------ diff --git a/src/glum/_cd_fast.pyx b/src/glum/_cd_fast.pyx index bd3910cb..7dea8f1c 100644 --- a/src/glum/_cd_fast.pyx +++ b/src/glum/_cd_fast.pyx @@ -117,7 +117,8 @@ def enet_coordinate_descent_gram(int[::1] active_set, bint has_lower_bounds, floating[:] lower_bounds, bint has_upper_bounds, - floating[:] upper_bounds): + floating[:] upper_bounds, + floating eps): """Cython version of the coordinate descent algorithm for Elastic-Net regression We minimize @@ -162,7 +163,7 @@ def enet_coordinate_descent_gram(int[::1] active_set, else: P1_ii = P1[ii - intercept] - if Q[active_set_ii, active_set_ii] == 0.0: + if Q[active_set_ii, active_set_ii] <= eps: continue w_ii = w[ii] # Store previous value diff --git a/src/glum/_glm.py b/src/glum/_glm.py index 1e594694..aadccaf9 100644 --- a/src/glum/_glm.py +++ b/src/glum/_glm.py @@ -452,7 +452,7 @@ def _one_over_var_inf_to_val(arr: np.ndarray, val: float) -> np.ndarray: If values are zeros, return val. """ - zeros = np.where(np.abs(arr) < 1e-7) + zeros = np.where(np.abs(arr) < np.sqrt(np.finfo(arr.dtype).eps)) with np.errstate(divide="ignore"): one_over = 1 / arr one_over[zeros] = val diff --git a/src/glum/_solvers.py b/src/glum/_solvers.py index 853e8fd0..43b3de50 100644 --- a/src/glum/_solvers.py +++ b/src/glum/_solvers.py @@ -70,6 +70,7 @@ def _cd_solver(state, data, active_hessian): data._lower_bounds, data.has_upper_bounds, data._upper_bounds, + np.finfo(state.coef.dtype).eps * 16, ) return new_coef - state.coef, n_cycles @@ -546,6 +547,9 @@ def __init__(self, coef, data): self.line_search_runtime = None self.quadratic_update_runtime = None + # used in the line-search Armijo stopping criterion + self.large_number = 1e30 if data.X.dtype == np.float32 else 1e43 + def _record_iteration(self): self.n_iter += 1 @@ -759,7 +763,9 @@ def line_search(state: IRLSState, data: IRLSData, d: np.ndarray): """ # line search parameters (beta, sigma) = (0.5, 0.0001) - eps = 16 * np.finfo(state.obj_val.dtype).eps # type: ignore + # Use np.finfo(state.coef.dtype).eps instead np.finfo(state.obj_val), as + # state.obj_val is np.float64, even if the data is np.float32. + eps = 16 * np.finfo(state.coef.dtype).eps # type: ignore # line search by sequence beta^k, k=0, 1, .. # F(w + lambda d) - F(w) <= lambda * bound @@ -792,7 +798,8 @@ def line_search(state: IRLSState, data: IRLSData, d: np.ndarray): ) # 1. Check Armijo / sufficient decrease condition. loss_improvement = obj_val_wd - state.obj_val - if mu_wd.max() < 1e43 and loss_improvement <= factor * bound: + + if mu_wd.max() < state.large_number and loss_improvement <= factor * bound: break # 2. Deal with relative loss differences around machine precision. tiny_loss = np.abs(state.obj_val * eps) # type: ignore