-
Notifications
You must be signed in to change notification settings - Fork 262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Balanced batch sampler+base dataset #753
Conversation
Replace BalancedBatchSampler's `force_balancing` and `throw_on_error` parameters with `on_error`
…nced partitioning
assert ( | ||
len(self.paths) == 1 | ||
), f"{type(self)} does not support a list of src paths." | ||
self.path = self.paths[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is a list being fed in here in the first place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good point. I don't know, i assumed it was because some datasets might take multiple files, potentially multiple lmdb paths. But i dont think this is currently used, I will take this out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second pass, looks like we do use multiple paths for ASE datasets,
https://github.com/FAIR-Chem/fairchem/blob/main/tests/core/datasets/test_ase_datasets.py#L89
@@ -214,7 +212,7 @@ def get_metadata(self, num_samples: int = 100): | |||
} | |||
|
|||
|
|||
class SinglePointLmdbDataset(LmdbDataset[BaseData]): | |||
class SinglePointLmdbDataset(LmdbDataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the dataset changes, this may be a good time to deprecate these....
Warning - you'll break a few tests. If you break way too many you can revert and we can do this as part of BE.
@@ -114,9 +116,6 @@ def __init__(self, config, transform=None) -> None: | |||
if self.train_on_oc20_total_energies: | |||
with open(config["oc20_ref"], "rb") as fp: | |||
self.oc20_ref = pickle.load(fp) | |||
if self.config.get("lin_ref", False): | |||
coeff = np.load(self.config["lin_ref"], allow_pickle=True)["coeff"] | |||
self.lin_ref = torch.nn.Parameter(torch.tensor(coeff), requires_grad=False) | |||
self.subsample = aii(self.config.get("subsample", False), bool) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this functionality is technically now supported with ur base dataset right? if so we can remove.
@@ -455,7 +456,9 @@ def load_model(self) -> None: | |||
self.logger.log_summary({"num_params": self.model.num_params}) | |||
|
|||
if distutils.initialized() and not self.config["noddp"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this PR behave when --noddp
is defined? Would be good to test that functionality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added!
235965f
to
4426870
Compare
65b5a6f
to
1e2e4aa
Compare
msg = msg[:-1] + f"that are smaller than the given max_atoms {max_atoms}." | ||
raise ValueError(msg) | ||
|
||
indices = indices[:max_index] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm testing this on my end...If i have a config as follows:
dataset:
train:
format: lmdb
src: data/s2ef/all/train/
first_n: 2000000
key_mapping:
y: energy
force: forces
transforms:
normalizer:
energy:
mean: -0.7554450631141663
stdev: 2.887317180633545
forces:
mean: 0
stdev: 2.887317180633545
val:
src: data/s2ef/all/val_id_30k
Because we have val set up to inherit the train parameters, it will also try to grab first_n
for validation/testing as well. To get around this I would need to explicitly define first_n
for the validation/test splits which is not ideal.
We should have it such that something like first_n
does not try to get passed down to val/test. I don't know the best way we want to tackle this, open to ideas.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same is true with metadata_path
, etc. All of it is trying to be inherited
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally fair, again great catch @mshuaibii ! Thank you! Looking into it...
* Update BalancedBatchSampler to use datasets' `data_sizes` method Replace BalancedBatchSampler's `force_balancing` and `throw_on_error` parameters with `on_error` * Remove python 3.10 syntax * Documentation * Added set_epoch method * Format * Changed "resolved dataset" message to be a debug log to reduce log spam * clean up batchsampler and tests * base dataset class * move lin_ref to base dataset * inherit basedataset for ase dataset * filter indices prop * added create_dataset fn * yaml load fix * create dataset function instead of filtering in base * remove filtered_indices * make create_dataset and LMDBDatabase importable from datasets * create_dataset cleanup * test create_dataset * use metadata.natoms directly and add it to subset * use self.indices to handle shard * rename _data_sizes * fix Subset of metadata * minor change to metadata, added full path option * import updates * implement get_metadata for datasets; add tests for max_atoms and balanced partitioning * a[:len(a)+1] does not throw error, change to check for this * off by one fix * fixing tests * plug create_dataset into trainer * remove datasetwithsizes; fix base dataset integration; replace close_db with __del__ * lint * add/fix test; * adding new notebook for using fairchem models with NEBs without CatTSunami enumeration (#764) * adding new notebook for using fairchem models with NEBs * adding md tutorials * blocking code cells that arent needed or take too long * Add extra test case for local batch size = 1 * fix example * fix test case * reorg changes * remove metadata_has_sizes in favor of basedataset function metadata_hasattr * fix data_parallel typo * fix up some tests * rename get_metadata to sample_property_metadata * add slow get_metadata for ase; add tests for get_metadata (ase+lmdb); add test for make lmdb metadata sizes * add support for different backends and ddp in pytest * fix tests and balanced batch sampler * make default dataset lmdb * lint * fix tests * test with world_size=0 by default * fix tests * fix tests.. * remove subsample from oc22 dataset * remove old datasets; add test for noddp * remove load balancing from docs * fix docs; add train_split_settings and test for this --------- Co-authored-by: Nima Shoghi <nimashoghi@gmail.com> Co-authored-by: Nima Shoghi <nimashoghi@fb.com> Co-authored-by: lbluque <lbluque@meta.com> Co-authored-by: Brandon <bmwood@meta.com> Co-authored-by: Brook Wander <73855115+brookwander@users.noreply.github.com> Co-authored-by: Muhammed Shuaibi <45150244+mshuaibii@users.noreply.github.com> Co-authored-by: Muhammed Shuaibi <mushuaibi@gmail.com> (cherry picked from commit 04a69b0)
Add Basedataset and corresponding changes to balanced batch sampler.
Added Basedataset.metadata_hasattr() to check if the metadata associated with dataset has a specific attr. In the case of balancedbatchsampler we require 'natoms' to be present.