File tree Expand file tree Collapse file tree 2 files changed +12
-2
lines changed
test/prototype/safetensors
torchao/prototype/safetensors Expand file tree Collapse file tree 2 files changed +12
-2
lines changed Original file line number Diff line number Diff line change @@ -45,6 +45,7 @@ class TestSafeTensors(TestCase):
4545 (Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ()), False ),
4646 (Int4WeightOnlyConfig (), False ),
4747 (Int4WeightOnlyConfig (), True ),
48+ (Int4WeightOnlyConfig (int4_packing_format = "tile_packed_to_4d" ), False ),
4849 ],
4950 )
5051 def test_safetensors (self , config , act_pre_scale = False ):
Original file line number Diff line number Diff line change 66import torch
77
88import torchao
9- from torchao .quantization import Float8Tensor , Int4Tensor
9+ from torchao .quantization import (
10+ Float8Tensor ,
11+ Int4Tensor ,
12+ Int4TilePackedTo4dTensor ,
13+ )
1014from torchao .quantization .quantize_ .common import KernelPreference
1115from torchao .quantization .quantize_ .workflows import QuantizeTensorToFloat8Kwargs
1216
1317ALLOWED_CLASSES = {
1418 "Float8Tensor" : Float8Tensor ,
1519 "Int4Tensor" : Int4Tensor ,
20+ "Int4TilePackedTo4dTensor" : Int4TilePackedTo4dTensor ,
1621 "Float8MMConfig" : torchao .float8 .inference .Float8MMConfig ,
1722 "QuantizeTensorToFloat8Kwargs" : QuantizeTensorToFloat8Kwargs ,
1823 "PerRow" : torchao .quantization .PerRow ,
1924 "PerTensor" : torchao .quantization .PerTensor ,
2025 "KernelPreference" : KernelPreference ,
2126}
2227
23- ALLOWED_TENSORS_SUBCLASSES = ["Float8Tensor" , "Int4Tensor" ]
28+ ALLOWED_TENSORS_SUBCLASSES = [
29+ "Float8Tensor" ,
30+ "Int4Tensor" ,
31+ "Int4TilePackedTo4dTensor" ,
32+ ]
2433
2534__all__ = [
2635 "TensorSubclassAttributeJSONEncoder" ,
You can’t perform that action at this time.
0 commit comments