diff --git a/graph_net/test/dtype_gen_test.sh b/graph_net/test/dtype_gen_test.sh index 155c6b09e..546835680 100755 --- a/graph_net/test/dtype_gen_test.sh +++ b/graph_net/test/dtype_gen_test.sh @@ -2,19 +2,22 @@ GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print( os.path.dirname(graph_net.__file__))") -SAMPLES_ROOT="$GRAPH_NET_ROOT/../" +GRAPHNET_ROOT="$GRAPH_NET_ROOT/../" OUTPUT_DIR="/tmp/dtype_gen_samples" mkdir -p "$OUTPUT_DIR" # Step 1: Initialize dtype generalization passes (samples of torchvision) python3 -m graph_net.apply_sample_pass \ --model-path-list "graph_net/config/small100_torch_samples_list.txt" \ - --sample-pass-file-path "$GRAPH_NET_ROOT/torch/dtype_generalizer.py" \ + --sample-pass-file-path "$GRAPH_NET_ROOT/torch/sample_pass/dtype_generalizer.py" \ --sample-pass-class-name InitDataTypeGeneralizationPasses \ --sample-pass-config $(base64 -w 0 < bool: + dst_model_path = Path(self.config["model_path_prefix"]) / rel_model_path + graph_net_json_path = dst_model_path / "graph_net.json" + with open(graph_net_json_path, "r", encoding="utf-8") as f: + data = json.load(f) + if data.get("data_type_generalization_passes"): + return True + return False + def __call__(self, model_path: str) -> None: + self.resumable_handle_sample(model_path) + + def resume(self, model_path: str) -> None: """ Initialize dtype passes for the given model. @@ -90,9 +117,9 @@ def __call__(self, model_path: str) -> None: model_path = str(Path(self.model_path_prefix) / model_path) # Parse the computation graph - # traced_model = parse_immutable_model_path_into_sole_graph_module(model_path) module, inputs = get_torch_module_and_inputs(model_path) traced_model = parse_sole_graph_module(module, inputs) + ShapeProp(traced_model).propagate(*inputs) # Test which dtype passes work @@ -195,7 +222,7 @@ def _save_dtype_pass_names( update_json(model_path, kDataTypeGeneralizationPasses, dtype_pass_names) -class ApplyDataTypeGeneralizationPasses: +class ApplyDataTypeGeneralizationPasses(SamplePass, ResumableSamplePassMixin): """ Step 2: Apply data type generalization passes to generate new samples. @@ -213,6 +240,8 @@ class ApplyDataTypeGeneralizationPasses: """ def __init__(self, config: Dict[str, Any]): + super().__init__(config) + self.config = config self.output_dir = config.get("output_dir") if not self.output_dir: @@ -228,6 +257,18 @@ def __init__(self, config: Dict[str, Any]): ) self.model_runnable_predicator = self._make_model_runnable_predicator(config) + def declare_config( + self, + output_dir: str, + model_path_prefix: str, + model_runnable_predicator_filepath: str, + model_runnable_predicator_class_name: str, + model_runnable_predicator_config: dict, + resume: bool = False, + limits_handled_models: int = None, + ): + pass + def _make_model_runnable_predicator(self, config: Dict[str, Any]): """Create model runnable predicator from config.""" module = load_module(config["model_runnable_predicator_filepath"]) @@ -238,7 +279,24 @@ def _make_model_runnable_predicator(self, config: Dict[str, Any]): predicator_config = config.get("model_runnable_predicator_config", {}) return cls(predicator_config) - def __call__(self, model_path: str) -> List[str]: + def sample_handled(self, rel_model_path: str) -> bool: + model_path = Path(self.config["model_path_prefix"]) / rel_model_path + dtype_pass_names = self._read_dtype_pass_names(model_path) + num_generated = 0 + for pass_name in dtype_pass_names: + dtype = pass_name.replace("dtype_generalization_pass_", "") + rel_model_path = rel_model_path + "_" + dtype + rel_generated_model_path = Path(*Path(rel_model_path).parts[1:]) + generated_model_path = ( + Path(self.config["output_dir"]) / rel_generated_model_path + ) + num_generated += len(list(generated_model_path.rglob("model.py"))) + return num_generated == len(dtype_pass_names) + + def __call__(self, rel_model_path: str): + self.resumable_handle_sample(rel_model_path) + + def resume(self, model_path: str) -> List[str]: """ Apply dtype passes to generate new samples.