Parallelize device-to-host transfers #5824
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
For async checkpointing, it is important to unblock training as quickly as possible. Training can only be unblocked when all data has been moved off of device onto the host to free up device memory.
One bottleneck I found was that in TransferFromServer, the tensors are transferred using
ToLiteralSync
, meaning each tensor is transferred sequentially. In benchmarking a 2B parameter model, parallelizing these transfers decreased the time spent in TransferFromServer from 5.1s to 1.8s, ~65% reduction.There is still significant overhead from copying the resulting xla::Literal into torch.Tensor, but that's for another PR.