forked from NVIDIA/DALI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from mwdowski/mwdowski
Cwt WIP
- Loading branch information
Showing
7 changed files
with
344 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef DALI_KERNELS_SIGNAL_WAVELET_CWT_ARGS_H_ | ||
#define DALI_KERNELS_SIGNAL_WAVELET_CWT_ARGS_H_ | ||
|
||
namespace dali { | ||
namespace kernels { | ||
namespace signal { | ||
namespace wavelet { | ||
|
||
template <typename T = float> | ||
struct CwtArgs { | ||
T a; | ||
}; | ||
|
||
} // namespace wavelet | ||
} // namespace signal | ||
} // namespace kernels | ||
} // namespace dali | ||
|
||
#endif // DALI_KERNELS_SIGNAL_WAVELET_CWT_ARGS_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
// Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <cmath> | ||
#include <complex> | ||
#include <vector> | ||
#include "dali/core/common.h" | ||
#include "dali/core/error_handling.h" | ||
#include "dali/core/format.h" | ||
#include "dali/kernels/kernel.h" | ||
#include "dali/kernels/signal/wavelets/cwt_args.h" | ||
#include "dali/kernels/signal/wavelets/cwt_gpu.h" | ||
|
||
namespace dali { | ||
namespace kernels { | ||
namespace signal { | ||
namespace wavelet { | ||
|
||
template <typename T> | ||
struct SampleDesc { | ||
const T *in = nullptr; | ||
T *out = nullptr; | ||
int64_t size = 0; | ||
}; | ||
|
||
template <typename T> | ||
__global__ void CwtKernel(const SampleDesc<T> *sample_data, CwtArgs<T> args) { | ||
const int64_t block_size = blockDim.y * blockDim.x; | ||
const int64_t grid_size = gridDim.x * block_size; | ||
const int sample_idx = blockIdx.y; | ||
const auto sample = sample_data[sample_idx]; | ||
const int64_t offset = block_size * blockIdx.x; | ||
const int64_t tid = threadIdx.y * blockDim.x + threadIdx.x; | ||
|
||
for (int64_t idx = offset + tid; idx < sample.size; idx += grid_size) { | ||
sample.out[idx] = sample.in[idx] * args.a; | ||
} | ||
} | ||
|
||
template <typename T> | ||
CwtGpu<T>::~CwtGpu() = default; | ||
|
||
template <typename T> | ||
KernelRequirements CwtGpu<T>::Setup(KernelContext &context, | ||
const InListGPU<T, DynamicDimensions> &in) { | ||
auto out_shape = in.shape; | ||
const size_t num_samples = in.size(); | ||
ScratchpadEstimator se; | ||
se.add<mm::memory_kind::host, SampleDesc<T>>(num_samples); | ||
se.add<mm::memory_kind::device, SampleDesc<T>>(num_samples); | ||
KernelRequirements req; | ||
req.scratch_sizes = se.sizes; | ||
req.output_shapes = {out_shape}; | ||
return req; | ||
} | ||
|
||
template <typename T> | ||
void CwtGpu<T>::Run(KernelContext &context, const OutListGPU<T, DynamicDimensions> &out, | ||
const InListGPU<T, DynamicDimensions> &in, const CwtArgs<T> &args) { | ||
auto num_samples = in.size(); | ||
auto *sample_data = context.scratchpad->AllocateHost<SampleDesc<T>>(num_samples); | ||
|
||
for (int i = 0; i < num_samples; i++) { | ||
auto &sample = sample_data[i]; | ||
sample.out = out.tensor_data(i); | ||
sample.in = in.tensor_data(i); | ||
sample.size = volume(in.tensor_shape(i)); | ||
assert(sample.size == volume(out.tensor_shape(i))); | ||
} | ||
|
||
auto *sample_data_gpu = context.scratchpad->AllocateGPU<SampleDesc<T>>(num_samples); | ||
CUDA_CALL(cudaMemcpyAsync(sample_data_gpu, sample_data, num_samples * sizeof(SampleDesc<T>), | ||
cudaMemcpyHostToDevice, context.gpu.stream)); | ||
|
||
dim3 block(32, 32); | ||
auto blocks_per_sample = std::max(32, 1024 / num_samples); | ||
dim3 grid(blocks_per_sample, num_samples); | ||
CwtKernel<T><<<grid, block, 0, context.gpu.stream>>>(sample_data_gpu, args); | ||
} | ||
|
||
template class CwtGpu<float>; | ||
template class CwtGpu<double>; | ||
|
||
} // namespace wavelet | ||
} // namespace signal | ||
} // namespace kernels | ||
} // namespace dali |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef DALI_KERNELS_SIGNAL_WAVELET_CWT_GPU_H_ | ||
#define DALI_KERNELS_SIGNAL_WAVELET_CWT_GPU_H_ | ||
|
||
#include <memory> | ||
#include "dali/core/common.h" | ||
#include "dali/core/error_handling.h" | ||
#include "dali/core/format.h" | ||
#include "dali/core/util.h" | ||
#include "dali/kernels/kernel.h" | ||
#include "dali/kernels/signal/wavelets/cwt_args.h" | ||
|
||
namespace dali { | ||
namespace kernels { | ||
namespace signal { | ||
namespace wavelet { | ||
|
||
template <typename T = float> | ||
class DLL_PUBLIC CwtGpu { | ||
public: | ||
static_assert(std::is_floating_point<T>::value, "Only floating point types are supported"); | ||
|
||
DLL_PUBLIC ~CwtGpu(); | ||
|
||
DLL_PUBLIC KernelRequirements Setup(KernelContext &context, | ||
const InListGPU<T, DynamicDimensions> &in); | ||
|
||
DLL_PUBLIC void Run(KernelContext &context, const OutListGPU<T, DynamicDimensions> &out, | ||
const InListGPU<T, DynamicDimensions> &in, const CwtArgs<T> &args); | ||
}; | ||
|
||
} // namespace wavelet | ||
} // namespace signal | ||
} // namespace kernels | ||
} // namespace dali | ||
|
||
#endif // DALI_KERNELS_SIGNAL_WAVELET_CWT_GPU_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
collect_headers(DALI_INST_HDRS PARENT_SCOPE) | ||
collect_sources(DALI_OPERATOR_SRCS PARENT_SCOPE) | ||
collect_test_sources(DALI_OPERATOR_TEST_SRCS PARENT_SCOPE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef DALI_OPERATORS_SIGNAL_WAVELETS_CWT_H_ | ||
#define DALI_OPERATORS_SIGNAL_WAVELETS_CWT_H_ | ||
|
||
#include <memory> | ||
#include <vector> | ||
#include "dali/core/common.h" | ||
#include "dali/kernels/signal/wavelets/cwt_args.h" | ||
#include "dali/pipeline/operator/common.h" | ||
#include "dali/pipeline/operator/operator.h" | ||
#include "dali/pipeline/util/operator_impl_utils.h" | ||
|
||
namespace dali { | ||
|
||
template <typename Backend> | ||
class Cwt : public Operator<Backend> { | ||
public: | ||
explicit Cwt(const OpSpec &spec) : Operator<Backend>(spec) { | ||
if (!spec.HasArgument("a")) { | ||
DALI_ENFORCE("`a` argument must be provided."); | ||
} | ||
args_.a = spec.GetArgument<float>("a"); | ||
} | ||
|
||
protected: | ||
bool CanInferOutputs() const override { | ||
return true; | ||
} | ||
|
||
bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override { | ||
assert(impl_ != nullptr); | ||
return impl_->SetupImpl(output_desc, ws); | ||
} | ||
|
||
void RunImpl(Workspace &ws) override { | ||
assert(impl_ != nullptr); | ||
impl_->RunImpl(ws); | ||
} | ||
|
||
USE_OPERATOR_MEMBERS(); | ||
using Operator<Backend>::RunImpl; | ||
|
||
kernels::KernelManager kmgr_; | ||
kernels::signal::wavelets::CwtArgs<float> args_; | ||
|
||
std::unique_ptr<OpImplBase<Backend>> impl_; | ||
DALIDataType type_ = DALI_NO_TYPE; | ||
}; | ||
|
||
} // namespace dali | ||
|
||
#endif // DALI_OPERATORS_SIGNAL_WAVELETS_CWT_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <memory> | ||
#include <utility> | ||
#include <vector> | ||
#include "dali/core/static_switch.h" | ||
#include "dali/kernels/kernel_manager.h" | ||
#include "dali/kernels/kernel_params.h" | ||
#include "dali/kernels/signal/wavelets/cwt_args.h" | ||
#include "dali/kernels/signal/wavelets/cwt_gpu.h" | ||
#include "dali/operators/signal/wavelets/cwt_op.h" | ||
#include "dali/pipeline/data/views.h" | ||
|
||
namespace dali { | ||
|
||
DALI_SCHEMA(Cwt).DocStr("by MW").NumInput(1).NumOutput(1).AddArg("a", "costam", | ||
type2id<float>::value); | ||
|
||
template <typename T> | ||
struct CwtImplGPU : public OpImplBase<GPUBackend> { | ||
public: | ||
using CwtArgs = kernels::signal::wavelets::CwtArgs<T>; | ||
using CwtKernel = kernels::signal::wavelets::CwtGpu<T>; | ||
|
||
explicit CwtImplGPU(CwtArgs args) : args_(std::move(args)) { | ||
kmgr_cwt_.Resize<CwtKernel>(1); | ||
} | ||
|
||
bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override { | ||
const auto &input = ws.Input<GPUBackend>(0); | ||
auto in_view = view<const T>(input); | ||
|
||
auto type = type2id<T>::value; | ||
|
||
kernels::KernelContext ctx; | ||
ctx.gpu.stream = ws.stream(); | ||
|
||
auto &req = kmgr_cwt_.Setup<CwtKernel>(0, ctx, in_view); | ||
output_desc.resize(1); | ||
output_desc[0].type = type; | ||
output_desc[0].shape = req.output_shapes[0]; | ||
|
||
return true; | ||
} | ||
|
||
void RunImpl(Workspace &ws) override { | ||
const auto &input = ws.Input<GPUBackend>(0); | ||
auto &output = ws.Output<GPUBackend>(0); | ||
|
||
auto in_view = view<const T>(input); | ||
auto out_view = view<T>(output); | ||
|
||
kernels::KernelContext ctx; | ||
ctx.gpu.stream = ws.stream(); | ||
|
||
kmgr_cwt_.Run<CwtKernel>(0, ctx, out_view, in_view, args_); | ||
} | ||
|
||
private: | ||
CwtArgs args_; | ||
kernels::KernelManager kmgr_cwt_; | ||
std::vector<OutputDesc> cwt_out_desc_; | ||
TensorList<GPUBackend> cwt_out_; | ||
}; | ||
|
||
DALI_REGISTER_OPERATOR(Cwt, Cwt<GPUBackend>, GPU); | ||
|
||
} // namespace dali |