Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__pycache__/
26 changes: 25 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,25 @@
# GraphNet
# GraphNet

## 📌 项目简介
GraphNet —— 一个面向编译器开发的大规模数据集,旨在为研究者提供一个统一、开放的实验平台。其中包含大量来自真实模型的计算图,方便评估不同编译器Pass的优化效果。

通过 GraphNet,用户可以:

1. 快速测试不同编译器策略的通用优化效果
2. 训练AI-for-system模型以自动生成编译器优化Pass


## 计算图抽取Demo
### torch
```
export PYTHONPATH=$PYTHONPATH:/path/to/your/GraphNet/repo
python3 -m graph_net.torch.extractor.vision_model_extractor --key resnet18 --model-path /path/to/your/extracted/graph_net/sample
```

## 计算图运行Demo
### torch
```
export PYTHONPATH=$PYTHONPATH:/path/to/your/GraphNet/repo
python3 -m graph_net.torch.runner.single_device_runner --model-path /path/to/your/extracted/graph_net/sample
```

Binary file not shown.
Binary file not shown.
180 changes: 180 additions & 0 deletions graph_net/torch/extractor/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import re
import torch
import torch.nn as nn
from collections import OrderedDict
import uuid
import json
import os
import argparse

dyn_template = """
%MODULE
"""

def convert_param_name(original_name):
if original_name.endswith(('.weight', '.bias')):
prefix = 'p_'
base_name = original_name

elif any(x in original_name for x in ['running_mean', 'running_var', 'num_batches_tracked']):
prefix = 'b_'
base_name = original_name
else:
raise ValueError(f"Unrecognized parameter type: {original_name}")

if '.' in base_name:
parts = base_name.split('.')
if len(parts) == 2 and not parts[0].startswith('layer'):
return prefix + parts[0] + '_' + parts[1]
else:
# layer1.0 -> layer1___0___
pattern = r'(layer\d+)\.(\d+)\.'
replacement = r'\1___\2___'
converted = re.sub(pattern, replacement, base_name)
converted = converted.replace('.', '_')
return f"{prefix}getattr_l__self___{converted}"
else:
return prefix + base_name

def indent_with_tab(code: str) -> str:
lines = code.splitlines()
indented_lines = [f" {line}" for line in lines]
return "\n".join(indented_lines)

def apply_templates(code: str) -> str:
code = indent_with_tab(code)
code = code.replace(" GraphModule()", "class GraphModule(torch.nn.Module):")
code = code.replace(" \n" * 3, "\n")
py_code = dyn_template.replace('%MODULE', code)
return py_code


def convert_state_and_inputs(state_dict, example_inputs):
def tensor_info(tensor):
is_float = tensor.dtype.is_floating_point
return {
"shape": list(tensor.shape),
"dtype": str(tensor.dtype),
"device": str(tensor.device),
"mean": float(tensor.mean().item()) if is_float else None,
"std": float(tensor.std().item()) if is_float else None,
}

def process_tensor(tensor):
if not isinstance(tensor, torch.Tensor):
return {"type": "unknown", "value": tensor}

info = tensor_info(tensor)
if tensor.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
if tensor.numel() < 1024:
return {"type": "small_int_tensor", "data": tensor.clone(), "info": info}
else:
return {"type": "big_int_tensor", "data": tensor.clone(), "info": info}
elif tensor.numel() < 1024:
return {"type": "small_tensor", "data": tensor.clone(), "info": info}
else:
return {"type": "random_tensor", "info": info}

if isinstance(example_inputs, torch.Tensor):
processed_inputs = process_tensor(example_inputs)
elif isinstance(example_inputs, (list, tuple)):
processed_inputs = [process_tensor(t) for t in example_inputs]
else:
processed_inputs = {"type": "unknown", "value": example_inputs}

processed_weights = {}
for key, tensor in state_dict.items():
data_value = None
data_type = "random_tensor"
if tensor.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
if tensor.numel() < 1024:
data_type = "small_int_tensor"
data_value = tensor.clone()
else:
data_type = "big_int_tensor"

info = tensor_info(tensor)
processed_weights[key] = {"info": info}
processed_weights[key]["data"] = data_value
processed_weights[key]["type"] = data_type

# dynamic_shapes = extract_dynamic_shapes(example_inputs)
return {
"input_info": processed_inputs,
"weight_info": processed_weights,
"dynamic_shapes": None
}

def save_constraints_text(converted, file_path):
lines = []
if converted["dynamic_shapes"] is not None:
raise NotImplementedError("Handling constraints is not implemented yet.")
with open(file_path, 'w') as f:
f.write("\n".join(lines))

def save_converted_to_text(converted, file_path):
def generate_uid():
return str(uuid.uuid4()).replace('-', '')

def format_data(data):
if data is None:
return "None"
elif isinstance(data, torch.Tensor):
if data.dtype.is_floating_point:
return "[{}]".format(", ".join(f'{x:.6f}' for x in data.tolist()))
else:
return "[{}]".format(", ".join(f'{x}' for x in data.tolist()))
else:
return repr(data)

lines = [[],[]]

