Skip to content

Commit 870284f

Browse files
authored
[pt2e] Avoid getting model device once per node (#2695)
**Summary:** Previously, we call `assert_and_get_unqiue_device` once per node in both prepare and convert. This is expensive and unnecessary since the model device is the same across all nodes, so we should just call this once in the beginning and reuse the same model device across all the nodes. torchao version of pytorch/pytorch#159901 Note: The prepare path is not completely done yet, since we are blocked on the pytorch PR on being merged. It's different from convert since it still calls utility functions from `torch.ao.quantization.fx`. **Test Plan:** ``` python test/quantization/pt2e/test_quantize_pt2e.py ```
1 parent 8555713 commit 870284f

File tree

4 files changed

+63
-11
lines changed

4 files changed

+63
-11
lines changed

torchao/quantization/pt2e/convert.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,7 @@
4949
)
5050
from torch.ao.quantization.fx.utils import (
5151
_get_module,
52-
assert_and_get_unique_device,
5352
collect_producer_nodes,
54-
create_getattr_from_value,
5553
graph_module_from_producer_nodes,
5654
node_arg_is_weight,
5755
)
@@ -74,6 +72,8 @@
7472

7573
from torchao.quantization.pt2e import FROM_NODE_KEY
7674
from torchao.quantization.pt2e.observer import _is_activation_post_process
75+
from torchao.quantization.pt2e.utils import create_getattr_from_value
76+
from torchao.utils import _assert_and_get_unique_device
7777

7878
__all__ = [
7979
"convert",
@@ -129,6 +129,7 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
129129
modules: dict[str, torch.nn.Module],
130130
node_name_to_scope: dict[str, tuple[str, type]],
131131
node_name_to_qconfig: dict[str, QConfigAny],
132+
model_device: Optional[torch.device] = None,
132133
) -> None:
133134
"""Replace activation_post_process module call node with quantize and
134135
dequantize node working with decomposed Tensor
@@ -255,7 +256,11 @@ def add_quantize_dequantize_node_info(qdq_node, original_node):
255256
# sure that the default overload can be used.
256257
# TODO: maybe need more complex attr name here
257258
qparam_node = create_getattr_from_value(
258-
model, graph, module_path + prefix + key, value_or_node
259+
model,
260+
graph,
261+
module_path + prefix + key,
262+
value_or_node,
263+
model_device,
259264
)
260265
quantize_op_inputs.append(qparam_node)
261266
else:
@@ -402,6 +407,7 @@ def _replace_observer_with_quantize_dequantize_node(
402407
modules: dict[str, torch.nn.Module],
403408
node_name_to_scope: dict[str, tuple[str, type]],
404409
node_name_to_qconfig: dict[str, QConfigAny],
410+
model_device: Optional[torch.device] = None,
405411
) -> None:
406412
"""Replace activation_post_process module call node with quantize and
407413
dequantize node
@@ -482,7 +488,11 @@ def _replace_observer_with_quantize_dequantize_node(
482488
# For scale and zero_point values we register them as buffers in the root module.
483489
# TODO: maybe need more complex attr name here
484490
qparam_node = create_getattr_from_value(
485-
model, graph, module_path + prefix + key, value_or_node
491+
model,
492+
graph,
493+
module_path + prefix + key,
494+
value_or_node,
495+
model_device,
486496
)
487497
quantize_op_inputs.append(qparam_node)
488498
else:
@@ -780,6 +790,7 @@ def convert_weighted_module(
780790
backend_config: BackendConfig,
781791
is_decomposed: bool = False,
782792
is_reference: bool = False,
793+
model_device: Optional[torch.device] = None,
783794
) -> None:
784795
"""Convert a weighted module to reference quantized module in the model
785796
If the QConfig of a QAT module is not set, the module will still be converted to
@@ -868,7 +879,10 @@ def convert_weighted_module(
868879
is_ptq = weight_post_process is None
869880
if is_ptq:
870881
weight_post_process = qconfig.weight() # type: ignore[union-attr, operator]
871-
device = assert_and_get_unique_device(float_module)
882+
if model_device is not None:
883+
device = model_device
884+
else:
885+
device = _assert_and_get_unique_device(float_module)
872886
if device:
873887
weight_post_process.to(device)
874888

@@ -1071,6 +1085,7 @@ def convert(
10711085
root_module_classes = tuple(root_module_to_quantized_reference_module.keys())
10721086
qat_module_classes = get_qat_module_classes(backend_config)
10731087
fused_module_classes = get_fused_module_classes(backend_config)
1088+
model_device = _assert_and_get_unique_device(model)
10741089

10751090
for node in list(model.graph.nodes):
10761091
if node.op == "placeholder":
@@ -1118,6 +1133,7 @@ def convert(
11181133
modules,
11191134
node_name_to_scope,
11201135
node_name_to_qconfig,
1136+
model_device,
11211137
)
11221138
else:
11231139
_replace_observer_with_quantize_dequantize_node(
@@ -1126,6 +1142,7 @@ def convert(
11261142
modules,
11271143
node_name_to_scope,
11281144
node_name_to_qconfig,
1145+
model_device,
11291146
)
11301147
elif isinstance(mod, DeQuantStub):
11311148
_replace_observer_or_dequant_stub_with_dequantize_node(
@@ -1155,6 +1172,7 @@ def convert(
11551172
backend_config,
11561173
is_decomposed,
11571174
is_reference,
1175+
model_device,
11581176
)
11591177

11601178
# remove deadcode after converting observers to quant/dequant ops

torchao/quantization/pt2e/observer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1908,10 +1908,18 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node):
19081908
else:
19091909
scale, zero_point = self.calculate_qparams()
19101910
scale_node = create_getattr_from_value(
1911-
model, model.graph, "_scale", scale
1911+
model,
1912+
model.graph,
1913+
"_scale",
1914+
scale,
1915+
scale.device if isinstance(scale, torch.Tensor) else None,
19121916
)
19131917
zero_point_node = create_getattr_from_value(
1914-
model, model.graph, "_zero_point", zero_point
1918+
model,
1919+
model.graph,
1920+
"_zero_point",
1921+
zero_point,
1922+
zero_point.device if isinstance(zero_point, torch.Tensor) else None,
19151923
)
19161924

19171925
q_node = model.graph.call_function(

torchao/quantization/pt2e/prepare.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
SharedQuantizationSpec,
3939
)
4040
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
41+
from torchao.utils import _assert_and_get_unique_device
4142

4243
# TODO: make pt2e folder private?
4344
__all__ = [
@@ -408,6 +409,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
408409
named_modules: dict[str, torch.nn.Module],
409410
obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
410411
is_qat: bool,
412+
model_device: Optional[torch.device] = None,
411413
) -> Argument:
412414
"""
413415
Given a `node` and an `arg`, inserts an input observer between
@@ -426,6 +428,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
426428
named_modules,
427429
obs_or_fq_map,
428430
is_qat,
431+
model_device,
429432
)
430433
new_arg_to_return.append(new_inner_arg)
431434
return type(arg)(new_arg_to_return)
@@ -478,6 +481,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
478481
return maybe_obs_node
479482

