@@ -95,6 +95,7 @@ class Float8Tensor(TorchAOBaseTensor):
9595
9696 tensor_data_names = ["qdata" , "scale" ]
9797 tensor_attribute_names = []
98+ optional_tensor_data_names = ["test_only_data" ]
9899 optional_tensor_attribute_names = [
99100 "block_size" ,
100101 "mm_config" ,
@@ -109,6 +110,7 @@ def __new__(
109110 cls ,
110111 qdata : torch .Tensor ,
111112 scale : torch .Tensor ,
113+ test_only_data : Optional [torch .Tensor ] = None ,
112114 block_size : Optional [List [int ]] = None ,
113115 mm_config : Optional [Float8MMConfig ] = None ,
114116 hp_value_lb : Optional [float ] = None ,
@@ -128,6 +130,7 @@ def __init__(
128130 self ,
129131 qdata : torch .Tensor ,
130132 scale : torch .Tensor ,
133+ test_only_data : Optional [torch .Tensor ] = None ,
131134 block_size : Optional [List [int ]] = None ,
132135 mm_config : Optional [Float8MMConfig ] = None ,
133136 hp_value_lb : Optional [float ] = None ,
@@ -138,6 +141,7 @@ def __init__(
138141 ):
139142 self .qdata = qdata
140143 self .scale = scale
144+ self .test_only_data = test_only_data
141145 self .block_size = block_size
142146 self .mm_config = mm_config
143147 self .hp_value_lb = hp_value_lb
0 commit comments