Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[INTERPRETER] Implement implicit tensor conversion for assignment operators #4214

Merged
merged 9 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/programming-guide/chapter-3/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ The interpreter has several known limitations:
ptr = tl.load(ptr)
x = tl.load(ptr)

- Unlike the compilation mode, a scalar in interpreter mode is treated as a simple float or integer but not as a 0-d tensor. This means it lacks tensor attributes such as :code:`x.dtype`. A workaround is to explicitly convert the scalar to a tensor using :code:`tl.to_tensor(x)`, where :code:`x` is the scalar.

----------------------------
Using Third-party Tools
----------------------------
Expand Down
33 changes: 27 additions & 6 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1702,6 +1702,27 @@ def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
assert torch.all(output == ref)


@pytest.mark.interpreter
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_store_constant_default_dtype(num_ctas, device):
"""Tests that boolean True is stored as 1"""

@triton.jit
def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
value = 1
output = tl.full([BLOCK_SIZE], value=value, dtype=value.dtype)
tl.store(output_ptr + offsets, output, mask=mask)

block_size = 128
ref = torch.ones([block_size], dtype=getattr(torch, 'int32'), device=device)
output = torch.zeros([block_size], dtype=getattr(torch, 'int32'), device=device)
kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas)

assert torch.all(output == ref)


def test_load_store_same_ptr(device):

@triton.jit()
Expand Down Expand Up @@ -5334,12 +5355,12 @@ def test_tl_range(device):
torch.testing.assert_close(ref_out, c, rtol=1e-2, atol=1e-1)
else:
torch.testing.assert_close(ref_out, c, rtol=1e-3, atol=1e-3)
if device in ['cuda']:
capability = torch.cuda.get_device_capability()
if capability[0] >= 8:
ptx = pgm.asm['ptx']
# check that the loop got pipelined with the right number of stages.
assert 'cp.async.wait_group 0x6' in ptx
if device in ['cuda']:
capability = torch.cuda.get_device_capability()
if capability[0] >= 8:
ptx = pgm.asm['ptx']
# check that the loop got pipelined with the right number of stages.
assert 'cp.async.wait_group 0x6' in ptx


@triton.jit(noinline=True)
Expand Down
37 changes: 34 additions & 3 deletions python/test/unit/language/test_line_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,8 @@ def check_file_lines(file_lines, file_name, lineno, should_contain=True):
should_contain: whether the file name and line number should be in the file_lines
"""
for file, line in file_lines:
if lineno == -1:
if file_name in file:
return True
if lineno == -1 and file_name in file:
return True
if file_name in file and str(lineno) in line:
return should_contain
return not should_contain
Expand Down Expand Up @@ -169,3 +168,35 @@ def test_line_info(func: str):
elif func == "dot_combine":
assert (check_file_lines(file_lines, "test_line_info.py", 65))
assert (check_file_lines(file_lines, "test_line_info.py", 66, should_contain=False))


def is_interpreter():
import os
return os.environ.get('TRITON_INTERPRET', '0') == '1'


@pytest.mark.interpreter
@pytest.mark.parametrize("func", func_types)
def test_line_info_interpreter(func: str):
if not is_interpreter():
pytest.skip("interpreter is not enabled")

kernel = None
expected_offset = 0
if func == "single":
kernel = kernel_single
expected_offset = 12
elif func == "call":
kernel = kernel_call
expected_offset = 25
elif func == "call_noinline":
kernel = kernel_call_noinline
expected_offset = 41
elif func == "autotune":
kernel = kernel_autotune.fn
expected_offset = 52
elif func == "dot_combine":
kernel = kernel_dot_combine
expected_offset = 62
kernel._rewrite_ast()
assert kernel.ast_transformer.offset == expected_offset
24 changes: 3 additions & 21 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .._C.libtriton import ir
from ..language import constexpr, tensor, str_to_ty
from ..language.core import _unwrap_if_constexpr
from ..runtime.jit import _normalize_ty
from ..runtime.jit import _normalize_ty, get_jit_fn_file_line
# ideally we wouldn't need any runtime component
from ..runtime import JITFunction
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
Expand Down Expand Up @@ -73,24 +73,6 @@ def _check_fn_args(node, fn, args):
)


def _get_fn_file_line(fn):
base_fn = fn
while not isinstance(base_fn, JITFunction):
base_fn = base_fn.fn
file_name = base_fn.fn.__code__.co_filename
lines, begin_line = inspect.getsourcelines(base_fn.fn)
# Match the following pattern:
# @triton.autotune(...) <- foo.__code__.co_firstlineno
# @triton.heuristics(...)
# @triton.jit
# def foo(...): <- this line is the first line
for idx, line in enumerate(lines):
if line.strip().startswith("def "):
begin_line += idx
break
return file_name, begin_line


_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels


Expand Down Expand Up @@ -1059,7 +1041,7 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs):
prototype = language.function_type([], arg_types)
gscope = fn.__globals__
# If the callee is not set, we use the same debug setting as the caller
file_name, begin_line = _get_fn_file_line(fn)
file_name, begin_line = get_jit_fn_file_line(fn)
debug = self.debug if fn.debug is None else fn.debug
generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module,
jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types,
Expand Down Expand Up @@ -1282,7 +1264,7 @@ def ast_to_ttir(fn, specialization, context, options, codegen_fns):
all_constants = constants.copy()
all_constants.update(new_constants)
arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants]
file_name, begin_line = _get_fn_file_line(fn)
file_name, begin_line = get_jit_fn_file_line(fn)

prototype = language.function_type([], arg_types)
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name,
Expand Down
8 changes: 5 additions & 3 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def to_tensor(x, _builder=None):
return _to_tensor(x, _builder)


def _to_tensor(x, builder):
def _to_tensor(x, builder, check_type: bool = True):
if isinstance(x, bool):
return tensor(builder.get_int1(x), int1)
# Note: compile-time const integers are represented by unsigned values
Expand All @@ -129,7 +129,7 @@ def _to_tensor(x, builder):
elif 2**63 <= x < 2**64:
return tensor(builder.get_uint64(x), uint64)
else:
raise RuntimeError(f'Nonrepresentable integer {x}.')
raise ValueError(f'Nonrepresentable integer {x}.')
elif isinstance(x, float):
min_float32 = 2**-126
max_float32 = (2 - 2**-23) * 2**127
Expand All @@ -146,7 +146,9 @@ def _to_tensor(x, builder):
return _to_tensor(x.value, builder)
elif isinstance(x, tensor):
return x
assert False, f"cannot convert {x} of type {type(x)} to tensor"
if check_type:
raise TypeError(f"cannot convert {x} of type {type(x)} to tensor")
return x


# -----------------------
Expand Down
74 changes: 70 additions & 4 deletions python/triton/runtime/interpreter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import ast
import textwrap
import inspect
from typing import Tuple

Expand Down Expand Up @@ -1094,30 +1096,94 @@ def __call__(self, *args_dev, **kwargs):
self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst)


class ASTTransformer(ast.NodeTransformer):

def __init__(self) -> None:
self.offset = 0

def visit_Assign(self, node):
names = []
for target in node.targets:
names += [self.visit(target)]
if len(names) > 1:
raise ValueError("Multiple assignments are not supported")
# Modify the assignment x = value to
# triton.core.language._to_tensor(value, interpreter_builder, False)
node.value = ast.Call(
func=ast.Attribute(
value=ast.Attribute(
value=ast.Attribute(value=ast.Name(id='triton', ctx=ast.Load()), attr='language', ctx=ast.Load()),
attr='core', ctx=ast.Load()), attr='_to_tensor', ctx=ast.Load()),
args=[node.value, ast.Name(id='interpreter_builder', ctx=ast.Load()),
ast.Constant(value=False)], keywords=[])
return node

def generic_visit(self, node):
# Adjust the begin line number of the node
if hasattr(node, 'lineno') and node.lineno is not None:
node.lineno += self.offset
if hasattr(node, 'end_lineno') and node.end_lineno is not None:
node.end_lineno += self.offset
return super().generic_visit(node)


class InterpretedFunction:
rewritted_fn = {}
ast_transformer = ASTTransformer()

def __init__(self, fn) -> None:
def __init__(self, fn, **kwargs) -> None:
self.fn = fn

def run(*args, **kwargs):
grid = kwargs["grid"]
return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs)
fn = self._rewrite_ast()
return GridExecutor(fn, self.arg_names, grid)(*args, **kwargs)

self.run = run
self.kwargs = kwargs
signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]

def _rewrite_ast(self):
if self.fn in self.rewritted_fn:
return self.rewritted_fn[self.fn]
# If exception is raise, it means the function does not have source code available,
# e.g., dynamically generated functions, we cannot rewrite it so just return the original function
try:
lines, lineno = inspect.getsourcelines(self.fn)
except Exception:
self.rewritted_fn[self.fn] = self.fn
return self.fn
from .jit import get_jit_fn_file_line, JITFunction
filename, lineno = get_jit_fn_file_line(JITFunction(self.fn))
src = ''.join(lines)
src = textwrap.dedent(src)
parsed_ast = ast.parse(src)
self.ast_transformer.offset = lineno
transformed_ast = self.ast_transformer.visit(parsed_ast)
transformed_ast = ast.fix_missing_locations(transformed_ast)
compiled_code = compile(transformed_ast, filename=filename, mode='exec')
local_namespace = {**self.kwargs}
if self.fn.__name__ in local_namespace:
raise ValueError(f"Function name {self.fn.__name__} is reserved")
exec(compiled_code, globals(), local_namespace)
fn = local_namespace[self.fn.__name__].fn
self.rewritted_fn[self.fn] = fn
return fn

@property
def __name__(self):
return self.fn.__name__

def __getitem__(self, grid):
return GridExecutor(self.fn, self.arg_names, grid)
fn = self._rewrite_ast()
return GridExecutor(fn, self.arg_names, grid)

def __call__(self, *args, **kwargs):
# This is a device function call
_patch_lang(self.fn)
fn = self._rewrite_ast()
try:
return self.fn(*args, **kwargs)
return fn(*args, **kwargs)
except Exception as e:
raise InterpreterError(repr(e)) from e
21 changes: 20 additions & 1 deletion python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,8 @@ def decorator(fn: T) -> JITFunction[T]:
assert callable(fn)
if os.getenv("TRITON_INTERPRET", "0") == "1":
from .interpreter import InterpretedFunction
return InterpretedFunction(fn)
return InterpretedFunction(fn, version=version, do_not_specialize=do_not_specialize, debug=debug,
noinline=noinline, repr=repr, launch_metadata=launch_metadata)
else:
return JITFunction(
fn,
Expand Down Expand Up @@ -935,3 +936,21 @@ def reinterpret(tensor, dtype):
return TensorWrapper(tensor, dtype)
else:
raise TypeError(f"Cannot reinterpret a {type(tensor)}.")


def get_jit_fn_file_line(fn):
base_fn = fn
while not isinstance(base_fn, JITFunction):
base_fn = base_fn.fn
file_name = base_fn.fn.__code__.co_filename
lines, begin_line = inspect.getsourcelines(base_fn.fn)
# Match the following pattern:
# @triton.autotune(...) <- foo.__code__.co_firstlineno
# @triton.heuristics(...)
# @triton.jit
# def foo(...): <- this line is the first line
for idx, line in enumerate(lines):
if line.strip().startswith("def "):
begin_line += idx
break
return file_name, begin_line
Loading