Skip to content

Commit 5091d35

Browse files
committed
Lint fixes;
1 parent 0601b5c commit 5091d35

File tree

6 files changed

+21
-11
lines changed

6 files changed

+21
-11
lines changed

ruff.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# We plan to add files in chunks using the 'include' list below.
33
# To add a new path: Simply add it to the 'include' list.
44
# Example: To lint all files in every subfolder of 'test', add "test/**/*"
5+
# To exclude a file type: Simply add it to the 'include' list.
6+
# Example: To lint all files in every subfolder of 'test', add "test/**/*"
57
include = [
68
"torchao/float8/inference.py",
79
"torchao/float8/float8_utils.py",
@@ -10,4 +12,9 @@ include = [
1012
"torchao/float8/float8_tensor.py",
1113
"torchao/quantization/linear_activation_weight_observer.py",
1214
"test/quantization/test_observer.py",
15+
"torchao/dtypes/*"
16+
]
17+
18+
exclude = [
19+
"**/*.md"
1320
]

torchao/dtypes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
__all__ = [
2222
"NF4Tensor",
2323
"to_nf4",
24-
"UInt4Tensor"
24+
"UInt4Tensor",
2525
"AffineQuantizedTensor",
2626
"to_affine_quantized_intx",
2727
"to_affine_quantized_intx_static",

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import torch
22
from typing import Tuple, Optional, Union
33
import torchao.ops
4-
from collections import defaultdict
5-
import functools
64
import math
75
from torchao.quantization.quant_primitives import (
86
choose_qparams_affine,
@@ -36,10 +34,10 @@
3634
_is_float8_type
3735
)
3836
import logging
37+
from torchao.float8.inference import Float8MMConfig
3938

4039
logger = logging.getLogger(__name__)
4140

42-
from torchao.float8.inference import Float8MMConfig
4341
aten = torch.ops.aten
4442

4543

@@ -1024,7 +1022,6 @@ def from_plain(
10241022
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, layout_type.inner_k_tiles)
10251023
scale = scale.reshape(int_data.shape[0], -1)
10261024
zero_point = zero_point.reshape(int_data.shape[0], -1)
1027-
from torchao.quantization.utils import pack_tinygemm_scales_and_zeros
10281025
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)
10291026
return cls(packed_weight, scale_and_zero, False, layout_type)
10301027

@@ -1232,7 +1229,7 @@ def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias):
12321229

12331230

12341231
def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias):
1235-
assert weight_tensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}"
1232+
assert weight_tensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}"
12361233
assert input_tensor.shape[-1] == weight_tensor.shape[1], (
12371234
f"need input_tensor shape: {input_tensor.shape} final"
12381235
f"dim to match weight_tensor shape: {weight_tensor.shape} second dim "
@@ -1292,7 +1289,6 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias):
12921289
# per channel int8 weight only quantizated mm
12931290
w_vals_int8_t = weight_tensor.layout_tensor.int_data.t()
12941291
scale = weight_tensor.layout_tensor.scale
1295-
orig_dtype = input_tensor.dtype
12961292
m = torch.mm(
12971293
input_tensor.reshape(-1, input_tensor.shape[-1]),
12981294
w_vals_int8_t.to(input_tensor.dtype),
@@ -1424,7 +1420,7 @@ def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor,
14241420
)
14251421

14261422
def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, bias):
1427-
from torchao.sparsity.marlin import marlin_24_workspace, const
1423+
from torchao.sparsity.marlin import marlin_24_workspace
14281424
assert isinstance(weight_tensor, AffineQuantizedTensor)
14291425

14301426
sparse_w_int4 = weight_tensor.layout_tensor.int_data

torchao/dtypes/fpx/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,9 @@
11
from .fpx import FpxTensorCoreLayoutType, FpxTensorCoreAQTLayout, to_scaled_tc_fpx, from_scaled_tc_fpx, _SPLIT_K_MAP
2+
3+
__all__ = [
4+
"FpxTensorCoreAQTLayout",
5+
"FpxTensorCoreLayoutType",
6+
"to_scaled_tc_fpx",
7+
"from_scaled_tc_fpx",
8+
"_SPLIT_K_MAP",
9+
]

torchao/dtypes/fpx/fpx.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from torchao.dtypes.utils import (
99
LayoutType,
1010
)
11-
from torchao.quantization.quant_api import _get_linear_subclass_inserter
1211
from dataclasses import dataclass
1312
from torchao.dtypes.affine_quantized_tensor import AQTLayout, register_layout_cls
1413

torchao/dtypes/uintx/bitpacking.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def unpack_cpu(data: List[torch.Tensor],
160160
output_narrow = output.narrow(dim, j * group_size, group_size)
161161
group = data[i] & unpack_mask[bit_size][j]
162162
shift_amt = j * bit_size - rel_pos
163-
output_narrow.copy_(torch.bitwise_or(output_narrow, abs_rsh(group, j * bit_size - rel_pos)))
163+
output_narrow.copy_(torch.bitwise_or(output_narrow, abs_rsh(group, shift_amt)))
164164
return output
165165

166166
# these are faster on the GPU
@@ -193,7 +193,7 @@ def _unpack(data, element_size, scale, dim):
193193

194194
for i in range(scale):
195195
shift_amt = element_size * i
196-
chunk = unpacked_data.narrow(dim, unpacked_data.shape[dim]*i//scale, unpacked_data.shape[dim] // scale).copy_((data >> shift_amt) & nbits)
196+
unpacked_data.narrow(dim, unpacked_data.shape[dim]*i//scale, unpacked_data.shape[dim] // scale).copy_((data >> shift_amt) & nbits)
197197

198198
return unpacked_data
199199

0 commit comments

Comments
 (0)