Skip to content

Commit

Permalink
Fix docstring args names (#735)
Browse files Browse the repository at this point in the history
  • Loading branch information
kit1980 authored Aug 23, 2024
1 parent 9860194 commit 68e4643
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def pad_tensor_for_matmul(
Args:
tensor: The tensor to pad.
both: Whether to pad both dimensions or just the second dimension.
dims: Dimensions to pad.
Returns:
torch.Tensor: The padded tensor.
Expand Down
2 changes: 1 addition & 1 deletion torchao/float8/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def from_float(
Create an nn.Linear with fp8 compute from another nn.Linear
Args:
mod (torch.nn.Linear): nn.Linear to convert
module (torch.nn.Linear): nn.Linear to convert
quant_config (QuantConfig): Configuration for the weight and activation casting
"""
forward_config = ScaledMMConfig(
Expand Down
4 changes: 2 additions & 2 deletions torchao/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def dequantize_tensor_core_tiled_layout(packed_w: Tensor, scales_and_zeros: Tens
Args:
packed_w: torch.tensor: 4D tensor with shape `(N / 8) x (K / (inner_k_tiles * 16)) x 32 x inner_k_tiles / 2`, dtype is torch.int32
scales_and_zeros: torch.tensor: 3D tensor with shape `numQGroups x N x 2`, dtype is torch.bfloat16 where numQGroups is K / qGroupSize
qGroupSize: int
group_size: int
inner_k_tiles: int
Returns:
Expand Down Expand Up @@ -158,4 +158,4 @@ def _(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles
torch._check(scales_and_zeros.size(1) == N, lambda: "scales_and_zeros must have N at dim 1")
torch._check(scales_and_zeros.size(2) == 2, lambda: "scales_and_zeros must have 2 at dim 2")

return torch.empty((N, K), dtype=torch.bfloat16, device=packed_w.device)
return torch.empty((N, K), dtype=torch.bfloat16, device=packed_w.device)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def intN_weight_only(group_size=32, n=8, symmetric=False):
'''
Apply int N-bit weight only quantization to a linear layer.
Args:
`groupsize`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [512, 256, 128, 64, 32]
`group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [512, 256, 128, 64, 32]
`n`: number of bits to quantize to, choices are [8, 6, 5, 4, 3, 2]
Usage:
from torchao.quantization import quantize_
Expand Down
12 changes: 6 additions & 6 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
class MappingType(Enum):
"""How floating point number is mapped to integer number
symmetric mapping means floating point range is symetrically mapped to integer range
symmetric mapping means floating point range is symmetrically mapped to integer range
let's say we have floating point range (-3.5, 10.2) and integer range (-8, 7) (int4)
we'll use (-10.2, 10.2) as the range for floating point and map that to (-8, 7)
e.g. scale = (10.2 - (-10.2)) / (7 - (-8))
Expand Down Expand Up @@ -167,7 +167,7 @@ def quantize_affine(
output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype
quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
Expand Down Expand Up @@ -287,11 +287,11 @@ def dequantize_affine(
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
scale (Tensor): quantization parameter for affine quantization
zero_point (Tensor): quantization parameter for affine quantization
dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
input_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
quant_min (Optional[int]): minimum quantized value for input Tensor
quant_max (Optional[int]): maximum quantized value for input Tensor
output_dtype (torch.dtype): dtype for output Tensor, default is fp32
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
Expand Down Expand Up @@ -413,7 +413,7 @@ def fake_quantize_affine(
quant_dtype (torch.dtype): desired quantized dtype for determining and validating quant_min and quant_max values.
quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype
quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
Expand Down Expand Up @@ -549,7 +549,7 @@ def choose_qparams_affine(
If we don't need zero to be exactly representable, we won't do rounding and clamping for zero_point
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
Expand Down

0 comments on commit 68e4643

Please sign in to comment.