-
-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Elastic Net Regularizer #49
Conversation
dask_glm/algorithms.py
Outdated
@@ -319,7 +319,7 @@ def bfgs(X, y, max_iter=500, tol=1e-14, family=Logistic): | |||
return beta | |||
|
|||
|
|||
def proximal_grad(X, y, regularizer=L1, lamduh=0.1, family=Logistic, | |||
def proximal_grad(X, y, regularizer=L1(), lamduh=0.1, family=Logistic, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thoughts about changing this to 'l1
' and looking it up in _regularizers
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, should be changed. Also in addition, we could move the string logic to a function in regularizers.py
@TomAugspurger What do you think of my latest commit? I use the base to retrieve child classes via string. |
dask_glm/regularizers.py
Outdated
@@ -9,6 +9,7 @@ class Regularizer(object): | |||
Defines the set of methods required to create a new regularization object. This includes | |||
the regularization functions itself and it's gradient, hessian, and proximal operator. | |||
""" | |||
_name = '_base' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think maybe this should be just name
. I think @moody-marlin mentioned that users might defining their own regularizers, so they shouldn't need to override private attributes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, made the change.
max_iter=250, abstol=1e-4, reltol=1e-2, family=Logistic): | ||
|
||
pointwise_loss = family.pointwise_loss | ||
pointwise_gradient = family.pointwise_gradient | ||
regularizer = _regularizers.get(regularizer, regularizer) # string | ||
regularizer = Regularizer.get(regularizer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll be sad to see this line go, but 👍
dask_glm/regularizers.py
Outdated
def proximal_operator(beta, t): | ||
return 1 / (1 + t) * beta | ||
Defines the set of methods required to create a new regularization object. This includes | ||
the regularization functions itself and it's gradient, hessian, and proximal operator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's -> its
dask_glm/regularizers.py
Outdated
def f(beta): | ||
return (beta**2).sum() | ||
def proximal_operator(self, beta, t): | ||
"""Proximal operator function for non-differentiable regularization function.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Proximal operator for regularization function
; the regularizer doesn't need to be non-differentiable, it just can be.
dask_glm/regularizers.py
Outdated
return grad(beta, *args) + lam * L1.gradient(beta) | ||
return wrapped | ||
def hessian(self, beta): | ||
raise ValueError('l1 norm is not twice differentiable!') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably just fix this with this PR; this should be similar to the gradient:
if np.any(np.isclose(beta, 0)):
raise ValueError('l1 norm is not twice differentiable at 0!')
else:
return np.zeros((beta.shape[0], beta.shape[0]))
l1 regularizer is a straight line everywhere except at 0 where there's a kink.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then do we want to switch the elastic net hessian to include the l1 side to the weight? It won't have an effect except raise when there are errors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, definitely - I missed that.
@@ -89,7 +89,7 @@ def test_basic_unreg_descent(func, kwargs, N, nchunks, family): | |||
@pytest.mark.parametrize('nchunks', [1, 10]) | |||
@pytest.mark.parametrize('family', [Logistic, Normal, Poisson]) | |||
@pytest.mark.parametrize('lam', [0.01, 1.2, 4.05]) | |||
@pytest.mark.parametrize('reg', [L1, L2]) | |||
@pytest.mark.parametrize('reg', [r() for r in Regularizer.__subclasses__()]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice.
@moody-marlin ok implemented your changes, looks good to go. |
@TomAugspurger do you have any other comments / suggestions? LGTM. |
👍 on a quick skim. |
This PR adds elastic net regularization, the weighted sum of l1 and l2, as an option. Couldn't find a proximal operator so derived it with much much help from @moody-marlin in an included notebook. It also adds an abstract base class for
Regularizer
and fixes the formulas for l2.Resolves #48
Resolves #47