diff --git a/dali/operators/image/resize/experimental/resize_op_impl_cvcuda.h b/dali/operators/image/resize/experimental/resize_op_impl_cvcuda.h index c9a7cbd3241..734aeb5a504 100644 --- a/dali/operators/image/resize/experimental/resize_op_impl_cvcuda.h +++ b/dali/operators/image/resize/experimental/resize_op_impl_cvcuda.h @@ -23,6 +23,7 @@ #include "dali/kernels/imgproc/resample/params.h" #include "dali/operators/image/resize/resize_op_impl.h" #include "dali/operators/nvcvop/nvcvop.h" +#include "dali/core/nvtx.h" namespace dali { @@ -33,12 +34,13 @@ class ResizeOpImplCvCuda : public ResizeBase::Impl { static_assert(spatial_ndim == 2 || spatial_ndim == 3, "Only 2D and 3D resizing is supported"); - /// Dimensionality of each separate frame. If input contains no channel dimension, one is added static constexpr int frame_ndim = spatial_ndim + 1; void Setup(TensorListShape<> &out_shape, const TensorListShape<> &in_shape, int first_spatial_dim, span params) override { + first_spatial_dim_ = first_spatial_dim; + // Calculate output shape of the input, as supplied (sequences, planar images, etc) GetResizedShape(out_shape, in_shape, params, spatial_ndim, first_spatial_dim); @@ -49,37 +51,46 @@ class ResizeOpImplCvCuda : public ResizeBase::Impl { // effective frames (from videos, channel planes, etc). GetResizedShape(out_shape_, in_shape_, make_cspan(params_), 0); - // Create a map of non-empty samples - SetFrameIdxs(); - // Now that we know how many logical frames there are, calculate batch subdivision. CalculateMinibatchPartition(minibatch_size_); + CalculateSourceSamples(in_shape, first_spatial_dim); + SetupKernel(); } - // Set the frame_idx_ map with indices of samples that are not empty - void SetFrameIdxs() { - frame_idx_.clear(); - frame_idx_.reserve(in_shape_.num_samples()); - for (int i = 0; i < in_shape_.num_samples(); ++i) { - if (volume(out_shape_.tensor_shape_span(i)) != 0 && - volume(in_shape_.tensor_shape_span(i)) != 0) { - frame_idx_.push_back(i); + // Assign each minibatch a range of frames in the original input/output TensorLists + void CalculateSourceSamples(const TensorListShape<> &original_shape, int first_spatial_dim) { + int64_t sample_id = 0; + int64_t frame_offset = 0; + for (auto &mb : minibatches_) { + auto v = original_shape[sample_id].num_elements(); + while (v == 0) { + sample_id++; + v = original_shape[sample_id].num_elements(); + } + mb.sample_offset = sample_id; + mb.frame_offset = frame_offset; + frame_offset = mb.frame_offset + mb.count; + int frames_n = num_frames(original_shape[sample_id], first_spatial_dim); + while (frame_offset >= frames_n) { + frame_offset -= frames_n; + if (++sample_id >= original_shape.num_samples()) { + break; + } + frames_n = num_frames(original_shape[sample_id], first_spatial_dim); } - total_frames_ = frame_idx_.size(); } } - // get the index of a frame in the DALI TensorList - int frame_idx(int f) { - return frame_idx_[f]; + int64_t num_frames(const TensorShape<> &shape, int first_spatial_dim) { + return volume(&shape[0], &shape[first_spatial_dim]); } void SetupKernel() { - kernels::KernelContext ctx; rois_.resize(total_frames_); - workspace_reqs_ = {}; + workspace_reqs_[0] = {}; + workspace_reqs_[1] = {}; std::vector mb_input_shapes(minibatch_size_); std::vector mb_output_shapes(minibatch_size_); auto *rois_ptr = rois_.data(); @@ -88,15 +99,14 @@ class ResizeOpImplCvCuda : public ResizeBase::Impl { int end = mb.start + mb.count; for (int i = mb.start, j = 0; i < end; i++, j++) { - auto f_id = frame_idx(i); - rois_ptr[j] = GetRoi(params_[f_id]); + rois_ptr[j] = GetRoi(params_[i]); for (int d = 0; d < spatial_ndim; ++d) { - mb_input_shapes[j].extent[d] = static_cast(in_shape_.tensor_shape_span(f_id)[d]); + mb_input_shapes[j].extent[d] = static_cast(in_shape_.tensor_shape_span(i)[d]); mb_output_shapes[j].extent[d] = - static_cast(out_shape_.tensor_shape_span(f_id)[d]); + static_cast(out_shape_.tensor_shape_span(i)[d]); } } - int num_channels = in_shape_[frame_idx(0)][frame_ndim - 1]; + int num_channels = in_shape_[0][frame_ndim - 1]; HQResizeTensorShapesI mb_input_shape{mb_input_shapes.data(), mb.count, spatial_ndim, num_channels}; HQResizeTensorShapesI mb_output_shape{mb_output_shapes.data(), mb.count, spatial_ndim, @@ -104,14 +114,14 @@ class ResizeOpImplCvCuda : public ResizeBase::Impl { mb.rois = HQResizeRoisF{mb.count, spatial_ndim, rois_ptr}; rois_ptr += mb.count; - auto param = params_[frame_idx(mb.start)][0]; + auto param = params_[mb.start][0]; mb.min_interpolation = GetInterpolationType(param.min_filter); mb.mag_interpolation = GetInterpolationType(param.mag_filter); mb.antialias = param.min_filter.antialias || param.mag_filter.antialias; auto ws_req = resize_op_.getWorkspaceRequirements(mb.count, mb_input_shape, mb_output_shape, mb.min_interpolation, mb.mag_interpolation, mb.antialias, mb.rois); - workspace_reqs_ = nvcvop::MaxWorkspaceRequirements(workspace_reqs_, ws_req); + workspace_reqs_[mb_idx % 2] = cvcuda::MaxWorkspaceReq(workspace_reqs_[mb_idx % 2], ws_req); } } @@ -146,36 +156,43 @@ class ResizeOpImplCvCuda : public ResizeBase::Impl { void RunResize(Workspace &ws, TensorList &output, const TensorList &input) override { - TensorList in_frames; - in_frames.ShareData(input); - in_frames.Resize(in_shape_); - PrepareInput(in_frames); - - TensorList out_frames; - out_frames.ShareData(output); - out_frames.Resize(out_shape_); - PrepareOutput(out_frames); - - kernels::DynamicScratchpad scratchpad({}, AccessOrder(ws.stream())); + auto allocator = nvcvop::GetScratchpadAllocator(scratchpad); - auto workspace_mem = op_workspace_.Allocate(workspace_reqs_, scratchpad); + auto workspace_mem = AllocateWorkspaces(scratchpad); for (size_t b = 0; b < minibatches_.size(); b++) { MiniBatch &mb = minibatches_[b]; - resize_op_(ws.stream(), workspace_mem, mb.input, mb.output, mb.min_interpolation, + auto reqs = nvcv::TensorBatch::CalcRequirements(mb.count); + auto mb_output = nvcv::TensorBatch(reqs, allocator); + auto mb_input = nvcv::TensorBatch(reqs, allocator); + nvcvop::PushFramesToBatch(mb_input, input, first_spatial_dim_, mb.sample_offset, + mb.frame_offset, mb.count, sample_layout_); + nvcvop::PushFramesToBatch(mb_output, output, first_spatial_dim_, mb.sample_offset, + mb.frame_offset, mb.count, sample_layout_); + resize_op_(ws.stream(), workspace_mem[b % 2], mb_input, mb_output, mb.min_interpolation, mb.mag_interpolation, mb.antialias, mb.rois); } } + std::array AllocateWorkspaces(kernels::Scratchpad &scratchpad) { + std::array result; + result[0] = op_workspace_.Allocate(workspace_reqs_[0], scratchpad); + if (minibatches_.size() > 1) { + result[1] = op_workspace_.Allocate(workspace_reqs_[1], scratchpad); + } + return result; + } + void CalculateMinibatchPartition(int minibatch_size) { + total_frames_ = in_shape_.num_samples(); std::vector> continuous_ranges; - kernels::FilterDesc min_filter_desc = params_[frame_idx(0)][0].min_filter; - kernels::FilterDesc mag_filter_desc = params_[frame_idx(0)][0].mag_filter; + kernels::FilterDesc min_filter_desc = params_[0][0].min_filter; + kernels::FilterDesc mag_filter_desc = params_[0][0].mag_filter; int start_id = 0; for (int i = 0; i < total_frames_; i++) { - if (params_[frame_idx(i)][0].min_filter != min_filter_desc || - params_[frame_idx(i)][0].mag_filter != mag_filter_desc) { + if (params_[i][0].min_filter != min_filter_desc || + params_[i][0].mag_filter != mag_filter_desc) { // we break the range if different filter types are used continuous_ranges.emplace_back(start_id, i); start_id = i; @@ -204,60 +221,31 @@ class ResizeOpImplCvCuda : public ResizeBase::Impl { } TensorListShape in_shape_, out_shape_; - std::vector frame_idx_; // map of absolute frame indices in the input TensorList - int total_frames_ = 0; // number of non-empty frames + int total_frames_; // number of non-empty frames + std::vector> params_; + int first_spatial_dim_; cvcuda::HQResize resize_op_{}; nvcvop::NVCVOpWorkspace op_workspace_; - cvcuda::WorkspaceRequirements workspace_reqs_{}; + std::array workspace_reqs_{}; std::vector rois_; const TensorLayout sample_layout_ = (spatial_ndim == 2) ? "HWC" : "DHWC"; + std::vector in_frames_; + std::vector out_frames_; + struct MiniBatch { int start, count; - nvcv::TensorBatch input; - nvcv::TensorBatch output; NVCVInterpolationType min_interpolation; NVCVInterpolationType mag_interpolation; bool antialias; HQResizeRoisF rois; + int64_t sample_offset; // id of a starting sample in the original IOs + int64_t frame_offset; // id of a starting frame in the starting sample }; std::vector minibatches_; - - void PrepareInput(const TensorList &input) { - for (auto &mb : minibatches_) { - int curr_capacity = mb.input ? mb.input.capacity() : 0; - if (mb.count > curr_capacity) { - int new_capacity = std::max(mb.count, curr_capacity * 2); - auto reqs = nvcv::TensorBatch::CalcRequirements(new_capacity); - mb.input = nvcv::TensorBatch(reqs); - } else { - mb.input.clear(); - } - for (int i = mb.start; i < mb.start + mb.count; ++i) { - mb.input.pushBack(nvcvop::AsTensor(input[frame_idx(i)], sample_layout_)); - } - } - } - - void PrepareOutput(const TensorList &out) { - for (auto &mb : minibatches_) { - int curr_capacity = mb.output ? mb.output.capacity() : 0; - if (mb.count > curr_capacity) { - int new_capacity = std::max(mb.count, curr_capacity * 2); - auto reqs = nvcv::TensorBatch::CalcRequirements(new_capacity); - mb.output = nvcv::TensorBatch(reqs); - } else { - mb.output.clear(); - } - for (int i = mb.start; i < mb.start + mb.count; ++i) { - mb.output.pushBack(nvcvop::AsTensor(out[frame_idx(i)], sample_layout_)); - } - } - } - int minibatch_size_; }; diff --git a/dali/operators/image/resize/resize_op_impl.h b/dali/operators/image/resize/resize_op_impl.h index 89111cf9038..300f533052f 100644 --- a/dali/operators/image/resize/resize_op_impl.h +++ b/dali/operators/image/resize/resize_op_impl.h @@ -63,7 +63,8 @@ void GetFrameShapesAndParams( for (int i = 0; i < N; i++) { auto in_sample_shape = in_shape.tensor_shape_span(i); - total_frames += volume(&in_sample_shape[0], &in_sample_shape[first_spatial_dim]); + if (volume(in_sample_shape) > 0) + total_frames += volume(&in_sample_shape[0], &in_sample_shape[first_spatial_dim]); } frame_params.resize(total_frames); @@ -72,10 +73,11 @@ void GetFrameShapesAndParams( int ndim = in_shape.sample_dim(); for (int i = 0, flat_frame_idx = 0; i < N; i++) { auto in_sample_shape = in_shape.tensor_shape_span(i); + if (volume(in_sample_shape) == 0) { + continue; // skip empty samples + } // Collapse leading dimensions, if any, as frame dim. This handles channel-first. int seq_len = volume(&in_sample_shape[0], &in_sample_shape[first_spatial_dim]); - if (seq_len == 0) - continue; // skip empty sequences TensorShape frame_shape; frame_shape.resize(frame_ndim); diff --git a/dali/operators/nvcvop/nvcvop.cc b/dali/operators/nvcvop/nvcvop.cc index a82982eb71a..a11a4b2a832 100644 --- a/dali/operators/nvcvop/nvcvop.cc +++ b/dali/operators/nvcvop/nvcvop.cc @@ -14,8 +14,8 @@ #include "dali/operators/nvcvop/nvcvop.h" - #include +#include namespace dali::nvcvop { @@ -208,11 +208,11 @@ nvcv::Tensor AsTensor(ConstSampleView sample, TensorLayout layout, return AsTensor(const_cast(sample.raw_data()), shape, sample.type(), layout); } -nvcv::Tensor AsTensor(void *data, const TensorShape<> shape, DALIDataType daliDType, +nvcv::Tensor AsTensor(void *data, const TensorShape<> &shape, DALIDataType daliDType, TensorLayout layout) { auto dtype = GetDataType(daliDType, 1); nvcv::TensorDataStridedCuda::Buffer inBuf; - inBuf.basePtr = reinterpret_cast(const_cast(data)); + inBuf.basePtr = static_cast(const_cast(data)); inBuf.strides[shape.size() - 1] = dtype.strideBytes(); for (int d = shape.size() - 2; d >= 0; --d) { inBuf.strides[d] = shape[d + 1] * inBuf.strides[d + 1]; @@ -225,13 +225,68 @@ nvcv::Tensor AsTensor(void *data, const TensorShape<> shape, DALIDataType daliDT return nvcv::TensorWrapData(inData); } -void PushTensorsToBatch(nvcv::TensorBatch &batch, const TensorList &t_list, - TensorLayout layout) { - for (int s = 0; s < t_list.num_samples(); ++s) { - batch.pushBack(AsTensor(t_list[s], layout)); +nvcv::Tensor AsTensor(const void *data, span shape_data, const nvcv::DataType &dtype, + const nvcv::TensorLayout &layout) { + int ndim = shape_data.size(); + nvcv::TensorDataStridedCuda::Buffer inBuf; + inBuf.basePtr = static_cast(const_cast(data)); + inBuf.strides[ndim - 1] = dtype.strideBytes(); + for (int d = ndim - 2; d >= 0; --d) { + inBuf.strides[d] = shape_data[d + 1] * inBuf.strides[d + 1]; + } + nvcv::TensorShape out_shape(shape_data.data(), ndim, layout); + nvcv::TensorDataStridedCuda inData(out_shape, dtype, inBuf); + return nvcv::TensorWrapData(inData); +} + +int64_t calc_num_frames(const TensorShape<> &shape, int first_spatial_dim) { + return (first_spatial_dim > 0) ? + volume(&shape[0], &shape[first_spatial_dim]) : + 1; +} + +void PushFramesToBatch(nvcv::TensorBatch &batch, const TensorList &t_list, + int first_spatial_dim, int64_t starting_sample, int64_t frame_offset, + int64_t num_frames, const TensorLayout &layout) { + int ndim = layout.ndim(); + auto nvcv_layout = nvcv::TensorLayout(layout.c_str()); + auto dtype = GetDataType(t_list.type()); + + std::vector tensors; + tensors.reserve(num_frames); + + const auto &input_shape = t_list.shape(); + int64_t sample_id = starting_sample - 1; + auto type_size = dtype.strideBytes(); + std::vector frame_shape(ndim, 1); + + auto frame_stride = 0; + int sample_nframes = 0; + const uint8_t *data = nullptr; + + for (int64_t i = 0; i < num_frames; ++i) { + if (frame_offset == sample_nframes) { + frame_offset = 0; + do { + ++sample_id; + auto sample_shape = input_shape[sample_id]; + DALI_ENFORCE(sample_id < t_list.num_samples()); + std::copy(&sample_shape[first_spatial_dim], &sample_shape[input_shape.sample_dim()], + frame_shape.begin()); + frame_stride = volume(frame_shape) * type_size; + sample_nframes = calc_num_frames(sample_shape, first_spatial_dim); + } while (sample_nframes * frame_stride == 0); // we skip empty samples + data = + static_cast(t_list.raw_tensor(sample_id)) + frame_stride * frame_offset; + } + tensors.push_back(AsTensor(data, make_span(frame_shape), dtype, nvcv_layout)); + data += frame_stride; + frame_offset++; } + batch.pushBack(tensors.begin(), tensors.end()); } + cvcuda::Workspace NVCVOpWorkspace::Allocate(const cvcuda::WorkspaceRequirements &reqs, kernels::Scratchpad &scratchpad) { auto *hostBuffer = scratchpad.AllocateHost(reqs.hostMem.size, reqs.hostMem.alignment); @@ -248,4 +303,21 @@ cvcuda::Workspace NVCVOpWorkspace::Allocate(const cvcuda::WorkspaceRequirements return workspace_; } +nvcv::Allocator GetScratchpadAllocator(kernels::Scratchpad &scratchpad) { + auto hostAllocator = nvcv::CustomHostMemAllocator( + [&](int64_t size, int32_t align) { return scratchpad.AllocateHost(size, align); }, + [](void *, int64_t, int32_t) {}); + + auto pinnedAllocator = nvcv::CustomHostPinnedMemAllocator( + [&](int64_t size, int32_t align) { return scratchpad.AllocatePinned(size, align); }, + [](void *, int64_t, int32_t) {}); + + auto gpuAllocator = nvcv::CustomCudaMemAllocator( + [&](int64_t size, int32_t align) { return scratchpad.AllocateGPU(size, align); }, + [](void *, int64_t, int32_t) {}); + + return nvcv::CustomAllocator(std::move(hostAllocator), std::move(pinnedAllocator), + std::move(gpuAllocator)); +} + } // namespace dali::nvcvop diff --git a/dali/operators/nvcvop/nvcvop.h b/dali/operators/nvcvop/nvcvop.h index 998fd9ba108..407f9d0edec 100644 --- a/dali/operators/nvcvop/nvcvop.h +++ b/dali/operators/nvcvop/nvcvop.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -34,6 +35,7 @@ #include "dali/pipeline/operator/sequence_operator.h" #include "dali/core/cuda_event_pool.h" + namespace dali::nvcvop { /** @@ -112,7 +114,7 @@ nvcv::Tensor AsTensor(SampleView sample, TensorLayout layout = "", nvcv::Tensor AsTensor(ConstSampleView sample, TensorLayout layout = "", const std::optional> &reshape = std::nullopt); -nvcv::Tensor AsTensor(void *data, const TensorShape<> shape, DALIDataType dtype, +nvcv::Tensor AsTensor(void *data, const TensorShape<> &shape, DALIDataType dtype, TensorLayout layout); /** @@ -133,8 +135,26 @@ void AllocateImagesLike(nvcv::ImageBatchVarShape &output, const TensorList &t_list); -void PushTensorsToBatch(nvcv::TensorBatch &batch, const TensorList &t_list, - TensorLayout layout); +/** + * @brief Push a range of frames from the input TensorList as samples in the output TensorBatch. + * + * The input TensorList is interpreted as sequence of frames where innermost dimensions + * starting from `first_spatial_dim` are the frames' dimensions. + * + * The range of frames is determined by the `starting_sample`, `frame_offset` + * and `num_frames` arguments. + * `starting_sample` is an index of the first source sample from the input TensorList. All the samples before that are skipped. + * `frame_offset` is an index of a first frame in the starting sample to be taken. + * `num_frames` is the total number of frames that will be pushed to the output TensorBatch. + * + * @param batch output TensorBatch + * @param t_list input TensorList + * @param layout layout of the output TensorBatch + */ +void PushFramesToBatch(nvcv::TensorBatch &batch, const TensorList &t_list, + int first_spatial_dim, int64_t starting_sample, int64_t frame_offset, + int64_t num_frames, const TensorLayout &layout); + class NVCVOpWorkspace { public: @@ -165,17 +185,10 @@ class NVCVOpWorkspace { int device_id_{}; }; -inline cvcuda::WorkspaceRequirements MaxWorkspaceRequirements( - const cvcuda::WorkspaceRequirements &a, const cvcuda::WorkspaceRequirements &b) { - cvcuda::WorkspaceRequirements max; - max.hostMem.size = std::max(a.hostMem.size, b.hostMem.size); - max.hostMem.alignment = std::max(a.hostMem.alignment, b.hostMem.alignment); - max.pinnedMem.size = std::max(a.pinnedMem.size, b.pinnedMem.size); - max.pinnedMem.alignment = std::max(a.pinnedMem.alignment, b.pinnedMem.alignment); - max.cudaMem.size = std::max(a.cudaMem.size, b.cudaMem.size); - max.cudaMem.alignment = std::max(a.cudaMem.alignment, b.cudaMem.alignment); - return max; -} +/** + * @brief Create an NVCV allocator using the given scratchpad. + */ +nvcv::Allocator GetScratchpadAllocator(kernels::Scratchpad &scratchpad); /** * @brief A base class for the CVCUDA operators.