diff --git a/graph_net/torch/decompose_util.py b/graph_net/torch/decompose_util.py index acbc4f8e0..c986b629d 100755 --- a/graph_net/torch/decompose_util.py +++ b/graph_net/torch/decompose_util.py @@ -85,6 +85,7 @@ def get_end_node_idx(range_idx): num_subgraphs = len(range_idx2submodule_body_nodes) for range_idx in range(num_subgraphs): + use_all_inputs = use_all_inputs and range_idx == 0 start, end = range_idx2range[range_idx] ( submodule_input_nodes, @@ -205,6 +206,7 @@ def sort_key(node): num_subgraphs = len(range_idx2submodule_body_nodes) for range_idx in range(num_subgraphs): + use_all_inputs = use_all_inputs and range_idx == 0 ( submodule_input_nodes, submodule_output_nodes,