Skip to content

Commit d663afb

Browse files
committed
address review comments
1 parent c4dccff commit d663afb

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

cuda_core/tests/test_memory.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
# this software and related documentation outside the terms of the EULA
77
# is strictly prohibited.
88

9-
from cuda import cuda
9+
try:
10+
from cuda.bindings import driver
11+
except ImportError:
12+
from cuda import cuda as driver
13+
1014
from cuda.core.experimental import Device
1115
from cuda.core.experimental._memory import Buffer, MemoryResource
1216
from cuda.core.experimental._utils import handle_return
@@ -65,11 +69,11 @@ def __init__(self, device):
6569
self.device = device
6670

6771
def allocate(self, size, stream=None) -> Buffer:
68-
ptr = handle_return(cuda.cuMemAllocManaged(size, cuda.CUmemAttach_flags.CU_MEM_ATTACH_GLOBAL.value))
72+
ptr = handle_return(driver.cuMemAllocManaged(size, driver.CUmemAttach_flags.CU_MEM_ATTACH_GLOBAL.value))
6973
return Buffer(ptr=ptr, size=size, mr=self)
7074

7175
def deallocate(self, ptr, size, stream=None):
72-
handle_return(cuda.cuMemFree(ptr))
76+
handle_return(driver.cuMemFree(ptr))
7377

7478
@property
7579
def is_device_accessible(self) -> bool:
@@ -88,11 +92,11 @@ def __init__(self, device):
8892
self.device = device
8993

9094
def allocate(self, size, stream=None) -> Buffer:
91-
ptr = handle_return(cuda.cuMemAllocHost(size))
95+
ptr = handle_return(driver.cuMemAllocHost(size))
9296
return Buffer(ptr=ptr, size=size, mr=self)
9397

9498
def deallocate(self, ptr, size, stream=None):
95-
handle_return(cuda.cuMemFreeHost(ptr))
99+
handle_return(driver.cuMemFreeHost(ptr))
96100

97101
@property
98102
def is_device_accessible(self) -> bool:

0 commit comments

Comments
 (0)