From 3895d7bb33ecaa1bc060789ad8505f3c93ee6a53 Mon Sep 17 00:00:00 2001 From: AL Date: Sun, 29 Oct 2023 15:08:14 -0400 Subject: [PATCH 1/3] . --- .../hidet/graph/frontend/torch/register_functions.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python/hidet/graph/frontend/torch/register_functions.py b/python/hidet/graph/frontend/torch/register_functions.py index 3fe918c22..c7d16ea72 100644 --- a/python/hidet/graph/frontend/torch/register_functions.py +++ b/python/hidet/graph/frontend/torch/register_functions.py @@ -1023,6 +1023,18 @@ def eq(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor def ne(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor: return a != b +@register_function(operator.mod) +def mod(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor: + return a % b + +@register_function(operator.lshift) +def lshift(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor: + return a << b + +@register_function(operator.rshift) +def rshift(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor: + return a >> b + @register_function(torch.rsqrt) def rsqrt(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor: From 3c1ceb3923dbea2fe9fcab9b329fb1420a335ce3 Mon Sep 17 00:00:00 2001 From: Allan Lin Date: Sun, 29 Oct 2023 15:10:48 -0400 Subject: [PATCH 2/3] format/lint --- python/hidet/graph/frontend/torch/register_functions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/hidet/graph/frontend/torch/register_functions.py b/python/hidet/graph/frontend/torch/register_functions.py index c7d16ea72..29285232e 100644 --- a/python/hidet/graph/frontend/torch/register_functions.py +++ b/python/hidet/graph/frontend/torch/register_functions.py @@ -1023,14 +1023,17 @@ def eq(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor def ne(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor: return a != b + @register_function(operator.mod) def mod(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor: return a % b + @register_function(operator.lshift) def lshift(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor: return a << b + @register_function(operator.rshift) def rshift(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor: return a >> b From 1d1a6531e33b4da9e9bba599aac41b89c5dc0cbc Mon Sep 17 00:00:00 2001 From: Allan Lin Date: Mon, 30 Oct 2023 11:34:26 -0400 Subject: [PATCH 3/3] fix bugs --- python/hidet/backend/codegen.py | 1 + .../graph/frontend/torch/register_functions.py | 1 + python/hidet/graph/ops/reduce/reduce.py | 4 ++-- python/hidet/ir/dtypes/__init__.py | 3 ++- python/hidet/ir/dtypes/vector.py | 13 +++++++++++-- 5 files changed, 17 insertions(+), 5 deletions(-) diff --git a/python/hidet/backend/codegen.py b/python/hidet/backend/codegen.py index e5e474636..d653976ed 100644 --- a/python/hidet/backend/codegen.py +++ b/python/hidet/backend/codegen.py @@ -621,6 +621,7 @@ def visit_DataType(self, t: DataType): 'float32x4': '__m128', 'float32x8': '__m256', 'int8x4': 'char4', + 'uint8x4': 'uint4', } self.require_complex = self.require_complex or t.name in ['complex64', 'complex128'] diff --git a/python/hidet/graph/frontend/torch/register_functions.py b/python/hidet/graph/frontend/torch/register_functions.py index 29285232e..99d24c6ae 100644 --- a/python/hidet/graph/frontend/torch/register_functions.py +++ b/python/hidet/graph/frontend/torch/register_functions.py @@ -1046,6 +1046,7 @@ def rsqrt(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor: return ops.rsqrt(x) +@register_function(operator.pow) @register_function(torch.pow) @register_method(torch.Tensor.pow) def tensor_pow(self: Union[Tensor, Number], exponent: Union[Tensor, Number]) -> Tensor: diff --git a/python/hidet/graph/ops/reduce/reduce.py b/python/hidet/graph/ops/reduce/reduce.py index ca669dadb..aaf95142c 100644 --- a/python/hidet/graph/ops/reduce/reduce.py +++ b/python/hidet/graph/ops/reduce/reduce.py @@ -15,7 +15,7 @@ from hidet.lang import grid from hidet.lang.cuda import blockIdx, threadIdx, register_tensor from hidet.ir.type import DataType -from hidet.ir.dtypes.vector import VectorType +from hidet.ir.dtypes.vector import VectorType, vectorize from hidet.ir.library import tune from ..arithmetic import square, sqrt from ..utils import Task, Operator, Tensor, TensorNode, IRModule, ReduceType @@ -85,7 +85,7 @@ def cuda_schedule_reduce_by_warp(self, use_atomic=True) -> IRModule: num_eles: int = 4 // xdtype.nbytes if is_constant(shape[-1]) and shape[-1] % num_eles == 0: lanes = num_eles - vtype = VectorType(xdtype, lanes) + vtype = vectorize(xdtype, lanes) read_shape = shape[:] read_shape[-1] /= lanes block_size = (read_shape[-1] + warp_size - 1) // warp_size * warp_size diff --git a/python/hidet/ir/dtypes/__init__.py b/python/hidet/ir/dtypes/__init__.py index 31391385b..b51d485b9 100644 --- a/python/hidet/ir/dtypes/__init__.py +++ b/python/hidet/ir/dtypes/__init__.py @@ -15,7 +15,7 @@ from .floats import float16, float32, float64, bfloat16, tfloat32 from .floats import f16, f32, f64, bf16, tf32 from .boolean import boolean -from .vector import float16x2, float32x4, float32x8, int8x4, vectorize +from .vector import float16x2, float32x4, float32x8, int8x4, uint8x4, vectorize from .vector import f16x2, f32x4, f32x8 from .complex import complex64, complex128 from .promotion import promote_type @@ -42,6 +42,7 @@ 'float32x8': float32x8, 'float16x2': float16x2, 'int8x4': int8x4, + 'uint8x4': uint8x4, } sname2dtype = { diff --git a/python/hidet/ir/dtypes/vector.py b/python/hidet/ir/dtypes/vector.py index 98326bea9..7f6d2b347 100644 --- a/python/hidet/ir/dtypes/vector.py +++ b/python/hidet/ir/dtypes/vector.py @@ -12,7 +12,7 @@ from typing import Any, Sequence from hidet.ir.type import DataType from .floats import float32, float16 -from .integer import int8 +from .integer import int8, uint8 class VectorType(DataType): @@ -75,6 +75,9 @@ def max_value(self): int8x4 = VectorType(int8, 4) i8x4 = int8x4 +uint8x4 = VectorType(uint8, 4) +u8x4 = uint8x4 + float32x4 = VectorType(float32, 4) f32x4 = float32x4 @@ -86,7 +89,13 @@ def max_value(self): def vectorize(base_dtype: DataType, num_lanes: int) -> VectorType: - table = {(float32, 4): float32x4, (float32, 8): float32x8, (float16, 2): float16x2, (int8, 4): int8x4} + table = { + (float32, 4): float32x4, + (float32, 8): float32x8, + (float16, 2): float16x2, + (int8, 4): int8x4, + (uint8, 4): uint8x4, + } if (base_dtype, num_lanes) in table: return table[(base_dtype, num_lanes)] else: