Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor updates to AxClient.fit_model #2580

Closed
wants to merge 1 commit into from

Conversation

saitcakmak
Copy link
Contributor

Summary:
I came across AxClient.fit_model while working on something else. It appeared to be a method that was split off from get_model_predictions, with leftover error messages that did not fit well with it as a standalone non-protected method.

This diff clears up the error messages and simplifies the definition of AxClient.fit_model.

Differential Revision: D59778614

@facebook-github-bot facebook-github-bot added the CLA Signed Do not delete this pull request or issue due to inactivity. label Jul 15, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D59778614

@codecov-commenter
Copy link

codecov-commenter commented Jul 15, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 95.22%. Comparing base (6859016) to head (b56f7b5).

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #2580   +/-   ##
=======================================
  Coverage   95.22%   95.22%           
=======================================
  Files         489      489           
  Lines       47647    47642    -5     
=======================================
- Hits        45372    45369    -3     
+ Misses       2275     2273    -2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -1197,9 +1198,10 @@ def get_model_predictions(
"argument to `get_model_predictions`."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few lines above, it says that AxClient only supports one-arm trials. Is that saying that AxClient doesn't support BatchTrials? I didn't think that was the case. See e.g. here:

"Selecting a GenerationStrategy when using BatchTrials is in beta. "

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know about AxClient more generally but this seems to be a limitation of the get_model_predictions method. I'll clarify this in the docstring.

# can be performed without the need to call get_next_trial(), we update the
# model with all attached data. Note that this method keeps track of previously
# seen trials and will update the model if there is newly attached data.
self.generation_strategy._maybe_transition_to_next_node()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why might we need to transition to the next node?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could go either way. With this call, we fit the model that would be used to generate the next trial. Without it, we would just re-train (with updated data) the model that was used to fit the last trial. I'll add a note about this to the docstring.

Copy link
Contributor

@esantorella esantorella left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may just reflect my ignorance of this area of the codebase, but could you clear up support for BatchTrials and add a comment explaining the use of ax_client.generation_strategy._maybe_transition_to_next_node()?

Summary:
Pull Request resolved: facebook#2580

I came across `AxClient.fit_model` while working on something else. It appeared to be a method that was split off from `get_model_predictions`, with leftover error messages that did not fit well with it as a standalone non-protected method.

This diff clears up the error messages and simplifies the definition of `AxClient.fit_model`.

Differential Revision: D59778614
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D59778614

Copy link
Contributor

@esantorella esantorella left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 56c9b46.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed Do not delete this pull request or issue due to inactivity. fb-exported Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants