Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accuracy regression updating triton pin in inductor #2456

Closed
peterbell10 opened this issue Oct 6, 2023 · 11 comments
Closed

Accuracy regression updating triton pin in inductor #2456

peterbell10 opened this issue Oct 6, 2023 · 11 comments
Assignees

Comments

@peterbell10
Copy link
Contributor

peterbell10 commented Oct 6, 2023

I'm trying to update PyTorch's triton pin in pytorch/pytorch#109601 but am seeing accuracy regressions in several models.

I was able to get a minimized pytorch program which produces two kernels

PyTorch reproducer
import torch
from torch import tensor, device
import torch.fx as fx
from torch._dynamo.testing import rand_strided
from math import inf
import torch._inductor.inductor_prims

import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
torch._dynamo.config.translation_validation = True
torch._inductor.config.fallback_random = True
torch._inductor.config.generate_intermediate_hooks = True



isolate_fails_code_str = None



# torch version: 2.2.0a0+gitd4f91a7
# torch cuda version: 12.1
# torch git version: d4f91a73149a1adb65c2ca676e70b63b0ad8e4ca


# CUDA Info: 
# nvcc not found
# GPU Hardware Info: 
# NVIDIA GeForce RTX 2060 : 2 


from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()

    
    
    def forward(self, arg0_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, sub):
        reflection_pad2d_backward = torch.ops.aten.reflection_pad2d_backward.default(arg5_1, arg0_1, [1, 1, 1, 1]);  arg5_1 = arg0_1 = None
        add = torch.ops.aten.add.Tensor(arg4_1, reflection_pad2d_backward);  arg4_1 = reflection_pad2d_backward = None
        copy = torch.ops.aten.copy.default(arg6_1, add);  arg6_1 = add = None
        clone = torch.ops.aten.clone.default(copy, memory_format = torch.contiguous_format)
        where = torch.ops.aten.where.self(arg7_1, arg3_1, clone);  arg7_1 = arg3_1 = clone = None
        copy_1 = torch.ops.aten.copy.default(copy, where);  copy = where = None
        convert_element_type = torch.ops.prims.convert_element_type.default(copy_1, torch.float32);  copy_1 = None
        mul = torch.ops.aten.mul.Tensor(convert_element_type, sub);  convert_element_type = sub = None
        sum_1 = torch.ops.aten.sum.dim_IntList(mul, [0, 2, 3]);  mul = None
        return (sum_1,)
        
def load_args(reader):
    buf0 = reader.storage('74a26d59db37fc38f602a562287ebca2f4a2bf15', 2097152, device=device(type='cuda', index=0), dtype_hint=torch.float16)
    reader.tensor(buf0, (1, 256, 64, 64), dtype=torch.float16, is_leaf=True)  # arg0_1
    buf1 = reader.storage('4f57a57893d372510ad16cb0336ee53827960f1f', 2, device=device(type='cuda', index=0), dtype_hint=torch.float16)
    reader.tensor(buf1, (), dtype=torch.float16, is_leaf=True)  # arg3_1
    buf2 = reader.storage('963b7771941b550e47d28b5b95547f7cb8b78599', 2097152, device=device(type='cuda', index=0), dtype_hint=torch.float16)
    reader.tensor(buf2, (1, 256, 64, 64), dtype=torch.float16, is_leaf=True)  # arg4_1
    buf3 = reader.storage('4b525890684d0082ac598512d8e23878e72838c6', 2230272, device=device(type='cuda', index=0), dtype_hint=torch.float16)
    reader.tensor(buf3, (1, 256, 66, 66), dtype=torch.float16, is_leaf=True)  # arg5_1
    buf4 = reader.storage('2db25999e45cc8a42b67f8a0038d08820abc870e', 2097152, device=device(type='cuda', index=0), dtype_hint=torch.float16)
    reader.tensor(buf4, (1, 256, 64, 64), dtype=torch.float16, is_leaf=True)  # arg6_1
    buf5 = reader.storage('e41fb6b2eeea86fd1c077f333a81b1300715e22e', 1048576, device=device(type='cuda', index=0), dtype_hint=torch.bool)
    reader.tensor(buf5, (1, 256, 64, 64), dtype=torch.bool, is_leaf=True)  # arg7_1
    buf6 = reader.storage('dbf8ba93ef42b07e489dc5e66d1e168458507745', 4194304, device=device(type='cuda', index=0))
    reader.tensor(buf6, (1, 256, 64, 64), is_leaf=True)  # sub
