diff --git a/auto_round/data_type/mxfp.py b/auto_round/data_type/mxfp.py index 862ff0a9a..7959f6a8a 100644 --- a/auto_round/data_type/mxfp.py +++ b/auto_round/data_type/mxfp.py @@ -77,6 +77,12 @@ def quant_element(tensor, ebits, mbits, max_norm, mantissa_rounding="even"): return tensor +def quant_element_ste(tensor, ebits, mbits, max_norm, mantissa_rounding="even"): + with torch.no_grad(): + tensor_q = quant_element(tensor, ebits, mbits, max_norm, mantissa_rounding) + return (tensor_q - tensor).detach() + tensor + + def quant_mx(tensor, bits=4, group_size=-1, v=0, max_scale=1.0, mantissa_rounding="even", data_type="mx_fp", **kwargs): """Quantize the given tensor using the specified parameters. @@ -120,7 +126,7 @@ def quant_mx(tensor, bits=4, group_size=-1, v=0, max_scale=1.0, mantissa_roundin scale = torch.pow(2, shared_exp) tensor = tensor / scale + v tensor = torch.clamp(tensor, min=-max_norm, max=max_norm) - tensor = quant_element(tensor, ebits, mbits, max_norm, mantissa_rounding) + tensor = quant_element_ste(tensor, ebits, mbits, max_norm, mantissa_rounding) tensor = tensor * scale tensor = revert_tensor_by_pad(tensor, orig_shape=orig_shape, pad_len=pad_len) @@ -171,7 +177,7 @@ def quant_mx_rceil( scale = torch.pow(2, shared_exp) tensor = tensor / scale + v tensor = torch.clamp(tensor, min=-max_norm, max=max_norm) - tensor = quant_element(tensor, ebits, mbits, max_norm, mantissa_rounding) + tensor = quant_element_ste(tensor, ebits, mbits, max_norm, mantissa_rounding) tensor = tensor * scale tensor = revert_tensor_by_pad(tensor, orig_shape=orig_shape, pad_len=pad_len)