diff --git a/sycl/include/CL/sycl/detail/sampler_impl.hpp b/sycl/include/CL/sycl/detail/sampler_impl.hpp index 164acd802ecb7..43dea3451396b 100644 --- a/sycl/include/CL/sycl/detail/sampler_impl.hpp +++ b/sycl/include/CL/sycl/detail/sampler_impl.hpp @@ -27,15 +27,12 @@ class sampler_impl { __spirv::OpTypeSampler *m_Sampler; sampler_impl(__spirv::OpTypeSampler *Sampler) : m_Sampler(Sampler) {} #else - cl_sampler m_Sampler = nullptr; - context m_SyclContext; std::unordered_map m_contextToSampler; private: coordinate_normalization_mode m_CoordNormMode; addressing_mode m_AddrMode; filtering_mode m_FiltMode; - bool m_ReleaseSampler; public: sampler_impl(coordinate_normalization_mode normalizationMode, diff --git a/sycl/source/detail/sampler_impl.cpp b/sycl/source/detail/sampler_impl.cpp index aa3c4af0a03ba..89e2ea4da3782 100644 --- a/sycl/source/detail/sampler_impl.cpp +++ b/sycl/source/detail/sampler_impl.cpp @@ -1,4 +1,4 @@ -//==----------------- sampler_impl.cpp - SYCL standard header file ---------==// +//==----------------- sampler_impl.cpp - SYCL sampler ----------------------==// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -17,28 +17,28 @@ sampler_impl::sampler_impl(coordinate_normalization_mode normalizationMode, addressing_mode addressingMode, filtering_mode filteringMode) : m_CoordNormMode(normalizationMode), m_AddrMode(addressingMode), - m_FiltMode(filteringMode), m_ReleaseSampler(false) {} + m_FiltMode(filteringMode) {} -sampler_impl::sampler_impl(cl_sampler clSampler, const context &syclContext) - : m_Sampler(clSampler), m_SyclContext(syclContext), m_ReleaseSampler(true) { +sampler_impl::sampler_impl(cl_sampler clSampler, const context &syclContext) { - m_contextToSampler[syclContext] = m_Sampler; - CHECK_OCL_CODE(clRetainSampler(m_Sampler)); - CHECK_OCL_CODE(clGetSamplerInfo(m_Sampler, CL_SAMPLER_NORMALIZED_COORDS, + m_contextToSampler[syclContext] = clSampler; + CHECK_OCL_CODE(clRetainSampler(clSampler)); + CHECK_OCL_CODE(clGetSamplerInfo(clSampler, CL_SAMPLER_NORMALIZED_COORDS, sizeof(cl_bool), &m_CoordNormMode, nullptr)); - CHECK_OCL_CODE(clGetSamplerInfo(m_Sampler, CL_SAMPLER_ADDRESSING_MODE, + CHECK_OCL_CODE(clGetSamplerInfo(clSampler, CL_SAMPLER_ADDRESSING_MODE, sizeof(cl_addressing_mode), &m_AddrMode, nullptr)); - CHECK_OCL_CODE(clGetSamplerInfo(m_Sampler, CL_SAMPLER_FILTER_MODE, + CHECK_OCL_CODE(clGetSamplerInfo(clSampler, CL_SAMPLER_FILTER_MODE, sizeof(cl_filter_mode), &m_FiltMode, nullptr)); } sampler_impl::~sampler_impl() { - if (m_ReleaseSampler) - CHECK_OCL_CODE_NO_EXC(clReleaseSampler(m_Sampler)); - // TODO replace CHECK_OCL_CODE_NO_EXC to CHECK_OCL_CODE and - // TODO catch an exception and add it to the list of asynchronous exceptions + for (auto &Iter : m_contextToSampler) { + // TODO replace CHECK_OCL_CODE_NO_EXC to CHECK_OCL_CODE and + // TODO catch an exception and add it to the list of asynchronous exceptions + CHECK_OCL_CODE_NO_EXC(clReleaseSampler(Iter.second)); + } } cl_sampler sampler_impl::getOrCreateSampler(const context &Context) { @@ -64,7 +64,6 @@ cl_sampler sampler_impl::getOrCreateSampler(const context &Context) { static_cast(m_FiltMode), &errcode_ret); #endif CHECK_OCL_CODE(errcode_ret); - m_ReleaseSampler = true; return m_contextToSampler[Context]; }