load_args._version = 0
mod = Repro()
if __name__ == '__main__':
    from torch._dynamo.repro.after_aot import run_repro
    with torch.no_grad():        run_repro(mod, load_args, accuracy=True, command='run', save_dir='./minifier/checkpoints', tracing_mode='real', check_str=None)

and the corresponding triton code is here

Full reproducer
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile

from torch import empty_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels


aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()

torch._inductor.config.compile_threads=1


# kernel path: /tmp/torchinductor_pbell/33/c33kkoxtkl3cu7pojlucwidpgx4d6ytqyclb5f4myjcgifk2wh2k.py
# Source Nodes: [], Original ATen: []

triton_poi_fused_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import AutotuneHint, pointwise
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers

@pointwise(
    size_hints=[16384, 64], tile_hint=TileHint.SQUARE,
    filename=__file__,
    meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32', 3: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'mutated_arg_names': [], 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_0', 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(2, 3))]},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 16384
    xnumel = 64
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    x2 = xindex
    y0 = yindex % 64
    y1 = (yindex // 64)
    y3 = yindex
    tmp0 = 1 + x2
    tmp1 = tl.full([1, 1], 0, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = tl.full([1, 1], 65, tl.int64)
    tmp4 = tmp0 <= tmp3
    tmp5 = tmp2 & tmp4
    tmp6 = 1 + y0
    tmp7 = tmp6 >= tmp1
    tmp8 = tmp6 <= tmp3
    tmp9 = tmp7 & tmp8
    tmp10 = tmp5 & tmp9
    tmp11 = tl.load(in_ptr0 + (67 + y0 + (66*x2) + (4356*y1)), tmp10 & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
    tmp12 = tl.full(tmp11.shape, 0.0, tmp11.dtype)
    tmp13 = tl.where(tmp10, tmp11, tmp12)
    tmp14 = y0
    tmp15 = tl.full([1, 1], 1, tl.int64)
    tmp16 = tmp14 >= tmp15
    tmp17 = tmp14 <= tmp15
    tmp18 = tmp16 & tmp17
    tmp19 = tmp5 & tmp18
    tmp20 = tl.load(in_ptr0 + (67 + ((-1)*y0) + (66*x2) + (4356*y1)), tmp19 & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
    tmp21 = tl.full(tmp20.shape, 0.0, tmp20.dtype)
    tmp22 = tl.where(tmp19, tmp20, tmp21)
    tmp23 = tmp13 + tmp22
    tmp24 = tl.full([1, 1], 62, tl.int64)
    tmp25 = tmp14 >= tmp24
    tmp26 = tmp14 <= tmp24
    tmp27 = tmp25 & tmp26
    tmp28 = tmp5 & tmp27
    tmp29 = tl.load(in_ptr0 + (193 + ((-1)*y0) + (66*x2) + (4356*y1)), tmp28 & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
    tmp30 = tl.full(tmp29.shape, 0.0, tmp29.dtype)
    tmp31 = tl.where(tmp28, tmp29, tmp30)
    tmp32 = tmp23 + tmp31
    tmp33 = x2
    tmp34 = tmp33 >= tmp15
    tmp35 = tmp33 <= tmp15
    tmp36 = tmp34 & tmp35
    tmp37 = tmp36 & tmp9
    tmp38 = tl.load(in_ptr0 + (67 + y0 + ((-66)*x2) + (4356*y1)), tmp37 & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
    tmp39 = tl.full(tmp38.shape, 0.0, tmp38.dtype)
    tmp40 = tl.where(tmp37, tmp38, tmp39)
    tmp41 = tmp32 + tmp40
    tmp42 = tmp33 >= tmp24
    tmp43 = tmp33 <= tmp24
    tmp44 = tmp42 & tmp43
    tmp45 = tmp44 & tmp9
    tmp46 = tl.load(in_ptr0 + (8383 + y0 + ((-66)*x2) + (4356*y1)), tmp45 & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
    tmp47 = tl.full(tmp46.shape, 0.0, tmp46.dtype)
    tmp48 = tl.where(tmp45, tmp46, tmp47)
    tmp49 = tmp41 + tmp48
    tmp50 = tmp36 & tmp18
    tmp51 = tl.load(in_ptr0 + (67 + ((-1)*y0) + ((-66)*x2) + (4356*y1)), tmp50 & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
    tmp52 = tl.full(tmp51.shape, 0.0, tmp51.dtype)
    tmp53 = tl.where(tmp50, tmp51, tmp52)
    tmp54 = tmp49 + tmp53
    tmp55 = tmp36 & tmp27
    tmp56 = tl.load(in_ptr0 + (193 + ((-1)*y0) + ((-66)*x2) + (4356*y1)), tmp55 & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
    tmp57 = tl.full(tmp56.shape, 0.0, tmp56.dtype)
    tmp58 = tl.where(tmp55, tmp56, tmp57)
    tmp59 = tmp54 + tmp58
    tmp60 = tmp44 & tmp18
    tmp61 = tl.load(in_ptr0 + (8383 + ((-1)*y0) + ((-66)*x2) + (4356*y1)), tmp60 & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
    tmp62 = tl.full(tmp61.shape, 0.0, tmp61.dtype)
    tmp63 = tl.where(tmp60, tmp61, tmp62)
    tmp64 = tmp59 + tmp63
    tmp65 = tmp44 & tmp27
    tmp66 = tl.load(in_ptr0 + (8509 + ((-1)*y0) + ((-66)*x2) + (4356*y1)), tmp65 & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
    tmp67 = tl.full(tmp66.shape, 0.0, tmp66.dtype)
    tmp68 = tl.where(tmp65, tmp66, tmp67)
    tmp69 = tmp64 + tmp68
    tl.store(out_ptr0 + (x2 + (64*y3)), tmp69, xmask)
''')

import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream


# kernel path: /tmp/torchinductor_pbell/cs/ccsjwloinafccyvnhjzgl5dymi7mjudbjobdudg5zehstelingg3.py
# Source Nodes: [], Original ATen: []

triton_red_fused_1 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import AutotuneHint, reduction
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers

@reduction(
    size_hints=[256, 4096],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    meta={'signature': {0: '*i1', 1: '*fp16', 2: '*fp16', 3: '*fp16', 4: '*fp16', 5: '*fp32', 6: 'i32', 7: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'mutated_arg_names': [], 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_1', 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(6, 7))]}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 256
    rnumel = 4096
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    tmp1 = tl.load(in_ptr1 + (0)).to(tl.float32)
    tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
    _tmp12 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r3 = rindex
        r1 = rindex % 64
        r2 = (rindex // 64)
        tmp0 = tl.load(in_ptr0 + (r3 + (4096*x0)), rmask & xmask, eviction_policy='evict_first').to(tl.int1)
        tmp3 = tl.load(in_ptr2 + (r3 + (4096*x0)), rmask & xmask, eviction_policy='evict_first', other=0).to(tl.float32)
        tmp4 = tl.load(in_ptr3 + (r2 + (64*r1) + (4096*x0)), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
        tmp8 = tl.load(in_ptr4 + (r3 + (4096*x0)), rmask & xmask, eviction_policy='evict_first', other=0).to(tl.float32)
        tmp5 = tmp3 + tmp4
        tmp6 = tl.where(tmp0, tmp2, tmp5)
        tmp7 = tmp6.to(tl.float32)
        tmp9 = tmp8.to(tl.float32)
        tmp10 = tmp7 * tmp9
        tmp11 = tl.broadcast_to(tmp10, [XBLOCK, RBLOCK])
        tmp13 = _tmp12 + tmp11
        _tmp12 = tl.where(rmask & xmask, tmp13, _tmp12)
    tmp12 = tl.sum(_tmp12, 1)[:, None]
    tl.store(out_ptr0 + (x0), tmp12, xmask)
''')


async_compile.wait(globals())
del async_compile

def call(args):
    arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1 = args
    args.clear()
    assert_size_stride(arg0_1, (1, 256, 64, 64), (1048576, 4096, 64, 1))
    assert_size_stride(arg1_1, (), ())
    assert_size_stride(arg2_1, (1, 256, 64, 64), (1048576, 4096, 64, 1))
    assert_size_stride(arg3_1, (1, 256, 66, 66), (1115136, 4356, 66, 1))
    assert_size_stride(arg4_1, (1, 256, 64, 64), (1048576, 4096, 64, 1))
    assert_size_stride(arg5_1, (1, 256, 64, 64), (1048576, 4096, 64, 1))
    assert_size_stride(arg6_1, (1, 256, 64, 64), (1048576, 4096, 64, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        buf0 = empty_strided((1, 256, 64, 64), (1048576, 4096, 1, 64), device='cuda', dtype=torch.float16)
        # Source Nodes: [], Original ATen: []
        stream0 = get_cuda_stream(0)
        triton_poi_fused_0.run(arg3_1, buf0, 16384, 64, grid=grid(16384, 64), stream=stream0)
        del arg3_1
        buf1 = empty_strided((256, ), (1, ), device='cuda', dtype=torch.float32)
        # Source Nodes: [], Original ATen: []
        triton_red_fused_1.run(arg5_1, arg1_1, arg2_1, buf0, arg6_1, buf1, 256, 4096, grid=grid(256), stream=stream0)
        del arg1_1
        del arg2_1
        del arg5_1
        del arg6_1
        return (buf1, )

class Repro(torch.nn.Module):
    def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1):
        reflection_pad2d_backward = torch.ops.aten.reflection_pad2d_backward.default(arg3_1, arg0_1, [1, 1, 1, 1]);  arg3_1 = arg0_1 = None
        add = torch.ops.aten.add.Tensor(arg2_1, reflection_pad2d_backward);  arg2_1 = reflection_pad2d_backward = None
        copy = torch.ops.aten.copy.default(arg4_1, add);  arg4_1 = add = None
        where = torch.ops.aten.where.self(arg5_1, arg1_1, copy);  arg5_1 = arg1_1 = None
        copy_1 = torch.ops.aten.copy.default(copy, where);  copy = where = None
        convert_element_type = torch.ops.prims.convert_element_type.default(copy_1, torch.float32);  copy_1 = None
        mul = torch.ops.aten.mul.Tensor(convert_element_type, arg6_1);  convert_element_type = arg6_1 = None
        sum_1 = torch.ops.aten.sum.dim_IntList(mul, [0, 2, 3]);  mul = None
        return (sum_1,)


torch.manual_seed(0)
mod = Repro()
args = [
  torch.randn((1, 256, 64, 64), dtype=torch.float16, device="cuda"),
  torch.randn((), dtype=torch.float16, device="cuda"),
  torch.randn((1, 256, 64, 64), dtype=torch.float16, device="cuda"),
  torch.randn((1, 256, 66, 66), dtype=torch.float16, device="cuda"),
  torch.randn((1, 256, 64, 64), dtype=torch.float16, device="cuda"),
  torch.randint(0, 1, (1, 256, 64, 64), dtype=torch.bool, device="cuda"),
  torch.randn((1, 256, 64, 64), dtype=torch.float16, device="cuda"),
]
from torch.fx.experimental.proxy_tensor import make_fx
from torch._dynamo.debug_utils import same_two_models
from torch._inductor.compile_fx import align_inputs

mod = make_fx(mod, tracing_mode="fake")(*args)
call_aligned = align_inputs(call, args, range(len(args)))
assert same_two_models(mod, lambda inputs: call_aligned(list(inputs)), args, only_fwd=True)

git bisect suggests that #2285 is the root cause of these failures, and I see the PR changes the llvm IR generated for the second kernel with the for loop.

@ThomasRaoux
Copy link
Collaborator

could you try to comment out this line in latest code:
https://github.com/openai/triton/blob/eed4559df2f4100b947bda39c7b4af21d9576684/lib/Target/LLVMIR/LLVMIRTranslation.cpp#L125

If it solves your problem we can probably disable that for now as it was meant to help further vectorization work that I ended up putting on hold.

@peterbell10
Copy link
Contributor Author

Yup that fixes the reproducer for me.

@ThomasRaoux
Copy link
Collaborator

ThomasRaoux commented Oct 6, 2023

Yup that fixes the reproducer for me.

ok thanks, let me send a PR to disable this pass and I'll debug offline

@ThomasRaoux
Copy link
Collaborator

Sent: #2458

@peterbell10
Copy link
Contributor Author

Running on a newer commit I still see some failures but the model the reproducer was from is fixed so those must be separate issues.

@ThomasRaoux
Copy link
Collaborator

I'm trying to debug that but I'm a bit confused as the reproducer just seem to print some performance number, how can I see the reproduce the functional bug?

@peterbell10
Copy link
Contributor Author

@ThomasRaoux I've updated the triton code in the issue so it runs and compares accuracy against eager pytorch.

@ThomasRaoux
Copy link
Collaborator

Thanks @peterbell10! We have to revert the revert to avoid some perf regressions in our internal workloads (#2498). I'll be debugging this next week, from a first look it seems like this exposes a NVTPX issue I'll try to find a solution, let me know if you have any concerns.

@ThomasRaoux ThomasRaoux reopened this Oct 14, 2023
@peterbell10
Copy link
Contributor Author

@ThomasRaoux do you have any updates on this, or possibly some more details on the PTX issue you're seeing?

@ThomasRaoux ThomasRaoux self-assigned this Nov 3, 2023
@peterbell10
Copy link
Contributor Author

Look like the underlying cause is the same as #2483, so closing this one.

@ThomasRaoux
Copy link
Collaborator

Thanks for figuring out this bug and finding a fix!

wkpark pushed a commit to wkpark/triton that referenced this issue Dec 1, 2023
 * maually applied, rebased, cleanup, fix lint errors
 * support clang for windows
 * disable MSVC CXX warnings
wkpark pushed a commit to wkpark/triton that referenced this issue Dec 1, 2023
 * maually applied, rebased, fix lint errors
 * support clang for windows
 * use set_target_properties(), cleanup for windows
 * windows ninja does not support platform option '/A'
 * remove unknown option '/m'

Signed-off-by: Won-Kyu Park <wkpark@gmail.com>
wkpark pushed a commit to wkpark/triton that referenced this issue Dec 1, 2023
 * maually applied, rebased, fix lint errors
 * support clang for windows
 * use set_target_properties(), cleanup for windows
 * windows ninja does not support platform option '/A'
 * remove unknown option '/m'

Signed-off-by: Won-Kyu Park <wkpark@gmail.com>
wkpark pushed a commit to wkpark/triton that referenced this issue Dec 1, 2023
 * WIN32 fix using LoadLibrary

Signed-off-by: Won-Kyu Park <wkpark@gmail.com>
wkpark pushed a commit to wkpark/triton that referenced this issue Dec 1, 2023
 * WIN32 fix using LoadLibrary

Signed-off-by: Won-Kyu Park <wkpark@gmail.com>
wkpark added a commit to wkpark/triton that referenced this issue Dec 1, 2023
 * based on Windows support PR triton-lang#2456 by @andreigh
 * DISPATCH_ARGS fix by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Dec 1, 2023
 * based on Windows support PR triton-lang#2456 by @andreigh
 * DISPATCH_ARGS fix by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Dec 1, 2023
 * based on Windows support PR triton-lang#2456 by @andreigh
 * DISPATCH_ARGS fix by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Dec 1, 2023
 * based on Windows support PR triton-lang#2456 by @andreigh
 * DISPATCH_ARGS fix by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Dec 1, 2023
 * based on Windows support PR triton-lang#2456 by @andreigh
 * DISPATCH_ARGS fix by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Dec 1, 2023
 * based on Windows support PR triton-lang#2456 by @andreigh
 * DISPATCH_ARGS fix by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Dec 1, 2023
 * based on Windows support PR triton-lang#2456 by @andreigh
 * DISPATCH_ARGS fix by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Dec 1, 2023
 * based on Windows support PR triton-lang#2456 by @andreigh
 * DISPATCH_ARGS fix by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Dec 3, 2023
 * based on Windows support PR triton-lang#2456 by @andreigh
 * DISPATCH_ARGS fix by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Dec 3, 2023
 * based on Windows support PR triton-lang#2456 by @andreigh
 * DISPATCH_ARGS fix by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Dec 5, 2023
 * based on Windows support PR triton-lang#2456 by @andreigh
 * DISPATCH_ARGS fix by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Dec 8, 2023
 * based on Windows support PR triton-lang#2456 by @andreigh
 * DISPATCH_ARGS fix by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Dec 17, 2023
 * based on Windows support PR triton-lang#2456 by @andreigh
 * DISPATCH_ARGS fix by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Dec 18, 2023
 * based on Windows support PR triton-lang#2456 by @andreigh
 * DISPATCH_ARGS fix by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Jan 14, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * DISPATCH_ARGS fix by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Jan 14, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * DISPATCH_ARGS fix by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Jan 21, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * DISPATCH_ARGS fix by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Jan 21, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * DISPATCH_ARGS fix by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Jan 23, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * DISPATCH_ARGS fix by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Feb 10, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Feb 10, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Feb 10, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * WIN32 fix using LoadLibrary
mantaionut pushed a commit to mantaionut/triton that referenced this issue Apr 2, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * WIN32 fix using LoadLibrary
mantaionut pushed a commit to mantaionut/triton that referenced this issue May 8, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * WIN32 fix using LoadLibrary
mantaionut pushed a commit to mantaionut/triton that referenced this issue May 21, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * WIN32 fix using LoadLibrary
mantaionut pushed a commit to mantaionut/triton that referenced this issue May 31, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * WIN32 fix using LoadLibrary
mantaionut pushed a commit to mantaionut/triton that referenced this issue Jun 4, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * WIN32 fix using LoadLibrary
mantaionut pushed a commit to mantaionut/triton that referenced this issue Jun 7, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Oct 15, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Oct 16, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Oct 17, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Oct 17, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Oct 17, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Oct 18, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Oct 18, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * WIN32 fix using LoadLibrary
wkpark added a commit to wkpark/triton that referenced this issue Oct 19, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * WIN32 fix using LoadLibrary
 * nvidia: nvcuda.dll, cupti*.dll, cublasLT*.dll
 * amd: fix dlfcn to support win32
 * lint
wkpark added a commit to wkpark/triton that referenced this issue Oct 19, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * WIN32 fix using LoadLibrary
 * nvidia: nvcuda.dll, cupti*.dll, cublasLT*.dll
 * amd: fix dlfcn to support win32
 * hsa: not supported but added for future use
 * lint
wkpark added a commit to wkpark/triton that referenced this issue Oct 19, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * WIN32 fix using LoadLibrary
 * nvidia: nvcuda.dll, cupti*.dll, cublasLT*.dll
 * amd: fix dlfcn to support win32
 * hsa: not supported but added for future use
 * lint
wkpark added a commit to wkpark/triton that referenced this issue Oct 19, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * WIN32 fix using LoadLibrary
 * nvidia: nvcuda.dll, cupti*.dll, cublasLT*.dll
 * amd: fix dlfcn to support win32
 * hsa: not supported but added for future use
 * lint
wkpark added a commit to wkpark/triton that referenced this issue Oct 23, 2024
 * based on Windows support PR triton-lang#2456 by @andreigh
 * WIN32 fix using LoadLibrary
 * nvidia: nvcuda.dll, cupti*.dll, cublasLT*.dll
 * amd: fix dlfcn to support win32
 * hsa: not supported but added for future use
 * lint
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants