11import logging
2- from typing import Callable , Dict , List , Optional , Sequence , Set
2+ from typing import Dict , List , Optional , Sequence , Set
33
44import torch
55
@@ -55,10 +55,6 @@ def __init__(
5555 )
5656
5757 self .min_block_size = min_block_size
58- logger .debug (
59- "Initialized Capability-Based Partitioner with available Converters:\n "
60- + f"{ CONVERTERS .display_all_available_converters ()} "
61- )
6258
6359 def propose_partitions (self ) -> List [Partition ]:
6460 # Propose partitions using the default, then refine the results
@@ -114,8 +110,8 @@ def __init__(self, support_dict=None, torch_executed_ops=set()):
114110 super ().__init__ (support_dict )
115111
116112 # Initialize sets of supported/unsupported operators
117- self .supported_operators = set ()
118- self .unsupported_operators = set ()
113+ self .supported_operators = {}
114+ self .unsupported_operators = {}
119115 self .torch_executed_ops = torch_executed_ops
120116
121117 def is_node_supported (
@@ -130,12 +126,18 @@ def is_node_supported(
130126 if node in CONVERTERS and node_name not in self .torch_executed_ops :
131127 # If node is a proper, supported computational node, store the operator
132128 if not node .is_impure ():
133- self .supported_operators .add (node_name )
129+ if node_name not in self .supported_operators :
130+ self .supported_operators [node_name ] = 1
131+ else :
132+ self .supported_operators [node_name ] += 1
134133
135134 return True
136135 else :
137136 if not node .is_impure ():
138- self .unsupported_operators .add (node_name )
137+ if node_name not in self .unsupported_operators :
138+ self .unsupported_operators [node_name ] = 1
139+ else :
140+ self .unsupported_operators [node_name ] += 1
139141
140142 return False
141143
@@ -147,15 +149,16 @@ def print_support_overview(self, num_trt_blocks: Optional[int] = None):
147149
148150 # Reformat support messages for debugger to print node overview as a single string
149151 supported_nodes_str = "\n Supported Nodes:\n "
150- for node_name in self .supported_operators :
151- supported_nodes_str += f"- { node_name } \n "
152+ for node_name , count in self .supported_operators . items () :
153+ supported_nodes_str += f"- { node_name } + Operator Count: { count } \n "
152154
153155 logger .debug (supported_nodes_str )
154156
155- if len ( self .unsupported_operators ) != 0 :
157+ if self .unsupported_operators :
156158 unsupported_nodes_str = "\n Unsupported or Excluded Nodes:\n "
157- for node_name in self .unsupported_operators :
158- unsupported_nodes_str += f"- { node_name } \n "
159+ for node_name , count in self .unsupported_operators .items ():
160+ unsupported_nodes_str += f"- { node_name } + Operator Count: { count } \n "
161+
159162 logger .debug (unsupported_nodes_str )
160163 else :
161164 logger .debug ("\n All Nodes Supported\n " )
0 commit comments