-
Notifications
You must be signed in to change notification settings - Fork 252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Not for land] Integrate float8nocompile, an experimental feature for high performance #778
base: main
Are you sure you want to change the base?
Conversation
float8 training in eager mode
@vkuzo here is the PoC of the torchtitan + float8nocompile integration, and the training performance benchmarking results |
I think a good way to go here is to mark this PR "not for land" for now
Here is the data I think is important:
|
I think we should take care to mention/keep in mind that the MFU is with respect to peak bf16 TFLOPS. Direct comparison of TFLOPS might make more sense when comparing bf16 vs. fp8 runs since the fp8 runs are not actually doing every computation in fp8 either (e.g. SDPA or final linear). |
Makes sense, is there an existing way to log TFLOPS/sec during training w/ torchtitan? Searching around I don't see one |
@danielvegamyhre I think the TFLOPS is the same as MFU without the peak TFLOPS denominator (and extra factor of 100): Line 365 in 90567fc
In other words, it is just |
ah of course, thanks! Updated the PR description to include TFLOPS instead of MFU |
based on the results shared so far, I think it would be interesting to add one additional experiment branch: production float8 + torch.compile on just the |
…or only compiling linear layers
@vkuzo I updated the PR description with benchmarks comparing against prod float8 + compiling linear layers only. Also included benchmarks for no AC vs full AC. |
Summary
This PR contains a proof-of-concept integration of an experimental feature I've been working on in torchao: float8nocompile (official name TBD, naming things is hard!).
It is an implementation of float8 conversion with tensorwise dynamic scaling that uses handwritten Triton kernels to achieve high performance, rather than requiring torch.compile.
Benchmarking training performance
Model: Llama3 8b on 4 H100s with batch size of 1, seq len of 4096.
TL;DR is the prototype slightly beats the baseline we want to match (prod float8 + compiling only linear layers) for both full AC and no AC.
No AC (seq len 4096) - 8 H100s
Full AC (seq len 4096) - 8 GPUs
Benchmarking single linear layer forward+backward performance
Tested used a single linear layer of size (4096,4096) with different input sizes.
Performance benchmarks show the float8nocompile implementation is beating torch.compile by 1.72-4.45% depending on the input tensor size.
Note: this PR depends on this stack of PRs being merged into torchao, and those changes being included into a release (which the user installs).