|  | 
|  | 1 | +#!/usr/bin/env python3 | 
|  | 2 | + | 
|  | 3 | +# Copyright (c) Meta Platforms, Inc. and affiliates. | 
|  | 4 | +# All rights reserved. | 
|  | 5 | +# This source code is licensed under the BSD-style license found in the | 
|  | 6 | +# LICENSE file in the root directory of this source tree. | 
|  | 7 | + | 
|  | 8 | +import logging | 
|  | 9 | + | 
|  | 10 | +import torch | 
|  | 11 | + | 
|  | 12 | +logger: logging.Logger = logging.getLogger() | 
|  | 13 | + | 
|  | 14 | +try: | 
|  | 15 | +    # pyre-ignore[21] | 
|  | 16 | +    from fbgemm_gpu import open_source  # noqa: F401 | 
|  | 17 | +except Exception: | 
|  | 18 | +    torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") | 
|  | 19 | +    torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") | 
|  | 20 | + | 
|  | 21 | +TORCH_HALF_MIN: float = torch.finfo(torch.float16).min | 
|  | 22 | +TORCH_HALF_MAX: float = torch.finfo(torch.float16).max | 
|  | 23 | + | 
|  | 24 | +TORCH_BFLOAT16_MIN: float = torch.finfo(torch.bfloat16).min | 
|  | 25 | +TORCH_BFLOAT16_MAX: float = torch.finfo(torch.bfloat16).max | 
|  | 26 | + | 
|  | 27 | + | 
|  | 28 | +def fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor: | 
|  | 29 | +    return torch.clamp(tensor, TORCH_HALF_MIN, TORCH_HALF_MAX).half() | 
|  | 30 | + | 
|  | 31 | + | 
|  | 32 | +def fp32_to_bf16_with_clamp(tensor: torch.Tensor) -> torch.Tensor: | 
|  | 33 | +    return torch.clamp(tensor, TORCH_BFLOAT16_MIN, TORCH_BFLOAT16_MAX).bfloat16() | 
|  | 34 | + | 
|  | 35 | + | 
|  | 36 | +def fp32_to_hfp8_with_clamp( | 
|  | 37 | +    tensor: torch.Tensor, ebits: int = 4, mbits: int = 3, bias: int = 15 | 
|  | 38 | +) -> torch.Tensor: | 
|  | 39 | +    max_pos: float = (2 ** ((1 << ebits) - 2 - bias)) * (2 - 2 ** (-mbits)) | 
|  | 40 | +    return torch.ops.fbgemm.FloatToHFP8Quantized( | 
|  | 41 | +        tensor.contiguous(), | 
|  | 42 | +        ebits, | 
|  | 43 | +        bias, | 
|  | 44 | +        max_pos, | 
|  | 45 | +    ) | 
|  | 46 | + | 
|  | 47 | + | 
|  | 48 | +def fp16_to_fp32(tensor: torch.Tensor) -> torch.Tensor: | 
|  | 49 | +    return tensor.float() | 
|  | 50 | + | 
|  | 51 | + | 
|  | 52 | +def bf16_to_fp32(tensor: torch.Tensor) -> torch.Tensor: | 
|  | 53 | +    return tensor.view(torch.bfloat16).float() | 
|  | 54 | + | 
|  | 55 | + | 
|  | 56 | +def hfp8_to_fp32(tensor: torch.Tensor, ebits: int = 4, bias: int = 15) -> torch.Tensor: | 
|  | 57 | +    return torch.ops.fbgemm.HFP8QuantizedToFloat( | 
|  | 58 | +        tensor.contiguous().view(torch.uint8), | 
|  | 59 | +        ebits, | 
|  | 60 | +        bias, | 
|  | 61 | +    ) | 
|  | 62 | + | 
|  | 63 | + | 
|  | 64 | +def measure_fp16_quant_error(input_tensor: torch.Tensor) -> None: | 
|  | 65 | +    # TODO: log to tensorboard | 
|  | 66 | + | 
|  | 67 | +    num_nan_fp32_tensor = torch.numel(input_tensor[torch.isnan(input_tensor)]) | 
|  | 68 | +    logger.info( | 
|  | 69 | +        "num NaN in fp32 tensor: {}, ratio: {}.".format( | 
|  | 70 | +            num_nan_fp32_tensor, num_nan_fp32_tensor / torch.numel(input_tensor) | 
|  | 71 | +        ) | 
|  | 72 | +    ) | 
|  | 73 | + | 
|  | 74 | +    logger.info( | 
|  | 75 | +        "fp32 tensor profile: min: {}, max: {}, min abs:{}, max abs:{}.".format( | 
|  | 76 | +            torch.min(input_tensor), | 
|  | 77 | +            torch.max(input_tensor), | 
|  | 78 | +            torch.min(torch.abs(input_tensor)), | 
|  | 79 | +            torch.max(torch.abs(input_tensor)), | 
|  | 80 | +        ) | 
|  | 81 | +    ) | 
|  | 82 | + | 
|  | 83 | +    fp16_tensor = fp32_to_fp16_with_clamp(input_tensor) | 
|  | 84 | +    num_nan_fp16_tensor = torch.numel(fp16_tensor[torch.isnan(fp16_tensor)]) | 
|  | 85 | + | 
|  | 86 | +    logger.info( | 
|  | 87 | +        "num NaN in fp16 tensor: {}, ratio: {}.".format( | 
|  | 88 | +            num_nan_fp16_tensor, num_nan_fp16_tensor / torch.numel(input_tensor) | 
|  | 89 | +        ) | 
|  | 90 | +    ) | 
|  | 91 | + | 
|  | 92 | +    diff = torch.abs(input_tensor - fp16_tensor.float()) | 
|  | 93 | +    rel_diff = diff / torch.abs(input_tensor) | 
|  | 94 | +    logger.info( | 
|  | 95 | +        "fp32_to_fp16 abs error: min={}, max={}, avg={}.".format( | 
|  | 96 | +            torch.min(diff), torch.max(diff), torch.mean(diff) | 
|  | 97 | +        ) | 
|  | 98 | +    ) | 
|  | 99 | + | 
|  | 100 | +    rel_diff_not_nan = rel_diff[torch.logical_not(torch.isnan(rel_diff))] | 
|  | 101 | +    logger.info( | 
|  | 102 | +        "fp32_to_fp16 rel error: min={}, max={}, avg={}.".format( | 
|  | 103 | +            torch.min(rel_diff_not_nan), | 
|  | 104 | +            torch.max(rel_diff_not_nan), | 
|  | 105 | +            torch.mean(rel_diff_not_nan), | 
|  | 106 | +        ) | 
|  | 107 | +    ) | 
|  | 108 | + | 
|  | 109 | +    rel_diff_1_idx = torch.where(rel_diff == 1.0) | 
|  | 110 | +    fp32_rel_err_1_vals = input_tensor[rel_diff_1_idx] | 
|  | 111 | +    if torch.numel(fp32_rel_err_1_vals) > 0: | 
|  | 112 | +        fp32_rel_err_1_vals = torch.abs(fp32_rel_err_1_vals) | 
|  | 113 | +        logger.info( | 
|  | 114 | +            "fp32_to_fp16 rel error == 1: fp32 min:{}, fp32 max:{}, fp32 avg:{}.".format( | 
|  | 115 | +                torch.min(fp32_rel_err_1_vals), | 
|  | 116 | +                torch.max(fp32_rel_err_1_vals), | 
|  | 117 | +                torch.mean(fp32_rel_err_1_vals), | 
|  | 118 | +            ) | 
|  | 119 | +        ) | 
|  | 120 | + | 
|  | 121 | +        subrange_ratio = torch.numel(fp16_tensor[rel_diff_1_idx]) / torch.numel( | 
|  | 122 | +            fp16_tensor | 
|  | 123 | +        ) | 
|  | 124 | +        logger.info("sub fp16 range ratio: {}".format(subrange_ratio)) | 
0 commit comments