diff --git a/benchmark/python/utils/benchmark_utils.py b/benchmark/python/utils/benchmark_utils.py index ec29309..d048c24 100644 --- a/benchmark/python/utils/benchmark_utils.py +++ b/benchmark/python/utils/benchmark_utils.py @@ -65,12 +65,15 @@ def run_benchmark( ): """Run benchmark with specified backends.""" output = [] + output_type = None with torch.no_grad(): for backend in backends: match backend: case Backends.TORCH_SPARSE_EAGER: - output.append(torch_net(*sparse_inputs)) + sparse_out = torch_net(*sparse_inputs) + output_type = sparse_out.layout + output.append(sparse_out) runtime_results.append( timer( "torch_net(*sparse_inputs)", @@ -133,8 +136,16 @@ def run_benchmark( output.append( torch.sparse_csr_tensor(*sp_out, size=dense_out.shape) ) + # Check MPACT and torch eager both return sparse csr output + # only when torch sparse eager has been run. + if output_type: + assert output_type == torch.sparse_csr else: output.append(torch.from_numpy(sp_out)) + # Check MPACT and torch eager both return dense output + # only when torch sparse eager has been run. + if output_type: + assert output_type == torch.strided invoker, f = mpact_jit_compile(torch_net, *sparse_inputs) compile_time_results.append( timer(