Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wavelet computing improvements #5

Merged
merged 3 commits into from
Jun 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 57 additions & 23 deletions dali/kernels/signal/wavelet/mother_wavelet.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include <cmath>
#include <vector>
#include "dali/kernels/signal/wavelet/mother_wavelet.cuh"
#include "dali/core/math_util.h"

Expand Down Expand Up @@ -42,33 +43,60 @@ template class HaarWavelet<float>;
template class HaarWavelet<double>;

template <typename T>
MeyerWavelet<T>::MeyerWavelet(const std::vector<T> &args) {
if (args.size() != 0) {
throw new std::invalid_argument("MeyerWavelet doesn't accept any arguments.");
GaussianWavelet<T>::GaussianWavelet(const std::vector<T> &args) {
if (args.size() != 1) {
throw new std::invalid_argument("GaussianWavelet accepts exactly one argument - n.");
}
if (args[0] < 1.0 || args[0] > 8.0) {
throw new std::invalid_argument(
"GaussianWavelet's argument n should be integer from range [1,8].");
}
this->n = args[0];
}

template <typename T>
__device__ T MeyerWavelet<T>::operator()(const T &t) const {
T psi1 = (4/(3*M_PI)*t*std::cos((2*M_PI)/3*t)-1/M_PI*std::sin((4*M_PI)/3*t))/(t-16/9*std::pow(t, 3.0));
T psi2 = (8/(3*M_PI)*t*std::cos(8*M_PI/3*t)+1/M_PI*std::sin((4*M_PI)/3)*t)/(t-64/9*std::pow(t, 3.0));
return psi1 + psi2;
__device__ T GaussianWavelet<T>::operator()(const T &t) const {
T expTerm = std::exp(-std::pow(t, 2.0));
T sqrtTerm = 1.2533141373155001; // std::sqrt(M_PI/2.0)
switch (static_cast<int>(n)) {
case 1:
return -2.0*t*expTerm/std::sqrt(sqrtTerm);
case 2:
return (-4.0*std::pow(t, 2.0)+2.0)*expTerm/std::sqrt(3.0*sqrtTerm);
case 3:
return (8.0*std::pow(t, 3.0)-12.0*t)*expTerm/std::sqrt(15.0*sqrtTerm);
case 4:
return (-48.0*std::pow(t, 2.0)+16.0*std::pow(t, 4.0)+12.0)*expTerm/std::sqrt(105.0*sqrtTerm);
case 5:
return (-32.0*std::pow(t, 5.0)+160.0*std::pow(t, 3.0)-120.0*t)*
expTerm/std::sqrt(945.0*sqrtTerm);
case 6:
return (-64.0*std::pow(t, 6.0)+480.0*std::pow(t, 4.0)-720.0*std::pow(t, 2.0)+120.0)*
expTerm/std::sqrt(10395.0*sqrtTerm);
case 7:
return (128.0*std::pow(t, 7.0)-1344.0*std::pow(t, 5.0)+3360.0*std::pow(t, 3.0)-1680.0*t)*
expTerm/std::sqrt(135135.0*sqrtTerm);
case 8:
return (256.0*std::pow(t, 8.0)-3584.0*std::pow(t, 6.0)+13440.0*std::pow(t, 4.0)-13440.0*
std::pow(t, 2.0)+1680.0)*expTerm/std::sqrt(2027025.0*sqrtTerm);
}
}

template class MeyerWavelet<float>;
template class MeyerWavelet<double>;
template class GaussianWavelet<float>;
template class GaussianWavelet<double>;

template <typename T>
MexicanHatWavelet<T>::MexicanHatWavelet(const std::vector<T> &args) {
if (args.size() != 1) {
throw new std::invalid_argument("MexicanHatWavelet accepts exactly one argument - sigma.");
}
this->sigma = *args.begin();
this->sigma = args[0];
}

template <typename T>
__device__ T MexicanHatWavelet<T>::operator()(const T &t) const {
return 2/(std::sqrt(3*sigma)*std::pow(M_PI, 0.25))*(1-std::pow(t/sigma, 2.0))*std::exp(-std::pow(t, 2.0)/(2*std::pow(sigma, 2.0)));
return 2.0/(std::sqrt(3.0*sigma)*std::pow(M_PI, 0.25))*(1.0-std::pow(t/sigma, 2.0))*
std::exp(-std::pow(t, 2.0)/(2.0*std::pow(sigma, 2.0)));
}

template class MexicanHatWavelet<float>;
Expand All @@ -79,50 +107,56 @@ MorletWavelet<T>::MorletWavelet(const std::vector<T> &args) {
if (args.size() != 1) {
throw new std::invalid_argument("MorletWavelet accepts exactly 1 argument - C.");
}
this->C = *args.begin();
this->C = args[0];
}

template <typename T>
__device__ T MorletWavelet<T>::operator()(const T &t) const {
return C * std::exp(-std::pow(t, 2.0)) * std::cos(5 * t);
return C * std::exp(-std::pow(t, 2.0) / 2.0) * std::cos(5.0 * t);
}

template class MorletWavelet<float>;
template class MorletWavelet<double>;

template <typename T>
ShannonWavelet<T>::ShannonWavelet(const std::vector<T> &args) {
if (args.size() != 0) {
throw new std::invalid_argument("ShannonWavelet doesn't accept any arguments.");
if (args.size() != 2) {
throw new std::invalid_argument(
"ShannonWavelet accepts exactly 2 arguments -> fb, fc in that order.");
}
this->fb = args[0];
this->fc = args[1];
}

template <typename T>
__device__ T ShannonWavelet<T>::operator()(const T &t) const {
return sinc(t - 0.5) - 2 * sinc(2 * t - 1);
auto res = std::cos((T)(2.0*M_PI)*fc*t)*std::sqrt(fb);
return t == 0.0 ? res : res*std::sin(t*fb*(T)(M_PI))/(t*fb*(T)(M_PI));
}

template class ShannonWavelet<float>;
template class ShannonWavelet<double>;

template <typename T>
FbspWavelet<T>::FbspWavelet(const std::vector<T> &args) {
if (args.size() != 0) {
throw new std::invalid_argument("FbspWavelet accepts exactly 3 arguments -> m, fb, fc in that order.");
if (args.size() != 3) {
throw new std::invalid_argument(
"FbspWavelet accepts exactly 3 arguments -> m, fb, fc in that order.");
}
this->m = *args.begin();
this->fb = *(args.begin()+1);
this->fc = *(args.begin()+2);
this->m = args[0];
this->fb = args[1];
this->fc = args[2];
}

template <typename T>
__device__ T FbspWavelet<T>::operator()(const T &t) const {
return std::sqrt(fb)*std::pow(sinc(t/std::pow(fb, m)), m)*std::exp(2*M_PI*fc*t);
auto res = std::cos((T)(2.0*M_PI)*fc*t)*std::sqrt(fb);
return t == 0.0 ? res : res*std::pow(std::sin((T)(M_PI)*t*fb/m)/((T)(M_PI)*t*fb/m), m);
}

template class FbspWavelet<float>;
template class FbspWavelet<double>;

} // namespace signal
} // namespace kernel
} // namespace kernels
} // namespace dali
37 changes: 14 additions & 23 deletions dali/kernels/signal/wavelet/mother_wavelet.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
#ifndef DALI_KERNELS_SIGNAL_WAVELET_MOTHER_WAVELET_CUH_
#define DALI_KERNELS_SIGNAL_WAVELET_MOTHER_WAVELET_CUH_

#include <vector>

#include "dali/core/common.h"
#include "dali/core/error_handling.h"
#include "dali/core/format.h"
#include "dali/core/util.h"
#include "dali/kernels/kernel.h"

#include <vector>

namespace dali {
namespace kernels {
namespace signal {
Expand All @@ -37,37 +37,24 @@ class HaarWavelet {
"Data type should be floating point");
public:
HaarWavelet() = default;
HaarWavelet(const std::vector<T> &args);
explicit HaarWavelet(const std::vector<T> &args);
~HaarWavelet() = default;

__device__ T operator()(const T &t) const;
};

template <typename T>
class MeyerWavelet {
static_assert(std::is_floating_point<T>::value,
"Data type should be floating point");
public:
MeyerWavelet() = default;
MeyerWavelet(const std::vector<T> &args);
~MeyerWavelet() = default;

__device__ T operator()(const T &t) const;
};

template <typename T>
class GaussianWavelet {
static_assert(std::is_floating_point<T>::value,
"Data type should be floating point");
public:
GaussianWavelet() = default;
GaussianWavelet(const std::vector<T> &args);
explicit GaussianWavelet(const std::vector<T> &args);
~GaussianWavelet() = default;

__device__ T operator()(const T &t) const;

private:
uint8_t N;
T n;
};

template <typename T>
Expand All @@ -76,7 +63,7 @@ class MexicanHatWavelet {
"Data type should be floating point");
public:
MexicanHatWavelet() = default;
MexicanHatWavelet(const std::vector<T> &args);
explicit MexicanHatWavelet(const std::vector<T> &args);
~MexicanHatWavelet() = default;

__device__ T operator()(const T &t) const;
Expand All @@ -91,7 +78,7 @@ class MorletWavelet {
"Data type should be floating point");
public:
MorletWavelet() = default;
MorletWavelet(const std::vector<T> &args);
explicit MorletWavelet(const std::vector<T> &args);
~MorletWavelet() = default;

__device__ T operator()(const T &t) const;
Expand All @@ -106,10 +93,14 @@ class ShannonWavelet {
"Data type should be floating point");
public:
ShannonWavelet() = default;
ShannonWavelet(const std::vector<T> &args);
explicit ShannonWavelet(const std::vector<T> &args);
~ShannonWavelet() = default;

__device__ T operator()(const T &t) const;

private:
T fb;
T fc;
};

template <typename T>
Expand All @@ -118,7 +109,7 @@ class FbspWavelet {
"Data type should be floating point");
public:
FbspWavelet() = default;
FbspWavelet(const std::vector<T> &args);
explicit FbspWavelet(const std::vector<T> &args);
~FbspWavelet() = default;

__device__ T operator()(const T &t) const;
Expand All @@ -130,7 +121,7 @@ class FbspWavelet {
};

} // namespace signal
} // namespace kernel
} // namespace kernels
} // namespace dali

