@@ -2948,11 +2948,10 @@ def has_inplace_ops(graph_module: torch.fx.GraphModule) -> bool:
29482948@unittest .skipIf (not torch_version_at_least ("2.7.0" ), "Requires torch 2.7+" )
29492949class TestQuantizePT2EAffineQuantization (PT2EQuantizationTestCase ):
29502950 def test_channel_group_quantization (self ):
2951- from torchao .quantization import PerGroup , PerToken
29522951 from torchao .quantization .pt2e ._affine_quantization import (
29532952 AffineQuantizedMinMaxObserver ,
29542953 )
2955- from torchao .quantization .pt2e .observer import MappingType
2954+ from torchao .quantization .pt2e .observer import MappingType , PerGroup , PerToken
29562955
29572956 class BackendAQuantizer (Quantizer ):
29582957 def annotate (self , model : torch .fx .GraphModule ) -> torch .fx .GraphModule :
@@ -3032,13 +3031,13 @@ def forward(self, x):
30323031 def test_dynamic_affine_act_per_channel_weights (self ):
30333032 import operator
30343033
3035- from torchao .quantization import PerToken
30363034 from torchao .quantization .pt2e ._affine_quantization import (
30373035 AffineQuantizedMovingAverageMinMaxObserver ,
30383036 )
30393037 from torchao .quantization .pt2e .observer import (
30403038 MappingType ,
30413039 PerChannelMinMaxObserver ,
3040+ PerToken ,
30423041 )
30433042
30443043 class BackendAQuantizer (Quantizer ):
@@ -3123,14 +3122,12 @@ def forward(self, x):
31233122 def test_dynamic_per_tok_act_per_group_weights (self ):
31243123 import operator
31253124
3126- from torchao .quantization import PerGroup , PerToken
3127-
31283125 # TODO: merge into torchao observer
31293126 from torchao .quantization .pt2e ._affine_quantization import (
31303127 AffineQuantizedMinMaxObserver ,
31313128 AffineQuantizedPlaceholderObserver ,
31323129 )
3133- from torchao .quantization .pt2e .observer import MappingType
3130+ from torchao .quantization .pt2e .observer import MappingType , PerGroup , PerToken
31343131
31353132 class BackendAQuantizer (Quantizer ):
31363133 def annotate (self , model : torch .fx .GraphModule ) -> torch .fx .GraphModule :
0 commit comments