Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add default losses to KerasClassifier and KerasRegressor #208

Open
wants to merge 52 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
4be5d0c
Add default loss to KerasClassifier
stsievert Feb 27, 2021
dccd92b
update message/tests
stsievert Feb 27, 2021
1d80b57
black
stsievert Feb 27, 2021
1f45285
isort
stsievert Feb 27, 2021
9358d6e
better test
stsievert Feb 28, 2021
6cc112e
Add default loss for KerasRegressor
stsievert Feb 28, 2021
80618bf
black
stsievert Feb 28, 2021
60b2404
catch binary cross entropy
stsievert Mar 1, 2021
dcb0823
black
stsievert Mar 1, 2021
0faadd9
Clean type hints in __init__
stsievert Mar 1, 2021
0449481
isort
stsievert Mar 1, 2021
ed4c1f5
change KerasRegressor.__init__
stsievert Mar 1, 2021
c58ec74
tests run
stsievert Mar 1, 2021
e73710d
MAINT
stsievert Mar 2, 2021
4e7e09f
add right loss back
stsievert Mar 2, 2021
2e830ff
Try removing binary_crossentropy check
stsievert Mar 2, 2021
e1ea339
black
stsievert Mar 2, 2021
36e6499
remove annoying 'needs linting'
stsievert Mar 2, 2021
8310834
Uncomment error
stsievert Mar 2, 2021
6ee8b50
warn for user compiled models
stsievert Mar 2, 2021
b88b74e
Union[T, None] → Optional[T]
stsievert Mar 2, 2021
3a3a536
DOC: complete docstring
stsievert Mar 2, 2021
9808cf2
DOC: complete docstring
stsievert Mar 2, 2021
d0147ac
fix loss?
stsievert Mar 2, 2021
7243995
Revert "fix loss?"
stsievert Mar 2, 2021
9735974
Warn if compiled with wrong loss
stsievert Mar 2, 2021
8cc0474
draft at loss=None
stsievert Mar 2, 2021
b0229c5
v2
stsievert Mar 2, 2021
dccfc5e
black
stsievert Mar 2, 2021
d2e23cb
Tell mypy to use type hints
stsievert Mar 2, 2021
9c3af6b
loss=None to docs
stsievert Mar 2, 2021
5121131
whoops on type hints
stsievert Mar 2, 2021
0dfa526
Update tests/test_simple_usage.py
stsievert Mar 2, 2021
7b379d7
Update scikeras/wrappers.py
stsievert Mar 2, 2021
e4338fc
Update tests/test_simple_usage.py
stsievert Mar 2, 2021
ca69f2e
Add classifier default loss test
stsievert Mar 2, 2021
2ac57e0
Merge branch 'clf-default-loss' of https://github.com/stsievert/scike…
stsievert Mar 2, 2021
0de8abe
Better warning for (really rare) use case
stsievert Mar 2, 2021
0cf7610
update warning with more recommendations
stsievert Mar 2, 2021
d4c3eea
TST: all classification losses
stsievert Mar 4, 2021
7fab517
Re-initialize
stsievert Mar 4, 2021
59e7012
tmp
stsievert Mar 4, 2021
0386e4e
loss_name is None
stsievert Mar 4, 2021
3a46538
black
stsievert Mar 4, 2021
8f2b00b
Remove backticks
stsievert Mar 4, 2021
e80338b
typing for utils/*_name
stsievert Mar 4, 2021
7e23480
raise
stsievert Mar 4, 2021
94df48a
API: loss_name / metric_name return None
stsievert Mar 4, 2021
35e1a6c
try cce
stsievert Mar 4, 2021
f092b7a
catch loss is not None
stsievert Mar 4, 2021
5af8b4c
tmp
stsievert Mar 4, 2021
7b38bc8
typo
stsievert Mar 4, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ jobs:
- uses: pre-commit/action@v2.0.0

TestStable:
needs: Linting
name: Ubuntu / Python ${{ matrix.python-version }} / TensorFlow Stable / Scikit-Learn Stable
runs-on: ubuntu-latest
strategy:
Expand Down Expand Up @@ -55,7 +54,6 @@ jobs:
- uses: codecov/codecov-action@v1

TestDev:
needs: Linting
name: Ubuntu / Python ${{ matrix.python-version }} / TensorFlow Nightly / Scikit-Learn Nightly
runs-on: ubuntu-latest
strategy:
Expand Down Expand Up @@ -98,7 +96,6 @@ jobs:
- uses: codecov/codecov-action@v1

TestOldest:
needs: Linting
name: Ubuntu / Python ${{ matrix.python-version }} / TF ${{ matrix.tf-version }} / Scikit-Learn ${{ matrix.sklearn-version }}
runs-on: ubuntu-latest
strategy:
Expand Down Expand Up @@ -135,7 +132,6 @@ jobs:
- uses: codecov/codecov-action@v1

TestOSs:
needs: Linting
name: ${{ matrix.os }} / Python ${{ matrix.python-version }} / TF Stable / Scikit-Learn Stable
runs-on: ${{ matrix.os }}-latest
strategy:
Expand Down
1 change: 0 additions & 1 deletion docs/source/notebooks/DataTransformers.md
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,6 @@ from sklearn.metrics import accuracy_score


class MultiOutputClassifier(KerasClassifier):

@property
def target_encoder(self):
return MultiOutputTransformer()
Expand Down
18 changes: 18 additions & 0 deletions scikeras/_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Callable, List, Type, Union

import numpy as np
import tensorflow as tf
import tensorflow.keras as keras

from tensorflow.keras.callbacks import Callback as TF_Callback
from tensorflow.keras.losses import Loss as TF_Loss
from tensorflow.keras.metrics import Metric as TF_Metric
from tensorflow.keras.optimizers import Optimizer as TF_Optimizer


Model = Union[Callable[..., keras.Model], keras.Model]
RandomState = Union[int, np.random.RandomState]
Optimizer = Union[str, TF_Optimizer, Type[TF_Optimizer]]
Loss = Union[str, TF_Loss, Type[TF_Loss], Callable]
Metrics = Union[List[Union[str, TF_Metric, Type[TF_Metric], Callable]]]
Callbacks = Union[List[Union[TF_Callback, Type[TF_Callback]]]]
adriangb marked this conversation as resolved.
Show resolved Hide resolved
adriangb marked this conversation as resolved.
Show resolved Hide resolved
Empty file added scikeras/py.typed
Empty file.
36 changes: 12 additions & 24 deletions scikeras/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def _camel2snake(s: str) -> str:
return "".join(["_" + c.lower() if c.isupper() else c for c in s]).lstrip("_")


def loss_name(loss: Union[str, Loss, Callable]) -> str:
def loss_name(loss: Union[str, Loss, Callable]) -> Union[None, str]:
"""Retrieves a loss's full name (eg: "mean_squared_error").

Parameters
Expand All @@ -26,7 +26,8 @@ def loss_name(loss: Union[str, Loss, Callable]) -> str:
Returns
-------
str
String name of the loss.
String name of the loss (e.g., "mse") or None if the
input is not a string, Metric or callable.

Notes
-----
Expand All @@ -43,26 +44,20 @@ def loss_name(loss: Union[str, Loss, Callable]) -> str:
'binary_crossentropy'
>>> loss_name(losses.binary_crossentropy)
'binary_crossentropy'

Raises
------
TypeError
If loss is not a string, tf.keras.losses.Loss instance or a callable.
>>> loss_name({"out1": "mse", "out2": "mae"})
None
"""
if isclass(loss):
loss = loss()
if not (isinstance(loss, (str, Loss)) or callable(loss)):
raise TypeError(
"``loss`` must be a string, a function, an instance of ``tf.keras.losses.Loss``"
" or a type inheriting from ``tf.keras.losses.Loss``"
)
return None
fn_or_cls = keras_loss_get(loss)
if isinstance(fn_or_cls, Loss):
return _camel2snake(fn_or_cls.__class__.__name__)
return fn_or_cls.__name__


def metric_name(metric: Union[str, Metric, Callable]) -> str:
def metric_name(metric: Union[str, Metric, Callable]) -> Union[None, str]:
"""Retrieves a metric's full name (eg: "mean_squared_error").

Parameters
Expand All @@ -74,7 +69,8 @@ def metric_name(metric: Union[str, Metric, Callable]) -> str:
Returns
-------
str
Full name for Keras metric. Ex: "mean_squared_error".
Full name for Keras metric (e.g., "mean_squared_error") or None if the
input is not a string, Metric or callable.

Notes
-----
Expand All @@ -91,21 +87,13 @@ def metric_name(metric: Union[str, Metric, Callable]) -> str:
'BinaryCrossentropy'
>>> metric_name(metrics.binary_crossentropy)
'binary_crossentropy'

Raises
------
TypeError
If metric is not a string, a tf.keras.metrics.Metric instance a class
inheriting from tf.keras.metrics.Metric.
>>> metric_name({"out1": "bce", "out2": "hinge"})
None
"""
if isclass(metric):
metric = metric() # get_metric accepts instances, not classes
if not (isinstance(metric, (str, Metric)) or callable(metric)):
raise TypeError(
"``metric`` must be a string, a function, an instance of"
" ``tf.keras.metrics.Metric`` or a type inheriting from"
" ``tf.keras.metrics.Metric``"
)
return None
fn_or_cls = keras_metric_get(metric)
if isinstance(fn_or_cls, Metric):
return _camel2snake(fn_or_cls.__class__.__name__)
Expand Down
Empty file added scikeras/utils/py.typed
Empty file.
Loading