Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Not for land] Integrate float8nocompile, an experimental feature for high performance #778

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.10.15
10 changes: 10 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,16 @@ def __init__(self):
default="dynamic",
help="float8 scaling for input, dynamic (default) or delayed",
)
self.parser.add_argument(
"--float8.float8nocompile",
action="store_true",
help="use the float8nocompile prototype implementation",
)
self.parser.add_argument(
"--float8.float8nocompile_no_precompute_for_backward",
action="store_true",
help="use activation checkpointing with float8nocompile linear layers",
)

# communications library settings
self.parser.add_argument(
Expand Down
62 changes: 54 additions & 8 deletions torchtitan/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# Note: Performance
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance

from typing import List, Union
from typing import Callable, List, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -47,6 +47,9 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
"torchao is not installed. Please install it to use float8 linear layers."
) from e

self.use_float8nocompile = float8_config.float8nocompile
self.ac_config = job_config.activation_checkpoint

# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
enable_fsdp_float8_all_gather = (
parallel_dims.dp_shard_enabled
Expand Down Expand Up @@ -90,14 +93,40 @@ def convert_to_float8_training(self, model: nn.Module):
if not self.enabled:
return

from torchao.float8 import convert_to_float8_training
if self.use_float8nocompile:
logger.info("Using float8nocompile prototype")
from torchao.prototype.float8nocompile.float8nocompile_linear_utils import (
convert_to_float8_nocompile_training,
)

# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
convert_to_float8_training(
model,
config=self.config,
module_filter_fn=lambda mod, fqn: fqn != "output",
)
# for full AC or no AC
no_precompute_for_backward = self.ac_config.mode == "full"
convert_to_float8_nocompile_training(
model,
config=self.config,
module_filter_fn=lambda mod, fqn: fqn != "output",
no_precompute_for_backward=no_precompute_for_backward,
)

# for selective per layer AC
if (
self.ac_config.mode == "selective"
and self.ac_config.selective_ac_option.isdigit()
):
no_precompute_for_backward_every_nth_layer(
model,
int(self.ac_config.selective_ac_option),
)
else:
logger.info("Using float8 training")
from torchao.float8 import convert_to_float8_training

# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
convert_to_float8_training(
model,
config=self.config,
module_filter_fn=lambda mod, fqn: fqn != "output",
)
logger.info(
"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="
f"{self.config.enable_fsdp_float8_all_gather}"
Expand Down Expand Up @@ -145,3 +174,20 @@ def sync_float8_amax_and_scale_history(
models = [model] if isinstance(model, nn.Module) else model
for m in models:
self._sync_float8_amax_and_scale_history(m)


def no_precompute_for_backward_every_nth_layer(model: nn.Module, n: int):
"""Set no_precompute_for_backward to True for every nth layer in the model."""
for layer_idx, (layer_id, layer) in enumerate(model.layers.named_children()):
if layer_idx % n == 0:
logger.info(f"Enabling no_precompute_for_backward for layer {layer_id}")
_enable_no_precompute_for_backward(layer)


def _enable_no_precompute_for_backward(model: nn.Module):
"""Recursively set no_precompute_for_backward to True for all linear layers in the given model."""
for child_layer in model.children():
if isinstance(child_layer, nn.Linear):
child_layer.no_precompute_for_backward = True
else:
_enable_no_precompute_for_backward(child_layer)
21 changes: 16 additions & 5 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# This file applies the PT-D parallelisms (except pipeline parallelism) and various
# training techniques (e.g. activation checkpointing and compile) to the Llama model.

import os
from collections import defaultdict

import torch
Expand Down Expand Up @@ -299,11 +300,21 @@ def apply_compile(model: nn.Module):
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
repeated structure. Alternatively one can compile the whole model (after applying DP).
"""
for layer_id, transformer_block in model.layers.named_children():
transformer_block = torch.compile(transformer_block, fullgraph=True)
model.layers.register_module(layer_id, transformer_block)

logger.info("Compiling each TransformerBlock with torch.compile")
compile_linear_only = bool(os.environ.get("TORCHTITAN_COMPILE_LINEAR_ONLY", False))

if compile_linear_only:
logger.info("Compiling linear layers with torch.compile")
for name, child in model.named_children():
if isinstance(child, torch.nn.Linear):
new_child = torch.compile(child)
setattr(model, name, new_child)
else:
apply_compile(child)
else:
logger.info("Compiling each TransformerBlock with torch.compile")
for layer_id, transformer_block in model.layers.named_children():
transformer_block = torch.compile(transformer_block, fullgraph=True)
model.layers.register_module(layer_id, transformer_block)


def apply_fsdp(
Expand Down
2 changes: 2 additions & 0 deletions train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,5 @@ selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac ba

[float8]
enable_float8_linear = false
float8nocompile = false # TODO: should this go in [experimental]?
float8nocompile_ac = false
Loading