diff --git a/python/hidet/ir/primitives/cuda/__init__.py b/python/hidet/ir/primitives/cuda/__init__.py index 3216f8157..4038ee012 100644 --- a/python/hidet/ir/primitives/cuda/__init__.py +++ b/python/hidet/ir/primitives/cuda/__init__.py @@ -27,16 +27,11 @@ from .memcpy import memcpy_async from .errchk import check_cuda_error from .cp_async import cp_async, cp_async_commit_group, cp_async_wait_group, cp_async_wait_all -from .barrier import ( - mbarrier_arrive, - mbarrier_arrive_and_expect_tx, - mbarrier_expect_transaction, - mbarrier_complete_transaction, - mbarrier_init, - mbarrier_invalidate, - mbarrier_test_wait, - mbarrier_try_wait, - mbarrier_wait, - cp_async_barrier_arrive, -) +from .barrier import mbarrier_arrive, mbarrier_arrive_and_expect_tx, mbarrier_expect_transaction +from .barrier import mbarrier_complete_transaction, mbarrier_init, mbarrier_invalidate, mbarrier_test_wait +from .barrier import mbarrier_try_wait, mbarrier_wait, cp_async_barrier_arrive +from .barrier import barrier_sync, barrier_arrive from .tensor_map import create_tensor_map +from .half import sub_f16x2, fma_f16x2 +from .lop3 import lop3 +from .prmt import prmt diff --git a/python/hidet/ir/primitives/cuda/barrier.py b/python/hidet/ir/primitives/cuda/barrier.py index adb849b7d..5cccd0f65 100644 --- a/python/hidet/ir/primitives/cuda/barrier.py +++ b/python/hidet/ir/primitives/cuda/barrier.py @@ -10,14 +10,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # pylint: disable=line-too-long -from typing import Optional +from typing import Optional, Union from hidet.utils import initialize from hidet.ir.expr import Constant, Expr from hidet.ir.stmt import asm from hidet.ir.func import Function -from hidet.ir.primitives.func import register_primitive_function +from hidet.ir.primitives.func import register_primitive_function, call_primitive_func from hidet.ir.primitives.cuda.funcs import call_cuda -from hidet.lang import script, i32, u32, u64, attrs +from hidet.lang import script, i32, u32, u64, int32, attrs @initialize() @@ -233,6 +233,36 @@ def cuda_cp_async_barrier_arrive(mbar: ~u64): register_primitive_function(name=cuda_cp_async_barrier_arrive.name, func_or_type=cuda_cp_async_barrier_arrive) +@initialize() +def register_barrier(): + for aligned in [False, True]: + for mode in ['arrive', 'sync', 'sync_all']: + func_name = 'barrier_{}{}'.format(mode, '_aligned' if aligned else '') + + if mode == 'sync_all': + template = 'barrier.sync{} %0;'.format('.aligned' if aligned else '') + + @script + def barrier_func(barrier: int32): + attrs.func_kind = 'cuda_internal' + attrs.func_name = func_name + + asm(template, inputs=[barrier], is_volatile=True) + + else: + template = 'barrier.sync{} %0, %1;'.format('.aligned' if aligned else '') + + @script + def barrier_func(barrier: int32, count: int32): + attrs.func_kind = 'cuda_internal' + attrs.func_name = func_name + + asm(template, inputs=[barrier, count], is_volatile=True) + + assert isinstance(barrier_func, Function) + register_primitive_function(name=barrier_func.name, func_or_type=barrier_func) + + def mbarrier_init(mbar: Expr, arrive_count: Expr): """ Init a barrier @@ -346,3 +376,65 @@ def cp_async_barrier_arrive(mbar: Expr): cp async barrier arrive """ return call_cuda('cp_async_barrier_arrive', [mbar]) + + +def _barrier(barrier: Union[int, Expr], count: Optional[Union[int, Expr]], aligned: bool, mode: str): + # resolve function name + func_name = 'barrier_{}{}'.format(mode, '_aligned' if aligned else '') + + # call the function + args = [barrier] + if count is not None: + args.append(count) + return call_primitive_func(func_name, args=args) + + +def barrier_sync(barrier: Union[int, Expr], count: Optional[Union[int, Expr]] = None, aligned: bool = False): + """ + Performs barrier synchronization and communication within a CTA. + + The threads will synchronize at the named barrier. + + See Also + -------- + The PTX ISA documentation for the `barrier` instruction: + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar-barrier + + Parameters + ---------- + barrier: + The named barrier to synchronize on. This must be an integer from 0 to 15. + + count: Optional[int] + The number of threads to synchronize. If not provided, all threads in the CTA will synchronize. + + aligned: + When specified, it indicates that all threads in CTA will execute the same barrier instruction. + """ + mode = 'sync_all' if count is None else 'sync' + return _barrier(barrier, count, aligned, mode=mode) + + +def barrier_arrive(barrier: Union[int, Expr], count: Union[int, Expr], aligned: bool = False): + """ + Performs barrier synchronization and communication within a CTA. + + The threads will mark their arrival at the named barrier but will not be blocked. + + See Also + -------- + The PTX ISA documentation for the `barrier` instruction: + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar-barrier + + Parameters + ---------- + barrier: Union[int, Expr] + The named barrier to synchronize on. This must be an integer from 0 to 15. + + count: Union[int, Expr] + The number of threads to synchronize. + + aligned: bool + When specified, it indicates that all threads in CTA will execute the same barrier instruction. + """ + return _barrier(barrier, count, aligned, mode='arrive') diff --git a/python/hidet/ir/primitives/cuda/half.py b/python/hidet/ir/primitives/cuda/half.py new file mode 100644 index 000000000..21bab7353 --- /dev/null +++ b/python/hidet/ir/primitives/cuda/half.py @@ -0,0 +1,68 @@ +from hidet.ir.expr import Expr +from hidet.ir.func import Function +from hidet.ir.primitives.func import register_primitive_function, call_primitive_func +from hidet.utils import initialize + + +@initialize() +def register_functions(): + from hidet.lang import attrs, script, asm, cast # pylint: disable=import-outside-toplevel + from hidet.lang.types import uint32, void_p + + @script + def sub_f16x2_(d: void_p, a: uint32, b: uint32): + attrs.func_kind = 'cuda_internal' + attrs.func_name = 'sub_f16x2' + + asm('sub.f16x2 %0, %1, %2;', outputs=[cast(d, ~uint32)[0]], inputs=[a, b], is_volatile=True) + + @script + def fma_f16x2_(d: void_p, a: uint32, b: uint32, c: uint32): + attrs.func_kind = 'cuda_internal' + attrs.func_name = 'fma_f16x2' + + asm('fma.rn.f16x2 %0, %1, %2, %3;', outputs=[cast(d, ~uint32)[0]], inputs=[a, b, c], is_volatile=True) + + funcs = [sub_f16x2_, fma_f16x2_] + for func in funcs: + assert isinstance(func, Function) + register_primitive_function(name=func.name, func_or_type=func) + + +def sub_f16x2(d: Expr, a: Expr, b: Expr): + """ + Subtract two f16x2 values and store the result in `d`. + + Expect `d` to be an uint32 pointer while `a` an `b` are uint32 values, all of them will be interpreted as f16x2. + + Parameters + ---------- + d: Expr + The pointer to the f16x2 result, stored with uint32 data type. + a: Expr + The first f16x2 operand stored with uint32 data type. + b: Expr + The second f16x2 operand stored with uint32 data type. + """ + return call_primitive_func('sub_f16x2', args=[d, a, b]) + + +def fma_f16x2(d: Expr, a: Expr, b: Expr, c: Expr): + """ + Multiply two f16x2 values and add the third f16x2 value and store the result in `d`. + + Expect `d` to be an uint32 pointer while `a`, `b`, and `c` are uint32 values, all of them will be interpreted as + f16x2. + + Parameters + ---------- + d: Expr + The pointer to the f16x2 result, stored with uint32 data type. + a: Expr + The first f16x2 operand stored with uint32 data type. + b: Expr + The second f16x2 operand stored with uint32 data type. + c: Expr + The third f16x2 operand stored with uint32 data type. + """ + return call_primitive_func('fma_f16x2', args=[d, a, b, c]) diff --git a/python/hidet/ir/primitives/cuda/lop3.py b/python/hidet/ir/primitives/cuda/lop3.py new file mode 100644 index 000000000..76859abca --- /dev/null +++ b/python/hidet/ir/primitives/cuda/lop3.py @@ -0,0 +1,36 @@ +from hidet.ir.expr import Expr, cast +from hidet.ir.stmt import asm +from hidet.ir.dtypes import uint32 + + +def lop3(d: Expr, a: Expr, b: Expr, c: Expr, *, imm_lut: int): + """ + Perform a logical operation on three 32-bit values and store the result in `d`. + + The logical operation is determined by the immediate value `imm_lut`. + + See the PTX ISA documentation for the `lop3` instruction for more information: + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-lop3 + + Parameters + ---------- + d: Expr + The pointer to the 32-bit result. + a: Expr + The first 32-bit operand. + b: Expr + The second 32-bit operand. + c: Expr + The third 32-bit operand. + imm_lut: int + The immediate value that determines the logical operation. Given logical operation `f(a, b, c)`, the + immediate value `imm_lut` should be set to `f(0xF0, 0xCC, 0xAA)` to indicate the logical operation. + """ + assert 0 <= imm_lut <= 255 + + return asm( + 'lop3.b32 %0, %1, %2, %3, {};'.format(imm_lut), + outputs=[cast(d, ~uint32)[0]], + inputs=[a, b, c, imm_lut], + is_volatile=True, + ) diff --git a/python/hidet/ir/primitives/cuda/prmt.py b/python/hidet/ir/primitives/cuda/prmt.py new file mode 100644 index 000000000..b693ad955 --- /dev/null +++ b/python/hidet/ir/primitives/cuda/prmt.py @@ -0,0 +1,65 @@ +from typing import Optional + +from hidet.ir.expr import Expr +from hidet.ir.func import Function +from hidet.ir.primitives.func import register_primitive_function, call_primitive_func +from hidet.utils import initialize + + +def resolve_func_name(mode: Optional[str] = None) -> str: + if mode is None: + return 'prmt_b32' + else: + return 'prmt_b32_{}'.format(mode) + + +def resolve_inst_template(mode: Optional[str] = None) -> str: + if mode is None: + return 'prmt.b32 %0, %1, %2, %3;' + else: + return 'prmt.b32.{} %0, %1, %2, %3;'.format(mode) + + +@initialize() +def register_functions(): + from hidet.lang import attrs, script, asm, cast # pylint: disable=import-outside-toplevel + from hidet.lang.types import uint32, void_p + + for mode in [None, 'f4e', 'b4e', 'rc8', 'ecl', 'ecr', 'rc16']: + template = resolve_inst_template(mode) + + @script + def prmt_primitive(d: void_p, a: uint32, b: uint32, c: uint32): + attrs.func_kind = 'cuda_internal' + attrs.func_name = resolve_func_name(mode) + + asm(template, outputs=[cast(d, ~uint32)[0]], inputs=[a, b, c], is_volatile=True) + + assert isinstance(prmt_primitive, Function) + register_primitive_function(name=prmt_primitive.name, func_or_type=prmt_primitive) + + +def prmt(d: Expr, a: Expr, b: Expr, c: Expr, *, mode: Optional[str] = None): + """ + Perform a byte-level permutation operation on two 32-bit values and store the result in `d`. + + The permutation operation is determined by the permutation mode `mode`. + + See Also the PTX ISA documentation for the `prmt` instruction for more information: + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt + + Parameters + ---------- + d: Expr + The pointer to the 32-bit result. + a: Expr + The first uint32 operand. + b: Expr + The second uint32 operand. + c: Expr + The control operand. + mode: Optional[str] + The permutation mode. If not provided, the default mode is used. + """ + assert mode in [None, 'f4e', 'b4e', 'rc8', 'ecl', 'ecr', 'rc16'] + return call_primitive_func(resolve_func_name(mode), args=[d, a, b, c]) diff --git a/python/hidet/testing/capture_stdout.py b/python/hidet/testing/capture_stdout.py new file mode 100644 index 000000000..f08f84776 --- /dev/null +++ b/python/hidet/testing/capture_stdout.py @@ -0,0 +1,55 @@ +from typing import ContextManager +import contextlib +import os +import tempfile + + +class CapturedStdout: + def __init__(self): + self.content: str = "" + + def __str__(self): + return self.content + + def set_output(self, content: str): + self.content = content + + +@contextlib.contextmanager +def capture_stdout() -> ContextManager[CapturedStdout]: + """ + capture the content that has been printed to stdout in the context + + We did not use `contextlib.redirect_stdout` nor similar functionality in pytest because it does not work with + `printf(...)` in c/c++. + + usage: + ``` + with capture_stdout() as captured: + print("hello world") + assert captured.content == "hello world\n" + ``` + """ + captured_stdout = CapturedStdout() + + with tempfile.TemporaryFile(mode='w+') as f: + new_fd = f.fileno() + assert new_fd != -1 + + original_fd = os.dup(1) + assert original_fd != -1 + + ret = os.dup2(new_fd, 1) + assert ret != -1 + + yield captured_stdout + + ret = os.dup2(original_fd, 1) + assert ret != -1 + + os.close(original_fd) + f.flush() + os.fsync(new_fd) + ret = f.seek(0) + captured_content = f.read() + captured_stdout.set_output(captured_content) diff --git a/tests/ir/primitives/cuda/test_barrier.py b/tests/ir/primitives/cuda/test_barrier.py index 26500f55a..3927fceaa 100644 --- a/tests/ir/primitives/cuda/test_barrier.py +++ b/tests/ir/primitives/cuda/test_barrier.py @@ -25,6 +25,7 @@ mbarrier_test_wait, mbarrier_try_wait, mbarrier_wait, + barrier_sync, ) from hidet.ir.primitives.debug import printf from hidet.ir.stmt import AssertStmt, AssignStmt, BlackBoxStmt, DeclareStmt, DeclareScope, SeqStmt, WhileStmt @@ -32,6 +33,7 @@ from hidet.ir.dtypes import i32, u32, u64 from hidet.lang import attrs, script from hidet.lang.constructs.declare import shared_tensor +from hidet.testing.capture_stdout import capture_stdout def test_mbarrier_basic(): @@ -233,5 +235,89 @@ def test_mbarrier_tx_count_ops(): hidet.cuda.synchronize() +def test_barrier(): + from hidet.lang import attrs, printf, asm + from hidet.lang.cuda import threadIdx, syncthreads + from hidet.ir.primitives.cuda import barrier_sync + + with hidet.script_module() as script_module: + + num_groups = 2 + group_size = 32 + + @hidet.script + def with_barrier(): + attrs.func_kind = 'cuda_kernel' + attrs.cuda.grid_dim = 1 + attrs.cuda.block_dim = num_groups * group_size + + for i in range(num_groups): + if threadIdx.x // group_size == i: + if threadIdx.x % group_size <= 1: + asm('nanosleep.u32 1024;') + printf('group %d, thread %d, before sync\n', i, threadIdx.x % group_size) + barrier_sync(1, group_size) + if group_size - 1 - threadIdx.x % group_size <= 1: + printf('group %d, thread %d, after sync\n', i, threadIdx.x % group_size) + + barrier_sync(0, aligned=True) + syncthreads() + + @hidet.script + def without_barrier(): + attrs.func_kind = 'cuda_kernel' + attrs.cuda.grid_dim = 1 + attrs.cuda.block_dim = num_groups * group_size + + for i in range(num_groups): + if threadIdx.x // group_size == i: + if threadIdx.x % group_size <= 1: + asm('nanosleep.u32 1024;') + printf('group %d, thread %d, before sync\n', i, threadIdx.x % group_size) + if group_size - 1 - threadIdx.x % group_size <= 1: + printf('group %d, thread %d, after sync\n', i, threadIdx.x % group_size) + + barrier_sync(0, aligned=True) + syncthreads() + + @hidet.script + def launch(): + attrs.func_kind = 'public' + printf('with barrier\n') + with_barrier() + BlackBoxStmt('cudaDeviceSynchronize();') + printf('without barrier\n') + without_barrier() + BlackBoxStmt('cudaDeviceSynchronize();') + + func = script_module.build() + with capture_stdout() as captured: + func() + + assert ( + str(captured).strip() + == """ +with barrier +group 0, thread 0, before sync +group 0, thread 1, before sync +group 0, thread 30, after sync +group 0, thread 31, after sync +group 1, thread 0, before sync +group 1, thread 1, before sync +group 1, thread 30, after sync +group 1, thread 31, after sync +without barrier +group 0, thread 30, after sync +group 0, thread 31, after sync +group 0, thread 0, before sync +group 0, thread 1, before sync +group 1, thread 30, after sync +group 1, thread 31, after sync +group 1, thread 0, before sync +group 1, thread 1, before sync + """.strip() + ) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/ir/primitives/cuda/test_half.py b/tests/ir/primitives/cuda/test_half.py new file mode 100644 index 000000000..c518b2f7e --- /dev/null +++ b/tests/ir/primitives/cuda/test_half.py @@ -0,0 +1,58 @@ +import hidet +import torch + + +def test_sub_f16x2(): + from hidet.lang import attrs + from hidet.lang.types import uint32 + from hidet.ir.primitives.cuda import sub_f16x2 + + with hidet.script_module() as script_module: + + @hidet.script + def sub_f16x2_test(c: ~uint32, a: uint32, b: uint32): + attrs.func_kind = 'cuda_kernel' + attrs.cuda.grid_dim = 1 + attrs.cuda.block_dim = 32 + + sub_f16x2(c, a, b) + + kernel = script_module.build() + + c_int32 = torch.zeros([1], dtype=torch.int32, device='cuda') + a = torch.asarray([3.0, 4.0], dtype=torch.float16, device='cuda') + b = torch.asarray([1.0, 0.0], dtype=torch.float16, device='cuda') + a_int32 = a.view(torch.int32).item() + b_int32 = b.view(torch.int32).item() + kernel(c_int32, a_int32, b_int32) + c = c_int32.view(torch.float16) + assert torch.allclose(c, a - b) + + +def test_fma_f16x2(): + from hidet.lang import attrs + from hidet.lang.types import uint32 + from hidet.ir.primitives.cuda import fma_f16x2 + + with hidet.script_module() as script_module: + + @hidet.script + def fma_f16x2_test(d: ~uint32, a: uint32, b: uint32, c: uint32): + attrs.func_kind = 'cuda_kernel' + attrs.cuda.grid_dim = 1 + attrs.cuda.block_dim = 32 + + fma_f16x2(d, a, b, c) + + kernel = script_module.build() + + d_int32 = torch.zeros([1], dtype=torch.int32, device='cuda') + a = torch.asarray([3.0, 4.0], dtype=torch.float16, device='cuda') + b = torch.asarray([1.0, 5.0], dtype=torch.float16, device='cuda') + c = torch.asarray([33.0, 44.0], dtype=torch.float16, device='cuda') + a_int32 = a.view(torch.int32).item() + b_int32 = b.view(torch.int32).item() + c_int32 = c.view(torch.int32).item() + kernel(d_int32, a_int32, b_int32, c_int32) + d = d_int32.view(torch.float16) + assert torch.allclose(d, a * b + c) diff --git a/tests/ir/primitives/cuda/test_lop3.py b/tests/ir/primitives/cuda/test_lop3.py new file mode 100644 index 000000000..e4c950647 --- /dev/null +++ b/tests/ir/primitives/cuda/test_lop3.py @@ -0,0 +1,24 @@ +import torch +import hidet + + +def test_lop3(): + from hidet.lang import attrs, script + from hidet.lang.types import uint32 + from hidet.ir.primitives.cuda import lop3 + + with hidet.script_module() as script_module: + + @script + def kernel(d_ptr: ~uint32, a: uint32, b: uint32, c: uint32): + attrs.func_kind = 'cuda_kernel' + attrs.cuda.grid_dim = 1 + attrs.cuda.block_dim = 32 + + lop3(d_ptr, a, b, c, imm_lut=(0xF0 & 0xCC) | 0xAA) + + func = script_module.build() + + d = torch.empty([1], dtype=torch.int32, device='cuda') + func(d, 0xFFFFFFFF, 0x00FF00FF, 0x0E00EE00) + assert d[0] == 0x0EFFEEFF diff --git a/tests/ir/primitives/cuda/test_prmt.py b/tests/ir/primitives/cuda/test_prmt.py new file mode 100644 index 000000000..bf5b6a540 --- /dev/null +++ b/tests/ir/primitives/cuda/test_prmt.py @@ -0,0 +1,25 @@ +import torch +import hidet + + +def test_prmt(): + from hidet.lang import attrs, script + from hidet.lang.types import uint32 + from hidet.ir.primitives.cuda import prmt + + with hidet.script_module() as script_module: + + @script + def kernel(d_ptr: ~uint32, a: uint32, b: uint32, c: uint32): + attrs.func_kind = 'cuda_kernel' + attrs.cuda.grid_dim = 1 + attrs.cuda.block_dim = 32 + + prmt(d=d_ptr, a=a, b=b, c=c) + + func = script_module.build() + + d_int32 = torch.empty([1], dtype=torch.int32, device='cuda') + func(d_int32, 0x00000201, 0x00000064, 0x4140) + d_int32 = d_int32.item() + assert d_int32 == 0x64026401