-
Notifications
You must be signed in to change notification settings - Fork 45
Fix dtype_generalizer.py #543
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,6 +39,9 @@ | |
| from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs | ||
| from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module | ||
|
|
||
| from graph_net.sample_pass.sample_pass import SamplePass | ||
| from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin | ||
|
|
||
| # Weights that must remain float32 for numerical stability | ||
| FLOAT32_PRESERVED_WEIGHTS = { | ||
| "running_mean", | ||
|
|
@@ -51,7 +54,7 @@ | |
| } | ||
|
|
||
|
|
||
| class InitDataTypeGeneralizationPasses: | ||
| class InitDataTypeGeneralizationPasses(SamplePass, ResumableSamplePassMixin): | ||
| """ | ||
| Step 1: Initialize data type generalization passes for a computation graph. | ||
|
|
||
|
|
@@ -66,6 +69,8 @@ class InitDataTypeGeneralizationPasses: | |
| """ | ||
|
|
||
| def __init__(self, config: Dict[str, Any]): | ||
| super().__init__(config) | ||
|
|
||
| self.config = config | ||
| self.dtype_list = config.get("dtype_list", ["float16", "bfloat16"]) | ||
| self.model_path_prefix = config.get("model_path_prefix", "") | ||
|
|
@@ -78,7 +83,29 @@ def __init__(self, config: Dict[str, Any]): | |
| f"Invalid dtype: {dtype}. Must be one of {valid_dtypes}" | ||
| ) | ||
|
|
||
| def declare_config( | ||
| self, | ||
| dtype_list: list, | ||
| model_path_prefix: str, | ||
| output_dir: str, | ||
| resume: bool = False, | ||
| limits_handled_models: int = None, | ||
| ): | ||
| pass | ||
|
|
||
| def sample_handled(self, rel_model_path: str) -> bool: | ||
| dst_model_path = Path(self.config["model_path_prefix"]) / rel_model_path | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 代码逻辑:通过判断GraphNet/samples/*/graph_net.json的data_type_generalization_passes是否有值来判断sample是否已被处理 |
||
| graph_net_json_path = dst_model_path / "graph_net.json" | ||
| with open(graph_net_json_path, "r", encoding="utf-8") as f: | ||
| data = json.load(f) | ||
| if data.get("data_type_generalization_passes"): | ||
| return True | ||
| return False | ||
|
|
||
| def __call__(self, model_path: str) -> None: | ||
| self.resumable_handle_sample(model_path) | ||
|
|
||
| def resume(self, model_path: str) -> None: | ||
| """ | ||
| Initialize dtype passes for the given model. | ||
|
|
||
|
|
@@ -90,9 +117,9 @@ def __call__(self, model_path: str) -> None: | |
| model_path = str(Path(self.model_path_prefix) / model_path) | ||
|
|
||
| # Parse the computation graph | ||
| # traced_model = parse_immutable_model_path_into_sole_graph_module(model_path) | ||
| module, inputs = get_torch_module_and_inputs(model_path) | ||
| traced_model = parse_sole_graph_module(module, inputs) | ||
|
|
||
| ShapeProp(traced_model).propagate(*inputs) | ||
|
|
||
| # Test which dtype passes work | ||
|
|
@@ -195,7 +222,7 @@ def _save_dtype_pass_names( | |
| update_json(model_path, kDataTypeGeneralizationPasses, dtype_pass_names) | ||
|
|
||
|
|
||
| class ApplyDataTypeGeneralizationPasses: | ||
| class ApplyDataTypeGeneralizationPasses(SamplePass, ResumableSamplePassMixin): | ||
| """ | ||
| Step 2: Apply data type generalization passes to generate new samples. | ||
|
|
||
|
|
@@ -213,6 +240,8 @@ class ApplyDataTypeGeneralizationPasses: | |
| """ | ||
|
|
||
| def __init__(self, config: Dict[str, Any]): | ||
| super().__init__(config) | ||
|
|
||
| self.config = config | ||
| self.output_dir = config.get("output_dir") | ||
| if not self.output_dir: | ||
|
|
@@ -228,6 +257,18 @@ def __init__(self, config: Dict[str, Any]): | |
| ) | ||
| self.model_runnable_predicator = self._make_model_runnable_predicator(config) | ||
|
|
||
| def declare_config( | ||
| self, | ||
| output_dir: str, | ||
| model_path_prefix: str, | ||
| model_runnable_predicator_filepath: str, | ||
| model_runnable_predicator_class_name: str, | ||
| model_runnable_predicator_config: dict, | ||
| resume: bool = False, | ||
| limits_handled_models: int = None, | ||
| ): | ||
| pass | ||
|
|
||
| def _make_model_runnable_predicator(self, config: Dict[str, Any]): | ||
| """Create model runnable predicator from config.""" | ||
| module = load_module(config["model_runnable_predicator_filepath"]) | ||
|
|
@@ -238,7 +279,24 @@ def _make_model_runnable_predicator(self, config: Dict[str, Any]): | |
| predicator_config = config.get("model_runnable_predicator_config", {}) | ||
| return cls(predicator_config) | ||
|
|
||
| def __call__(self, model_path: str) -> List[str]: | ||
| def sample_handled(self, rel_model_path: str) -> bool: | ||
| model_path = Path(self.config["model_path_prefix"]) / rel_model_path | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 代码逻辑:例如,在/tmp/dtype_gen_samples生成了resnet18_float16和resnet18_bfloat16,通过判断resnet18_float16和resnet18_bfloat16含有model.py的数量是否等于dtype_pass_names的数量来判断样本是否被处理 |
||
| dtype_pass_names = self._read_dtype_pass_names(model_path) | ||
| num_generated = 0 | ||
| for pass_name in dtype_pass_names: | ||
| dtype = pass_name.replace("dtype_generalization_pass_", "") | ||
| rel_model_path = rel_model_path + "_" + dtype | ||
| rel_generated_model_path = Path(*Path(rel_model_path).parts[1:]) | ||
| generated_model_path = ( | ||
| Path(self.config["output_dir"]) / rel_generated_model_path | ||
| ) | ||
| num_generated += len(list(generated_model_path.rglob("model.py"))) | ||
| return num_generated == len(dtype_pass_names) | ||
|
|
||
| def __call__(self, rel_model_path: str): | ||
| self.resumable_handle_sample(rel_model_path) | ||
|
|
||
| def resume(self, model_path: str) -> List[str]: | ||
| """ | ||
| Apply dtype passes to generate new samples. | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dtype_list是list[str], 但是class SamplePass的_check_config_declaration_parameters方法会检查type必须为list,所以这里修改为list,而不是list[str]