Skip to content

Commit

Permalink
skip fp8 CI on non-H100 GPUs (pytorch#465)
Browse files Browse the repository at this point in the history
skip fp8 tests on non-H100 GPUs by checking
`torch.cuda.get_device_capability() >= (9, 0)`

this makes 4 GPU CI healthy again
  • Loading branch information
weifengpy authored Jul 17, 2024
1 parent ae8181b commit 3760bcf
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 20 deletions.
44 changes: 40 additions & 4 deletions torchtitan/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
# Note: Performance
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
import contextlib
import functools
from typing import Optional

import torch
import torch.nn as nn
from torch._logging import warning_once

from torchtitan.config_manager import JobConfig
from torchtitan.logging_utils import logger
Expand All @@ -36,7 +38,13 @@ def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool):
config.enable_fsdp_fp8_all_gather = prev


def build_fp8_linear(
@functools.lru_cache(None)
def is_sm90_or_later():
# Float8 is only supported on H100+ GPUs
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)


def maybe_build_fp8_linear(
model: nn.Module, job_config: JobConfig, dp_enabled: Optional[bool] = False
):
"""
Expand All @@ -46,16 +54,24 @@ def build_fp8_linear(
This will mutate the model inplace.
"""
enable_fp8_linear = job_config.training.enable_fp8_linear
enable_fsdp_fp8_all_gather = (
job_config.training.enable_fsdp_fp8_all_gather and dp_enabled
)
if not enable_fp8_linear:
return
if not is_sm90_or_later():
warning_once(
logger,
"Failed to swap to Float8Linear because SM90 or later is not available",
)
return
try:
from float8_experimental.float8_linear import TensorScalingType
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
)

# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
enable_fsdp_fp8_all_gather = (
job_config.training.enable_fsdp_fp8_all_gather and dp_enabled
)
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
swap_linear_with_float8_linear(
model, scaling_type_w=TensorScalingType.DYNAMIC
Expand All @@ -67,3 +83,23 @@ def build_fp8_linear(
raise ImportError(
"float8_experimental is not installed. Please install it to use fp8 linear layers."
) from exc


def maybe_precompute_fp8_dynamic_scale_for_fsdp(
model: nn.Module, job_config: JobConfig
):
if not (
job_config.training.enable_fp8_linear
and job_config.training.enable_fsdp_fp8_all_gather
and job_config.training.precompute_float8_dynamic_scale_for_fsdp
):
return
if not is_sm90_or_later():
warning_once(
logger,
"Skipped precomputing fp8 scales because SM90 or later is not available",
)
return
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp

precompute_float8_dynamic_scale_for_fsdp(model)
26 changes: 10 additions & 16 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
from torchtitan.checkpoint import CheckpointManager
from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_hf_data_loader, create_tokenizer
from torchtitan.float8_linear import build_fp8_linear
from torchtitan.float8_linear import (
maybe_build_fp8_linear,
maybe_precompute_fp8_dynamic_scale_for_fsdp,
)
from torchtitan.logging_utils import init_logger, logger
from torchtitan.lr_scheduling import get_lr_schedulers
from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger
Expand Down Expand Up @@ -215,9 +218,8 @@ def loss_fn(pred, labels):
with torch.device("meta"):
whole_model = model_cls.from_model_args(model_config)

# apply fp8 linear module swap
if job_config.training.enable_fp8_linear:
build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled)
# swap to Float8Linear base on fp8 config
maybe_build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled)

# log model size
model_param_count = get_num_params(whole_model)
Expand Down Expand Up @@ -398,18 +400,10 @@ def loss_fn(pred, labels):
optimizers.step()
lr_schedulers.step()

if (
job_config.training.enable_fp8_linear
and job_config.training.enable_fsdp_fp8_all_gather
and job_config.training.precompute_float8_dynamic_scale_for_fsdp
):
from float8_experimental.fsdp_utils import (
precompute_float8_dynamic_scale_for_fsdp,
)

# calculate float8 dynamic amax/scale for all-parameter for FSDP2
# it issues a single all-reduce for all parameters at once for better performance
precompute_float8_dynamic_scale_for_fsdp(model)
# when fp8 config is on,
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
# it issues a single all-reduce for all parameters at once for better performance
maybe_precompute_fp8_dynamic_scale_for_fsdp(model, job_config)

losses_since_last_log.append(loss)

Expand Down

0 comments on commit 3760bcf

Please sign in to comment.