Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fails in the middle of training with KeyError #1512

Open
mitchellnw opened this issue Apr 11, 2023 · 7 comments
Open

Fails in the middle of training with KeyError #1512

mitchellnw opened this issue Apr 11, 2023 · 7 comments

Comments

@mitchellnw
Copy link

mitchellnw commented Apr 11, 2023

I'm new to Triton and really enjoy using the library - thanks a lot.

I've replaced AdamW with a Triton vesion of AdamW and training works great for the first ~5000 iterations. However, after that the code randomly fails with:

Traceback (most recent call last):
  File "<string>", line 21, in _opt_kernel2
KeyError: ('2-.-0-.-0-83ca8b715a9dc5f32dc1110973485f64-d6252949da17ceb5f3a278a70250af13-3b85c7bef5f0a641282f3b73af50f599-14de7de5c4da5794c8ca14e7e41a122d-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.float32, torch.float32, torch.float32, torch.float32, dtype('float64'), 'fp32', 'fp32', dtype('float64'), 'fp32', 'i1', 'i32'), (1024,), (True, True, True, True, (False,), (False,), (False,), (False,), (False,), (False, True), (False, True)))

which appears similar to #1509.

The kernel is:

@triton.jit
def _opt_kernel2(
    p_ptr,
    grad_ptr,
    exp_avg_ptr,
    exp_avg2_ptr,
    lr,
    wd,
    beta1,
    eta,
    eps,
    update_clip,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis = 0)

    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)

    mask = offsets < n_elements

    # offsetted pointers

    offset_p_ptr = p_ptr + offsets
    offset_grad_ptr = grad_ptr + offsets
    offset_exp_avg_ptr = exp_avg_ptr + offsets
    offset_exp_avg2_ptr = exp_avg2_ptr + offsets

    # load, early exit if nan
    grad = tl.load(offset_grad_ptr, mask = mask)
    #if tl.max(tl.libdevice.isnan(grad), 0) == 0:
        
    p = tl.load(offset_p_ptr, mask = mask)
    exp_avg = tl.load(offset_exp_avg_ptr, mask = mask)
    exp_avg2 = tl.load(offset_exp_avg2_ptr, mask = mask)

    # stepweight decay

    p = p * (1 - lr * wd)

    # update exp_avgs
    exp_avg = beta1 * exp_avg + (1 - beta1) * grad
    exp_avg2 = beta2 * exp_avg2 + (1 - beta2) * grad * grad
    p = p - eta * (exp_avg / (tl.sqrt(exp_avg2) + eps))

    # store new params and momentum running average coefficient

    tl.store(offset_p_ptr, p, mask = mask)
    tl.store(offset_exp_avg_ptr, exp_avg, mask = mask)
    tl.store(offset_exp_avg2_ptr, exp_avg2, mask = mask)

If anyone else has a working triton AdamW optimizer that would also be helpful :). Thanks!

@Jokeren
Copy link
Contributor

Jokeren commented Apr 11, 2023

hmm, we don't have tl.libdevice anymore. Can you try triton/main?

@mitchellnw
Copy link
Author

mitchellnw commented Apr 11, 2023

Yep, I was using the pip install but will re-run with installing from source. Thanks - will let you know if that resolves!

@mitchellnw
Copy link
Author

Hm, same error:

Traceback (most recent call last):
  File "<string>", line 22, in _opt_kernel2
KeyError: ('2-.-1-.-0-83ca8b715a9dc5f32dc1110973485f64-9a197408afc30884cecf5bedb3c7049b-8a97a7050a3f4cab2ea2ba8f09449b6d-bd5c6baf0724e403eab8dc030c72f2a7-cacecb5a01b695fe1eb376e18972d557-06b47813aaed5d9f4
2e68c0b9c8e48c0-8995dea56505a768227bc4e164ffe442-1b2b0313e0260ff80c807e598ca0d1ff-868af27d0622dee0bd2d3276f76b9332-02c1ea983827fd34c6c487bd5600ca16-293a6a48c2635776c9dee8b54603306f-9d79534bfe19ad2486925db
9d67e58d2', (torch.float32, torch.float32, torch.float32, torch.float32, dtype('float64'), 'fp32', 'fp32', dtype('float64'), 'fp32', 'i1', 'i32'), (1024,), (True, True, True, True, (False,), (False,), (Fa
lse,), (False,), (False,), (False, True), (False, True)), 4, 2, False)

