Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an option to pass model_cv_kwargs to ModelSpec.cross_validate #2566

Closed
wants to merge 4 commits into from

Commits on Jul 9, 2024

  1. Refactor BestModelSelector to operate on ModelSpecs (facebook#2557)

    Summary:
    Pull Request resolved: facebook#2557
    
    `BestModelSelector` was previously limited to selecting the best out of a given dictionary of CV diagnostics that were computed in `ModelSpec.cross_validate`. This setup limited extensibility, since any change would require updating `ModelSpec` code to the diagnostics that are computed.
    
    This diff refactors `BestModelSelector` to directly operate on the `ModelSpecs`. This new modular design will let each `BestModelSelector` class compute the necessary diagnostics internally, without locking us up to any pre-specified list.
    
    Other minor changes:
    - Removed `CallableEnum` and subclasses and replaced these with a single `ReductionCriterion` enum.
    - Split off `BestModelSelector` into a separate file to avoid circular imports.
    
    Differential Revision: D59249657
    saitcakmak authored and facebook-github-bot committed Jul 9, 2024
    Configuration menu
    Copy the full SHA
    bf55e2b View commit details
    Browse the repository at this point in the history
  2. Add storage support for BestModelSelector (facebook#2561)

    Summary:
    Pull Request resolved: facebook#2561
    
    Context: BestModelSelector is used to pick the model to gen from in GNode with multiple ModelSpecs.
    
    This diff adds storage support for the BestModelSelector, making it possible to serialize & deserialize GenerationNodes with multiple ModelSpecs & a BestModelSelector.
    
    Differential Revision: D59355084
    
    Reviewed By: Balandat
    saitcakmak authored and facebook-github-bot committed Jul 9, 2024
    Configuration menu
    Copy the full SHA
    a2bcceb View commit details
    Browse the repository at this point in the history
  3. Improve typing of *_kwargs fields in ModelSpec (facebook#2565)

    Summary:
    Pull Request resolved: facebook#2565
    
    These were typed as `Optional[Dict[str, Any]]` but immediately made into `Dict[str, Any]` in `__post_init__`. The internal usage also included a bunch of `self.*_kwargs or {}`, presumably to make Pyre happy with the optional field.
    
    This diff updates the type-hints to `Dict[str, Any]` and keeps these arguments as optional during initialization using `field(default_factory=dict)` as the default (which will assign an empty dict as the default). The change is also extended to `GenerationStep`.
    
    I kept the `__post_init__` method in place to keep backwards compatibility with any previous usage of `None`.
    
    Differential Revision: D59403745
    
    Reviewed By: Balandat
    saitcakmak authored and facebook-github-bot committed Jul 9, 2024
    Configuration menu
    Copy the full SHA
    59739e3 View commit details
    Browse the repository at this point in the history
  4. Add an option to pass model_cv_kwargs to ModelSpec.cross_validate (fa…

    …cebook#2566)
    
    Summary:
    Pull Request resolved: facebook#2566
    
    This allows us to customize the kwargs passed to `cross_validate` during call time (in addition to `ModelSpec.model_cv_kwargs`, which is specified at initialization). This can be used in `SingleDiagnosticBestModelSelector` to customize how the CV diagnostics are computed.
    
    To ensure we don't return cached results that were computed using different kwargs, we also store the last `cv_kwargs` used and re-compute CV if they changed.
    
    Differential Revision: D59406177
    saitcakmak authored and facebook-github-bot committed Jul 9, 2024
    Configuration menu
    Copy the full SHA
    8e86ddc View commit details
    Browse the repository at this point in the history