diff --git a/csrcs/fastdeploy/vision/common/processors/resize.h b/csrcs/fastdeploy/vision/common/processors/resize.h index 137007997f..5b6e9c0257 100644 --- a/csrcs/fastdeploy/vision/common/processors/resize.h +++ b/csrcs/fastdeploy/vision/common/processors/resize.h @@ -41,6 +41,16 @@ class Resize : public Processor { float scale_h = -1.0, int interp = 1, bool use_scale = false, ProcLib lib = ProcLib::OPENCV_CPU); + bool SetWidthAndHeight(int width, int height) { + width_ = width; + height_ = height; + return true; + } + + std::tuple GetWidthAndHeight() { + return std::make_tuple(width_, height_); + } + private: int width_; int height_; @@ -49,5 +59,5 @@ class Resize : public Processor { int interp_ = 1; bool use_scale_ = false; }; -} // namespace vision -} // namespace fastdeploy +} // namespace vision +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/common/result.cc b/csrcs/fastdeploy/vision/common/result.cc index 0ef077f0ce..1a9a6dbfeb 100644 --- a/csrcs/fastdeploy/vision/common/result.cc +++ b/csrcs/fastdeploy/vision/common/result.cc @@ -140,11 +140,24 @@ std::string FaceDetectionResult::Str() { } void SegmentationResult::Clear() { - std::vector>().swap(masks); + std::vector().swap(label_map); + std::vector().swap(score_map); + std::vector().swap(shape); + contain_score_map = false; } -void SegmentationResult::Resize(int64_t height, int64_t width) { - masks.resize(height, std::vector(width)); +void SegmentationResult::Reserve(int size) { + label_map.reserve(size); + if (contain_score_map > 0) { + score_map.reserve(size); + } +} + +void SegmentationResult::Resize(int size) { + label_map.resize(size); + if (contain_score_map) { + score_map.resize(size); + } } std::string SegmentationResult::Str() { @@ -153,11 +166,24 @@ std::string SegmentationResult::Str() { for (size_t i = 0; i < 10; ++i) { out += "["; for (size_t j = 0; j < 10; ++j) { - out = out + std::to_string(masks[i][j]) + ", "; + out = out + std::to_string(label_map[i * 10 + j]) + ", "; } out += ".....]\n"; } out += "...........\n"; + if (contain_score_map) { + out += "SegmentationResult Score map 10 rows x 10 cols: \n"; + for (size_t i = 0; i < 10; ++i) { + out += "["; + for (size_t j = 0; j < 10; ++j) { + out = out + std::to_string(score_map[i * 10 + j]) + ", "; + } + out += ".....]\n"; + } + out += "...........\n"; + } + out += "result shape is: [" + std::to_string(shape[0]) + " " + + std::to_string(shape[1]) + "]"; return out; } diff --git a/csrcs/fastdeploy/vision/common/result.h b/csrcs/fastdeploy/vision/common/result.h index 4900d394d8..f2b20f623b 100644 --- a/csrcs/fastdeploy/vision/common/result.h +++ b/csrcs/fastdeploy/vision/common/result.h @@ -84,13 +84,18 @@ struct FASTDEPLOY_DECL FaceDetectionResult : public BaseResult { struct FASTDEPLOY_DECL SegmentationResult : public BaseResult { // mask - std::vector> masks; + std::vector label_map; + std::vector score_map; + std::vector shape; + bool contain_score_map = false; ResultType type = ResultType::SEGMENTATION; void Clear(); - void Resize(int64_t height, int64_t width); + void Reserve(int size); + + void Resize(int size); std::string Str(); }; diff --git a/csrcs/fastdeploy/vision/ppseg/model.cc b/csrcs/fastdeploy/vision/ppseg/model.cc index 268d85f7d3..7f692c6a71 100644 --- a/csrcs/fastdeploy/vision/ppseg/model.cc +++ b/csrcs/fastdeploy/vision/ppseg/model.cc @@ -11,8 +11,8 @@ Model::Model(const std::string& model_file, const std::string& params_file, const std::string& config_file, const RuntimeOption& custom_option, const Frontend& model_format) { config_file_ = config_file; - valid_cpu_backends = {Backend::ORT, Backend::PDINFER}; - valid_gpu_backends = {Backend::ORT, Backend::PDINFER}; + valid_cpu_backends = {Backend::PDINFER, Backend::ORT}; + valid_gpu_backends = {Backend::PDINFER, Backend::ORT}; runtime_option = custom_option; runtime_option.model_format = model_format; runtime_option.model_file = model_file; @@ -65,6 +65,7 @@ bool Model::BuildPreprocessPipelineFromConfig() { const auto& target_size = op["target_size"]; int resize_width = target_size[0].as(); int resize_height = target_size[1].as(); + is_resized = true; processors_.push_back( std::make_shared(resize_width, resize_height)); } @@ -74,49 +75,140 @@ bool Model::BuildPreprocessPipelineFromConfig() { return true; } -bool Model::Preprocess(Mat* mat, FDTensor* output) { +bool Model::Preprocess(Mat* mat, FDTensor* output, + std::map>* im_info) { for (size_t i = 0; i < processors_.size(); ++i) { + if (processors_[i]->Name().compare("Resize") == 0) { + auto processor = dynamic_cast(processors_[i].get()); + int resize_width = -1; + int resize_height = -1; + std::tie(resize_width, resize_height) = processor->GetWidthAndHeight(); + if (is_vertical_screen && (resize_width > resize_height)) { + if (processor->SetWidthAndHeight(resize_height, resize_width)) { + FDERROR << "Failed to set Resize processor width and height " + << processors_[i]->Name() << "." << std::endl; + } + } + } if (!(*(processors_[i].get()))(mat)) { FDERROR << "Failed to process image data in " << processors_[i]->Name() << "." << std::endl; return false; } } + + // Record output shape of preprocessed image + (*im_info)["output_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + mat->ShareWithTensor(output); output->shape.insert(output->shape.begin(), 1); output->name = InputInfoOfRuntime(0).name; return true; } -bool Model::Postprocess(const FDTensor& infer_result, - SegmentationResult* result) { - FDASSERT(infer_result.dtype == FDDataType::INT64, - "Require the data type of output is int64, but now it's " + - Str(const_cast(infer_result.dtype)) + - "."); +bool Model::Postprocess(FDTensor& infer_result, SegmentationResult* result, + std::map>* im_info) { + // PaddleSeg has three types of inference output: + // 1. output with argmax and without softmax. 3-D matrix CHW, Channel + // always 1, the element in matrix is classified label_id INT64 Type. + // 2. output without argmax and without softmax. 4-D matrix NCHW, N always + // 1, Channel is the num of classes. The element is the logits of classes + // FP32 + // 3. output without argmax and with softmax. 4-D matrix NCHW, the result + // of 2 with softmax layer + // Fastdeploy output: + // 1. label_map + // 2. score_map(optional) + // 3. shape: 2-D HW + FDASSERT(infer_result.dtype == FDDataType::INT64 || + infer_result.dtype == FDDataType::FP32, + "Require the data type of output is int64 or fp32, but now it's " + + Str(infer_result.dtype) + "."); result->Clear(); - std::vector output_shape = infer_result.shape; - int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, - std::multiplies()); - const int64_t* infer_result_buffer = - reinterpret_cast(infer_result.data.data()); - int64_t height = output_shape[1]; - int64_t width = output_shape[2]; - result->Resize(height, width); - for (int64_t i = 0; i < height; i++) { - int64_t begin = i * width; - int64_t end = (i + 1) * width - 1; - std::copy(infer_result_buffer + begin, infer_result_buffer + end, - result->masks[i].begin()); + + if (infer_result.shape.size() == 4) { + FDASSERT(infer_result.shape[0] == 1, "Only support batch size = 1."); + // output without argmax + result->contain_score_map = true; + utils::NCHW2NHWC(infer_result); } + // for resize mat below + FDTensor new_infer_result; + Mat* mat = nullptr; + if (is_resized) { + cv::Mat temp_mat; + utils::FDTensor2FP32CVMat(temp_mat, infer_result, + result->contain_score_map); + + // original image shape + auto iter_ipt = (*im_info).find("input_shape"); + FDASSERT(iter_ipt != im_info->end(), + "Cannot find input_shape from im_info."); + int ipt_h = iter_ipt->second[0]; + int ipt_w = iter_ipt->second[1]; + + mat = new Mat(temp_mat); + + Resize::Run(mat, ipt_w, ipt_h, -1, -1, 1); + mat->ShareWithTensor(&new_infer_result); + new_infer_result.shape.insert(new_infer_result.shape.begin(), 1); + result->shape = new_infer_result.shape; + } else { + result->shape = infer_result.shape; + } + int out_num = + std::accumulate(result->shape.begin(), result->shape.begin() + 3, 1, + std::multiplies()); + // NCHW remove N or CHW remove C + result->shape.erase(result->shape.begin()); + result->Resize(out_num); + if (result->contain_score_map) { + // output with label_map and score_map + float_t* infer_result_buffer = nullptr; + if (is_resized) { + infer_result_buffer = static_cast(new_infer_result.Data()); + } else { + infer_result_buffer = static_cast(infer_result.Data()); + } + // argmax + utils::ArgmaxScoreMap(infer_result_buffer, result, with_softmax); + result->shape.erase(result->shape.begin() + 2); + } else { + // output only with label_map + if (is_resized) { + float_t* infer_result_buffer = + static_cast(new_infer_result.Data()); + for (int i = 0; i < out_num; i++) { + result->label_map[i] = static_cast(*(infer_result_buffer + i)); + } + } else { + const int64_t* infer_result_buffer = + reinterpret_cast(infer_result.Data()); + for (int i = 0; i < out_num; i++) { + result->label_map[i] = static_cast(*(infer_result_buffer + i)); + } + } + } + delete mat; + mat = nullptr; return true; } bool Model::Predict(cv::Mat* im, SegmentationResult* result) { Mat mat(*im); std::vector processed_data(1); - if (!Preprocess(&mat, &(processed_data[0]))) { + + std::map> im_info; + + // Record the shape of image and the shape of preprocessed image + im_info["input_shape"] = {static_cast(mat.Height()), + static_cast(mat.Width())}; + im_info["output_shape"] = {static_cast(mat.Height()), + static_cast(mat.Width())}; + + if (!Preprocess(&mat, &(processed_data[0]), &im_info)) { FDERROR << "Failed to preprocess input data while using model:" << ModelName() << "." << std::endl; return false; @@ -127,7 +219,7 @@ bool Model::Predict(cv::Mat* im, SegmentationResult* result) { << std::endl; return false; } - if (!Postprocess(infer_result[0], result)) { + if (!Postprocess(infer_result[0], result, &im_info)) { FDERROR << "Failed to postprocess while using model:" << ModelName() << "." << std::endl; return false; diff --git a/csrcs/fastdeploy/vision/ppseg/model.h b/csrcs/fastdeploy/vision/ppseg/model.h index c0ca5a70d0..72f8dbc645 100644 --- a/csrcs/fastdeploy/vision/ppseg/model.h +++ b/csrcs/fastdeploy/vision/ppseg/model.h @@ -18,14 +18,22 @@ class FASTDEPLOY_DECL Model : public FastDeployModel { virtual bool Predict(cv::Mat* im, SegmentationResult* result); + bool with_softmax = false; + + bool is_vertical_screen = false; + private: bool Initialize(); bool BuildPreprocessPipelineFromConfig(); - bool Preprocess(Mat* mat, FDTensor* outputs); + bool Preprocess(Mat* mat, FDTensor* outputs, + std::map>* im_info); + + bool Postprocess(FDTensor& infer_result, SegmentationResult* result, + std::map>* im_info); - bool Postprocess(const FDTensor& infer_result, SegmentationResult* result); + bool is_resized = false; std::vector> processors_; std::string config_file_; diff --git a/csrcs/fastdeploy/vision/ppseg/ppseg_pybind.cc b/csrcs/fastdeploy/vision/ppseg/ppseg_pybind.cc index 60022f914b..949c274875 100644 --- a/csrcs/fastdeploy/vision/ppseg/ppseg_pybind.cc +++ b/csrcs/fastdeploy/vision/ppseg/ppseg_pybind.cc @@ -20,11 +20,16 @@ void BindPPSeg(pybind11::module& m) { pybind11::class_(ppseg_module, "Model") .def(pybind11::init()) - .def("predict", [](vision::ppseg::Model& self, pybind11::array& data) { - auto mat = PyArrayToCvMat(data); - vision::SegmentationResult res; - self.Predict(&mat, &res); - return res; - }); + .def("predict", + [](vision::ppseg::Model& self, pybind11::array& data) { + auto mat = PyArrayToCvMat(data); + vision::SegmentationResult* res = new vision::SegmentationResult(); + // self.Predict(&mat, &res); + self.Predict(&mat, res); + return res; + }) + .def_readwrite("with_softmax", &vision::ppseg::Model::with_softmax) + .def_readwrite("is_vertical_screen", + &vision::ppseg::Model::is_vertical_screen); } } // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/utils/FDTensor2CVMat.cc b/csrcs/fastdeploy/vision/utils/FDTensor2CVMat.cc new file mode 100644 index 0000000000..fdd110cb8c --- /dev/null +++ b/csrcs/fastdeploy/vision/utils/FDTensor2CVMat.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/utils/utils.h" + +namespace fastdeploy { +namespace vision { +namespace utils { + +void FDTensor2FP32CVMat(cv::Mat& mat, FDTensor& infer_result, + bool contain_score_map) { + // output with argmax channel is 1 + int channel = 1; + int height = infer_result.shape[1]; + int width = infer_result.shape[2]; + + if (contain_score_map) { + // output without argmax and convent to NHWC + channel = infer_result.shape[3]; + } + // create FP32 cvmat + if (infer_result.dtype == FDDataType::INT64) { + FDWARNING << "The PaddleSeg model is exported with argmax. Inference " + "result type is " + + Str(infer_result.dtype) + + ". If you want the edge of segmentation image more " + "smoother. Please export model with --without_argmax " + "--with_softmax." + << std::endl; + int64_t chw = channel * height * width; + int64_t* infer_result_buffer = static_cast(infer_result.Data()); + std::vector float_result_buffer(chw); + mat = cv::Mat(height, width, CV_32FC(channel)); + int index = 0; + for (int i = 0; i < height; i++) { + for (int j = 0; j < width; j++) { + mat.at(i, j) = + static_cast(infer_result_buffer[index++]); + } + } + } else if (infer_result.dtype == FDDataType::FP32) { + mat = cv::Mat(height, width, CV_32FC(channel), infer_result.Data()); + } +} + +} // namespace utils +} // namespace vision +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/utils/utils.h b/csrcs/fastdeploy/vision/utils/utils.h index e95e7e10b5..fb9c874d53 100644 --- a/csrcs/fastdeploy/vision/utils/utils.h +++ b/csrcs/fastdeploy/vision/utils/utils.h @@ -51,6 +51,73 @@ std::vector TopKIndices(const T* array, int array_size, int topk) { return res; } +template +void ArgmaxScoreMap(T infer_result_buffer, SegmentationResult* result, + bool with_softmax) { + int64_t height = result->shape[0]; + int64_t width = result->shape[1]; + int64_t num_classes = result->shape[2]; + int index = 0; + for (size_t i = 0; i < height; ++i) { + for (size_t j = 0; j < width; ++j) { + int64_t s = (i * width + j) * num_classes; + T max_class_score = std::max_element( + infer_result_buffer + s, infer_result_buffer + s + num_classes); + int label_id = std::distance(infer_result_buffer + s, max_class_score); + if (label_id >= 255) { + FDWARNING << "label_id is stored by uint8_t, now the value is bigger " + "than 255, it's " + << static_cast(label_id) << "." << std::endl; + } + result->label_map[index] = static_cast(label_id); + + if (with_softmax) { + double_t total = 0; + for (int k = 0; k < num_classes; k++) { + total += exp(*(infer_result_buffer + s + k) - *max_class_score); + } + double_t softmax_class_score = 1 / total; + result->score_map[index] = static_cast(softmax_class_score); + + } else { + result->score_map[index] = static_cast(*max_class_score); + } + index++; + } + } +} + +template +void NCHW2NHWC(FDTensor& infer_result) { + T* infer_result_buffer = reinterpret_cast(infer_result.MutableData()); + int num = infer_result.shape[0]; + int channel = infer_result.shape[1]; + int height = infer_result.shape[2]; + int width = infer_result.shape[3]; + int chw = channel * height * width; + int wc = width * channel; + int wh = width * height; + std::vector hwc_data(chw); + int index = 0; + for (int n = 0; n < num; n++) { + for (int c = 0; c < channel; c++) { + for (int h = 0; h < height; h++) { + for (int w = 0; w < width; w++) { + hwc_data[n * chw + h * wc + w * channel + c] = + *(infer_result_buffer + index); + index++; + } + } + } + } + std::memcpy(infer_result.MutableData(), hwc_data.data(), + num * chw * sizeof(T)); + infer_result.shape = {num, height, width, channel}; +} + +void FDTensor2FP32CVMat(cv::Mat& mat, FDTensor& infer_result, + bool contain_score_map); + void NMS(DetectionResult* output, float iou_threshold = 0.5); void NMS(FaceDetectionResult* result, float iou_threshold = 0.5); diff --git a/csrcs/fastdeploy/vision/vision_pybind.cc b/csrcs/fastdeploy/vision/vision_pybind.cc index 79aa876351..3648105a44 100644 --- a/csrcs/fastdeploy/vision/vision_pybind.cc +++ b/csrcs/fastdeploy/vision/vision_pybind.cc @@ -59,7 +59,10 @@ void BindVision(pybind11::module& m) { .def("__str__", &vision::FaceDetectionResult::Str); pybind11::class_(m, "SegmentationResult") .def(pybind11::init()) - .def_readwrite("masks", &vision::SegmentationResult::masks) + .def_readwrite("label_map", &vision::SegmentationResult::label_map) + .def_readwrite("score_map", &vision::SegmentationResult::score_map) + .def_readwrite("shape", &vision::SegmentationResult::shape) + .def_readwrite("shape", &vision::SegmentationResult::shape) .def("__repr__", &vision::SegmentationResult::Str) .def("__str__", &vision::SegmentationResult::Str); diff --git a/csrcs/fastdeploy/vision/visualize/segmentation.cc b/csrcs/fastdeploy/vision/visualize/segmentation.cc index b1b142fc08..1fba09a131 100644 --- a/csrcs/fastdeploy/vision/visualize/segmentation.cc +++ b/csrcs/fastdeploy/vision/visualize/segmentation.cc @@ -25,14 +25,14 @@ void Visualize::VisSegmentation(const cv::Mat& im, const SegmentationResult& result, cv::Mat* vis_img, const int& num_classes) { auto color_map = GetColorMap(num_classes); - int64_t height = result.masks.size(); - int64_t width = result.masks[1].size(); + int64_t height = result.shape[0]; + int64_t width = result.shape[1]; *vis_img = cv::Mat::zeros(height, width, CV_8UC3); int64_t index = 0; for (int i = 0; i < height; i++) { for (int j = 0; j < width; j++) { - int category_id = static_cast(result.masks[i][j]); + int category_id = result.label_map[index++]; vis_img->at(i, j)[0] = color_map[3 * category_id + 0]; vis_img->at(i, j)[1] = color_map[3 * category_id + 1]; vis_img->at(i, j)[2] = color_map[3 * category_id + 2]; diff --git a/fastdeploy/vision/ppseg/__init__.py b/fastdeploy/vision/ppseg/__init__.py index b580c01455..dbc826722d 100644 --- a/fastdeploy/vision/ppseg/__init__.py +++ b/fastdeploy/vision/ppseg/__init__.py @@ -35,3 +35,25 @@ def __init__(self, def predict(self, input_image): return self._model.predict(input_image) + + @property + def with_softmax(self): + return self._model.with_softmax + + @with_softmax.setter + def with_softmax(self, value): + assert isinstance( + value, + bool), "The value to set `with_softmax` must be type of bool." + self._model.with_softmax = value + + @property + def is_vertical_screen(self): + return self._model.is_vertical_screen + + @is_vertical_screen.setter + def is_vertical_screen(self, value): + assert isinstance( + value, + bool), "The value to set `is_vertical_screen` must be type of bool." + self._model.is_vertical_screen = value diff --git a/model_zoo/vision/ppseg/ppseg_unet.py b/model_zoo/vision/ppseg/ppseg_unet.py index c279e0a8fd..750e2167ba 100644 --- a/model_zoo/vision/ppseg/ppseg_unet.py +++ b/model_zoo/vision/ppseg/ppseg_unet.py @@ -5,18 +5,8 @@ # 下载模型和测试图片 model_url = "https://github.com/felixhjh/Fastdeploy-Models/raw/main/unet_Cityscapes.tar.gz" test_jpg_url = "https://paddleseg.bj.bcebos.com/dygraph/demo/cityscapes_demo.png" -fd.download(model_url, ".", show_progress=True) +fd.download_and_decompress(model_url, ".") fd.download(test_jpg_url, ".", show_progress=True) - -try: - tar = tarfile.open("unet_Cityscapes.tar.gz", "r:gz") - file_names = tar.getnames() - for file_name in file_names: - tar.extract(file_name, ".") - tar.close() -except Exception as e: - raise Exception(e) - # 加载模型 model = fd.vision.ppseg.Model("./unet_Cityscapes/model.pdmodel", "./unet_Cityscapes/model.pdiparams",