Skip to content

Commit

Permalink
Merge pull request #1 from mwdowski/mwdowski
Browse files Browse the repository at this point in the history
Cwt WIP
  • Loading branch information
mwdowski authored May 18, 2023
2 parents 9d6e0b0 + 6bb49f5 commit 359d79c
Show file tree
Hide file tree
Showing 7 changed files with 344 additions and 0 deletions.
33 changes: 33 additions & 0 deletions dali/kernels/signal/wavelet/cwt_args.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef DALI_KERNELS_SIGNAL_WAVELET_CWT_ARGS_H_
#define DALI_KERNELS_SIGNAL_WAVELET_CWT_ARGS_H_

namespace dali {
namespace kernels {
namespace signal {
namespace wavelet {

template <typename T = float>
struct CwtArgs {
T a;
};

} // namespace wavelet
} // namespace signal
} // namespace kernels
} // namespace dali

#endif // DALI_KERNELS_SIGNAL_WAVELET_CWT_ARGS_H_
98 changes: 98 additions & 0 deletions dali/kernels/signal/wavelet/cwt_gpu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cmath>
#include <complex>
#include <vector>
#include "dali/core/common.h"
#include "dali/core/error_handling.h"
#include "dali/core/format.h"
#include "dali/kernels/kernel.h"
#include "dali/kernels/signal/wavelets/cwt_args.h"
#include "dali/kernels/signal/wavelets/cwt_gpu.h"

namespace dali {
namespace kernels {
namespace signal {
namespace wavelet {

template <typename T>
struct SampleDesc {
const T *in = nullptr;
T *out = nullptr;
int64_t size = 0;
};

template <typename T>
__global__ void CwtKernel(const SampleDesc<T> *sample_data, CwtArgs<T> args) {
const int64_t block_size = blockDim.y * blockDim.x;
const int64_t grid_size = gridDim.x * block_size;
const int sample_idx = blockIdx.y;
const auto sample = sample_data[sample_idx];
const int64_t offset = block_size * blockIdx.x;
const int64_t tid = threadIdx.y * blockDim.x + threadIdx.x;

for (int64_t idx = offset + tid; idx < sample.size; idx += grid_size) {
sample.out[idx] = sample.in[idx] * args.a;
}
}

template <typename T>
CwtGpu<T>::~CwtGpu() = default;

template <typename T>
KernelRequirements CwtGpu<T>::Setup(KernelContext &context,
const InListGPU<T, DynamicDimensions> &in) {
auto out_shape = in.shape;
const size_t num_samples = in.size();
ScratchpadEstimator se;
se.add<mm::memory_kind::host, SampleDesc<T>>(num_samples);
se.add<mm::memory_kind::device, SampleDesc<T>>(num_samples);
KernelRequirements req;
req.scratch_sizes = se.sizes;
req.output_shapes = {out_shape};
return req;
}

template <typename T>
void CwtGpu<T>::Run(KernelContext &context, const OutListGPU<T, DynamicDimensions> &out,
const InListGPU<T, DynamicDimensions> &in, const CwtArgs<T> &args) {
auto num_samples = in.size();
auto *sample_data = context.scratchpad->AllocateHost<SampleDesc<T>>(num_samples);

for (int i = 0; i < num_samples; i++) {
auto &sample = sample_data[i];
sample.out = out.tensor_data(i);
sample.in = in.tensor_data(i);
sample.size = volume(in.tensor_shape(i));
assert(sample.size == volume(out.tensor_shape(i)));
}

auto *sample_data_gpu = context.scratchpad->AllocateGPU<SampleDesc<T>>(num_samples);
CUDA_CALL(cudaMemcpyAsync(sample_data_gpu, sample_data, num_samples * sizeof(SampleDesc<T>),
cudaMemcpyHostToDevice, context.gpu.stream));

dim3 block(32, 32);
auto blocks_per_sample = std::max(32, 1024 / num_samples);
dim3 grid(blocks_per_sample, num_samples);
CwtKernel<T><<<grid, block, 0, context.gpu.stream>>>(sample_data_gpu, args);
}

template class CwtGpu<float>;
template class CwtGpu<double>;

} // namespace wavelet
} // namespace signal
} // namespace kernels
} // namespace dali
50 changes: 50 additions & 0 deletions dali/kernels/signal/wavelet/cwt_gpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef DALI_KERNELS_SIGNAL_WAVELET_CWT_GPU_H_
#define DALI_KERNELS_SIGNAL_WAVELET_CWT_GPU_H_

#include <memory>
#include "dali/core/common.h"
#include "dali/core/error_handling.h"
#include "dali/core/format.h"
#include "dali/core/util.h"
#include "dali/kernels/kernel.h"
#include "dali/kernels/signal/wavelets/cwt_args.h"

