Skip to content

Commit

Permalink
[RUNTIME] Fixed JIT bug that leg some constexpr values to be override…
Browse files Browse the repository at this point in the history
…n by specialization parameters (#742)
  • Loading branch information
ptillet authored Oct 5, 2022
1 parent 77c752d commit bdfdb9a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/triton/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,7 @@ def make_triton_ir(fn, signature, specialization, constants):
gscope = fn.__globals__.copy()
function_name = '_'.join([fn.__name__, kernel_suffix(signature.values(), specialization)])
tys = list(signature.values())
new_constants = {k: True if tys[k] == "i1" else 1 for k in specialization.equal_to_1}
new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in specialization.equal_to_1}
new_attrs = {k: ("multiple_of", 16) for k in specialization.divisible_by_16}
all_constants = constants.copy()
all_constants.update(new_constants)
Expand Down
6 changes: 3 additions & 3 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,12 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
except KeyError:
# build dict of constant values
args = [{args}]
configs = self._get_config(*args),
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
configs = self._get_config(*all_args),
constants = self._make_constants(constexpr_key)
constants.update({{i: None for i, arg in enumerate(args) if arg is None}})
constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}})
constants.update({{i: 1 for i in configs[0].equal_to_1}})
# build kernel signature -- doesn't include specialized arguments
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }}
# build stub signature -- includes arguments that are specialized
for i, arg in constants.items():
Expand Down

0 comments on commit bdfdb9a

Please sign in to comment.