Skip to content

Commit

Permalink
Equalization operator
Browse files Browse the repository at this point in the history
Signed-off-by: Kamil Tokarski <ktokarski@nvidia.com>
  • Loading branch information
stiepan committed Jan 18, 2023
1 parent fdc7d8d commit ae107e5
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 0 deletions.
28 changes: 28 additions & 0 deletions dali/operators/image/color/equalize.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "dali/operators/image/color/equalize.h"

namespace dali {

DALI_SCHEMA(experimental__Equalize)
.DocStr(R"code(Performs grayscale/per-channel histogram equalization.
The supported inputs are images and videos of uint8_t type.)code")
.NumInput(1)
.NumOutput(1)
.InputLayout(0, {"HW", "HWC", "CHW", "FHW", "FHWC", "FCHW"})
.AllowSequences();

} // namespace dali
80 changes: 80 additions & 0 deletions dali/operators/image/color/equalize.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <vector>

#include "dali/kernels/dynamic_scratchpad.h"
#include "dali/kernels/imgproc/color_manipulation/equalize/equalize.h"
#include "dali/kernels/kernel_manager.h"
#include "dali/operators/image/color/equalize.h"

namespace dali {

namespace equalize {

class EqualizeGPU : public Equalize<GPUBackend> {
using Kernel = kernels::equalize::EqualizeKernelGpu;

public:
explicit EqualizeGPU(const OpSpec &spec) : Equalize<GPUBackend>(spec) {
kmgr_.Resize<Kernel>(1);
}

protected:
void RunImpl(Workspace &ws) override {
const auto &input = ws.Input<GPUBackend>(0);
auto &output = ws.Output<GPUBackend>(0);
auto input_type = input.type();
auto layout = input.GetLayout();
DALI_ENFORCE(input_type == type2id<uint8_t>::value,
make_string("Unsupported input type for equalize operator: ", input_type,
". Expected input type: `uint8_t`."));
// enforced by the layouts specified in operator schema
assert(layout.size() == 2 || layout.size() == 3);
output.SetLayout(layout);
kernels::DynamicScratchpad scratchpad({}, AccessOrder(ws.stream()));
ctx_.gpu.stream = ws.stream();
ctx_.scratchpad = &scratchpad;
auto out_view = view<uint8_t>(output);
auto in_view = view<const uint8_t>(input);
auto out_shape = GetFlattenedShape(out_view.shape);
auto in_shape = GetFlattenedShape(in_view.shape);
TensorListView<StorageGPU, uint8_t, 2> out_view_flat{out_view.data, out_shape};
TensorListView<StorageGPU, const uint8_t, 2> in_view_flat{in_view.data, in_shape};
kmgr_.Run<Kernel>(0, ctx_, out_view_flat, in_view_flat);
}

template <int ndim>
TensorListShape<2> GetFlattenedShape(TensorListShape<ndim> shape) {
if (shape.sample_dim() == 3) { // has_channels
return collapse_dims<2>(shape, {{0, shape.sample_dim() - 1}});
} else {
int batch_size = shape.num_samples();
TensorListShape<2> ret{batch_size};
for (int sample_idx = 0; sample_idx < batch_size; sample_idx++) {
ret.set_tensor_shape(sample_idx, TensorShape<2>(shape[sample_idx].num_elements(), 1));
}
return ret;
}
}

kernels::KernelManager kmgr_;
kernels::KernelContext ctx_;
};

} // namespace equalize

DALI_REGISTER_OPERATOR(experimental__Equalize, equalize::EqualizeGPU, GPU);

} // namespace dali
68 changes: 68 additions & 0 deletions dali/operators/image/color/equalize.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef DALI_OPERATORS_IMAGE_COLOR_EQUALIZE_H_
#define DALI_OPERATORS_IMAGE_COLOR_EQUALIZE_H_

#include <vector>

#include "dali/pipeline/operator/common.h"
#include "dali/pipeline/operator/op_spec.h"
#include "dali/pipeline/operator/operator.h"
#include "dali/pipeline/operator/sequence_operator.h"

#define EQUALIZE_SUPPORTED_TYPES (uint8_t)

