Skip to content

Commit 8172419

Browse files
stiepancyyever
authored andcommitted
Add Laplacian GPU operator (NVIDIA#3644)
* Add Laplacian GPU operator * Move LaplacianWindows to kernels * Add slow attr to some of Laplacian Python tests Signed-off-by: Kamil Tokarski <ktokarski@nvidia.com>
1 parent 0e46c92 commit 8172419

21 files changed

+954
-227
lines changed

dali/kernels/imgproc/convolution/laplacian_gpu.cuh

+14-2
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ namespace laplacian {
4040
* @brief Computes convolution to obtain partial derivative in one of the dimensions.
4141
* Convolution consits of `axes` windows, each to convolve along one dimension of the input data,
4242
* where `deriv_axis`-th window is supposed to compute partial derivative along that axis,
43-
* whereas the remaining windows should perform smoothing. If no smoothing is necessary in a whole
44-
* batch, you can prevent smoothing convolutions form running by passing empty lists for
43+
* whereas the remaining windows should perform smoothing. If no smoothing is necessary in
44+
* the whole batch, you can prevent smoothing convolutions from running by passing empty lists for
4545
* `window_sizes[i]` such that `i != deriv_axis`.
4646
*/
4747
template <typename Out, typename In, typename W, int axes, int deriv_axis, bool has_channels,
@@ -61,6 +61,18 @@ struct PartialDerivGpu {
6161
return false;
6262
}
6363

64+
/**
65+
* @param ctx Kernel context, used for scratch-pad.
66+
* @param in_shape List of input shapes, used by underlaying convolution kernels to infer
67+
* intermediate buffer sizes.
68+
* @param window_sizes For given `i`, `window_sizes[i]` contains per-sample window sizes
69+
* to be applied in a convolution along `i-th` axis. The length of
70+
* `window_sizes[deriv_axis]` must be equal to the input batch size.
71+
* Lists for other axes must either all have length equal to the input
72+
* batch size or all be empty. In the latter case, smoothing convolutions
73+
* will be omitted, i.e. only one convolution, along `deriv_axis`
74+
* will be applied.
75+
*/
6476
KernelRequirements Setup(KernelContext& ctx, const TensorListShape<ndim>& in_shape,
6577
const std::array<TensorListShape<1>, axes>& window_sizes) {
6678
has_smoothing_ = HasSmoothing(window_sizes);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef DALI_KERNELS_IMGPROC_CONVOLUTION_LAPLACIAN_WINDOWS_H_
16+
#define DALI_KERNELS_IMGPROC_CONVOLUTION_LAPLACIAN_WINDOWS_H_
17+
18+
#include <vector>
19+
20+
#include "dali/core/tensor_view.h"
21+
22+
namespace dali {
23+
namespace kernels {
24+
25+
template <typename T>
26+
class LaplacianWindows {
27+
public:
28+
explicit LaplacianWindows(int max_window_size) : smooth_computed_{1}, deriv_computed_{1} {
29+
Resize(max_window_size);
30+
*smoothing_views_[0](0) = 1;
31+
*deriv_views_[0](0) = 1;
32+
}
33+
34+
TensorView<StorageCPU, const T, 1> GetDerivWindow(int window_size) {
35+
assert(1 <= window_size && window_size <= max_window_size_);
36+
assert(window_size % 2 == 1);
37+
auto window_idx = window_size / 2;
38+
PrepareSmoothingWindow(window_size - 2);
39+
PrepareDerivWindow(window_size);
40+
return deriv_views_[window_idx];
41+
}
42+
43+
TensorView<StorageCPU, const T, 1> GetSmoothingWindow(int window_size) {
44+
assert(1 <= window_size && window_size <= max_window_size_);
45+
assert(window_size % 2 == 1);
46+
auto window_idx = window_size / 2;
47+
PrepareSmoothingWindow(window_size);
48+
return smoothing_views_[window_idx];
49+
}
50+
51+
private:
52+
/**
53+
* @brief Smoothing window of size 2n + 1 is [1, 2, 1] conv composed with itself n - 1 times
54+
* so that the window has appropriate size: it boils down to computing binominal coefficients:
55+
* (1 + 1) ^ (2n).
56+
*/
57+
inline void PrepareSmoothingWindow(int window_size) {
58+
for (; smooth_computed_ < window_size; smooth_computed_++) {
59+
auto cur_size = smooth_computed_ + 1;
60+
auto cur_idx = cur_size / 2;
61+
auto &prev_view = smoothing_views_[cur_size % 2 == 0 ? cur_idx - 1 : cur_idx];
62+
auto &view = smoothing_views_[cur_idx];
63+
auto prev_val = *prev_view(0);
64+
*view(0) = prev_val;
65+
for (int j = 1; j < cur_size - 1; j++) {
66+
auto val = *prev_view(j);
67+
*view(j) = prev_val + *prev_view(j);
68+
prev_val = val;
69+
}
70+
*view(cur_size - 1) = prev_val;
71+
}
72+
}
73+
74+
/**
75+
* @brief Derivative window of size 3 is [1, -2, 1] (which is [1, -1] composed with itself).
76+
* Bigger windows are convolutions of smoothing windows with [1, -2, 1].
77+
*/
78+
inline void PrepareDerivWindow(int window_size) {
79+
for (; deriv_computed_ < window_size; deriv_computed_++) {
80+
auto cur_size = deriv_computed_ + 1;
81+
auto cur_idx = cur_size / 2;
82+
auto &prev_view = cur_size % 2 == 0 ? smoothing_views_[cur_idx - 1] : deriv_views_[cur_idx];
83+
auto &view = deriv_views_[cur_idx];
84+
auto prev_val = *prev_view(0);
85+
*view(0) = -prev_val;
86+
for (int j = 1; j < cur_size - 1; j++) {
87+
auto val = *prev_view(j);
88+
*view(j) = prev_val - *prev_view(j);
89+
prev_val = val;
90+
}
91+
*view(cur_size - 1) = prev_val;
92+
}
93+
}
94+
95+
void Resize(int max_window_size) {
96+
assert(1 <= max_window_size && max_window_size % 2 == 1);
97+
max_window_size_ = max_window_size;
98+
int num_windows = (max_window_size + 1) / 2;
99+
int num_elements = num_windows * num_windows;
100+
smoothing_memory_.resize(num_elements);
101+
deriv_memory_.resize(num_elements);
102+
smoothing_views_.resize(num_windows);
103+
deriv_views_.resize(num_windows);
104+
int offset = 0;
105+
int window_size = 1;
106+
for (int i = 0; i < num_windows; i++) {
107+
smoothing_views_[i] = {&smoothing_memory_[offset], {window_size}};
108+
deriv_views_[i] = {&deriv_memory_[offset], {window_size}};
109+
offset += window_size;
110+
window_size += 2;
111+
}
112+
}
113+
114+
int smooth_computed_, deriv_computed_;
115+
int max_window_size_;
116+
std::vector<T> smoothing_memory_;
117+
std::vector<T> deriv_memory_;
118+
std::vector<TensorView<StorageCPU, T, 1>> smoothing_views_;
119+
std::vector<TensorView<StorageCPU, T, 1>> deriv_views_;
120+
};
121+
122+
} // namespace kernels
123+
} // namespace dali
124+
125+
#endif // DALI_KERNELS_IMGPROC_CONVOLUTION_LAPLACIAN_WINDOWS_H_
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <gtest/gtest.h>
16+
#include <cmath>
17+
#include <opencv2/imgproc.hpp>
18+
19+
#include "dali/kernels/common/utils.h"
20+
#include "dali/test/tensor_test_utils.h"
21+
#include "dali/test/test_tensors.h"
22+
23+
#include "dali/kernels/imgproc/convolution/laplacian_windows.h"
24+
25+
namespace dali {
26+
namespace kernels {
27+
28+
void CheckDerivWindow(int window_size, LaplacianWindows<float> &windows) {
29+
cv::Mat d, s;
30+
cv::getDerivKernels(d, s, 2, 0, window_size, true, CV_32F);
31+
const auto &window_view = windows.GetDerivWindow(window_size);
32+
float d_scale = std::exp2f(-window_size + 3);
33+
for (int i = 0; i < window_size; i++) {
34+
EXPECT_NEAR(window_view.data[i] * d_scale, d.at<float>(i), 1e-6f)
35+
<< "window_size: " << window_size << ", position: " << i;
36+
}
37+
}
38+
39+
void CheckSmoothingWindow(int window_size, LaplacianWindows<float> &windows) {
40+
cv::Mat d, s;
41+
cv::getDerivKernels(d, s, 2, 0, window_size, true, CV_32F);
42+
const auto &window_view = windows.GetSmoothingWindow(window_size);
43+
float s_scale = std::exp2f(-window_size + 1);
44+
for (int i = 0; i < window_size; i++) {
45+
EXPECT_NEAR(window_view.data[i] * s_scale, s.at<float>(i), 1e-6f)
46+
<< "window_size: " << window_size << ", position: " << i;
47+
}
48+
}
49+
50+
TEST(LaplacianWindowsTest, GetDerivWindows) {
51+
int max_window = 31;
52+
LaplacianWindows<float> windows{max_window};
53+
for (int window_size = 3; window_size <= max_window; window_size += 2) {
54+
CheckDerivWindow(window_size, windows);
55+
}
56+
}
57+
58+
TEST(LaplacianWindowsTest, GetSmoothingWindows) {
59+
int max_window = 31;
60+
LaplacianWindows<float> windows{max_window};
61+
for (int window_size = 3; window_size <= max_window; window_size += 2) {
62+
CheckSmoothingWindow(window_size, windows);
63+
}
64+
}
65+
66+
TEST(LaplacianWindowsTest, CheckPrecomputed) {
67+
int max_window = 31;
68+
LaplacianWindows<float> windows{max_window};
69+
for (int window_size = max_window; window_size >= 3; window_size -= 2) {
70+
CheckDerivWindow(window_size, windows);
71+
CheckSmoothingWindow(window_size, windows);
72+
}
73+
}
74+
75+
} // namespace kernels
76+
} // namespace dali

dali/operators/image/convolution/CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
1+
# Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
add_subdirectory(gaussian_blur_gpu)
16+
add_subdirectory(laplacian_gpu)
1617

1718
# Get all the source files and dump test files
1819
collect_headers(DALI_INST_HDRS PARENT_SCOPE)

dali/operators/image/convolution/laplacian.cc

+19-10
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "dali/core/static_switch.h"
2222
#include "dali/kernels/imgproc/convolution/laplacian_cpu.h"
23+
#include "dali/kernels/imgproc/convolution/laplacian_windows.h"
2324
#include "dali/kernels/kernel_manager.h"
2425
#include "dali/operators/image/convolution/laplacian.h"
2526
#include "dali/pipeline/data/views.h"
@@ -107,8 +108,12 @@ class LaplacianOpCpu : public OpImplBase<CPUBackend> {
107108
using Kernel = kernels::LaplacianCpu<Out, In, float, axes, has_channels>;
108109
static constexpr int ndim = Kernel::ndim;
109110

110-
explicit LaplacianOpCpu(const OpSpec& spec, const DimDesc& dim_desc)
111-
: spec_{spec}, args{spec}, dim_desc_{dim_desc} {}
111+
/**
112+
* @param spec Pointer to a persistent OpSpec object,
113+
* which is guaranteed to be alive for the entire lifetime of this object
114+
*/
115+
explicit LaplacianOpCpu(const OpSpec* spec, const DimDesc& dim_desc)
116+
: spec_{*spec}, args{*spec}, dim_desc_{dim_desc}, lap_windows_{maxWindowSize} {}
112117

113118
bool SetupImpl(std::vector<OutputDesc>& output_desc, const workspace_t<CPUBackend>& ws) override {
114119
const auto& input = ws.template Input<CPUBackend>(0);
@@ -126,7 +131,9 @@ class LaplacianOpCpu : public OpImplBase<CPUBackend> {
126131
const auto& window_sizes = args.GetWindowSizes(sample_idx);
127132
for (int i = 0; i < axes; i++) {
128133
for (int j = 0; j < axes; j++) {
129-
windows_[sample_idx][i][j] = lap_windows_.GetWindow(window_sizes[i][j], i == j);
134+
auto window_size = window_sizes[i][j];
135+
windows_[sample_idx][i][j] = i == j ? lap_windows_.GetDerivWindow(window_size) :
136+
lap_windows_.GetSmoothingWindow(window_size);
130137
}
131138
}
132139
}
@@ -182,19 +189,20 @@ class LaplacianOpCpu : public OpImplBase<CPUBackend> {
182189

183190
LaplacianArgs<axes> args;
184191
DimDesc dim_desc_;
192+
kernels::LaplacianWindows<float> lap_windows_;
185193

186194
kernels::KernelManager kmgr_;
187195
kernels::KernelContext ctx_;
188196

189-
LaplacianWindows<float> lap_windows_;
190197
// windows_[i][j] is a window used in convolution along j-th axis in the i-th partial derivative
191198
std::vector<std::array<std::array<TensorView<StorageCPU, const float, 1>, axes>, axes>> windows_;
192199
};
193200

194-
195201
} // namespace laplacian
196202

197-
bool Laplacian::SetupImpl(std::vector<OutputDesc>& output_desc, const workspace_t<CPUBackend>& ws) {
203+
template <>
204+
bool Laplacian<CPUBackend>::SetupImpl(std::vector<OutputDesc>& output_desc,
205+
const workspace_t<CPUBackend>& ws) {
198206
const auto& input = ws.template Input<CPUBackend>(0);
199207
auto layout = input.GetLayout();
200208
auto dim_desc = ParseAndValidateDim(input.shape().sample_dim(), layout);
@@ -211,10 +219,10 @@ bool Laplacian::SetupImpl(std::vector<OutputDesc>& output_desc, const workspace_
211219
BOOL_SWITCH(dim_desc.is_channel_last(), HasChannels, (
212220
if (dtype == input.type()) {
213221
using LaplacianSame = laplacian::LaplacianOpCpu<In, In, Axes, HasChannels>;
214-
impl_ = std::make_unique<LaplacianSame>(spec_, dim_desc);
222+
impl_ = std::make_unique<LaplacianSame>(&spec_, dim_desc);
215223
} else {
216224
using LaplacianFloat = laplacian::LaplacianOpCpu<float, In, Axes, HasChannels>;
217-
impl_ = std::make_unique<LaplacianFloat>(spec_, dim_desc);
225+
impl_ = std::make_unique<LaplacianFloat>(&spec_, dim_desc);
218226
}
219227
)); // NOLINT
220228
), DALI_FAIL("Axis count out of supported range.")); // NOLINT
@@ -224,10 +232,11 @@ bool Laplacian::SetupImpl(std::vector<OutputDesc>& output_desc, const workspace_
224232
return impl_->SetupImpl(output_desc, ws);
225233
}
226234

227-
void Laplacian::RunImpl(workspace_t<CPUBackend>& ws) {
235+
template <>
236+
void Laplacian<CPUBackend>::RunImpl(workspace_t<CPUBackend>& ws) {
228237
impl_->RunImpl(ws);
229238
}
230239

231-
DALI_REGISTER_OPERATOR(Laplacian, Laplacian, CPU);
240+
DALI_REGISTER_OPERATOR(Laplacian, Laplacian<CPUBackend>, CPU);
232241

233242
} // namespace dali

0 commit comments

Comments
 (0)