-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Primitives] Add CUDA primitives: prmt, lop3, f16x2 sub and fma, and …
…barrier (#414) Add primitives: - `prmt` - `lop3` - `sub_f16x2`, `fma_f16x2` - `barrier` See the tests and function documentation for the usage of each primitive.
- Loading branch information
1 parent
3e323e3
commit b1681b7
Showing
10 changed files
with
519 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from hidet.ir.expr import Expr | ||
from hidet.ir.func import Function | ||
from hidet.ir.primitives.func import register_primitive_function, call_primitive_func | ||
from hidet.utils import initialize | ||
|
||
|
||
@initialize() | ||
def register_functions(): | ||
from hidet.lang import attrs, script, asm, cast # pylint: disable=import-outside-toplevel | ||
from hidet.lang.types import uint32, void_p | ||
|
||
@script | ||
def sub_f16x2_(d: void_p, a: uint32, b: uint32): | ||
attrs.func_kind = 'cuda_internal' | ||
attrs.func_name = 'sub_f16x2' | ||
|
||
asm('sub.f16x2 %0, %1, %2;', outputs=[cast(d, ~uint32)[0]], inputs=[a, b], is_volatile=True) | ||
|
||
@script | ||
def fma_f16x2_(d: void_p, a: uint32, b: uint32, c: uint32): | ||
attrs.func_kind = 'cuda_internal' | ||
attrs.func_name = 'fma_f16x2' | ||
|
||
asm('fma.rn.f16x2 %0, %1, %2, %3;', outputs=[cast(d, ~uint32)[0]], inputs=[a, b, c], is_volatile=True) | ||
|
||
funcs = [sub_f16x2_, fma_f16x2_] | ||
for func in funcs: | ||
assert isinstance(func, Function) | ||
register_primitive_function(name=func.name, func_or_type=func) | ||
|
||
|
||
def sub_f16x2(d: Expr, a: Expr, b: Expr): | ||
""" | ||
Subtract two f16x2 values and store the result in `d`. | ||
Expect `d` to be an uint32 pointer while `a` an `b` are uint32 values, all of them will be interpreted as f16x2. | ||
Parameters | ||
---------- | ||
d: Expr | ||
The pointer to the f16x2 result, stored with uint32 data type. | ||
a: Expr | ||
The first f16x2 operand stored with uint32 data type. | ||
b: Expr | ||
The second f16x2 operand stored with uint32 data type. | ||
""" | ||
return call_primitive_func('sub_f16x2', args=[d, a, b]) | ||
|
||
|
||
def fma_f16x2(d: Expr, a: Expr, b: Expr, c: Expr): | ||
""" | ||
Multiply two f16x2 values and add the third f16x2 value and store the result in `d`. | ||
Expect `d` to be an uint32 pointer while `a`, `b`, and `c` are uint32 values, all of them will be interpreted as | ||
f16x2. | ||
Parameters | ||
---------- | ||
d: Expr | ||
The pointer to the f16x2 result, stored with uint32 data type. | ||
a: Expr | ||
The first f16x2 operand stored with uint32 data type. | ||
b: Expr | ||
The second f16x2 operand stored with uint32 data type. | ||
c: Expr | ||
The third f16x2 operand stored with uint32 data type. | ||
""" | ||
return call_primitive_func('fma_f16x2', args=[d, a, b, c]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from hidet.ir.expr import Expr, cast | ||
from hidet.ir.stmt import asm | ||
from hidet.ir.dtypes import uint32 | ||
|
||
|
||
def lop3(d: Expr, a: Expr, b: Expr, c: Expr, *, imm_lut: int): | ||
""" | ||
Perform a logical operation on three 32-bit values and store the result in `d`. | ||
The logical operation is determined by the immediate value `imm_lut`. | ||
See the PTX ISA documentation for the `lop3` instruction for more information: | ||
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-lop3 | ||
Parameters | ||
---------- | ||
d: Expr | ||
The pointer to the 32-bit result. | ||
a: Expr | ||
The first 32-bit operand. | ||
b: Expr | ||
The second 32-bit operand. | ||
c: Expr | ||
The third 32-bit operand. | ||
imm_lut: int | ||
The immediate value that determines the logical operation. Given logical operation `f(a, b, c)`, the | ||
immediate value `imm_lut` should be set to `f(0xF0, 0xCC, 0xAA)` to indicate the logical operation. | ||
""" | ||
assert 0 <= imm_lut <= 255 | ||
|
||
return asm( | ||
'lop3.b32 %0, %1, %2, %3, {};'.format(imm_lut), | ||
outputs=[cast(d, ~uint32)[0]], | ||
inputs=[a, b, c, imm_lut], | ||
is_volatile=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from typing import Optional | ||
|
||
from hidet.ir.expr import Expr | ||
from hidet.ir.func import Function | ||
from hidet.ir.primitives.func import register_primitive_function, call_primitive_func | ||
from hidet.utils import initialize | ||
|
||
|
||
def resolve_func_name(mode: Optional[str] = None) -> str: | ||
if mode is None: | ||
return 'prmt_b32' | ||
else: | ||
return 'prmt_b32_{}'.format(mode) | ||
|
||
|
||
def resolve_inst_template(mode: Optional[str] = None) -> str: | ||
if mode is None: | ||
return 'prmt.b32 %0, %1, %2, %3;' | ||
else: | ||
return 'prmt.b32.{} %0, %1, %2, %3;'.format(mode) | ||
|
||
|
||
@initialize() | ||
def register_functions(): | ||
from hidet.lang import attrs, script, asm, cast # pylint: disable=import-outside-toplevel | ||
from hidet.lang.types import uint32, void_p | ||
|
||
for mode in [None, 'f4e', 'b4e', 'rc8', 'ecl', 'ecr', 'rc16']: | ||
template = resolve_inst_template(mode) | ||
|
||
@script | ||
def prmt_primitive(d: void_p, a: uint32, b: uint32, c: uint32): | ||
attrs.func_kind = 'cuda_internal' | ||
attrs.func_name = resolve_func_name(mode) | ||
|
||
asm(template, outputs=[cast(d, ~uint32)[0]], inputs=[a, b, c], is_volatile=True) | ||
|
||
assert isinstance(prmt_primitive, Function) | ||
register_primitive_function(name=prmt_primitive.name, func_or_type=prmt_primitive) | ||
|
||
|
||
def prmt(d: Expr, a: Expr, b: Expr, c: Expr, *, mode: Optional[str] = None): | ||
""" | ||
Perform a byte-level permutation operation on two 32-bit values and store the result in `d`. | ||
The permutation operation is determined by the permutation mode `mode`. | ||
See Also the PTX ISA documentation for the `prmt` instruction for more information: | ||
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt | ||
Parameters | ||
---------- | ||
d: Expr | ||
The pointer to the 32-bit result. | ||
a: Expr | ||
The first uint32 operand. | ||
b: Expr | ||
The second uint32 operand. | ||
c: Expr | ||
The control operand. | ||
mode: Optional[str] | ||
The permutation mode. If not provided, the default mode is used. | ||
""" | ||
assert mode in [None, 'f4e', 'b4e', 'rc8', 'ecl', 'ecr', 'rc16'] | ||
return call_primitive_func(resolve_func_name(mode), args=[d, a, b, c]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from typing import ContextManager | ||
import contextlib | ||
import os | ||
import tempfile | ||
|
||
|
||
class CapturedStdout: | ||
def __init__(self): | ||
self.content: str = "" | ||
|
||
def __str__(self): | ||
return self.content | ||
|
||
def set_output(self, content: str): | ||
self.content = content | ||
|
||
|
||
@contextlib.contextmanager | ||
def capture_stdout() -> ContextManager[CapturedStdout]: | ||
""" | ||
capture the content that has been printed to stdout in the context | ||
We did not use `contextlib.redirect_stdout` nor similar functionality in pytest because it does not work with | ||
`printf(...)` in c/c++. | ||
usage: | ||
``` | ||
with capture_stdout() as captured: | ||
print("hello world") | ||
assert captured.content == "hello world\n" | ||
``` | ||
""" | ||
captured_stdout = CapturedStdout() | ||
|
||
with tempfile.TemporaryFile(mode='w+') as f: | ||
new_fd = f.fileno() | ||
assert new_fd != -1 | ||
|
||
original_fd = os.dup(1) | ||
assert original_fd != -1 | ||
|
||
ret = os.dup2(new_fd, 1) | ||
assert ret != -1 | ||
|
||
yield captured_stdout | ||
|
||
ret = os.dup2(original_fd, 1) | ||
assert ret != -1 | ||
|
||
os.close(original_fd) | ||
f.flush() | ||
os.fsync(new_fd) | ||
ret = f.seek(0) | ||
captured_content = f.read() | ||
captured_stdout.set_output(captured_content) |
Oops, something went wrong.