diff --git a/csrc/core/device.h b/csrc/core/device.h index f337ad4dcd..c1713162e1 100644 --- a/csrc/core/device.h +++ b/csrc/core/device.h @@ -208,9 +208,9 @@ class MMDEPLOY_API Event { void *GetNative(ErrorCode *ec = nullptr); - bool operator==(const Event &other); + bool operator==(const Event &other) const { return impl_ == other.impl_; } - bool operator!=(const Event &other); + bool operator!=(const Event &other) const { return !(*this == other); } explicit operator bool() const noexcept { return static_cast(impl_); } @@ -285,6 +285,10 @@ class MMDEPLOY_API Buffer { Allocator GetAllocator() const; + bool operator==(const Buffer &other) const { return impl_ == other.impl_; } + + bool operator!=(const Buffer &other) const { return !(*this == other); } + explicit operator bool() const noexcept { return static_cast(impl_); } private: diff --git a/csrc/core/utils/device_utils.cpp b/csrc/core/utils/device_utils.cpp index 265d3488e8..561cc2f09f 100644 --- a/csrc/core/utils/device_utils.cpp +++ b/csrc/core/utils/device_utils.cpp @@ -1,6 +1,9 @@ // Copyright (c) OpenMMLab. All rights reserved. #include "device_utils.h" + +#include "core/logger.h" + namespace mmdeploy { Result MakeAvailableOnDevice(const Mat& src, const Device& device, Stream& stream) { @@ -26,4 +29,14 @@ Result MakeAvailableOnDevice(const Tensor& src, const Device& device, St return dst; } +SyncOnScopeExit::~SyncOnScopeExit() { + if (active_ && stream_) { + if (!stream_.Wait()) { + MMDEPLOY_ERROR("Implicit stream synchronization failed."); + } else { + MMDEPLOY_DEBUG("Implicit stream synchronization succeeded."); + } + } +} + } // namespace mmdeploy diff --git a/csrc/core/utils/device_utils.h b/csrc/core/utils/device_utils.h index 65422664e8..b5822f125a 100644 --- a/csrc/core/utils/device_utils.h +++ b/csrc/core/utils/device_utils.h @@ -3,6 +3,8 @@ #ifndef MMDEPLOY_TRANSFORM_UTILS_H #define MMDEPLOY_TRANSFORM_UTILS_H +#include + #include "core/mat.h" #include "core/tensor.h" @@ -26,6 +28,26 @@ MMDEPLOY_API Result MakeAvailableOnDevice(const Mat& src, const Device& dev */ MMDEPLOY_API Result MakeAvailableOnDevice(const Tensor& src, const Device& device, Stream& stream); + + +// Calls stream.Wait() on destruction if active is true. This class is used to force a wait +// operation before intermediate variables goes out of scope. Add variables in consideration to the +// tailing parameter pack to ensure correctness (this make sure SyncOnScopeExit is created later +// (thus will be destructed earlier) than the variables + +class MMDEPLOY_API SyncOnScopeExit { + public: + template + explicit SyncOnScopeExit(Stream& stream, bool active, Ts&&...) noexcept + : stream_(stream), active_(active) {} + + ~SyncOnScopeExit(); + + private: + bool active_; + Stream& stream_; +}; + } // namespace mmdeploy #endif // MMDEPLOY_TRANSFORM_UTILS_H diff --git a/csrc/core/value.h b/csrc/core/value.h index 3241330565..4ee0119ddc 100644 --- a/csrc/core/value.h +++ b/csrc/core/value.h @@ -50,6 +50,11 @@ class ValueRef; template class ValueIterator { public: + using value_type = Value; + using difference_type = std::ptrdiff_t; + using pointer = value_type*; + using reference = value_type&; + using iterator_category = std::bidirectional_iterator_tag; using object_iterator_t = typename T::Object::iterator; using array_iterator_t = typename T::Array::iterator; ValueIterator() = default; diff --git a/csrc/preprocess/cpu/crop_impl.cpp b/csrc/preprocess/cpu/crop_impl.cpp index 121e1f42c9..740b3416ae 100644 --- a/csrc/preprocess/cpu/crop_impl.cpp +++ b/csrc/preprocess/cpu/crop_impl.cpp @@ -18,6 +18,8 @@ class CenterCropImpl : public ::mmdeploy::CenterCropImpl { int right) override { OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_)); + SyncOnScopeExit(stream_, src_tensor.buffer() != tensor.buffer(), src_tensor); + cv::Mat mat = Tensor2CVMat(src_tensor); cv::Mat cropped_mat = Crop(mat, top, left, bottom, right); return CVMat2Tensor(cropped_mat); diff --git a/csrc/preprocess/cpu/default_format_bundle_impl.cpp b/csrc/preprocess/cpu/default_format_bundle_impl.cpp index efee3cc47b..37b62617ae 100644 --- a/csrc/preprocess/cpu/default_format_bundle_impl.cpp +++ b/csrc/preprocess/cpu/default_format_bundle_impl.cpp @@ -14,6 +14,9 @@ class DefaultFormatBundleImpl : public ::mmdeploy::DefaultFormatBundleImpl { protected: Result ToFloat32(const Tensor& tensor, const bool& img_to_float) override { OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_)); + + SyncOnScopeExit(stream_, src_tensor.buffer() != tensor.buffer(), src_tensor); + auto data_type = src_tensor.desc().data_type; if (img_to_float && data_type == DataType::kINT8) { @@ -27,6 +30,9 @@ class DefaultFormatBundleImpl : public ::mmdeploy::DefaultFormatBundleImpl { Result HWC2CHW(const Tensor& tensor) override { OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_)); + + SyncOnScopeExit(stream_, src_tensor.buffer() != tensor.buffer(), src_tensor); + auto shape = src_tensor.shape(); int height = shape[1]; int width = shape[2]; diff --git a/csrc/preprocess/cpu/image2tensor_impl.cpp b/csrc/preprocess/cpu/image2tensor_impl.cpp index ddd3e339b2..3e497f56d7 100644 --- a/csrc/preprocess/cpu/image2tensor_impl.cpp +++ b/csrc/preprocess/cpu/image2tensor_impl.cpp @@ -14,6 +14,9 @@ class ImageToTensorImpl : public ::mmdeploy::ImageToTensorImpl { protected: Result HWC2CHW(const Tensor& tensor) override { OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_)); + + SyncOnScopeExit(stream_, src_tensor.buffer() != tensor.buffer(), src_tensor); + auto shape = src_tensor.shape(); int height = shape[1]; int width = shape[2]; diff --git a/csrc/preprocess/cpu/normalize_impl.cpp b/csrc/preprocess/cpu/normalize_impl.cpp index 1be6d53c9f..3e56c111a8 100644 --- a/csrc/preprocess/cpu/normalize_impl.cpp +++ b/csrc/preprocess/cpu/normalize_impl.cpp @@ -18,6 +18,9 @@ class NormalizeImpl : public ::mmdeploy::NormalizeImpl { protected: Result NormalizeImage(const Tensor& tensor) override { OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_)); + + SyncOnScopeExit(stream_, src_tensor.buffer() != tensor.buffer(), src_tensor); + auto mat = Tensor2CVMat(src_tensor); auto dst_mat = Normalize(mat, arg_.mean, arg_.std, arg_.to_rgb, true); return CVMat2Tensor(dst_mat); diff --git a/csrc/preprocess/cpu/pad_impl.cpp b/csrc/preprocess/cpu/pad_impl.cpp index c75ba4139b..5bf2ce9cc3 100644 --- a/csrc/preprocess/cpu/pad_impl.cpp +++ b/csrc/preprocess/cpu/pad_impl.cpp @@ -26,6 +26,9 @@ class PadImpl : public ::mmdeploy::PadImpl { protected: Result PadImage(const Tensor& img, const std::array& padding) override { OUTCOME_TRY(auto tensor, MakeAvailableOnDevice(img, device_, stream_)); + + SyncOnScopeExit(stream_, tensor.buffer() != img.buffer(), tensor); + cv::Mat dst_mat = Pad(Tensor2CVMat(tensor), padding[1], padding[0], padding[3], padding[2], border_type_, arg_.pad_val); return CVMat2Tensor(dst_mat); diff --git a/csrc/preprocess/cpu/resize_impl.cpp b/csrc/preprocess/cpu/resize_impl.cpp index 3079197f7a..9e03e5b792 100644 --- a/csrc/preprocess/cpu/resize_impl.cpp +++ b/csrc/preprocess/cpu/resize_impl.cpp @@ -20,6 +20,8 @@ class ResizeImpl final : public ::mmdeploy::ResizeImpl { Result ResizeImage(const Tensor& img, int dst_h, int dst_w) override { OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(img, device_, stream_)); + SyncOnScopeExit(stream_, src_tensor.buffer() != img.buffer(), src_tensor); + auto src_mat = Tensor2CVMat(src_tensor); auto dst_mat = Resize(src_mat, dst_h, dst_w, arg_.interpolation); diff --git a/csrc/preprocess/cuda/crop_impl.cpp b/csrc/preprocess/cuda/crop_impl.cpp index eb6f64f835..e33a3e3c69 100644 --- a/csrc/preprocess/cuda/crop_impl.cpp +++ b/csrc/preprocess/cuda/crop_impl.cpp @@ -23,6 +23,8 @@ class CenterCropImpl : public ::mmdeploy::CenterCropImpl { int right) override { OUTCOME_TRY(auto device_tensor, MakeAvailableOnDevice(tensor, device_, stream_)); + SyncOnScopeExit sync(stream_, device_tensor.buffer() != tensor.buffer(), device_tensor); + auto stream = GetNative(stream_); auto desc = device_tensor.desc(); diff --git a/csrc/preprocess/cuda/default_format_bundle_impl.cpp b/csrc/preprocess/cuda/default_format_bundle_impl.cpp index 2091d4ac8d..0cc143a2ad 100644 --- a/csrc/preprocess/cuda/default_format_bundle_impl.cpp +++ b/csrc/preprocess/cuda/default_format_bundle_impl.cpp @@ -21,6 +21,9 @@ class DefaultFormatBundleImpl final : public ::mmdeploy::DefaultFormatBundleImpl protected: Result ToFloat32(const Tensor& tensor, const bool& img_to_float) override { OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_)); + + SyncOnScopeExit sync(stream_, src_tensor.buffer() != tensor.buffer(), src_tensor); + auto data_type = src_tensor.data_type(); auto h = tensor.shape(1); auto w = tensor.shape(2); @@ -45,6 +48,9 @@ class DefaultFormatBundleImpl final : public ::mmdeploy::DefaultFormatBundleImpl Result HWC2CHW(const Tensor& tensor) override { OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_)); + + SyncOnScopeExit sync(stream_, src_tensor.buffer() != tensor.buffer(), src_tensor); + auto h = tensor.shape(1); auto w = tensor.shape(2); auto c = tensor.shape(3); diff --git a/csrc/preprocess/cuda/image2tensor_impl.cpp b/csrc/preprocess/cuda/image2tensor_impl.cpp index a57219dc6c..28876a2254 100644 --- a/csrc/preprocess/cuda/image2tensor_impl.cpp +++ b/csrc/preprocess/cuda/image2tensor_impl.cpp @@ -18,6 +18,9 @@ class ImageToTensorImpl final : public ::mmdeploy::ImageToTensorImpl { protected: Result HWC2CHW(const Tensor& tensor) override { OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_)); + + SyncOnScopeExit sync(stream_, src_tensor.buffer() != tensor.buffer(), src_tensor); + auto h = tensor.shape(1); auto w = tensor.shape(2); auto c = tensor.shape(3); diff --git a/csrc/preprocess/cuda/load_impl.cpp b/csrc/preprocess/cuda/load_impl.cpp index e7ffe506d2..cec80b401c 100644 --- a/csrc/preprocess/cuda/load_impl.cpp +++ b/csrc/preprocess/cuda/load_impl.cpp @@ -25,8 +25,7 @@ class PrepareImageImpl : public ::mmdeploy::PrepareImageImpl { Tensor Mat2Tensor(const mmdeploy::Mat& mat) { TensorDesc desc{ mat.buffer().GetDevice(), mat.type(), {1, mat.height(), mat.width(), mat.channel()}, ""}; - shared_ptr data(mat.data(), [mat = mat](void* p) {}); - return Tensor(desc, data); + return Tensor(std::move(desc), mat.buffer()); } protected: @@ -39,6 +38,9 @@ class PrepareImageImpl : public ::mmdeploy::PrepareImageImpl { cudaStream_t stream = ::mmdeploy::GetNative(stream_); Mat dst_mat(src_mat.height(), src_mat.width(), PixelFormat::kBGR, src_mat.type(), device_); + + SyncOnScopeExit sync(stream_, true, src_mat, dst_mat); + ppl::common::RetCode ret = 0; int src_h = src_mat.height(); @@ -97,6 +99,9 @@ class PrepareImageImpl : public ::mmdeploy::PrepareImageImpl { cudaStream_t stream = ::mmdeploy::GetNative(stream_); Mat dst_mat(src_mat.height(), src_mat.width(), PixelFormat::kGRAYSCALE, src_mat.type(), device_); + + SyncOnScopeExit sync(stream_, true, src_mat, dst_mat); + ppl::common::RetCode ret = 0; int src_h = src_mat.height(); diff --git a/csrc/preprocess/cuda/normalize_impl.cpp b/csrc/preprocess/cuda/normalize_impl.cpp index 48e6647990..c337305670 100644 --- a/csrc/preprocess/cuda/normalize_impl.cpp +++ b/csrc/preprocess/cuda/normalize_impl.cpp @@ -21,6 +21,9 @@ class NormalizeImpl : public ::mmdeploy::NormalizeImpl { protected: Result NormalizeImage(const Tensor& tensor) override { OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_)); + + SyncOnScopeExit sync(stream_, src_tensor.buffer() != tensor.buffer(), src_tensor); + auto src_desc = src_tensor.desc(); int h = (int)src_desc.shape[1]; int w = (int)src_desc.shape[2]; diff --git a/csrc/preprocess/cuda/pad_impl.cpp b/csrc/preprocess/cuda/pad_impl.cpp index 77781c0485..511b95b2f3 100644 --- a/csrc/preprocess/cuda/pad_impl.cpp +++ b/csrc/preprocess/cuda/pad_impl.cpp @@ -37,6 +37,8 @@ class PadImpl : public ::mmdeploy::PadImpl { Result PadImage(const Tensor& img, const array& padding) override { OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(img, device_, stream_)); + SyncOnScopeExit sync(stream_, src_tensor.buffer() != img.buffer(), src_tensor); + auto desc = src_tensor.desc(); int height = desc.shape[1]; int width = desc.shape[2]; diff --git a/csrc/preprocess/cuda/resize_impl.cpp b/csrc/preprocess/cuda/resize_impl.cpp index 8a37664801..2c6d07c0d6 100644 --- a/csrc/preprocess/cuda/resize_impl.cpp +++ b/csrc/preprocess/cuda/resize_impl.cpp @@ -23,6 +23,9 @@ class ResizeImpl final : public ::mmdeploy::ResizeImpl { protected: Result ResizeImage(const Tensor& tensor, int dst_h, int dst_w) override { OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_)); + + SyncOnScopeExit sync(stream_, src_tensor.buffer() != tensor.buffer(), src_tensor); + TensorDesc dst_desc{ device_, src_tensor.data_type(), {1, dst_h, dst_w, src_tensor.shape(3)}, src_tensor.name()}; Tensor dst_tensor(dst_desc); diff --git a/csrc/preprocess/transform/compose.cpp b/csrc/preprocess/transform/compose.cpp index a52b6848a1..be147e20d1 100644 --- a/csrc/preprocess/transform/compose.cpp +++ b/csrc/preprocess/transform/compose.cpp @@ -34,14 +34,16 @@ Compose::Compose(const Value& args, int version) : Transform(args) { Result Compose::Process(const Value& input) { Value output = input; + Value::Array intermediates; for (auto& transform : transforms_) { - auto t = transform->Process(output); - OUTCOME_TRY(stream_.Wait()); - if (!t) { - return t; + OUTCOME_TRY(auto t, transform->Process(output)); + if (auto it = t.find("__data__"); it != t.end()) { + std::move(it->begin(), it->end(), std::back_inserter(intermediates)); + it->array().clear(); } - output = std::move(t).value(); + output = std::move(t); } + OUTCOME_TRY(stream_.Wait()); return std::move(output); } diff --git a/csrc/preprocess/transform/crop.cpp b/csrc/preprocess/transform/crop.cpp index 1ea8867cab..39badbb05b 100644 --- a/csrc/preprocess/transform/crop.cpp +++ b/csrc/preprocess/transform/crop.cpp @@ -47,7 +47,6 @@ Result CenterCropImpl::Process(const Value& input) { auto& shape = dst_tensor.desc().shape; - output[key] = dst_tensor; output["img_shape"] = {shape[0], shape[1], shape[2], shape[3]}; if (input.contains("scale_factor")) { // image has been processed by `Resize` transform before. @@ -61,6 +60,8 @@ Result CenterCropImpl::Process(const Value& input) { output["offset"].push_back(x1); output["offset"].push_back(y1); } + + SetTransformData(output, key, std::move(dst_tensor)); } MMDEPLOY_DEBUG("output: {}", to_json(output).dump(2)); diff --git a/csrc/preprocess/transform/default_format_bundle.cpp b/csrc/preprocess/transform/default_format_bundle.cpp index 7dbcbfa736..abe6b01dd4 100644 --- a/csrc/preprocess/transform/default_format_bundle.cpp +++ b/csrc/preprocess/transform/default_format_bundle.cpp @@ -20,9 +20,8 @@ Result DefaultFormatBundleImpl::Process(const Value& input) { Value output = input; if (input.contains("img")) { Tensor in_tensor = input["img"].get(); - OUTCOME_TRY(output["img"], ToFloat32(in_tensor, arg_.img_to_float)); + OUTCOME_TRY(auto tensor, ToFloat32(in_tensor, arg_.img_to_float)); - Tensor tensor = output["img"].get(); // set default meta keys if (!output.contains("pad_shape")) { for (auto v : tensor.shape()) { @@ -42,7 +41,8 @@ Result DefaultFormatBundleImpl::Process(const Value& input) { } // transpose - OUTCOME_TRY(output["img"], HWC2CHW(tensor)); + OUTCOME_TRY(tensor, HWC2CHW(tensor)); + SetTransformData(output, "img", std::move(tensor)); } MMDEPLOY_DEBUG("DefaultFormatBundle output: {}", to_json(output).dump(2)); diff --git a/csrc/preprocess/transform/image2tensor.cpp b/csrc/preprocess/transform/image2tensor.cpp index e2ccd3bb5d..163a73a2f4 100644 --- a/csrc/preprocess/transform/image2tensor.cpp +++ b/csrc/preprocess/transform/image2tensor.cpp @@ -26,7 +26,8 @@ Result ImageToTensorImpl::Process(const Value& input) { assert(shape.size() == 4); assert(shape[3] == 1 || shape[3] == 3); - OUTCOME_TRY(output[key], HWC2CHW(src_tensor)); + OUTCOME_TRY(auto dst, HWC2CHW(src_tensor)); + SetTransformData(output, key, std::move(dst)); } // for key MMDEPLOY_DEBUG("output: {}", to_json(output).dump(2)); return output; diff --git a/csrc/preprocess/transform/load.cpp b/csrc/preprocess/transform/load.cpp index 462c70a837..2d9a647975 100644 --- a/csrc/preprocess/transform/load.cpp +++ b/csrc/preprocess/transform/load.cpp @@ -44,12 +44,14 @@ Result PrepareImageImpl::Process(const Value& input) { OUTCOME_TRY(auto tensor, std::move(res)); - output["img"] = tensor; for (auto v : tensor.desc().shape) { output["img_shape"].push_back(v); } output["ori_shape"] = {1, src_mat.height(), src_mat.width(), src_mat.channel()}; output["img_fields"].push_back("img"); + + SetTransformData(output, "img", std::move(tensor)); + MMDEPLOY_DEBUG("output: {}", to_json(output).dump(2)); return output; diff --git a/csrc/preprocess/transform/normalize.cpp b/csrc/preprocess/transform/normalize.cpp index 7fc9c2ad31..7318a8ac31 100644 --- a/csrc/preprocess/transform/normalize.cpp +++ b/csrc/preprocess/transform/normalize.cpp @@ -65,7 +65,8 @@ Result NormalizeImpl::Process(const Value& input) { assert(desc.shape.size() == 4 /*n, h, w, c*/); assert(desc.shape[3] == arg_.mean.size()); - OUTCOME_TRY(output[key], NormalizeImage(tensor)); + OUTCOME_TRY(auto dst, NormalizeImage(tensor)); + SetTransformData(output, key, std::move(dst)); for (auto& v : arg_.mean) { output["img_norm_cfg"]["mean"].push_back(v); diff --git a/csrc/preprocess/transform/pad.cpp b/csrc/preprocess/transform/pad.cpp index 4e0ac60583..293348347a 100644 --- a/csrc/preprocess/transform/pad.cpp +++ b/csrc/preprocess/transform/pad.cpp @@ -48,25 +48,23 @@ Result PadImpl::Process(const Value& input) { assert(tensor.desc().shape[0] == 1); assert(tensor.desc().shape[3] == 3 || tensor.desc().shape[3] == 1); - int height = tensor.desc().shape[1]; - int width = tensor.desc().shape[2]; + int height = tensor.shape(1); + int width = tensor.shape(2); + + std::array padding{0, 0, 0, 0}; if (arg_.pad_to_square) { - int max_size = std::max(tensor.desc().shape[1], tensor.desc().shape[2]); - std::array padding{0, 0, max_size - width, max_size - height}; - OUTCOME_TRY(output_tensor, PadImage(tensor, padding)); + int max_size = std::max(tensor.shape(1), tensor.shape(2)); + padding = {0, 0, max_size - width, max_size - height}; output["pad_fixed_size"].push_back(max_size); output["pad_fixed_size"].push_back(max_size); } else if (arg_.size[0] != 0 && arg_.size[1] != 0) { - output_tensor = tensor; - std::array padding{0, 0, arg_.size[1] - width, arg_.size[0] - height}; - OUTCOME_TRY(output_tensor, PadImage(tensor, padding)); + padding = {0, 0, arg_.size[1] - width, arg_.size[0] - height}; output["pad_fixed_size"].push_back(arg_.size[0]); output["pad_fixed_size"].push_back(arg_.size[1]); } else if (arg_.size_divisor != 1) { auto pad_h = (height + arg_.size_divisor - 1) / arg_.size_divisor * arg_.size_divisor; auto pad_w = (width + arg_.size_divisor - 1) / arg_.size_divisor * arg_.size_divisor; - std::array padding{0, 0, pad_w - width, pad_h - height}; - OUTCOME_TRY(output_tensor, PadImage(tensor, padding)); + padding = {0, 0, pad_w - width, pad_h - height}; output["pad_size_divisor"] = arg_.size_divisor; output["pad_fixed_size"].push_back(pad_h); output["pad_fixed_size"].push_back(pad_w); @@ -75,10 +73,18 @@ Result PadImpl::Process(const Value& input) { output["pad_fixed_size"].push_back(height); output["pad_fixed_size"].push_back(width); } - output[key] = output_tensor; - for (auto& v : output_tensor.desc().shape) { + + if (std::count(begin(padding), end(padding), 0) != 4) { + OUTCOME_TRY(output_tensor, PadImage(tensor, padding)); + } else { + output_tensor = tensor; + } + + for (auto& v : output_tensor.shape()) { output["pad_shape"].push_back(v); } + + SetTransformData(output, key, std::move(output_tensor)); } MMDEPLOY_DEBUG("output: {}", to_json(output).dump(2)); diff --git a/csrc/preprocess/transform/resize.cpp b/csrc/preprocess/transform/resize.cpp index 98398e3dce..34292331a9 100644 --- a/csrc/preprocess/transform/resize.cpp +++ b/csrc/preprocess/transform/resize.cpp @@ -106,9 +106,9 @@ Result ResizeImpl::Process(const Value& input) { auto h_scale = dst_h * 1.0 / h; output["scale_factor"] = {w_scale, h_scale, w_scale, h_scale}; output["img_shape"] = {1, dst_h, dst_w, desc.shape[3]}; - // output["pad_shape"] = output["img_shape"]; output["keep_ratio"] = arg_.keep_ratio; - output[key] = dst_img; + + SetTransformData(output, key, std::move(dst_img)); } MMDEPLOY_DEBUG("output: {}", to_json(output).dump(2)); diff --git a/csrc/preprocess/transform/transform.h b/csrc/preprocess/transform/transform.h index ba96e91a14..64456d41e3 100644 --- a/csrc/preprocess/transform/transform.h +++ b/csrc/preprocess/transform/transform.h @@ -65,6 +65,12 @@ class MMDEPLOY_API Transform : public Module { std::vector candidate_platforms_; }; +template +void SetTransformData(Value& dst, Key&& key, Val val) { + dst[std::forward(key)] = val; + dst["__data__"].push_back(std::move(val)); +} + MMDEPLOY_DECLARE_REGISTRY(Transform); } // namespace mmdeploy diff --git a/mmdeploy/backend/sdk/wrapper.py b/mmdeploy/backend/sdk/wrapper.py index 8b09d53673..338e86641c 100644 --- a/mmdeploy/backend/sdk/wrapper.py +++ b/mmdeploy/backend/sdk/wrapper.py @@ -2,6 +2,7 @@ import mmdeploy_python as c_api from mmdeploy.utils import Backend +from mmdeploy.utils.timer import TimeCounter from ..base import BACKEND_WRAPPER, BaseWrapper @@ -14,6 +15,7 @@ def __init__(self, model_file, task_name, device): # TODO: get device id somewhere self.handle = creator(model_file, device, 0) + @TimeCounter.count_time() def invoke(self, imgs): return self.handle(imgs)