Skip to content

Commit 8535b24

Browse files
Ilya Stepykinbader
authored andcommitted
[SYCL] Fix accessor construction from a buffer.
Allow accessor to be constructed from a buffer with a non-default allocator. Signed-off-by: Ilya Stepykin <ilya.stepykin@intel.com>
1 parent b63a96f commit 8535b24

File tree

2 files changed

+73
-22
lines changed

2 files changed

+73
-22
lines changed

sycl/include/CL/sycl/accessor.hpp

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -741,11 +741,11 @@ class accessor :
741741
using reference = DataT &;
742742
using const_reference = const DataT &;
743743

744-
template <int Dims = Dimensions>
745-
accessor(
746-
detail::enable_if_t<Dims == 0 && ((!IsPlaceH && IsHostBuf) ||
747-
(IsPlaceH && (IsGlobalBuf || IsConstantBuf))),
748-
buffer<DataT, 1>> &BufferRef)
744+
template <int Dims = Dimensions, typename AllocatorT,
745+
typename detail::enable_if_t<
746+
Dims == 0 && ((!IsPlaceH && IsHostBuf) ||
747+
(IsPlaceH && (IsGlobalBuf || IsConstantBuf)))>* = nullptr>
748+
accessor(buffer<DataT, 1, AllocatorT> &BufferRef)
749749
#ifdef __SYCL_DEVICE_ONLY__
750750
: impl(id<AdjustedDim>(), BufferRef.get_range(), BufferRef.MemRange) {
751751
#else
@@ -762,11 +762,11 @@ class accessor :
762762
#endif
763763
}
764764

765-
template <int Dims = Dimensions>
766-
accessor(
767-
buffer<DataT, 1> &BufferRef,
768-
detail::enable_if_t<Dims == 0 && (!IsPlaceH && (IsGlobalBuf || IsConstantBuf)),
769-
handler> &CommandGroupHandler)
765+
template <int Dims = Dimensions, typename AllocatorT>
766+
accessor(buffer<DataT, 1, AllocatorT> &BufferRef,
767+
detail::enable_if_t<Dims == 0 &&
768+
(!IsPlaceH && (IsGlobalBuf || IsConstantBuf)),
769+
handler> &CommandGroupHandler)
770770
#ifdef __SYCL_DEVICE_ONLY__
771771
: impl(id<AdjustedDim>(), BufferRef.get_range(), BufferRef.MemRange) {
772772
}
@@ -781,11 +781,12 @@ class accessor :
781781
}
782782
#endif
783783

784-
template <int Dims = Dimensions,
785-
typename = detail::enable_if_t<
784+
template <int Dims = Dimensions, typename AllocatorT,
785+
typename detail::enable_if_t<
786786
(Dims > 0) && ((!IsPlaceH && IsHostBuf) ||
787-
(IsPlaceH && (IsGlobalBuf || IsConstantBuf)))>>
788-
accessor(buffer<DataT, Dimensions> &BufferRef)
787+
(IsPlaceH && (IsGlobalBuf || IsConstantBuf)))>
788+
* = nullptr>
789+
accessor(buffer<DataT, Dimensions, AllocatorT> &BufferRef)
789790
#ifdef __SYCL_DEVICE_ONLY__
790791
: impl(id<Dimensions>(), BufferRef.get_range(), BufferRef.MemRange) {
791792
}
@@ -803,10 +804,11 @@ class accessor :
803804
}
804805
#endif
805806

806-
template <int Dims = Dimensions,
807+
template <int Dims = Dimensions, typename AllocatorT,
807808
typename = detail::enable_if_t<
808809
(Dims > 0) && (!IsPlaceH && (IsGlobalBuf || IsConstantBuf))>>
809-
accessor(buffer<DataT, Dimensions> &BufferRef, handler &CommandGroupHandler)
810+
accessor(buffer<DataT, Dimensions, AllocatorT> &BufferRef,
811+
handler &CommandGroupHandler)
810812
#ifdef __SYCL_DEVICE_ONLY__
811813
: impl(id<AdjustedDim>(), BufferRef.get_range(), BufferRef.MemRange) {
812814
}
@@ -821,12 +823,12 @@ class accessor :
821823
}
822824
#endif
823825

824-
template <int Dims = Dimensions,
826+
template <int Dims = Dimensions, typename AllocatorT,
825827
typename = detail::enable_if_t<
826828
(Dims > 0) && ((!IsPlaceH && IsHostBuf) ||
827829
(IsPlaceH && (IsGlobalBuf || IsConstantBuf)))>>
828-
accessor(buffer<DataT, Dimensions> &BufferRef, range<Dimensions> AccessRange,
829-
id<Dimensions> AccessOffset = {})
830+
accessor(buffer<DataT, Dimensions, AllocatorT> &BufferRef,
831+
range<Dimensions> AccessRange, id<Dimensions> AccessOffset = {})
830832
#ifdef __SYCL_DEVICE_ONLY__
831833
: impl(AccessOffset, AccessRange, BufferRef.MemRange) {
832834
}
@@ -843,11 +845,12 @@ class accessor :
843845
}
844846
#endif
845847

846-
template <int Dims = Dimensions,
848+
template <int Dims = Dimensions, typename AllocatorT,
847849
typename = detail::enable_if_t<
848850
(Dims > 0) && (!IsPlaceH && (IsGlobalBuf || IsConstantBuf))>>
849-
accessor(buffer<DataT, Dimensions> &BufferRef, handler &CommandGroupHandler,
850-
range<Dimensions> AccessRange, id<Dimensions> AccessOffset = {})
851+
accessor(buffer<DataT, Dimensions, AllocatorT> &BufferRef,
852+
handler &CommandGroupHandler, range<Dimensions> AccessRange,
853+
id<Dimensions> AccessOffset = {})
851854
#ifdef __SYCL_DEVICE_ONLY__
852855
: impl(AccessOffset, AccessRange, BufferRef.MemRange) {
853856
}

sycl/test/basic_tests/accessor/accessor.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,4 +379,52 @@ int main() {
379379
return 1;
380380
}
381381
}
382+
383+
{
384+
// Call every available accessor's constructor to ensure that they work with
385+
// a buffer with a non-default allocator.
386+
int data[] = {1, 2, 3};
387+
388+
using allocator_type = std::allocator<int>;
389+
390+
sycl::buffer<int, 1, allocator_type> buf1(&data[0], sycl::range<1>(1),
391+
allocator_type{});
392+
sycl::buffer<int, 1, allocator_type> buf2(&data[1], sycl::range<1>(1),
393+
allocator_type{});
394+
sycl::buffer<int, 1, allocator_type> buf3(&data[2], sycl::range<1>(1),
395+
allocator_type{});
396+
397+
sycl::queue queue;
398+
queue.submit([&](sycl::handler &cgh) {
399+
sycl::accessor<int, 0, sycl::access::mode::read_write,
400+
sycl::access::target::global_buffer>
401+
acc1(buf1, cgh);
402+
sycl::accessor<int, 1, sycl::access::mode::read_write,
403+
sycl::access::target::global_buffer>
404+
acc2(buf2, cgh);
405+
sycl::accessor<int, 1, sycl::access::mode::read_write,
406+
sycl::access::target::global_buffer>
407+
acc3(buf3, cgh, sycl::range<1>(1));
408+
409+
cgh.single_task<class acc_alloc_buf>([=]() {
410+
acc1 *= 2;
411+
acc2[0] *= 2;
412+
acc3[0] *= 2;
413+
});
414+
});
415+
416+
sycl::accessor<int, 0, sycl::access::mode::read,
417+
sycl::access::target::host_buffer>
418+
acc4(buf1);
419+
sycl::accessor<int, 1, sycl::access::mode::read,
420+
sycl::access::target::host_buffer>
421+
acc5(buf2);
422+
sycl::accessor<int, 1, sycl::access::mode::read,
423+
sycl::access::target::host_buffer>
424+
acc6(buf3, sycl::range<1>(1));
425+
426+
assert(acc4 == 2);
427+
assert(acc5[0] == 4);
428+
assert(acc6[0] == 6);
429+
}
382430
}

0 commit comments

Comments
 (0)