Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: programmatic validations and error handling #209

Open
adriangb opened this issue Feb 28, 2021 · 0 comments
Open

ENH: programmatic validations and error handling #209

adriangb opened this issue Feb 28, 2021 · 0 comments

Comments

@adriangb
Copy link
Owner

adriangb commented Feb 28, 2021

From #208 (comment) :

I think giving users better errors and validating their inputs like you are doing here can be a very valuable part of SciKeras, but currently it is done in an ad-hoc manner via _check_model_compatibility, etc. I think if we add more of these types of things, it would be nice to have an organized interface for it.

It would be good to organize these checks. We can split them into two categories:

  • Checks before fit is called. This includes checking if the model is compiled, if it has a loss, that the number of outputs match the target, etc.
  • Error handling after fit or predict are called: translating cryptic Keras/TF erorrs into user friendly errors with suggestions to fix them.

I envision something like this, inspired by Pydantic:

wrappers.py

class BaseWrapper(...):

    def __get_pre_fit_validators__(self) -> Generator[Callable[[BaseWrapper, Dict[str, Any]], None], None, None]:
        yield validate_compiled  # accepts self & self.model_.fit kwargs
        yield validate_has_loss
        yield validate_outptus_match_target
        ....

    def __handle_keras_exceptions__(self) -> Generator[Callable[Exception, BaseWrapper, Dict[str, Any], Literal["fit", "predict"]], None, None]:
        yield catch_some_cryptic_error  # can pass by just returning without raising
        yield catch_some_other_cryptic_error

    def _fit_keras_model(....):
        fit_kwargs = ...
        for validator in self.__get_pre_fit_validators__():
            validator(self, fit_kwargs)
       try:
           self.model_.fit(..., **fit_kwargs)
       except Exception as e:
           for handler in self.__handle_keras_exceptions__():
                handler(e, fit_kwargs, "fit")
           raise  # if no handler caught it

And then we can split out SciKeras' default checks into their own modules:

_utils.validators.py

def validate_compiled(wrapper: BaseWrapper, fit_kwargs: Dict[str, Any]):
    ...
...

_utils.erorr_handlers.py

def catch_some_cryptic_error(exception: Exception, wrapper: BaseWrapper, fit_kwargs: Dict[str, Any], from: str):
    ...
...

And adding new checks (eg. to check things in KerasClassifier that don't make sense in BaseWrapper is made more explicit:

wrappers.py

class KerasClassifier(BaseWrapper):

    def __get_pre_fit_validators__(self) -> Generator[Callable[[BaseWrapper, Dict[str, Any]], None], None, None]:
        yield from super().__get_pre_fit_validators__()
        yield some_check_that_assumes_classification
        ....
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant