4949)
5050from torch .ao .quantization .fx .utils import (
5151 _get_module ,
52- assert_and_get_unique_device ,
5352 collect_producer_nodes ,
54- create_getattr_from_value ,
5553 graph_module_from_producer_nodes ,
5654 node_arg_is_weight ,
5755)
7371
7472from torchao .quantization .pt2e import FROM_NODE_KEY
7573from torchao .quantization .pt2e .observer import _is_activation_post_process
76- from torchao .utils import TORCH_VERSION_AT_LEAST_2_6
74+ from torchao .quantization .pt2e .utils import create_getattr_from_value
75+ from torchao .utils import (
76+ TORCH_VERSION_AT_LEAST_2_6 ,
77+ _assert_and_get_unique_device ,
78+ )
7779
7880if TORCH_VERSION_AT_LEAST_2_6 :
7981 from torch .fx .traceback import NodeSource , NodeSourceAction
@@ -132,6 +134,7 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
132134 modules : dict [str , torch .nn .Module ],
133135 node_name_to_scope : dict [str , tuple [str , type ]],
134136 node_name_to_qconfig : dict [str , QConfigAny ],
137+ model_device : Optional [torch .device ] = None ,
135138) -> None :
136139 """Replace activation_post_process module call node with quantize and
137140 dequantize node working with decomposed Tensor
@@ -260,7 +263,11 @@ def add_quantize_dequantize_node_info(qdq_node, original_node):
260263 # sure that the default overload can be used.
261264 # TODO: maybe need more complex attr name here
262265 qparam_node = create_getattr_from_value (
263- model , graph , module_path + prefix + key , value_or_node
266+ model ,
267+ graph ,
268+ module_path + prefix + key ,
269+ value_or_node ,
270+ model_device ,
264271 )
265272 quantize_op_inputs .append (qparam_node )
266273 else :
@@ -407,6 +414,7 @@ def _replace_observer_with_quantize_dequantize_node(
407414 modules : dict [str , torch .nn .Module ],
408415 node_name_to_scope : dict [str , tuple [str , type ]],
409416 node_name_to_qconfig : dict [str , QConfigAny ],
417+ model_device : Optional [torch .device ] = None ,
410418) -> None :
411419 """Replace activation_post_process module call node with quantize and
412420 dequantize node
@@ -487,7 +495,11 @@ def _replace_observer_with_quantize_dequantize_node(
487495 # For scale and zero_point values we register them as buffers in the root module.
488496 # TODO: maybe need more complex attr name here
489497 qparam_node = create_getattr_from_value (
490- model , graph , module_path + prefix + key , value_or_node
498+ model ,
499+ graph ,
500+ module_path + prefix + key ,
501+ value_or_node ,
502+ model_device ,
491503 )
492504 quantize_op_inputs .append (qparam_node )
493505 else :
@@ -785,6 +797,7 @@ def convert_weighted_module(
785797 backend_config : BackendConfig ,
786798 is_decomposed : bool = False ,
787799 is_reference : bool = False ,
800+ model_device : Optional [torch .device ] = None ,
788801) -> None :
789802 """Convert a weighted module to reference quantized module in the model
790803 If the QConfig of a QAT module is not set, the module will still be converted to
@@ -873,7 +886,10 @@ def convert_weighted_module(
873886 is_ptq = weight_post_process is None
874887 if is_ptq :
875888 weight_post_process = qconfig .weight () # type: ignore[union-attr, operator]
876- device = assert_and_get_unique_device (float_module )
889+ if model_device is not None :
890+ device = model_device
891+ else :
892+ device = _assert_and_get_unique_device (float_module )
877893 if device :
878894 weight_post_process .to (device )
879895
@@ -1076,6 +1092,7 @@ def convert(
10761092 root_module_classes = tuple (root_module_to_quantized_reference_module .keys ())
10771093 qat_module_classes = get_qat_module_classes (backend_config )
10781094 fused_module_classes = get_fused_module_classes (backend_config )
1095+ model_device = _assert_and_get_unique_device (model )
10791096
10801097 for node in list (model .graph .nodes ):
10811098 if node .op == "placeholder" :
@@ -1123,6 +1140,7 @@ def convert(
11231140 modules ,
11241141 node_name_to_scope ,
11251142 node_name_to_qconfig ,
1143+ model_device ,
11261144 )
11271145 else :
11281146 _replace_observer_with_quantize_dequantize_node (
@@ -1131,6 +1149,7 @@ def convert(
11311149 modules ,
11321150 node_name_to_scope ,
11331151 node_name_to_qconfig ,
1152+ model_device ,
11341153 )
11351154 elif isinstance (mod , DeQuantStub ):
11361155 _replace_observer_or_dequant_stub_with_dequantize_node (
@@ -1160,6 +1179,7 @@ def convert(
11601179 backend_config ,
11611180 is_decomposed ,
11621181 is_reference ,
1182+ model_device ,
11631183 )
11641184
11651185 # remove deadcode after converting observers to quant/dequant ops
0 commit comments