def process_tensor_info(tensor_info, name_prefix="example_input"):
data_list = None
file_index = 1
if "input_" in tensor_info["name"]:
file_index = 0
if tensor_info["type"] in ["small_tensor", "small_int_tensor"]:
data_list = tensor_info["data"].flatten()
elif tensor_info["type"] == "big_int_tensor":
data_list = f'pt-filename:xxx-key'
else:
pass
else:
if tensor_info["type"] == "small_int_tensor":
data_list = tensor_info["data"].flatten()
if tensor_info["type"] == "big_int_tensor":
raise ValueError("Unexpected cases: there are weights in big tensor of int type ")
info = tensor_info.get("info", {})
dtype = info.get("dtype", "torch.float")
shape = info.get("shape", [])
device = info.get("device", "cpu")
mean = info.get("mean", 0.0)
std = info.get("std", 1.0)
uid = f"{name_prefix}_tensor_meta_{generate_uid()}"
lines[file_index].append(f"class {uid}:")
lines[file_index].append(f"\tname = \"{tensor_info.get('name', '')}\"")
lines[file_index].append(f"\tshape = {shape}")
lines[file_index].append(f"\tdtype = \"{dtype}\"")
lines[file_index].append(f"\tdevice = \"{device}\"")
lines[file_index].append(f"\tmean = {mean}")
lines[file_index].append(f"\tstd = {std}")
lines[file_index].append(f"\tdata = {format_data(data_list)}")
lines[file_index].append("")

input_infos = converted["input_info"]
if isinstance(input_infos, dict):
input_infos = [input_infos]

for idx, input_info in enumerate(input_infos):
input_info["name"] = f"input_{idx}"
process_tensor_info(input_info, name_prefix="Program_input")

for name, weight_info in converted["weight_info"].items():
weight_info["name"] = name
process_tensor_info(weight_info, name_prefix="Program_weight")

with open(f"{file_path}/input_meta.py", 'w') as f:
f.write("\n".join(lines[0]))
with open(f"{file_path}/weight_meta.py", 'w') as f:
f.write("\n".join(lines[1]))
101 changes: 101 additions & 0 deletions graph_net/torch/extractor/vision_model_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import argparse
import os
import json
import torch
import torchvision
from torchvision import transforms
from torch.export import export
from torch import nn
import graph_net.torch.extractor.utils as utils
from graph_net.torch.extractor.utils import convert_param_name, indent_with_tab, apply_templates


def main(key, model_path):
# Normalization parameters for ImageNet
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)

# Create dummy input
batch_size = 1
height, width = 224, 224 # Standard ImageNet size
num_channels = 3
random_input = torch.rand(batch_size, num_channels, height, width)
normalized_input = normalize(random_input)

# Get and initialize model
try:
model = torchvision.models.get_model(key, weights="DEFAULT")
except ValueError as e:
print(f"Error loading model {key}: {e}")
return

model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
normalized_input = normalized_input.to(device)

# Export model
try:
exported = export(model, args=(normalized_input,))
except Exception as e:
print(f"Error exporting model {key}: {e}")
return

# Process parameters
params = exported.state_dict
new_params = {
convert_param_name(k): v
for k, v in params.items()
}

# Generate and save model code
base_code = exported.graph_module.__str__()
write_code = apply_templates(base_code)

os.makedirs(model_path, exist_ok=True)

with open(f'{model_path}/model.py', 'w') as fp:
fp.write(write_code)

# Save metadata
metadata = {
"framework": "torch",
"num_devices_required": 1,
"num_nodes_required": 1
}

with open(f'{model_path}/attribute.json', 'w') as f:
json.dump(metadata, f, indent=4)

# Save tensor metadata and constraints
converted = utils.convert_state_and_inputs(params, exported.example_inputs[0])
utils.save_converted_to_text(
converted,
file_path=f'{model_path}'
)
utils.save_constraints_text(
converted,
file_path=f'{model_path}/input_tensor_constraints.py'
)


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="Export torchvision models to txt"
)
parser.add_argument(
"--key",
type=str,
required=True,
help="Model name from torchvision.models"
)
parser.add_argument(
"--model-path",
type=str,
required=True,
help="Directory to save the exported model"
)
args = parser.parse_args()
main(key=args.key, model_path=args.model_path)
Binary file not shown.
Binary file not shown.
50 changes: 50 additions & 0 deletions graph_net/torch/runner/single_device_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import graph_net.torch.runner.utils as utils
import argparse
import importlib.util
import torch
from pathlib import Path
from typing import Type, Any
import sys

def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]:
file = Path(file_path).resolve()
module_name = file.stem

with open(file_path, 'r', encoding='utf-8') as f:
original_code = f.read()
import_stmt= "import torch"
modified_code = f"{import_stmt}\n{original_code}"
spec = importlib.util.spec_from_loader(module_name, loader=None)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
compiled_code = compile(modified_code, filename=file, mode='exec')
exec(compiled_code, module.__dict__)

model_class = getattr(module, class_name, None)
return model_class

def main(model_path: str):
model_class = load_class_from_file(f"{model_path}/model.py", class_name="GraphModule")
model = model_class()

inputs_params = utils.load_converted_from_text(f'{model_path}')
inputs = inputs_params["input_info"]
inputs = [utils.replay_tensor(i) for i in inputs]
params = inputs_params["weight_info"]

state_dict = {}
for k, v in params.items():
k = utils.convert_param_name(k)
v = utils.replay_tensor(v)
state_dict[k] = v

y = model(x=inputs[0], **state_dict)[0]
print(torch.argmin(y), torch.argmax(y))
print(y.shape)

if __name__ == '__main__':
parser = argparse.ArgumentParser(description="load and run model")
parser.add_argument("--model-path", type=str, required=True,
help="模型文件夹的路径,如'../../samples/torch/resnet18'")
args = parser.parse_args()
main(model_path=args.model_path)
Loading