Skip to content

Commit

Permalink
[Fix] Optimize preprocess & fix pontential use-after-free (#229)
Browse files Browse the repository at this point in the history
* hold async data and wait only at the end of the pipeline

* fix use-after-free bugs

* fix wording

* bypass trivial cases for Pad to avoid ppl.cv's bug

* fix pad

* fix lint

* cleanup

* fix DefaultFormatBundle

* fix all cpu preprocess impl

* suppress log

* fix dynamic library build & add comments for SyncOnScopeExit
  • Loading branch information
lzhangzz authored Mar 28, 2022
1 parent fee55f3 commit 73cf3b5
Show file tree
Hide file tree
Showing 27 changed files with 138 additions and 30 deletions.
8 changes: 6 additions & 2 deletions csrc/core/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ class MMDEPLOY_API Event {

void *GetNative(ErrorCode *ec = nullptr);

bool operator==(const Event &other);
bool operator==(const Event &other) const { return impl_ == other.impl_; }

bool operator!=(const Event &other);
bool operator!=(const Event &other) const { return !(*this == other); }

explicit operator bool() const noexcept { return static_cast<bool>(impl_); }

Expand Down Expand Up @@ -285,6 +285,10 @@ class MMDEPLOY_API Buffer {

Allocator GetAllocator() const;

bool operator==(const Buffer &other) const { return impl_ == other.impl_; }

bool operator!=(const Buffer &other) const { return !(*this == other); }

explicit operator bool() const noexcept { return static_cast<bool>(impl_); }

private:
Expand Down
13 changes: 13 additions & 0 deletions csrc/core/utils/device_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include "device_utils.h"

#include "core/logger.h"

namespace mmdeploy {

Result<Mat> MakeAvailableOnDevice(const Mat& src, const Device& device, Stream& stream) {
Expand All @@ -26,4 +29,14 @@ Result<Tensor> MakeAvailableOnDevice(const Tensor& src, const Device& device, St
return dst;
}

SyncOnScopeExit::~SyncOnScopeExit() {
if (active_ && stream_) {
if (!stream_.Wait()) {
MMDEPLOY_ERROR("Implicit stream synchronization failed.");
} else {
MMDEPLOY_DEBUG("Implicit stream synchronization succeeded.");
}
}
}

} // namespace mmdeploy
22 changes: 22 additions & 0 deletions csrc/core/utils/device_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#ifndef MMDEPLOY_TRANSFORM_UTILS_H
#define MMDEPLOY_TRANSFORM_UTILS_H

#include <utility>

#include "core/mat.h"
#include "core/tensor.h"

Expand All @@ -26,6 +28,26 @@ MMDEPLOY_API Result<Mat> MakeAvailableOnDevice(const Mat& src, const Device& dev
*/
MMDEPLOY_API Result<Tensor> MakeAvailableOnDevice(const Tensor& src, const Device& device,
Stream& stream);


// Calls stream.Wait() on destruction if active is true. This class is used to force a wait
// operation before intermediate variables goes out of scope. Add variables in consideration to the
// tailing parameter pack to ensure correctness (this make sure SyncOnScopeExit is created later
// (thus will be destructed earlier) than the variables

class MMDEPLOY_API SyncOnScopeExit {
public:
template <typename... Ts>
explicit SyncOnScopeExit(Stream& stream, bool active, Ts&&...) noexcept
: stream_(stream), active_(active) {}

~SyncOnScopeExit();

private:
bool active_;
Stream& stream_;
};

} // namespace mmdeploy

#endif // MMDEPLOY_TRANSFORM_UTILS_H
5 changes: 5 additions & 0 deletions csrc/core/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ class ValueRef;
template <typename T>
class ValueIterator {
public:
using value_type = Value;
using difference_type = std::ptrdiff_t;
using pointer = value_type*;
using reference = value_type&;
using iterator_category = std::bidirectional_iterator_tag;
using object_iterator_t = typename T::Object::iterator;
using array_iterator_t = typename T::Array::iterator;
ValueIterator() = default;
Expand Down
2 changes: 2 additions & 0 deletions csrc/preprocess/cpu/crop_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class CenterCropImpl : public ::mmdeploy::CenterCropImpl {
int right) override {
OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_));

SyncOnScopeExit(stream_, src_tensor.buffer() != tensor.buffer(), src_tensor);

cv::Mat mat = Tensor2CVMat(src_tensor);
cv::Mat cropped_mat = Crop(mat, top, left, bottom, right);
return CVMat2Tensor(cropped_mat);
Expand Down
6 changes: 6 additions & 0 deletions csrc/preprocess/cpu/default_format_bundle_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ class DefaultFormatBundleImpl : public ::mmdeploy::DefaultFormatBundleImpl {
protected:
Result<Tensor> ToFloat32(const Tensor& tensor, const bool& img_to_float) override {
OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_));

SyncOnScopeExit(stream_, src_tensor.buffer() != tensor.buffer(), src_tensor);

auto data_type = src_tensor.desc().data_type;

if (img_to_float && data_type == DataType::kINT8) {
Expand All @@ -27,6 +30,9 @@ class DefaultFormatBundleImpl : public ::mmdeploy::DefaultFormatBundleImpl {

Result<Tensor> HWC2CHW(const Tensor& tensor) override {
OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_));

SyncOnScopeExit(stream_, src_tensor.buffer() != tensor.buffer(), src_tensor);

auto shape = src_tensor.shape();
int height = shape[1];
int width = shape[2];
Expand Down
3 changes: 3 additions & 0 deletions csrc/preprocess/cpu/image2tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ class ImageToTensorImpl : public ::mmdeploy::ImageToTensorImpl {
protected:
Result<Tensor> HWC2CHW(const Tensor& tensor) override {
OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_));

SyncOnScopeExit(stream_, src_tensor.buffer() != tensor.buffer(), src_tensor);

auto shape = src_tensor.shape();
int height = shape[1];
int width = shape[2];
Expand Down
3 changes: 3 additions & 0 deletions csrc/preprocess/cpu/normalize_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ class NormalizeImpl : public ::mmdeploy::NormalizeImpl {
protected:
Result<Tensor> NormalizeImage(const Tensor& tensor) override {
OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_));

SyncOnScopeExit(stream_, src_tensor.buffer() != tensor.buffer(), src_tensor);

auto mat = Tensor2CVMat(src_tensor);
auto dst_mat = Normalize(mat, arg_.mean, arg_.std, arg_.to_rgb, true);
return CVMat2Tensor(dst_mat);
Expand Down
3 changes: 3 additions & 0 deletions csrc/preprocess/cpu/pad_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class PadImpl : public ::mmdeploy::PadImpl {
protected:
Result<Tensor> PadImage(const Tensor& img, const std::array<int, 4>& padding) override {
OUTCOME_TRY(auto tensor, MakeAvailableOnDevice(img, device_, stream_));

SyncOnScopeExit(stream_, tensor.buffer() != img.buffer(), tensor);

cv::Mat dst_mat = Pad(Tensor2CVMat(tensor), padding[1], padding[0], padding[3], padding[2],
border_type_, arg_.pad_val);
return CVMat2Tensor(dst_mat);
Expand Down
2 changes: 2 additions & 0 deletions csrc/preprocess/cpu/resize_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class ResizeImpl final : public ::mmdeploy::ResizeImpl {
Result<Tensor> ResizeImage(const Tensor& img, int dst_h, int dst_w) override {
OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(img, device_, stream_));

SyncOnScopeExit(stream_, src_tensor.buffer() != img.buffer(), src_tensor);

auto src_mat = Tensor2CVMat(src_tensor);
auto dst_mat = Resize(src_mat, dst_h, dst_w, arg_.interpolation);

Expand Down
2 changes: 2 additions & 0 deletions csrc/preprocess/cuda/crop_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class CenterCropImpl : public ::mmdeploy::CenterCropImpl {
int right) override {
OUTCOME_TRY(auto device_tensor, MakeAvailableOnDevice(tensor, device_, stream_));

SyncOnScopeExit sync(stream_, device_tensor.buffer() != tensor.buffer(), device_tensor);

auto stream = GetNative<cudaStream_t>(stream_);
auto desc = device_tensor.desc();

Expand Down
6 changes: 6 additions & 0 deletions csrc/preprocess/cuda/default_format_bundle_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ class DefaultFormatBundleImpl final : public ::mmdeploy::DefaultFormatBundleImpl
protected:
Result<Tensor> ToFloat32(const Tensor& tensor, const bool& img_to_float) override {
OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_));

SyncOnScopeExit sync(stream_, src_tensor.buffer() != tensor.buffer(), src_tensor);

auto data_type = src_tensor.data_type();
auto h = tensor.shape(1);
auto w = tensor.shape(2);
Expand All @@ -45,6 +48,9 @@ class DefaultFormatBundleImpl final : public ::mmdeploy::DefaultFormatBundleImpl

Result<Tensor> HWC2CHW(const Tensor& tensor) override {
OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_));

SyncOnScopeExit sync(stream_, src_tensor.buffer() != tensor.buffer(), src_tensor);

auto h = tensor.shape(1);
auto w = tensor.shape(2);
auto c = tensor.shape(3);
Expand Down
3 changes: 3 additions & 0 deletions csrc/preprocess/cuda/image2tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ class ImageToTensorImpl final : public ::mmdeploy::ImageToTensorImpl {
protected:
Result<Tensor> HWC2CHW(const Tensor& tensor) override {
OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_));

SyncOnScopeExit sync(stream_, src_tensor.buffer() != tensor.buffer(), src_tensor);

auto h = tensor.shape(1);
auto w = tensor.shape(2);
auto c = tensor.shape(3);
Expand Down
9 changes: 7 additions & 2 deletions csrc/preprocess/cuda/load_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ class PrepareImageImpl : public ::mmdeploy::PrepareImageImpl {
Tensor Mat2Tensor(const mmdeploy::Mat& mat) {
TensorDesc desc{
mat.buffer().GetDevice(), mat.type(), {1, mat.height(), mat.width(), mat.channel()}, ""};
shared_ptr<void> data(mat.data<void>(), [mat = mat](void* p) {});
return Tensor(desc, data);
return Tensor(std::move(desc), mat.buffer());
}

protected:
Expand All @@ -39,6 +38,9 @@ class PrepareImageImpl : public ::mmdeploy::PrepareImageImpl {

cudaStream_t stream = ::mmdeploy::GetNative<cudaStream_t>(stream_);
Mat dst_mat(src_mat.height(), src_mat.width(), PixelFormat::kBGR, src_mat.type(), device_);

SyncOnScopeExit sync(stream_, true, src_mat, dst_mat);

ppl::common::RetCode ret = 0;

int src_h = src_mat.height();
Expand Down Expand Up @@ -97,6 +99,9 @@ class PrepareImageImpl : public ::mmdeploy::PrepareImageImpl {
cudaStream_t stream = ::mmdeploy::GetNative<cudaStream_t>(stream_);
Mat dst_mat(src_mat.height(), src_mat.width(), PixelFormat::kGRAYSCALE, src_mat.type(),
device_);

SyncOnScopeExit sync(stream_, true, src_mat, dst_mat);

ppl::common::RetCode ret = 0;

int src_h = src_mat.height();
Expand Down
3 changes: 3 additions & 0 deletions csrc/preprocess/cuda/normalize_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ class NormalizeImpl : public ::mmdeploy::NormalizeImpl {
protected:
Result<Tensor> NormalizeImage(const Tensor& tensor) override {
OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_));

SyncOnScopeExit sync(stream_, src_tensor.buffer() != tensor.buffer(), src_tensor);

auto src_desc = src_tensor.desc();
int h = (int)src_desc.shape[1];
int w = (int)src_desc.shape[2];
Expand Down
2 changes: 2 additions & 0 deletions csrc/preprocess/cuda/pad_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class PadImpl : public ::mmdeploy::PadImpl {
Result<Tensor> PadImage(const Tensor& img, const array<int, 4>& padding) override {
OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(img, device_, stream_));

SyncOnScopeExit sync(stream_, src_tensor.buffer() != img.buffer(), src_tensor);

auto desc = src_tensor.desc();
int height = desc.shape[1];
int width = desc.shape[2];
Expand Down
3 changes: 3 additions & 0 deletions csrc/preprocess/cuda/resize_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class ResizeImpl final : public ::mmdeploy::ResizeImpl {
protected:
Result<Tensor> ResizeImage(const Tensor& tensor, int dst_h, int dst_w) override {
OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_));

SyncOnScopeExit sync(stream_, src_tensor.buffer() != tensor.buffer(), src_tensor);

TensorDesc dst_desc{
device_, src_tensor.data_type(), {1, dst_h, dst_w, src_tensor.shape(3)}, src_tensor.name()};
Tensor dst_tensor(dst_desc);
Expand Down
12 changes: 7 additions & 5 deletions csrc/preprocess/transform/compose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,16 @@ Compose::Compose(const Value& args, int version) : Transform(args) {

Result<Value> Compose::Process(const Value& input) {
Value output = input;
Value::Array intermediates;
for (auto& transform : transforms_) {
auto t = transform->Process(output);
OUTCOME_TRY(stream_.Wait());
if (!t) {
return t;
OUTCOME_TRY(auto t, transform->Process(output));
if (auto it = t.find("__data__"); it != t.end()) {
std::move(it->begin(), it->end(), std::back_inserter(intermediates));
it->array().clear();
}
output = std::move(t).value();
output = std::move(t);
}
OUTCOME_TRY(stream_.Wait());
return std::move(output);
}

Expand Down
3 changes: 2 additions & 1 deletion csrc/preprocess/transform/crop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ Result<Value> CenterCropImpl::Process(const Value& input) {

auto& shape = dst_tensor.desc().shape;

output[key] = dst_tensor;
output["img_shape"] = {shape[0], shape[1], shape[2], shape[3]};
if (input.contains("scale_factor")) {
// image has been processed by `Resize` transform before.
Expand All @@ -61,6 +60,8 @@ Result<Value> CenterCropImpl::Process(const Value& input) {
output["offset"].push_back(x1);
output["offset"].push_back(y1);
}

SetTransformData(output, key, std::move(dst_tensor));
}

MMDEPLOY_DEBUG("output: {}", to_json(output).dump(2));
Expand Down
6 changes: 3 additions & 3 deletions csrc/preprocess/transform/default_format_bundle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ Result<Value> DefaultFormatBundleImpl::Process(const Value& input) {
Value output = input;
if (input.contains("img")) {
Tensor in_tensor = input["img"].get<Tensor>();
OUTCOME_TRY(output["img"], ToFloat32(in_tensor, arg_.img_to_float));
OUTCOME_TRY(auto tensor, ToFloat32(in_tensor, arg_.img_to_float));

Tensor tensor = output["img"].get<Tensor>();
// set default meta keys
if (!output.contains("pad_shape")) {
for (auto v : tensor.shape()) {
Expand All @@ -42,7 +41,8 @@ Result<Value> DefaultFormatBundleImpl::Process(const Value& input) {
}

// transpose
OUTCOME_TRY(output["img"], HWC2CHW(tensor));
OUTCOME_TRY(tensor, HWC2CHW(tensor));
SetTransformData(output, "img", std::move(tensor));
}

MMDEPLOY_DEBUG("DefaultFormatBundle output: {}", to_json(output).dump(2));
Expand Down
3 changes: 2 additions & 1 deletion csrc/preprocess/transform/image2tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ Result<Value> ImageToTensorImpl::Process(const Value& input) {
assert(shape.size() == 4);
assert(shape[3] == 1 || shape[3] == 3);

OUTCOME_TRY(output[key], HWC2CHW(src_tensor));
OUTCOME_TRY(auto dst, HWC2CHW(src_tensor));
SetTransformData(output, key, std::move(dst));
} // for key
MMDEPLOY_DEBUG("output: {}", to_json(output).dump(2));
return output;
Expand Down
4 changes: 3 additions & 1 deletion csrc/preprocess/transform/load.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,14 @@ Result<Value> PrepareImageImpl::Process(const Value& input) {

OUTCOME_TRY(auto tensor, std::move(res));

output["img"] = tensor;
for (auto v : tensor.desc().shape) {
output["img_shape"].push_back(v);
}
output["ori_shape"] = {1, src_mat.height(), src_mat.width(), src_mat.channel()};
output["img_fields"].push_back("img");

SetTransformData(output, "img", std::move(tensor));

MMDEPLOY_DEBUG("output: {}", to_json(output).dump(2));

return output;
Expand Down
3 changes: 2 additions & 1 deletion csrc/preprocess/transform/normalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ Result<Value> NormalizeImpl::Process(const Value& input) {
assert(desc.shape.size() == 4 /*n, h, w, c*/);
assert(desc.shape[3] == arg_.mean.size());

OUTCOME_TRY(output[key], NormalizeImage(tensor));
OUTCOME_TRY(auto dst, NormalizeImage(tensor));
SetTransformData(output, key, std::move(dst));

for (auto& v : arg_.mean) {
output["img_norm_cfg"]["mean"].push_back(v);
Expand Down
Loading

0 comments on commit 73cf3b5

Please sign in to comment.