Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

8-bit LION, take 2 #514

Merged
merged 35 commits into from
Aug 24, 2023
Merged

8-bit LION, take 2 #514

merged 35 commits into from
Aug 24, 2023

Conversation

dblalock
Copy link
Contributor

@dblalock dblalock commented Aug 10, 2023

Adds an 8-bit version of the LION optimizer. Also features 1 byte of (optional) auxiliary error correction state for each parameter to make pure bf16 training work.

Code changes:

  • Adds lion8b.py to llm-foundry/optim
  • Adds DecoupledLionW_8bit to llm-foundry/optim/__init__.py
  • Adds lion8b as an option in llm-foundry/optim/builders.py
  • Adds test_lion8b.py to the tests.
  • Adds mosaicml-turbo to the GPU dependencies in setup.py. This is the repo that currently holds all the CUDA kernels. These are in a separate repo for now to avoid complicating LLM foundry {install, deps, source code}.
  • Adds an optional master_weight_dtype field in train.py. If set to bf16 or fp16, the script does model.to(dtype=<that dtype>) before training. This works when we have error correction turned on.
  • Tweaks config_utils.py to set FSDP's param_dtype to None if the master weights are already fp16/bf16.

Non-obvious design choices:

  • Features the option to store a few extra bits per parameter to facilitate bf16 or f16 master weights, which should save us space and hopefully squeeze out a little more MFU.
  • There's an undocumented _fused arg that's kind of needed for testing and maybe (?) should be part of the API in case someone wants to check whether it's causing issues. I kind of want to get rid of this though once we trust the kernel logic fully.
  • This should work fine on CPUs—it just doesn't quantize. We could save some LoC if we instead said "sorry, you need a GPU to run this" but I'm worried about users wanting to prototype/debug locally.
  • We have a "_MaybeQuantizedTensor" class to hold the quantized optimizer states. This is a pretty good solution but I don't love having a tensor-like object that doesn't implement the full Tensor interface.

There's enough test coverage here that I'm not super worried about these choices, but wanted to highlight them in case someone has strong opinions.

WandB report

@dblalock dblalock changed the title add decoupled lion8b optimizer + tests + builder option + deps [WIP] 8-bit LION, take 2 Aug 11, 2023
@dblalock dblalock requested a review from vchiley August 14, 2023 07:37
@dblalock dblalock marked this pull request as ready for review August 14, 2023 07:38
tests/test_lion8b.py Outdated Show resolved Hide resolved
@dblalock dblalock changed the title [WIP] 8-bit LION, take 2 8-bit LION, take 2 Aug 14, 2023
tests/test_lion8b.py Outdated Show resolved Hide resolved
llmfoundry/optim/lion8b.py Outdated Show resolved Hide resolved
@dblalock dblalock requested a review from vchiley August 21, 2023 03:53
tests/test_lion8b.py Outdated Show resolved Hide resolved
@vchiley
Copy link
Contributor

vchiley commented Aug 22, 2023

High level design point
If DecoupledLionW_8bit is requested, it kind of seems like it should always be quantized.
If people are running on CPU or do not ask for the quantized version, they should be running DecoupledLionW not DecoupledLionW_8bit...

Should we make sure that the state dict of DecoupledLionW_8bit can be loaded into DecoupledLionW (and vice versa), then error out directing people to use DecoupledLionW if they are in a setting where DecoupledLionW_8bit isn't available.
This setup would mean that we can get rid of the _MaybeQuantizedTensor cls (as well as potentially some other code)

Yes this point should have been brought up a LONG time ago. I'm not saying we should do this, just opening the discussion.
The downside is obviously that it would be tougher to test quantized vs not quantized.
(The answer to this will probably be: we're going to enable the non-quantized version in the DecoupledLionW_8bit class as well 😄 )

Also can the state dict of DecoupledLionW_8bit can be loaded into DecoupledLionW (and vice versa)?

@dblalock
Copy link
Contributor Author

dblalock commented Aug 22, 2023

@vchiley How would you feel about just making this be DecoupledLionW in a future PR? So we'd have one optimizer with a quantize flag instead of two classes?

Asking because this informs whether I should rip out the flag. I like the idea of avoiding redundant code, but maybe we have reasons to not do this.

@vchiley
Copy link
Contributor

vchiley commented Aug 22, 2023

@bmosaicml originally wrote lion implementation
@bmosaicml do you have an opinion on this?

Copy link
Contributor

@bmosaicml bmosaicml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incredible work Davis!

@dblalock
Copy link
Contributor Author

@bmosaicml so are you good with "yes, leave the quantization flag and make this the single LION implementation"?

@dblalock dblalock merged commit 795ab4a into main Aug 24, 2023
9 checks passed
@dakinggg dakinggg deleted the davis/lion8b-v2 branch October 11, 2023 21:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants