Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: PyTorch debugger #1573

Merged
merged 31 commits into from
May 8, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
f5d3b80
feat: setup interpreter
pommedeterresautee Apr 24, 2023
9c94995
feat: random grid iteration
pommedeterresautee Apr 25, 2023
57c1955
Merge branch 'main' into feat/interpreter
pommedeterresautee Apr 25, 2023
7f4803e
feat: add atomic ops
pommedeterresautee Apr 27, 2023
9bf38ae
merge main
pommedeterresautee May 2, 2023
42cc175
fix debugger import issue
pommedeterresautee May 2, 2023
0bf7104
fix: limit module exports
pommedeterresautee May 2, 2023
11634b9
fix: export module in setup.py
pommedeterresautee May 2, 2023
c76f42f
fix: tuple -> Tuple (old Python support)
pommedeterresautee May 2, 2023
41f9519
fix: remove torch dependency
pommedeterresautee May 2, 2023
86aced7
fix: harmonize error message
pommedeterresautee May 2, 2023
90c1331
Merge branch 'main' into feat/interpreter
pommedeterresautee May 3, 2023
41488a8
Merge branch 'main' into feat/interpreter
pommedeterresautee May 3, 2023
d59ba19
fix: simplify mechanism to make torch optional
pommedeterresautee May 3, 2023
aa1f527
Merge branch 'main' into feat/interpreter
pommedeterresautee May 4, 2023
dce5db2
fix: check torch version
pommedeterresautee May 4, 2023
86d3e6c
Merge branch 'main' into feat/interpreter
pommedeterresautee May 8, 2023
047b251
feat: setup interpreter
pommedeterresautee Apr 24, 2023
bb3e4c0
feat: random grid iteration
pommedeterresautee Apr 25, 2023
1c2f5ad
feat: add atomic ops
pommedeterresautee Apr 27, 2023
6e89e6b
fix debugger import issue
pommedeterresautee May 2, 2023
2839e34
fix: limit module exports
pommedeterresautee May 2, 2023
d6c93a8
fix: export module in setup.py
pommedeterresautee May 2, 2023
ee877ba
fix: tuple -> Tuple (old Python support)
pommedeterresautee May 2, 2023
3f51fb6
fix: remove torch dependency
pommedeterresautee May 2, 2023
e492b46
fix: harmonize error message
pommedeterresautee May 2, 2023
a501314
fix: simplify mechanism to make torch optional
pommedeterresautee May 3, 2023
3ad8c43
fix: check torch version
pommedeterresautee May 4, 2023
e33280e
fix: fix assert
pommedeterresautee May 8, 2023
2772e39
Merge remote-tracking branch 'origin/feat/interpreter' into feat/inte…
pommedeterresautee May 8, 2023
7154fcb
fix: disable test_reduce_2d
pommedeterresautee May 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions python/test/unit/debugger/test_debugger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch

import triton
import triton.language as tl


def test_addition():

@triton.jit(pytorch_interpreter=True)
def add_kernel(
x_ptr, # *Pointer* to first input vector
y_ptr, # *Pointer* to second input vector
output_ptr, # *Pointer* to output vector
n_elements, # Size of the vector
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process
# NOTE: `constexpr` so it can be used as a shape value
):
# There are multiple 'program's processing different data. We identify which program
# we are here
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
# This program will process inputs that are offset from the initial data.
# for instance, if you had a vector of length 256 and block_size of 64, the programs
# would each access the elements [0:64, 64:128, 128:192, 192:256].
# Note that offsets is a list of pointers
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extra elements in case the input is not a
# multiple of the block size
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
# Write x + y back to DRAM
tl.store(output_ptr + offsets, output, mask=mask)

a = torch.rand((128,), device="cuda")
b = torch.rand((128,), device="cuda")
expected = a + b

output = torch.empty((128,), device="cuda")

def grid(meta):
return (triton.cdiv(128, meta["BLOCK_SIZE"]),)

add_kernel[grid](a, b, output, 128, BLOCK_SIZE=32)

assert torch.allclose(expected, output, atol=1e-2, rtol=0)
Empty file.
7 changes: 7 additions & 0 deletions python/triton/debugger/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import dataclasses


@dataclasses.dataclass
class ExecutionContext:
program_id: tuple[int]
program_size: tuple[int]
154 changes: 154 additions & 0 deletions python/triton/debugger/debugger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import itertools

import torch

import triton
import triton.language as tl
from .core import ExecutionContext
from .memory_map import MemoryMap
from .tl_lang import (TritonLangProxy, WrappedTensor, _primitive_to_tensor,
debugger_constexpr)

tl_method_backup = {}


def get_proxy_method(proxy, name):
method = getattr(proxy, name)

def fun(*args, **kwarg):
return method(*args, **kwarg)

return fun


def attach_triton(module, proxy):
method_list = [func for func in dir(TritonLangProxy) if func[0] != "_"]
for name in method_list:
if hasattr(module, name):
attr = getattr(module, name)
tl_method_backup[name] = attr
if callable(attr):
setattr(module, name, get_proxy_method(proxy, name))
else:
setattr(module, name, getattr(proxy, name))


def detach_triton(module):
for name, method in tl_method_backup.items():
setattr(module, name, method)


def program_ids_from_grid(grid):
iterator = itertools.product(*[range(v) for v in tuple(reversed(grid))])
return map(lambda v: tuple(reversed(v)), iterator)


class DebuggerFunction:
def __init__(self, func, grid=(1,)):
self.func = func
self.grid = grid

def _is_constexpr(self, name):
return name in self.func.__annotations__ and self.func.__annotations__[name] is triton.language.core.constexpr

def _get_constexpr(self):
result = []
for name, annotation in self.func.__annotations__.items():
if annotation is triton.language.core.constexpr:
result.append(name)
return result

def _assert_constexpr(self, **kwargs):
constexp = self._get_constexpr()
missing = [i for i in constexp if i not in kwargs.keys()]
assert len(missing) == 0, f"You must specify constexpr {missing}"

def _get_grid(self, **kwargs):
if callable(self.grid):
return self.grid(kwargs)
else:
return self.grid

def __call__(self, *args, **kwargs):
self._assert_constexpr(**kwargs)

memory = MemoryMap()

def convert_arg(v):
name, arg = v
if torch.is_tensor(arg):
ptr = memory.add_tensor(arg)
return WrappedTensor(torch.tensor([ptr], dtype=torch.int64, device="cuda"))
if self._is_constexpr(name):
return debugger_constexpr(arg)
return WrappedTensor(_primitive_to_tensor(arg))

new_args = tuple(map(convert_arg, zip(self.func.__code__.co_varnames, args)))
new_kwargs = {k: convert_arg((k, v)) for (k, v) in kwargs.items() if k not in ["num_warps", "num_stages"]}

grid = self._get_grid(**kwargs)
for program_id in program_ids_from_grid(grid):
proxy = TritonLangProxy(memory, ExecutionContext(program_id, grid))
attach_triton(tl, proxy)
self.func(*new_args, **new_kwargs)
detach_triton(tl)


class GridSelector:
def __init__(self, func):
self.func = func

def __getitem__(self, grid):
return DebuggerFunction(self.func, grid)

def __call__(self, *args, **kwargs):
return DebuggerFunction(self.func)(*args, **kwargs)


class AutotuneGridSelector:
def __init__(self, func, autotune_params):
self.func = func
self.autotune_params = autotune_params

def __getitem__(self, grid):
return AutotuneRunner(self.func, self.autotune_params, grid)

def __call__(self, *args, **kwargs):
return AutotuneRunner(self.func, self.autotune_params)(*args, **kwargs)


class AutotuneRunner:
def __init__(self, func, autotune_params, grid=None):
self.func = func
self.autotune_params = autotune_params
self.grid = grid

def __call__(self, *args, **kwargs):
assert len(self.autotune_params["configs"]) >= 1

for config in self.autotune_params["configs"][1:]:

def convert_arg(v):
if torch.is_tensor(v):
return torch.clone(v)
return v

new_args = tuple(map(convert_arg, args))
new_kwargs = {k: convert_arg(v) for k, v in kwargs.items()}
if self.grid:
self.func[self.grid](*new_args, **new_kwargs, **config.kwargs)
else:
self.func(*new_args, **new_kwargs, **config.kwargs)

main_config = self.autotune_params["configs"][0]
if self.grid:
self.func[self.grid](*args, **kwargs, **main_config.kwargs)
else:
self.func(*args, **kwargs, **main_config.kwargs)


def triton_debug_autotune(**kwars):
def wrapper(func):
return AutotuneGridSelector(func, kwars)

return wrapper
97 changes: 97 additions & 0 deletions python/triton/debugger/memory_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import dataclasses
import torch


@dataclasses.dataclass
class RegisteredStorage:
storage: torch.Storage
dtype: torch.dtype
size: int
ptr: int

@property
def end_ptr(self) -> int:
return self.ptr + self.size

@property
def access_tensor(self) -> torch.Tensor:
return torch.tensor(self.storage, dtype=self.dtype, device=self.storage.device)

def ensure_immutable(self):
assert self.storage.data_ptr() == self.ptr and self.storage.size() == self.size


class MemoryMap:
storages: [RegisteredStorage]

def __init__(self):
self.storages = []

def _get_registered_storage(self, pointer: torch.Tensor):
max_pointer = torch.max(pointer).item()
min_pointer = torch.min(pointer).item()

registered_storage = next(
filter(
lambda registered: min_pointer >= registered.ptr and max_pointer < registered.end_ptr, self.storages
),
None,
)
if registered_storage is None:
raise Exception("Storage not found or pointers spanning multiple tensors")
registered_storage.ensure_immutable()
return registered_storage

def add_tensor(self, t: torch.Tensor):
storage = t.untyped_storage()
self.storages.append(RegisteredStorage(storage, t.dtype, storage.size(), storage.data_ptr()))
return t.data_ptr()

def load(
self,
pointer: torch.Tensor,
mask: torch.Tensor = None,
other=0.0,
):
assert pointer.is_cuda
assert 0 < pointer.dim() < 3
assert pointer.dtype == torch.int64

if mask is None:
mask = torch.ones_like(pointer).bool()
assert mask.is_cuda
assert 0 < mask.dim() < 3
assert mask.dtype == torch.bool
mask = mask.expand(pointer.size())

if torch.all(~mask):
# Todo: The type is wrong here, we can't determine the correct type
return torch.full_like(pointer, fill_value=other, dtype=torch.float16, device="cuda")

registered_storage = self._get_registered_storage(pointer[mask])
access_tensor = registered_storage.access_tensor

index_tensor = pointer - registered_storage.ptr

block = torch.full_like(pointer, fill_value=other, dtype=access_tensor.dtype, device="cuda")
block[mask] = access_tensor[index_tensor[mask]]
return block

def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None):
assert 0 < pointer.dim() < 3
assert pointer.dtype == torch.int64

if mask is None:
mask = torch.ones_like(pointer).bool()
assert 0 < mask.dim() < 3
assert mask.dtype == torch.bool
mask = mask.expand(pointer.size())

if torch.all(~mask):
return

registered_storage = self._get_registered_storage(pointer[mask])
access_tensor = registered_storage.access_tensor

index_tensor = pointer - registered_storage.ptr
access_tensor[index_tensor[mask]] = value[mask].to(access_tensor.dtype)
Loading