Apologies for bothering and will also try to debug myself -- the issue is that it is somewhat hard to debug as it happens roughly 5000 iterations into training.

Posting more from the stack trace for more info:

Traceback (most recent call last):
  File "~/git/git/triton/python/triton/compiler/code_generator.py", line 974, in ast_to_ttir
    generator.visit(fn.parse())
  File "~/git/git/triton/python/triton/compiler/code_generator.py", line 874, in visit
    return super().visit(node)
  File "~/git/miniconda3/envs/py2/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "~/git/git/triton/python/triton/compiler/code_generator.py", line 184, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "~/git/miniconda3/envs/py2/lib/python3.8/ast.py", line 379, in generic_visit
    self.visit(item)
  File "~/git/git/triton/python/triton/compiler/code_generator.py", line 874, in visit
    return super().visit(node)
  File "~/git/miniconda3/envs/py2/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "~/git/git/triton/python/triton/compiler/code_generator.py", line 253, in visit_FunctionDef
    has_ret = self.visit_compound_statement(node.body)
  File "~/git/git/triton/python/triton/compiler/code_generator.py", line 148, in visit_compound_statement
    self.last_ret_type = self.visit(stm
    return super().visit(node)
  File "~/git/miniconda3/envs/py2/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "~/git/git/triton/python/triton/compiler/code_generator.py", line 351, in visit_BinOp
    rhs = self.visit(node.right)
  File "~/git/git/triton/python/triton/compiler/code_generator.py", line 874, in visit
    return super().visit(node)
  File "~/git/miniconda3/envs/py2/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "~/git/git/triton/python/triton/compiler/code_generator.py", line 351, in visit_BinOp
    rhs = self.visit(node.right)
  File "~/git/git/triton/python/triton/compiler/code_generator.py", line 874, in visit
    return super().visit(node)
  File "~/git/miniconda3/envs/py2/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "~/git/git/triton/python/triton/compiler/code_generator.py", line 355, in visit_BinOp
    return self._apply_binary_method(method_name, lhs, rhs)
  File "~/git/git/triton/python/triton/compiler/code_generator.py", line 343, in _apply_binary_method
    return getattr(lhs, method_name)(rhs, _builder=self.builder)
  File "~/git/git/triton/python/triton/language/core.py", line 29, in wrapper
    return fn(*args, **kwargs)
  File "~/git/git/triton/python/triton/language/core.py", line 521, in __mul__
    return semantic.mul(self, other, _builder)
  File "~/git/git/triton/python/triton/language/semantic.py", line 167, in mul
    input, other = binary_op_type_checking_impl(input, other, builder)
  File "~/git/git/triton/python/triton/language/semantic.py", line 115, in binary_op_type_checking_impl
    check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr)
  File "~/git/git/triton/python/triton/language/semantic.py", line 95, in check_ptr_type_impl
    raise IncompatibleTypeErrorImpl(type_a, type_b)
triton.language.semantic.IncompatibleTypeErrorImpl: invalid operands of type pointer<fp64> and triton.language.fp32

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "~/git/miniconda3/envs/py2/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "~/git/miniconda3/envs/py2/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "~/training/main.py", line 868, in <module>
    exec(code, run_globals)                                                                                                                                                                         [2/1872]
  File "~/training/main.py", line 868, in <module>
    main(sys.argv[1:])
  File "~/training/main.py", line 765, in main
    train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer,
  File "~/training/train.py", line 214, in train_one_epoch
    optimizer.step()
  File "~/git/miniconda3/envs/py2/lib/python3.8/site-packages/torch/optim/optimizer.py", line 280, in wrapper
    out = func(*args, **kwargs)
  File "~/git/miniconda3/envs/py2/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "~/training/optimizers/adamw.py", line 260, in step
    update_fn(p, grad, exp_avg, exp_avg_sq, lr, wd, beta1hat, beta2hat, self.eps, True)
  File "~/training/optimizers/adamw.py", line 159, in update_fn
    _opt_kernel2[grid](
  File "~/git/git/triton/python/triton/runtime/autotuner.py", line 110, in run
    return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
  File "<string>", line 42, in _opt_kernel2
  File "~/git/git/triton/python/triton/compiler/compiler.py", line 455, in compile
    next_module = compile_kernel(module)
  File "~/git/git/triton/python/triton/compiler/compiler.py", line 375, in <lambda>
    lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug), arch))
  File "~/git/git/triton/python/triton/compiler/code_generator.py", line 983, in ast_to_ttir
    raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 39:22:
    # load, early exit if nan
    grad = tl.load(offset_grad_ptr, mask = mask)
    #if tl.max(tl.libdevice.isnan(grad), 0) == 0:

    p = tl.load(offset_p_ptr, mask = mask)
    exp_avg = tl.load(offset_exp_avg_ptr, mask = mask)
    exp_avg2 = tl.load(offset_exp_avg2_ptr, mask = mask)

    # stepweight decay

    p = p * (1 - lr * wd)
                      ^
