Skip to content

Commit

Permalink
parametrize the arch and kernel options
Browse files Browse the repository at this point in the history
  • Loading branch information
ksimpson-work committed Nov 26, 2024
1 parent b5b6b98 commit 3ed20a8
Showing 1 changed file with 77 additions and 77 deletions.
154 changes: 77 additions & 77 deletions cuda_bindings/tests/test_nvjitlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,83 +75,83 @@ def test_invalid_arch_error():
nvjitlink.create(1, ["-arch=sm_XX"])


def test_create_and_destroy():
for option in ARCHITECTURES:
handle = nvjitlink.create(1, [f"-arch={option}"])
assert handle != 0
nvjitlink.destroy(handle)


def test_complete_empty():
for option in ARCHITECTURES:
handle = nvjitlink.create(1, [f"-arch={option}"])
nvjitlink.complete(handle)
nvjitlink.destroy(handle)


def test_add_data():
for option, ptx_bytes in zip(ARCHITECTURES, ptx_kernel_bytes):
handle = nvjitlink.create(1, [f"-arch={option}"])
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
nvjitlink.complete(handle)
nvjitlink.destroy(handle)


def test_add_file(tmp_path):
for option, ptx_bytes in zip(ARCHITECTURES, ptx_kernel_bytes):
handle = nvjitlink.create(1, [f"-arch={option}"])
file_path = tmp_path / "test_file.cubin"
file_path.write_bytes(ptx_bytes)
nvjitlink.add_file(handle, nvjitlink.InputType.ANY, str(file_path))
nvjitlink.complete(handle)
nvjitlink.destroy(handle)


def test_get_error_log():
for option in ARCHITECTURES:
handle = nvjitlink.create(1, [f"-arch={option}"])
nvjitlink.complete(handle)
log_size = nvjitlink.get_error_log_size(handle)
log = bytearray(log_size)
nvjitlink.get_error_log(handle, log)
assert len(log) == log_size
nvjitlink.destroy(handle)


def test_get_info_log():
for option, ptx_bytes in zip(ARCHITECTURES, ptx_kernel_bytes):
handle = nvjitlink.create(1, [f"-arch={option}"])
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
nvjitlink.complete(handle)
log_size = nvjitlink.get_info_log_size(handle)
log = bytearray(log_size)
nvjitlink.get_info_log(handle, log)
assert len(log) == log_size
nvjitlink.destroy(handle)


def test_get_linked_cubin():
for option, ptx_bytes in zip(ARCHITECTURES, ptx_kernel_bytes):
handle = nvjitlink.create(1, [f"-arch={option}"])
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
nvjitlink.complete(handle)
cubin_size = nvjitlink.get_linked_cubin_size(handle)
cubin = bytearray(cubin_size)
nvjitlink.get_linked_cubin(handle, cubin)
assert len(cubin) == cubin_size
nvjitlink.destroy(handle)


def test_get_linked_ptx():
for option in ARCHITECTURES:
handle = nvjitlink.create(3, [f"-arch={option}", "-lto", "-ptx"])
nvjitlink.add_data(handle, nvjitlink.InputType.LTOIR, empty_kernel_ltoir, len(empty_kernel_ltoir), "test_data")
nvjitlink.complete(handle)
ptx_size = nvjitlink.get_linked_ptx_size(handle)
ptx = bytearray(ptx_size)
nvjitlink.get_linked_ptx(handle, ptx)
assert len(ptx) == ptx_size
nvjitlink.destroy(handle)
@pytest.mark.parametrize("option", ARCHITECTURES)
def test_create_and_destroy(option):
handle = nvjitlink.create(1, [f"-arch={option}"])
assert handle != 0
nvjitlink.destroy(handle)


@pytest.mark.parametrize("option", ARCHITECTURES)
def test_complete_empty(option):
handle = nvjitlink.create(1, [f"-arch={option}"])
nvjitlink.complete(handle)
nvjitlink.destroy(handle)


@pytest.mark.parametrize("option, ptx_bytes", zip(ARCHITECTURES, ptx_kernel_bytes))
def test_add_data(option, ptx_bytes):
handle = nvjitlink.create(1, [f"-arch={option}"])
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
nvjitlink.complete(handle)
nvjitlink.destroy(handle)


@pytest.mark.parametrize("option, ptx_bytes", zip(ARCHITECTURES, ptx_kernel_bytes))
def test_add_file(option, ptx_bytes, tmp_path):
handle = nvjitlink.create(1, [f"-arch={option}"])
file_path = tmp_path / "test_file.cubin"
file_path.write_bytes(ptx_bytes)
nvjitlink.add_file(handle, nvjitlink.InputType.ANY, str(file_path))
nvjitlink.complete(handle)
nvjitlink.destroy(handle)


@pytest.mark.parametrize("option", ARCHITECTURES)
def test_get_error_log(option):
handle = nvjitlink.create(1, [f"-arch={option}"])
nvjitlink.complete(handle)
log_size = nvjitlink.get_error_log_size(handle)
log = bytearray(log_size)
nvjitlink.get_error_log(handle, log)
assert len(log) == log_size
nvjitlink.destroy(handle)


@pytest.mark.parametrize("option, ptx_bytes", zip(ARCHITECTURES, ptx_kernel_bytes))
def test_get_info_log(option, ptx_bytes):
handle = nvjitlink.create(1, [f"-arch={option}"])
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
nvjitlink.complete(handle)
log_size = nvjitlink.get_info_log_size(handle)
log = bytearray(log_size)
nvjitlink.get_info_log(handle, log)
assert len(log) == log_size
nvjitlink.destroy(handle)


@pytest.mark.parametrize("option, ptx_bytes", zip(ARCHITECTURES, ptx_kernel_bytes))
def test_get_linked_cubin(option, ptx_bytes):
handle = nvjitlink.create(1, [f"-arch={option}"])
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
nvjitlink.complete(handle)
cubin_size = nvjitlink.get_linked_cubin_size(handle)
cubin = bytearray(cubin_size)
nvjitlink.get_linked_cubin(handle, cubin)
assert len(cubin) == cubin_size
nvjitlink.destroy(handle)


@pytest.mark.parametrize("option", ARCHITECTURES)
def test_get_linked_ptx(option):
handle = nvjitlink.create(3, [f"-arch={option}", "-lto", "-ptx"])
nvjitlink.add_data(handle, nvjitlink.InputType.LTOIR, empty_kernel_ltoir, len(empty_kernel_ltoir), "test_data")
nvjitlink.complete(handle)
ptx_size = nvjitlink.get_linked_ptx_size(handle)
ptx = bytearray(ptx_size)
nvjitlink.get_linked_ptx(handle, ptx)
assert len(ptx) == ptx_size
nvjitlink.destroy(handle)


def test_package_version():
Expand Down

0 comments on commit 3ed20a8

Please sign in to comment.