diff --git a/graph_net/tensor_meta.py b/graph_net/tensor_meta.py index 4f4c5b04c..9744616d8 100755 --- a/graph_net/tensor_meta.py +++ b/graph_net/tensor_meta.py @@ -92,7 +92,13 @@ def _get_classes(cls, file_path, name="unnamed"): spec = imp.spec_from_file_location(name, file_path) unnamed = imp.module_from_spec(spec) spec.loader.exec_module(unnamed) - yield from inspect.getmembers(unnamed, inspect.isclass) + classes = inspect.getmembers(unnamed, inspect.isclass) + file_content = Path(file_path).read_text() + class2pos = { + cls_pair: file_content.find(cls_pair[1].__name__) for cls_pair in classes + } + classes.sort(key=lambda x: class2pos[x]) + return classes @classmethod def _get_classes_order_preserved(cls, file_path, name="unnamed"): diff --git a/graph_net/torch/decompose_util.py b/graph_net/torch/decompose_util.py index c986b629d..85b5d4ac7 100755 --- a/graph_net/torch/decompose_util.py +++ b/graph_net/torch/decompose_util.py @@ -83,6 +83,13 @@ def get_end_node_idx(range_idx): return i + 1 raise NotImplementedError("Dead code.") + new_node2original_node = {} + for node in gm.graph.nodes: + new_node2original_node[node] = node + + def sort_key(node): + return new_node2original_node[node].name + num_subgraphs = len(range_idx2submodule_body_nodes) for range_idx in range(num_subgraphs): use_all_inputs = use_all_inputs and range_idx == 0 @@ -98,7 +105,13 @@ def get_end_node_idx(range_idx): chain_style=chain_style, use_all_inputs=use_all_inputs, ) - yield start, end, submodule_input_nodes + + def get_input_nodes(range_idx): + if use_all_inputs: + return submodule_input_nodes + return sorted(submodule_input_nodes, key=sort_key) + + yield start, end, get_input_nodes(range_idx) def convert_to_submodules_graph( @@ -221,6 +234,8 @@ def sort_key(node): identity_node_set = set(identity_nodes) def get_input_nodes(range_idx): + if use_all_inputs: + return submodule_input_nodes return sorted(submodule_input_nodes, key=sort_key) def get_output_nodes(range_idx): @@ -338,9 +353,7 @@ def _get_submodule_inputs_and_outputs( ) if use_all_inputs: node_list = list(gm.graph.nodes) - input_nodes, _ = _get_minimal_submodule_inputs_and_outputs( - gm=gm, start_node_idx=start_node_idx, end_node_idx=len(node_list) - ) + input_nodes = [node for node in node_list if node.op == "placeholder"] else: input_nodes = minimal_input_nodes return input_nodes, minimal_output_nodes, []