-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[for informative purposes only] compiled cuda train-gpt.py #1
base: main
Are you sure you want to change the base?
Conversation
Thank you for looking into this and setting this up! I'll try to see if we can optimize it better (up til now we didn't look at MFU at all). One thing I can think of is that currently we implement Q, K, V as three separate matrix multiplies for simplicity, but these could be batched. Also curious what effect including our normalization has on MFU. Will report back when I get a chance to look into this. Also thanks for pointing out the typo in the example. |
Just jotting down a couple more ideas when they come to me:
|
compiling both gpt.regularize && gpt.normalize:
$ CUDA_VISIBLE_DEVICES=2 python examples/train-gpt.py --cuda
W0526 21:31:55.259000 140707784318976 torch/_dynamo/convert_frame.py:357] torch._dynamo hit config.cache_size_limit (8)
W0526 21:31:55.259000 140707784318976 torch/_dynamo/convert_frame.py:357] function: 'normalize' (/home/main/modula/.venv/lib/python3.10/site-packages/modula/abstract.py:92)
W0526 21:31:55.259000 140707784318976 torch/_dynamo/convert_frame.py:357] last reason: L['target_norm'] == 0.5
W0526 21:31:55.259000 140707784318976 torch/_dynamo/convert_frame.py:357] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0526 21:31:55.259000 140707784318976 torch/_dynamo/convert_frame.py:357] To diagnose recompilation issues, see https://pytorch.org/docs/master/compile/troubleshooting.html.
W0526 21:32:03.863000 140707784318976 torch/_dynamo/convert_frame.py:357] torch._dynamo hit config.cache_size_limit (8)
W0526 21:32:03.863000 140707784318976 torch/_dynamo/convert_frame.py:357] function: 'regularize' (/home/main/modula/.venv/lib/python3.10/site-packages/modula/abstract.py:102)
W0526 21:32:03.863000 140707784318976 torch/_dynamo/convert_frame.py:357] last reason: L['strength'] == 0.005
W0526 21:32:03.863000 140707784318976 torch/_dynamo/convert_frame.py:357] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0526 21:32:03.863000 140707784318976 torch/_dynamo/convert_frame.py:357] To diagnose recompilation issues, see https://pytorch.org/docs/master/compile/troubleshooting.html.
W0526 21:32:06.091000 140707784318976 torch/_dynamo/convert_frame.py:357] torch._dynamo hit config.cache_size_limit (8)
W0526 21:32:06.091000 140707784318976 torch/_dynamo/convert_frame.py:357] function: 'normalize' (/home/main/modula/.venv/lib/python3.10/site-packages/modula/atom.py:32)
W0526 21:32:06.091000 140707784318976 torch/_dynamo/convert_frame.py:357] last reason: L['target_norm'] == 0.11857166654767855
W0526 21:32:06.091000 140707784318976 torch/_dynamo/convert_frame.py:357] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0526 21:32:06.091000 140707784318976 torch/_dynamo/convert_frame.py:357] To diagnose recompilation issues, see https://pytorch.org/docs/master/compile/troubleshooting.html.
step: 10 train loss: 2.98 test loss: 3.03 tokens/gpu/sec: 3339.13 MFU: 22.75%
step: 20 train loss: 2.67 test loss: 2.74 tokens/gpu/sec: 6538.01 MFU: 44.55%
step: 30 train loss: 2.69 test loss: 2.69 tokens/gpu/sec: 7995.59 MFU: 54.49%
step: 40 train loss: 2.60 test loss: 2.64 tokens/gpu/sec: 8768.21 MFU: 59.75%
step: 50 train loss: 2.61 test loss: 2.61 tokens/gpu/sec: 9223.39 MFU: 62.85%
step: 60 train loss: 2.65 test loss: 2.62 tokens/gpu/sec: 9431.93 MFU: 64.27%
step: 70 train loss: 2.61 test loss: 2.61 tokens/gpu/sec: 9640.00 MFU: 65.69%
step: 80 train loss: 2.56 test loss: 2.60 tokens/gpu/sec: 9711.68 MFU: 66.18%
step: 90 train loss: 2.64 test loss: 2.61 tokens/gpu/sec: 9708.39 MFU: 66.16%
step: 100 train loss: 2.57 test loss: 2.58 tokens/gpu/sec: 9803.95 MFU: 66.81%
step: 110 train loss: 2.56 test loss: 2.56 tokens/gpu/sec: 9809.32 MFU: 66.85%
step: 120 train loss: 2.65 test loss: 2.57 tokens/gpu/sec: 9896.43 MFU: 67.44%
step: 130 train loss: 2.60 test loss: 2.58 tokens/gpu/sec: 9974.42 MFU: 67.97%
step: 140 train loss: 2.57 test loss: 2.57 tokens/gpu/sec: 10019.70 MFU: 68.28%
step: 150 train loss: 2.59 test loss: 2.58 tokens/gpu/sec: 10057.23 MFU: 68.54%
step: 160 train loss: 2.51 test loss: 2.54 tokens/gpu/sec: 10086.36 MFU: 68.73%
step: 170 train loss: 2.52 test loss: 2.57 tokens/gpu/sec: 10135.59 MFU: 69.07%
step: 180 train loss: 2.58 test loss: 2.59 tokens/gpu/sec: 10182.83 MFU: 69.39%
step: 190 train loss: 2.50 test loss: 2.58 tokens/gpu/sec: 10217.94 MFU: 69.63%
step: 200 train loss: 2.51 test loss: 2.56 tokens/gpu/sec: 10234.87 MFU: 69.75%
step: 210 train loss: 2.54 test loss: 2.59 tokens/gpu/sec: 10256.04 MFU: 69.89%
step: 220 train loss: 2.56 test loss: 2.55 tokens/gpu/sec: 10290.66 MFU: 70.13%
step: 230 train loss: 2.48 test loss: 2.55 tokens/gpu/sec: 10305.05 MFU: 70.22%
step: 240 train loss: 2.53 test loss: 2.56 tokens/gpu/sec: 10321.76 MFU: 70.34%
step: 250 train loss: 2.54 test loss: 2.53 tokens/gpu/sec: 10336.67 MFU: 70.44%
step: 260 train loss: 2.53 test loss: 2.54 tokens/gpu/sec: 10347.73 MFU: 70.52%
step: 270 train loss: 2.60 test loss: 2.55 tokens/gpu/sec: 10362.08 MFU: 70.61%
step: 280 train loss: 2.48 test loss: 2.52 tokens/gpu/sec: 10386.00 MFU: 70.78%
step: 290 train loss: 2.57 test loss: 2.55 tokens/gpu/sec: 10405.08 MFU: 70.91%
step: 300 train loss: 2.54 test loss: 2.53 tokens/gpu/sec: 10418.83 MFU: 71.00% Compare with the forward-only compiled baseline: $ CUDA_VISIBLE_DEVICES=2 python examples/train-gpt.py --cuda
step: 10 train loss: 3.11 test loss: 3.05 tokens/gpu/sec: 9378.40 MFU: 63.91%
step: 20 train loss: 2.73 test loss: 2.76 tokens/gpu/sec: 10191.80 MFU: 69.45%
step: 30 train loss: 2.64 test loss: 2.68 tokens/gpu/sec: 9852.90 MFU: 67.14%
step: 40 train loss: 2.66 test loss: 2.69 tokens/gpu/sec: 10145.24 MFU: 69.14%
step: 50 train loss: 2.64 test loss: 2.64 tokens/gpu/sec: 10331.91 MFU: 70.41%
step: 60 train loss: 2.73 test loss: 2.62 tokens/gpu/sec: 10332.93 MFU: 70.41%
step: 70 train loss: 2.59 test loss: 2.62 tokens/gpu/sec: 10376.13 MFU: 70.71%
step: 80 train loss: 2.67 test loss: 2.63 tokens/gpu/sec: 10437.77 MFU: 71.13%
step: 90 train loss: 2.55 test loss: 2.60 tokens/gpu/sec: 10415.25 MFU: 70.98%
step: 100 train loss: 2.52 test loss: 2.60 tokens/gpu/sec: 10437.63 MFU: 71.13%
step: 110 train loss: 2.58 test loss: 2.61 tokens/gpu/sec: 10432.99 MFU: 71.10%
step: 120 train loss: 2.54 test loss: 2.59 tokens/gpu/sec: 10445.39 MFU: 71.18%
step: 130 train loss: 2.57 test loss: 2.57 tokens/gpu/sec: 10490.28 MFU: 71.49%
step: 140 train loss: 2.51 test loss: 2.58 tokens/gpu/sec: 10527.36 MFU: 71.74%
step: 150 train loss: 2.62 test loss: 2.54 tokens/gpu/sec: 10545.53 MFU: 71.86%
step: 160 train loss: 2.64 test loss: 2.56 tokens/gpu/sec: 10549.42 MFU: 71.89%
step: 170 train loss: 2.51 test loss: 2.56 tokens/gpu/sec: 10552.99 MFU: 71.91%
step: 180 train loss: 2.56 test loss: 2.54 tokens/gpu/sec: 10531.44 MFU: 71.77%
step: 190 train loss: 2.64 test loss: 2.56 tokens/gpu/sec: 10554.32 MFU: 71.92%
step: 200 train loss: 2.55 test loss: 2.57 tokens/gpu/sec: 10567.03 MFU: 72.01% I believe the main issue with torch compiling normalize/regularize is the use of raw python floats in the input arguments, as torch.compile might be guarding against the scalar as a constant instead of treating the provided Also, changing |
"v100-sxm": 125e12, | ||
"6000A": 364.25e12, | ||
"4090": 165.2 * 10**12, | ||
"3090": 71 * 10**12, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this correct? Wikipedia only lists the 2:1 sparse TFLOPs but the source included in the article shows it as 143 TFLOPs. I think that would explain why you're seeing high MFU. On an A10 (with num_blocks = 3
because of memory spike when adam initialises, batch_size=8
) I see:
...
step: 100 train loss: 2.77 test loss: 2.81 tokens/gpu/sec: 14982.88 MFU: 43.46%
Assuming an A10 has peak 16 bit 125 TFLOPs. This is pretty closely in line with what I've seen with nanoGPT experiments. The code matches nanoGPT closely so that would make sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please always consult the appropriate nvidia whitepaper when looking for GPU performance numbers. if you have issues finding these PDFs in the future, I have developed a simple frontend that aggregates a number of manually collected sources.
Regarding the observed MFU of 3090,
- all gamer gpus have crippled fp32 accumulation such that the effective performance of any tensor core is half of what it should be. see the table from the link above:
- all pytorch matmuls use fp32 accumulation for all GEMMs; this is impossible to change at runtime, so all gamer GPUs on pytorch get half the performance of what they would with fp16 accumulation.
- because of this disadvantage, gamer GPUs also have a much lower FLOPs-to-membandwidth ratio than datacenter GPUs do, and it is much easier to achieve high MFU on them than on their corresponding datacenter equivalents.
I hope this addresses your concerns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK so if I use a triton matmul to accumulate in fp16 I should see the MFU on the A10 increase to around what you're seeing with a 3090?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The A10 has equivalent fp16 vs fp32 accum performance, unlike the 3090. This means that the A10 will always appear to have lower MFU relative to the 3090, regardless of accumulation precision.
Please note that the assumed peak tflops used in MFU calculation is based on the reported spec sheet values for fp32 accumulation. You may interpret this as the 3090 having an "easier target" to accomplish.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about this for a "lazy coder's MFU": measure the FLOP ceiling using just PyTorch matmuls and assume that's the peak I can get without doing extra work in triton or C++. I wrote this script and it measures the A10 at 69 TFLOPs, then the results are pretty much exactly what you see on the 3090. Out of interest, what flop ceiling do you see on the 3090 with the same script?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've done the simple lazy coder's approach in the past; it's usually within 95% of the spec on gamer GPUs. These are my 3090 results, which are not unexpected:
Running on device: cuda
<torch.utils.benchmark.utils.common.Measurement object at 0x7fa483f67b20>
FLOP Ceiling Measurement: Precision: bfloat16
Matrix size: 32768x32768
setup: from __main__ import matrix1, matrix2, matmul
Median: 1.08 s
IQR: 0.00 s (1.08 to 1.08)
10 measurements, 1 runs per measurement, 64 threads
FLOP Ceiling: 65.30 TFLOP/s
However, your observation of the A10 achieving only 69TFLOPs surprised me, so I spun up A10 (Lambdalabs) && A10g (AWS) instances to test triton fp16 vs torch fp32 accum.
A10
I am also able to replicate your script's results:
Running on device: cuda
<torch.utils.benchmark.utils.common.Measurement object at 0x7f3bfe2635b0>
FLOP Ceiling Measurement: Precision: bfloat16
Matrix size: 32768x32768
setup: from __main__ import matrix1, matrix2, matmul
Median: 1.01 s
IQR: 0.07 s (0.98 to 1.05)
99 measurements, 1 runs per measurement, 30 threads
FLOP Ceiling: 69.80 TFLOP/s
During the execution of your script, I noticed that the A10's GPU clock speed in nvtop is drastically limited at 100% utilization. When initially boosted, it goes up close to 1700MHz, but after a period of continuous execution, it settles around 1000MHz.
I believe that the A10's 150W TDP explains the reduced performance of the A10 relative to the expected 125TF spec.
A10g
There is a separate spec sheet for the A10g, which indicates it having only 70 peak TFlops and a TDP of 300W.
On 300W, it achieves this much:
Limited to 150W, it performs as so:
Special thanks to @neggles for quickly setting up tests on A10g, and for pointing out A10 GPU clock throttling.
this fork implements a small number of changes to benchmark the MFU and token throughput of a reasonably large GPT (7b-like dims but only 4 layers and 256 vocab) on a single GPU.
On a 3090 (
-pl 350
), I obtain these results:$ CUDA_VISIBLE_DEVICES=1 python examples/train-gpt.py --cuda step: 100 train loss: 2.58 test loss: 2.59 tokens/gpu/sec: 10889.37 MFU: 74.21%
This is surprisingly decent, though fairly low compared to what a standard compiled implementation would achieve (~95%) on the same hardware.
A similar gap is observable on H100 (26.85% vs ~55%).
This PR is not (currently) meant to be merged, and is merely here to provide useful information. The code is currently hardcoded to use 3090's FLOPs -- though it isn't hard to auto-detect this, I didn't want to bloat the code diff for readability reasons.