Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions graph_net/test/dtype_gen_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,42 @@

GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
os.path.dirname(graph_net.__file__))")
SAMPLES_ROOT="$GRAPH_NET_ROOT/../"
GRAPHNET_ROOT="$GRAPH_NET_ROOT/../"
OUTPUT_DIR="/tmp/dtype_gen_samples"
mkdir -p "$OUTPUT_DIR"

# Step 1: Initialize dtype generalization passes (samples of torchvision)
python3 -m graph_net.apply_sample_pass \
--model-path-list "graph_net/config/small100_torch_samples_list.txt" \
--sample-pass-file-path "$GRAPH_NET_ROOT/torch/dtype_generalizer.py" \
--sample-pass-file-path "$GRAPH_NET_ROOT/torch/sample_pass/dtype_generalizer.py" \
--sample-pass-class-name InitDataTypeGeneralizationPasses \
--sample-pass-config $(base64 -w 0 <<EOF
{
"dtype_list": ["float16", "bfloat16"],
"model_path_prefix": "$SAMPLES_ROOT"
"model_path_prefix": "$GRAPHNET_ROOT",
"output_dir": "$OUTPUT_DIR",
"resume": true,
"limits_handled_models": null
}
EOF
)

# Step 2: Apply passes to generate samples
python3 -m graph_net.apply_sample_pass \
--model-path-list "graph_net/config/small100_torch_samples_list.txt" \
--sample-pass-file-path "$GRAPH_NET_ROOT/torch/dtype_generalizer.py" \
--sample-pass-file-path "$GRAPH_NET_ROOT/torch/sample_pass/dtype_generalizer.py" \
--sample-pass-class-name ApplyDataTypeGeneralizationPasses \
--sample-pass-config $(base64 -w 0 <<EOF
{
"output_dir": "$OUTPUT_DIR",
"model_path_prefix": "$SAMPLES_ROOT",
"model_path_prefix": "$GRAPHNET_ROOT",
"model_runnable_predicator_filepath": "$GRAPH_NET_ROOT/torch/constraint_util.py",
"model_runnable_predicator_class_name": "RunModelPredicator",
"model_runnable_predicator_config": {
"use_dummy_inputs": true
}
},
"resume": true,
"limits_handled_models": null
}
EOF
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module

from graph_net.sample_pass.sample_pass import SamplePass
from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin

# Weights that must remain float32 for numerical stability
FLOAT32_PRESERVED_WEIGHTS = {
"running_mean",
Expand All @@ -51,7 +54,7 @@
}


class InitDataTypeGeneralizationPasses:
class InitDataTypeGeneralizationPasses(SamplePass, ResumableSamplePassMixin):
"""
Step 1: Initialize data type generalization passes for a computation graph.

Expand All @@ -66,6 +69,8 @@ class InitDataTypeGeneralizationPasses:
"""

def __init__(self, config: Dict[str, Any]):
super().__init__(config)

self.config = config
self.dtype_list = config.get("dtype_list", ["float16", "bfloat16"])
self.model_path_prefix = config.get("model_path_prefix", "")
Expand All @@ -78,7 +83,29 @@ def __init__(self, config: Dict[str, Any]):
f"Invalid dtype: {dtype}. Must be one of {valid_dtypes}"
)

def declare_config(
self,
dtype_list: list,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype_list是list[str], 但是class SamplePass的_check_config_declaration_parameters方法会检查type必须为list,所以这里修改为list,而不是list[str]

model_path_prefix: str,
output_dir: str,
resume: bool = False,
limits_handled_models: int = None,
):
pass

def sample_handled(self, rel_model_path: str) -> bool:
dst_model_path = Path(self.config["model_path_prefix"]) / rel_model_path
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码逻辑:通过判断GraphNet/samples/*/graph_net.json的data_type_generalization_passes是否有值来判断sample是否已被处理

graph_net_json_path = dst_model_path / "graph_net.json"
with open(graph_net_json_path, "r", encoding="utf-8") as f:
data = json.load(f)
if data.get("data_type_generalization_passes"):
return True
return False

def __call__(self, model_path: str) -> None:
self.resumable_handle_sample(model_path)

def resume(self, model_path: str) -> None:
"""
Initialize dtype passes for the given model.

Expand All @@ -90,9 +117,9 @@ def __call__(self, model_path: str) -> None:
model_path = str(Path(self.model_path_prefix) / model_path)

# Parse the computation graph
# traced_model = parse_immutable_model_path_into_sole_graph_module(model_path)
module, inputs = get_torch_module_and_inputs(model_path)
traced_model = parse_sole_graph_module(module, inputs)

ShapeProp(traced_model).propagate(*inputs)

# Test which dtype passes work
Expand Down Expand Up @@ -195,7 +222,7 @@ def _save_dtype_pass_names(
update_json(model_path, kDataTypeGeneralizationPasses, dtype_pass_names)


class ApplyDataTypeGeneralizationPasses:
class ApplyDataTypeGeneralizationPasses(SamplePass, ResumableSamplePassMixin):
"""
Step 2: Apply data type generalization passes to generate new samples.

Expand All @@ -213,6 +240,8 @@ class ApplyDataTypeGeneralizationPasses:
"""

def __init__(self, config: Dict[str, Any]):
super().__init__(config)

self.config = config
self.output_dir = config.get("output_dir")
if not self.output_dir:
Expand All @@ -228,6 +257,18 @@ def __init__(self, config: Dict[str, Any]):
)
self.model_runnable_predicator = self._make_model_runnable_predicator(config)

def declare_config(
self,
output_dir: str,
model_path_prefix: str,
model_runnable_predicator_filepath: str,
model_runnable_predicator_class_name: str,
model_runnable_predicator_config: dict,
resume: bool = False,
limits_handled_models: int = None,
):
pass

def _make_model_runnable_predicator(self, config: Dict[str, Any]):
"""Create model runnable predicator from config."""
module = load_module(config["model_runnable_predicator_filepath"])
Expand All @@ -238,7 +279,24 @@ def _make_model_runnable_predicator(self, config: Dict[str, Any]):
predicator_config = config.get("model_runnable_predicator_config", {})
return cls(predicator_config)

def __call__(self, model_path: str) -> List[str]:
def sample_handled(self, rel_model_path: str) -> bool:
model_path = Path(self.config["model_path_prefix"]) / rel_model_path
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码逻辑:例如,在/tmp/dtype_gen_samples生成了resnet18_float16和resnet18_bfloat16,通过判断resnet18_float16和resnet18_bfloat16含有model.py的数量是否等于dtype_pass_names的数量来判断样本是否被处理

dtype_pass_names = self._read_dtype_pass_names(model_path)
num_generated = 0
for pass_name in dtype_pass_names:
dtype = pass_name.replace("dtype_generalization_pass_", "")
rel_model_path = rel_model_path + "_" + dtype
rel_generated_model_path = Path(*Path(rel_model_path).parts[1:])
generated_model_path = (
Path(self.config["output_dir"]) / rel_generated_model_path
)
num_generated += len(list(generated_model_path.rglob("model.py")))
return num_generated == len(dtype_pass_names)

def __call__(self, rel_model_path: str):
self.resumable_handle_sample(rel_model_path)

def resume(self, model_path: str) -> List[str]:
"""
Apply dtype passes to generate new samples.

Expand Down