Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Graph][Operator] Fix reduce bug, add uint8x4 #372

Merged
merged 3 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/hidet/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ def visit_DataType(self, t: DataType):
'float32x4': '__m128',
'float32x8': '__m256',
'int8x4': 'char4',
'uint8x4': 'uint4',
}

self.require_complex = self.require_complex or t.name in ['complex64', 'complex128']
Expand Down
16 changes: 16 additions & 0 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,13 +1024,29 @@ def ne(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor
return a != b


@register_function(operator.mod)
def mod(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor:
return a % b


@register_function(operator.lshift)
def lshift(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor:
return a << b


@register_function(operator.rshift)
def rshift(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor:
return a >> b


@register_function(torch.rsqrt)
def rsqrt(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.rsqrt(..., out=...)")
return ops.rsqrt(x)


@register_function(operator.pow)
@register_function(torch.pow)
@register_method(torch.Tensor.pow)
def tensor_pow(self: Union[Tensor, Number], exponent: Union[Tensor, Number]) -> Tensor:
Expand Down
4 changes: 2 additions & 2 deletions python/hidet/graph/ops/reduce/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from hidet.lang import grid
from hidet.lang.cuda import blockIdx, threadIdx, register_tensor
from hidet.ir.type import DataType
from hidet.ir.dtypes.vector import VectorType
from hidet.ir.dtypes.vector import VectorType, vectorize
from hidet.ir.library import tune
from ..arithmetic import square, sqrt
from ..utils import Task, Operator, Tensor, TensorNode, IRModule, ReduceType
Expand Down Expand Up @@ -85,7 +85,7 @@ def cuda_schedule_reduce_by_warp(self, use_atomic=True) -> IRModule:
num_eles: int = 4 // xdtype.nbytes
if is_constant(shape[-1]) and shape[-1] % num_eles == 0:
lanes = num_eles
vtype = VectorType(xdtype, lanes)
vtype = vectorize(xdtype, lanes)
read_shape = shape[:]
read_shape[-1] /= lanes
block_size = (read_shape[-1] + warp_size - 1) // warp_size * warp_size
Expand Down
3 changes: 2 additions & 1 deletion python/hidet/ir/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .floats import float16, float32, float64, bfloat16, tfloat32
from .floats import f16, f32, f64, bf16, tf32
from .boolean import boolean
from .vector import float16x2, float32x4, float32x8, int8x4, vectorize
from .vector import float16x2, float32x4, float32x8, int8x4, uint8x4, vectorize
from .vector import f16x2, f32x4, f32x8
from .complex import complex64, complex128
from .promotion import promote_type
Expand All @@ -42,6 +42,7 @@
'float32x8': float32x8,
'float16x2': float16x2,
'int8x4': int8x4,
'uint8x4': uint8x4,
}

sname2dtype = {
Expand Down
13 changes: 11 additions & 2 deletions python/hidet/ir/dtypes/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Any, Sequence
from hidet.ir.type import DataType
from .floats import float32, float16
from .integer import int8
from .integer import int8, uint8


class VectorType(DataType):
Expand Down Expand Up @@ -75,6 +75,9 @@ def max_value(self):
int8x4 = VectorType(int8, 4)
i8x4 = int8x4

uint8x4 = VectorType(uint8, 4)
u8x4 = uint8x4

float32x4 = VectorType(float32, 4)
f32x4 = float32x4

Expand All @@ -86,7 +89,13 @@ def max_value(self):


def vectorize(base_dtype: DataType, num_lanes: int) -> VectorType:
table = {(float32, 4): float32x4, (float32, 8): float32x8, (float16, 2): float16x2, (int8, 4): int8x4}
table = {
(float32, 4): float32x4,
(float32, 8): float32x8,
(float16, 2): float16x2,
(int8, 4): int8x4,
(uint8, 4): uint8x4,
}
if (base_dtype, num_lanes) in table:
return table[(base_dtype, num_lanes)]
else:
Expand Down