Skip to content

Commit

Permalink
Generalize code in tutorials: triton-lang#2,triton-lang#3,triton-lang#4
Browse files Browse the repository at this point in the history
Signed-off-by: Anatoly Myachev <anatoly.myachev@intel.com>
  • Loading branch information
anmyachev committed Nov 29, 2024
1 parent cc89dac commit be20a1f
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 14 deletions.
12 changes: 7 additions & 5 deletions python/tutorials/02-fused-softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import triton.language as tl
from triton.runtime import driver

DEVICE = "cuda"


def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
Expand Down Expand Up @@ -110,7 +112,7 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n
# %%
# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.

device = torch.cuda.current_device()
device = getattr(torch, DEVICE).current_device()
properties = driver.active.utils.get_device_properties(device)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
Expand Down Expand Up @@ -189,7 +191,7 @@ def softmax(x):
# This will allow us to verify that our padding mechanism works.

torch.manual_seed(0)
x = torch.randn(1823, 781, device='cuda')
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
Expand Down Expand Up @@ -221,9 +223,9 @@ def softmax(x):
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE).Stream()
getattr(torch, DEVICE).set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
Expand Down
14 changes: 8 additions & 6 deletions python/tutorials/03-matrix-multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@
import triton
import triton.language as tl

DEVICE = "cuda"


def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"
Expand Down Expand Up @@ -355,8 +357,8 @@ def matmul(a, b, activation=""):
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS).

torch.manual_seed(0)
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
triton_output = matmul(a, b)
torch_output = torch.matmul(a, b)
print(f"triton_output_with_fp16_inputs={triton_output}")
Expand All @@ -373,8 +375,8 @@ def matmul(a, b, activation=""):
TORCH_HAS_FP8 = hasattr(torch, "float8_e5m2")
if TORCH_HAS_FP8 and is_cuda():
torch.manual_seed(0)
a = torch.randn((512, 512), device="cuda", dtype=torch.float16)
b = torch.randn((512, 512), device="cuda", dtype=torch.float16)
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
a = a.to(torch.float8_e5m2)
# pre-transpose b for efficiency.
b = b.T
Expand Down Expand Up @@ -423,8 +425,8 @@ def matmul(a, b, activation=""):

@triton.testing.perf_report(configs)
def benchmark(M, N, K, provider, fp8_inputs):
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
a = torch.randn((M, K), device=DEVICE, dtype=torch.float16)
b = torch.randn((K, N), device=DEVICE, dtype=torch.float16)
if TORCH_HAS_FP8 and fp8_inputs:
a = a.to(torch.float8_e5m2)
b = b.T
Expand Down
8 changes: 5 additions & 3 deletions python/tutorials/04-low-memory-dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
import triton
import triton.language as tl

DEVICE = "cuda"


@triton.jit
def _dropout(
Expand Down Expand Up @@ -71,10 +73,10 @@ def dropout(x, x_keep, p):


# Input tensor
x = torch.randn(size=(10, )).cuda()
x = getattr(torch.randn(size=(10, )), DEVICE)()
# Dropout mask
p = 0.5
x_keep = (torch.rand(size=(10, )) > p).to(torch.int32).cuda()
x_keep = getattr((torch.rand(size=(10, )) > p).to(torch.int32), DEVICE)()
#
output = dropout(x, x_keep=x_keep, p=p)
print(tabulate.tabulate([
Expand Down Expand Up @@ -138,7 +140,7 @@ def seeded_dropout(x, p, seed):
return output


x = torch.randn(size=(10, )).cuda()
x = getattr(torch.randn(size=(10, )), DEVICE)()
# Compare this to the baseline - dropout mask is never instantiated!
output = seeded_dropout(x, p=0.5, seed=123)
output2 = seeded_dropout(x, p=0.5, seed=123)
Expand Down

0 comments on commit be20a1f

Please sign in to comment.