Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions graph_net/torch/constraint_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,3 @@ def __call__(self, model_path):
decorator_config = b64_encoded_bytes.decode("utf-8")
cmd = f"{sys.executable} -m graph_net.torch.run_model --model-path {model_path} --decorator-config {decorator_config}"
return os.system(cmd) == 0


class FusibleSubgraphPredicator:
def __init__(self, config=None):
if config is None:
config = {}
self.config = config

def __call__(self, model_path):
import json
import base64

json_string = json.dumps(self.config)
json_bytes = json_string.encode("utf-8")
b64_encoded_bytes = base64.b64encode(json_bytes)
predicator_config = b64_encoded_bytes.decode("utf-8")
cmd = f"{sys.executable} -m graph_net.model_path_handler --model-path {model_path} --handler-config {predicator_config}"
return os.system(cmd) == 0
34 changes: 34 additions & 0 deletions graph_net/torch/fully_fusible_graph_predicator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import traceback
import logging
from graph_net.torch.graph_decomposer import NaiveDecomposerExtractor
from graph_net.torch.graph_fusibility_status import (
GraphFusibilityStatus,
GraphFusibility,
)

logger = logging.getLogger(__name__)


class FullyFusibleGraphPredicator:
def __init__(self, config=None):
if config is None:
config = {}
self.config = config
handler_config = self.config["handler_config"]
self.decomposer_extractor = NaiveDecomposerExtractor(handler_config)

def __call__(self, model_path):
try:
self.decomposer_extractor(model_path)
except GraphFusibilityStatus as status:
if status.graph_fusibility == GraphFusibility.kFullyFusible:
return True
elif status.graph_fusibility == GraphFusibility.kNotFullyFusible:
return False
else:
raise NotImplementedError(f"{status.graph_fusibility=}")
except Exception:
print("\n--- Custom Error Handler ---")
traceback.print_exc()
print("--------------------------\n")
return False
29 changes: 13 additions & 16 deletions graph_net/torch/fully_fusible_subgraph_extractor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
from pathlib import Path
import graph_net
import tempfile
import shutil
from graph_net.torch import constraint_util
from graph_net.torch import fully_fusible_graph_predicator
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module
import logging
Expand Down Expand Up @@ -60,20 +61,16 @@ def _get_sub_ranges(self):
), f"Invalid range generated: start={start_pos}, end={end_pos}, max={self.config['max_nodes']}"
yield start_pos, end_pos

def _handle_success(
self, temp_dir: str, start_pos: int, end_pos: int, model_name
) -> str:
target_name = f"{model_name}_start{start_pos}_end{end_pos}"
def _handle_success(self, temp_dir: str, rel_model_path: str) -> str:
subdirs = list(Path(temp_dir).iterdir())
assert len(subdirs) == 1
temp_dir = str(subdirs[0])
target_path = os.path.join(
self.config["output_dir"],
target_name,
rel_model_path,
)
os.makedirs(target_path, exist_ok=True)
# shutil.move(temp_dir, target_path)
for item in os.listdir(temp_dir):
source = os.path.join(temp_dir, item)
destination = os.path.join(target_path, item)
shutil.move(source, destination)
shutil.copytree(temp_dir, target_path, dirs_exist_ok=True)
return target_path

def _build_decompose_config(
Expand All @@ -90,7 +87,7 @@ def _build_decompose_config(
"split_positions": [start_pos, end_pos],
"group_head_and_tail": False,
"post_extract_process_path": f"{graph_net_root}/torch/post_extract_process_count_kernels.py",
"post_extract_process_class_name": "GraphFullyFusible",
"post_extract_process_class_name": "ThrowExitStatusIfGraphFullyFusible",
},
}
return check_fusible_config
Expand All @@ -106,14 +103,14 @@ def __call__(self, rel_model_path):
check_fusible_config = self._build_decompose_config(
temp_dir, start_pos, end_pos, self.config["model_path_prefix"]
)
predicator = constraint_util.FusibleSubgraphPredicator(
predicator = fully_fusible_graph_predicator.FullyFusibleGraphPredicator(
check_fusible_config
)
logger.warning("fully_fusible_graph_predicator-begin")
success = predicator(model_path)
logger.warning("fully_fusible_graph_predicator-end")
if success:
target_path = self._handle_success(
temp_dir, start_pos, end_pos, os.path.basename(model_path)
)
target_path = self._handle_success(temp_dir, rel_model_path)
print(
f"SUCCESS in finding the biggest fully fusible subgraph. Result saved to: {target_path}"
)
Expand Down
13 changes: 13 additions & 0 deletions graph_net/torch/graph_fusibility_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from enum import Enum


class GraphFusibility(Enum):
kFullyFusible = "fully_fusible"
kNotFullyFusible = "not_fully_fusible"


class GraphFusibilityStatus(Exception):
def __init__(self, graph_fusibility: GraphFusibility):
message = f"{graph_fusibility=}"
super().__init__(message)
self.graph_fusibility = graph_fusibility
37 changes: 31 additions & 6 deletions graph_net/torch/post_extract_process_count_kernels.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import traceback
from graph_net.torch import utils
import importlib.util
import torch
import sys
from typing import Type
from torch.profiler import profile, record_function, ProfilerActivity

from graph_net.torch.graph_fusibility_status import (
GraphFusibilityStatus,
GraphFusibility,
)

class GraphFullyFusible:

class ThrowExitStatusIfGraphFullyFusible:
def __init__(self, config):
self.config = config

Expand All @@ -16,7 +22,7 @@ def __call__(self, model_path=None):
# atexit.register(callback)
torch._dynamo.reset()
if model_path is None:
sys.exit(1)
raise GraphFusibilityStatus(GraphFusibility.kNotFullyFusible)
# model
model_class = load_class_from_file(
f"{model_path}/model.py", class_name="GraphModule"
Expand All @@ -33,17 +39,36 @@ def __call__(self, model_path=None):
try:
model(**state_dict)
except Exception:
sys.exit(1)
raise GraphFusibilityStatus(GraphFusibility.kNotFullyFusible)
# try to compile the model
try:
compiled_model = torch.compile(model)
except Exception:
sys.exit(1)
raise GraphFusibilityStatus(GraphFusibility.kNotFullyFusible)
compiled_num_of_kernels = count_kernels(compiled_model, state_dict)
if compiled_num_of_kernels == 1:
sys.exit(0)
raise GraphFusibilityStatus(GraphFusibility.kFullyFusible)
else:
sys.exit(1)
raise GraphFusibilityStatus(GraphFusibility.kNotFullyFusible)


class GraphFullyFusible:
def __init__(self, config):
self.predicator = ThrowExitStatusIfGraphFullyFusible(config)

def __call__(self, model_path=None):
try:
self.predicator(model_path)
except GraphFusibilityStatus as status:
if status.graph_fusibility == GraphFusibility.kFullyFusible:
sys.exit(0)
elif status.graph_fusibility == GraphFusibility.kNotFullyFusible:
sys.exit(1)
else:
raise NotImplementedError(f"{status.graph_fusibility=}")
except Exception:
traceback.print_exc()
sys.exit(1)


def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]:
Expand Down