diff --git a/graph_net/tensor_meta.py b/graph_net/tensor_meta.py old mode 100644 new mode 100755 index 154077704..d5dbbdd9a --- a/graph_net/tensor_meta.py +++ b/graph_net/tensor_meta.py @@ -58,6 +58,11 @@ def serialize_to_py_str(self) -> str: lines = [ (f"class {self.record_class_name}:"), (f'\tname = "{self.name}"'), + *( + [f'\toriginal_name = "{self.original_name}"'] + if self.original_name is not None + else [] + ), (f"\tshape = {self.shape}"), (f'\tdtype = "{self.dtype}"'), (f'\tdevice = "{self.device}"'), diff --git a/graph_net/test/graph_variable_rename_test.sh b/graph_net/test/graph_variable_rename_test.sh new file mode 100755 index 000000000..bac616870 --- /dev/null +++ b/graph_net/test/graph_variable_rename_test.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print( +os.path.dirname(graph_net.__file__))") + +# input model path +MODEL_NAME=resnet18 +MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME +config_json_str=$(cat <