-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Per Batch Padded Dataset #281
Conversation
Needs a one line change but is otherwise ready for review. |
@@ -16,8 +16,11 @@ data: | |||
#train_data_path: <PATH> | |||
#val_data_path: <PATH> | |||
#test_data_path: <PATH> | |||
dataset: | |||
name: cached |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'd prefer type
over name
. But I see why you don't want to use that as attribute. Maybe method
or processing
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used name to be consistent with the other uses of discriminative unions in our config.
I think method would be fine but I think we should only use one keyword for this purpose
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll change it to processing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My reasoning against name
was, that dataset:name
reminds me of train
, test
or SPICE
etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yeah I can see how this might be confusing. thanks, I changed it 👍
"""Dataset which pads everything (atoms, neighbors) | ||
to the next larges power of two. | ||
This limits the compute wasted due to padding at the (negligible) | ||
cost of some recompilations. | ||
The NL is computed on-the-fly in parallel for `num_workers` of batches. | ||
Does not use tf.data. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume PB
stand for parallel batches
. Maybe mention that once somewhere in the docstring, so that it is clear. I would also not write MP
but materials project
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The class name stands for PerBatchPadded . I guess I can just write it out.same for materials project.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, I don't see "PB" written anywhere
apax/data/input_pipeline.py
Outdated
n_epochs, | ||
n_jit_steps=1, | ||
buffer_size=20, | ||
num_workers=10, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it make sense to set num_workers
to None and get the default from the number of available cores?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably. I'll update the default.
apax/data/input_pipeline.py
Outdated
if n_jit_steps > 1: | ||
raise "PerBatchPaddedDataset is not yet compatible with multi step jit" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure but in general it's better to raise a concrete value, like here raise TypeError(msg...)
instead, isn't it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, not doing so was not intended. Thanks for pointing it out
def transpose_dict_of_lists(dict_of_lists: dict): | ||
list_of_dicts = [] | ||
keys = list(dict_of_lists.keys()) | ||
|
||
for i in range(len(dict_of_lists[keys[0]])): | ||
data = {k: dict_of_lists[k][i] for k in keys} | ||
list_of_dicts.append(data) | ||
|
||
return list_of_dicts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing unit test
for more information, see https://pre-commit.ci
pre-commit.ci autofix |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have not tested it locally but looks good 👍
Adds a dataset which does not use tf.data and pads samples per batch instead of everything to the largest size.
Very advantageous for training datasets containing very differently sized samples.