diff --git a/dali/kernels/signal/wavelet/mother_wavelet.cu b/dali/kernels/signal/wavelet/mother_wavelet.cu index 79bc695a36f..6e1c0279969 100644 --- a/dali/kernels/signal/wavelet/mother_wavelet.cu +++ b/dali/kernels/signal/wavelet/mother_wavelet.cu @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include "dali/kernels/signal/wavelet/mother_wavelet.cuh" #include "dali/core/math_util.h" @@ -42,33 +43,60 @@ template class HaarWavelet; template class HaarWavelet; template -MeyerWavelet::MeyerWavelet(const std::vector &args) { - if (args.size() != 0) { - throw new std::invalid_argument("MeyerWavelet doesn't accept any arguments."); +GaussianWavelet::GaussianWavelet(const std::vector &args) { + if (args.size() != 1) { + throw new std::invalid_argument("GaussianWavelet accepts exactly one argument - n."); + } + if (args[0] < 1.0 || args[0] > 8.0) { + throw new std::invalid_argument( + "GaussianWavelet's argument n should be integer from range [1,8]."); } + this->n = args[0]; } template -__device__ T MeyerWavelet::operator()(const T &t) const { - T psi1 = (4/(3*M_PI)*t*std::cos((2*M_PI)/3*t)-1/M_PI*std::sin((4*M_PI)/3*t))/(t-16/9*std::pow(t, 3.0)); - T psi2 = (8/(3*M_PI)*t*std::cos(8*M_PI/3*t)+1/M_PI*std::sin((4*M_PI)/3)*t)/(t-64/9*std::pow(t, 3.0)); - return psi1 + psi2; +__device__ T GaussianWavelet::operator()(const T &t) const { + T expTerm = std::exp(-std::pow(t, 2.0)); + T sqrtTerm = 1.2533141373155001; // std::sqrt(M_PI/2.0) + switch (static_cast(n)) { + case 1: + return -2.0*t*expTerm/std::sqrt(sqrtTerm); + case 2: + return (-4.0*std::pow(t, 2.0)+2.0)*expTerm/std::sqrt(3.0*sqrtTerm); + case 3: + return (8.0*std::pow(t, 3.0)-12.0*t)*expTerm/std::sqrt(15.0*sqrtTerm); + case 4: + return (-48.0*std::pow(t, 2.0)+16.0*std::pow(t, 4.0)+12.0)*expTerm/std::sqrt(105.0*sqrtTerm); + case 5: + return (-32.0*std::pow(t, 5.0)+160.0*std::pow(t, 3.0)-120.0*t)* + expTerm/std::sqrt(945.0*sqrtTerm); + case 6: + return (-64.0*std::pow(t, 6.0)+480.0*std::pow(t, 4.0)-720.0*std::pow(t, 2.0)+120.0)* + expTerm/std::sqrt(10395.0*sqrtTerm); + case 7: + return (128.0*std::pow(t, 7.0)-1344.0*std::pow(t, 5.0)+3360.0*std::pow(t, 3.0)-1680.0*t)* + expTerm/std::sqrt(135135.0*sqrtTerm); + case 8: + return (256.0*std::pow(t, 8.0)-3584.0*std::pow(t, 6.0)+13440.0*std::pow(t, 4.0)-13440.0* + std::pow(t, 2.0)+1680.0)*expTerm/std::sqrt(2027025.0*sqrtTerm); + } } -template class MeyerWavelet; -template class MeyerWavelet; +template class GaussianWavelet; +template class GaussianWavelet; template MexicanHatWavelet::MexicanHatWavelet(const std::vector &args) { if (args.size() != 1) { throw new std::invalid_argument("MexicanHatWavelet accepts exactly one argument - sigma."); } - this->sigma = *args.begin(); + this->sigma = args[0]; } template __device__ T MexicanHatWavelet::operator()(const T &t) const { - return 2/(std::sqrt(3*sigma)*std::pow(M_PI, 0.25))*(1-std::pow(t/sigma, 2.0))*std::exp(-std::pow(t, 2.0)/(2*std::pow(sigma, 2.0))); + return 2.0/(std::sqrt(3.0*sigma)*std::pow(M_PI, 0.25))*(1.0-std::pow(t/sigma, 2.0))* + std::exp(-std::pow(t, 2.0)/(2.0*std::pow(sigma, 2.0))); } template class MexicanHatWavelet; @@ -79,12 +107,12 @@ MorletWavelet::MorletWavelet(const std::vector &args) { if (args.size() != 1) { throw new std::invalid_argument("MorletWavelet accepts exactly 1 argument - C."); } - this->C = *args.begin(); + this->C = args[0]; } template __device__ T MorletWavelet::operator()(const T &t) const { - return C * std::exp(-std::pow(t, 2.0)) * std::cos(5 * t); + return C * std::exp(-std::pow(t, 2.0) / 2.0) * std::cos(5.0 * t); } template class MorletWavelet; @@ -92,14 +120,18 @@ template class MorletWavelet; template ShannonWavelet::ShannonWavelet(const std::vector &args) { - if (args.size() != 0) { - throw new std::invalid_argument("ShannonWavelet doesn't accept any arguments."); + if (args.size() != 2) { + throw new std::invalid_argument( + "ShannonWavelet accepts exactly 2 arguments -> fb, fc in that order."); } + this->fb = args[0]; + this->fc = args[1]; } template __device__ T ShannonWavelet::operator()(const T &t) const { - return sinc(t - 0.5) - 2 * sinc(2 * t - 1); + auto res = std::cos((T)(2.0*M_PI)*fc*t)*std::sqrt(fb); + return t == 0.0 ? res : res*std::sin(t*fb*(T)(M_PI))/(t*fb*(T)(M_PI)); } template class ShannonWavelet; @@ -107,22 +139,24 @@ template class ShannonWavelet; template FbspWavelet::FbspWavelet(const std::vector &args) { - if (args.size() != 0) { - throw new std::invalid_argument("FbspWavelet accepts exactly 3 arguments -> m, fb, fc in that order."); + if (args.size() != 3) { + throw new std::invalid_argument( + "FbspWavelet accepts exactly 3 arguments -> m, fb, fc in that order."); } - this->m = *args.begin(); - this->fb = *(args.begin()+1); - this->fc = *(args.begin()+2); + this->m = args[0]; + this->fb = args[1]; + this->fc = args[2]; } template __device__ T FbspWavelet::operator()(const T &t) const { - return std::sqrt(fb)*std::pow(sinc(t/std::pow(fb, m)), m)*std::exp(2*M_PI*fc*t); + auto res = std::cos((T)(2.0*M_PI)*fc*t)*std::sqrt(fb); + return t == 0.0 ? res : res*std::pow(std::sin((T)(M_PI)*t*fb/m)/((T)(M_PI)*t*fb/m), m); } template class FbspWavelet; template class FbspWavelet; } // namespace signal -} // namespace kernel +} // namespace kernels } // namespace dali diff --git a/dali/kernels/signal/wavelet/mother_wavelet.cuh b/dali/kernels/signal/wavelet/mother_wavelet.cuh index 52388b97e6e..1e618be69f1 100644 --- a/dali/kernels/signal/wavelet/mother_wavelet.cuh +++ b/dali/kernels/signal/wavelet/mother_wavelet.cuh @@ -15,14 +15,14 @@ #ifndef DALI_KERNELS_SIGNAL_WAVELET_MOTHER_WAVELET_CUH_ #define DALI_KERNELS_SIGNAL_WAVELET_MOTHER_WAVELET_CUH_ +#include + #include "dali/core/common.h" #include "dali/core/error_handling.h" #include "dali/core/format.h" #include "dali/core/util.h" #include "dali/kernels/kernel.h" -#include - namespace dali { namespace kernels { namespace signal { @@ -37,37 +37,24 @@ class HaarWavelet { "Data type should be floating point"); public: HaarWavelet() = default; - HaarWavelet(const std::vector &args); + explicit HaarWavelet(const std::vector &args); ~HaarWavelet() = default; __device__ T operator()(const T &t) const; }; -template -class MeyerWavelet { - static_assert(std::is_floating_point::value, - "Data type should be floating point"); - public: - MeyerWavelet() = default; - MeyerWavelet(const std::vector &args); - ~MeyerWavelet() = default; - - __device__ T operator()(const T &t) const; -}; - template class GaussianWavelet { static_assert(std::is_floating_point::value, "Data type should be floating point"); public: GaussianWavelet() = default; - GaussianWavelet(const std::vector &args); + explicit GaussianWavelet(const std::vector &args); ~GaussianWavelet() = default; __device__ T operator()(const T &t) const; - private: - uint8_t N; + T n; }; template @@ -76,7 +63,7 @@ class MexicanHatWavelet { "Data type should be floating point"); public: MexicanHatWavelet() = default; - MexicanHatWavelet(const std::vector &args); + explicit MexicanHatWavelet(const std::vector &args); ~MexicanHatWavelet() = default; __device__ T operator()(const T &t) const; @@ -91,7 +78,7 @@ class MorletWavelet { "Data type should be floating point"); public: MorletWavelet() = default; - MorletWavelet(const std::vector &args); + explicit MorletWavelet(const std::vector &args); ~MorletWavelet() = default; __device__ T operator()(const T &t) const; @@ -106,10 +93,14 @@ class ShannonWavelet { "Data type should be floating point"); public: ShannonWavelet() = default; - ShannonWavelet(const std::vector &args); + explicit ShannonWavelet(const std::vector &args); ~ShannonWavelet() = default; __device__ T operator()(const T &t) const; + + private: + T fb; + T fc; }; template @@ -118,7 +109,7 @@ class FbspWavelet { "Data type should be floating point"); public: FbspWavelet() = default; - FbspWavelet(const std::vector &args); + explicit FbspWavelet(const std::vector &args); ~FbspWavelet() = default; __device__ T operator()(const T &t) const; @@ -130,7 +121,7 @@ class FbspWavelet { }; } // namespace signal -} // namespace kernel +} // namespace kernels } // namespace dali #endif // DALI_KERNELS_SIGNAL_WAVELET_MOTHER_WAVELET_CUH_ diff --git a/dali/kernels/signal/wavelet/wavelet_gpu.cu b/dali/kernels/signal/wavelet/wavelet_gpu.cu index d8b07667e00..a5ab81a5dfa 100644 --- a/dali/kernels/signal/wavelet/wavelet_gpu.cu +++ b/dali/kernels/signal/wavelet/wavelet_gpu.cu @@ -42,18 +42,17 @@ __global__ void ComputeWavelet(const SampleDesc* sample_data, W wavelet) { auto x = std::pow(2.0, a); if (a == 0.0) { shm[b_id] = sample.in[t_id]; - } - else { + } else { shm[b_id] = x * sample.in[t_id]; shm[1024] = std::pow(2.0, a / 2.0); } + __syncthreads(); for (int i = 0; i < sample.size_b; ++i) { const int64_t out_id = blockIdx.y * sample.size_b * sample.size_in + i * sample.size_in + t_id; auto b = sample.b[i]; if (b == 0.0) { sample.out[out_id] = wavelet(shm[b_id]); - } - else { + } else { sample.out[out_id] = wavelet(shm[b_id] - b); } if (a != 0.0) { @@ -65,7 +64,8 @@ __global__ void ComputeWavelet(const SampleDesc* sample_data, W wavelet) { // translate input range information to input samples template __global__ void ComputeInputSamples(const SampleDesc* sample_data) { - const int64_t t_id = blockDim.x * blockDim.y * blockIdx.x + threadIdx.y * blockDim.x + threadIdx.x; + const int64_t block_size = blockDim.x * blockDim.y; + const int64_t t_id = block_size * blockIdx.x + threadIdx.y * blockDim.x + threadIdx.x; auto& sample = sample_data[blockIdx.y]; if (t_id >= sample.size_in) return; sample.in[t_id] = sample.span.begin + (T)t_id / sample.span.sampling_rate; @@ -106,8 +106,9 @@ DLL_PUBLIC void WaveletGpu::Run(KernelContext &ctx, sample.b = b.tensor_data(i); sample.size_b = b.shape.tensor_size(i); sample.span = span; - sample.size_in = std::ceil((sample.span.end - sample.span.begin) * sample.span.sampling_rate); - CUDA_CALL(cudaMalloc(&(sample.in), sizeof(T) * sample.size_in)); + sample.size_in = + std::ceil((sample.span.end - sample.span.begin) * sample.span.sampling_rate) + 1; + sample.in = ctx.scratchpad->AllocateGPU(sample.size_in); max_size_in = std::max(max_size_in, sample.size_in); } @@ -128,13 +129,15 @@ TensorListShape<> WaveletGpu::GetOutputShape(const TensorListShape<> &a_sh const TensorListShape<> &b_shape, const WaveletSpan &span) { int N = a_shape.num_samples(); - int in_size = std::ceil((span.end - span.begin) * span.sampling_rate); + int in_size = std::ceil((span.end - span.begin) * span.sampling_rate) + 1; TensorListShape<> out_shape(N, 3); TensorShape<> tshape; for (int i = 0; i < N; i++) { - // output tensor will be 3-dimensional of shape: + // output tensor will be 3-dimensional of shape: // a coeffs x b coeffs x signal samples - tshape = TensorShape<>({a_shape.tensor_shape(i).num_elements(), b_shape.tensor_shape(i).num_elements(), in_size}); + tshape = TensorShape<>({a_shape.tensor_shape(i).num_elements(), + b_shape.tensor_shape(i).num_elements(), + in_size}); out_shape.set_tensor_shape(i, tshape); } return out_shape; @@ -142,8 +145,8 @@ TensorListShape<> WaveletGpu::GetOutputShape(const TensorListShape<> &a_sh template class WaveletGpu; template class WaveletGpu; -template class WaveletGpu; -template class WaveletGpu; +template class WaveletGpu; +template class WaveletGpu; template class WaveletGpu; template class WaveletGpu; template class WaveletGpu; diff --git a/dali/kernels/signal/wavelet/wavelet_gpu.cuh b/dali/kernels/signal/wavelet/wavelet_gpu.cuh index 0026ff58c43..45c54e8b251 100644 --- a/dali/kernels/signal/wavelet/wavelet_gpu.cuh +++ b/dali/kernels/signal/wavelet/wavelet_gpu.cuh @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef DALI_KERNELS_SIGNAL_WAVELET_GPU_CUH_ -#define DALI_KERNELS_SIGNAL_WAVELET_GPU_CUH_ +#ifndef DALI_KERNELS_SIGNAL_WAVELET_WAVELET_GPU_CUH_ +#define DALI_KERNELS_SIGNAL_WAVELET_WAVELET_GPU_CUH_ #include #include +#include #include "dali/core/common.h" #include "dali/core/error_handling.h" #include "dali/core/format.h" @@ -26,13 +27,16 @@ // makes sure both tensors have the same number of samples and // that they're one-dimensional -#define ENFORCE_SHAPES(a_shape, b_shape) do { \ -DALI_ENFORCE(a_shape.num_samples() == b_shape.num_samples(),"a and b tensors must have the same amount of samples."); \ -for (int i = 0; i < a_shape.num_samples(); ++i) { \ - DALI_ENFORCE(a_shape.tensor_shape(i).size() == 1, "Tensor of a coeffs should be 1-dimensional."); \ - DALI_ENFORCE(b_shape.tensor_shape(i).size() == 1, "Tensor of b coeffs should be 1-dimensional."); \ -} \ -} while(0); +#define ENFORCE_SHAPES(a_shape, b_shape) do { \ +DALI_ENFORCE(a_shape.num_samples() == b_shape.num_samples(), \ + "a and b tensors must have the same amount of samples."); \ +for (int i = 0; i < a_shape.num_samples(); ++i) { \ + DALI_ENFORCE(a_shape.tensor_shape(i).size() == 1, \ + "Tensor of a coeffs should be 1-dimensional."); \ + DALI_ENFORCE(b_shape.tensor_shape(i).size() == 1, \ + "Tensor of b coeffs should be 1-dimensional."); \ +} \ +} while (0); namespace dali { namespace kernels { @@ -90,6 +94,7 @@ class DLL_PUBLIC WaveletGpu { static TensorListShape<> GetOutputShape(const TensorListShape<> &a_shape, const TensorListShape<> &b_shape, const WaveletSpan &span); + private: W wavelet_; }; @@ -98,4 +103,4 @@ class DLL_PUBLIC WaveletGpu { } // namespace kernels } // namespace dali -#endif // DALI_KERNELS_SIGNAL_WAVELET_GPU_CUH_ +#endif // DALI_KERNELS_SIGNAL_WAVELET_WAVELET_GPU_CUH_ diff --git a/dali/operators/signal/wavelet/wavelet_name.h b/dali/operators/signal/wavelet/wavelet_name.h new file mode 100644 index 00000000000..5e53713bbae --- /dev/null +++ b/dali/operators/signal/wavelet/wavelet_name.h @@ -0,0 +1,34 @@ +// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_OPERATORS_SIGNAL_WAVELET_WAVELET_NAME_H_ +#define DALI_OPERATORS_SIGNAL_WAVELET_WAVELET_NAME_H_ + +namespace dali { + +/** + * @brief Supported wavelet names + */ +enum DALIWaveletName { + DALI_HAAR = 0, + DALI_GAUS = 1, + DALI_MEXH = 2, + DALI_MORL = 3, + DALI_SHAN = 4, + DALI_FBSP = 5 +}; + +} // namespace dali + +#endif // DALI_OPERATORS_SIGNAL_WAVELET_WAVELET_NAME_H_ diff --git a/dali/operators/signal/wavelet/wavelet_run.h b/dali/operators/signal/wavelet/wavelet_run.h index 93b2c1840d0..def362cff75 100644 --- a/dali/operators/signal/wavelet/wavelet_run.h +++ b/dali/operators/signal/wavelet/wavelet_run.h @@ -45,7 +45,7 @@ void RunWaveletKernel(kernels::KernelManager &kmgr, // translates wavelet name to type and runs RunWaveletKernel() for that type template -void RunForName(const std::string &name, +void RunForName(const DALIWaveletName &name, kernels::KernelManager &kmgr, size_t size, size_t device, @@ -55,29 +55,36 @@ void RunForName(const std::string &name, TensorListView &b, const kernels::signal::WaveletSpan &span, const std::vector &args) { - if (name == "HAAR") { - RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); - } - else if (name == "MEY") { - RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); - } - else if (name == "MEXH") { - RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); - } - else if (name == "MORL") { - RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); - } - else if (name == "SHAN") { - RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); - } - else if (name == "FBSP") { - RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); - } - else { + switch (name) { + case DALIWaveletName::DALI_HAAR: + using kernels::signal::HaarWavelet; + RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); + break; + case DALIWaveletName::DALI_GAUS: + using kernels::signal::GaussianWavelet; + RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); + break; + case DALIWaveletName::DALI_MEXH: + using kernels::signal::MexicanHatWavelet; + RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); + break; + case DALIWaveletName::DALI_MORL: + using kernels::signal::MorletWavelet; + RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); + break; + case DALIWaveletName::DALI_SHAN: + using kernels::signal::ShannonWavelet; + RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); + break; + case DALIWaveletName::DALI_FBSP: + using kernels::signal::FbspWavelet; + RunWaveletKernel(kmgr, size, device, ctx, out, a, b, span, args); + break; + default: throw new std::invalid_argument("Unknown wavelet name."); } } -} // namespace dali +} // namespace dali -#endif // DALI_OPERATORS_SIGNAL_WAVELET_RUN_H_ \ No newline at end of file +#endif // DALI_OPERATORS_SIGNAL_WAVELET_WAVELET_RUN_H_ diff --git a/dali/pipeline/data/types.h b/dali/pipeline/data/types.h index 0efa36e5a19..eed79432c6c 100644 --- a/dali/pipeline/data/types.h +++ b/dali/pipeline/data/types.h @@ -30,6 +30,7 @@ #include "dali/core/float16.h" #include "dali/core/cuda_error.h" #include "dali/core/tensor_layout.h" +#include "dali/operators/signal/wavelet/wavelet_name.h" #ifdef DALI_BUILD_PROTO3 #include "dali/operators/reader/parser/tf_feature.h" @@ -123,6 +124,7 @@ enum DALIDataType : int { DALI_PYTHON_OBJECT = 24, DALI_TENSOR_LAYOUT_VEC = 25, DALI_DATA_TYPE_VEC = 26, + DALI_WAVELET_NAME = 27, DALI_DATATYPE_END = 1000 }; @@ -202,6 +204,9 @@ inline const char *GetBuiltinTypeName(DALIDataType t) { case DALI_INTERP_TYPE: return "DALIInterpType"; break; + case DALI_WAVELET_NAME: + return "DALIWaveletName"; + break; case DALI_TENSOR_LAYOUT: return "TensorLayout"; break; @@ -557,24 +562,25 @@ DLL_PUBLIC inline bool IsValidType(const TypeInfo &type) { DALI_REGISTER_TYPE_IMPL(Type, dtype); // Instantiate some basic types -DALI_REGISTER_TYPE(NoType, DALI_NO_TYPE); -DALI_REGISTER_TYPE(uint8_t, DALI_UINT8); -DALI_REGISTER_TYPE(uint16_t, DALI_UINT16); -DALI_REGISTER_TYPE(uint32_t, DALI_UINT32); -DALI_REGISTER_TYPE(uint64_t, DALI_UINT64); -DALI_REGISTER_TYPE(int8_t, DALI_INT8); -DALI_REGISTER_TYPE(int16_t, DALI_INT16); -DALI_REGISTER_TYPE(int32_t, DALI_INT32); -DALI_REGISTER_TYPE(int64_t, DALI_INT64); -DALI_REGISTER_TYPE(float16, DALI_FLOAT16); -DALI_REGISTER_TYPE(float, DALI_FLOAT); -DALI_REGISTER_TYPE(double, DALI_FLOAT64); -DALI_REGISTER_TYPE(bool, DALI_BOOL); -DALI_REGISTER_TYPE(string, DALI_STRING); -DALI_REGISTER_TYPE(DALIImageType, DALI_IMAGE_TYPE); -DALI_REGISTER_TYPE(DALIDataType, DALI_DATA_TYPE); -DALI_REGISTER_TYPE(DALIInterpType, DALI_INTERP_TYPE); -DALI_REGISTER_TYPE(TensorLayout, DALI_TENSOR_LAYOUT); +DALI_REGISTER_TYPE(NoType, DALI_NO_TYPE); +DALI_REGISTER_TYPE(uint8_t, DALI_UINT8); +DALI_REGISTER_TYPE(uint16_t, DALI_UINT16); +DALI_REGISTER_TYPE(uint32_t, DALI_UINT32); +DALI_REGISTER_TYPE(uint64_t, DALI_UINT64); +DALI_REGISTER_TYPE(int8_t, DALI_INT8); +DALI_REGISTER_TYPE(int16_t, DALI_INT16); +DALI_REGISTER_TYPE(int32_t, DALI_INT32); +DALI_REGISTER_TYPE(int64_t, DALI_INT64); +DALI_REGISTER_TYPE(float16, DALI_FLOAT16); +DALI_REGISTER_TYPE(float, DALI_FLOAT); +DALI_REGISTER_TYPE(double, DALI_FLOAT64); +DALI_REGISTER_TYPE(bool, DALI_BOOL); +DALI_REGISTER_TYPE(string, DALI_STRING); +DALI_REGISTER_TYPE(DALIImageType, DALI_IMAGE_TYPE); +DALI_REGISTER_TYPE(DALIDataType, DALI_DATA_TYPE); +DALI_REGISTER_TYPE(DALIInterpType, DALI_INTERP_TYPE); +DALI_REGISTER_TYPE(DALIWaveletName, DALI_WAVELET_NAME); +DALI_REGISTER_TYPE(TensorLayout, DALI_TENSOR_LAYOUT); #ifdef DALI_BUILD_PROTO3 diff --git a/dali/python/backend_impl.cc b/dali/python/backend_impl.cc index 262c2ee9070..1706532d8c2 100644 --- a/dali/python/backend_impl.cc +++ b/dali/python/backend_impl.cc @@ -27,6 +27,7 @@ #include "dali/operators.h" #include "dali/kernels/kernel.h" #include "dali/operators/reader/parser/tfrecord_parser.h" +#include "dali/operators/signal/wavelet/wavelet_name.h" #include "dali/pipeline/data/copy_to_external.h" #include "dali/pipeline/data/dltensor.h" #include "dali/pipeline/data/tensor.h" @@ -121,8 +122,7 @@ py::dict ArrayInterfaceRepr(Tensor &t) { d["shape"] = py::tuple(py_shape(t)); // tuple of (raw_data_pointer, if_data_is_read_only) tup[0] = py::reinterpret_borrow(PyLong_FromVoidPtr(t.raw_mutable_data())); - // if we make it readonly, it prevents us from sharing memory with PyTorch tensor - tup[1] = false; + tup[1] = true; d["data"] = tup; if (std::is_same::value) { // see https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html @@ -1672,6 +1672,7 @@ PYBIND11_MODULE(backend_impl, m) { .value("IMAGE_TYPE", DALI_IMAGE_TYPE) .value("DATA_TYPE", DALI_DATA_TYPE) .value("INTERP_TYPE", DALI_INTERP_TYPE) + .value("WAVELET_NAME", DALI_WAVELET_NAME) .value("TENSOR_LAYOUT", DALI_TENSOR_LAYOUT) .value("PYTHON_OBJECT", DALI_PYTHON_OBJECT) .value("_TENSOR_LAYOUT_VEC", DALI_TENSOR_LAYOUT_VEC) @@ -1716,6 +1717,16 @@ PYBIND11_MODULE(backend_impl, m) { .value("INTERP_GAUSSIAN", DALI_INTERP_GAUSSIAN) .export_values(); + // DALIWaveletName + py::enum_(types_m, "DALIWaveletName", "Wavelet name\n") + .value("HAAR", DALI_HAAR) + .value("GAUS", DALI_GAUS) + .value("MEXH", DALI_MEXH) + .value("MORL", DALI_MORL) + .value("SHAN", DALI_SHAN) + .value("FBSP", DALI_FBSP) + .export_values(); + // Operator node py::class_(m, "OpNode") .def("instance_name", @@ -1998,6 +2009,7 @@ PYBIND11_MODULE(backend_impl, m) { DALI_OPSPEC_ADDARG(DALIDataType) DALI_OPSPEC_ADDARG(DALIImageType) DALI_OPSPEC_ADDARG(DALIInterpType) + DALI_OPSPEC_ADDARG(DALIWaveletName) #ifdef DALI_BUILD_PROTO3 DALI_OPSPEC_ADDARG(TFFeature) #endif diff --git a/dali/python/nvidia/dali/types.py b/dali/python/nvidia/dali/types.py index f4362fd2249..5d56077ee62 100644 --- a/dali/python/nvidia/dali/types.py +++ b/dali/python/nvidia/dali/types.py @@ -16,7 +16,8 @@ from enum import Enum, unique import re -from nvidia.dali.backend_impl.types import DALIDataType, DALIImageType, DALIInterpType +from nvidia.dali.backend_impl.types import DALIDataType, DALIImageType, \ + DALIInterpType, DALIWaveletName # TODO: Handle forwarding imports from backend_impl from nvidia.dali.backend_impl.types import * # noqa: F401, F403 @@ -63,6 +64,8 @@ def _not_implemented(val): DALIDataType.DATA_TYPE: ("nvidia.dali.types.DALIDataType", lambda x: DALIDataType(int(x))), DALIDataType.INTERP_TYPE: ("nvidia.dali.types.DALIInterpType", lambda x: DALIInterpType(int(x))), + DALIDataType.WAVELET_NAME: + ("nvidia.dali.types.DALIWaveletName", lambda x: DALIWaveletName(int(x))), DALIDataType.TENSOR_LAYOUT: (":ref:`layout str`", lambda x: str(x)), DALIDataType.PYTHON_OBJECT: ("object", lambda x: x), DALIDataType._TENSOR_LAYOUT_VEC: