Skip to content

Commit

Permalink
[Fix] Support writing subbyte data to global memory (#415)
Browse files Browse the repository at this point in the history
Fix #400
  • Loading branch information
yaoyaoding authored and vadiklyutiy committed Dec 19, 2024
1 parent 9ac0516 commit b1f6177
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 75 deletions.
119 changes: 87 additions & 32 deletions python/hidet/ir/builders/stmt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,24 @@
from typing import Union, Sequence, List, cast, Optional

from hidet.ir.stmt import Stmt, ForStmt, IfStmt, EvaluateStmt, SeqStmt, LetStmt, ForMappingStmt, ForStmtAttr
from hidet.ir.stmt import DeclareStmt, BufferStoreStmt, AssignStmt
from hidet.ir.stmt import DeclareStmt, BufferStoreStmt, AssignStmt, ReturnStmt, WhileStmt, BreakStmt
from hidet.ir.expr import Expr, Var, var, convert
from hidet.ir.mapping import RepeatTaskMapping
from hidet.ir.dtypes import int32
from hidet.ir.mapping import TaskMapping, repeat_map

ScopedStmt = Union[IfStmt, ForStmt, LetStmt, ForMappingStmt]
ScopedStmt = Union[IfStmt, ForStmt, LetStmt, ForMappingStmt, WhileStmt]


class StmtScope:
def __init__(self, sb: 'StmtBuilder', stmts: Union[Sequence[ScopedStmt], ScopedStmt], ret=None):
if isinstance(stmts, (IfStmt, ForStmt, LetStmt, ForMappingStmt)):
def __init__(self, sb, stmts: Union[ScopedStmt, Sequence[ScopedStmt]], ret=None):
if not isinstance(stmts, Sequence):
stmts = [stmts]
self.sb = sb
self.stmts = stmts

assert all(isinstance(stmt, (IfStmt, ForStmt, LetStmt, ForMappingStmt, WhileStmt)) for stmt in stmts)

self.sb: StmtBuilder = sb
self.stmts: List[ScopedStmt] = list(stmts)
self.ret = ret

def __enter__(self):
Expand All @@ -39,6 +42,52 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.sb.exit_body()


class ElseIfScope:
def __init__(self, sb, stmt: IfStmt):
self.sb: StmtBuilder = sb
self.stmt: IfStmt = stmt

def __enter__(self):
if not isinstance(self.sb.scope_stack[-1][-1], IfStmt):
raise RuntimeError('else_if() must be called after if_then() or else_if()')

# put the current one into the scope stack
self.sb.scope_stack[-1].append(self.stmt)
self.sb.scope_stack.append([])

def __exit__(self, exc_type, exc_val, exc_tb):
# exit the then body of the current if statement
self.stmt.then_body = SeqStmt(self.sb.scope_stack.pop())
self.sb.scope_stack[-1].pop()

cur: IfStmt = self.sb.scope_stack[-1][-1]
while cur.else_body is not None:
if not isinstance(cur.else_body, IfStmt):
raise RuntimeError('else_if() must be called after if_then() or else_if()')
cur = cur.else_body

cur.else_body = self.stmt


class OtherwiseScope:
def __init__(self, sb):
self.sb: StmtBuilder = sb

def __enter__(self):
if not isinstance(self.sb.scope_stack[-1][-1], IfStmt):
raise RuntimeError('otherwise() must be called after if_then() or else_if()')
self.sb.scope_stack.append([])

def __exit__(self, exc_type, exc_val, exc_tb):
else_body = SeqStmt(self.sb.scope_stack.pop())
cur: IfStmt = self.sb.scope_stack[-1][-1]
while cur.else_body is not None:
if not isinstance(cur.else_body, IfStmt):
raise RuntimeError('otherwise() must be called after if_then() or else_if()')
cur = cur.else_body
cur.else_body = else_body


class StmtBuilder:
def __init__(self):
# the structure of scope_stack:
Expand All @@ -65,40 +114,45 @@ def _name_index_vars(num_vars: int) -> List[str]:
iter_names = [f'i{idx}' for idx in range(num_vars)]
return iter_names

def let(self, v: Union[str, Var], value: Union[int, Expr]) -> StmtScope:
if isinstance(v, str):
v = var(v)
return StmtScope(self, stmts=LetStmt(v, value), ret=v)

# singleton statements
def declare(self, v: Var, init: Optional[Expr] = None, scope=None):
self.append(DeclareStmt(v, init, scope=scope))
return v

def buffer_store(self, buf: Expr, indices: Sequence[Union[Expr, int]], value: Expr):
self.append(BufferStoreStmt(buf, convert(indices), value))

def assign(self, dst: Var, value: Expr):
self.append(AssignStmt(dst, value))

def brk(self):
self.append(BreakStmt())

# scope statements
def let(self, v: Union[str, Var], value: Union[int, Expr]) -> StmtScope:
if isinstance(v, str):
v = var(v)
return StmtScope(self, stmts=LetStmt(v, value), ret=v)

def lets(self, bind_vars: Sequence[Union[str, Var]], values: Sequence[Union[int, Expr]]) -> StmtScope:
assert len(bind_vars) == len(values)
bind_vars = [var(v) if isinstance(v, str) else v for v in bind_vars]
bind_values = [convert(value) for value in values]
seq_let_stmt = LetStmt(bind_vars, bind_values, body=1)
return StmtScope(self, stmts=seq_let_stmt, ret=bind_vars)
return StmtScope(self, stmts=LetStmt(bind_vars, bind_values, body=1), ret=bind_vars)

def for_loop(self, v: Union[str, Var], extent: Union[int, Expr], attr: str = '.') -> StmtScope:
if isinstance(v, str):
v = var(v)
return StmtScope(self, stmts=ForStmt(v, extent, attr=ForStmtAttr.parse(attr, num_loops=1)[0]), ret=v)

def if_then(self, cond: Union[bool, Expr]) -> StmtScope:
return StmtScope(self, stmts=[IfStmt(cond)], ret=None)
return StmtScope(self, stmts=IfStmt(cond), ret=None)

def else_if(self, cond: Union[bool, Expr]) -> ElseIfScope:
return ElseIfScope(self, IfStmt(cond))

def otherwise(self) -> StmtScope:
assert len(self.scope_stack[-1]) > 0
if_stmt = self.scope_stack[-1].pop()
assert isinstance(if_stmt, IfStmt)
assert if_stmt.then_body is not None
assert if_stmt.else_body is None
return StmtScope(self, stmts=if_stmt, ret=None)
def otherwise(self) -> OtherwiseScope:
return OtherwiseScope(self)

def for_mapping(
self,
Expand All @@ -109,25 +163,26 @@ def for_mapping(
if worker is None:
if not isinstance(mapping, RepeatTaskMapping):
raise ValueError('worker must be specified for non-repeat mapping')
worker = 0
worker = int32.zero
if iter_names is None:
iter_names = self._name_index_vars(len(mapping.task_shape))
iter_vars = [var(name) for name in iter_names]
return StmtScope(self, stmts=ForMappingStmt(iter_vars, mapping, worker, cast(Stmt, None)), ret=iter_vars)

def for_grid(self, shape: List[Union[Expr, int]]) -> StmtScope:
iter_names = self._name_index_vars(len(shape))
iter_vars = [var(name) for name in iter_names]
mapping = repeat_map(shape)
return StmtScope(self, stmts=ForMappingStmt(iter_vars, mapping, int32(0), cast(Stmt, None)), ret=iter_vars)

def assign(self, dst: Var, value: Expr):
self.append(AssignStmt(dst, value))
return self.for_mapping(mapping=repeat_map(shape), iter_names=self._name_index_vars(len(shape)), worker=0)

def for_range(self, extent: Union[Expr, int]):
iter_var = var('i')
return StmtScope(self, stmts=ForStmt(iter_var, extent), ret=iter_var)

def while_loop(self, cond: Expr):
return StmtScope(self, stmts=WhileStmt(cond, body=None), ret=None)

def ret(self, value: Optional[Expr] = None):
self.append(ReturnStmt(value))

# utils
def append(self, stmt: Union[Stmt, Expr, Sequence[Stmt]]):
if stmt is None:
return
Expand All @@ -140,16 +195,16 @@ def append(self, stmt: Union[Stmt, Expr, Sequence[Stmt]]):
for s in stmt:
self.append(s)

def enter_body(self, stmt: Union[IfStmt, ForStmt, LetStmt]):
def enter_body(self, stmt: Union[IfStmt, ForStmt, LetStmt, WhileStmt]):
self.scope_stack[-1].append(stmt)
self.scope_stack.append([])

def exit_body(self):
body = SeqStmt(self.scope_stack.pop())
assert len(self.scope_stack) > 0
last_stmt = self.scope_stack[-1][-1]
if isinstance(last_stmt, (ForStmt, LetStmt)):
assert last_stmt.body is None or last_stmt.body == 1
if isinstance(last_stmt, (ForStmt, LetStmt, WhileStmt)):
assert last_stmt.body is None
last_stmt.body = body
elif isinstance(last_stmt, IfStmt):
if last_stmt.then_body is None:
Expand Down
112 changes: 69 additions & 43 deletions python/hidet/transforms/lower_integer_subbyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,16 @@
# pylint: disable=unused-variable
from typing import Dict, List, Union, Tuple

from hidet.ir.tools import infer_type, simplify
from hidet.ir.tools import TypeInfer, simplify
from hidet.ir.type import BaseType, DataType, TensorType, TensorPointerType, PointerType
from hidet.ir.dtypes import i32
from hidet.ir.dtypes import i32, boolean
from hidet.ir.expr import Var, Expr, Add, TensorElement, Address, Constant, Cast, var, cast, bitwise_not
from hidet.ir.stmt import (
Stmt,
DeclareStmt,
AssignStmt,
LetStmt,
EvaluateStmt,
BufferStoreStmt,
SeqStmt,
BlackBoxStmt,
WhileStmt,
DeclareScope,
)

from hidet.ir.primitives.cuda.atomic import atomic_cas
from hidet.ir.stmt import Stmt, DeclareStmt, AssignStmt, LetStmt, EvaluateStmt, BufferStoreStmt, SeqStmt, DeclareScope
from hidet.ir.func import Function
from hidet.ir.module import IRModule
from hidet.ir.functors import IRRewriter
from hidet.ir.builders import StmtBuilder
from hidet.transforms import Pass
from hidet.utils.py import is_power_of_two

Expand Down Expand Up @@ -72,6 +62,7 @@ class LowerIntegerSubbyteRewriter(IRRewriter):
# int4b c = a + b ==> not allowed
def __init__(self):
super().__init__()
self.type_infer = TypeInfer()
self.old2new: Dict[Var, Var] = {}
self.stmts: List[Stmt] = []
self.var2scope: Dict[Var, DeclareScope] = {}
Expand All @@ -81,7 +72,7 @@ def auto_var(self, v: Var = None, hint: str = None, e: Expr = None):
if v is not None:
self.stmts.append(DeclareStmt(v))
return v
v_ty = infer_type(e)
v_ty = self.type_infer(e)
v = var(hint, v_ty)
self.stmts.append(DeclareStmt(v, e))
return v
Expand Down Expand Up @@ -119,7 +110,7 @@ def _get_subbyte_value(self, dtype: DataType, base: Var, offset: Expr):
idx = simplify(offset // divisor)
offset_ = simplify(offset % divisor)
mask = storage_ty.constant(dtype.bits_mask)
return (base[idx] >> (offset_ * dtype_bits)) & mask
return storage_ty((base[idx] >> (offset_ * dtype_bits)) & mask)

def _set_subbyte_value(self, dtype: DataType, base: Var, offset: Expr, value: Expr):
storage_ty = dtype.storage
Expand All @@ -130,29 +121,54 @@ def _set_subbyte_value(self, dtype: DataType, base: Var, offset: Expr, value: Ex
raise TypeError(f"data type not supported yet(got:{dtype})")
idx = simplify(offset // divisor)
offset_ = simplify(offset % divisor)
value_ty = infer_type(value)
assert value_ty == storage_ty
value_ty = self.type_infer(value)
assert value_ty == storage_ty, f"expected {storage_ty}, but got {value_ty} (value={value})"
mask = storage_ty.constant(dtype.bits_mask)
item = self.auto_var(hint="item", e=value & mask)
updated_mask = self.auto_var(hint="updated_mask", e=bitwise_not(mask << (offset_ * dtype_bits)))
new_bits = self.auto_var(hint="new_bits", e=item << (offset_ * dtype_bits))
updated_mask = self.auto_var(hint="updated_mask", e=bitwise_not(storage_ty(mask << (offset_ * dtype_bits))))
new_bits = self.auto_var(hint="new_bits", e=storage_ty(item << (offset_ * dtype_bits)))

from hidet.ir.dtypes import u32, u16, u8

from hidet.ir.dtypes import u32, u16
if base not in self.var2scope:
raise NotImplementedError('Can not determine the scope of {}'.format(base))

if self.var2scope[base].is_memory():
if not any(storage_ty is ty for ty in [i32, u32, u16]):
raise NotImplementedError(
"writing subbyte data to memory requires the storage type must be"
" int32, uint32, or uint16 due to atomicCAS, but got({storage_ty})"
)
original = self.auto_var(hint="original", e=storage_ty.zero)
updated = self.auto_var(hint="updated", e=storage_ty.zero)
body = []
body.append(AssignStmt(original, base[idx]))
body.append(AssignStmt(updated, (original & updated_mask) | new_bits))
body.append(BlackBoxStmt("atomicCAS({}, {}, {});", ~base[idx], original, updated))
body = SeqStmt(body)
self.stmts.append(WhileStmt(original == updated, body))
if storage_ty.nbits == 8:
sb = StmtBuilder()

original = sb.declare(Var("original", u16))
result = sb.declare(Var("result", u16))
with sb.while_loop(boolean.true):
updated_value = (
((((original >> (8 * (idx % 2))) & u8(0xFF)) & updated_mask) | new_bits) << (8 * (idx % 2))
) | (original & (u16(0xFF) << (1 - idx % 2)))
updated = sb.declare(Var("updated", storage_ty), init=updated_value)
u16_ptr = cast(cast(base, ~u16) + (idx >> 1), ~u16)
sb.assign(original, value=u16_ptr[0])
sb.assign(result, value=atomic_cas(u16_ptr, compare=original, value=updated))
with sb.if_then(result == original):
sb.brk()
self.stmts.append(sb.finish())
else:
if not any(storage_ty is ty for ty in [i32, u32, u16]):
raise NotImplementedError(
"writing subbyte data to memory requires the storage type must be"
f" int32, uint32, or uint16 due to atomicCAS, but got({storage_ty})"
)

sb = StmtBuilder()

original = sb.declare(Var("original", storage_ty))
result = sb.declare(Var("result", storage_ty))
with sb.while_loop(boolean.true):
sb.assign(original, value=base[idx])
updated = sb.declare(Var("updated", storage_ty), init=(original & updated_mask) | new_bits)
sb.assign(result, value=atomic_cas(~base[idx], compare=original, value=updated))
with sb.if_then(result == original):
sb.brk()

self.stmts.append(sb.finish())
else:
assert self.var2scope[base].is_register()
original = self.auto_var(hint="original", e=base[idx])
Expand Down Expand Up @@ -199,7 +215,7 @@ def visit_Constant(self, e: Constant):

def visit_TensorElement(self, e: TensorElement):
if isinstance(e.base, Var):
base_ty = infer_type(e.base)
base_ty = self.type_infer(e.base)
if is_pointer_type(base_ty):
dtype = get_pointer_base_type(base_ty)
if is_integer_subbyte(dtype):
Expand All @@ -213,7 +229,7 @@ def visit_Address(self, e: Address):
if isinstance(e.expr, TensorElement):
base = e.expr.base
if isinstance(base, Var):
base_ty = infer_type(base)
base_ty = self.type_infer(base)
if is_pointer_type(base_ty):
dtype = get_pointer_base_type(base_ty)
if is_integer_subbyte(dtype):
Expand All @@ -237,7 +253,7 @@ def _cast_int(self, dtype: DataType, expr: Expr):
return (int_data << shift) >> shift

def visit_Cast(self, e: Cast):
expr_ty = infer_type(e.expr)
expr_ty = self.type_infer(e.expr)
if is_integer_subbyte(expr_ty):
if is_integer_subbyte(e.target_type):
raise NotImplementedError(f"casting from {expr_ty} to {e.target_type} is not supported yet")
Expand Down Expand Up @@ -278,8 +294,8 @@ def _subbyte_pointer_add(self, dtype: DataType, ptr: Union[Expr, Tuple[Expr]], o
return ptr, offset

def visit_Add(self, e: Add):
a_ty = infer_type(e.a)
b_ty = infer_type(e.b)
a_ty = self.type_infer(e.a)
b_ty = self.type_infer(e.b)
if isinstance(a_ty, PointerType) and is_integer_subbyte(a_ty.base_type):
self.recursive_depth += 1
a = self.visit(e.a)
Expand Down Expand Up @@ -317,9 +333,12 @@ def visit_LetStmt(self, stmt: LetStmt):

def visit_BufferStoreStmt(self, stmt: BufferStoreStmt):
if isinstance(stmt.buf, Var):
buf_ty = infer_type(stmt.buf)
if isinstance(buf_ty, TensorType):
dtype = buf_ty.dtype
buf_ty = self.type_infer(stmt.buf)
if isinstance(buf_ty, (TensorType, PointerType)):
if isinstance(buf_ty, TensorType):
dtype = buf_ty.dtype
else:
dtype = buf_ty.base_type
if is_integer_subbyte(dtype):
buf = self.visit(stmt.buf)
indices = self.visit(stmt.indices)
Expand All @@ -330,6 +349,13 @@ def visit_BufferStoreStmt(self, stmt: BufferStoreStmt):
self.append_stmt(super().visit_BufferStoreStmt(stmt))
return self.flatten_stmts(self.flush_stmts())

def visit_Function(self, func: Function):
params = self.visit(func.params)
for param in params:
if is_pointer_type(param.type):
self.var2scope[param] = DeclareScope.Global
return super().visit_Function(func)


class LowerIntegerSubbytePass(Pass):
def process_func(self, func: Function) -> Function:
Expand Down
Loading

0 comments on commit b1f6177

Please sign in to comment.