Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions graph_net/tensor_meta.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def serialize_to_py_str(self) -> str:
lines = [
(f"class {self.record_class_name}:"),
(f'\tname = "{self.name}"'),
*(
[f'\toriginal_name = "{self.original_name}"']
if self.original_name is not None
else []
),
(f"\tshape = {self.shape}"),
(f'\tdtype = "{self.dtype}"'),
(f'\tdevice = "{self.device}"'),
Expand Down
27 changes: 27 additions & 0 deletions graph_net/test/graph_variable_rename_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/bin/bash

GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
os.path.dirname(graph_net.__file__))")

# input model path
MODEL_NAME=resnet18
MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME
config_json_str=$(cat <<EOF
{
"handler_path": "$GRAPH_NET_ROOT/torch/graph_variable_renamer.py",
"handler_class_name": "GraphVariableRenamer",
"handler_config": {
"model_path_prefix": "$GRAPH_NET_ROOT/../",
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
"data_input_predicator_class_name": "NaiveDataInputPredicator",
"model_runnable_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
"model_runnable_predicator_class_name": "ModelRunnablePredicator",
"output_dir": "/tmp/graph_variable_rename_workspace"
}
}
EOF
)
CONFIG=$(echo $config_json_str | base64 -w 0)

python3 -m graph_net.model_path_handler --model-path samples/$MODEL_PATH_IN_SAMPLES --handler-config=$CONFIG
# python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/decomposition_error_tmp_torch_samples_list.txt --handler-config=$CONFIG
199 changes: 199 additions & 0 deletions graph_net/torch/graph_variable_renamer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import os
import torch
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module
from graph_net.tensor_meta import TensorMeta
from pathlib import Path
import shutil
from graph_net.torch.utils import apply_templates
from graph_net.imp_util import load_module
import inspect


class GraphVariableRenamer:
"""
Used by graph_net.model_path_handler
"""

def __init__(self, config: dict = None):
if config is None:
config = {}
self.config = self._make_config(**config)
self.data_input_predicator = self._make_data_input_predicator(self.config)
self.model_runnable_predicator = self._make_model_runnable_predicator(
self.config
)

def _make_data_input_predicator(self, config):
module = load_module(config["data_input_predicator_filepath"])
cls = getattr(module, config["data_input_predicator_class_name"])
return cls(config["data_input_predicator_config"])

def _make_model_runnable_predicator(self, config):
module = load_module(config["model_runnable_predicator_filepath"])
cls = getattr(module, config["model_runnable_predicator_class_name"])
return cls(config["model_runnable_predicator_config"])

def _make_config(
self,
data_input_predicator_filepath,
model_runnable_predicator_filepath,
output_dir="./tmp/graph_variable_renamer_dir",
filter_path=None,
filter_config=None,
post_extract_process_path=None,
post_extract_process_class_name=None,
post_extract_process_config=None,
data_input_predicator_class_name="DataInputPredicator",
model_runnable_predicator_class_name="ModelRunner",
data_input_predicator_config=None,
model_runnable_predicator_config=None,
model_path_prefix="",
**kwargs,
):
if post_extract_process_config is None:
post_extract_process_config = {}
if data_input_predicator_config is None:
data_input_predicator_config = {}
if model_runnable_predicator_config is None:
model_runnable_predicator_config = {}
return {
"output_dir": output_dir,
"filter_path": filter_path,
"filter_config": filter_config if filter_config is not None else {},
"post_extract_process_path": post_extract_process_path,
"post_extract_process_class_name": post_extract_process_class_name,
"post_extract_process_config": post_extract_process_config,
"data_input_predicator_filepath": data_input_predicator_filepath,
"data_input_predicator_class_name": data_input_predicator_class_name,
"data_input_predicator_config": data_input_predicator_config,
"model_runnable_predicator_filepath": model_runnable_predicator_filepath,
"model_runnable_predicator_class_name": model_runnable_predicator_class_name,
"model_runnable_predicator_config": model_runnable_predicator_config,
"model_path_prefix": model_path_prefix,
}

