|
1 | 1 | import torch |
2 | 2 | from typing import Tuple, Optional, Union |
3 | 3 | import torchao.ops |
4 | | -from collections import defaultdict |
5 | | -import functools |
6 | 4 | import math |
7 | 5 | from torchao.quantization.quant_primitives import ( |
8 | 6 | choose_qparams_affine, |
|
36 | 34 | _is_float8_type |
37 | 35 | ) |
38 | 36 | import logging |
| 37 | +from torchao.float8.inference import Float8MMConfig |
39 | 38 |
|
40 | 39 | logger = logging.getLogger(__name__) |
41 | 40 |
|
42 | | -from torchao.float8.inference import Float8MMConfig |
43 | 41 | aten = torch.ops.aten |
44 | 42 |
|
45 | 43 |
|
@@ -1024,7 +1022,6 @@ def from_plain( |
1024 | 1022 | packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, layout_type.inner_k_tiles) |
1025 | 1023 | scale = scale.reshape(int_data.shape[0], -1) |
1026 | 1024 | zero_point = zero_point.reshape(int_data.shape[0], -1) |
1027 | | - from torchao.quantization.utils import pack_tinygemm_scales_and_zeros |
1028 | 1025 | scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) |
1029 | 1026 | return cls(packed_weight, scale_and_zero, False, layout_type) |
1030 | 1027 |
|
@@ -1232,7 +1229,7 @@ def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias): |
1232 | 1229 |
|
1233 | 1230 |
|
1234 | 1231 | def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): |
1235 | | - assert weight_tensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}" |
| 1232 | + assert weight_tensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" |
1236 | 1233 | assert input_tensor.shape[-1] == weight_tensor.shape[1], ( |
1237 | 1234 | f"need input_tensor shape: {input_tensor.shape} final" |
1238 | 1235 | f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " |
@@ -1292,7 +1289,6 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): |
1292 | 1289 | # per channel int8 weight only quantizated mm |
1293 | 1290 | w_vals_int8_t = weight_tensor.layout_tensor.int_data.t() |
1294 | 1291 | scale = weight_tensor.layout_tensor.scale |
1295 | | - orig_dtype = input_tensor.dtype |
1296 | 1292 | m = torch.mm( |
1297 | 1293 | input_tensor.reshape(-1, input_tensor.shape[-1]), |
1298 | 1294 | w_vals_int8_t.to(input_tensor.dtype), |
@@ -1424,7 +1420,7 @@ def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor, |
1424 | 1420 | ) |
1425 | 1421 |
|
1426 | 1422 | def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, bias): |
1427 | | - from torchao.sparsity.marlin import marlin_24_workspace, const |
| 1423 | + from torchao.sparsity.marlin import marlin_24_workspace |
1428 | 1424 | assert isinstance(weight_tensor, AffineQuantizedTensor) |
1429 | 1425 |
|
1430 | 1426 | sparse_w_int4 = weight_tensor.layout_tensor.int_data |
|
0 commit comments