Skip to content

Commit a6f5488

Browse files
jianyuhfacebook-github-bot
authored andcommitted
Extract the quantization utils function in FBGEMM (#1204)
Summary: Pull Request resolved: #1204 This will be better shared between Trec and HPC. - We will refactor https://www.internalfb.com/code/fbsource/[history]/fbcode/caffe2/torch/fb/hpc/quantized_comms_lib.py to extract the common components in FBGEMM - It's open source so TorchRec can call it from FBGEMM. - Reuse the quantize utils functions and dedup the code. This part of code is landable. Reviewed By: YLGH Differential Revision: D37799807 fbshipit-source-id: ced3c98efd096985db02c287449fa48939fd3da3
1 parent 49061a2 commit a6f5488

File tree

1 file changed

+124
-0
lines changed

1 file changed

+124
-0
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Meta Platforms, Inc. and affiliates.
4+
# All rights reserved.
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import logging
9+
10+
import torch
11+
12+
logger: logging.Logger = logging.getLogger()
13+
14+
try:
15+
# pyre-ignore[21]
16+
from fbgemm_gpu import open_source # noqa: F401
17+
except Exception:
18+
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
19+
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
20+
21+
TORCH_HALF_MIN: float = torch.finfo(torch.float16).min
22+
TORCH_HALF_MAX: float = torch.finfo(torch.float16).max
23+
24+
TORCH_BFLOAT16_MIN: float = torch.finfo(torch.bfloat16).min
25+
TORCH_BFLOAT16_MAX: float = torch.finfo(torch.bfloat16).max
26+
27+
28+
def fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
29+
return torch.clamp(tensor, TORCH_HALF_MIN, TORCH_HALF_MAX).half()
30+
31+
32+
def fp32_to_bf16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
33+
return torch.clamp(tensor, TORCH_BFLOAT16_MIN, TORCH_BFLOAT16_MAX).bfloat16()
34+
35+
36+
def fp32_to_hfp8_with_clamp(
37+
tensor: torch.Tensor, ebits: int = 4, mbits: int = 3, bias: int = 15
38+
) -> torch.Tensor:
39+
max_pos: float = (2 ** ((1 << ebits) - 2 - bias)) * (2 - 2 ** (-mbits))
40+
return torch.ops.fbgemm.FloatToHFP8Quantized(
41+
tensor.contiguous(),
42+
ebits,
43+
bias,
44+
max_pos,
45+
)
46+
47+
48+
def fp16_to_fp32(tensor: torch.Tensor) -> torch.Tensor:
49+
return tensor.float()
50+
51+
52+
def bf16_to_fp32(tensor: torch.Tensor) -> torch.Tensor:
53+
return tensor.view(torch.bfloat16).float()
54+
55+
56+
def hfp8_to_fp32(tensor: torch.Tensor, ebits: int = 4, bias: int = 15) -> torch.Tensor:
57+
return torch.ops.fbgemm.HFP8QuantizedToFloat(
58+
tensor.contiguous().view(torch.uint8),
59+
ebits,
60+
bias,
61+
)
62+
63+
64+
def measure_fp16_quant_error(input_tensor: torch.Tensor) -> None:
65+
# TODO: log to tensorboard
66+
67+
num_nan_fp32_tensor = torch.numel(input_tensor[torch.isnan(input_tensor)])
68+
logger.info(
69+
"num NaN in fp32 tensor: {}, ratio: {}.".format(
70+
num_nan_fp32_tensor, num_nan_fp32_tensor / torch.numel(input_tensor)
71+
)
72+
)
73+
74+
logger.info(
75+
"fp32 tensor profile: min: {}, max: {}, min abs:{}, max abs:{}.".format(
76+
torch.min(input_tensor),
77+
torch.max(input_tensor),
78+
torch.min(torch.abs(input_tensor)),
79+
torch.max(torch.abs(input_tensor)),
80+
)
81+
)
82+
83+
fp16_tensor = fp32_to_fp16_with_clamp(input_tensor)
84+
num_nan_fp16_tensor = torch.numel(fp16_tensor[torch.isnan(fp16_tensor)])
85+
86+
logger.info(
87+
"num NaN in fp16 tensor: {}, ratio: {}.".format(
88+
num_nan_fp16_tensor, num_nan_fp16_tensor / torch.numel(input_tensor)
89+
)
90+
)
91+
92+
diff = torch.abs(input_tensor - fp16_tensor.float())
93+
rel_diff = diff / torch.abs(input_tensor)
94+
logger.info(
95+
"fp32_to_fp16 abs error: min={}, max={}, avg={}.".format(
96+
torch.min(diff), torch.max(diff), torch.mean(diff)
97+
)
98+
)
99+
100+
rel_diff_not_nan = rel_diff[torch.logical_not(torch.isnan(rel_diff))]
101+
logger.info(
102+
"fp32_to_fp16 rel error: min={}, max={}, avg={}.".format(
103+
torch.min(rel_diff_not_nan),
104+
torch.max(rel_diff_not_nan),
105+
torch.mean(rel_diff_not_nan),
106+
)
107+
)
108+
109+
rel_diff_1_idx = torch.where(rel_diff == 1.0)
110+
fp32_rel_err_1_vals = input_tensor[rel_diff_1_idx]
111+
if torch.numel(fp32_rel_err_1_vals) > 0:
112+
fp32_rel_err_1_vals = torch.abs(fp32_rel_err_1_vals)
113+
logger.info(
114+
"fp32_to_fp16 rel error == 1: fp32 min:{}, fp32 max:{}, fp32 avg:{}.".format(
115+
torch.min(fp32_rel_err_1_vals),
116+
torch.max(fp32_rel_err_1_vals),
117+
torch.mean(fp32_rel_err_1_vals),
118+
)
119+
)
120+
121+
subrange_ratio = torch.numel(fp16_tensor[rel_diff_1_idx]) / torch.numel(
122+
fp16_tensor
123+
)
124+
logger.info("sub fp16 range ratio: {}".format(subrange_ratio))

0 commit comments

Comments
 (0)