diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 868d4f52a..54613e5b0 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -209,7 +209,7 @@ def pad_tensor_for_matmul( Args: tensor: The tensor to pad. - both: Whether to pad both dimensions or just the second dimension. + dims: Dimensions to pad. Returns: torch.Tensor: The padded tensor. diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index f441009c4..ccf83d7ce 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -151,7 +151,7 @@ def from_float( Create an nn.Linear with fp8 compute from another nn.Linear Args: - mod (torch.nn.Linear): nn.Linear to convert + module (torch.nn.Linear): nn.Linear to convert quant_config (QuantConfig): Configuration for the weight and activation casting """ forward_config = ScaledMMConfig( diff --git a/torchao/ops.py b/torchao/ops.py index 4fcc8681a..cb337aabb 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -115,7 +115,7 @@ def dequantize_tensor_core_tiled_layout(packed_w: Tensor, scales_and_zeros: Tens Args: packed_w: torch.tensor: 4D tensor with shape `(N / 8) x (K / (inner_k_tiles * 16)) x 32 x inner_k_tiles / 2`, dtype is torch.int32 scales_and_zeros: torch.tensor: 3D tensor with shape `numQGroups x N x 2`, dtype is torch.bfloat16 where numQGroups is K / qGroupSize - qGroupSize: int + group_size: int inner_k_tiles: int Returns: @@ -158,4 +158,4 @@ def _(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles torch._check(scales_and_zeros.size(1) == N, lambda: "scales_and_zeros must have N at dim 1") torch._check(scales_and_zeros.size(2) == 2, lambda: "scales_and_zeros must have 2 at dim 2") - return torch.empty((N, K), dtype=torch.bfloat16, device=packed_w.device) \ No newline at end of file + return torch.empty((N, K), dtype=torch.bfloat16, device=packed_w.device) diff --git a/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py b/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py index 6ebe458a4..6ec933435 100644 --- a/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py +++ b/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py @@ -12,7 +12,7 @@ def intN_weight_only(group_size=32, n=8, symmetric=False): ''' Apply int N-bit weight only quantization to a linear layer. Args: - `groupsize`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [512, 256, 128, 64, 32] + `group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [512, 256, 128, 64, 32] `n`: number of bits to quantize to, choices are [8, 6, 5, 4, 3, 2] Usage: from torchao.quantization import quantize_ diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 1ac97de3c..bd4656f6c 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -33,7 +33,7 @@ class MappingType(Enum): """How floating point number is mapped to integer number - symmetric mapping means floating point range is symetrically mapped to integer range + symmetric mapping means floating point range is symmetrically mapped to integer range let's say we have floating point range (-3.5, 10.2) and integer range (-8, 7) (int4) we'll use (-10.2, 10.2) as the range for floating point and map that to (-8, 7) e.g. scale = (10.2 - (-10.2)) / (7 - (-8)) @@ -167,7 +167,7 @@ def quantize_affine( output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float + zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float if zero_point is in integer domain, zero point is added to the quantized integer value during quantization if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) @@ -287,11 +287,11 @@ def dequantize_affine( e.g. when size is the same as the input tensor dimension, we are using per tensor quantization scale (Tensor): quantization parameter for affine quantization zero_point (Tensor): quantization parameter for affine quantization - dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + input_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor quant_min (Optional[int]): minimum quantized value for input Tensor quant_max (Optional[int]): maximum quantized value for input Tensor output_dtype (torch.dtype): dtype for output Tensor, default is fp32 - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float + zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float if zero_point is in integer domain, zero point is added to the quantized integer value during quantization if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) @@ -413,7 +413,7 @@ def fake_quantize_affine( quant_dtype (torch.dtype): desired quantized dtype for determining and validating quant_min and quant_max values. quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float + zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float if zero_point is in integer domain, zero point is added to the quantized integer value during quantization if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) @@ -549,7 +549,7 @@ def choose_qparams_affine( If we don't need zero to be exactly representable, we won't do rounding and clamping for zero_point - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float + zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float if zero_point is in integer domain, zero point is added to the quantized integer value during quantization if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)