def __call__(self, rel_model_path):
src_model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
module, inputs = get_torch_module_and_inputs(src_model_path)
gm = parse_sole_graph_module(module, inputs)
gm = self.rename_graph_variables(gm, inputs, src_model_path)
dst_model_path = os.path.realpath(
os.path.join(self.config["output_dir"], rel_model_path)
)
Path(dst_model_path).parent.mkdir(parents=True, exist_ok=True)
shutil.copytree(src_model_path, dst_model_path, dirs_exist_ok=True)
self._update_model_py_file(gm, dst_model_path)
self._update_weight_meta_py_file(src_model_path, dst_model_path)
self._update_input_meta_py_file(src_model_path, dst_model_path)
self._try_run(dst_model_path)

def _try_run(self, model_path):
assert self.model_runnable_predicator(
model_path
), f"{model_path} is not a runnable model"

def _update_model_py_file(self, graph_module, model_path):
py_code = apply_templates(graph_module.code)
(Path(model_path) / "model.py").write_text(py_code)

def _update_weight_meta_py_file(self, src_model_path, dst_model_path):
old_name_to_new_name = self._get_original_name_to_new_name(
src_model_path, dst_model_path
)
tensor_metas = TensorMeta.unserialize_from_py_file(
os.path.join(src_model_path, "weight_meta.py"),
)
for weight_meta in tensor_metas:
assert weight_meta.name in old_name_to_new_name
if weight_meta.original_name is None:
weight_meta.original_name = weight_meta.name
weight_meta.name = old_name_to_new_name[weight_meta.name]
py_code = "\n\n".join(
[weight_meta.serialize_to_py_str() for weight_meta in tensor_metas]
)
(Path(dst_model_path) / "weight_meta.py").write_text(py_code)

def _update_input_meta_py_file(self, src_model_path, dst_model_path):
old_name_to_new_name = self._get_original_name_to_new_name(
src_model_path, dst_model_path
)
tensor_metas = TensorMeta.unserialize_from_py_file(
os.path.join(src_model_path, "input_meta.py"),
)
for input_meta in tensor_metas:
assert input_meta.name in old_name_to_new_name
if input_meta.original_name is None:
input_meta.original_name = input_meta.name
input_meta.name = old_name_to_new_name[input_meta.name]
py_code = "\n\n".join(
[input_meta.serialize_to_py_str() for input_meta in tensor_metas]
)
(Path(dst_model_path) / "input_meta.py").write_text(py_code)

def _get_original_name_to_new_name(self, src_model_path, dst_model_path):
src_model = self._get_model(src_model_path)
dst_model = self._get_model(dst_model_path)
old_name_and_new_name_pairs = zip(
self._get_input_names_from_signature(src_model),
self._get_input_names_from_signature(dst_model),
strict=True,
)
return {
old_name: new_name for old_name, new_name in old_name_and_new_name_pairs
}

def _get_model(self, model_path):
py_module = load_module(os.path.join(model_path, "model.py"))
GraphModule = getattr(py_module, "GraphModule")
GraphModule.__graph_net_file_path__ = py_module.__graph_net_file_path__
return GraphModule()

def _get_input_names_from_signature(self, module):
return inspect.signature(module.forward).parameters

def rename_graph_variables(
self, gm: torch.fx.GraphModule, sample_inputs, model_path
):
in_cnt = 0
w_cnt = 0
tmp_cnt = 0

arg_iter = iter(sample_inputs)
for node in gm.graph.nodes:
if "original_name" not in node.meta:
node.meta["original_name"] = node.name

if node.op == "placeholder":
real_arg = next(arg_iter)
is_weight = not self.data_input_predicator(model_path, node.name)
if node.type is not None:
if isinstance(node.type, type) and issubclass(
node.type, torch.nn.parameter.Parameter
):
is_weight = True
elif real_arg is not None:
if isinstance(real_arg, torch.nn.Parameter):
is_weight = True

if is_weight:
new_name = f"w_{w_cnt}"
w_cnt += 1
else:
new_name = f"in_{in_cnt}"
in_cnt += 1

node.name = new_name
node.target = new_name

elif node.op == "get_attr":
node.name = f"w_{w_cnt}"
w_cnt += 1

elif node.op != "output":
node.name = f"tmp_{tmp_cnt}"
tmp_cnt += 1

gm.graph.lint()
gm.recompile()
return gm