Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD][BACKEND] Switch to code object v5 #5005

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ jobs:
runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-HIP)}}
name: Integration-Tests (${{matrix.runner[1] == 'gfx90a' && 'mi210' || 'mi300x'}})
container:
image: rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.4
image: rocmshared/pytorch:rocm6.2.2_ubuntu22.04_py3.10_pytorch_2.5.1_asan
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
steps:
- name: Checkout
Expand Down Expand Up @@ -382,7 +382,7 @@ jobs:
id: amd-install-triton
run: |
echo "PATH is '$PATH'"
pip uninstall -y triton
pip uninstall -y triton pytorch-triton-rocm
cd python
pip install -v -e '.[tests]'
- name: Clean up after an unsuccessful build
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/integration-tests.yml.in
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ jobs:
name: Integration-Tests (${{matrix.runner[1] == 'gfx90a' && 'mi210' || 'mi300x'}})

container:
image: rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.4
image: rocmshared/pytorch:rocm6.2.2_ubuntu22.04_py3.10_pytorch_2.5.1_asan
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root

steps:
Expand All @@ -384,7 +384,7 @@ jobs:
id: amd-install-triton
run: |
echo "PATH is '$PATH'"
pip uninstall -y triton
pip uninstall -y triton pytorch-triton-rocm
cd python
pip install -v -e '.[tests]'

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,5 +234,5 @@ Supported Platforms:
Supported Hardware:

- NVIDIA GPUs (Compute Capability 8.0+)
- AMD GPUs (ROCm 5.2+)
- AMD GPUs (ROCm 6.2+)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So rocm 5.x is no longer supported?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to the switch to cov5 test_optimize_thread_locality and some device_prints (the new test) will cause segfaults on 6.1 and 6.0, I haven't tested 5.x versions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to chat a bit regarding when to land this then, given it's a breaking change.

- Under development: CPUs
18 changes: 16 additions & 2 deletions python/test/unit/language/print_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,17 @@ def kernel_print_pointer(X, Y, BLOCK: tl.constexpr):
tl.device_print("ptr ", X + tl.arange(0, BLOCK))


@triton.jit
def kernel_print_2d_tensor(X, Y, BLOCK_SIZE_X: tl.constexpr, BLOCK_SIZE_Y: tl.constexpr):
off_x = tl.arange(0, BLOCK_SIZE_X)
off_y = tl.arange(0, BLOCK_SIZE_Y)
x = tl.load(X + off_x[:, None] * BLOCK_SIZE_Y + off_y[None, :])
tl.device_print("", x)


def test_print(func: str, data_type: str, device: str):
N = 128 # This value should match with test_print in test_subprocess.py.
# TODO(antiagainst): Currently the warp count is chosen to make sure wedon't have multiple
# TODO(antiagainst): Currently the warp count is chosen to make sure we don't have multiple
# threads printing duplicated messages due to broadcasting. Improve print op lowering logic
# to filter out duplicated data range.
num_warps = N // get_current_target_warp_size()
Expand Down Expand Up @@ -128,12 +136,18 @@ def test_print(func: str, data_type: str, device: str):
kernel_device_print_hex[(1, )](x, y, num_warps=num_warps, BLOCK=N)
elif func == "device_print_pointer":
kernel_print_pointer[(1, )](x, y, num_warps=num_warps, BLOCK=N)
elif func == "device_print_2d_tensor":
BLOCK_SIZE_X = num_warps
BLOCK_SIZE_Y = get_current_target_warp_size()
x_2d_tensor = x.reshape((BLOCK_SIZE_X, BLOCK_SIZE_Y))
kernel_print_2d_tensor[(1, )](x_2d_tensor, y, num_warps=num_warps, BLOCK_SIZE_X=BLOCK_SIZE_X,
BLOCK_SIZE_Y=BLOCK_SIZE_Y)
else:
assert f"Unknown kernel: {func}"

if func != "print_no_arg" and func != "no_arg_print" and func != "device_print_large" and \
func != "print_multiple_args" and func != "device_print_multiple_args" and \
func != "device_print_pointer" and func != "device_print_scalar":
func != "device_print_pointer" and func != "device_print_scalar" and func != "device_print_2d_tensor":
assert_close(y, x)

# Wait until driver complete all the jobs for the device_print, especially test_subprocess
Expand Down
13 changes: 5 additions & 8 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2513,21 +2513,18 @@ def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr):
def test_optimize_thread_locality(op, BLOCK_N, N, num_pid_n, device):

@triton.jit
def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl.constexpr):
def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
start_m = tl.program_id(0)
pid_n = tl.program_id(1)
num_pid_n = tl.num_programs(1)
local = INITIALIZE_PATCH
off_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
for start_n in range(pid_n, tl.cdiv(N, BLOCK_N), NUM_PID_N):
for start_n in range(pid_n, tl.cdiv(N, BLOCK_N), num_pid_n):
off_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
Xs = X + off_m[:, None] * N + off_n[None, :]
x = tl.load(Xs)
local = ACCUMULATE_PATCH
tl.store(Y + off_m * NUM_PID_N + pid_n, local)
# the following segfaults AMD backend following #3492
# really unclear why; the llvm-ir and kernel arguments are
# identical !
# tl.store(Y + off_m * tl.num_programs(1) + pid_n, local)
AlexAUT marked this conversation as resolved.
Show resolved Hide resolved
tl.store(Y + off_m * num_pid_n + pid_n, local)

