@@ -579,24 +579,23 @@ def _bypass_cast_node(self, node: onnx.NodeProto) -> None:
579579
580580 input_tensor = node .input [0 ]
581581 output_tensor = node .output [0 ]
582- is_output_producer = False
583582
584- # If removed cast node is producing a network output, we need to update the node producing the cast
585- # Network output name should not be changed
586- for output in self . model . graph . output :
587- if output . name == output_tensor :
588- is_output_producer = True
589- producers = utils . get_producer_nodes ( self . model , input_tensor )
590- for producer in producers :
591- for i , prod_out in enumerate ( producer . output ) :
592- if prod_out == input_tensor :
593- producer . output [ i ] = output_tensor
594- consumers = utils . get_consumer_nodes ( self . model , prod_out )
595- if len ( consumers ) > 1 :
596- self . _replace_tensor_name (consumers , prod_out , output_tensor )
597- if (
598- not is_output_producer
599- ): # Reconnect consumers of the cast output to use the cast input instead
583+ # Check if the cast output is also a graph output
584+ is_output_producer = any ( output . name == output_tensor for output in self . model . graph . output )
585+
586+ # If the removed cast node is producing a network output, we need to update the node producing the cast, as
587+ # the network output name should not be changed
588+ if is_output_producer :
589+ producers = utils . get_producer_nodes ( self . model , input_tensor )
590+ for producer in producers :
591+ for i , prod_out in enumerate ( producer . output ) :
592+ if prod_out == input_tensor :
593+ producer . output [ i ] = output_tensor
594+ consumers = utils . get_consumer_nodes ( self . model , prod_out )
595+ if len (consumers ) > 1 :
596+ self . _replace_tensor_name ( consumers , prod_out , output_tensor )
597+ else :
598+ # Reconnect consumers of the cast output to use the cast input instead
600599 consumers = utils .get_consumer_nodes (self .model , output_tensor )
601600 for consumer in consumers :
602601 for i , input_name in enumerate (consumer .input ):
0 commit comments