Skip to content

Commit

Permalink
Revert "[FRONTEND] added support for tuples (#5220)"
Browse files Browse the repository at this point in the history
This reverts commit 9743ec0.
  • Loading branch information
whitneywhtsang committed Dec 10, 2024
1 parent 2c10050 commit 492ea92
Show file tree
Hide file tree
Showing 21 changed files with 288 additions and 635 deletions.
1 change: 0 additions & 1 deletion python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,6 @@ void init_triton_ir(py::module &&m) {
"Function argument index out of range");
return self.getArgument(idx);
})
.def("get_num_args", &FuncOp::getNumArguments)
.def(
"add_entry_block",
[](FuncOp &self) -> Block * { return self.addEntryBlock(); },
Expand Down
49 changes: 24 additions & 25 deletions python/test/unit/language/test_compile_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def kernel():
a += 1 # noqa

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))

try:
assert "is not defined" in str(e.value), "error should mention the undefined variable"
Expand All @@ -32,7 +32,7 @@ def kernel():
0 + "a"

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))

try:
assert "at 2:4:" in str(e.value), "error should point to the 0"
Expand All @@ -47,7 +47,7 @@ def kernel():
tl.static_assert(isinstance(0, tl.tensor))

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))

try:
assert isinstance(e.value, CompileTimeAssertionFailure)
Expand All @@ -66,7 +66,7 @@ def kernel():
not (0, 0)

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))

try:
assert e.value.__cause__ is None
Expand All @@ -83,7 +83,7 @@ def kernel():
1.0 << 1

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))

try:
assert "at 2:4:" in str(e.value), "error should point to the 1.0"
Expand All @@ -107,7 +107,7 @@ def kernel():
nested_call()

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))

try:
inner = e.value.__cause__
Expand All @@ -130,7 +130,7 @@ def kernel():
tl.expand_dims(None, -1)

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))

try:
inner = e.value.__cause__
Expand All @@ -157,7 +157,7 @@ def kernel():
a = two_returns()
a + tl.arange(0, 4) # only works if we took the first return

triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))


def test_not_const_annotate_no_err():
Expand All @@ -166,7 +166,7 @@ def test_not_const_annotate_no_err():
def kernel(N: int = 1):
pass

triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={}))


@triton.jit
Expand All @@ -186,14 +186,14 @@ def kernel1(N: tl.constexpr):
a = returns_branched_on_constexpr(N)
a + tl.arange(0, 4)

triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={"N": "constexpr"}, constexprs={"N": 0}))
triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={}, constants={"N": 0}))

@triton.jit
def kernel2(N: tl.constexpr):
a = returns_branched_on_constexpr(N)
a + tl.arange(0, 8)

triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={"N": "constexpr"}, constexprs={"N": 1}))
triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={}, constants={"N": 1}))


@triton.jit
Expand All @@ -211,7 +211,7 @@ def kernel(N: int):
returns_branched_on_non_constexpr(N)

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={}))

try:
assert "at 2:4:" in str(e.value), "error should point to the function call"
Expand All @@ -227,7 +227,7 @@ def kernel():
tl.arange(2, 7)

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
assert str(e.value.__cause__) == "arange's range must be a power of 2"


Expand All @@ -238,7 +238,7 @@ def kernel():
tl.full((33, ), 0, dtype=tl.int64)

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
assert str(e.value.__cause__) == "Shape element 0 must be a power of 2"


Expand All @@ -251,7 +251,7 @@ def kernel():
a = CAPTURED # noqa

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
assert "CAPTURED is not defined" in str(e.value)


Expand All @@ -265,7 +265,7 @@ def kernel():
a = GLOBAL # noqa

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
assert "global variable" in str(e.value)


Expand All @@ -279,7 +279,7 @@ def kernel():
a = CONSTEXPR_ANNOTATED_GLOBAL # noqa

# No error.
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))


CONSTEXPR_GLOBAL = tl.constexpr(42)
Expand All @@ -292,7 +292,7 @@ def kernel():
a = CONSTEXPR_GLOBAL # noqa

# No error.
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))


TYPE_ALIAS = tl.pointer_type(tl.int32)
Expand All @@ -305,7 +305,7 @@ def kernel():
a = TYPE_ALIAS # noqa

# No error.
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))


def test_global_access_in_fn_default_arg():
Expand All @@ -315,7 +315,7 @@ def kernel(a=GLOBAL):
pass

# No error.
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constexprs={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constants={}))


def test_defaults_assign_no_err():
Expand All @@ -324,7 +324,7 @@ def test_defaults_assign_no_err():
def kernel(a=1, B: tl.constexpr = ""):
pass

triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32', 'B': 'constexpr'}, constexprs={'B': ""}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32'}, constants={'B': ""}))


def test_where_warning(fresh_triton_cache):
Expand All @@ -337,7 +337,7 @@ def kernel():
tl.where(a, b, c)

with pytest.warns(UserWarning):
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))


@pytest.mark.parametrize("dtype", [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15])
Expand Down Expand Up @@ -371,8 +371,7 @@ def dtype_kernel(dtype: tl.constexpr):
ctx = pytest.raises(CompilationError, match="")

with ctx as e:
triton.compile(
triton.compiler.ASTSource(fn=dtype_kernel, signature={"dtype": "constexpr"}, constexprs={"dtype": dtype}))
triton.compile(triton.compiler.ASTSource(fn=dtype_kernel, signature={}, constants={"dtype": dtype}))

if dtype not in supported_dtypes:
try:
Expand All @@ -391,7 +390,7 @@ def dot_kernel():
tl.dot(a, b, max_num_imprecise_acc=128)

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constexprs={}))
triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constants={}))
try:
assert (str(e.value.__cause__) == "max_num_imprecise_acc (128) must be <= K (64)")
except AssertionError as assertion_err:
Expand Down
23 changes: 4 additions & 19 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4407,17 +4407,15 @@ def kernel(x):
def test_value_specialization(value: int, value_type: str, device) -> None:

def repr(specialization):
ty = specialization.signature["value1"]
cst = '_'.join([k for k, v in specialization.constants.items() if v == 1])
return f"kernel_{ty}_{cst}"
spec_type = specialization.signature["VALUE"]
return f"kernel_{spec_type}"

@triton.jit(repr=repr)
def kernel(value1, is_one, X):
def kernel(VALUE, X):
pass

x = torch.tensor([3.14159], device=device)
h = kernel[(1, )](value, 1, x)
assert "is_one" in h.name
h = kernel[(1, )](value, x)
assert value_type in h.name


Expand Down Expand Up @@ -6188,19 +6186,6 @@ def sanitize_sum_2d_kernel(Z, X, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, r
torch.testing.assert_close(Z, X.sum(reduce_dim).to(torch.int32))


def test_dtype(device):

@triton.jit
def kernel(X):
dtype_x: tl.constexpr = X.dtype.element_ty
tl.static_assert(dtype_x == tl.int32)
tl.static_assert(dtype_x == tl.constexpr(tl.int32))
tl.static_assert(dtype_x == tl.int8 or (dtype_x == tl.int16 or dtype_x == tl.int32))

X = torch.zeros(1, dtype=torch.int32, device=device)
kernel[(1, )](X)


def test_side_effectful_scan(device):
if device != "cuda":
pytest.xfail()
Expand Down
2 changes: 1 addition & 1 deletion python/test/unit/language/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def kernel():
pass

try:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
except Exception as e:
pytest.fail(f"triton compile failed with error: {e}")

Expand Down
100 changes: 0 additions & 100 deletions python/test/unit/language/test_tuple.py

This file was deleted.

15 changes: 9 additions & 6 deletions python/test/unit/runtime/test_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,15 @@ def walk_fn(op):
backend = triton.compiler.compiler.make_backend(target)
src = triton.compiler.compiler.ASTSource(
fn=kernel,
signature={kernel.arg_names[i]: kernel._type_of(kernel._key_of(arg))
for i, arg in enumerate(args)},
constexprs={kernel.arg_names[i]: arg
for i, arg in enumerate(args)
if not isinstance(arg, torch.Tensor)},
attrs=backend.get_attrs_descriptor(kernel.params, args),
signature={
kernel.arg_names[i]: kernel._type_of(kernel._key_of(arg))
for i, arg in enumerate(args)
if i not in kernel.constexprs
},
constants={kernel.arg_names[i]: arg
for i, arg in enumerate(args)
if not isinstance(arg, torch.Tensor)},
attrs=backend.get_attrs_descriptor(args, kernel.params),
)

context = triton._C.libtriton.ir.context()
Expand Down
Loading

0 comments on commit 492ea92

Please sign in to comment.