66# this software and related documentation outside the terms of the EULA
77# is strictly prohibited.
88
9- from cuda import cuda
9+ try :
10+ from cuda .bindings import driver
11+ except ImportError :
12+ from cuda import cuda as driver
13+
1014from cuda .core .experimental import Device
1115from cuda .core .experimental ._memory import Buffer , MemoryResource
1216from cuda .core .experimental ._utils import handle_return
@@ -65,11 +69,11 @@ def __init__(self, device):
6569 self .device = device
6670
6771 def allocate (self , size , stream = None ) -> Buffer :
68- ptr = handle_return (cuda .cuMemAllocManaged (size , cuda .CUmemAttach_flags .CU_MEM_ATTACH_GLOBAL .value ))
72+ ptr = handle_return (driver .cuMemAllocManaged (size , driver .CUmemAttach_flags .CU_MEM_ATTACH_GLOBAL .value ))
6973 return Buffer (ptr = ptr , size = size , mr = self )
7074
7175 def deallocate (self , ptr , size , stream = None ):
72- handle_return (cuda .cuMemFree (ptr ))
76+ handle_return (driver .cuMemFree (ptr ))
7377
7478 @property
7579 def is_device_accessible (self ) -> bool :
@@ -88,11 +92,11 @@ def __init__(self, device):
8892 self .device = device
8993
9094 def allocate (self , size , stream = None ) -> Buffer :
91- ptr = handle_return (cuda .cuMemAllocHost (size ))
95+ ptr = handle_return (driver .cuMemAllocHost (size ))
9296 return Buffer (ptr = ptr , size = size , mr = self )
9397
9498 def deallocate (self , ptr , size , stream = None ):
95- handle_return (cuda .cuMemFreeHost (ptr ))
99+ handle_return (driver .cuMemFreeHost (ptr ))
96100
97101 @property
98102 def is_device_accessible (self ) -> bool :
0 commit comments