Skip to content

Commit

Permalink
make float8 README.md examples standalone
Browse files Browse the repository at this point in the history
Updates the two float8 README.md examples (for dynamic and delayed scaling) to be standalone

Test plan: copy-paste each code sample and execute it, runs successfully
  • Loading branch information
vkuzo authored Sep 4, 2024
1 parent f5703b0 commit 75d5a3d
Showing 1 changed file with 38 additions and 46 deletions.
84 changes: 38 additions & 46 deletions torchao/float8/README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
# torchao.float8

This is an early version of a library for accelerating training with float8 in native PyTorch
This is a workflow for accelerating training with float8 in native PyTorch
according to the recipes laid out in https://arxiv.org/pdf/2209.05433.pdf.
The codebase strives to stay small, easily hackable, debuggable with native PyTorch tooling,
and composable with key systems such as autograd, ```torch.compile``` and distributed.
With ``torch.compile`` on, initial results show
throughput speedups of up to 1.2x on small scale (8 GPUs) LLaMa pretraining jobs.
throughput speedups of up to 1.5x on 128 GPU LLaMa 3 70B pretraining jobs.

:warning: <em>See the [feature tracker](https://github.com/pytorch/ao/issues/556) for upcoming features.</em>

:warning: <em>Backwards compatibility is not guaranteed at this point. The codebase is in active development and
will change rapidly.</em>
:warning: <em>The codebase is stable, but backwards compatibility is not yet guaranteed.</em>

# Single GPU User API

Expand All @@ -21,97 +20,90 @@ We provide three per-tensor scaling strategies: dynamic, delayed and static. Se
This is the most accurate recipe as every tensor is scaled dynamically.

```python
from torchao.float8 import (
convert_to_float8_training,
precompute_float8_dynamic_scale_for_fsdp,
)

# create model
m = Model(...)
import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training

# create model and sample input
m = nn.Sequential(
nn.Linear(2048, 4096),
nn.Linear(4096, 128),
).bfloat16().cuda()
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)

# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# don't convert the output module
if fqn == "output":
# don't convert the last module
if fqn == "1":
return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True

# convert all `torch.nn.Linear` modules to `Float8Linear`
# convert specified `torch.nn.Linear` modules to `Float8Linear`
convert_to_float8_training(m, module_filter_fn=module_filter_fn)

# optional: use FSDP
model = FSDP(model, use_orig_params=True)

# optional: enable torch.compile for improved performance
# enable torch.compile for competitive performance
m = torch.compile(m)

# toy training loop
for _ in range(N_ITER):
for _ in range(10):
optimizer.zero_grad()
y = m(x)
y.sum().backward()
optimizer.step()

# specific to fsdp2 + dynamic scaling, when fp8 all-gather is turned on
# this method is optional but is highly recommended for performance
# it calcuclates scales for all parameters in a single all-reduce
precompute_float8_dynamic_scale_for_fsdp(model)

```

## float8 linear with delayed scaling

This is theoretically the most performant recipe as it minimizes memory reads.

```python
import torch
import torch.nn as nn
from torchao.float8 import (
convert_to_float8_training,
sync_float8_amax_and_scale_history,
Float8LinearConfig,
ScalingType,
CastConfig,
)

# create model
m = Model(...)
# create model and sample input
m = nn.Sequential(
nn.Linear(2048, 4096),
nn.Linear(4096, 128),
).bfloat16().cuda()
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)

# optional: configure for compatibility with FSDP. Note that workarounds
# gated with config.enable_amax_init and
# config.enable_pre_and_post_forward are needed for
# autocast + compile + FSDP + float8 to work
from torchao.float8 import Float8LinearConfig, ScalingType, CastConfig
# configure delayed scaling
config = Float8LinearConfig(
enable_amax_init=False, # only needed for autocast + compile + FSDP + float8 delayed
enable_pre_and_post_forward=False # only needed for autocast + compile + FSDP + float8 delayed
cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
# enable_amax_init=False, # only needed for autocast + compile + FSDP + float8 delayed
# enable_pre_and_post_forward=False # only needed for autocast + compile + FSDP + float8 delayed
)

# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling
# type
convert_to_float8_training(
m,
config=config,
)

# optional: use FSDP
model = FSDP(model, use_orig_params=True)
# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying custom scaling behavior
convert_to_float8_training(m, config=config)

# optional: enable torch.compile for improved performance
# enable torch.compile for competitive performance
m = torch.compile(m)

# toy training loop
for _ in range(N_ITER):
for _ in range(10):
optimizer.zero_grad()
y = m(x)
y.sum().backward()

# specific to float8 with delayed scaling: separate step to sync scales/amaxes
# in the future, this may move to a context manager
sync_float8_amax_and_scale_history(model)
sync_float8_amax_and_scale_history(m)

optimizer.step()
```
Expand Down

0 comments on commit 75d5a3d

Please sign in to comment.