Skip to content

Commit ac75d21

Browse files
authored
Merge branch 'main' into wengshiy/int8_scaled_embedding_bag
2 parents 912a8e6 + 5e90c47 commit ac75d21

File tree

15 files changed

+992
-53
lines changed

15 files changed

+992
-53
lines changed

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2948,11 +2948,10 @@ def has_inplace_ops(graph_module: torch.fx.GraphModule) -> bool:
29482948
@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+")
29492949
class TestQuantizePT2EAffineQuantization(PT2EQuantizationTestCase):
29502950
def test_channel_group_quantization(self):
2951-
from torchao.quantization import PerGroup, PerToken
29522951
from torchao.quantization.pt2e._affine_quantization import (
29532952
AffineQuantizedMinMaxObserver,
29542953
)
2955-
from torchao.quantization.pt2e.observer import MappingType
2954+
from torchao.quantization.pt2e.observer import MappingType, PerGroup, PerToken
29562955

29572956
class BackendAQuantizer(Quantizer):
29582957
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
@@ -3032,13 +3031,13 @@ def forward(self, x):
30323031
def test_dynamic_affine_act_per_channel_weights(self):
30333032
import operator
30343033

3035-
from torchao.quantization import PerToken
30363034
from torchao.quantization.pt2e._affine_quantization import (
30373035
AffineQuantizedMovingAverageMinMaxObserver,
30383036
)
30393037
from torchao.quantization.pt2e.observer import (
30403038
MappingType,
30413039
PerChannelMinMaxObserver,
3040+
PerToken,
30423041
)
30433042

30443043
class BackendAQuantizer(Quantizer):
@@ -3123,14 +3122,12 @@ def forward(self, x):
31233122
def test_dynamic_per_tok_act_per_group_weights(self):
31243123
import operator
31253124

3126-
from torchao.quantization import PerGroup, PerToken
3127-
31283125
# TODO: merge into torchao observer
31293126
from torchao.quantization.pt2e._affine_quantization import (
31303127
AffineQuantizedMinMaxObserver,
31313128
AffineQuantizedPlaceholderObserver,
31323129
)
3133-
from torchao.quantization.pt2e.observer import MappingType
3130+
from torchao.quantization.pt2e.observer import MappingType, PerGroup, PerToken
31343131

31353132
class BackendAQuantizer(Quantizer):
31363133
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:

test/test_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,14 @@
4040
except RuntimeError:
4141
pytest.skip("torchao.ops not available")
4242

43+
from torchao.quantization import PerGroup, PerRow, PerTensor
44+
from torchao.quantization.quant_primitives import (
45+
_choose_scale_float8,
46+
_dequantize_affine_float8,
47+
_quantize_affine_float8,
48+
)
4349
from torchao.quantization.utils import (
50+
get_block_size,
4451
get_groupwise_affine_qparams,
4552
groupwise_affine_dequantize_tensor_from_qparams,
4653
groupwise_affine_quantize_tensor_from_qparams,

0 commit comments

Comments
 (0)