22
22
#include " dali/imgcodec/util/convert_gpu.h"
23
23
#include " dali/core/static_switch.h"
24
24
#include " dali/imgcodec/registry.h"
25
+ #include " dali/pipeline/util/for_each_thread.h"
25
26
26
27
namespace dali {
27
28
namespace imgcodec {
@@ -63,8 +64,10 @@ NvJpeg2000DecoderInstance::NvJpeg2000DecoderInstance(
63
64
nvjpeg2k_handle_ = NvJpeg2kHandle (&nvjpeg2k_dev_alloc_, &nvjpeg2k_pin_alloc_);
64
65
DALI_ENFORCE (nvjpeg2k_handle_, " NvJpeg2kHandle initalization failed" );
65
66
66
- for (auto &res : per_thread_resources_)
67
- res = {nvjpeg2k_handle_, device_memory_padding, device_id_};
67
+ ForEachThread (*tp_, [&](int tid) noexcept {
68
+ CUDA_CALL (cudaSetDevice (device_id));
69
+ per_thread_resources_[tid] = {nvjpeg2k_handle_, device_memory_padding, device_id_};
70
+ });
68
71
69
72
for (const auto &thread_id : tp_->GetThreadIds ()) {
70
73
if (device_memory_padding > 0 ) {
@@ -81,9 +84,18 @@ NvJpeg2000DecoderInstance::NvJpeg2000DecoderInstance(
81
84
}
82
85
83
86
NvJpeg2000DecoderInstance::~NvJpeg2000DecoderInstance () {
87
+ tp_->WaitForWork ();
84
88
for (const auto &res : per_thread_resources_)
85
89
CUDA_CALL (cudaStreamSynchronize (res.cuda_stream ));
86
- for (const auto &thread_id : tp_->GetThreadIds ())
90
+
91
+ ForEachThread (*tp_, [&](int tid) {
92
+ auto &res = per_thread_resources_[tid];
93
+ res.tile_dec_res .clear ();
94
+ res.nvjpeg2k_decode_state .reset ();
95
+ res.intermediate_buffer .free ();
96
+ });
97
+
98
+ for (auto thread_id : tp_->GetThreadIds ())
87
99
nvjpeg_memory::DeleteAllBuffers (thread_id);
88
100
}
89
101
@@ -157,8 +169,9 @@ bool NvJpeg2000DecoderInstance::DecodeJpeg2000(ImageSource *in, void *out, const
157
169
ctx.nvjpeg2k_stream , &output_image, ctx.cuda_stream );
158
170
return check_status (ret, in);
159
171
} else {
160
- // Decode tile by tile: nvjpeg2kDecodeImage seems to be bugged
161
172
auto &image_info = ctx.image_info ;
173
+
174
+ // Decode tile by tile: nvjpeg2kDecodeImage doesn't work properly with ROI
162
175
auto &roi = ctx.roi ;
163
176
std::array tile_shape = {image_info.tile_height , image_info.tile_width };
164
177
@@ -185,11 +198,10 @@ bool NvJpeg2000DecoderInstance::DecodeJpeg2000(ImageSource *in, void *out, const
185
198
186
199
if (begin_x < end_x && begin_y < end_y) {
187
200
const TileDecodingResources &per_tile_ctx = ctx.tile_dec_res [state_idx];
188
- state_idx = (state_idx + 1 ) % ctx.tile_dec_res .size ();
189
201
190
202
CUDA_CALL (cudaEventSynchronize (per_tile_ctx.decode_event ));
191
203
192
- NvJpeg2kDecodeParams params;
204
+ auto ¶ms = per_tile_ctx. params ;
193
205
CUDA_CALL (nvjpeg2kDecodeParamsSetDecodeArea (params, begin_x, end_x, begin_y, end_y));
194
206
195
207
auto output_image = PrepareOutputArea (out, pixel_data, pitch_in_bytes, output_offset_x,
@@ -208,6 +220,7 @@ bool NvJpeg2000DecoderInstance::DecodeJpeg2000(ImageSource *in, void *out, const
208
220
return check_status (ret, in);
209
221
210
222
CUDA_CALL (cudaEventRecord (per_tile_ctx.decode_event , ctx.cuda_stream ));
223
+ state_idx = (state_idx + 1 ) % ctx.tile_dec_res .size ();
211
224
}
212
225
}
213
226
}
@@ -225,11 +238,11 @@ DecodeResult NvJpeg2000DecoderInstance::DecodeImplTask(int thread_idx,
225
238
Context ctx (opts, roi, res);
226
239
DecodeResult result = {false , nullptr };
227
240
241
+ CUDA_CALL (cudaEventSynchronize (ctx.decode_event ));
242
+
228
243
if (!ParseJpeg2000Info (in, ctx))
229
244
return result;
230
245
231
- CUDA_CALL (cudaEventSynchronize (ctx.decode_event ));
232
-
233
246
const int64_t channels = ctx.shape [0 ];
234
247
DALIImageType format = channels == 1 ? DALI_GRAY : DALI_RGB;
235
248
bool is_processing_needed =
@@ -241,6 +254,10 @@ DecodeResult NvJpeg2000DecoderInstance::DecodeImplTask(int thread_idx,
241
254
auto decode_out = out;
242
255
if (is_processing_needed) {
243
256
int64_t type_size = dali::TypeTable::GetTypeInfo (ctx.pixel_type ).size ();
257
+ size_t new_size = volume (ctx.shape ) * type_size;
258
+ if (new_size > res.intermediate_buffer .capacity ()) {
259
+ CUDA_CALL (cudaStreamSynchronize (ctx.cuda_stream ));
260
+ }
244
261
res.intermediate_buffer .resize (volume (ctx.shape ) * type_size);
245
262
decode_out = {res.intermediate_buffer .data (), ctx.shape , ctx.pixel_type };
246
263
}
0 commit comments