Skip to content

Commit

Permalink
Optimize CPU time of JPEG lossless decoder (#4625)
Browse files Browse the repository at this point in the history
Signed-off-by: Joaquin Anton <janton@nvidia.com>
  • Loading branch information
jantonguirao authored Feb 1, 2023
1 parent c36e5a1 commit 9f0f7e0
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 82 deletions.
15 changes: 14 additions & 1 deletion dali/imgcodec/decoders/nvjpeg/nvjpeg.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -144,6 +144,19 @@ NvJpegDecoderInstance::PerThreadResources::~PerThreadResources() {
}
}

bool NvJpegDecoderInstance::CanDecode(DecodeContext ctx, ImageSource *in, DecodeParams opts,
const ROI &roi) {
JpegParser jpeg_parser{};
if (!jpeg_parser.CanParse(in))
return false;

// This decoder does not support SOF-3 (JPEG lossless) samples
auto ext_info = jpeg_parser.GetExtendedInfo(in);
std::array<uint8_t, 2> sof3_marker = {0xff, 0xc3};
bool is_lossless_jpeg = ext_info.sof_marker == sof3_marker;
return !is_lossless_jpeg;
}

bool NvJpegDecoderInstance::SetParam(const char *name, const any &value) {
if (strcmp(name, "device_memory_padding") == 0) {
device_memory_padding_ = any_cast<size_t>(value);
Expand Down
5 changes: 4 additions & 1 deletion dali/imgcodec/decoders/nvjpeg/nvjpeg.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,6 +33,9 @@ class DLL_PUBLIC NvJpegDecoderInstance : public BatchParallelDecoderImpl {
public:
explicit NvJpegDecoderInstance(int device_id, const std::map<std::string, any> &params);

using BatchParallelDecoderImpl::CanDecode;
bool CanDecode(DecodeContext ctx, ImageSource *in, DecodeParams opts, const ROI &roi) override;

// NvjpegDecoderInstance has to operate on its own thread pool instead of the
// one passed by the DecodeContext. Overriding thread pool pointer caried in
// the context argument of this variant of
Expand Down
234 changes: 158 additions & 76 deletions dali/imgcodec/decoders/nvjpeg_lossless/nvjpeg_lossless.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,115 +32,190 @@ NvJpegLosslessDecoderInstance::NvJpegLosslessDecoderInstance(
: BatchedApiDecoderImpl(device_id, params),
event_(CUDAEvent::Create(device_id)) {
DeviceGuard dg(device_id_);
// TODO(janton): use custom allocators (?)
CUDA_CALL(nvjpegCreateEx(NVJPEG_BACKEND_LOSSLESS_JPEG, NULL, NULL, 0, &nvjpeg_handle_));
CUDA_CALL(nvjpegJpegStreamCreate(nvjpeg_handle_, &jpeg_stream_));
per_thread_resources_.push_back(PerThreadResources{nvjpeg_handle_});
CUDA_CALL(nvjpegJpegStateCreate(nvjpeg_handle_, &state_));
}

NvJpegLosslessDecoderInstance::~NvJpegLosslessDecoderInstance() {
DeviceGuard dg(device_id_);
CUDA_CALL(cudaEventSynchronize(event_));
CUDA_CALL(nvjpegJpegStreamDestroy(jpeg_stream_));
per_thread_resources_.clear();
CUDA_CALL(nvjpegJpegStateDestroy(state_));
CUDA_CALL(nvjpegDestroy(nvjpeg_handle_));
}

NvJpegLosslessDecoderInstance::PerThreadResources::PerThreadResources(nvjpegHandle_t handle) {
CUDA_CALL(nvjpegJpegStreamCreate(handle, &jpeg_stream));
}

NvJpegLosslessDecoderInstance::PerThreadResources::PerThreadResources(PerThreadResources&& other)
: jpeg_stream(other.jpeg_stream) {
other.jpeg_stream = nullptr;
}

NvJpegLosslessDecoderInstance::PerThreadResources::~PerThreadResources() {
if (jpeg_stream) {
CUDA_CALL(nvjpegJpegStreamDestroy(jpeg_stream));
}
}

bool NvJpegLosslessDecoderInstance::CanDecode(DecodeContext ctx, ImageSource *in, DecodeParams opts,
const ROI &roi) {
if (opts.format != DALI_ANY_DATA && opts.format != DALI_GRAY) {
return false;
}

try {
CUDA_CALL(nvjpegJpegStreamParseHeader(nvjpeg_handle_, in->RawData<unsigned char>(), in->Size(),
jpeg_stream_));

int is_supported = 0; // 0 means yes
CUDA_CALL(nvjpegDecodeBatchedSupported(nvjpeg_handle_, jpeg_stream_, &is_supported));
return is_supported == 0;
} catch (...) {
JpegParser jpeg_parser{};
if (!jpeg_parser.CanParse(in))
return false;

// This decoder only supports SOF-3 (JPEG lossless) samples
auto ext_info = jpeg_parser.GetExtendedInfo(in);
std::array<uint8_t, 2> sof3_marker = {0xff, 0xc3};
bool is_lossless_jpeg = ext_info.sof_marker == sof3_marker;
return is_lossless_jpeg;
}

void NvJpegLosslessDecoderInstance::Parse(DecodeResultsPromise &promise,
DecodeContext ctx,
cspan<ImageSource *> in,
DecodeParams opts,
cspan<ROI> rois) {
int nsamples = in.size();
assert(rois.empty() || rois.size() == nsamples);
assert(ctx.tp != nullptr);
int nthreads = ctx.tp->NumThreads();
if (nthreads > static_cast<int>(per_thread_resources_.size())) {
per_thread_resources_.reserve(nthreads);
for (int i = per_thread_resources_.size(); i < nthreads; i++)
per_thread_resources_.emplace_back(nvjpeg_handle_);
}
sample_meta_.clear();
sample_meta_.resize(nsamples);

// temporary solution: Just to check when the parsing has finished
DecodeResultsPromise parse_promise(nsamples);
for (int i = 0; i < nsamples; i++) {
int tid = 0;
ctx.tp->AddWork(
[&, i](int tid) {
auto &jpeg_stream = per_thread_resources_[tid].jpeg_stream;
auto *sample = in[i];
auto &meta = sample_meta_[i];
try {
CUDA_CALL(nvjpegJpegStreamParseHeader(nvjpeg_handle_, sample->RawData<unsigned char>(),
sample->Size(), jpeg_stream));
int is_supported = 0; // 0 means yes
CUDA_CALL(nvjpegDecodeBatchedSupported(nvjpeg_handle_, jpeg_stream, &is_supported));
meta.can_decode = (is_supported == 0);
if (!meta.can_decode) {
promise.set(i, {false, nullptr});
parse_promise.set(i, {false, nullptr});
return;
}
meta.needs_processing = opts.dtype != DALI_UINT16;
meta.needs_processing |= !rois.empty() && rois[i].use_roi();
if (opts.use_orientation) {
auto &ori = meta.orientation = JpegParser().GetInfo(sample).orientation;
meta.needs_processing |= (ori.rotate || ori.flip_x || ori.flip_y);
}

unsigned int precision;
CUDA_CALL(nvjpegJpegStreamGetSamplePrecision(jpeg_stream, &precision));
meta.dyn_range_multiplier = 1.0f;
if (NeedDynamicRangeScaling(precision, DALI_UINT16)) {
meta.dyn_range_multiplier = DynamicRangeMultiplier(precision, DALI_UINT16);
meta.needs_processing = true;
}
parse_promise.set(i, {true, nullptr});
} catch (...) {
meta.can_decode = false;
promise.set(i, {false, std::current_exception()});
parse_promise.set(i, {false, nullptr});
}
}, in[i]->Size());
}
ctx.tp->RunAll(false);
parse_promise.get_future().wait_all();

batch_sz_ = 0;
for (auto &meta : sample_meta_) {
if (meta.can_decode)
meta.idx_in_batch = batch_sz_++;
}
}

void NvJpegLosslessDecoderInstance::RunDecode(kernels::DynamicScratchpad& s,
DecodeContext ctx,
span<SampleView<GPUBackend>> out,
cspan<ImageSource *> in,
DecodeParams opts,
cspan<ROI> rois) {
if (batch_sz_ <= 0)
return;
int nsamples = in.size();
encoded_.clear();
encoded_.resize(batch_sz_);
encoded_len_.clear();
encoded_len_.resize(batch_sz_);
decoded_.clear();
decoded_.resize(batch_sz_);
for (int i = 0; i < nsamples; i++) {
auto &meta = sample_meta_[i];
if (!meta.can_decode)
continue;
auto *sample = in[i];
auto &out_sample = out[i];
auto roi = rois.empty() ? ROI{} : rois[i];
assert(sample->Kind() == InputKind::HostMemory);
encoded_[meta.idx_in_batch] = sample->RawData<unsigned char>();
encoded_len_[meta.idx_in_batch] = sample->Size();
auto &o = decoded_[meta.idx_in_batch];
auto sh = out_sample.shape();
o.pitch[0] = sh[1] * sh[2] * sizeof(uint16_t);
if (meta.needs_processing) {
int64_t nbytes = volume(sh) * sizeof(uint16_t);
o.channel[0] = s.Allocate<mm::memory_kind::device, uint8_t>(nbytes);
} else {
o.channel[0] = static_cast<uint8_t *>(out_sample.raw_mutable_data());
}
}

CUDA_CALL(nvjpegDecodeBatchedInitialize(nvjpeg_handle_, state_, batch_sz_, 1,
NVJPEG_OUTPUT_UNCHANGEDI_U16));
CUDA_CALL(nvjpegDecodeBatched(nvjpeg_handle_, state_, encoded_.data(), encoded_len_.data(),
decoded_.data(), ctx.stream));
}


FutureDecodeResults NvJpegLosslessDecoderInstance::ScheduleDecode(DecodeContext ctx,
span<SampleView<GPUBackend>> out,
cspan<ImageSource *> in,
DecodeParams opts,
cspan<ROI> rois) {
int nsamples = in.size();
assert(out.size() == nsamples);
assert(rois.empty() || rois.size() == nsamples);
assert(ctx.tp != nullptr);

DecodeResultsPromise promise(nsamples);
auto set_promise = [&](DecodeResult result) {
for (int i = 0; i < nsamples; i++)
promise.set(i, result);
};

if (opts.format != DALI_ANY_DATA && opts.format != DALI_GRAY)
set_promise({false, std::make_exception_ptr(
std::invalid_argument("Only ANY_DATA and GRAY are supported."))});

// scratchpad should not go out of scope until we launch the postprocessing
kernels::DynamicScratchpad s({}, ctx.stream);
Parse(promise, ctx, in, opts, rois);
try {
if (opts.format != DALI_ANY_DATA && opts.format != DALI_GRAY)
throw std::invalid_argument("Only ANY_DATA and GRAY are supported.");

sample_meta_.clear();
sample_meta_.resize(nsamples);
encoded_.clear();
encoded_.resize(nsamples);
encoded_len_.clear();
encoded_len_.resize(nsamples);
decoded_.clear();
decoded_.resize(nsamples);

for (int i = 0; i < nsamples; i++) {
auto *sample = in[i];
auto &out_sample = out[i];
assert(sample->Kind() == InputKind::HostMemory);
auto *data_ptr = sample->RawData<unsigned char>();
auto data_size = sample->Size();
encoded_[i] = data_ptr;
encoded_len_[i] = data_size;
sample_meta_[i].needs_processing = opts.dtype != DALI_UINT16;
if (!rois.empty() && rois[i].use_roi()) {
sample_meta_[i].needs_processing = true;
}
if (opts.use_orientation) {
auto &ori = sample_meta_[i].orientation = JpegParser().GetInfo(in[i]).orientation;
if (ori.rotate || ori.flip_x || ori.flip_y)
sample_meta_[i].needs_processing = true;
}

CUDA_CALL(nvjpegJpegStreamParseHeader(nvjpeg_handle_, sample->RawData<unsigned char>(),
sample->Size(), jpeg_stream_));
unsigned int precision;
CUDA_CALL(nvjpegJpegStreamGetSamplePrecision(jpeg_stream_, &precision));
sample_meta_[i].dyn_range_multiplier = 1.0f;
if (NeedDynamicRangeScaling(precision, DALI_UINT16)) {
sample_meta_[i].dyn_range_multiplier = DynamicRangeMultiplier(precision, DALI_UINT16);
sample_meta_[i].needs_processing = true;
}

auto &o = decoded_[i];
auto sh = out_sample.shape();
o.pitch[0] = sh[1] * sh[2] * sizeof(uint16_t);
if (sample_meta_[i].needs_processing) {
int64_t nbytes = volume(sh) * sizeof(uint16_t);
o.channel[0] = s.Allocate<mm::memory_kind::device, uint8_t>(nbytes);
} else {
o.channel[0] = static_cast<uint8_t *>(out_sample.raw_mutable_data());
}
}

CUDA_CALL(nvjpegDecodeBatchedInitialize(nvjpeg_handle_, state_, nsamples, 1,
NVJPEG_OUTPUT_UNCHANGEDI_U16));
CUDA_CALL(nvjpegDecodeBatched(nvjpeg_handle_, state_, encoded_.data(), encoded_len_.data(),
decoded_.data(), ctx.stream));
} catch (...) {
RunDecode(s, ctx, out, in, opts, rois);
} catch(...) {
set_promise({false, std::current_exception()});
return promise.get_future();
}

Postprocess(promise, ctx, out, opts, rois);
CUDA_CALL(cudaEventRecord(event_, ctx.stream));
return promise.get_future();
Expand All @@ -149,27 +224,34 @@ FutureDecodeResults NvJpegLosslessDecoderInstance::ScheduleDecode(DecodeContext
void NvJpegLosslessDecoderInstance::Postprocess(DecodeResultsPromise &promise, DecodeContext ctx,
span<SampleView<GPUBackend>> out, DecodeParams opts,
cspan<ROI> rois) {
if (batch_sz_ <= 0)
return;
int nsamples = out.size();
for (int i = 0; i < nsamples; i++) {
if (!sample_meta_[i].needs_processing) {
for (int i = 0, j = 0; i < nsamples; i++) {
const auto &meta = sample_meta_[i];
if (!meta.can_decode)
continue; // we didn't try to decode this sample

const auto &decoded = decoded_[j++]; // decoded only has samples where can_decode == true
if (!meta.needs_processing) {
promise.set(i, {true, nullptr});
continue;
}
auto sh = out[i].shape();
SampleView<GPUBackend> decoded_view(decoded_[i].channel[0], sh, DALI_UINT16);
auto roi = rois.empty() ? ROI{} : rois[i];
SampleView<GPUBackend> decoded_view(decoded.channel[0], sh, DALI_UINT16);
DALIImageType decoded_format = sh[2] == 1 ? DALI_GRAY : DALI_ANY_DATA;
try {
Convert(out[i], "HWC", opts.format, decoded_view, "HWC", decoded_format,
ctx.stream, rois.empty() ? ROI{} : rois[i], sample_meta_[i].orientation,
sample_meta_[i].dyn_range_multiplier);
ctx.stream, roi, meta.orientation, meta.dyn_range_multiplier);
promise.set(i, {true, nullptr});
} catch (...) {
promise.set(i, {false, std::current_exception()});
}
}
}

REGISTER_DECODER("JPEG", NvJpegLosslessDecoderFactory, CUDADecoderPriority - 1);
REGISTER_DECODER("JPEG", NvJpegLosslessDecoderFactory, CUDADecoderPriority + 1);

} // namespace imgcodec
} // namespace dali
22 changes: 20 additions & 2 deletions dali/imgcodec/decoders/nvjpeg_lossless/nvjpeg_lossless.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "dali/core/cuda_stream_pool.h"
#include "dali/imgcodec/decoders/decoder_batched_api_impl.h"
#include "dali/imgcodec/decoders/nvjpeg/nvjpeg_memory.h"
#include "dali/kernels/dynamic_scratchpad.h"
#include "dali/pipeline/data/buffer.h"

namespace dali {
Expand All @@ -42,20 +43,37 @@ class DLL_PUBLIC NvJpegLosslessDecoderInstance : public BatchedApiDecoderImpl {
cspan<ROI> rois = {}) override;

private:
// Parses encoded streams and populates SampleMeta, and batch_sz_
void Parse(DecodeResultsPromise &promise, DecodeContext ctx, cspan<ImageSource *> in,
DecodeParams opts, cspan<ROI> rois);

// Invokes nvJPEG decoding (sample_meta_ and batch_sz_ to be populated)
void RunDecode(kernels::DynamicScratchpad &s, DecodeContext ctx, span<SampleView<GPUBackend>> out,
cspan<ImageSource *> in, DecodeParams opts, cspan<ROI> rois = {});

void Postprocess(DecodeResultsPromise &promise, DecodeContext ctx,
span<SampleView<GPUBackend>> out, DecodeParams opts, cspan<ROI> rois);

nvjpegHandle_t nvjpeg_handle_;
nvjpegJpegStream_t jpeg_stream_;

struct PerThreadResources {
explicit PerThreadResources(nvjpegHandle_t handle);
PerThreadResources(PerThreadResources&& other);
~PerThreadResources();
nvjpegJpegStream_t jpeg_stream;
};
std::vector<PerThreadResources> per_thread_resources_;
CUDAEvent event_;
nvjpegJpegState_t state_;

int batch_sz_ = 0; // number of samples to be decoded by nvJPEG
struct SampleMeta {
bool can_decode;
int idx_in_batch; // only relevant if can_decode == true
bool needs_processing;
Orientation orientation;
float dyn_range_multiplier;
};

std::vector<SampleMeta> sample_meta_;
std::vector<const unsigned char*> encoded_;
std::vector<size_t> encoded_len_;
Expand Down
Loading

0 comments on commit 9f0f7e0

Please sign in to comment.