480483
assert isinstance(model.graph, Graph)
484+
# TODO: pass in model_device here after https://github.com/pytorch/pytorch/pull/159901
481485
new_arg = _insert_obs_or_fq(
482486
arg, input_edge_obs_or_fq, model, named_modules, model.graph
483487
)
@@ -491,6 +495,7 @@ def _maybe_insert_input_observers_for_node(
491495
named_modules: dict[str, torch.nn.Module],
492496
obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
493497
is_qat: bool,
498+
model_device: Optional[torch.device] = None,
494499
) -> None:
495500
"""
496501
If needed, inserts observers to the input args and kwargs of `node`.
@@ -517,6 +522,7 @@ def _maybe_insert_input_observers_for_node(
517522
named_modules,
518523
obs_or_fq_map,
519524
is_qat,
525+
model_device,
520526
)
521527
new_args.append(new_arg)
522528

@@ -541,9 +547,11 @@ def _maybe_insert_output_observer_for_node(
541547
graph: Graph,
542548
obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
543549
is_qat: bool,
550+
model_device: Optional[torch.device] = None,
544551
) -> Optional[Node]:
545552
if node in obs_or_fq_map:
546553
output_act_obs_or_fq = obs_or_fq_map[node]
554+
# TODO: pass in model_device here after https://github.com/pytorch/pytorch/pull/159901
547555
new_output = _insert_obs_or_fq(
548556
node, output_act_obs_or_fq, model, named_modules, graph
549557
)
@@ -563,6 +571,7 @@ def _maybe_insert_input_and_output_observers_for_node(
563571
model: torch.fx.GraphModule,
564572
obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
565573
is_qat: bool,
574+
model_device: Optional[torch.device] = None,
566575
):
567576
this_node_quantization_annotation = (
568577
node.meta[Q_ANNOTATION_KEY] if Q_ANNOTATION_KEY in node.meta else None
@@ -578,6 +587,7 @@ def _maybe_insert_input_and_output_observers_for_node(
578587
named_modules,
579588
obs_or_fq_map,
580589
is_qat,
590+
model_device,
581591
)
582592

583593
output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor)
@@ -586,7 +596,13 @@ def _maybe_insert_input_and_output_observers_for_node(
586596

587597
# this returns the new observer node if it was needed
588598
maybe_output_obs_node = _maybe_insert_output_observer_for_node(
589-
node, model, named_modules, model.graph, obs_or_fq_map, is_qat
599+
node,
600+
model,
601+
named_modules,
602+
model.graph,
603+
obs_or_fq_map,
604+
is_qat,
605+
model_device,
590606
)
591607

592608
if maybe_output_obs_node is None:
@@ -634,11 +650,16 @@ def prepare(
634650
)
635651
if obs_or_fq_callback:
636652
obs_or_fq_callback(model, obs_or_fq_map)
653+
model_device = _assert_and_get_unique_device(model)
637654

638655
for node in nodes_before_observation:
639656
# TODO: simplify logic for inserting observers
640657
_maybe_insert_input_and_output_observers_for_node(
641-
node, model, obs_or_fq_map, is_qat
658+
node,
659+
model,
660+
obs_or_fq_map,
661+
is_qat,
662+
model_device,
642663
)
643664

644665
model = GraphModule(model, model.graph)

torchao/quantization/pt2e/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -525,15 +525,20 @@ def get_attr_name(i: int):
525525

526526

527527
def create_getattr_from_value(
528-
module: torch.nn.Module, graph: Graph, prefix: str, value: Any
528+
module: torch.nn.Module,
529+
graph: Graph,
530+
prefix: str,
531+
value: Any,
532+
device: Optional[torch.device] = None,
529533
) -> Node:
530534
"""
531535
Given a value of any type, creates a getattr node corresponding to the value and
532536
registers the value as a buffer to the module.
533537
"""
534538
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
535539
attr_name = get_new_attr_name(module)
536-
device = _assert_and_get_unique_device(module)
540+
if device is None:
541+
device = _assert_and_get_unique_device(module)
537542
new_value = (
538543
value.detach().clone()
539544
if isinstance(value, torch.Tensor)

0 commit comments

Comments
 (0)