Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion graph_net/tensor_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,22 @@ class TensorMeta:
max_val: int | None
min_val: int | None

@classmethod
def reset_tensor_metas_by_original_name(cls, mut_file_path, const_file_path):
mut_tensor_metas = cls.unserialize_from_py_file(mut_file_path)
const_tensor_metas = cls.unserialize_from_py_file(const_file_path)
name2const_tensor_meta = {tensor.name: tensor for tensor in const_tensor_metas}

def get_name(tensor_meta):
old_name = getattr(tensor_meta, "original_name", None)
return old_name if old_name is not None else tensor_meta.name

new_tensor_metas = [
name2const_tensor_meta.get(get_name(mut_tensor_meta), mut_tensor_meta)
for mut_tensor_meta in mut_tensor_metas
]
cls.save_tensor_metas(mut_file_path, new_tensor_metas)

@classmethod
def unserialize_from_py_file(cls, file_path: str) -> list["TensorMeta"]:
return [
Expand Down Expand Up @@ -104,7 +120,7 @@ def update_shape_safely(self, shape):
self.data = extended_tensor_data[:size]

@classmethod
def save_tensor_metas(cls, file_path: str, tensor_metas: list):
def save_tensor_metas(cls, file_path: str | Path, tensor_metas: list):
py_code = "\n\n".join(
tensor_meta.serialize_to_py_str() for tensor_meta in tensor_metas
)
Expand Down
29 changes: 20 additions & 9 deletions graph_net/torch/sample_pass/subgraph_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from graph_net.sample_pass.sample_pass import SamplePass
from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin
from graph_net.tensor_meta import TensorMeta
import graph_net.subgraph_range_util as range_util
import os
import shutil
Expand Down Expand Up @@ -145,31 +146,41 @@ def __init__(
else:
submodule_name = f"{parent_graph_model_name}_start{subgraph_start}_end{subgraph_end}_{self.seq_no}"
self.model_name = submodule_name
self.workspace_path = os.path.join(
self.config["output_dir"], parent_graph_rel_model_path, "_decomposed"
)
self.builtin_extractor = BuiltinGraphExtractor(
name=submodule_name,
dynamic=False,
mut_graph_codes=[],
placeholder_auto_rename=False,
workspace_path=os.path.join(
self.config["output_dir"], parent_graph_rel_model_path, "_decomposed"
),
workspace_path=self.workspace_path,
)
self._save_subgraph_sources()

def _get_model_path(self) -> Path:
return (
Path(self.config["output_dir"])
/ self.parent_graph_rel_model_path
/ "_decomposed"
/ self.model_name
)
return Path(self.workspace_path) / self.model_name

def forward(self, *args):
if not self.extracted:
self.builtin_extractor(self.submodule, args)
self._reset_tensor_metas_by_parent()
self.extracted = True
return self.submodule(*args)

def _reset_tensor_metas_by_parent(self):
parent_model_path = (
Path(self.parent_graph_model_path_root) / self.parent_graph_rel_model_path
)
TensorMeta.reset_tensor_metas_by_original_name(
mut_file_path=self._get_model_path() / "input_meta.py",
const_file_path=parent_model_path / "input_meta.py",
)
TensorMeta.reset_tensor_metas_by_original_name(
mut_file_path=self._get_model_path() / "weight_meta.py",
const_file_path=parent_model_path / "weight_meta.py",
)

def _save_subgraph_sources(self):
sources_json_obj = self._get_sources_json_obj()
model_path = self._get_model_path()
Expand Down
Loading