Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion graph_net/subgraph_decompose_and_evaluation_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,13 @@ def run_evaluation(
work_dir, "reference_device_outputs"
)

cmd = [sys.executable, "-m", f"graph_net.{framework}.{test_module_name}"] + [
test_module_path = (
f"graph_net_bench.{framework}.{test_module_name}"
if test_module_name == "test_compiler"
else f"graph_net.{framework}.{test_module_name}"
)

cmd = [sys.executable, "-m", test_module_path] + [
item
for key, value in test_module_arguments.items()
for item in (f"--{key}", str(value))
Expand Down
2 changes: 1 addition & 1 deletion graph_net/torch/test_reference_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def test_multi_models(args):
def main(args):
assert os.path.isdir(args.model_path)
# Support all torch compilers
valid_compilers = list(test_compiler.registry_backend.keys())
valid_compilers = list(test_compiler.compiler_backend_name2class.keys())
assert (
args.compiler in valid_compilers
), f"Compiler must be one of {valid_compilers}"
Expand Down
3 changes: 2 additions & 1 deletion graph_net/torch/test_target_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from graph_net_bench import path_utils
from graph_net_bench import test_compiler_util
from graph_net import model_path_util
from graph_net_bench.torch import test_compiler, test_reference_device
from graph_net_bench.torch import test_compiler
from graph_net.torch import test_reference_device


def parse_config_from_reference_log(log_path):
Expand Down
Loading