|
12 | 12 |
|
13 | 13 | # Note: Performance |
14 | 14 | # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance |
| 15 | +import contextlib |
| 16 | +import functools |
| 17 | +from typing import Optional |
15 | 18 |
|
| 19 | +import torch |
16 | 20 | import torch.nn as nn |
| 21 | +from torch._logging import warning_once |
17 | 22 |
|
18 | 23 | from torchtitan.config_manager import JobConfig |
19 | 24 | from torchtitan.logging_utils import logger |
20 | 25 |
|
21 | 26 |
|
22 | | -def build_fp8_linear(model: nn.Module, job_config: JobConfig): |
| 27 | +@contextlib.contextmanager |
| 28 | +def set_enable_fsdp_float8_all_gather(enable_fsdp_fp8_all_gather: bool): |
| 29 | + import float8_experimental.config as config |
| 30 | + |
| 31 | + prev = config.enable_fsdp_fp8_all_gather |
| 32 | + torch.distributed.barrier() |
| 33 | + config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather |
| 34 | + try: |
| 35 | + yield |
| 36 | + finally: |
| 37 | + torch.distributed.barrier() |
| 38 | + config.enable_fsdp_fp8_all_gather = prev |
| 39 | + |
| 40 | + |
| 41 | +@functools.lru_cache(None) |
| 42 | +def is_sm90_or_later(): |
| 43 | + # Float8 is only supported on H100+ GPUs |
| 44 | + return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) |
| 45 | + |
| 46 | + |
| 47 | +def maybe_build_fp8_linear( |
| 48 | + model: nn.Module, job_config: JobConfig, dp_enabled: Optional[bool] = False |
| 49 | +): |
23 | 50 | """ |
24 | 51 | This function converts the linear layers to `Float8Linear`. Note that today, |
25 | 52 | only dynamic tensor scaling (the default) is supported. |
26 | 53 |
|
27 | 54 | This will mutate the model inplace. |
28 | 55 | """ |
29 | | - use_fp8_linear = job_config.training.fp8_linear |
| 56 | + enable_float8_linear = job_config.training.enable_float8_linear |
| 57 | + if not enable_float8_linear: |
| 58 | + return |
| 59 | + if not is_sm90_or_later(): |
| 60 | + warning_once( |
| 61 | + logger, |
| 62 | + "Failed to swap to Float8Linear because SM90 or later is not available", |
| 63 | + ) |
| 64 | + return |
30 | 65 | try: |
31 | | - from float8_experimental.float8_linear import Float8Linear |
| 66 | + from float8_experimental.float8_linear import TensorScalingType |
32 | 67 | from float8_experimental.float8_linear_utils import ( |
33 | 68 | swap_linear_with_float8_linear, |
34 | 69 | ) |
| 70 | + |
| 71 | + # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear |
| 72 | + enable_fsdp_float8_all_gather = ( |
| 73 | + job_config.training.enable_fsdp_float8_all_gather and dp_enabled |
| 74 | + ) |
| 75 | + with set_enable_fsdp_float8_all_gather(enable_fsdp_float8_all_gather): |
| 76 | + swap_linear_with_float8_linear( |
| 77 | + model, scaling_type_w=TensorScalingType.DYNAMIC |
| 78 | + ) |
| 79 | + logger.info( |
| 80 | + f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}" |
| 81 | + ) |
35 | 82 | except ImportError as exc: |
36 | 83 | raise ImportError( |
37 | 84 | "float8_experimental is not installed. Please install it to use fp8 linear layers." |
38 | 85 | ) from exc |
39 | | - if use_fp8_linear: |
40 | | - # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear |
41 | | - swap_linear_with_float8_linear(model, Float8Linear) |
42 | | - logger.info("Swapped to Float8Linear layers") |
| 86 | + |
| 87 | + |
| 88 | +def maybe_precompute_fp8_dynamic_scale_for_fsdp( |
| 89 | + model: nn.Module, job_config: JobConfig |
| 90 | +): |
| 91 | + if not ( |
| 92 | + job_config.training.enable_float8_linear |
| 93 | + and job_config.training.enable_fsdp_float8_all_gather |
| 94 | + and job_config.training.precompute_float8_dynamic_scale_for_fsdp |
| 95 | + ): |
| 96 | + return |
| 97 | + if not is_sm90_or_later(): |
| 98 | + warning_once( |
| 99 | + logger, |
| 100 | + "Skipped precomputing fp8 scales because SM90 or later is not available", |
| 101 | + ) |
| 102 | + return |
| 103 | + from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp |
| 104 | + |
| 105 | + precompute_float8_dynamic_scale_for_fsdp(model) |
0 commit comments