namespace dali {

template <typename Backend>
class Equalize : public SequenceOperator<Backend> {
public:
explicit Equalize(const OpSpec &spec) : SequenceOperator<Backend>(spec) {}

protected:
DISABLE_COPY_MOVE_ASSIGN(Equalize);
USE_OPERATOR_MEMBERS();

protected:
bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override {
output_desc.resize(1);
output_desc[0].type = ws.GetInputDataType(0);
// output_desc[0].shape is set by ProcessOutputDesc
return true;
}

bool CanInferOutputs() const override {
return true;
}

bool ShouldExpandChannels(int input_idx) const override {
(void)input_idx;
return true;
}

// Overrides unnecessary shape coalescing for video/sequence inputs
bool ProcessOutputDesc(std::vector<OutputDesc> &output_desc, const Workspace &ws,
bool is_inferred) override {
assert(is_inferred && output_desc.size() == 1);
const auto &input = ws.Input<Backend>(0);
// The shape of data stays untouched
output_desc[0].shape = input.shape();
return true;
}
};

} // namespace dali

#endif // DALI_OPERATORS_IMAGE_COLOR_EQUALIZE_H_
69 changes: 69 additions & 0 deletions dali/test/python/operator_2/test_equalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import cv2
import numpy as np

from nvidia.dali import pipeline_def, fn, types
from test_utils import get_dali_extra_path, check_batch
from nose2.tools import params

data_root = get_dali_extra_path()
images_dir = os.path.join(data_root, 'db', 'single', 'jpeg')


def equalize_cv_baseline(img, layout):
if layout == "HW":
return cv2.equalizeHist(img)
if layout == "HWC":
img = img.transpose(2, 0, 1)
axis = 2
else:
assert layout == "CHW", f"{layout}"
axis = 0
return np.stack([cv2.equalizeHist(channel) for channel in img], axis=axis)


@pipeline_def
def images_pipeline(layout):
images, _ = fn.readers.file(name="Reader", file_root=images_dir, prefetch_queue_depth=2,
random_shuffle=True, seed=42)
if layout == "HW":
images = fn.decoders.image(images, device="mixed", output_type=types.GRAY)
images = fn.squeeze(images, axes=2)
else:
assert layout in ["HWC", "CHW"], f"{layout}"
images = fn.decoders.image(images, device="mixed", output_type=types.RGB)
if layout == "CHW":
images = fn.transpose(images, perm=[2, 0, 1])
equalized = fn.experimental.equalize(images)
return equalized, images


@params(("HWC", 1), ("HWC", 32), ("CHW", 1), ("CHW", 7), ("HW", 253), ("HW", 128))
def test_image_pipeline(layout, batch_size):
num_iters = 2

pipe = images_pipeline(num_threads=4, device_id=0, batch_size=batch_size, layout=layout)
pipe.build()

for _ in range(num_iters):
equalized, imgs = pipe.run()
equalized = [np.array(img) for img in equalized.as_cpu()]
imgs = [np.array(img) for img in imgs.as_cpu()]
assert len(equalized) == len(imgs)
baseline = [equalize_cv_baseline(img, layout) for img in imgs]
check_batch(equalized, baseline)
1 change: 1 addition & 0 deletions dali/test/python/test_dali_cpu_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,6 +1331,7 @@ def video_input_pipeline(input_name):
"paste", # not supported for CPU
"experimental.audio_resample", # Alias of audio_resample (already tested)
"experimental.debayer", # not supported for CPU
"experimental.equalize", # not supported for CPU
"experimental.filter", # not supported for CPU
"experimental.inflate", # not supported for CPU
"experimental.remap", # operator is GPU-only
Expand Down
2 changes: 2 additions & 0 deletions dali/test/python/test_dali_variable_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def numba_setup_out_shape(out_shape, in_shape):
(fn.coord_transform, {'T': 2}),
(fn.coord_transform, {'M': .5}),
(fn.crop, {'crop': (5, 5)}),
(fn.experimental.equalize, {'devices': ['gpu']}),
(fn.erase, {'anchor': [0.3], 'axis_names': "H", 'normalized_anchor': True,
'shape': [0.1], 'normalized_shape': True}),
(fn.fast_resize_crop_mirror, {'crop': [5, 5], 'resize_shorter': 10, 'devices': ['cpu']}),
Expand Down Expand Up @@ -1325,6 +1326,7 @@ def get_data(batch_size):
"decoders.image_slice",
"dl_tensor_python_function",
"dump_image",
"experimental.equalize",
"element_extract",
"erase",
"erase",
Expand Down

0 comments on commit ae107e5

Please sign in to comment.