diff --git a/sycl/include/CL/sycl/usm/usm_allocator.hpp b/sycl/include/CL/sycl/usm/usm_allocator.hpp index f821175205c03..e24a37eeff0d4 100644 --- a/sycl/include/CL/sycl/usm/usm_allocator.hpp +++ b/sycl/include/CL/sycl/usm/usm_allocator.hpp @@ -56,16 +56,16 @@ class usm_allocator { /// object. /// \param Val is a value to initialize the newly constructed object. template < - usm::alloc AllocT = AllocKind, + typename... ArgsT, usm::alloc AllocT = AllocKind, typename std::enable_if::type = 0> - void construct(pointer Ptr, const_reference Val) { - new (Ptr) value_type(Val); + void construct(pointer Ptr, ArgsT &&... Args) { + new (Ptr) value_type(std::forward(Args)...); } template < - usm::alloc AllocT = AllocKind, + typename... ArgsT, usm::alloc AllocT = AllocKind, typename std::enable_if::type = 0> - void construct(pointer, const_reference) { + void construct(pointer, ArgsT &&...) { throw feature_not_supported( "Device pointers do not support construct on host", PI_INVALID_OPERATION); @@ -87,7 +87,9 @@ class usm_allocator { usm::alloc AllocT = AllocKind, typename std::enable_if::type = 0> void destroy(pointer) { - // This method must be a NOP for device pointers. + throw feature_not_supported( + "Device pointers do not support construct on host", + PI_INVALID_OPERATION); } /// Note:: AllocKind == alloc::device is not allowed. diff --git a/sycl/test/usm/allocator_vector.cpp b/sycl/test/usm/allocator_vector.cpp index 265c071e1cf0e..66b707715fa51 100644 --- a/sycl/test/usm/allocator_vector.cpp +++ b/sycl/test/usm/allocator_vector.cpp @@ -18,6 +18,7 @@ #include +#include #include using namespace cl::sycl; @@ -91,15 +92,14 @@ int main() { if (dev.get_info()) { usm_allocator alloc(ctxt, dev); - std::vector vec(alloc); - vec.resize(N); + auto AllocDeleter = [&](int *ptr) { alloc.deallocate(ptr, N); }; + std::unique_ptr mem(alloc.allocate(N), + AllocDeleter); - int *res = &vec[0]; - int *vals = &vec[0]; + int *vals = mem.get(); auto e0 = q.submit([=](handler &h) { h.single_task([=]() { - res[0] = 0; for (int i = 0; i < N; i++) { vals[i] = i; } @@ -110,7 +110,7 @@ int main() { h.depends_on(e0); h.single_task([=]() { for (int i = 1; i < N; i++) { - res[0] += vals[i]; + vals[0] += vals[i]; } }); }); @@ -119,7 +119,7 @@ int main() { int answer = (N * (N - 1)) / 2; int result; - q.memcpy(&result, res, sizeof(int)); + q.memcpy(&result, vals, sizeof(int)); q.wait(); if (result != answer)