#endif // DALI_KERNELS_SIGNAL_WAVELET_MOTHER_WAVELET_CUH_
27 changes: 15 additions & 12 deletions dali/kernels/signal/wavelet/wavelet_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,17 @@ __global__ void ComputeWavelet(const SampleDesc<T>* sample_data, W<T> wavelet) {
auto x = std::pow(2.0, a);
if (a == 0.0) {
shm[b_id] = sample.in[t_id];
}
else {
} else {
shm[b_id] = x * sample.in[t_id];
shm[1024] = std::pow(2.0, a / 2.0);
}
__syncthreads();
for (int i = 0; i < sample.size_b; ++i) {
const int64_t out_id = blockIdx.y * sample.size_b * sample.size_in + i * sample.size_in + t_id;
auto b = sample.b[i];
if (b == 0.0) {
sample.out[out_id] = wavelet(shm[b_id]);
}
else {
} else {
sample.out[out_id] = wavelet(shm[b_id] - b);
}
if (a != 0.0) {
Expand All @@ -65,7 +64,8 @@ __global__ void ComputeWavelet(const SampleDesc<T>* sample_data, W<T> wavelet) {
// translate input range information to input samples
template <typename T>
__global__ void ComputeInputSamples(const SampleDesc<T>* sample_data) {
const int64_t t_id = blockDim.x * blockDim.y * blockIdx.x + threadIdx.y * blockDim.x + threadIdx.x;
const int64_t block_size = blockDim.x * blockDim.y;
const int64_t t_id = block_size * blockIdx.x + threadIdx.y * blockDim.x + threadIdx.x;
auto& sample = sample_data[blockIdx.y];
if (t_id >= sample.size_in) return;
sample.in[t_id] = sample.span.begin + (T)t_id / sample.span.sampling_rate;
Expand Down Expand Up @@ -106,8 +106,9 @@ DLL_PUBLIC void WaveletGpu<T, W>::Run(KernelContext &ctx,
sample.b = b.tensor_data(i);
sample.size_b = b.shape.tensor_size(i);
sample.span = span;
sample.size_in = std::ceil((sample.span.end - sample.span.begin) * sample.span.sampling_rate);
CUDA_CALL(cudaMalloc(&(sample.in), sizeof(T) * sample.size_in));
sample.size_in =
std::ceil((sample.span.end - sample.span.begin) * sample.span.sampling_rate) + 1;
sample.in = ctx.scratchpad->AllocateGPU<T>(sample.size_in);
max_size_in = std::max(max_size_in, sample.size_in);
}

Expand All @@ -128,22 +129,24 @@ TensorListShape<> WaveletGpu<T, W>::GetOutputShape(const TensorListShape<> &a_sh
const TensorListShape<> &b_shape,
const WaveletSpan<T> &span) {
int N = a_shape.num_samples();
int in_size = std::ceil((span.end - span.begin) * span.sampling_rate);
int in_size = std::ceil((span.end - span.begin) * span.sampling_rate) + 1;
TensorListShape<> out_shape(N, 3);
TensorShape<> tshape;
for (int i = 0; i < N; i++) {
// output tensor will be 3-dimensional of shape:
// output tensor will be 3-dimensional of shape:
// a coeffs x b coeffs x signal samples
tshape = TensorShape<>({a_shape.tensor_shape(i).num_elements(), b_shape.tensor_shape(i).num_elements(), in_size});
tshape = TensorShape<>({a_shape.tensor_shape(i).num_elements(),
b_shape.tensor_shape(i).num_elements(),
in_size});
out_shape.set_tensor_shape(i, tshape);
}
return out_shape;
}

template class WaveletGpu<float, HaarWavelet>;
template class WaveletGpu<double, HaarWavelet>;
template class WaveletGpu<float, MeyerWavelet>;
template class WaveletGpu<double, MeyerWavelet>;
template class WaveletGpu<float, GaussianWavelet>;
template class WaveletGpu<double, GaussianWavelet>;
template class WaveletGpu<float, MexicanHatWavelet>;
template class WaveletGpu<double, MexicanHatWavelet>;
template class WaveletGpu<float, MorletWavelet>;
Expand Down
25 changes: 15 additions & 10 deletions dali/kernels/signal/wavelet/wavelet_gpu.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef DALI_KERNELS_SIGNAL_WAVELET_GPU_CUH_
#define DALI_KERNELS_SIGNAL_WAVELET_GPU_CUH_
#ifndef DALI_KERNELS_SIGNAL_WAVELET_WAVELET_GPU_CUH_
#define DALI_KERNELS_SIGNAL_WAVELET_WAVELET_GPU_CUH_

#include <memory>
#include <string>
#include <vector>
#include "dali/core/common.h"
#include "dali/core/error_handling.h"
#include "dali/core/format.h"
Expand All @@ -26,13 +27,16 @@

// makes sure both tensors have the same number of samples and
// that they're one-dimensional
#define ENFORCE_SHAPES(a_shape, b_shape) do { \
DALI_ENFORCE(a_shape.num_samples() == b_shape.num_samples(),"a and b tensors must have the same amount of samples."); \
for (int i = 0; i < a_shape.num_samples(); ++i) { \
DALI_ENFORCE(a_shape.tensor_shape(i).size() == 1, "Tensor of a coeffs should be 1-dimensional."); \
DALI_ENFORCE(b_shape.tensor_shape(i).size() == 1, "Tensor of b coeffs should be 1-dimensional."); \
} \
} while(0);
#define ENFORCE_SHAPES(a_shape, b_shape) do { \
DALI_ENFORCE(a_shape.num_samples() == b_shape.num_samples(), \
"a and b tensors must have the same amount of samples."); \
for (int i = 0; i < a_shape.num_samples(); ++i) { \
DALI_ENFORCE(a_shape.tensor_shape(i).size() == 1, \
"Tensor of a coeffs should be 1-dimensional."); \
DALI_ENFORCE(b_shape.tensor_shape(i).size() == 1, \
"Tensor of b coeffs should be 1-dimensional."); \
} \
} while (0);

namespace dali {
namespace kernels {
Expand Down Expand Up @@ -90,6 +94,7 @@ class DLL_PUBLIC WaveletGpu {
static TensorListShape<> GetOutputShape(const TensorListShape<> &a_shape,
const TensorListShape<> &b_shape,
const WaveletSpan<T> &span);

private:
W<T> wavelet_;
};
Expand All @@ -98,4 +103,4 @@ class DLL_PUBLIC WaveletGpu {
} // namespace kernels
} // namespace dali

#endif // DALI_KERNELS_SIGNAL_WAVELET_GPU_CUH_
#endif // DALI_KERNELS_SIGNAL_WAVELET_WAVELET_GPU_CUH_
Loading