1212import torch .fx
1313import torch .nn .functional as F
1414from executorch .backends .arm .common .debug import get_node_debug_info
15+ from executorch .backends .arm .common .type import ensure_type
1516from executorch .backends .arm .quantizer import QuantizationConfig
1617from torch ._subclasses import FakeTensor
1718
@@ -510,7 +511,8 @@ def any_or_hardtanh_min_zero(n: Node):
510511 torch .ops .aten .minimum .default ,
511512 torch .ops .aten .maximum .default ,
512513 ):
513- shared_qspec = SharedQuantizationSpec ((node .args [0 ], node )) # type: ignore[arg-type]
514+ lhs_node = ensure_type (Node , node .args [0 ])
515+ shared_qspec = SharedQuantizationSpec ((lhs_node , node ))
514516 quant_properties .quant_inputs = [
515517 _QuantProperty (0 , input_act_qspec ),
516518 _QuantProperty (
@@ -520,22 +522,24 @@ def any_or_hardtanh_min_zero(n: Node):
520522 ]
521523 quant_properties .quant_output = _QuantProperty (0 , shared_qspec )
522524 elif node .target in (torch .ops .aten .where .self ,):
523- shared_qspec = SharedQuantizationSpec (node .args [1 ]) # type: ignore[arg-type]
525+ true_node = ensure_type (Node , node .args [1 ])
526+ shared_qspec = SharedQuantizationSpec (true_node )
524527 quant_properties .quant_inputs = [
525528 _QuantProperty (1 , shared_qspec ),
526529 _QuantProperty (2 , shared_qspec ),
527530 ]
528531 quant_properties .quant_output = _QuantProperty (0 , shared_qspec )
529532 elif node .target in _one_to_one_shared_input_or_input_act_qspec :
533+ input_node = ensure_type (Node , node .args [0 ])
530534 input_qspec = (
531- SharedQuantizationSpec (node . args [ 0 ]) # type: ignore[arg-type]
532- if is_output_annotated (node . args [ 0 ]) # type: ignore[arg-type]
535+ SharedQuantizationSpec (input_node )
536+ if is_output_annotated (input_node )
533537 else input_act_qspec
534538 )
535539 quant_properties .quant_inputs = [_QuantProperty (0 , input_qspec )]
536540 quant_properties .quant_output = _QuantProperty (
537541 0 ,
538- SharedQuantizationSpec ((node . args [ 0 ] , node )), # type: ignore[arg-type]
542+ SharedQuantizationSpec ((input_node , node )),
539543 )
540544 elif node .target in (
541545 torch .ops .aten .cat .default ,
@@ -550,26 +554,24 @@ def any_or_hardtanh_min_zero(n: Node):
550554 )
551555 if len (node .args [0 ]) == 0 :
552556 raise ValueError ("Expected non-empty list for node.args[0]" )
553-
554- shared_qspec = SharedQuantizationSpec ((node . args [0 ][ 0 ] , node )) # type: ignore[arg-type]
557+ inputs = [ ensure_type ( Node , element ) for element in node . args [ 0 ]]
558+ shared_qspec = SharedQuantizationSpec ((inputs [0 ], node ))
555559 quant_properties .quant_inputs = [
556560 _QuantProperty (
557561 0 ,
558- [
559- input_act_qspec if n == node .args [0 ][0 ] else shared_qspec # type: ignore[misc]
560- for n in node .args [0 ]
561- ],
562+ [input_act_qspec if n == inputs [0 ] else shared_qspec for n in inputs ],
562563 )
563564 ]
564565 quant_properties .quant_output = _QuantProperty (0 , shared_qspec )
565566 elif node .target in _one_to_one :
566567 quant_properties .quant_inputs = [_QuantProperty (0 , input_act_qspec )]
567568 quant_properties .quant_output = _QuantProperty (0 , output_act_qspec )
568569 elif node .target in _one_to_one_shared_input_qspec :
570+ input_node = ensure_type (Node , node .args [0 ])
569571 quant_properties .quant_inputs = [_QuantProperty (0 , input_act_qspec )]
570572 quant_properties .quant_output = _QuantProperty (
571573 0 ,
572- SharedQuantizationSpec ((node . args [ 0 ] , node )), # type: ignore[arg-type]
574+ SharedQuantizationSpec ((input_node , node )),
573575 )
574576 elif node .target in [
575577 torch .ops .aten .eq .Tensor ,
@@ -578,7 +580,8 @@ def any_or_hardtanh_min_zero(n: Node):
578580 torch .ops .aten .le .Tensor ,
579581 torch .ops .aten .lt .Tensor ,
580582 ]:
581- shared_qspec = SharedQuantizationSpec ((node .args [0 ], node )) # type: ignore[arg-type]
583+ input_node = ensure_type (Node , node .args [0 ])
584+ shared_qspec = SharedQuantizationSpec ((input_node , node ))
582585 quant_properties .quant_inputs = [
583586 _QuantProperty (0 , input_act_qspec ),
584587 _QuantProperty (
@@ -596,9 +599,10 @@ def any_or_hardtanh_min_zero(n: Node):
596599 quant_properties .quant_inputs = []
597600 quant_properties .quant_output = _QuantProperty (0 , output_act_qspec )
598601 elif node .target in [operator .getitem ]:
599- if not is_output_annotated (node .args [0 ]): # type: ignore[arg-type]
602+ input_node = ensure_type (Node , node .args [0 ])
603+ if not is_output_annotated (input_node ):
600604 return None
601- shared_qspec = SharedQuantizationSpec (node . args [ 0 ]) # type: ignore[arg-type]
605+ shared_qspec = SharedQuantizationSpec (input_node )
602606 quant_properties .quant_inputs = [_QuantProperty (0 , shared_qspec )]
603607 quant_properties .quant_output = _QuantProperty (0 , shared_qspec )
604608 else :
0 commit comments