-
Notifications
You must be signed in to change notification settings - Fork 245
/
Copy pathsubclass_fp8.py
198 lines (154 loc) · 6.52 KB
/
subclass_fp8.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
from torch import Tensor
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor
aten = torch.ops.aten
c10d_functional = torch.ops.c10d_functional
_c10d_functional = torch.ops._c10d_functional
DTYPE = torch.float8_e4m3fn
def quantize_fp8(input: Tensor, block_size: int):
shape = input.shape
input = input.view(-1, block_size)
scale = input.abs().amax(-1).clip(1e-12) / torch.finfo(DTYPE).max
input = input / scale.view(-1, 1)
codes = input.to(DTYPE).view(-1)
return codes.view(shape), scale
# NOTE: FP8 sign bit is redundant for unsigned optim state.
# we may investigate how to use it to increase range/precision for unsigned optim state.
# https://arxiv.org/abs/2409.12517 uses FP8 E5M2 for 2nd Adam buffer
class OptimStateFp8(TorchAOBaseTensor):
tensor_attrs = ["codes", "scale"]
@staticmethod
def __new__(cls, codes: Tensor, scale: Tensor):
return Tensor._make_wrapper_subclass(cls, codes.shape, device=codes.device)
def __init__(self, codes: Tensor, scale: Tensor):
"""Create quantized FP8 optimizer state.
Args
codes: quantized FP8 E4M3FN data. Has the same shape as the original float tensor.
scale: scale data for block-wise quantization.
NOTE: To get block-wise scale, the original float tensor is first reshape to (-1, block_size).
Thus, the last dimension of the original float tensor is not necessarily divisible by block size.
Given `codes` and `scale`, `block_size` is calculated as `codes.numel() // scale.numel()`.
"""
assert codes.dtype is DTYPE
assert scale.ndim == 1
self.codes = codes
self.scale = scale
self.block_size = codes.numel() // scale.numel()
def __tensor_flatten__(self):
return self.tensor_attrs, []
@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None
):
return cls(
*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes
)
def dequantize(self, output_dtype=None):
float_data = self.codes.float()
float_data = float_data.view(-1, self.block_size) * self.scale.view(-1, 1)
if output_dtype is not None:
float_data = float_data.to(output_dtype)
return float_data.view(self.codes.shape)
@classmethod
def zeros(cls, shape, block_size: int = 256, device=None):
codes = torch.zeros(shape, dtype=DTYPE, device=device)
scale = torch.zeros(codes.numel() // block_size, device=device)
return cls(codes, scale)
def __repr__(self):
return (
f"{self.__class__.__name__}(block_size={self.block_size}, "
f"shape={tuple(self.shape)}, device={self.device}, requires_grad={self.requires_grad})"
)
@OptimStateFp8.implements(aten.copy_.default)
def _(func, types, args, kwargs):
dst = args[0]
src = args[1]
if isinstance(dst, OptimStateFp8) and isinstance(src, OptimStateFp8):
assert dst.block_size == src.block_size
dst.codes.copy_(src.codes)
dst.scale.copy_(src.scale)
elif isinstance(dst, OptimStateFp8):
codes, scale = quantize_fp8(src, dst.block_size)
dst.codes.copy_(codes)
dst.scale.copy_(scale)
else:
dst.copy_(src.dequantize())
return dst
@OptimStateFp8.implements(aten._to_copy.default)
def _(func, types, args, kwargs):
# ignore dtype
device = kwargs.get("device", None)
out = OptimStateFp8(
args[0].codes.to(device=device),
args[0].scale.to(device=device),
)
return return_and_correct_aliasing(func, args, kwargs, out)
@OptimStateFp8.implements(aten.lerp.Scalar)
def _(func, types, args, kwargs):
args = [x.dequantize() if isinstance(x, OptimStateFp8) else x for x in args]
return func(*args, **kwargs)
# this is needed for DTensor.from_local()
@OptimStateFp8.implements(aten.view.default)
def _(func, types, args, kwargs):
x, shape = args
return OptimStateFp8(x.codes.view(shape), x.scale)
@OptimStateFp8.implements(
[
# required by DTensor.full_tensor()
c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor.default,
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
# required by torch.distributed.checkpoint.save
aten.detach.default,
]
)
def _(func, types, args, kwargs):
x = args[0]
if not isinstance(x, OptimStateFp8):
raise ValueError(f"expecting a OptimStateFp8 but found {type(x)}")
# assume tensors from all ranks have the same signedness
return OptimStateFp8(
func(x.codes, *args[1:], **kwargs),
func(x.scale, *args[1:], **kwargs),
)
# required by torch.distributed.checkpoint.save
# note that we don't actually implement pin memory for this tensor subclass
# (pin_memory argument is ignored in aten._to_copy)
@OptimStateFp8.implements(aten.is_pinned.default)
def _(func, types, args, kwargs):
return args[0].codes.is_pinned() and args[0].scale.is_pinned()
# required by torch.distributed.checkpoint.load when world size changes i.e. re-sharding
@OptimStateFp8.implements(aten.slice.Tensor)
def _(func, types, args, kwargs):
x, dim, start, end = args[:4]
step = args[4] if len(args) > 4 else 1
# input validation
if dim != 0:
raise ValueError("Only support aten.slice along the first dim")
if step != 1:
raise ValueError("Only support aten.slice with step=1")
block_size = x.block_size
stride = math.prod(x.shape[1:])
# for 1 increment in x along the first dim,
# (flattened) scale will increment by stride / block_size
if (start * stride) % block_size != 0 or (end * stride) % block_size != 0:
raise ValueError(
f"Invalid start or end for shape={x.shape} and block_size={block_size}. "
f"Make sure start and end align with block boundary. "
f"Received start={start}, end={end}."
)
return OptimStateFp8(
x.codes[start:end],
x.scale[start * stride // block_size : end * stride // block_size],
)
if TORCH_VERSION_AT_LEAST_2_5:
from torch.serialization import add_safe_globals
add_safe_globals([OptimStateFp8])