@@ -566,7 +566,10 @@ def convert_initializer(
566566 to_type = self .high_precision_type ,
567567 )
568568
569- def _replace_tensor_name (self , consumers , original_tensor_name , new_tensor_name ):
569+ def _replace_tensor_name (
570+ self , consumers : list [onnx .NodeProto ], original_tensor_name : str , new_tensor_name : str
571+ ) -> None :
572+ """Replace occurrences of a tensor name in the given consumers' inputs with a new tensor name."""
570573 for consumer in consumers :
571574 for idx , inp in enumerate (consumer .input ):
572575 if inp == original_tensor_name :
@@ -583,8 +586,8 @@ def _bypass_cast_node(self, node: onnx.NodeProto) -> None:
583586 # Check if the cast output is also a graph output
584587 is_output_producer = any (output .name == output_tensor for output in self .model .graph .output )
585588
586- # If the removed cast node is producing a network output, we need to update the node producing the cast, as
587- # the network output name should not be changed
589+ # If the removed cast node is producing a network output, update the producer of the cast input so
590+ # the network output name is preserved.
588591 if is_output_producer :
589592 producers = utils .get_producer_nodes (self .model , input_tensor )
590593 for producer in producers :
0 commit comments