-
Notifications
You must be signed in to change notification settings - Fork 51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SplitBatchRunner #1713
base: pytorch
Are you sure you want to change the base?
SplitBatchRunner #1713
Conversation
for i in range(0, batch_size, self._max_batch_size): | ||
batch_args, batch_kwargs = alf.nest.map_structure( | ||
lambda x: x[i:i + self._max_batch_size], (args, kwargs)) | ||
outputs.append( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we supposed to have the same results from the full batch and the split batch? what if there are dropouts and batch norm in the model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As the doc-string notes, dropout is not supported for training.
BatchNorm will also be affected. But if the batch_size is not too small, its effect on training should be small.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added notes
alf/utils/lean_function.py
Outdated
|
||
Args: | ||
model (nn.Module): the model to run | ||
max_batch_size (int): the maximum batch size to run the model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to add a comment that if max_batch_size <=0, what will happen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
|
||
return alf.nest.map_structure(lambda *x: torch.cat(x, dim=0), *outputs) | ||
|
||
def original_forward(self, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is this function used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is used by user of this class.
|
||
|
||
class SplitBatchRunner(torch.nn.Module): | ||
"""Split the input into smaller batches and run the model on each batch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to clarify that lean_function will be used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment added
Sometime we want to use a large batch for training but the memory is not enough. In this kind of situation, we can split the large batch into smaller batches and use lean_function to run the small batches. This can reduces the memory usage because the intermediate tensors will all be released.
So SplitBatchRunner is implemented to make this process very simple.