Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions graph_net/sample_pass/fusible_subgraph_ranges_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,12 @@ def valid_fused_ops(num_ops_list: list[int]):
if is_a_range(num_ops_list)
if valid_fused_ops(num_ops_list)
]
offset = self.start_offset_in_original_graph
fusible_subgraph_ranges = [
(start, end)
for num_ops_list in num_ops_lists
for start in [num_ops_list[0] - 1]
for end in [num_ops_list[-1]]
for start in [num_ops_list[0] - 1 + offset]
for end in [num_ops_list[-1] + offset]
]

# sorted by `start`
Expand Down
119 changes: 119 additions & 0 deletions graph_net/sample_pass/group_fusible_subgraph_ranges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from graph_net.sample_pass.sample_pass import SamplePass
from pathlib import Path
import json


class GroupFusibleSubgraphRanges(SamplePass):
def __init__(self, config=None):
super().__init__(config)
self.original_graph_rel_model_path2ranges: dict[str, list[(int, int)]] = {}
self.original_graph_rel_model_path2subgraph_rel_model_paths: dict[
str, list[str]
] = {}

def declare_config(
self,
subgraph_model_path_prefix: str,
output_dir: str,
input_json_file_name: str = "fusible_subgraph_ranges.json",
output_json_file_name: str = "grouped_fusible_subgraph_ranges.json",
output_json_key: str = "subgraph_ranges",
output_json_subgraph_rel_model_path_key: str = "fusible_subgraph_relative_model_paths",
):
pass

def __call__(self, subgraph_rel_model_path: str):
model_path = (
Path(self.config["subgraph_model_path_prefix"])
/ subgraph_rel_model_path
/ self.config["input_json_file_name"]
)
subgraph_sources = json.load(open(model_path))
subgraph_ranges = subgraph_sources.get(self.config["output_json_key"], [])
original_graph_rel_model_path = self._extract_original_model_path(
subgraph_rel_model_path
)
self._collect_original_graph_rel_model_path2ranges(
original_graph_rel_model_path, subgraph_ranges
)
self._collect_original_graph_rel_model_path2subgraph_rel_model_path(
original_graph_rel_model_path,
[subgraph_rel_model_path] * len(subgraph_ranges),
)

def _extract_original_model_path(self, rel_model_path: str) -> str:
path_parts = Path(rel_model_path).parts
if "_decomposed" in path_parts:
decomposed_idx = path_parts.index("_decomposed")
return str(Path(*path_parts[:decomposed_idx]))
return rel_model_path

def _collect_original_graph_rel_model_path2subgraph_rel_model_path(
self,
original_graph_rel_model_path: str,
subgraph_rel_model_paths: list[str],
):
old = self.original_graph_rel_model_path2subgraph_rel_model_paths.get(
original_graph_rel_model_path, []
)
self.original_graph_rel_model_path2subgraph_rel_model_paths[
original_graph_rel_model_path
] = [
*old,
*subgraph_rel_model_paths,
]

def _collect_original_graph_rel_model_path2ranges(
self, original_graph_rel_model_path, subgraph_ranges
):
old_ranges = self.original_graph_rel_model_path2ranges.get(
original_graph_rel_model_path, []
)
self.original_graph_rel_model_path2ranges[original_graph_rel_model_path] = [
*old_ranges,
*subgraph_ranges,
]

def END(self, rel_model_paths: list[str]):
for (
original_graph_rel_model_path,
subgraph_ranges,
) in self.original_graph_rel_model_path2ranges.items():
subgraph_rel_model_paths = (
self.original_graph_rel_model_path2subgraph_rel_model_paths[
original_graph_rel_model_path
]
)
self._save_json(
original_graph_rel_model_path, subgraph_ranges, subgraph_rel_model_paths
)

def _save_json(
self, original_graph_rel_model_path, subgraph_ranges, subgraph_rel_model_paths
):
model_dir = Path(self.config["output_dir"]) / original_graph_rel_model_path
model_dir.mkdir(parents=True, exist_ok=True)

# Sort ranges by start index, and sort paths accordingly
sorted_data = sorted(
zip(subgraph_ranges, subgraph_rel_model_paths), key=lambda x: x[0][0]
)
sorted_ranges, sorted_paths = zip(*sorted_data) if sorted_data else ([], [])

ranges_json = self._get_ranges_json(list(sorted_ranges))
paths_json = self._get_paths_json(list(sorted_paths))
json_obj = {**ranges_json, **paths_json}
json_str = json.dumps(json_obj, indent=4)
(model_dir / self.config["output_json_file_name"]).write_text(json_str)

