Skip to content

Commit

Permalink
[Primitives] Add CUDA primitives: prmt, lop3, f16x2 sub and fma, and …
Browse files Browse the repository at this point in the history
…barrier (#414)

Add primitives:
- `prmt`
- `lop3`
- `sub_f16x2`, `fma_f16x2`
- `barrier`

See the tests and function documentation for the usage of each
primitive.
  • Loading branch information
yaoyaoding authored Aug 15, 2024
1 parent 3e323e3 commit b1681b7
Show file tree
Hide file tree
Showing 10 changed files with 519 additions and 15 deletions.
19 changes: 7 additions & 12 deletions python/hidet/ir/primitives/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,11 @@
from .memcpy import memcpy_async
from .errchk import check_cuda_error
from .cp_async import cp_async, cp_async_commit_group, cp_async_wait_group, cp_async_wait_all
from .barrier import (
mbarrier_arrive,
mbarrier_arrive_and_expect_tx,
mbarrier_expect_transaction,
mbarrier_complete_transaction,
mbarrier_init,
mbarrier_invalidate,
mbarrier_test_wait,
mbarrier_try_wait,
mbarrier_wait,
cp_async_barrier_arrive,
)
from .barrier import mbarrier_arrive, mbarrier_arrive_and_expect_tx, mbarrier_expect_transaction
from .barrier import mbarrier_complete_transaction, mbarrier_init, mbarrier_invalidate, mbarrier_test_wait
from .barrier import mbarrier_try_wait, mbarrier_wait, cp_async_barrier_arrive
from .barrier import barrier_sync, barrier_arrive
from .tensor_map import create_tensor_map
from .half import sub_f16x2, fma_f16x2
from .lop3 import lop3
from .prmt import prmt
98 changes: 95 additions & 3 deletions python/hidet/ir/primitives/cuda/barrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=line-too-long
from typing import Optional
from typing import Optional, Union
from hidet.utils import initialize
from hidet.ir.expr import Constant, Expr
from hidet.ir.stmt import asm
from hidet.ir.func import Function
from hidet.ir.primitives.func import register_primitive_function
from hidet.ir.primitives.func import register_primitive_function, call_primitive_func
from hidet.ir.primitives.cuda.funcs import call_cuda
from hidet.lang import script, i32, u32, u64, attrs
from hidet.lang import script, i32, u32, u64, int32, attrs


@initialize()
Expand Down Expand Up @@ -233,6 +233,36 @@ def cuda_cp_async_barrier_arrive(mbar: ~u64):
register_primitive_function(name=cuda_cp_async_barrier_arrive.name, func_or_type=cuda_cp_async_barrier_arrive)


@initialize()
def register_barrier():
for aligned in [False, True]:
for mode in ['arrive', 'sync', 'sync_all']:
func_name = 'barrier_{}{}'.format(mode, '_aligned' if aligned else '')

if mode == 'sync_all':
template = 'barrier.sync{} %0;'.format('.aligned' if aligned else '')

@script
def barrier_func(barrier: int32):
attrs.func_kind = 'cuda_internal'
attrs.func_name = func_name

asm(template, inputs=[barrier], is_volatile=True)

else:
template = 'barrier.sync{} %0, %1;'.format('.aligned' if aligned else '')

@script
def barrier_func(barrier: int32, count: int32):
attrs.func_kind = 'cuda_internal'
attrs.func_name = func_name

asm(template, inputs=[barrier, count], is_volatile=True)

assert isinstance(barrier_func, Function)
register_primitive_function(name=barrier_func.name, func_or_type=barrier_func)


def mbarrier_init(mbar: Expr, arrive_count: Expr):
"""
Init a barrier
Expand Down Expand Up @@ -346,3 +376,65 @@ def cp_async_barrier_arrive(mbar: Expr):
cp async barrier arrive
"""
return call_cuda('cp_async_barrier_arrive', [mbar])


def _barrier(barrier: Union[int, Expr], count: Optional[Union[int, Expr]], aligned: bool, mode: str):
# resolve function name
func_name = 'barrier_{}{}'.format(mode, '_aligned' if aligned else '')

# call the function
args = [barrier]
if count is not None:
args.append(count)
return call_primitive_func(func_name, args=args)


def barrier_sync(barrier: Union[int, Expr], count: Optional[Union[int, Expr]] = None, aligned: bool = False):
"""
Performs barrier synchronization and communication within a CTA.
The threads will synchronize at the named barrier.
See Also
--------
The PTX ISA documentation for the `barrier` instruction:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar-barrier
Parameters
----------
barrier:
The named barrier to synchronize on. This must be an integer from 0 to 15.
count: Optional[int]
The number of threads to synchronize. If not provided, all threads in the CTA will synchronize.
aligned:
When specified, it indicates that all threads in CTA will execute the same barrier instruction.
"""
mode = 'sync_all' if count is None else 'sync'
return _barrier(barrier, count, aligned, mode=mode)


def barrier_arrive(barrier: Union[int, Expr], count: Union[int, Expr], aligned: bool = False):
"""
Performs barrier synchronization and communication within a CTA.
The threads will mark their arrival at the named barrier but will not be blocked.
See Also
--------
The PTX ISA documentation for the `barrier` instruction:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar-barrier
Parameters
----------
barrier: Union[int, Expr]
The named barrier to synchronize on. This must be an integer from 0 to 15.
count: Union[int, Expr]
The number of threads to synchronize.
aligned: bool
When specified, it indicates that all threads in CTA will execute the same barrier instruction.
"""
return _barrier(barrier, count, aligned, mode='arrive')
68 changes: 68 additions & 0 deletions python/hidet/ir/primitives/cuda/half.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from hidet.ir.expr import Expr
from hidet.ir.func import Function
from hidet.ir.primitives.func import register_primitive_function, call_primitive_func
from hidet.utils import initialize


@initialize()
def register_functions():
from hidet.lang import attrs, script, asm, cast # pylint: disable=import-outside-toplevel
from hidet.lang.types import uint32, void_p

@script
def sub_f16x2_(d: void_p, a: uint32, b: uint32):
attrs.func_kind = 'cuda_internal'
attrs.func_name = 'sub_f16x2'

asm('sub.f16x2 %0, %1, %2;', outputs=[cast(d, ~uint32)[0]], inputs=[a, b], is_volatile=True)

@script
def fma_f16x2_(d: void_p, a: uint32, b: uint32, c: uint32):
attrs.func_kind = 'cuda_internal'
attrs.func_name = 'fma_f16x2'

asm('fma.rn.f16x2 %0, %1, %2, %3;', outputs=[cast(d, ~uint32)[0]], inputs=[a, b, c], is_volatile=True)

funcs = [sub_f16x2_, fma_f16x2_]
for func in funcs:
assert isinstance(func, Function)
register_primitive_function(name=func.name, func_or_type=func)


def sub_f16x2(d: Expr, a: Expr, b: Expr):
"""
Subtract two f16x2 values and store the result in `d`.
Expect `d` to be an uint32 pointer while `a` an `b` are uint32 values, all of them will be interpreted as f16x2.
Parameters
----------
d: Expr
The pointer to the f16x2 result, stored with uint32 data type.
a: Expr
The first f16x2 operand stored with uint32 data type.
b: Expr
The second f16x2 operand stored with uint32 data type.
"""
return call_primitive_func('sub_f16x2', args=[d, a, b])


def fma_f16x2(d: Expr, a: Expr, b: Expr, c: Expr):
"""
Multiply two f16x2 values and add the third f16x2 value and store the result in `d`.
Expect `d` to be an uint32 pointer while `a`, `b`, and `c` are uint32 values, all of them will be interpreted as
f16x2.
Parameters
----------
d: Expr
The pointer to the f16x2 result, stored with uint32 data type.
a: Expr
The first f16x2 operand stored with uint32 data type.
b: Expr
The second f16x2 operand stored with uint32 data type.
c: Expr
The third f16x2 operand stored with uint32 data type.
"""
return call_primitive_func('fma_f16x2', args=[d, a, b, c])
36 changes: 36 additions & 0 deletions python/hidet/ir/primitives/cuda/lop3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from hidet.ir.expr import Expr, cast
from hidet.ir.stmt import asm
from hidet.ir.dtypes import uint32


def lop3(d: Expr, a: Expr, b: Expr, c: Expr, *, imm_lut: int):
"""
Perform a logical operation on three 32-bit values and store the result in `d`.
The logical operation is determined by the immediate value `imm_lut`.
See the PTX ISA documentation for the `lop3` instruction for more information:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-lop3
Parameters
----------
d: Expr
The pointer to the 32-bit result.
a: Expr
The first 32-bit operand.
b: Expr
The second 32-bit operand.
c: Expr
The third 32-bit operand.
imm_lut: int
The immediate value that determines the logical operation. Given logical operation `f(a, b, c)`, the
immediate value `imm_lut` should be set to `f(0xF0, 0xCC, 0xAA)` to indicate the logical operation.
"""
assert 0 <= imm_lut <= 255

return asm(
'lop3.b32 %0, %1, %2, %3, {};'.format(imm_lut),
outputs=[cast(d, ~uint32)[0]],
inputs=[a, b, c, imm_lut],
is_volatile=True,
)
65 changes: 65 additions & 0 deletions python/hidet/ir/primitives/cuda/prmt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import Optional

from hidet.ir.expr import Expr
from hidet.ir.func import Function
from hidet.ir.primitives.func import register_primitive_function, call_primitive_func
from hidet.utils import initialize


def resolve_func_name(mode: Optional[str] = None) -> str:
if mode is None:
return 'prmt_b32'
else:
return 'prmt_b32_{}'.format(mode)


def resolve_inst_template(mode: Optional[str] = None) -> str:
if mode is None:
return 'prmt.b32 %0, %1, %2, %3;'
else:
return 'prmt.b32.{} %0, %1, %2, %3;'.format(mode)


@initialize()
def register_functions():
from hidet.lang import attrs, script, asm, cast # pylint: disable=import-outside-toplevel
from hidet.lang.types import uint32, void_p

for mode in [None, 'f4e', 'b4e', 'rc8', 'ecl', 'ecr', 'rc16']:
template = resolve_inst_template(mode)

@script
def prmt_primitive(d: void_p, a: uint32, b: uint32, c: uint32):
attrs.func_kind = 'cuda_internal'
attrs.func_name = resolve_func_name(mode)

asm(template, outputs=[cast(d, ~uint32)[0]], inputs=[a, b, c], is_volatile=True)

assert isinstance(prmt_primitive, Function)
register_primitive_function(name=prmt_primitive.name, func_or_type=prmt_primitive)


def prmt(d: Expr, a: Expr, b: Expr, c: Expr, *, mode: Optional[str] = None):
"""
Perform a byte-level permutation operation on two 32-bit values and store the result in `d`.
The permutation operation is determined by the permutation mode `mode`.
See Also the PTX ISA documentation for the `prmt` instruction for more information:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt
Parameters
----------
d: Expr
The pointer to the 32-bit result.
a: Expr
The first uint32 operand.
b: Expr
The second uint32 operand.
c: Expr
The control operand.
mode: Optional[str]
The permutation mode. If not provided, the default mode is used.
"""
assert mode in [None, 'f4e', 'b4e', 'rc8', 'ecl', 'ecr', 'rc16']
return call_primitive_func(resolve_func_name(mode), args=[d, a, b, c])
55 changes: 55 additions & 0 deletions python/hidet/testing/capture_stdout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import ContextManager
import contextlib
import os
import tempfile


class CapturedStdout:
def __init__(self):
self.content: str = ""

def __str__(self):
return self.content

def set_output(self, content: str):
self.content = content


@contextlib.contextmanager
def capture_stdout() -> ContextManager[CapturedStdout]:
"""
capture the content that has been printed to stdout in the context
We did not use `contextlib.redirect_stdout` nor similar functionality in pytest because it does not work with
`printf(...)` in c/c++.
usage:
```
with capture_stdout() as captured:
print("hello world")
assert captured.content == "hello world\n"
```
"""
captured_stdout = CapturedStdout()

with tempfile.TemporaryFile(mode='w+') as f:
new_fd = f.fileno()
assert new_fd != -1

original_fd = os.dup(1)
assert original_fd != -1

ret = os.dup2(new_fd, 1)
assert ret != -1

yield captured_stdout

ret = os.dup2(original_fd, 1)
assert ret != -1

os.close(original_fd)
f.flush()
os.fsync(new_fd)
ret = f.seek(0)
captured_content = f.read()
captured_stdout.set_output(captured_content)
Loading

0 comments on commit b1681b7

Please sign in to comment.