namespace dali {
namespace kernels {
namespace signal {
namespace wavelet {

template <typename T = float>
class DLL_PUBLIC CwtGpu {
public:
static_assert(std::is_floating_point<T>::value, "Only floating point types are supported");

DLL_PUBLIC ~CwtGpu();

DLL_PUBLIC KernelRequirements Setup(KernelContext &context,
const InListGPU<T, DynamicDimensions> &in);

DLL_PUBLIC void Run(KernelContext &context, const OutListGPU<T, DynamicDimensions> &out,
const InListGPU<T, DynamicDimensions> &in, const CwtArgs<T> &args);
};

} // namespace wavelet
} // namespace signal
} // namespace kernels
} // namespace dali

#endif // DALI_KERNELS_SIGNAL_WAVELET_CWT_GPU_H_
1 change: 1 addition & 0 deletions dali/operators/signal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ add_subdirectory(decibel)
if (BUILD_FFTS)
add_subdirectory(fft)
endif()
add_subdirectory(wavelet)

collect_headers(DALI_INST_HDRS PARENT_SCOPE)
collect_sources(DALI_OPERATOR_SRCS PARENT_SCOPE)
Expand Down
17 changes: 17 additions & 0 deletions dali/operators/signal/wavelet/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

collect_headers(DALI_INST_HDRS PARENT_SCOPE)
collect_sources(DALI_OPERATOR_SRCS PARENT_SCOPE)
collect_test_sources(DALI_OPERATOR_TEST_SRCS PARENT_SCOPE)
65 changes: 65 additions & 0 deletions dali/operators/signal/wavelet/cwt_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef DALI_OPERATORS_SIGNAL_WAVELETS_CWT_H_
#define DALI_OPERATORS_SIGNAL_WAVELETS_CWT_H_

#include <memory>
#include <vector>
#include "dali/core/common.h"
#include "dali/kernels/signal/wavelets/cwt_args.h"
#include "dali/pipeline/operator/common.h"
#include "dali/pipeline/operator/operator.h"
#include "dali/pipeline/util/operator_impl_utils.h"

namespace dali {

template <typename Backend>
class Cwt : public Operator<Backend> {
public:
explicit Cwt(const OpSpec &spec) : Operator<Backend>(spec) {
if (!spec.HasArgument("a")) {
DALI_ENFORCE("`a` argument must be provided.");
}
args_.a = spec.GetArgument<float>("a");
}

protected:
bool CanInferOutputs() const override {
return true;
}

bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override {
assert(impl_ != nullptr);
return impl_->SetupImpl(output_desc, ws);
}

void RunImpl(Workspace &ws) override {
assert(impl_ != nullptr);
impl_->RunImpl(ws);
}

USE_OPERATOR_MEMBERS();
using Operator<Backend>::RunImpl;

kernels::KernelManager kmgr_;
kernels::signal::wavelets::CwtArgs<float> args_;

std::unique_ptr<OpImplBase<Backend>> impl_;
DALIDataType type_ = DALI_NO_TYPE;
};

} // namespace dali

#endif // DALI_OPERATORS_SIGNAL_WAVELETS_CWT_H_
80 changes: 80 additions & 0 deletions dali/operators/signal/wavelet/cwt_op_gpu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <memory>
#include <utility>
#include <vector>
#include "dali/core/static_switch.h"
#include "dali/kernels/kernel_manager.h"
#include "dali/kernels/kernel_params.h"
#include "dali/kernels/signal/wavelets/cwt_args.h"
#include "dali/kernels/signal/wavelets/cwt_gpu.h"
#include "dali/operators/signal/wavelets/cwt_op.h"
#include "dali/pipeline/data/views.h"

namespace dali {

DALI_SCHEMA(Cwt).DocStr("by MW").NumInput(1).NumOutput(1).AddArg("a", "costam",
type2id<float>::value);

template <typename T>
struct CwtImplGPU : public OpImplBase<GPUBackend> {
public:
using CwtArgs = kernels::signal::wavelets::CwtArgs<T>;
using CwtKernel = kernels::signal::wavelets::CwtGpu<T>;

explicit CwtImplGPU(CwtArgs args) : args_(std::move(args)) {
kmgr_cwt_.Resize<CwtKernel>(1);
}

bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override {
const auto &input = ws.Input<GPUBackend>(0);
auto in_view = view<const T>(input);

auto type = type2id<T>::value;

kernels::KernelContext ctx;
ctx.gpu.stream = ws.stream();

auto &req = kmgr_cwt_.Setup<CwtKernel>(0, ctx, in_view);
output_desc.resize(1);
output_desc[0].type = type;
output_desc[0].shape = req.output_shapes[0];

return true;
}

void RunImpl(Workspace &ws) override {
const auto &input = ws.Input<GPUBackend>(0);
auto &output = ws.Output<GPUBackend>(0);

auto in_view = view<const T>(input);
auto out_view = view<T>(output);

kernels::KernelContext ctx;
ctx.gpu.stream = ws.stream();

kmgr_cwt_.Run<CwtKernel>(0, ctx, out_view, in_view, args_);
}

private:
CwtArgs args_;
kernels::KernelManager kmgr_cwt_;
std::vector<OutputDesc> cwt_out_desc_;
TensorList<GPUBackend> cwt_out_;
};

DALI_REGISTER_OPERATOR(Cwt, Cwt<GPUBackend>, GPU);

} // namespace dali

0 comments on commit 359d79c

Please sign in to comment.