Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions graph_net/local_graph_decomposer_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import argparse
import base64
import json
import subprocess
import sys
from typing import List

from graph_net.graph_net_root import get_graphnet_root


def convert_json_to_b64_string(config) -> str:
return base64.b64encode(json.dumps(config).encode()).decode()


Comment on lines +9 to +14
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The convert_json_to_b64_string function is duplicated from subgraph_decompose_and_evaluation_step.py. This creates code duplication and maintainability issues. Consider importing this function from the existing module or creating a shared utility module for common encoding/decoding functions.

Suggested change
def convert_json_to_b64_string(config) -> str:
return base64.b64encode(json.dumps(config).encode()).decode()
from graph_net.subgraph_decompose_and_evaluation_step import convert_json_to_b64_string

Copilot uses AI. Check for mistakes.
def build_decorator_config(
framework: str,
model_name: str,
output_dir: str,
split_positions: List[int],
) -> dict:
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The build_decorator_config function lacks a docstring explaining its purpose, parameters, and return value. Adding documentation would improve code maintainability.

Suggested change
) -> dict:
) -> dict:
"""
Build the decorator configuration for running a model with the graph decomposer.
The returned dictionary is encoded and passed to the framework-specific
`run_model` entry point to configure how the graph extraction and optional
post-processing should be performed.
:param framework: Name of the ML framework (e.g. ``"paddle"`` or ``"torch"``),
used to select framework-specific extractor paths and optional post-processors.
:param model_name: Logical name of the model to include in the decorator config.
:param output_dir: Directory where the graph decomposer should write its outputs.
:param split_positions: List of layer or block indices at which the model
computation should be split during graph decomposition.
:return: A nested dictionary describing the decorator, custom extractor, and,
for some frameworks (e.g. ``"paddle"``), additional post-extraction processing.
"""

Copilot uses AI. Check for mistakes.
graphnet_root = get_graphnet_root()
decorator_config = {
"decorator_path": f"{graphnet_root}/graph_net/{framework}/extractor.py",
"decorator_config": {
"name": model_name,
"custom_extractor_path": f"{graphnet_root}/graph_net/{framework}/graph_decomposer.py",
"custom_extractor_config": {
"output_dir": output_dir,
"split_positions": split_positions,
"group_head_and_tail": False,
"use_all_inputs": True,
"chain_style": False,
},
},
}

if framework == "paddle":
post_process_configs = {
"post_extract_process_path": f"{graphnet_root}/graph_net/{framework}/graph_meta_restorer.py",
"post_extract_process_class_name": "GraphMetaRestorer",
"post_extract_process_config": {
"update_inplace": True,
"input_meta_allow_partial_update": False,
},
}
decorator_config["decorator_config"]["custom_extractor_config"].update(
post_process_configs
)

return decorator_config


def main():
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main function lacks a docstring explaining its purpose and behavior. Adding documentation would improve code clarity and maintainability.

Suggested change
def main():
def main():
"""Entry point for running a model with the local graph decomposer.
This function expects command-line arguments to have been parsed into the
module-level ``args`` variable. It:
1. Parses and validates the JSON string of tensor split positions.
2. Builds a decorator configuration for the specified framework and model.
3. Base64-encodes the configuration and passes it to the framework-specific
``run_model`` module as a decorator configuration argument.
4. Executes the model in a subprocess and exits the current process with
the subprocess's return code.
"""

Copilot uses AI. Check for mistakes.
split_positions = json.loads(args.split_positions_json)
if not isinstance(split_positions, list) or not all(
isinstance(x, int) for x in split_positions
):
raise ValueError(f"Invalid split positions: {split_positions}")

decorator_config = build_decorator_config(
framework=args.framework,
model_name=args.model_name,
output_dir=args.output_dir,
split_positions=split_positions,
)
decorator_config_b64 = convert_json_to_b64_string(decorator_config)

cmd = [
sys.executable,
"-m",
f"graph_net.{args.framework}.run_model",
"--model-path",
args.model_path,
"--decorator-config",
decorator_config_b64,
]

result = subprocess.run(cmd, text=True)
sys.exit(result.returncode)
Comment on lines +53 to +79
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main function references the args variable before it is defined. The args variable is only assigned on line 92 within the if name == "main" block, but the main function is called on line 93. This creates a scope issue where args is not accessible within the main function.

The main function should accept args as a parameter.

Copilot uses AI. Check for mistakes.


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--framework", type=str, choices=["paddle", "torch"], required=True
)
parser.add_argument("--model-name", type=str, required=True)
parser.add_argument("--model-path", type=str, required=True)
parser.add_argument("--output-dir", type=str, required=True)
parser.add_argument("--split-positions-json", type=str, required=True)

args = parser.parse_args()
main()
49 changes: 15 additions & 34 deletions graph_net/subgraph_decompose_and_evaluation_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,52 +322,33 @@ def run_decomposer_for_single_model(
output_dir: str,
log_path: str,
) -> bool:
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring for the run_decomposer_for_single_model function was removed during refactoring. The original docstring "Decomposes a single model." should be retained to maintain documentation consistency.

Suggested change
) -> bool:
) -> bool:
"""Decomposes a single model."""

Copilot uses AI. Check for mistakes.
"""Decomposes a single model."""

graphnet_root = get_graphnet_root()
decorator_config = {
"decorator_path": f"{graphnet_root}/graph_net/{framework}/extractor.py",
"decorator_config": {
"name": model_name,
"custom_extractor_path": f"{graphnet_root}/graph_net/{framework}/graph_decomposer.py",
"custom_extractor_config": {
"output_dir": output_dir,
"split_positions": split_positions,
"group_head_and_tail": False,
"use_all_inputs": True,
"chain_style": False,
},
},
}
if framework == "paddle":
post_process_configs = {
"post_extract_process_path": f"{graphnet_root}/graph_net/{framework}/graph_meta_restorer.py",
"post_extract_process_class_name": "GraphMetaRestorer",
"post_extract_process_config": {
"update_inplace": True,
"input_meta_allow_partial_update": False,
},
}
for key, value in post_process_configs.items():
decorator_config["decorator_config"]["custom_extractor_config"][key] = value

decorator_config_b64 = convert_json_to_b64_string(decorator_config)

print(
f"[Decomposition] model_path: {model_path}, split_positions: {split_positions}",
flush=True,
)

split_positions_json = json.dumps(split_positions)

cmd = [
sys.executable,
"-m",
f"graph_net.{framework}.run_model",
"graph_net.local_graph_decomposer_wrapper",
"--framework",
framework,
"--model-name",
model_name,
"--model-path",
model_path,
"--decorator-config",
decorator_config_b64,
"--output-dir",
output_dir,
"--split-positions-json",
split_positions_json,
]

os.makedirs(os.path.dirname(log_path), exist_ok=True)
with open(log_path, "a") as f:
result = subprocess.run(cmd, stdout=f, stderr=f, text=True)

return result.returncode == 0


Expand Down
Loading