Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix wavelet exceptions and expand cwt operator docstr #6

Merged
merged 2 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions dali/kernels/signal/wavelet/mother_wavelet.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace signal {
template <typename T>
HaarWavelet<T>::HaarWavelet(const std::vector<T> &args) {
if (args.size() != 0) {
throw new std::invalid_argument("HaarWavelet doesn't accept any arguments.");
throw std::invalid_argument("HaarWavelet doesn't accept any arguments.");
}
}

Expand All @@ -45,10 +45,10 @@ template class HaarWavelet<double>;
template <typename T>
GaussianWavelet<T>::GaussianWavelet(const std::vector<T> &args) {
if (args.size() != 1) {
throw new std::invalid_argument("GaussianWavelet accepts exactly one argument - n.");
throw std::invalid_argument("GaussianWavelet accepts exactly one argument - n.");
}
if (args[0] < 1.0 || args[0] > 8.0) {
throw new std::invalid_argument(
throw std::invalid_argument(
"GaussianWavelet's argument n should be integer from range [1,8].");
}
this->n = args[0];
Expand Down Expand Up @@ -88,7 +88,7 @@ template class GaussianWavelet<double>;
template <typename T>
MexicanHatWavelet<T>::MexicanHatWavelet(const std::vector<T> &args) {
if (args.size() != 1) {
throw new std::invalid_argument("MexicanHatWavelet accepts exactly one argument - sigma.");
throw std::invalid_argument("MexicanHatWavelet accepts exactly one argument - sigma.");
}
this->sigma = args[0];
}
Expand All @@ -104,15 +104,14 @@ template class MexicanHatWavelet<double>;

template <typename T>
MorletWavelet<T>::MorletWavelet(const std::vector<T> &args) {
if (args.size() != 1) {
throw new std::invalid_argument("MorletWavelet accepts exactly 1 argument - C.");
if (args.size() != 0) {
throw std::invalid_argument("MorletWavelet doesn't accept any arguments.");
}
this->C = args[0];
}

template <typename T>
__device__ T MorletWavelet<T>::operator()(const T &t) const {
return C * std::exp(-std::pow(t, 2.0) / 2.0) * std::cos(5.0 * t);
return std::exp(-std::pow(t, 2.0) / 2.0) * std::cos(5.0 * t);
}

template class MorletWavelet<float>;
Expand All @@ -121,7 +120,7 @@ template class MorletWavelet<double>;
template <typename T>
ShannonWavelet<T>::ShannonWavelet(const std::vector<T> &args) {
if (args.size() != 2) {
throw new std::invalid_argument(
throw std::invalid_argument(
"ShannonWavelet accepts exactly 2 arguments -> fb, fc in that order.");
}
this->fb = args[0];
Expand All @@ -140,7 +139,7 @@ template class ShannonWavelet<double>;
template <typename T>
FbspWavelet<T>::FbspWavelet(const std::vector<T> &args) {
if (args.size() != 3) {
throw new std::invalid_argument(
throw std::invalid_argument(
"FbspWavelet accepts exactly 3 arguments -> m, fb, fc in that order.");
}
this->m = args[0];
Expand Down
3 changes: 0 additions & 3 deletions dali/kernels/signal/wavelet/mother_wavelet.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,6 @@ class MorletWavelet {
~MorletWavelet() = default;

__device__ T operator()(const T &t) const;

private:
T C;
};

template <typename T>
Expand Down
27 changes: 25 additions & 2 deletions dali/operators/signal/wavelet/cwt_op_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,31 @@

namespace dali {

DALI_SCHEMA(Cwt).DocStr("by MW").NumInput(1).NumOutput(1).AddArg("a", "costam",
type2id<float>::value);
DALI_SCHEMA(Cwt)
.DocStr(R"(Performs continuous wavelet transform on a 1D signal (for example, audio).

Result values of transform are computed for all specified scales.
Input data is expected to be one channel (shape being ``(nsamples,)``, ``(nsamples, 1)``
) of type float32.)")
.NumInput(1)
.NumOutput(1)
.AddArg("a", R"(List of scale coefficients of type float32.)", DALIDataType::DALI_FLOAT_VEC)
.AddArg("wavelet", R"(Name of mother wavelet. Currently supported wavelets' names are:
- HAAR - Haar wavelet
- GAUS - Gaussian wavelet
- MEXH - Mexican hat wavelet
- MORL - Morlet wavelet
- SHAN - Shannon wavleet
- FBSP - Frequency B-spline wavelet)", DALIDataType::DALI_WAVELET_NAME)
.AddArg("wavelet_args", R"(Additional arguments for mother wavelet. They are passed
as list of float32 values.
- HAAR - none
- GAUS - n (order of derivative)
- MEXH - sigma
- MORL - none
- SHAN - fb (bandwidth parameter > 0), fc (center frequency > 0)
- FBSP - m (order parameter >= 1), fb (bandwidth parameter > 0), fc (center frequency > 0)
)", DALIDataType::DALI_FLOAT_VEC);

template <typename T>
struct CwtImplGPU : public OpImplBase<GPUBackend> {
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/signal/wavelet/wavelet_run.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ void RunForName(const DALIWaveletName &name,
RunWaveletKernel<T, FbspWavelet>(kmgr, size, device, ctx, out, a, b, span, args);
break;
default:
throw new std::invalid_argument("Unknown wavelet name.");
throw std::invalid_argument("Unknown wavelet name.");
}
}

Expand Down