Skip to content

Commit 3ed20a8

Browse files
committed
parametrize the arch and kernel options
1 parent b5b6b98 commit 3ed20a8

File tree

1 file changed

+77
-77
lines changed

1 file changed

+77
-77
lines changed

cuda_bindings/tests/test_nvjitlink.py

Lines changed: 77 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -75,83 +75,83 @@ def test_invalid_arch_error():
7575
nvjitlink.create(1, ["-arch=sm_XX"])
7676

7777

78-
def test_create_and_destroy():
79-
for option in ARCHITECTURES:
80-
handle = nvjitlink.create(1, [f"-arch={option}"])
81-
assert handle != 0
82-
nvjitlink.destroy(handle)
83-
84-
85-
def test_complete_empty():
86-
for option in ARCHITECTURES:
87-
handle = nvjitlink.create(1, [f"-arch={option}"])
88-
nvjitlink.complete(handle)
89-
nvjitlink.destroy(handle)
90-
91-
92-
def test_add_data():
93-
for option, ptx_bytes in zip(ARCHITECTURES, ptx_kernel_bytes):
94-
handle = nvjitlink.create(1, [f"-arch={option}"])
95-
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
96-
nvjitlink.complete(handle)
97-
nvjitlink.destroy(handle)
98-
99-
100-
def test_add_file(tmp_path):
101-
for option, ptx_bytes in zip(ARCHITECTURES, ptx_kernel_bytes):
102-
handle = nvjitlink.create(1, [f"-arch={option}"])
103-
file_path = tmp_path / "test_file.cubin"
104-
file_path.write_bytes(ptx_bytes)
105-
nvjitlink.add_file(handle, nvjitlink.InputType.ANY, str(file_path))
106-
nvjitlink.complete(handle)
107-
nvjitlink.destroy(handle)
108-
109-
110-
def test_get_error_log():
111-
for option in ARCHITECTURES:
112-
handle = nvjitlink.create(1, [f"-arch={option}"])
113-
nvjitlink.complete(handle)
114-
log_size = nvjitlink.get_error_log_size(handle)
115-
log = bytearray(log_size)
116-
nvjitlink.get_error_log(handle, log)
117-
assert len(log) == log_size
118-
nvjitlink.destroy(handle)
119-
120-
121-
def test_get_info_log():
122-
for option, ptx_bytes in zip(ARCHITECTURES, ptx_kernel_bytes):
123-
handle = nvjitlink.create(1, [f"-arch={option}"])
124-
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
125-
nvjitlink.complete(handle)
126-
log_size = nvjitlink.get_info_log_size(handle)
127-
log = bytearray(log_size)
128-
nvjitlink.get_info_log(handle, log)
129-
assert len(log) == log_size
130-
nvjitlink.destroy(handle)
131-
132-
133-
def test_get_linked_cubin():
134-
for option, ptx_bytes in zip(ARCHITECTURES, ptx_kernel_bytes):
135-
handle = nvjitlink.create(1, [f"-arch={option}"])
136-
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
137-
nvjitlink.complete(handle)
138-
cubin_size = nvjitlink.get_linked_cubin_size(handle)
139-
cubin = bytearray(cubin_size)
140-
nvjitlink.get_linked_cubin(handle, cubin)
141-
assert len(cubin) == cubin_size
142-
nvjitlink.destroy(handle)
143-
144-
145-
def test_get_linked_ptx():
146-
for option in ARCHITECTURES:
147-
handle = nvjitlink.create(3, [f"-arch={option}", "-lto", "-ptx"])
148-
nvjitlink.add_data(handle, nvjitlink.InputType.LTOIR, empty_kernel_ltoir, len(empty_kernel_ltoir), "test_data")
149-
nvjitlink.complete(handle)
150-
ptx_size = nvjitlink.get_linked_ptx_size(handle)
151-
ptx = bytearray(ptx_size)
152-
nvjitlink.get_linked_ptx(handle, ptx)
153-
assert len(ptx) == ptx_size
154-
nvjitlink.destroy(handle)
78+
@pytest.mark.parametrize("option", ARCHITECTURES)
79+
def test_create_and_destroy(option):
80+
handle = nvjitlink.create(1, [f"-arch={option}"])
81+
assert handle != 0
82+
nvjitlink.destroy(handle)
83+
84+
85+
@pytest.mark.parametrize("option", ARCHITECTURES)
86+
def test_complete_empty(option):
87+
handle = nvjitlink.create(1, [f"-arch={option}"])
88+
nvjitlink.complete(handle)
89+
nvjitlink.destroy(handle)
90+
91+
92+
@pytest.mark.parametrize("option, ptx_bytes", zip(ARCHITECTURES, ptx_kernel_bytes))
93+
def test_add_data(option, ptx_bytes):
94+
handle = nvjitlink.create(1, [f"-arch={option}"])
95+
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
96+
nvjitlink.complete(handle)
97+
nvjitlink.destroy(handle)
98+
99+
100+
@pytest.mark.parametrize("option, ptx_bytes", zip(ARCHITECTURES, ptx_kernel_bytes))
101+
def test_add_file(option, ptx_bytes, tmp_path):
102+
handle = nvjitlink.create(1, [f"-arch={option}"])
103+
file_path = tmp_path / "test_file.cubin"
104+
file_path.write_bytes(ptx_bytes)
105+
nvjitlink.add_file(handle, nvjitlink.InputType.ANY, str(file_path))
106+
nvjitlink.complete(handle)
107+
nvjitlink.destroy(handle)
108+
109+
110+
@pytest.mark.parametrize("option", ARCHITECTURES)
111+
def test_get_error_log(option):
112+
handle = nvjitlink.create(1, [f"-arch={option}"])
113+
nvjitlink.complete(handle)
114+
log_size = nvjitlink.get_error_log_size(handle)
115+
log = bytearray(log_size)
116+
nvjitlink.get_error_log(handle, log)
117+
assert len(log) == log_size
118+
nvjitlink.destroy(handle)
119+
120+
121+
@pytest.mark.parametrize("option, ptx_bytes", zip(ARCHITECTURES, ptx_kernel_bytes))
122+
def test_get_info_log(option, ptx_bytes):
123+
handle = nvjitlink.create(1, [f"-arch={option}"])
124+
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
125+
nvjitlink.complete(handle)
126+
log_size = nvjitlink.get_info_log_size(handle)
127+
log = bytearray(log_size)
128+
nvjitlink.get_info_log(handle, log)
129+
assert len(log) == log_size
130+
nvjitlink.destroy(handle)
131+
132+
133+
@pytest.mark.parametrize("option, ptx_bytes", zip(ARCHITECTURES, ptx_kernel_bytes))
134+
def test_get_linked_cubin(option, ptx_bytes):
135+
handle = nvjitlink.create(1, [f"-arch={option}"])
136+
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
137+
nvjitlink.complete(handle)
138+
cubin_size = nvjitlink.get_linked_cubin_size(handle)
139+
cubin = bytearray(cubin_size)
140+
nvjitlink.get_linked_cubin(handle, cubin)
141+
assert len(cubin) == cubin_size
142+
nvjitlink.destroy(handle)
143+
144+
145+
@pytest.mark.parametrize("option", ARCHITECTURES)
146+
def test_get_linked_ptx(option):
147+
handle = nvjitlink.create(3, [f"-arch={option}", "-lto", "-ptx"])
148+
nvjitlink.add_data(handle, nvjitlink.InputType.LTOIR, empty_kernel_ltoir, len(empty_kernel_ltoir), "test_data")
149+
nvjitlink.complete(handle)
150+
ptx_size = nvjitlink.get_linked_ptx_size(handle)
151+
ptx = bytearray(ptx_size)
152+
nvjitlink.get_linked_ptx(handle, ptx)
153+
assert len(ptx) == ptx_size
154+
nvjitlink.destroy(handle)
155155

156156

157157
def test_package_version():

0 commit comments

Comments
 (0)