Skip to content
This repository was archived by the owner on Mar 14, 2024. It is now read-only.

Checkpoint removal 2 #250

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

tmarkovich
Copy link
Contributor

Types of changes

  • Docs change / refactoring / dependency upgrade
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)

Motivation and Context / Related issue

With 1bn entities, dumping the embedding table and reloading it was taking as long as training on a single 10bn edge chunk. This PR reduces the amount of check pointing to save us that cost which drastically speeds up single-instance embedding.

How Has This Been Tested (if it applies)

Verified that this worked at runtime

Checklist

  • The documentation is up-to-date with the changes I made.
  • I have read the CONTRIBUTING document and completed the CLA (see CONTRIBUTING).
  • All tests passed, and additional code has been covered with new tests.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 14, 2022
@tmarkovich tmarkovich marked this pull request as draft February 14, 2022 14:51
@tmarkovich
Copy link
Contributor Author

@adamlerer @lw this is super hacky at the moment. I'll add some config options, but I wanted to get your feedback on this first before I went too far down a rabbit hole.

Copy link
Contributor

@adamlerer adamlerer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change seems conceptually good to me. Of course, the code looks incorrect for the distributed case, but I assume you will fix that up for the final version.

We will probably need to test this change with some distributed runs. Unless you can make it very clearly correct "by inspection".

self.checkpoint_manager.write(
entity, part, embs.detach(), optimizer.state_dict()
)
self._write_single_embedding(holder, entity, part)
self.embedding_storage_freelist[entity].add(embs.storage())
io_bytes += embs.numel() * embs.element_size() # ignore optim state
# these variables are holding large objects; let them be freed
del embs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do these lines work if you don't define embs and optimizer any more?

for entity, part in parts:
self._write_single_embedding(self.holder, entity, part)

def _write_stats(self, bucket: Optional[Bucket], stats: Optional[BucketStats]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This naming is misleading... I think it does more than write stats in the distributed scheduler.

@tmarkovich
Copy link
Contributor Author

This change seems conceptually good to me. Of course, the code looks incorrect for the distributed case, but I assume you will fix that up for the final version.

We will probably need to test this change with some distributed runs. Unless you can make it very clearly correct "by inspection".

Even if it's "clearly correct by inspection", I'd still feel better if you all were able to run some distributed versions as well.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants