Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SplitBatchRunner #1713

Open
wants to merge 2 commits into
base: pytorch
Choose a base branch
from
Open

SplitBatchRunner #1713

wants to merge 2 commits into from

Conversation

emailweixu
Copy link
Contributor

Sometime we want to use a large batch for training but the memory is not enough. In this kind of situation, we can split the large batch into smaller batches and use lean_function to run the small batches. This can reduces the memory usage because the intermediate tensors will all be released.

So SplitBatchRunner is implemented to make this process very simple.

for i in range(0, batch_size, self._max_batch_size):
batch_args, batch_kwargs = alf.nest.map_structure(
lambda x: x[i:i + self._max_batch_size], (args, kwargs))
outputs.append(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we supposed to have the same results from the full batch and the split batch? what if there are dropouts and batch norm in the model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the doc-string notes, dropout is not supported for training.
BatchNorm will also be affected. But if the batch_size is not too small, its effect on training should be small.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added notes


Args:
model (nn.Module): the model to run
max_batch_size (int): the maximum batch size to run the model.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add a comment that if max_batch_size <=0, what will happen.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added


return alf.nest.map_structure(lambda *x: torch.cat(x, dim=0), *outputs)

def original_forward(self, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this function used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used by user of this class.



class SplitBatchRunner(torch.nn.Module):
"""Split the input into smaller batches and run the model on each batch.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to clarify that lean_function will be used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment added

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants