Skip to content

Commit

Permalink
[FRONTEND] changed hook format and added launch metadata for external…
Browse files Browse the repository at this point in the history
… tools (#3492)
  • Loading branch information
ptillet authored Mar 29, 2024
1 parent 5b9ed92 commit 88abff6
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 72 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -346,4 +346,4 @@ jobs:
- name: Run python tests on ROCM
run: |
cd python
pytest --capture=tee-sys -rfs -v -n 32 ./test/unit/language/test_core.py
pytest --capture=tee-sys -rfs -vvv -n 32 ./test/unit/language/test_core.py
34 changes: 16 additions & 18 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import triton
import triton.language as tl
from triton.runtime.jit import JITFunction, TensorWrapper, reinterpret
from triton.runtime.jit import TensorWrapper, reinterpret


def is_interpreter():
Expand Down Expand Up @@ -2304,21 +2304,24 @@ def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr):
@pytest.mark.parametrize("BLOCK_N", [32, 64, 128])
@pytest.mark.parametrize("N", [512, 1024, 2048])
@pytest.mark.parametrize("num_pid_n", [2, 4])
def test_locality(op, BLOCK_N, N, num_pid_n, device):
def test_optimize_thread_locality(op, BLOCK_N, N, num_pid_n, device):

@triton.jit
def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl.constexpr):
start_m = tl.program_id(0)
pid_n = tl.program_id(1)
num_pid_n = tl.num_programs(1)
local = INITIALIZE_PATCH
off_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
for start_n in range(pid_n, tl.cdiv(N, BLOCK_N), num_pid_n):
for start_n in range(pid_n, tl.cdiv(N, BLOCK_N), NUM_PID_N):
off_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
Xs = X + off_m[:, None] * N + off_n[None, :]
x = tl.load(Xs)
local = ACCUMULATE_PATCH
tl.store(Y + off_m * num_pid_n + pid_n, local)
tl.store(Y + off_m * NUM_PID_N + pid_n, local)
# the following segfaults AMD backend following #3492
# really unclear why; the llvm-ir and kernel arguments are
# identical !
# tl.store(Y + off_m * tl.num_programs(1) + pid_n, local)

initialize_patch = {
'sum': 'tl.zeros([BLOCK_M], dtype=tl.float32)',
Expand All @@ -2340,7 +2343,7 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
BLOCK_M = 32
x = torch.randn((BLOCK_M, N), dtype=torch.float32, device=device)
y = torch.randn((BLOCK_M, num_pid_n), dtype=torch.float32, device=device)
h = kernel[(1, num_pid_n, 1)](x, y, N, BLOCK_M, BLOCK_N)
h = kernel[(1, num_pid_n, 1)](x, y, N, BLOCK_M, BLOCK_N, NUM_PID_N=num_pid_n)
if not is_interpreter():
assert h.asm['ttgir'].count(
'"tt.reduce"') == 2, "tt.reduce should be called twice, otherwise the optimization didn't work"
Expand Down Expand Up @@ -3764,23 +3767,18 @@ def kernel(x):
(2**31, 'i64'), (2**32 - 1, 'i64'), (2**32, 'i64'), (2**63 - 1, 'i64'),
(-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')])
def test_value_specialization(value: int, value_type: str, device) -> None:
spec_type = None

def cache_hook(*args, **kwargs):
nonlocal spec_type
spec_type = kwargs["compile"]["signature"][0]
def repr(specialization):
spec_type = specialization.signature["VALUE"]
return f"kernel_{spec_type}"

JITFunction.cache_hook = cache_hook

@triton.jit
@triton.jit(repr=repr)
def kernel(VALUE, X):
pass

x = torch.tensor([3.14159], device=device)
pgm = kernel[(1, )](value, x)

JITFunction.cache_hook = None
assert spec_type == value_type
h = kernel[(1, )](value, x)
assert value_type in h.name


# --------------------
Expand Down
28 changes: 28 additions & 0 deletions python/test/unit/runtime/test_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,34 @@
# from typing import Tuple


def test_metadata() -> None:

used_hook = False

def _launch_metadata(grid, kernel, args):
ret = dict()
ret["grid"] = grid
ret["value"] = args["x"]
return ret

def hook(launch_metadata):
nonlocal used_hook
metadata = launch_metadata.get()
assert metadata["grid"] == (1, 3, 2)
assert metadata["value"] == 6
used_hook = True

@triton.jit(launch_metadata=_launch_metadata)
def kernel(x):
pass

# launch kernel
triton.compiler.CompiledKernel.launch_enter_hook = hook
kernel[(1, 3, 2)](6)
triton.compiler.CompiledKernel.launch_enter_hook = None
assert used_hook


def test_memory_leak() -> None:

@triton.jit
Expand Down
4 changes: 2 additions & 2 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,14 +1225,14 @@ def ast_to_ttir(fn, specialization, context, options, codegen_fns):
constants = {cst_key(key): value for key, value in specialization.constants.items()}
# visit kernel AST
gscope = fn.__globals__.copy()
function_name = '_'.join([fn.__name__, kernel_suffix(specialization.signature.values(), attrs)])
function_name = fn.repr(specialization)
tys = list(specialization.signature.values())
new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in attrs.equal_to_1}
new_attrs = {k: [("tt.divisibility", 16)] for k in attrs.divisible_by_16}

all_constants = constants.copy()
all_constants.update(new_constants)
arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in constants]
arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants]
file_name, begin_line = _get_fn_file_line(fn)

prototype = language.function_type([], arg_types)
Expand Down
41 changes: 33 additions & 8 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,22 @@ def make_backend(target):
return actives[0](target)


class LazyDict:

def __init__(self, data):
self.data = data
self.extras = []

def get(self) -> None:
for func, args in self.extras:
self.data = self.data | func(*args)
self.extras.clear()
return self.data

def add(self, func, args):
self.extras.append((func, args))


class CompiledKernel:

# Hooks for external tools to monitor the execution of triton kernels
Expand All @@ -306,12 +322,12 @@ class CompiledKernel:
def __init__(self, src, metadata_group, hash):
from collections import namedtuple
metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json")))
self.metadata = json.loads(metadata_path.read_text())
KernelMetadata = namedtuple('KernelMetadata', sorted(list(self.metadata.keys())))
self.metadata = KernelMetadata(**self.metadata)
metadata = json.loads(metadata_path.read_text())
metadata['cluster_dims'] = tuple(metadata['cluster_dims'])
KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys())))
self.metadata = KernelMetadata(**metadata)
self.src = src
self.hash = hash

self.name = self.metadata.name
# stores the text of each level of IR that was generated during compilation
asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")]
Expand Down Expand Up @@ -346,16 +362,25 @@ def __getattribute__(self, name):
self._init_handles()
return super().__getattribute__(name)

def launch_metadata(self, grid, stream, *args):
if CompiledKernel.launch_enter_hook is None:
return None
ret = LazyDict({"name": self.name, "function": self.function, "stream": stream})
if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None:
return ret
args = {k: v for k, v in zip(self.src.fn.arg_names, args)}
ret.add(self.src.fn.launch_metadata, (grid, self.metadata, args))
return ret

def __getitem__(self, grid):
self._init_handles()

def runner(*args, stream=None):
if stream is None:
device = driver.active.get_current_device()
stream = driver.active.get_current_stream(device)
md = self.metadata
self.run(grid[0], grid[1], grid[2], md.num_warps, md.num_ctas, md.cluster_dims[0], md.cluster_dims[1],
md.cluster_dims[2], md.shared, stream, self.function, CompiledKernel.launch_enter_hook,
CompiledKernel.launch_exit_hook, md, *args)
launch_metadata = self.launch_metadata(grid, stream, *args)
self.run(grid[0], grid[1], grid[2], stream, self.function, self.metadata, launch_metadata,
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, *args)

return runner
31 changes: 18 additions & 13 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def run(self, *args, grid, warmup, **kwargs):
if key not in self.cache[device]:
configs = (self._get_config(*[arg.value for arg in args]), )
constants = {
arg.param.num: arg.value
arg.param.name: arg.value
for arg in args
if arg.param.is_constexpr or arg.param.num in configs[0].equal_to_1 or arg.value is None
}
Expand All @@ -432,7 +432,7 @@ def run(self, *args, grid, warmup, **kwargs):
raise TypeError(f"Callable constexpr at index {i} is not supported")

# Build kernel signature -- doesn't include constexpr arguments.
signature = {arg.param.num: arg.mangled_type() for arg in args if not arg.param.is_constexpr}
signature = {arg.param.name: arg.mangled_type() for arg in args if not arg.param.is_constexpr}

if self._call_hook(key, signature, device, constants, options, configs):
return None
Expand All @@ -447,7 +447,7 @@ def run(self, *args, grid, warmup, **kwargs):
kernel = self.cache[device][key]

# Verify key signature from the cache
signature = {arg.param.num: arg.mangled_type() for arg in args if not arg.param.is_constexpr}
signature = {arg.param.name: arg.mangled_type() for arg in args if not arg.param.is_constexpr}
if kernel.src.signature != signature:
raise RuntimeError(
f"Signature mismatch for cached kernel {self.fn.__name__}:\n"\
Expand All @@ -457,16 +457,13 @@ def run(self, *args, grid, warmup, **kwargs):

if not warmup:
args = [arg.value for arg in args if not arg.param.is_constexpr]
metadata = kernel.metadata

kernel.run(grid_0, grid_1, grid_2, metadata.num_warps,
metadata.num_ctas, # number of warps/ctas per instance
metadata.cluster_dims[0], metadata.cluster_dims[1], metadata.cluster_dims[2], # cluster
metadata.shared, stream, kernel.function, CompiledKernel.launch_enter_hook,
CompiledKernel.launch_exit_hook, metadata, *args)
launch_metadata = kernel.launch_metadata(grid, stream, *args)
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.metadata, launch_metadata,
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, *args)
return kernel

def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None):
def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None, repr=None,
launch_metadata=None):
do_not_specialize = do_not_specialize if do_not_specialize else []

self.fn = fn
Expand All @@ -475,6 +472,8 @@ def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinlin
self.signature = inspect.signature(fn)
self.do_not_specialize = do_not_specialize
self.starting_line_number = inspect.getsourcelines(fn)[1]
self.repr = lambda _: fn.__name__ if repr is None else repr(_)
self.launch_metadata = launch_metadata

self.params = []
for i, param in enumerate(self.signature.parameters.values()):
Expand Down Expand Up @@ -529,10 +528,10 @@ def preload(self, specialization_data):
raise RuntimeError(
f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.fn.__name__}")
constants = {
int(key): tl.dtype(value) if tl.dtype.is_dtype(value) else value
key: tl.dtype(value) if tl.dtype.is_dtype(value) else value
for key, value in deserialized_obj['constants'].items()
}
signature = {int(key): value for key, value in deserialized_obj['signature'].items()}
signature = {key: value for key, value in deserialized_obj['signature'].items()}
src = ASTSource(self, signature, constants, AttrsDescriptor.from_dict(deserialized_obj['attrs']))
options = {
key: tuple(value) if isinstance(value, list) else value
Expand Down Expand Up @@ -581,6 +580,8 @@ def jit(fn: T) -> JITFunction[T]:
def jit(
*,
version=None,
repr: Optional[Callable] = None,
launch_metadata: Optional[Callable] = None,
do_not_specialize: Optional[Iterable[int]] = None,
debug: Optional[bool] = None,
noinline: Optional[bool] = None,
Expand All @@ -592,6 +593,8 @@ def jit(
fn: Optional[T] = None,
*,
version=None,
repr: Optional[Callable] = None,
launch_metadata: Optional[Callable] = None,
do_not_specialize: Optional[Iterable[int]] = None,
debug: Optional[bool] = None,
noinline: Optional[bool] = None,
Expand Down Expand Up @@ -626,6 +629,8 @@ def decorator(fn: T) -> JITFunction[T]:
do_not_specialize=do_not_specialize,
debug=debug,
noinline=noinline,
repr=repr,
launch_metadata=launch_metadata,
)

if fn is not None:
Expand Down
3 changes: 1 addition & 2 deletions python/triton/tools/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,10 @@ def constexpr(s):
# dump C stub code
suffix = kernel_suffix(signature.values(), attrs)
func_name = '_'.join([out_name, sig_hash, suffix])
triton_kernel_name = '_'.join([args.kernel_name, suffix])
hex_ = str(binascii.hexlify(ccinfo.asm["cubin"]))[2:-1]
params = {
"kernel_name": func_name,
"triton_kernel_name": triton_kernel_name,
"triton_kernel_name": args.kernel_name,
"bin_size": len(hex_),
"bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]),
"signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]),
Expand Down
Loading

0 comments on commit 88abff6

Please sign in to comment.