diff --git a/dpctl-capi/source/dpctl_sycl_usm_interface.cpp b/dpctl-capi/source/dpctl_sycl_usm_interface.cpp index 66a796db7f..be1becb1f1 100644 --- a/dpctl-capi/source/dpctl_sycl_usm_interface.cpp +++ b/dpctl-capi/source/dpctl_sycl_usm_interface.cpp @@ -44,9 +44,18 @@ DEFINE_SIMPLE_CONVERSION_FUNCTIONS(void, DPCTLSyclUSMRef) __dpctl_give DPCTLSyclUSMRef DPCTLmalloc_shared(size_t size, __dpctl_keep const DPCTLSyclQueueRef QRef) { - auto Q = unwrap(QRef); - auto Ptr = malloc_shared(size, *Q); - return wrap(Ptr); + if (!QRef) { + std::cerr << "Input QRef is nullptr\n"; + return nullptr; + } + try { + auto Q = unwrap(QRef); + auto Ptr = malloc_shared(size, *Q); + return wrap(Ptr); + } catch (feature_not_supported const &fns) { + std::cerr << fns.what() << '\n'; + return nullptr; + } } __dpctl_give DPCTLSyclUSMRef @@ -54,14 +63,29 @@ DPCTLaligned_alloc_shared(size_t alignment, size_t size, __dpctl_keep const DPCTLSyclQueueRef QRef) { - auto Q = unwrap(QRef); - auto Ptr = aligned_alloc_shared(alignment, size, *Q); - return wrap(Ptr); + if (!QRef) { + std::cerr << "Input QRef is nullptr\n"; + return nullptr; + } + try { + auto Q = unwrap(QRef); + auto Ptr = aligned_alloc_shared(alignment, size, *Q); + return wrap(Ptr); + } catch (feature_not_supported const &fns) { + std::cerr << fns.what() << '\n'; + return nullptr; + } } __dpctl_give DPCTLSyclUSMRef DPCTLmalloc_host(size_t size, __dpctl_keep const DPCTLSyclQueueRef QRef) { + if (!QRef) { + std::cerr << "Input QRef is nullptr\n"; + return nullptr; + } + // SYCL 2020 spec: for devices without aspect::usm_host_allocations: + // undefined behavior auto Q = unwrap(QRef); auto Ptr = malloc_host(size, *Q); return wrap(Ptr); @@ -72,6 +96,12 @@ DPCTLaligned_alloc_host(size_t alignment, size_t size, __dpctl_keep const DPCTLSyclQueueRef QRef) { + if (!QRef) { + std::cerr << "Input QRef is nullptr\n"; + return nullptr; + } + // SYCL 2020 spec: for devices without aspect::usm_host_allocations: + // undefined behavior auto Q = unwrap(QRef); auto Ptr = aligned_alloc_host(alignment, size, *Q); return wrap(Ptr); @@ -80,9 +110,18 @@ DPCTLaligned_alloc_host(size_t alignment, __dpctl_give DPCTLSyclUSMRef DPCTLmalloc_device(size_t size, __dpctl_keep const DPCTLSyclQueueRef QRef) { - auto Q = unwrap(QRef); - auto Ptr = malloc_device(size, *Q); - return wrap(Ptr); + if (!QRef) { + std::cerr << "Input QRef is nullptr\n"; + return nullptr; + } + try { + auto Q = unwrap(QRef); + auto Ptr = malloc_device(size, *Q); + return wrap(Ptr); + } catch (feature_not_supported const &fns) { + std::cerr << fns.what() << '\n'; + return nullptr; + } } __dpctl_give DPCTLSyclUSMRef @@ -90,14 +129,31 @@ DPCTLaligned_alloc_device(size_t alignment, size_t size, __dpctl_keep const DPCTLSyclQueueRef QRef) { - auto Q = unwrap(QRef); - auto Ptr = aligned_alloc_device(alignment, size, *Q); - return wrap(Ptr); + if (!QRef) { + std::cerr << "Input QRef is nullptr\n"; + return nullptr; + } + try { + auto Q = unwrap(QRef); + auto Ptr = aligned_alloc_device(alignment, size, *Q); + return wrap(Ptr); + } catch (feature_not_supported const &fns) { + std::cerr << fns.what() << '\n'; + return nullptr; + } } void DPCTLfree_with_queue(__dpctl_take DPCTLSyclUSMRef MRef, __dpctl_keep const DPCTLSyclQueueRef QRef) { + if (!QRef) { + std::cerr << "Input QRef is nullptr\n"; + return; + } + if (!MRef) { + std::cerr << "Input MRef is nullptr, nothing to free\n"; + return; + } auto Ptr = unwrap(MRef); auto Q = unwrap(QRef); free(Ptr, *Q); @@ -106,6 +162,14 @@ void DPCTLfree_with_queue(__dpctl_take DPCTLSyclUSMRef MRef, void DPCTLfree_with_context(__dpctl_take DPCTLSyclUSMRef MRef, __dpctl_keep const DPCTLSyclContextRef CRef) { + if (!CRef) { + std::cerr << "Input CRef is nullptr\n"; + return; + } + if (!MRef) { + std::cerr << "Input MRef is nullptr, nothing to free\n"; + return; + } auto Ptr = unwrap(MRef); auto C = unwrap(CRef); free(Ptr, *C); @@ -114,6 +178,14 @@ void DPCTLfree_with_context(__dpctl_take DPCTLSyclUSMRef MRef, const char *DPCTLUSM_GetPointerType(__dpctl_keep const DPCTLSyclUSMRef MRef, __dpctl_keep const DPCTLSyclContextRef CRef) { + if (!CRef) { + std::cerr << "Input CRef is nullptr\n"; + return "unknown"; + } + if (!MRef) { + std::cerr << "Input MRef is nullptr\n"; + return "unknown"; + } auto Ptr = unwrap(MRef); auto C = unwrap(CRef); @@ -134,6 +206,15 @@ DPCTLSyclDeviceRef DPCTLUSM_GetPointerDevice(__dpctl_keep const DPCTLSyclUSMRef MRef, __dpctl_keep const DPCTLSyclContextRef CRef) { + if (!CRef) { + std::cerr << "Input CRef is nullptr\n"; + return nullptr; + } + if (!MRef) { + std::cerr << "Input MRef is nullptr\n"; + return nullptr; + } + auto Ptr = unwrap(MRef); auto C = unwrap(CRef); diff --git a/dpctl/memory/_memory.pxd b/dpctl/memory/_memory.pxd index 1a2a573e66..1e9b796ff7 100644 --- a/dpctl/memory/_memory.pxd +++ b/dpctl/memory/_memory.pxd @@ -55,6 +55,13 @@ cdef public class _Memory [object Py_MemoryObject, type Py_MemoryType]: DPCTLSyclUSMRef p, SyclContext ctx) @staticmethod cdef public bytes get_pointer_type(DPCTLSyclUSMRef p, SyclContext ctx) + @staticmethod + cdef public object create_from_usm_pointer_size_qref( + DPCTLSyclUSMRef USMRef, + Py_ssize_t nbytes, + DPCTLSyclQueueRef QRef, + object memory_owner=* + ) cdef public class MemoryUSMShared(_Memory) [object PyMemoryUSMSharedObject, diff --git a/dpctl/memory/_memory.pyx b/dpctl/memory/_memory.pyx index 44211d155e..3838651f08 100644 --- a/dpctl/memory/_memory.pyx +++ b/dpctl/memory/_memory.pyx @@ -32,6 +32,7 @@ from dpctl._backend cimport ( # noqa: E211 DPCTLaligned_alloc_device, DPCTLaligned_alloc_host, DPCTLaligned_alloc_shared, + DPCTLContext_Delete, DPCTLfree_with_queue, DPCTLmalloc_device, DPCTLmalloc_host, @@ -39,9 +40,11 @@ from dpctl._backend cimport ( # noqa: E211 DPCTLQueue_Copy, DPCTLQueue_Create, DPCTLQueue_Delete, + DPCTLQueue_GetContext, DPCTLQueue_Memcpy, DPCTLSyclContextRef, DPCTLSyclDeviceRef, + DPCTLSyclUSMRef, DPCTLUSM_GetPointerDevice, DPCTLUSM_GetPointerType, ) @@ -106,7 +109,8 @@ def _to_memory(unsigned char[::1] b, str usm_kind): else: raise ValueError( "Unrecognized usm_kind={} stored in the " - "pickle".format(usm_kind)) + "pickle".format(usm_kind) + ) res.copy_from_host(b) return res @@ -214,7 +218,7 @@ cdef class _Memory: self.memory_ptr, ctx.get_context_ref() ) if kind == b'device': - raise ValueError('USM Device memory is not host accessible') + raise ValueError("USM Device memory is not host accessible") buffer.buf = self.memory_ptr buffer.format = 'B' # byte buffer.internal = NULL # see References @@ -431,6 +435,65 @@ cdef class _Memory: return usm_type + @staticmethod + cdef object create_from_usm_pointer_size_qref( + DPCTLSyclUSMRef USMRef, Py_ssize_t nbytes, + DPCTLSyclQueueRef QRef, object memory_owner=None + ): + r""" + Create appropriate `MemoryUSM*` object from pre-allocated + USM memory bound to SYCL context in the reference SYCL queue. + + Memory will be freed by `MemoryUSM*` object for default + value of memory_owner keyword. The non-default value should + be an object whose dealloc slot frees the memory. + + The object may not be a no-op dummy Python object to + delay freeing the memory until later times. + """ + cdef const char *usm_type + cdef DPCTLSyclContextRef CRef = NULL + cdef DPCTLSyclQueueRef QRef_copy = NULL + cdef _Memory _mem + cdef object mem_ty + if nbytes <= 0: + raise ValueError("Number of bytes must must be positive") + if (QRef is NULL): + raise TypeError("Argument DPCTLSyclQueueRef is NULL") + CRef = DPCTLQueue_GetContext(QRef) + if (CRef is NULL): + raise ValueError("Could not retrieve context from QRef") + usm_type = DPCTLUSM_GetPointerType(USMRef, CRef) + DPCTLContext_Delete(CRef) + if usm_type == b"shared": + mem_ty = MemoryUSMShared + elif usm_type == b"device": + mem_ty = MemoryUSMDevice + elif usm_type == b"host": + mem_ty = MemoryUSMHost + else: + raise ValueError( + "Argument pointer is not bound to " + "context in the given queue" + ) + res = _Memory.__new__(_Memory) + _mem = <_Memory> res + _mem._cinit_empty() + _mem.memory_ptr = USMRef + _mem.nbytes = nbytes + QRef_copy = DPCTLQueue_Copy(QRef) + if QRef_copy is NULL: + raise ValueError("Referenced queue could not be copied.") + try: + _mem.queue = SyclQueue._create(QRef_copy) # consumes the copy + except dpctl.SyclQueueCreationError as sqce: + raise ValueError( + "SyclQueue object could not be created from " + "copy of referenced queue" + ) from sqce + _mem.refobj = memory_owner + return mem_ty(res) + cdef class MemoryUSMShared(_Memory): """