[Fix] Several bugs when using HvdAllToAllEmbedding to train model and save/restore KV files. #376
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
BUG1
The id key may be multiple copies in one tensor before real shadow lookup when using HvdAllToAllEmbedding.
Here is how the bug happens.
Before ALL2ALL:
Rank0 IDs: [0,1,1,3,3] --unique--> [0,1,3]
Rank1 IDs: [0,1,2,2,3] --unique--> [0,1,2,3]
After ALL2ALL:
Rank0 IDs: [0,0,2]
Rank1 IDs: [1,1,3,3]
Rank0 has duplicated key 0, and Rank1 has duplicated keys 1 and 3. When updating the gradient, in Rank0, the key 0 will insert twice which means only one gradient takes effect and another one was covered. That may lead to under-training, same as Rank1 because of multiple copies of key 1 and key 3.
So we need to do secondary unique operation after ALL2ALL. The first time for reducing transmission overhead, the second time for training properly.
BUG2
Fix bug that de_hvd_save_model and CheckpointManager were unable to use together. Which cause by CheckpointManager compatibility code in DE hack changing the storage path in DE saveable object when runtime. The modified storage path was not wrote back after Checkpoint saving, and then when call saved_model saving function it would dump to an unexpected directory which was set in DECheckpoint class.
Others
Modify function name de_hvd_save_model to de_save_model.
Also make user more easy to use de_save_model by writing fewer code.
Type of change
Checklist:
How Has This Been Tested?
Train any model using HvdAllToAllEmbedding, and comparing the loss between set parameter with_secondary_unique True and False.
Run the new test.