def _get_paths_json(self, subgraph_rel_model_paths: list[str]):
json_obj = {
self.config[
"output_json_subgraph_rel_model_path_key"
]: subgraph_rel_model_paths
}
return json_obj

def _get_ranges_json(self, subgraph_ranges: list[(int, int)]):
json_obj = {self.config["output_json_key"]: subgraph_ranges}
return json_obj
18 changes: 18 additions & 0 deletions graph_net/test/group_fusible_subgraph_ranges_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/bin/bash

GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")

python3 -m graph_net.apply_sample_pass \
--model-path-list "$GRAPH_NET_ROOT/graph_net/test/workspace_group_fusible_subgraph_ranges/sample_list.txt" \
--sample-pass-file-path "$GRAPH_NET_ROOT/graph_net/sample_pass/group_fusible_subgraph_ranges.py" \
--sample-pass-class-name "GroupFusibleSubgraphRanges" \
--sample-pass-config $(base64 -w 0 <<EOF
{
"subgraph_model_path_prefix": "$GRAPH_NET_ROOT/graph_net/test/workspace_group_fusible_subgraph_ranges",
"output_dir": "/tmp/workspace_group_fusible_subgraph_ranges",
"input_json_file_name": "fusible_subgraph_ranges.json",
"output_json_file_name": "grouped_fusible_subgraph_ranges.json",
"output_json_key": "subgraph_ranges"
}
EOF
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
samples/transformers-auto-model/hf-tiny-model-private_tiny-random-Swinv2ForImageClassification/_decomposed/hf-tiny-model-private_tiny-random-Swinv2ForImageClassification_start167_end183_2
samples/transformers-auto-model/hf-tiny-model-private_tiny-random-Swinv2ForImageClassification/_decomposed/hf-tiny-model-private_tiny-random-Swinv2ForImageClassification_start230_end288_4
samples/transformers-auto-model/hf-tiny-model-private_tiny-random-Swinv2ForImageClassification/_decomposed/hf-tiny-model-private_tiny-random-Swinv2ForImageClassification_start186_end207_3
samples/transformers-auto-model/hf-tiny-model-private_tiny-random-Swinv2ForImageClassification/_decomposed/hf-tiny-model-private_tiny-random-Swinv2ForImageClassification_start72_end130_1
samples/transformers-auto-model/hf-tiny-model-private_tiny-random-Swinv2ForImageClassification/_decomposed/hf-tiny-model-private_tiny-random-Swinv2ForImageClassification_start5_end63_0
samples/timm/convnextv2_base.fcmae_ft_in1k/_decomposed/convnextv2_base.fcmae_ft_in1k_start643_end700_11
samples/timm/convnextv2_base.fcmae_ft_in1k/_decomposed/convnextv2_base.fcmae_ft_in1k_start126_end183_2
samples/timm/convnextv2_base.fcmae_ft_in1k/_decomposed/convnextv2_base.fcmae_ft_in1k_start297_end354_5
samples/timm/convnextv2_base.fcmae_ft_in1k/_decomposed/convnextv2_base.fcmae_ft_in1k_start65_end122_1
samples/timm/convnextv2_base.fcmae_ft_in1k/_decomposed/convnextv2_base.fcmae_ft_in1k_start4_end61_0
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"subgraph_ranges": [
[
126,
129
],
[
130,
132
],
[
133,
138
],
[
145,
148
],
[
149,
151
],
[
152,
157
],
[
164,
167
],
[
168,
170
],
[
171,
176
]
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"subgraph_ranges": [
[
297,
300
],
[
301,
303
],
[
304,
309
],
[
316,
319
],
[
320,
322
],
[
323,
328
],
[
335,
338
],
[
339,
341
],
[
342,
347
]
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"subgraph_ranges": [
[
4,
7
],
[
8,
10
],
[
11,
16
],
[
23,
26
],
[
27,
29
],
[
30,
35
],
[
42,
45
],
[
46,
48
],
[
49,
54
]
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
{
"subgraph_ranges": [
[
643,
646
],
[
647,
649
],
[
650,
655
],
[
655,
657
],
[
658,
661
],
[
662,
665
],
[
666,
668
],
[
669,
674
],
[
674,
676
],
[
677,
680
],
[
681,
684
],
[
685,
687
],
[
688,
693
],
[
693,
695
],
[
696,
699
]
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"subgraph_ranges": [
[
65,
68
],
[
69,
71
],
[
72,
77
],
[
84,
87
],
[
88,
90
],
[
91,
96
],
[
103,
106
],
[
107,
109
],
[
110,
115
]
]
}
Loading