Skip to content

Commit

Permalink
Use dtype dependent precision (#844)
Browse files Browse the repository at this point in the history
Co-authored-by: Martin Stancsics <martin.stancsics@quantco.com>
Co-authored-by: Luca Bittarello <15511539+lbittarello@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 8, 2024
1 parent df6f372 commit 959288d
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 6 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ Changelog
- New fitted attributes ``col_means_`` and ``col_stds_`` for classes :class:`~glum.GeneralizedLinearRegressor` and :class:`~glum.GeneralizedLinearRegressorCV`.
- :class:`~glum.GeneralizedLinearRegressor` now prints more informative logs when fitting with ``alpha_search=True`` and ``verbose=True``.

**Bug fix:
**Bug fixes:**

- Fixed a bug where :meth:`glum.GeneralizedLinearRegressor.fit` would raise a ``dtype`` mismatch error if fit with ``alpha_search=True``.
- Use data type (``float64`` or ``float32``) dependent precision in solvers.

3.0.2 - 2024-06-25
------------------
Expand Down
5 changes: 3 additions & 2 deletions src/glum/_cd_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def enet_coordinate_descent_gram(int[::1] active_set,
bint has_lower_bounds,
floating[:] lower_bounds,
bint has_upper_bounds,
floating[:] upper_bounds):
floating[:] upper_bounds,
floating eps):
"""Cython version of the coordinate descent algorithm
for Elastic-Net regression
We minimize
Expand Down Expand Up @@ -162,7 +163,7 @@ def enet_coordinate_descent_gram(int[::1] active_set,
else:
P1_ii = P1[ii - intercept]

if Q[active_set_ii, active_set_ii] == 0.0:
if Q[active_set_ii, active_set_ii] <= eps:
continue

w_ii = w[ii] # Store previous value
Expand Down
2 changes: 1 addition & 1 deletion src/glum/_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def _one_over_var_inf_to_val(arr: np.ndarray, val: float) -> np.ndarray:
If values are zeros, return val.
"""
zeros = np.where(np.abs(arr) < 1e-7)
zeros = np.where(np.abs(arr) < np.sqrt(np.finfo(arr.dtype).eps))
with np.errstate(divide="ignore"):
one_over = 1 / arr
one_over[zeros] = val
Expand Down
11 changes: 9 additions & 2 deletions src/glum/_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def _cd_solver(state, data, active_hessian):
data._lower_bounds,
data.has_upper_bounds,
data._upper_bounds,
np.finfo(state.coef.dtype).eps * 16,
)
return new_coef - state.coef, n_cycles

Expand Down Expand Up @@ -546,6 +547,9 @@ def __init__(self, coef, data):
self.line_search_runtime = None
self.quadratic_update_runtime = None

# used in the line-search Armijo stopping criterion
self.large_number = 1e30 if data.X.dtype == np.float32 else 1e43

def _record_iteration(self):
self.n_iter += 1

Expand Down Expand Up @@ -759,7 +763,9 @@ def line_search(state: IRLSState, data: IRLSData, d: np.ndarray):
"""
# line search parameters
(beta, sigma) = (0.5, 0.0001)
eps = 16 * np.finfo(state.obj_val.dtype).eps # type: ignore
# Use np.finfo(state.coef.dtype).eps instead np.finfo(state.obj_val), as
# state.obj_val is np.float64, even if the data is np.float32.
eps = 16 * np.finfo(state.coef.dtype).eps # type: ignore

# line search by sequence beta^k, k=0, 1, ..
# F(w + lambda d) - F(w) <= lambda * bound
Expand Down Expand Up @@ -792,7 +798,8 @@ def line_search(state: IRLSState, data: IRLSData, d: np.ndarray):
)
# 1. Check Armijo / sufficient decrease condition.
loss_improvement = obj_val_wd - state.obj_val
if mu_wd.max() < 1e43 and loss_improvement <= factor * bound:

if mu_wd.max() < state.large_number and loss_improvement <= factor * bound:
break
# 2. Deal with relative loss differences around machine precision.
tiny_loss = np.abs(state.obj_val * eps) # type: ignore
Expand Down

0 comments on commit 959288d

Please sign in to comment.