Skip to content

Commit

Permalink
Add performance section to float8 README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
vkuzo authored Sep 2, 2024
1 parent 65f660d commit bc41160
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions torchao/float8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,57 @@ We compose with the `DTensor` based [distributed APIs](https://pytorch.org/docs/
such as FSDP, TP and SP. Please see the [torchtitan](https://github.com/pytorch/torchtitan) repository for e2e examples
on using `torchao.float8` in a distributed setting.

# Performance

A common question about float8 training is "when is float8 linear faster vs bfloat16?". Given the M, K, N of the forward pass through your linear, you can reference the table below for a microbenchmark based speedup estimate on NVIDIA H100:

<img width="805" alt="float8_speedup" src="https://github.com/user-attachments/assets/5c5f2817-7eb7-4cab-bd03-49fe70cd31a8">

Example 1 (small shapes):
* forward input tensor size 1024x2048, linear weight size 2048x1024; M, K, N = 1024, 2048, 1024
* benchmark speedup is 0.80
* recommendation: leave this linear in bfloat16, the shapes are too small to benefit from float8 compute

Example 2 (large shapes):
* forward input tensor size 4096x8192, linear weight size 8192x16384; M, K, N = 4096, 8192, 16384
* benchmark speedup is 1.39
* recommendation: enable float8 for this linear to get a speedup

To reproduce the raw data for table above, you can run the following script

```lang=shell
python benchmarks/float8/float8_roofline.py your_output_filename.csv --gemm_time_strategy benchmarks --shape_gen_name sweep
```

## Derivation

In a bf16 linear, assume all of the time is spent in gemms. In a float8 linear, account for max_abs and casting overhead. We want to know when

```
bf16_gemm_time > fp8_gemm_time + fp8_overhead_time
```

Or, equivalently,

```
bf16_gemm_time - fp8_gemm_time > fp8_overhead_time
```

There are three observations we can make about the formula above:
* LHS > 0 for large shapes, with the gemm speedup approaching 2x as M, K, N increase
* LHS < 0 for small shapes, on NVIDIA H100 + cuBLAS
* RHS > 0 for all shapes, bounded by memory bandwidth, framework overhead and compiler limitations

For small shapes, a combination of (2) and (3) leads to speedup < 1. For medium shapes, (1) and (3) are of similar magnitude and the speedup depends on M, K, N and framework and compiler behavior. For large shapes, (1) leads to speedup > 1.

## Scaling type vs speedup

Delayed scaling is theoretically faster than dynamic scaling because of reduced read/write traffic requirements. Today, torch.compile has a couple of limitations (see the performance section of https://github.com/pytorch/ao/issues/556) which prevent us from reaching the optimal behavior for delayed scaling, so the observed performance of delayed scaling is close to that of dynamic scaling. As the torch.compile limitations are fixed, we expect delayed scaling to eventually become more performant compared to dynamic scaling.

## torch.compile behavior vs speedup

There are a couple of limitations in how torch.compile generates float8 scaling and casting kernels (see the performance section of https://github.com/pytorch/ao/issues/556). As the limitations get resolved, we expect to reach improved performance.

# Testing

```bash
Expand Down

0 comments on commit bc41160

Please sign in to comment.