-
Notifications
You must be signed in to change notification settings - Fork 455
Checkpoint removal 2 #250
base: main
Are you sure you want to change the base?
Checkpoint removal 2 #250
Conversation
@adamlerer @lw this is super hacky at the moment. I'll add some config options, but I wanted to get your feedback on this first before I went too far down a rabbit hole. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change seems conceptually good to me. Of course, the code looks incorrect for the distributed case, but I assume you will fix that up for the final version.
We will probably need to test this change with some distributed runs. Unless you can make it very clearly correct "by inspection".
self.checkpoint_manager.write( | ||
entity, part, embs.detach(), optimizer.state_dict() | ||
) | ||
self._write_single_embedding(holder, entity, part) | ||
self.embedding_storage_freelist[entity].add(embs.storage()) | ||
io_bytes += embs.numel() * embs.element_size() # ignore optim state | ||
# these variables are holding large objects; let them be freed | ||
del embs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do these lines work if you don't define embs
and optimizer
any more?
for entity, part in parts: | ||
self._write_single_embedding(self.holder, entity, part) | ||
|
||
def _write_stats(self, bucket: Optional[Bucket], stats: Optional[BucketStats]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This naming is misleading... I think it does more than write stats in the distributed scheduler.
Even if it's "clearly correct by inspection", I'd still feel better if you all were able to run some distributed versions as well. |
Types of changes
Motivation and Context / Related issue
With 1bn entities, dumping the embedding table and reloading it was taking as long as training on a single 10bn edge chunk. This PR reduces the amount of check pointing to save us that cost which drastically speeds up single-instance embedding.
How Has This Been Tested (if it applies)
Verified that this worked at runtime
Checklist