diff --git a/python/hidet/ir/builders/stmt_builder.py b/python/hidet/ir/builders/stmt_builder.py index ed825dc72..544512a7d 100644 --- a/python/hidet/ir/builders/stmt_builder.py +++ b/python/hidet/ir/builders/stmt_builder.py @@ -12,21 +12,24 @@ from typing import Union, Sequence, List, cast, Optional from hidet.ir.stmt import Stmt, ForStmt, IfStmt, EvaluateStmt, SeqStmt, LetStmt, ForMappingStmt, ForStmtAttr -from hidet.ir.stmt import DeclareStmt, BufferStoreStmt, AssignStmt +from hidet.ir.stmt import DeclareStmt, BufferStoreStmt, AssignStmt, ReturnStmt, WhileStmt, BreakStmt from hidet.ir.expr import Expr, Var, var, convert from hidet.ir.mapping import RepeatTaskMapping from hidet.ir.dtypes import int32 from hidet.ir.mapping import TaskMapping, repeat_map -ScopedStmt = Union[IfStmt, ForStmt, LetStmt, ForMappingStmt] +ScopedStmt = Union[IfStmt, ForStmt, LetStmt, ForMappingStmt, WhileStmt] class StmtScope: - def __init__(self, sb: 'StmtBuilder', stmts: Union[Sequence[ScopedStmt], ScopedStmt], ret=None): - if isinstance(stmts, (IfStmt, ForStmt, LetStmt, ForMappingStmt)): + def __init__(self, sb, stmts: Union[ScopedStmt, Sequence[ScopedStmt]], ret=None): + if not isinstance(stmts, Sequence): stmts = [stmts] - self.sb = sb - self.stmts = stmts + + assert all(isinstance(stmt, (IfStmt, ForStmt, LetStmt, ForMappingStmt, WhileStmt)) for stmt in stmts) + + self.sb: StmtBuilder = sb + self.stmts: List[ScopedStmt] = list(stmts) self.ret = ret def __enter__(self): @@ -39,6 +42,52 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.sb.exit_body() +class ElseIfScope: + def __init__(self, sb, stmt: IfStmt): + self.sb: StmtBuilder = sb + self.stmt: IfStmt = stmt + + def __enter__(self): + if not isinstance(self.sb.scope_stack[-1][-1], IfStmt): + raise RuntimeError('else_if() must be called after if_then() or else_if()') + + # put the current one into the scope stack + self.sb.scope_stack[-1].append(self.stmt) + self.sb.scope_stack.append([]) + + def __exit__(self, exc_type, exc_val, exc_tb): + # exit the then body of the current if statement + self.stmt.then_body = SeqStmt(self.sb.scope_stack.pop()) + self.sb.scope_stack[-1].pop() + + cur: IfStmt = self.sb.scope_stack[-1][-1] + while cur.else_body is not None: + if not isinstance(cur.else_body, IfStmt): + raise RuntimeError('else_if() must be called after if_then() or else_if()') + cur = cur.else_body + + cur.else_body = self.stmt + + +class OtherwiseScope: + def __init__(self, sb): + self.sb: StmtBuilder = sb + + def __enter__(self): + if not isinstance(self.sb.scope_stack[-1][-1], IfStmt): + raise RuntimeError('otherwise() must be called after if_then() or else_if()') + self.sb.scope_stack.append([]) + + def __exit__(self, exc_type, exc_val, exc_tb): + else_body = SeqStmt(self.sb.scope_stack.pop()) + cur: IfStmt = self.sb.scope_stack[-1][-1] + while cur.else_body is not None: + if not isinstance(cur.else_body, IfStmt): + raise RuntimeError('otherwise() must be called after if_then() or else_if()') + cur = cur.else_body + cur.else_body = else_body + + class StmtBuilder: def __init__(self): # the structure of scope_stack: @@ -65,11 +114,7 @@ def _name_index_vars(num_vars: int) -> List[str]: iter_names = [f'i{idx}' for idx in range(num_vars)] return iter_names - def let(self, v: Union[str, Var], value: Union[int, Expr]) -> StmtScope: - if isinstance(v, str): - v = var(v) - return StmtScope(self, stmts=LetStmt(v, value), ret=v) - + # singleton statements def declare(self, v: Var, init: Optional[Expr] = None, scope=None): self.append(DeclareStmt(v, init, scope=scope)) return v @@ -77,12 +122,23 @@ def declare(self, v: Var, init: Optional[Expr] = None, scope=None): def buffer_store(self, buf: Expr, indices: Sequence[Union[Expr, int]], value: Expr): self.append(BufferStoreStmt(buf, convert(indices), value)) + def assign(self, dst: Var, value: Expr): + self.append(AssignStmt(dst, value)) + + def brk(self): + self.append(BreakStmt()) + + # scope statements + def let(self, v: Union[str, Var], value: Union[int, Expr]) -> StmtScope: + if isinstance(v, str): + v = var(v) + return StmtScope(self, stmts=LetStmt(v, value), ret=v) + def lets(self, bind_vars: Sequence[Union[str, Var]], values: Sequence[Union[int, Expr]]) -> StmtScope: assert len(bind_vars) == len(values) bind_vars = [var(v) if isinstance(v, str) else v for v in bind_vars] bind_values = [convert(value) for value in values] - seq_let_stmt = LetStmt(bind_vars, bind_values, body=1) - return StmtScope(self, stmts=seq_let_stmt, ret=bind_vars) + return StmtScope(self, stmts=LetStmt(bind_vars, bind_values, body=1), ret=bind_vars) def for_loop(self, v: Union[str, Var], extent: Union[int, Expr], attr: str = '.') -> StmtScope: if isinstance(v, str): @@ -90,15 +146,13 @@ def for_loop(self, v: Union[str, Var], extent: Union[int, Expr], attr: str = '.' return StmtScope(self, stmts=ForStmt(v, extent, attr=ForStmtAttr.parse(attr, num_loops=1)[0]), ret=v) def if_then(self, cond: Union[bool, Expr]) -> StmtScope: - return StmtScope(self, stmts=[IfStmt(cond)], ret=None) + return StmtScope(self, stmts=IfStmt(cond), ret=None) + + def else_if(self, cond: Union[bool, Expr]) -> ElseIfScope: + return ElseIfScope(self, IfStmt(cond)) - def otherwise(self) -> StmtScope: - assert len(self.scope_stack[-1]) > 0 - if_stmt = self.scope_stack[-1].pop() - assert isinstance(if_stmt, IfStmt) - assert if_stmt.then_body is not None - assert if_stmt.else_body is None - return StmtScope(self, stmts=if_stmt, ret=None) + def otherwise(self) -> OtherwiseScope: + return OtherwiseScope(self) def for_mapping( self, @@ -109,25 +163,26 @@ def for_mapping( if worker is None: if not isinstance(mapping, RepeatTaskMapping): raise ValueError('worker must be specified for non-repeat mapping') - worker = 0 + worker = int32.zero if iter_names is None: iter_names = self._name_index_vars(len(mapping.task_shape)) iter_vars = [var(name) for name in iter_names] return StmtScope(self, stmts=ForMappingStmt(iter_vars, mapping, worker, cast(Stmt, None)), ret=iter_vars) def for_grid(self, shape: List[Union[Expr, int]]) -> StmtScope: - iter_names = self._name_index_vars(len(shape)) - iter_vars = [var(name) for name in iter_names] - mapping = repeat_map(shape) - return StmtScope(self, stmts=ForMappingStmt(iter_vars, mapping, int32(0), cast(Stmt, None)), ret=iter_vars) - - def assign(self, dst: Var, value: Expr): - self.append(AssignStmt(dst, value)) + return self.for_mapping(mapping=repeat_map(shape), iter_names=self._name_index_vars(len(shape)), worker=0) def for_range(self, extent: Union[Expr, int]): iter_var = var('i') return StmtScope(self, stmts=ForStmt(iter_var, extent), ret=iter_var) + def while_loop(self, cond: Expr): + return StmtScope(self, stmts=WhileStmt(cond, body=None), ret=None) + + def ret(self, value: Optional[Expr] = None): + self.append(ReturnStmt(value)) + + # utils def append(self, stmt: Union[Stmt, Expr, Sequence[Stmt]]): if stmt is None: return @@ -140,7 +195,7 @@ def append(self, stmt: Union[Stmt, Expr, Sequence[Stmt]]): for s in stmt: self.append(s) - def enter_body(self, stmt: Union[IfStmt, ForStmt, LetStmt]): + def enter_body(self, stmt: Union[IfStmt, ForStmt, LetStmt, WhileStmt]): self.scope_stack[-1].append(stmt) self.scope_stack.append([]) @@ -148,8 +203,8 @@ def exit_body(self): body = SeqStmt(self.scope_stack.pop()) assert len(self.scope_stack) > 0 last_stmt = self.scope_stack[-1][-1] - if isinstance(last_stmt, (ForStmt, LetStmt)): - assert last_stmt.body is None or last_stmt.body == 1 + if isinstance(last_stmt, (ForStmt, LetStmt, WhileStmt)): + assert last_stmt.body is None last_stmt.body = body elif isinstance(last_stmt, IfStmt): if last_stmt.then_body is None: diff --git a/python/hidet/transforms/lower_integer_subbyte.py b/python/hidet/transforms/lower_integer_subbyte.py index 9742b4e2f..b969f5322 100644 --- a/python/hidet/transforms/lower_integer_subbyte.py +++ b/python/hidet/transforms/lower_integer_subbyte.py @@ -12,26 +12,16 @@ # pylint: disable=unused-variable from typing import Dict, List, Union, Tuple -from hidet.ir.tools import infer_type, simplify +from hidet.ir.tools import TypeInfer, simplify from hidet.ir.type import BaseType, DataType, TensorType, TensorPointerType, PointerType -from hidet.ir.dtypes import i32 +from hidet.ir.dtypes import i32, boolean from hidet.ir.expr import Var, Expr, Add, TensorElement, Address, Constant, Cast, var, cast, bitwise_not -from hidet.ir.stmt import ( - Stmt, - DeclareStmt, - AssignStmt, - LetStmt, - EvaluateStmt, - BufferStoreStmt, - SeqStmt, - BlackBoxStmt, - WhileStmt, - DeclareScope, -) - +from hidet.ir.primitives.cuda.atomic import atomic_cas +from hidet.ir.stmt import Stmt, DeclareStmt, AssignStmt, LetStmt, EvaluateStmt, BufferStoreStmt, SeqStmt, DeclareScope from hidet.ir.func import Function from hidet.ir.module import IRModule from hidet.ir.functors import IRRewriter +from hidet.ir.builders import StmtBuilder from hidet.transforms import Pass from hidet.utils.py import is_power_of_two @@ -72,6 +62,7 @@ class LowerIntegerSubbyteRewriter(IRRewriter): # int4b c = a + b ==> not allowed def __init__(self): super().__init__() + self.type_infer = TypeInfer() self.old2new: Dict[Var, Var] = {} self.stmts: List[Stmt] = [] self.var2scope: Dict[Var, DeclareScope] = {} @@ -81,7 +72,7 @@ def auto_var(self, v: Var = None, hint: str = None, e: Expr = None): if v is not None: self.stmts.append(DeclareStmt(v)) return v - v_ty = infer_type(e) + v_ty = self.type_infer(e) v = var(hint, v_ty) self.stmts.append(DeclareStmt(v, e)) return v @@ -119,7 +110,7 @@ def _get_subbyte_value(self, dtype: DataType, base: Var, offset: Expr): idx = simplify(offset // divisor) offset_ = simplify(offset % divisor) mask = storage_ty.constant(dtype.bits_mask) - return (base[idx] >> (offset_ * dtype_bits)) & mask + return storage_ty((base[idx] >> (offset_ * dtype_bits)) & mask) def _set_subbyte_value(self, dtype: DataType, base: Var, offset: Expr, value: Expr): storage_ty = dtype.storage @@ -130,29 +121,54 @@ def _set_subbyte_value(self, dtype: DataType, base: Var, offset: Expr, value: Ex raise TypeError(f"data type not supported yet(got:{dtype})") idx = simplify(offset // divisor) offset_ = simplify(offset % divisor) - value_ty = infer_type(value) - assert value_ty == storage_ty + value_ty = self.type_infer(value) + assert value_ty == storage_ty, f"expected {storage_ty}, but got {value_ty} (value={value})" mask = storage_ty.constant(dtype.bits_mask) item = self.auto_var(hint="item", e=value & mask) - updated_mask = self.auto_var(hint="updated_mask", e=bitwise_not(mask << (offset_ * dtype_bits))) - new_bits = self.auto_var(hint="new_bits", e=item << (offset_ * dtype_bits)) + updated_mask = self.auto_var(hint="updated_mask", e=bitwise_not(storage_ty(mask << (offset_ * dtype_bits)))) + new_bits = self.auto_var(hint="new_bits", e=storage_ty(item << (offset_ * dtype_bits))) + + from hidet.ir.dtypes import u32, u16, u8 - from hidet.ir.dtypes import u32, u16 + if base not in self.var2scope: + raise NotImplementedError('Can not determine the scope of {}'.format(base)) if self.var2scope[base].is_memory(): - if not any(storage_ty is ty for ty in [i32, u32, u16]): - raise NotImplementedError( - "writing subbyte data to memory requires the storage type must be" - " int32, uint32, or uint16 due to atomicCAS, but got({storage_ty})" - ) - original = self.auto_var(hint="original", e=storage_ty.zero) - updated = self.auto_var(hint="updated", e=storage_ty.zero) - body = [] - body.append(AssignStmt(original, base[idx])) - body.append(AssignStmt(updated, (original & updated_mask) | new_bits)) - body.append(BlackBoxStmt("atomicCAS({}, {}, {});", ~base[idx], original, updated)) - body = SeqStmt(body) - self.stmts.append(WhileStmt(original == updated, body)) + if storage_ty.nbits == 8: + sb = StmtBuilder() + + original = sb.declare(Var("original", u16)) + result = sb.declare(Var("result", u16)) + with sb.while_loop(boolean.true): + updated_value = ( + ((((original >> (8 * (idx % 2))) & u8(0xFF)) & updated_mask) | new_bits) << (8 * (idx % 2)) + ) | (original & (u16(0xFF) << (1 - idx % 2))) + updated = sb.declare(Var("updated", storage_ty), init=updated_value) + u16_ptr = cast(cast(base, ~u16) + (idx >> 1), ~u16) + sb.assign(original, value=u16_ptr[0]) + sb.assign(result, value=atomic_cas(u16_ptr, compare=original, value=updated)) + with sb.if_then(result == original): + sb.brk() + self.stmts.append(sb.finish()) + else: + if not any(storage_ty is ty for ty in [i32, u32, u16]): + raise NotImplementedError( + "writing subbyte data to memory requires the storage type must be" + f" int32, uint32, or uint16 due to atomicCAS, but got({storage_ty})" + ) + + sb = StmtBuilder() + + original = sb.declare(Var("original", storage_ty)) + result = sb.declare(Var("result", storage_ty)) + with sb.while_loop(boolean.true): + sb.assign(original, value=base[idx]) + updated = sb.declare(Var("updated", storage_ty), init=(original & updated_mask) | new_bits) + sb.assign(result, value=atomic_cas(~base[idx], compare=original, value=updated)) + with sb.if_then(result == original): + sb.brk() + + self.stmts.append(sb.finish()) else: assert self.var2scope[base].is_register() original = self.auto_var(hint="original", e=base[idx]) @@ -199,7 +215,7 @@ def visit_Constant(self, e: Constant): def visit_TensorElement(self, e: TensorElement): if isinstance(e.base, Var): - base_ty = infer_type(e.base) + base_ty = self.type_infer(e.base) if is_pointer_type(base_ty): dtype = get_pointer_base_type(base_ty) if is_integer_subbyte(dtype): @@ -213,7 +229,7 @@ def visit_Address(self, e: Address): if isinstance(e.expr, TensorElement): base = e.expr.base if isinstance(base, Var): - base_ty = infer_type(base) + base_ty = self.type_infer(base) if is_pointer_type(base_ty): dtype = get_pointer_base_type(base_ty) if is_integer_subbyte(dtype): @@ -237,7 +253,7 @@ def _cast_int(self, dtype: DataType, expr: Expr): return (int_data << shift) >> shift def visit_Cast(self, e: Cast): - expr_ty = infer_type(e.expr) + expr_ty = self.type_infer(e.expr) if is_integer_subbyte(expr_ty): if is_integer_subbyte(e.target_type): raise NotImplementedError(f"casting from {expr_ty} to {e.target_type} is not supported yet") @@ -278,8 +294,8 @@ def _subbyte_pointer_add(self, dtype: DataType, ptr: Union[Expr, Tuple[Expr]], o return ptr, offset def visit_Add(self, e: Add): - a_ty = infer_type(e.a) - b_ty = infer_type(e.b) + a_ty = self.type_infer(e.a) + b_ty = self.type_infer(e.b) if isinstance(a_ty, PointerType) and is_integer_subbyte(a_ty.base_type): self.recursive_depth += 1 a = self.visit(e.a) @@ -317,9 +333,12 @@ def visit_LetStmt(self, stmt: LetStmt): def visit_BufferStoreStmt(self, stmt: BufferStoreStmt): if isinstance(stmt.buf, Var): - buf_ty = infer_type(stmt.buf) - if isinstance(buf_ty, TensorType): - dtype = buf_ty.dtype + buf_ty = self.type_infer(stmt.buf) + if isinstance(buf_ty, (TensorType, PointerType)): + if isinstance(buf_ty, TensorType): + dtype = buf_ty.dtype + else: + dtype = buf_ty.base_type if is_integer_subbyte(dtype): buf = self.visit(stmt.buf) indices = self.visit(stmt.indices) @@ -330,6 +349,13 @@ def visit_BufferStoreStmt(self, stmt: BufferStoreStmt): self.append_stmt(super().visit_BufferStoreStmt(stmt)) return self.flatten_stmts(self.flush_stmts()) + def visit_Function(self, func: Function): + params = self.visit(func.params) + for param in params: + if is_pointer_type(param.type): + self.var2scope[param] = DeclareScope.Global + return super().visit_Function(func) + class LowerIntegerSubbytePass(Pass): def process_func(self, func: Function) -> Function: diff --git a/tests/ir/test_int_subbyte.py b/tests/ir/test_int_subbyte.py index 118aed2a4..72f77ad47 100644 --- a/tests/ir/test_int_subbyte.py +++ b/tests/ir/test_int_subbyte.py @@ -11,6 +11,7 @@ # limitations under the License. import pytest +import torch import hidet @@ -122,6 +123,29 @@ def func(out: f32[2, 2]): np.testing.assert_equal(data.cpu().numpy(), groundtruth) +def test_write_int4_to_global_memory(): + from hidet.lang import attrs + from hidet.lang.cuda import threadIdx + from hidet.ir.dtypes import int4b, uint4b + + with hidet.script_module() as script_module: + + @hidet.script + def kernel(d: ~uint4b): + attrs.func_kind = 'cuda_kernel' + attrs.cuda.grid_dim = 1 + attrs.cuda.block_dim = 32 + + i = threadIdx.x + if i < 8: + d[i] = uint4b(i % 0xF) + + func = script_module.build() + d_int32 = torch.empty([1], dtype=torch.int32, device='cuda') + func(d_int32) + assert d_int32.item() == 0x76543210 + + if __name__ == "__main__": hidet.option.cache_dir("./demo_int_subbyte") hidet.option.save_lower_ir(True)