Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix batch size #354

Merged
merged 4 commits into from
Aug 19, 2024
Merged

Fix batch size #354

merged 4 commits into from
Aug 19, 2024

Conversation

rg936672
Copy link
Contributor

@rg936672 rg936672 commented Aug 2, 2024

PR Type

  • Bugfix
  • Tests

Description

Closes #265. In batch mode, the GP's internal training data has to be reset at every training iteration to avoid it complaining that "You must train on the training inputs!".

How Has This Been Tested?

The base usage unit test has been refactored to also test batch mode.

Does this PR introduce a breaking change?

No.

Screenshots

N/A

Checklist before requesting a review

  • I have made sure that my PR is not a duplicate.
  • My code follows the style guidelines of this project.
  • I have commented my code, particularly in hard-to-understand areas.
  • I have performed a self-review of my code.
  • I have made corresponding changes to the documentation.
  • My changes generate no new warnings.
  • I have added tests that prove my fix is effective or that my feature works.
  • New and existing unit tests pass locally with my changes.
  • Any dependent changes have been merged and published in downstream modules.

# update the training data to the current train_x and train_y, to avoid "You must train on the
# training data!"
self._gp.set_train_data(train_x, train_y.squeeze(dim=-1), strict=False)
# TODO: consider using get_fantasy_model() instead if possible, when using ExactGP?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a sense for the effect on performance of not using get_fantasy_model()? Is it more of a nice to have or is it possible batching could slow down the fitting as it is currently implemented?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Batching is still a good deal faster than running on the whole dataset - on my machine, the unbatched test_base_usage integration test takes 2.87s for a 500-training-point dataset, but the test that does the same thing with batch_size=100 takes only 0.48s.

Strangely, setting batch_size=500 makes the test take only 1.85s, and setting batch_size=499 makes the test fail as 55% of outputs are outside the confidence interval! This should be investigated, but might have to be its own separate issue.

I don't know enough about gpytorch to comment on whether or by how much the performance might be improved by get_fantasy_model(), though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, that definitely sounds like it would need its own issue. I would be happy to merge after that has been set up!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

@db091756 db091756 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question. Also, is there anything that can be said about #266 given the changes here?

@rg936672
Copy link
Contributor Author

One question. Also, is there anything that can be said about #266 given the changes here?

I don't think I can say anything new on #266, as I don't have easy access to a GPU machine to test on, and .to(device) is only really relevant in the context of non-CPU computing.

@rg936672 rg936672 requested a review from db091756 August 19, 2024 12:54
@db091756 db091756 merged commit 6e85e7e into develop Aug 19, 2024
10 checks passed
@db091756 db091756 deleted the fix/batch-size branch August 19, 2024 14:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Batch mode fails unless batch_size is the length of the training data
2 participants