diff --git a/graph_net/tensor_meta.py b/graph_net/tensor_meta.py index cd1e03a0f..4f4c5b04c 100755 --- a/graph_net/tensor_meta.py +++ b/graph_net/tensor_meta.py @@ -21,6 +21,22 @@ class TensorMeta: max_val: int | None min_val: int | None + @classmethod + def reset_tensor_metas_by_original_name(cls, mut_file_path, const_file_path): + mut_tensor_metas = cls.unserialize_from_py_file(mut_file_path) + const_tensor_metas = cls.unserialize_from_py_file(const_file_path) + name2const_tensor_meta = {tensor.name: tensor for tensor in const_tensor_metas} + + def get_name(tensor_meta): + old_name = getattr(tensor_meta, "original_name", None) + return old_name if old_name is not None else tensor_meta.name + + new_tensor_metas = [ + name2const_tensor_meta.get(get_name(mut_tensor_meta), mut_tensor_meta) + for mut_tensor_meta in mut_tensor_metas + ] + cls.save_tensor_metas(mut_file_path, new_tensor_metas) + @classmethod def unserialize_from_py_file(cls, file_path: str) -> list["TensorMeta"]: return [ @@ -104,7 +120,7 @@ def update_shape_safely(self, shape): self.data = extended_tensor_data[:size] @classmethod - def save_tensor_metas(cls, file_path: str, tensor_metas: list): + def save_tensor_metas(cls, file_path: str | Path, tensor_metas: list): py_code = "\n\n".join( tensor_meta.serialize_to_py_str() for tensor_meta in tensor_metas ) diff --git a/graph_net/torch/sample_pass/subgraph_generator.py b/graph_net/torch/sample_pass/subgraph_generator.py index 82f1d3b66..4e5d773f5 100644 --- a/graph_net/torch/sample_pass/subgraph_generator.py +++ b/graph_net/torch/sample_pass/subgraph_generator.py @@ -1,5 +1,6 @@ from graph_net.sample_pass.sample_pass import SamplePass from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin +from graph_net.tensor_meta import TensorMeta import graph_net.subgraph_range_util as range_util import os import shutil @@ -145,31 +146,41 @@ def __init__( else: submodule_name = f"{parent_graph_model_name}_start{subgraph_start}_end{subgraph_end}_{self.seq_no}" self.model_name = submodule_name + self.workspace_path = os.path.join( + self.config["output_dir"], parent_graph_rel_model_path, "_decomposed" + ) self.builtin_extractor = BuiltinGraphExtractor( name=submodule_name, dynamic=False, mut_graph_codes=[], placeholder_auto_rename=False, - workspace_path=os.path.join( - self.config["output_dir"], parent_graph_rel_model_path, "_decomposed" - ), + workspace_path=self.workspace_path, ) self._save_subgraph_sources() def _get_model_path(self) -> Path: - return ( - Path(self.config["output_dir"]) - / self.parent_graph_rel_model_path - / "_decomposed" - / self.model_name - ) + return Path(self.workspace_path) / self.model_name def forward(self, *args): if not self.extracted: self.builtin_extractor(self.submodule, args) + self._reset_tensor_metas_by_parent() self.extracted = True return self.submodule(*args) + def _reset_tensor_metas_by_parent(self): + parent_model_path = ( + Path(self.parent_graph_model_path_root) / self.parent_graph_rel_model_path + ) + TensorMeta.reset_tensor_metas_by_original_name( + mut_file_path=self._get_model_path() / "input_meta.py", + const_file_path=parent_model_path / "input_meta.py", + ) + TensorMeta.reset_tensor_metas_by_original_name( + mut_file_path=self._get_model_path() / "weight_meta.py", + const_file_path=parent_model_path / "weight_meta.py", + ) + def _save_subgraph_sources(self): sources_json_obj = self._get_sources_json_obj() model_path = self._get_model_path()