Skip to content

Commit

Permalink
add architecture to hash to avoid invalid image on cubin load
Browse files Browse the repository at this point in the history
  • Loading branch information
davidma committed Apr 28, 2023
1 parent 8f47bdc commit a945ad6
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def convert_type_repr(x):
return x


def make_hash(fn, **kwargs):
def make_hash(fn, arch, **kwargs):
if isinstance(fn, triton.runtime.JITFunction):
configs = kwargs["configs"]
signature = kwargs["signature"]
Expand All @@ -262,7 +262,7 @@ def make_hash(fn, **kwargs):
# Get unique key for the compiled code
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1))
configs_key = [get_conf_key(conf) for conf in configs]
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{debug}"
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{debug}-{arch}"
return hashlib.md5(key.encode("utf-8")).hexdigest()
assert isinstance(fn, str)
return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest()
Expand Down Expand Up @@ -418,7 +418,7 @@ def compile(fn, **kwargs):
# cache manager
so_path = make_stub(name, signature, constants)
# create cache manager
fn_cache_manager = get_cache_manager(make_hash(fn, **kwargs))
fn_cache_manager = get_cache_manager(make_hash(fn, arch, **kwargs))
# determine name and extension type of provided function
if isinstance(fn, triton.runtime.JITFunction):
name, ext = fn.__name__, "ast"
Expand Down

0 comments on commit a945ad6

Please sign in to comment.