Skip to content

Commit

Permalink
Fix wavelet exceptions and expand cwt operator docstr
Browse files Browse the repository at this point in the history
Wavelet constructor exceptions are now being handled correctly.
Morlet wavelet C argument has been removed.
The CWT operator docstr has been updated.
  • Loading branch information
JakubO committed Jul 3, 2023
1 parent 20d5d7e commit 7580be3
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 14 deletions.
19 changes: 9 additions & 10 deletions dali/kernels/signal/wavelet/mother_wavelet.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace signal {
template <typename T>
HaarWavelet<T>::HaarWavelet(const std::vector<T> &args) {
if (args.size() != 0) {
throw new std::invalid_argument("HaarWavelet doesn't accept any arguments.");
throw std::invalid_argument("HaarWavelet doesn't accept any arguments.");
}
}

Expand All @@ -45,10 +45,10 @@ template class HaarWavelet<double>;
template <typename T>
GaussianWavelet<T>::GaussianWavelet(const std::vector<T> &args) {
if (args.size() != 1) {
throw new std::invalid_argument("GaussianWavelet accepts exactly one argument - n.");
throw std::invalid_argument("GaussianWavelet accepts exactly one argument - n.");
}
if (args[0] < 1.0 || args[0] > 8.0) {
throw new std::invalid_argument(
throw std::invalid_argument(
"GaussianWavelet's argument n should be integer from range [1,8].");
}
this->n = args[0];
Expand Down Expand Up @@ -88,7 +88,7 @@ template class GaussianWavelet<double>;
template <typename T>
MexicanHatWavelet<T>::MexicanHatWavelet(const std::vector<T> &args) {
if (args.size() != 1) {
throw new std::invalid_argument("MexicanHatWavelet accepts exactly one argument - sigma.");
throw std::invalid_argument("MexicanHatWavelet accepts exactly one argument - sigma.");
}
this->sigma = args[0];
}
Expand All @@ -104,15 +104,14 @@ template class MexicanHatWavelet<double>;

template <typename T>
MorletWavelet<T>::MorletWavelet(const std::vector<T> &args) {
if (args.size() != 1) {
throw new std::invalid_argument("MorletWavelet accepts exactly 1 argument - C.");
if (args.size() != 0) {
throw std::invalid_argument("MorletWavelet doesn't accept any arguments.");
}
this->C = args[0];
}

template <typename T>
__device__ T MorletWavelet<T>::operator()(const T &t) const {
return C * std::exp(-std::pow(t, 2.0) / 2.0) * std::cos(5.0 * t);
return std::exp(-std::pow(t, 2.0) / 2.0) * std::cos(5.0 * t);
}

template class MorletWavelet<float>;
Expand All @@ -121,7 +120,7 @@ template class MorletWavelet<double>;
template <typename T>
ShannonWavelet<T>::ShannonWavelet(const std::vector<T> &args) {
if (args.size() != 2) {
throw new std::invalid_argument(
throw std::invalid_argument(
"ShannonWavelet accepts exactly 2 arguments -> fb, fc in that order.");
}
this->fb = args[0];
Expand All @@ -140,7 +139,7 @@ template class ShannonWavelet<double>;
template <typename T>
FbspWavelet<T>::FbspWavelet(const std::vector<T> &args) {
if (args.size() != 3) {
throw new std::invalid_argument(
throw std::invalid_argument(
"FbspWavelet accepts exactly 3 arguments -> m, fb, fc in that order.");
}
this->m = args[0];
Expand Down
3 changes: 0 additions & 3 deletions dali/kernels/signal/wavelet/mother_wavelet.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,6 @@ class MorletWavelet {
~MorletWavelet() = default;

__device__ T operator()(const T &t) const;

private:
T C;
};

template <typename T>
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/signal/wavelet/wavelet_run.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ void RunForName(const DALIWaveletName &name,
RunWaveletKernel<T, FbspWavelet>(kmgr, size, device, ctx, out, a, b, span, args);
break;
default:
throw new std::invalid_argument("Unknown wavelet name.");
throw std::invalid_argument("Unknown wavelet name.");
}
}

Expand Down

0 comments on commit 7580be3

Please sign in to comment.