initialize_patch = {
'sum': 'tl.zeros([BLOCK_M], dtype=tl.float32)',
Expand All @@ -2549,7 +2546,7 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl.
BLOCK_M = 32
x = torch.randn((BLOCK_M, N), dtype=torch.float32, device=device)
y = torch.randn((BLOCK_M, num_pid_n), dtype=torch.float32, device=device)
h = kernel[(1, num_pid_n, 1)](x, y, N, BLOCK_M, BLOCK_N, NUM_PID_N=num_pid_n)
h = kernel[(1, num_pid_n, 1)](x, y, N, BLOCK_M, BLOCK_N)
if not is_interpreter():
assert h.asm['ttgir'].count(
'"tt.reduce"') == 2, "tt.reduce should be called twice, otherwise the optimization didn't work"
Expand Down
10 changes: 10 additions & 0 deletions python/test/unit/language/test_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import sys
from collections import Counter

import triton

import pytest

dir_path = os.path.dirname(os.path.realpath(__file__))
Expand Down Expand Up @@ -35,6 +37,7 @@ def is_interpreter():
("device_print_pointer", "int32"),
("device_print_negative", "int32"),
("device_print_uint", "uint32"),
("device_print_2d_tensor", "int32"),
])
def test_print(func_type: str, data_type: str, device: str):
proc = subprocess.run(
Expand Down Expand Up @@ -101,6 +104,13 @@ def test_print(func_type: str, data_type: str, device: str):
elif func_type == "device_print_pointer":
for i in range(N):
expected_lines[f"pid (0, 0, 0) idx ({i:3}) ptr: 0x"] = 1
elif func_type == "device_print_2d_tensor":
warp_size = triton.runtime.driver.active.get_current_target().warp_size
x_dim = N // warp_size
y_dim = warp_size
for x in range(x_dim):
for y in range(y_dim):
expected_lines[f"pid (0, 0, 0) idx ({x}, {y:2}): {(x * y_dim + y)}"] = 1

actual_lines = Counter()
for line in outs:
Expand Down
2 changes: 1 addition & 1 deletion third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def make_llir(src, metadata, options):
# Set various control constants on the LLVM module so that device
# libraries can resolve references to them.
amd.set_isa_version(llvm_mod, options.arch)
amd.set_abi_version(llvm_mod, 400)
amd.set_abi_version(llvm_mod, 500)
amd.set_bool_control_constant(llvm_mod, "__oclc_finite_only_opt", False)
amd.set_bool_control_constant(llvm_mod, "__oclc_correctly_rounded_sqrt32", True)
amd.set_bool_control_constant(llvm_mod, "__oclc_unsafe_math_opt", False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ std::shared_ptr<Metric>
convertActivityToMetric(const roctracer_record_t *activity) {
std::shared_ptr<Metric> metric;
switch (activity->kind) {
case kHipVdiCommandTask:
case kHipVdiCommandKernel: {
if (activity->begin_ns < activity->end_ns) {
metric = std::make_shared<KernelMetric>(
Expand Down Expand Up @@ -135,7 +136,7 @@ void processActivity(RoctracerProfiler::CorrIdToExternIdMap &corrIdToExternId,
const roctracer_record_t *record, bool isAPI,
bool isGraph) {
switch (record->kind) {
case 0x11F1: // Task - kernel enqueued by graph launch
case kHipVdiCommandTask:
case kHipVdiCommandKernel: {
processActivityKernel(corrIdToExternId, externId, dataSet, record, isAPI,
isGraph);
Expand Down Expand Up @@ -169,6 +170,7 @@ std::pair<bool, bool> matchKernelCbId(uint32_t cbId) {
case HIP_API_ID_hipModuleLaunchCooperativeKernel:
case HIP_API_ID_hipModuleLaunchCooperativeKernelMultiDevice:
case HIP_API_ID_hipGraphExecDestroy:
case HIP_API_ID_hipGraphInstantiateWithFlags:
case HIP_API_ID_hipGraphInstantiate: {
isRuntimeApi = true;
break;
Expand Down Expand Up @@ -300,6 +302,13 @@ void RoctracerProfiler::RoctracerProfilerPimpl::apiCallback(
pImpl->StreamToCaptureCount[Stream]++;
break;
}
case HIP_API_ID_hipGraphInstantiateWithFlags: {
hipGraph_t Graph = data->args.hipGraphInstantiateWithFlags.graph;
hipGraphExec_t GraphExec =
*(data->args.hipGraphInstantiateWithFlags.pGraphExec);
pImpl->GraphExecToGraph[GraphExec] = Graph;
break;
}
case HIP_API_ID_hipGraphInstantiate: {
hipGraph_t Graph = data->args.hipGraphInstantiate.graph;
hipGraphExec_t GraphExec = *(data->args.hipGraphInstantiate.pGraphExec);
Expand Down