IncompatibleTypeErrorImpl('invalid operands of type pointer<fp64> and triton.language.fp32')

@mitchellnw
Copy link
Author

Still unable to resolve, if anyone has any advice would be much appreciated!

ptillet pushed a commit that referenced this issue Apr 29, 2023
…oad (#1593)

Closes #1556
#1512

The current hash used for caching the cubin does not include the
architecture. This leads to the following error when compiling against
one arch and running against another (with no code changes to trigger a
recompilation).
```
RuntimeError: Triton Error [CUDA]: device kernel image is invalid
```
Was not sure what unit tests would be appropriate here (if any)

Co-authored-by: davidma <davidma@speechmatics.com>
@mitchellnw
Copy link
Author

I've installed from source and ran the same code. It no longer shows the above error but still fails randomly in the middle of training. The error is now:

Traceback (most recent call last):
  File "~/miniconda3/envs/py9/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    ast.NodeVisitor.generic_visit(self, node)
  File "~/miniconda3/envs/py9/lib/python3.9/ast.py", line 415, in generic_visit
    values = self.visit(node.value)
  File "~/git/triton/python/triton/compiler/code_generator.py", line 933, in visit
    raise IncompatibleTypeErrorImpl(type_a, type_b)
triton.language.semantic.IncompatibleTypeErrorImpl: invalid operands of type pointer<fp64> and triton.language.fp32

@ptillet
Copy link
Collaborator

ptillet commented May 20, 2023

I don't think the failure is random, since this is a compilation error and the kernel would only get recompiled when it is re-specialized. It would be helpful if you could record the kernel arguments that this kernel is called with when it crashes and create a self-contained and simple repro for it.

Also, it's separately strange that the error doesn't show the cursor

jayfurmanek added a commit to ROCm/triton that referenced this issue May 22, 2023
* [OPTIMIZER] simplified pipeline pass (triton-lang#1582)

directly rematerialize for loop with the right values, instead of
replacing unpipelined load uses a posteriori

* [OPTIMIZER] Added kWidth attribute to DotOperandEncoding (triton-lang#1584)

This is a pre-requisist for efficient mixed-precision matmul

* [TEST] Fix test cache (triton-lang#1588)

To avoid puzzling segment fault problems caused by multiprocessing, this
PR:

- Uses "spawn" instead of "fork".
- Define the `instance_descriptor` namedtuple globally.
- Make the `kernel_sub` JITFunction defined by the child process only.

* [BACKEND] Updated slice layout semantics, updated vectorization logic used for load/store ops. (triton-lang#1587)

* [FRONTEND][BACKEND] Add the `noinline` annotation for `triton.jit` (triton-lang#1568)

# Introducing the `noinline` Parameter for Triton JIT Decorator

We're excited to introduce a new parameter, `noinline`, that can be
added to the `jit` decorator in Triton. This parameter allows developers
to specify that a particular Triton function should not be inlined into
its callers. In this post, we'll dive into the syntax, purpose, and
implementation details of this new feature.

## Syntax

To use the `noinline` parameter, simply add `noinline=True` to the `jit`
decorator for the function that you don't want to be inlined. Here's an
example:

```python
@triton.jit(noinline=True)
def device_fn(x, y, Z):
    z = x + y
    tl.store(Z, z)

def test_noinline():
    @triton.jit
    def kernel(X, Y, Z):
        x = tl.load(X)
        y = tl.load(Y)
        device_fn(x, y, Z)
```

In this example, the `device_fn` function is decorated with
`@triton.jit(noinline=True)`, indicating that it should not be inlined
into its caller, `kernel`.

## Purpose

The `noinline` parameter serves several key purposes:

- Reducing code size: By preventing inlining, we can reduce the size of
the compiled code.
- Facilitating debugging: Keeping functions separate can make it easier
to debug the code.
- Avoiding common subexpression elimination (CSE) in certain cases: CSE
can sometimes be avoided by using the `noinline` parameter to reduce
register pressure.
- Enabling dynamic linking: This parameter makes it possible to
dynamically link Triton functions.

## Implementation

The implementation of the `noinline` parameter involves significant
changes to three analysis modules in Triton: *Allocation*, *Membar*, and
*AxisInfo*. Prior to this update, these modules assumed that all Triton
functions had been inlined into the root kernel function. With the
introduction of non-inlined functions, we've had to rework these
assumptions and make corresponding changes to the analyses.

### Call Graph and Limitations

<div style="text-align: center;">
<img
src="https://user-images.githubusercontent.com/2306281/234663904-12864247-3412-4405-987b-6991cdf053bb.png"
alt="figure 1" width="200" height="auto">
</div>

To address the changes, we build a call graph and perform all the
analyses on the call graph instead of a single function. The call graph
is constructed by traversing the call edges and storing them in an edge
map. Roots are extracted by checking nodes with no incoming edges.

The call graph has certain limitations:

- It does not support recursive function calls, although this could be
implemented in the future.
- It does not support dynamic function calls, where the function name is
unknown at compilation time.

### Allocation

<div style="text-align: center;">
<img
src="https://user-images.githubusercontent.com/2306281/234665110-bf6a2660-06fb-4648-85dc-16429439e72d.png"
alt="figure 2" width="400" height="auto">
</div>

In Triton, shared memory allocation is achieved through two operations:
`triton_gpu.convert_layout` and `triton_gpu.alloc_tensor`. The
`convert_layout` operation allocates an internal tensor, which we refer
to as a *scratch* buffer, while the `alloc_tensor` operation returns an
allocated tensor and is thus known as an *explicit* buffer.

To accommodate the introduction of function calls, we are introducing a
third type of buffer called a *virtual* buffer. Similar to scratch
buffers, virtual buffers are allocated internally within the scope of a
function call, and the buffers allocated by the called functions remain
invisible to subsequent operations in the calling function. However,
virtual buffers are distinct from scratch buffers in that the call
operation itself does not allocate memory—instead, it specifies the
total amount of memory required by all the child functions being called.
The actual allocation of buffers is performed by individual operations
within these child functions. For example, when invoking edge e1, no
memory is allocated, but the total amount of memory needed by function B
is reserved. Notably, the amount of shared memory used by function B
remains fixed across its call sites due to the consideration of dynamic
control flows within each function.

An additional challenge to address is the calculation of shared memory
offsets for functions within a call graph. While we can assume a shared
memory offset starting at 0 for a single root function, this is not the
case with a call graph, where we must determine each function's starting
offset based on the call path. Although each function has a fixed memory
consumption, the starting offset may vary. For instance, in Figure 2,
the starting offset of function C through edges e1->e2 differs from that
through edges e2->e4. To handle this, we accumulate the starting offset
at each call site and pass it as an argument to the called function.
Additionally, we amend both the function declaration and call sites by
appending an offset variable.

### Membar

<div style="text-align: center;">
<img
src="https://user-images.githubusercontent.com/2306281/234665157-844dd66f-5028-4ef3-bca2-4ca74b8f969d.png"
alt="figure 3" width="300" height="auto">
</div>

The membar pass is dependent on the allocation analysis. Once the offset
and size of each buffer are known, we conduct a post-order traversal of
the call graph and analyze each function on an individual basis. Unlike
previous analyses, we now return buffers that remain unsynchronized at
the end of functions, allowing the calling function to perform
synchronization in cases of overlap.

### AxisInfo

<div style="text-align: center;">
<img
src="https://user-images.githubusercontent.com/2306281/234665183-790a11ac-0ba1-47e1-98b1-e356220405a3.png"
alt="figure 4" width="400" height="auto">
</div>

The AxisInfo analysis operates differently from both membar and
allocation, as it traverses the call graph in topological order. This is
necessary because function arguments may contain axis information that
will be utilized by callee functions. As we do not implement
optimizations like function cloning, each function has a single code
base, and the axis information for an argument is determined as a
conservative result of all axis information passed by the calling
functions.

---------

Co-authored-by: Philippe Tillet <phil@openai.com>

* [FRONTEND] add architecture to hash to avoid invalid image on cubin load (triton-lang#1593)

Closes triton-lang#1556
triton-lang#1512

The current hash used for caching the cubin does not include the
architecture. This leads to the following error when compiling against
one arch and running against another (with no code changes to trigger a
recompilation).
```
RuntimeError: Triton Error [CUDA]: device kernel image is invalid
```
Was not sure what unit tests would be appropriate here (if any)

Co-authored-by: davidma <davidma@speechmatics.com>

* [FRONTEND] Fix calling local variables’ attribute functions in the if statement (triton-lang#1597)

If `node.func` is an `ast.Attribute`, it won't cause an early return.
(Not sure if I interpret it correctly)

triton-lang#1591

* [OPTIMIZER][BACKEND] Enabled elementwise ops (including casts) between ldmatrix and mma.sync (triton-lang#1595)

* [RUNTIME] Ensure we hold the GIL before calling into CPython API in cubin binding (triton-lang#1583)

Formatting of the diff is not the best. I only indented the whole
function, moved the creation of the py::bytes and the return out of the
scope and declared and assigned the cubin variable appropriately.
Everything else is unchanged.

Today it triggers the following error on CPython debug build:
```
Fatal Python error: _PyMem_DebugMalloc: Python memory allocator called without holding the GIL
Python runtime state: initialized

```

---------

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Philippe Tillet <phil@openai.com>

* Merge branch `llvm-head` (triton-lang#1600)

* Zahi/slice reduce rebased (triton-lang#1594)

[BACKEND] Enable slice layout support for reduce op

* [OPTIMIZER] Fix crash in loop pipelining. (triton-lang#1602)

Fixes issue triton-lang#1601.

* [FRONTEND] make torch optional (triton-lang#1604)

make torch optional to fix circular dependency issue

* [OPTIMIZER] Clean-up Utility.cpp and fixed bug in RematerializeForward (triton-lang#1608)

ConvertLayoutOp can be folded in other ConvertLayoutOp

* [BACKEND] Fixed up ConvertLayout for slices (triton-lang#1616)

* [FRONTEND] Add `tl.expand_dims` (triton-lang#1614)

This exposes `semantic.expand_dims` in the public API and builds upon it
with support for expanding multiple dimensions at once. e.g.
```python
tl.expand_dims(tl.arange(0, N), (0, -1))  # shape = [1, N, 1]
```

Compared to indexing with `None`, this API is useful because the
dimensions can be constexpr values rather than hard-coded into the
source. As a basic example
```python
@triton.jit
def max_keepdim(value, dim):
    res = tl.max(value, dim)
    return tl.expand_dims(res, dim)
```

* [BACKEND] Modified store op thread masking (triton-lang#1605)

* [CI] no longer runs CI job on macos-10.15 (triton-lang#1624)

* [BACKEND] Allow noinline functions to return multiple values of primitive types (triton-lang#1623)

Fix triton-lang#1621

* [BACKEND] Updated predicate for atomic ops (triton-lang#1619)

* [TEST] Added convert layout test from/to sliced blocked/mma (triton-lang#1620)

* [BACKEND] fix typo in Membar class about WAR description and refine some code (triton-lang#1629)

Co-authored-by: Philippe Tillet <phil@openai.com>

* [SETUP] Removing `torch` as a test dependency (triton-lang#1632)

circular dependency is causing troubles now that our interpreter depends
on torch 2.0 ...

* [DOCS] Fix docstrings for sphinx docs (triton-lang#1635)

* [FRONTEND] Added interpreter mode (triton-lang#1573)

Simple mechanism to run Triton kernels on PyTorch for debugging purpose
(upstream from Kernl).

Todo:
- random grid iteration
- support of atomic ops
- more unit tests
- cover new APIs?

* [CI] Build wheels for musllinux (triton-lang#1638)

Ideally you would also build source distributions so that it is in
principle possible to build `triton` on other platforms, but building
`musllinux` wheels would at least help with openai/whisper#1328.

I suspect you will also get people showing up at some point asking for
`aarch64` wheels as well. It might be worth taking a look at the
[`cibuildwheel` output
matrix](https://cibuildwheel.readthedocs.io/en/stable/#what-does-it-do)
to see what you are comfortable with shipping (particularly if you
aren't shipping source distributions).

* [FRONTEND] Fix return op related control flow issues (triton-lang#1637)

- Case 1: Return after static control flow is taken. Peel off
instructions after the first `return` for each basic block.

```python
if static_condition:
    tl.store(...)
    return
return
```

- Case 2: Return exists in both `if` and `else` branches of an inlined
`JITFunction` function

```python
def foo():
    if dynamic_condition:
        return a
    else:
        return b
```

- Case 3: Return exists in a `JITFunction` from another module

```python
import module
if cond:
    a = module.func()
```

- Case 4: A chain of calls through undefined local variables

```python
import module
if cond:
    a = x
    a = a.to(tl.int32).to(tl.int32)
```

- Case 5: Call a function `func` without returning variables. `func` is
recognized as an `Expr` first instead of a `Call`.

```python
if cond:
    foo()
else:
    bar()
```

- Case 6: Call a `noinline` function. We don't need to check if the
function contains any return op.

* [CI] Upload CUDA test artifacts (triton-lang#1645)

* [FRONTEND] Add support for scalar conditions in `device_assert` (triton-lang#1641)

This sometimes happens in TorchInductor. See
pytorch/pytorch#100880.
More generally, it's useful to be able to write `tl.device_assert(False,
msg)`.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>

* [FRONTEND] Hotfix for `contains_return_op` (triton-lang#1651)

`noinline` can be None, False, or True, so we have to check the callee
in the first two cases.

* [TEST] Fixed and re-enabled reduce test (triton-lang#1644)

Re-enabled reduce test after fixing the %cst stride in the ttgir, and
modifying the sweep parameters to make sure the shape per CTA to be less
than or equal to the tensor shape.

* [FRONTEND] Don't call set_device in tl.dot (triton-lang#1646)

This breaks multiprocess compilation

* [TESTS] Add regression test for issue triton-lang#1601. (triton-lang#1611)

Following up on triton-lang#1603, I am adding a new file meant to contain
functional regression tests to the repository.
Let me know if another folder would be a more appropriate place for
these tests.

Co-authored-by: Philippe Tillet <phil@openai.com>

* [BUILD] Move canonicalization patterns of Load/Store to Ops.cpp. (NFC) (triton-lang#1650)

This breaks a cyclic dependency between the TritonAnalysis and the
TritonIR libraries (see triton-lang#1649). It also follows the convention from
upstream (for example, see the AMDGPU, Affine, and Arith dialects).

* [FRONTEND] Better error messages for noinline functions (triton-lang#1657)

```
at 10:18:def val_multiplier_noinline(val, i):
    return val * i

           ^
Function val_multiplier_noinline is marked noinline, but was called with non-scalar argument val:fp32[constexpr[128]]
```

* [BUILD] Add missing CMake link-time dependencies. (triton-lang#1654)

* [BACKEND] Move isSharedEncoding to TritonGPUIR. (triton-lang#1655)

This breaks a cyclic dependency between TritonAnalysis and TritonGPUIR
(see triton-lang#1649).

* [FRONTEND] Do not use exceptions do guide control flow in compilation runtime (triton-lang#1663)

Triton runtime currently relies on KeyError to check whether a kernel
has been compiled. This results in somewhat confusing backtraces when
running the kernel crashes, as the stack traces includes not only the
actual crash, but also the stack trace for the original KeyError which
was caught.

* [FRONTEND] Assert that for loop bounds must be ints (triton-lang#1664)

* [OPTIMIZER] Fix-up reduction cloning

* [DEPENDENCIES] Update LLVM to 17.0.0 (c5dede880d17) and port changes. (triton-lang#1668)

This depends on a [pending LLVM
release](ptillet/triton-llvm-releases#10).

* Implement setCalleeFromCallable in CallOp.
* Cast type to ShapedType for various getters.
* Improve TritonDialect::materializeConstant due to breaking change in
constructor of arith::ConstantOp.
* Add OpaqueProperties argument in inferReturnTypes.

Co-authored-by: Philippe Tillet <phil@openai.com>

* [OPTIMIZER] adjusted selection heuristics for when `mmaLayout.warpsPerTile[1] = 1` (triton-lang#1675)

this fixes fused attention with D_HEAD=128

* [BUILD] stop depending on dlfcn-win32 by implementing `dladdr` natively with WIN32 API (triton-lang#1674)

Co-authored-by: Philippe Tillet <phil@openai.com>

* [BUILD] minor fixes (triton-lang#1676)

Remove unused variables, fix member initializer list order.

* [FRONTEND] Differentiate between bool and int in the frontend (triton-lang#1678)

`bool` is a subclass of `int`, so `isinstance(bool_var, int) == True`,
and a `bool` constant will be converted to an `int` constant.

In triton specifically, if a bool var is treated as an integer, it
prevents us using the `logical_and` operator which requires both
operands have the same bit length.

> Cannot bitcast data-type of size 32 to data-type of size 1

By differentiating int and bool, it allows us to make the syntax more
close to native python. We can now use `if bool_var and condition` to
check the truthiness, and `if bool_var is True` to check identity.

* [BUILD] Add deduction guide for `Interval` (triton-lang#1680)

This avoids `ctad-maybe-unsupported` warning.

* [OPS] Remove duplicated function already defined in `triton` module. (triton-lang#1679)

* IFU 230517 Resolve merge conflicts

* Fix is_hip() check

* [ROCM] Fix hardcoded warpsize in getMask

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Zahi Moudallal <128723247+zahimoud@users.noreply.github.com>
Co-authored-by: David MacLeod <macleod.david@live.co.uk>
Co-authored-by: davidma <davidma@speechmatics.com>
Co-authored-by: albanD <desmaison.alban@gmail.com>
Co-authored-by: Christian Sigg <chsigg@users.noreply.github.com>
Co-authored-by: Benjamin Chetioui <3920784+bchetioui@users.noreply.github.com>
Co-authored-by: Michaël Benesty <pommedeterresautee@users.noreply.github.com>
Co-authored-by: peterbell10 <peterbell10@live.co.uk>
Co-authored-by: long.chen <lipracer@gmail.com>
Co-authored-by: q.yao <streetyao@live.com>
Co-authored-by: Paul Ganssle <1377457+pganssle@users.noreply.github.com>
Co-authored-by: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com>
Co-authored-by: Natalia Gimelshein <ngimel@fb.com>
Co-authored-by: Ingo Müller <github.com@ingomueller.net>
Co-authored-by: Ingo Müller <ingomueller@google.com>
Co-authored-by: George Karpenkov <cheshire@google.com>
Co-authored-by: Sophia Wisdom <sophia.wisdom1999@gmail.com>
Co-authored-by: cloudhan <cloudhan@outlook.com>
Co-authored-by: Daniil Fukalov <1671137+dfukalov@users.noreply.github.com>
@antony-frolov
Copy link

Hi! Getting the same error using flash_attn_qkvpacked_func from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L832

pingzhuu pushed a commit to siliconflow/triton that referenced this issue Apr 2, 2024
…oad (triton-lang#1593)

Closes triton-lang#1556
triton-lang#1512

The current hash used for caching the cubin does not include the
architecture. This leads to the following error when compiling against
one arch and running against another (with no code changes to trigger a
recompilation).
```
RuntimeError: Triton Error [CUDA]: device kernel image is invalid
```
Was not sure what unit tests would be appropriate here (if any)

Co-authored-by: davidma <davidma@speechmatics.com>
ZzEeKkAa pushed a commit to ZzEeKkAa/triton that referenced this issue Aug 5, 2024
…converting shared layout to dot layout. (triton-lang#1512)

Support the `repCluster` field in convert shared layout to dot layout
with parent layout of DPAS.

---------

Signed-off-by: Tiotto, Ettore <ettore.tiotto@intel.com>
Co-authored-by: Tiotto, Ettore <ettore.tiotto@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants