From a8458e67293910eac141e8d19f43518b614ca87f Mon Sep 17 00:00:00 2001 From: huangjianhui <852142024@qq.com> Date: Thu, 21 Jul 2022 15:38:21 +0800 Subject: [PATCH] Add new model PaddleSeg (#30) * Support new model PaddleSeg * Fix conflict * PaddleSeg add visulization function * fix bug * Fix BindPPSeg wrong name * Fix variable name * Update by comments * Add ppseg-unet example python version Co-authored-by: Jason --- examples/vision/ppseg_unet.cc | 59 ++++++++ fastdeploy/vision.h | 1 + fastdeploy/vision/__init__.py | 1 + fastdeploy/vision/common/result.cc | 22 +++ fastdeploy/vision/common/result.h | 13 ++ fastdeploy/vision/ppseg/__init__.py | 37 +++++ fastdeploy/vision/ppseg/model.cc | 140 ++++++++++++++++++ fastdeploy/vision/ppseg/model.h | 35 +++++ fastdeploy/vision/ppseg/ppseg_pybind.cc | 30 ++++ fastdeploy/vision/vision_pybind.cc | 8 + fastdeploy/vision/visualize/__init__.py | 5 + fastdeploy/vision/visualize/segmentation.cc | 46 ++++++ fastdeploy/vision/visualize/visualize.h | 7 +- .../vision/visualize/visualize_pybind.cc | 21 ++- model_zoo/vision/ppseg/ppseg_unet.py | 36 +++++ 15 files changed, 453 insertions(+), 8 deletions(-) create mode 100644 examples/vision/ppseg_unet.cc create mode 100644 fastdeploy/vision/ppseg/__init__.py create mode 100644 fastdeploy/vision/ppseg/model.cc create mode 100644 fastdeploy/vision/ppseg/model.h create mode 100644 fastdeploy/vision/ppseg/ppseg_pybind.cc create mode 100644 fastdeploy/vision/visualize/segmentation.cc create mode 100644 model_zoo/vision/ppseg/ppseg_unet.py diff --git a/examples/vision/ppseg_unet.cc b/examples/vision/ppseg_unet.cc new file mode 100644 index 0000000000..cb33611ad4 --- /dev/null +++ b/examples/vision/ppseg_unet.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" +#include "yaml-cpp/yaml.h" + +int main() { + namespace vis = fastdeploy::vision; + + std::string model_file = "../resources/models/unet_Cityscapes/model.pdmodel"; + std::string params_file = + "../resources/models/unet_Cityscapes/model.pdiparams"; + std::string config_file = "../resources/models/unet_Cityscapes/deploy.yaml"; + std::string img_path = "../resources/images/cityscapes_demo.png"; + std::string vis_path = "../resources/outputs/vis.jpeg"; + + auto model = vis::ppseg::Model(model_file, params_file, config_file); + if (!model.Initialized()) { + std::cerr << "Init Failed." << std::endl; + return -1; + } + + cv::Mat im = cv::imread(img_path); + cv::Mat vis_im; + + vis::SegmentationResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Prediction Failed." << std::endl; + return -1; + } else { + std::cout << "Prediction Done!" << std::endl; + } + + // 输出预测框结果 + std::cout << res.Str() << std::endl; + + YAML::Node cfg = YAML::LoadFile(config_file); + int num_classes = 19; + if (cfg["Deploy"]["num_classes"]) { + num_classes = cfg["Deploy"]["num_classes"].as(); + } + + // 可视化预测结果 + vis::Visualize::VisSegmentation(im, res, &vis_im, num_classes); + cv::imwrite(vis_path, vis_im); + std::cout << "Inference Done! Saved: " << vis_path << std::endl; + return 0; +} diff --git a/fastdeploy/vision.h b/fastdeploy/vision.h index 68c0881cac..d539482a72 100644 --- a/fastdeploy/vision.h +++ b/fastdeploy/vision.h @@ -19,6 +19,7 @@ #include "fastdeploy/vision/meituan/yolov6.h" #include "fastdeploy/vision/ppcls/model.h" #include "fastdeploy/vision/ppdet/ppyoloe.h" +#include "fastdeploy/vision/ppseg/model.h" #include "fastdeploy/vision/ultralytics/yolov5.h" #include "fastdeploy/vision/wongkinyiu/yolor.h" #include "fastdeploy/vision/wongkinyiu/yolov7.h" diff --git a/fastdeploy/vision/__init__.py b/fastdeploy/vision/__init__.py index 6acbf0c376..08b0d68124 100644 --- a/fastdeploy/vision/__init__.py +++ b/fastdeploy/vision/__init__.py @@ -16,6 +16,7 @@ from . import evaluation from . import ppcls from . import ppdet +from . import ppseg from . import ultralytics from . import meituan from . import megvii diff --git a/fastdeploy/vision/common/result.cc b/fastdeploy/vision/common/result.cc index ece0973c0c..06a85ea454 100644 --- a/fastdeploy/vision/common/result.cc +++ b/fastdeploy/vision/common/result.cc @@ -72,5 +72,27 @@ std::string DetectionResult::Str() { return out; } +void SegmentationResult::Clear() { + std::vector>().swap(masks); +} + +void SegmentationResult::Resize(int64_t height, int64_t width) { + masks.resize(height, std::vector(width)); +} + +std::string SegmentationResult::Str() { + std::string out; + out = "SegmentationResult Image masks 10 rows x 10 cols: \n"; + for (size_t i = 0; i < 10; ++i) { + out += "["; + for (size_t j = 0; j < 10; ++j) { + out = out + std::to_string(masks[i][j]) + ", "; + } + out += ".....]\n"; + } + out += "...........\n"; + return out; +} + } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/common/result.h b/fastdeploy/vision/common/result.h index 22227a26cb..7ff104250f 100644 --- a/fastdeploy/vision/common/result.h +++ b/fastdeploy/vision/common/result.h @@ -56,5 +56,18 @@ struct FASTDEPLOY_DECL DetectionResult : public BaseResult { std::string Str(); }; +struct FASTDEPLOY_DECL SegmentationResult : public BaseResult { + // mask + std::vector> masks; + + ResultType type = ResultType::SEGMENTATION; + + void Clear(); + + void Resize(int64_t height, int64_t width); + + std::string Str(); +}; + } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/ppseg/__init__.py b/fastdeploy/vision/ppseg/__init__.py new file mode 100644 index 0000000000..b580c01455 --- /dev/null +++ b/fastdeploy/vision/ppseg/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +import logging +from ... import FastDeployModel, Frontend +from ... import fastdeploy_main as C + + +class Model(FastDeployModel): + def __init__(self, + model_file, + params_file, + config_file, + backend_option=None, + model_format=Frontend.PADDLE): + super(Model, self).__init__(backend_option) + + assert model_format == Frontend.PADDLE, "PaddleSeg only support model format of Frontend.Paddle now." + self._model = C.vision.ppseg.Model(model_file, params_file, + config_file, self._runtime_option, + model_format) + assert self.initialized, "PaddleSeg model initialize failed." + + def predict(self, input_image): + return self._model.predict(input_image) diff --git a/fastdeploy/vision/ppseg/model.cc b/fastdeploy/vision/ppseg/model.cc new file mode 100644 index 0000000000..268d85f7d3 --- /dev/null +++ b/fastdeploy/vision/ppseg/model.cc @@ -0,0 +1,140 @@ +#include "fastdeploy/vision/ppseg/model.h" +#include "fastdeploy/vision.h" +#include "fastdeploy/vision/utils/utils.h" +#include "yaml-cpp/yaml.h" + +namespace fastdeploy { +namespace vision { +namespace ppseg { + +Model::Model(const std::string& model_file, const std::string& params_file, + const std::string& config_file, const RuntimeOption& custom_option, + const Frontend& model_format) { + config_file_ = config_file; + valid_cpu_backends = {Backend::ORT, Backend::PDINFER}; + valid_gpu_backends = {Backend::ORT, Backend::PDINFER}; + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + initialized = Initialize(); +} + +bool Model::Initialize() { + if (!BuildPreprocessPipelineFromConfig()) { + FDERROR << "Failed to build preprocess pipeline from configuration file." + << std::endl; + return false; + } + if (!InitRuntime()) { + FDERROR << "Failed to initialize fastdeploy backend." << std::endl; + return false; + } + return true; +} + +bool Model::BuildPreprocessPipelineFromConfig() { + processors_.clear(); + YAML::Node cfg; + processors_.push_back(std::make_shared()); + try { + cfg = YAML::LoadFile(config_file_); + } catch (YAML::BadFile& e) { + FDERROR << "Failed to load yaml file " << config_file_ + << ", maybe you should check this file." << std::endl; + return false; + } + + if (cfg["Deploy"]["transforms"]) { + auto preprocess_cfg = cfg["Deploy"]["transforms"]; + for (const auto& op : preprocess_cfg) { + FDASSERT(op.IsMap(), + "Require the transform information in yaml be Map type."); + if (op["type"].as() == "Normalize") { + std::vector mean = {0.5, 0.5, 0.5}; + std::vector std = {0.5, 0.5, 0.5}; + if (op["mean"]) { + mean = op["mean"].as>(); + } + if (op["std"]) { + std = op["std"].as>(); + } + processors_.push_back(std::make_shared(mean, std)); + + } else if (op["type"].as() == "Resize") { + const auto& target_size = op["target_size"]; + int resize_width = target_size[0].as(); + int resize_height = target_size[1].as(); + processors_.push_back( + std::make_shared(resize_width, resize_height)); + } + } + processors_.push_back(std::make_shared()); + } + return true; +} + +bool Model::Preprocess(Mat* mat, FDTensor* output) { + for (size_t i = 0; i < processors_.size(); ++i) { + if (!(*(processors_[i].get()))(mat)) { + FDERROR << "Failed to process image data in " << processors_[i]->Name() + << "." << std::endl; + return false; + } + } + mat->ShareWithTensor(output); + output->shape.insert(output->shape.begin(), 1); + output->name = InputInfoOfRuntime(0).name; + return true; +} + +bool Model::Postprocess(const FDTensor& infer_result, + SegmentationResult* result) { + FDASSERT(infer_result.dtype == FDDataType::INT64, + "Require the data type of output is int64, but now it's " + + Str(const_cast(infer_result.dtype)) + + "."); + result->Clear(); + std::vector output_shape = infer_result.shape; + int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, + std::multiplies()); + const int64_t* infer_result_buffer = + reinterpret_cast(infer_result.data.data()); + int64_t height = output_shape[1]; + int64_t width = output_shape[2]; + result->Resize(height, width); + for (int64_t i = 0; i < height; i++) { + int64_t begin = i * width; + int64_t end = (i + 1) * width - 1; + std::copy(infer_result_buffer + begin, infer_result_buffer + end, + result->masks[i].begin()); + } + + return true; +} + +bool Model::Predict(cv::Mat* im, SegmentationResult* result) { + Mat mat(*im); + std::vector processed_data(1); + if (!Preprocess(&mat, &(processed_data[0]))) { + FDERROR << "Failed to preprocess input data while using model:" + << ModelName() << "." << std::endl; + return false; + } + std::vector infer_result(1); + if (!Infer(processed_data, &infer_result)) { + FDERROR << "Failed to inference while using model:" << ModelName() << "." + << std::endl; + return false; + } + if (!Postprocess(infer_result[0], result)) { + FDERROR << "Failed to postprocess while using model:" << ModelName() << "." + << std::endl; + return false; + } + return true; +} + +} // namespace ppseg +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/ppseg/model.h b/fastdeploy/vision/ppseg/model.h new file mode 100644 index 0000000000..c0ca5a70d0 --- /dev/null +++ b/fastdeploy/vision/ppseg/model.h @@ -0,0 +1,35 @@ +#pragma once +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { +namespace vision { +namespace ppseg { + +class FASTDEPLOY_DECL Model : public FastDeployModel { + public: + Model(const std::string& model_file, const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option = RuntimeOption(), + const Frontend& model_format = Frontend::PADDLE); + + std::string ModelName() const { return "ppseg"; } + + virtual bool Predict(cv::Mat* im, SegmentationResult* result); + + private: + bool Initialize(); + + bool BuildPreprocessPipelineFromConfig(); + + bool Preprocess(Mat* mat, FDTensor* outputs); + + bool Postprocess(const FDTensor& infer_result, SegmentationResult* result); + + std::vector> processors_; + std::string config_file_; +}; +} // namespace ppseg +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/ppseg/ppseg_pybind.cc b/fastdeploy/vision/ppseg/ppseg_pybind.cc new file mode 100644 index 0000000000..60022f914b --- /dev/null +++ b/fastdeploy/vision/ppseg/ppseg_pybind.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { +void BindPPSeg(pybind11::module& m) { + auto ppseg_module = + m.def_submodule("ppseg", "Module to deploy PaddleSegmentation."); + pybind11::class_(ppseg_module, "Model") + .def(pybind11::init()) + .def("predict", [](vision::ppseg::Model& self, pybind11::array& data) { + auto mat = PyArrayToCvMat(data); + vision::SegmentationResult res; + self.Predict(&mat, &res); + return res; + }); +} +} // namespace fastdeploy diff --git a/fastdeploy/vision/vision_pybind.cc b/fastdeploy/vision/vision_pybind.cc index 0334303ce6..22c4f0bc2e 100644 --- a/fastdeploy/vision/vision_pybind.cc +++ b/fastdeploy/vision/vision_pybind.cc @@ -19,6 +19,7 @@ namespace fastdeploy { void BindPPCls(pybind11::module& m); void BindPPDet(pybind11::module& m); void BindWongkinyiu(pybind11::module& m); +void BindPPSeg(pybind11::module& m); void BindUltralytics(pybind11::module& m); void BindMeituan(pybind11::module& m); void BindMegvii(pybind11::module& m); @@ -42,8 +43,15 @@ void BindVision(pybind11::module& m) { .def("__repr__", &vision::DetectionResult::Str) .def("__str__", &vision::DetectionResult::Str); + pybind11::class_(m, "SegmentationResult") + .def(pybind11::init()) + .def_readwrite("masks", &vision::SegmentationResult::masks) + .def("__repr__", &vision::SegmentationResult::Str) + .def("__str__", &vision::SegmentationResult::Str); + BindPPCls(m); BindPPDet(m); + BindPPSeg(m); BindUltralytics(m); BindWongkinyiu(m); BindMeituan(m); diff --git a/fastdeploy/vision/visualize/__init__.py b/fastdeploy/vision/visualize/__init__.py index 384ec2768f..7d1bcc8926 100644 --- a/fastdeploy/vision/visualize/__init__.py +++ b/fastdeploy/vision/visualize/__init__.py @@ -19,3 +19,8 @@ def vis_detection(im_data, det_result, line_size=1, font_size=0.5): C.vision.Visualize.vis_detection(im_data, det_result, line_size, font_size) + + +def vis_segmentation(im_data, seg_result, vis_im_data, num_classes=1000): + C.vision.Visualize.vis_segmentation(im_data, seg_result, vis_im_data, + num_classes) diff --git a/fastdeploy/vision/visualize/segmentation.cc b/fastdeploy/vision/visualize/segmentation.cc new file mode 100644 index 0000000000..b1b142fc08 --- /dev/null +++ b/fastdeploy/vision/visualize/segmentation.cc @@ -0,0 +1,46 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef ENABLE_VISION_VISUALIZE + +#include "fastdeploy/vision/visualize/visualize.h" +#include "opencv2/highgui.hpp" +#include "opencv2/imgproc/imgproc.hpp" + +namespace fastdeploy { +namespace vision { + +void Visualize::VisSegmentation(const cv::Mat& im, + const SegmentationResult& result, + cv::Mat* vis_img, const int& num_classes) { + auto color_map = GetColorMap(num_classes); + int64_t height = result.masks.size(); + int64_t width = result.masks[1].size(); + *vis_img = cv::Mat::zeros(height, width, CV_8UC3); + + int64_t index = 0; + for (int i = 0; i < height; i++) { + for (int j = 0; j < width; j++) { + int category_id = static_cast(result.masks[i][j]); + vis_img->at(i, j)[0] = color_map[3 * category_id + 0]; + vis_img->at(i, j)[1] = color_map[3 * category_id + 1]; + vis_img->at(i, j)[2] = color_map[3 * category_id + 2]; + } + } + cv::addWeighted(im, .5, *vis_img, .5, 0, *vis_img); +} + +} // namespace vision +} // namespace fastdeploy +#endif diff --git a/fastdeploy/vision/visualize/visualize.h b/fastdeploy/vision/visualize/visualize.h index 6fffa521a6..1eb212c2b9 100644 --- a/fastdeploy/vision/visualize/visualize.h +++ b/fastdeploy/vision/visualize/visualize.h @@ -27,8 +27,11 @@ class FASTDEPLOY_DECL Visualize { static const std::vector& GetColorMap(int num_classes = 1000); static void VisDetection(cv::Mat* im, const DetectionResult& result, int line_size = 2, float font_size = 0.5f); + static void VisSegmentation(const cv::Mat& im, + const SegmentationResult& result, + cv::Mat* vis_img, const int& num_classes = 1000); }; -} // namespace vision -} // namespace fastdeploy +} // namespace vision +} // namespace fastdeploy #endif diff --git a/fastdeploy/vision/visualize/visualize_pybind.cc b/fastdeploy/vision/visualize/visualize_pybind.cc index 66ffc74f9f..5d5eb2388d 100644 --- a/fastdeploy/vision/visualize/visualize_pybind.cc +++ b/fastdeploy/vision/visualize/visualize_pybind.cc @@ -18,11 +18,20 @@ namespace fastdeploy { void BindVisualize(pybind11::module& m) { pybind11::class_(m, "Visualize") .def(pybind11::init<>()) - .def_static("vis_detection", [](pybind11::array& im_data, - vision::DetectionResult& result, - int line_size, float font_size) { - auto im = PyArrayToCvMat(im_data); - vision::Visualize::VisDetection(&im, result, line_size, font_size); + .def_static("vis_detection", + [](pybind11::array& im_data, vision::DetectionResult& result, + int line_size, float font_size) { + auto im = PyArrayToCvMat(im_data); + vision::Visualize::VisDetection(&im, result, line_size, + font_size); + }) + .def_static("vis_segmentation", [](pybind11::array& im_data, + vision::SegmentationResult& result, + pybind11::array& vis_im_data, + const int& num_classes) { + cv::Mat im = PyArrayToCvMat(im_data); + cv::Mat vis_im = PyArrayToCvMat(vis_im_data); + vision::Visualize::VisSegmentation(im, result, &vis_im, num_classes); }); } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/model_zoo/vision/ppseg/ppseg_unet.py b/model_zoo/vision/ppseg/ppseg_unet.py new file mode 100644 index 0000000000..c279e0a8fd --- /dev/null +++ b/model_zoo/vision/ppseg/ppseg_unet.py @@ -0,0 +1,36 @@ +import fastdeploy as fd +import cv2 +import tarfile + +# 下载模型和测试图片 +model_url = "https://github.com/felixhjh/Fastdeploy-Models/raw/main/unet_Cityscapes.tar.gz" +test_jpg_url = "https://paddleseg.bj.bcebos.com/dygraph/demo/cityscapes_demo.png" +fd.download(model_url, ".", show_progress=True) +fd.download(test_jpg_url, ".", show_progress=True) + +try: + tar = tarfile.open("unet_Cityscapes.tar.gz", "r:gz") + file_names = tar.getnames() + for file_name in file_names: + tar.extract(file_name, ".") + tar.close() +except Exception as e: + raise Exception(e) + +# 加载模型 +model = fd.vision.ppseg.Model("./unet_Cityscapes/model.pdmodel", + "./unet_Cityscapes/model.pdiparams", + "./unet_Cityscapes/deploy.yaml") + +# 预测图片 +im = cv2.imread("./cityscapes_demo.png") +result = model.predict(im) + +vis_im = im.copy() +# 可视化结果 +fd.vision.visualize.vis_segmentation(im, result, vis_im) +cv2.imwrite("vis_img.png", vis_im) + +# 输出预测结果 +print(result) +print(model.runtime_option)