From a2419af27c6b2c2e31e8c55dea3e643f8b277ccf Mon Sep 17 00:00:00 2001 From: "huangchenhui.yellow" Date: Mon, 1 Jul 2024 22:46:30 +0800 Subject: [PATCH] refine chore code --- runtime/python/brt/utils.py | 2 ++ tests/compatibilty_test/execute.py | 4 ++-- tests/compatibilty_test/utils.py | 35 ------------------------------ 3 files changed, 4 insertions(+), 37 deletions(-) delete mode 100644 tests/compatibilty_test/utils.py diff --git a/runtime/python/brt/utils.py b/runtime/python/brt/utils.py index ddc12b898..7e3739c82 100644 --- a/runtime/python/brt/utils.py +++ b/runtime/python/brt/utils.py @@ -16,6 +16,8 @@ def brt_dtype_to_torch_dtype(dtype: brt.DType): return torch.float64 if dtype == brt.DType.bool: return torch.bool + if dtype == brt.DType.int1: + return torch.bool if dtype == brt.DType.int8: return torch.int8 if dtype == brt.DType.int16: diff --git a/tests/compatibilty_test/execute.py b/tests/compatibilty_test/execute.py index ea386008d..202c248ca 100644 --- a/tests/compatibilty_test/execute.py +++ b/tests/compatibilty_test/execute.py @@ -13,6 +13,7 @@ # ============================================================================== import brt +from brt.utils import brt_dtype_to_torch_dtype import torch import numpy as np @@ -20,7 +21,6 @@ import re from reporting import TestResult -from utils import (mlir_type_to_torch_dtype) class BRTBackend: @@ -47,7 +47,7 @@ def _generate_torch_outputs(self): for offset in self.session.get_output_arg_offsets(): outputs.append( torch.empty(self.session.get_static_shape(offset), - dtype=mlir_type_to_torch_dtype( + dtype=brt_dtype_to_torch_dtype( self.session.get_data_type(offset)), device=self.device)) return outputs diff --git a/tests/compatibilty_test/utils.py b/tests/compatibilty_test/utils.py deleted file mode 100644 index 10050a260..000000000 --- a/tests/compatibilty_test/utils.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import torch - -def mlir_type_to_torch_dtype(mlir_type): - if str(mlir_type) in ["DType.float64"]: - return torch.float64 - if str(mlir_type) in ["DType.float32"]: - return torch.float32 - if str(mlir_type) in ["DType.float16"]: - return torch.float16 - if str(mlir_type) in ["DType.int64", "DType.index"]: - return torch.int64 - if str(mlir_type) in ["DType.int32"]: - return torch.int32 - if str(mlir_type) in ["DType.int16"]: - return torch.int16 - if str(mlir_type) in ["DType.int8"]: - return torch.int8 - if str(mlir_type) in ["DType.int1", "DType.bool"]: - return torch.bool - raise NotImplementedError("unsupported mlir type {}".format( - str(mlir_type)))