Skip to content

Commit

Permalink
Merge pull request triton-lang#14 from ROCmSoftwarePlatform/fix_vecto…
Browse files Browse the repository at this point in the history
…rization

fix test_vectorization and test_load_cache_modifier
  • Loading branch information
rsanthanam-amd authored Oct 28, 2022
2 parents 7fce2bc + 8d9572b commit 48fcd8c
Showing 1 changed file with 34 additions and 21 deletions.
55 changes: 34 additions & 21 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,13 +1039,18 @@ def kernel(X, stride_xm, stride_xn,
# compare
triton.testing.assert_almost_equal(z_tri, z_ref)
triton.testing.assert_almost_equal(z_tri_contiguous, z_ref)
# parse ptx to make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
ptx = pgm_contiguous.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx

if torch.version.hip is None:
# parse ptx to make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
ptx = pgm_contiguous.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
else:
# TODO add rocm gcn assert
pass

# ---------------
# test dot
Expand Down Expand Up @@ -1306,16 +1311,20 @@ def _kernel(dst, src, CACHE: tl.constexpr):
tl.store(dst + offsets, x)

pgm = _kernel[(1,)](dst, src, CACHE=cache)
ptx = pgm.asm['ptx']
if cache == '':
assert 'ld.global.ca' not in ptx
assert 'ld.global.cg' not in ptx
if cache == '.cg':
assert 'ld.global.cg' in ptx
assert 'ld.global.ca' not in ptx
if cache == '.ca':
assert 'ld.global.ca' in ptx
assert 'ld.global.cg' not in ptx
if torch.version.hip is None:
ptx = pgm.asm['ptx']
if cache == '':
assert 'ld.global.ca' not in ptx
assert 'ld.global.cg' not in ptx
if cache == '.cg':
assert 'ld.global.cg' in ptx
assert 'ld.global.ca' not in ptx
if cache == '.ca':
assert 'ld.global.ca' in ptx
assert 'ld.global.cg' not in ptx
else:
# TODO add rocm gcn assert
pass


@pytest.mark.parametrize("N", [16, 10, 11, 1024])
Expand All @@ -1329,11 +1338,15 @@ def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
x = tl.load(src + offsets, mask=offsets < N)
tl.store(dst + offsets, x, mask=offsets < N)
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
ptx = pgm.asm["ptx"]
if N % 16 == 0:
assert "ld.global.v4.b32" in ptx
if torch.version.hip is None:
ptx = pgm.asm["ptx"]
if N % 16 == 0:
assert "ld.global.v4.b32" in ptx
else:
assert "ld.global.b32" in ptx
else:
assert "ld.global.b32" in ptx
#TODO add rocm assert
pass
# triton.testing.assert_almost_equal(dst, src[:N])
# ---------------
# test store
Expand Down

0 comments on commit 48fcd8c

Please sign in to comment.