Skip to content

Commit

Permalink
[Float8Configs] Make named tuples have better docs + public
Browse files Browse the repository at this point in the history
stack-info: PR: #808, branch: drisspg/stack/9
  • Loading branch information
drisspg committed Sep 4, 2024
1 parent f5703b0 commit ec334af
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 27 deletions.
1 change: 1 addition & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ include = [
"torchao/float8/float8_utils.py",
"torchao/dtypes/nf4tensor.py",
"test/dtypes/test_nf4.py",
"torchao/float8/float8_tensor.py",
]
1 change: 1 addition & 0 deletions torchao/float8/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"Float8GemmConfig",
"Float8LinearConfig",
"CastConfig",
"ScaledMMConfig",
# top level UX
"convert_to_float8_training",
"linear_requires_sync",
Expand Down
62 changes: 35 additions & 27 deletions torchao/float8/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import enum
from collections import namedtuple
from typing import Dict, Optional
from typing import Dict, Optional, NamedTuple

import torch

import torch.distributed._functional_collectives as funcol
from torchao.float8.float8_utils import (
e4m3_dtype,
tensor_to_amax,
to_fp8_saturated,
)
from torch.distributed._tensor import DTensor
Expand Down Expand Up @@ -47,31 +45,41 @@
# to configure all three gemms, also not user facing


# ScaledMMConfig is a namedtuple that defines the configuration for the scaled_mm in the forward and backward pass.
# emulate: whether to emulate the matmuls in fp32
# use_fast_accum: whether to use the fast-accumulation option for scaled_mm
# fp8_output: whether to output the result of the scaled_mm in fp8
# pad_inner_dim: whether to pad the inner dimension of a and b with 0s. This is needed for matmuls not aligned to 16.
ScaledMMConfig = namedtuple(
"ScaledMMConfig",
["emulate", "use_fast_accum", "fp8_output", "pad_inner_dim"],
defaults=[False, False, False, False],
)
class ScaledMMConfig(NamedTuple):
"""
Configuration for the scaled_mm in the forward and backward pass.
Attributes:
emulate (bool): Whether to emulate the matmuls in fp32.
use_fast_accum (bool): Whether to use the fast-accumulation option for scaled_mm.
fp8_output (bool): Whether to output the result of the scaled_mm in fp8.
pad_inner_dim (bool): Whether to pad the inner dimension of a and b with 0s.
This is needed for matmuls not aligned to 16.
"""

# The object below is not user facing and exists for convenience,
# to allow Float8Tensor to use
# the right config based on which gemm from gemms with outputs
# `output`, `grad_input`, `grad_weight` is
# being called.
LinearMMConfig = namedtuple(
"LinearMMConfig",
["output", "grad_input", "grad_weight"],
defaults=[
ScaledMMConfig(False, True, False, False),
ScaledMMConfig(False, False, False, False),
ScaledMMConfig(False, False, False, False),
],
)
emulate: bool = False
use_fast_accum: bool = False
fp8_output: bool = False
pad_inner_dim: bool = False


class LinearMMConfig(NamedTuple):
"""
Configuration for different gemm operations in LinearMM.
This configuration is not user-facing and exists for convenience,
allowing Float8Tensor to use the right config based on which gemm
from gemms with outputs `output`, `grad_input`, `grad_weight` is being called.
Attributes:
output (ScaledMMConfig): Configuration for the output gemm.
grad_input (ScaledMMConfig): Configuration for the grad_input gemm.
grad_weight (ScaledMMConfig): Configuration for the grad_weight gemm.
"""

output: ScaledMMConfig = ScaledMMConfig(False, True, False, False)
grad_input: ScaledMMConfig = ScaledMMConfig(False, False, False, False)
grad_weight: ScaledMMConfig = ScaledMMConfig(False, False, False, False)


class GemmInputRole(enum.Enum):
Expand Down

0 comments on commit ec334af

Please sign in to comment.