diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 1e2dc43bd4..112193c86a 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1,24 +1,26 @@ -function(add_fastdeploy_executable field url model) +function(add_fastdeploy_executable FIELD CC_FILE) # temp target name/file var in function scope - set(TEMP_TARGET_FILE ${PROJECT_SOURCE_DIR}/examples/${field}/${url}_${model}.cc) - set(TEMP_TARGET_NAME ${field}_${url}_${model}) + set(TEMP_TARGET_FILE ${CC_FILE}) + string(REGEX MATCHALL "[0-9A-Za-z_]*.cc" FILE_NAME ${CC_FILE}) + string(REGEX REPLACE ".cc" "" FILE_PREFIX ${FILE_NAME}) + set(TEMP_TARGET_NAME ${FIELD}_${FILE_PREFIX}) if (EXISTS ${TEMP_TARGET_FILE} AND TARGET fastdeploy) add_executable(${TEMP_TARGET_NAME} ${TEMP_TARGET_FILE}) target_link_libraries(${TEMP_TARGET_NAME} PUBLIC fastdeploy) - message(STATUS "Found source file: [${field}/${url}_${model}.cc], ADD!!! fastdeploy executable: [${TEMP_TARGET_NAME}] !") - else () - message(WARNING "Can not found source file: [${field}/${url}_${model}.cc], SKIP!!! fastdeploy executable: [${TEMP_TARGET_NAME}] !") + message(STATUS " Added FastDeploy Executable : ${TEMP_TARGET_NAME}") endif() unset(TEMP_TARGET_FILE) unset(TEMP_TARGET_NAME) endfunction() # vision examples -if (WITH_VISION_EXAMPLES) - add_fastdeploy_executable(vision ultralytics yolov5) - add_fastdeploy_executable(vision meituan yolov6) - add_fastdeploy_executable(vision wongkinyiu yolov7) - add_fastdeploy_executable(vision megvii yolox) +if(WITH_VISION_EXAMPLES AND EXISTS ${PROJECT_SOURCE_DIR}/examples/vision) + message(STATUS "") + message(STATUS "*************FastDeploy Examples Summary**********") + file(GLOB ALL_VISION_EXAMPLE_SRCS ${PROJECT_SOURCE_DIR}/examples/vision/*.cc) + foreach(_CC_FILE ${ALL_VISION_EXAMPLE_SRCS}) + add_fastdeploy_executable(vision ${_CC_FILE}) + endforeach() endif() -# other examples ... \ No newline at end of file +# other examples ... diff --git a/examples/vision/ppdet_ppyoloe.cc b/examples/vision/ppdet_ppyoloe.cc new file mode 100644 index 0000000000..b234021c92 --- /dev/null +++ b/examples/vision/ppdet_ppyoloe.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +int main() { + namespace vis = fastdeploy::vision; + + std::string model_file = "ppyoloe_crn_l_300e_coco/model.pdmodel"; + std::string params_file = "ppyoloe_crn_l_300e_coco/model.pdiparams"; + std::string config_file = "ppyoloe_crn_l_300e_coco/infer_cfg.yml"; + std::string img_path = "test.jpeg"; + std::string vis_path = "vis.jpeg"; + + auto model = vis::ppdet::PPYOLOE(model_file, params_file, config_file); + if (!model.Initialized()) { + std::cerr << "Init Failed." << std::endl; + return -1; + } + + cv::Mat im = cv::imread(img_path); + cv::Mat vis_im = im.clone(); + + vis::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Prediction Failed." << std::endl; + return -1; + } else { + std::cout << "Prediction Done!" << std::endl; + } + + // 输出预测框结果 + std::cout << res.Str() << std::endl; + + // 可视化预测结果 + vis::Visualize::VisDetection(&vis_im, res); + cv::imwrite(vis_path, vis_im); + std::cout << "Detect Done! Saved: " << vis_path << std::endl; + return 0; +} diff --git a/fastdeploy/__init__.py b/fastdeploy/__init__.py index 500e7cc42a..68006c1bed 100644 --- a/fastdeploy/__init__.py +++ b/fastdeploy/__init__.py @@ -17,7 +17,7 @@ from .fastdeploy_runtime import * from . import fastdeploy_main as C from . import vision -from .download import download +from .download import download, download_and_decompress def TensorInfoStr(tensor_info): diff --git a/fastdeploy/download.py b/fastdeploy/download.py index e00af098df..67f21d8e76 100644 --- a/fastdeploy/download.py +++ b/fastdeploy/download.py @@ -156,7 +156,7 @@ def decompress(fname): def url2dir(url, path, rename=None): full_name = download(url, path, rename, show_progress=True) - print("SDK is donwloaded, now extracting...") + print("File is donwloaded, now extracting...") if url.count(".tgz") > 0 or url.count(".tar") > 0 or url.count("zip") > 0: return decompress(full_name) diff --git a/fastdeploy/utils/utils.h b/fastdeploy/utils/utils.h index 1b9f625b5e..9312084265 100644 --- a/fastdeploy/utils/utils.h +++ b/fastdeploy/utils/utils.h @@ -64,6 +64,10 @@ class FASTDEPLOY_DECL FDLogger { bool verbose_ = true; }; +#ifndef __REL_FILE__ +#define __REL_FILE__ __FILE__ +#endif + #define FDERROR \ FDLogger(true, "[ERROR]") \ << __REL_FILE__ << "(" << __LINE__ << ")::" << __FUNCTION__ << "\t" diff --git a/fastdeploy/vision.h b/fastdeploy/vision.h index ac3f006c0a..cafe310c70 100644 --- a/fastdeploy/vision.h +++ b/fastdeploy/vision.h @@ -16,6 +16,7 @@ #include "fastdeploy/core/config.h" #ifdef ENABLE_VISION #include "fastdeploy/vision/ppcls/model.h" +#include "fastdeploy/vision/ppdet/ppyoloe.h" #include "fastdeploy/vision/ultralytics/yolov5.h" #include "fastdeploy/vision/wongkinyiu/yolov7.h" #include "fastdeploy/vision/meituan/yolov6.h" diff --git a/fastdeploy/vision/__init__.py b/fastdeploy/vision/__init__.py index 7122bede0b..6acbf0c376 100644 --- a/fastdeploy/vision/__init__.py +++ b/fastdeploy/vision/__init__.py @@ -15,6 +15,7 @@ from . import evaluation from . import ppcls +from . import ppdet from . import ultralytics from . import meituan from . import megvii diff --git a/fastdeploy/vision/common/processors/convert.cc b/fastdeploy/vision/common/processors/convert.cc new file mode 100644 index 0000000000..a7ca6de07a --- /dev/null +++ b/fastdeploy/vision/common/processors/convert.cc @@ -0,0 +1,62 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/common/processors/convert.h" + +namespace fastdeploy { + +namespace vision { + +Convert::Convert(const std::vector& alpha, + const std::vector& beta) { + FDASSERT(alpha.size() == beta.size(), + "Convert: requires the size of alpha equal to the size of beta."); + FDASSERT(alpha.size() != 0, + "Convert: requires the size of alpha and beta > 0."); + alpha_.assign(alpha.begin(), alpha.end()); + beta_.assign(beta.begin(), beta.end()); +} + +bool Convert::CpuRun(Mat* mat) { + cv::Mat* im = mat->GetCpuMat(); + std::vector split_im; + cv::split(*im, split_im); + for (int c = 0; c < im->channels(); c++) { + split_im[c].convertTo(split_im[c], CV_32FC1, alpha_[c], beta_[c]); + } + cv::merge(split_im, *im); + return true; +} + +#ifdef ENABLE_OPENCV_CUDA +bool Convert::GpuRun(Mat* mat) { + cv::cuda::GpuMat* im = mat->GetGpuMat(); + std::vector split_im; + cv::cuda::split(*im, split_im); + for (int c = 0; c < im->channels(); c++) { + split_im[c].convertTo(split_im[c], CV_32FC1, alpha_[c], beta_[c]); + } + cv::cuda::merge(split_im, *im); + return true; +} +#endif + +bool Convert::Run(Mat* mat, const std::vector& alpha, + const std::vector& beta, ProcLib lib) { + auto c = Convert(alpha, beta); + return c(mat, lib); +} + +} // namespace vision +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/vision/common/processors/convert.h b/fastdeploy/vision/common/processors/convert.h new file mode 100644 index 0000000000..5d5a5276f5 --- /dev/null +++ b/fastdeploy/vision/common/processors/convert.h @@ -0,0 +1,42 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fastdeploy/vision/common/processors/base.h" + +namespace fastdeploy { +namespace vision { +class Convert : public Processor { + public: + Convert(const std::vector& alpha, const std::vector& beta); + + bool CpuRun(Mat* mat); +#ifdef ENABLE_OPENCV_CUDA + bool GpuRun(Mat* mat); +#endif + std::string Name() { return "Convert"; } + + // Compute `result = mat * alpha + beta` directly by channel. + // The default behavior is the same as OpenCV's convertTo method. + static bool Run(Mat* mat, const std::vector& alpha, + const std::vector& beta, + ProcLib lib = ProcLib::OPENCV_CPU); + + private: + std::vector alpha_; + std::vector beta_; +}; +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/transform.h b/fastdeploy/vision/common/processors/transform.h index 12eec8d72d..08073b4e42 100644 --- a/fastdeploy/vision/common/processors/transform.h +++ b/fastdeploy/vision/common/processors/transform.h @@ -17,6 +17,7 @@ #include "fastdeploy/vision/common/processors/cast.h" #include "fastdeploy/vision/common/processors/center_crop.h" #include "fastdeploy/vision/common/processors/color_space_convert.h" +#include "fastdeploy/vision/common/processors/convert.h" #include "fastdeploy/vision/common/processors/hwc2chw.h" #include "fastdeploy/vision/common/processors/normalize.h" #include "fastdeploy/vision/common/processors/pad.h" diff --git a/fastdeploy/vision/meituan/yolov6.cc b/fastdeploy/vision/meituan/yolov6.cc index 8f37bf89c6..8ac7377194 100644 --- a/fastdeploy/vision/meituan/yolov6.cc +++ b/fastdeploy/vision/meituan/yolov6.cc @@ -25,14 +25,14 @@ namespace meituan { void LetterBox(Mat* mat, std::vector size, std::vector color, bool _auto, bool scale_fill = false, bool scale_up = true, int stride = 32) { - float scale = std::min(size[1] * 1.0f / static_cast(mat->Height()), - size[0] * 1.0f / static_cast(mat->Width())); + float scale = std::min(size[1] * 1.0f / static_cast(mat->Height()), + size[0] * 1.0f / static_cast(mat->Width())); if (!scale_up) { scale = std::min(scale, 1.0f); } int resize_h = int(round(static_cast(mat->Height()) * scale)); - int resize_w = int(round(static_cast(mat->Width()) * scale)); + int resize_w = int(round(static_cast(mat->Width()) * scale)); int pad_w = size[0] - resize_w; int pad_h = size[1] - resize_h; @@ -85,13 +85,13 @@ bool YOLOv6::Initialize() { is_scale_up = false; stride = 32; max_wh = 4096.0f; - + if (!InitRuntime()) { FDERROR << "Failed to initialize fastdeploy backend." << std::endl; return false; } - // Check if the input shape is dynamic after Runtime already initialized, - // Note that, We need to force is_mini_pad 'false' to keep static + // Check if the input shape is dynamic after Runtime already initialized, + // Note that, We need to force is_mini_pad 'false' to keep static // shape after padding (LetterBox) when the is_dynamic_shape is 'false'. is_dynamic_input_ = false; auto shape = InputInfoOfRuntime(0).shape; @@ -102,7 +102,7 @@ bool YOLOv6::Initialize() { break; } } - if (!is_dynamic_input_) { + if (!is_dynamic_input_) { is_mini_pad = false; } return true; @@ -111,15 +111,15 @@ bool YOLOv6::Initialize() { bool YOLOv6::Preprocess(Mat* mat, FDTensor* output, std::map>* im_info) { // process after image load - float ratio = std::min(size[1] * 1.0f / static_cast(mat->Height()), - size[0] * 1.0f / static_cast(mat->Width())); + float ratio = std::min(size[1] * 1.0f / static_cast(mat->Height()), + size[0] * 1.0f / static_cast(mat->Width())); if (ratio != 1.0) { int interp = cv::INTER_AREA; if (ratio > 1.0) { interp = cv::INTER_LINEAR; } int resize_h = int(round(static_cast(mat->Height()) * ratio)); - int resize_w = int(round(static_cast(mat->Width()) * ratio)); + int resize_w = int(round(static_cast(mat->Width()) * ratio)); Resize::Run(mat, resize_w, resize_h, -1, -1, interp); } // yolov6's preprocess steps @@ -129,8 +129,12 @@ bool YOLOv6::Preprocess(Mat* mat, FDTensor* output, LetterBox(mat, size, padding_value, is_mini_pad, is_no_pad, is_scale_up, stride); BGR2RGB::Run(mat); - Normalize::Run(mat, std::vector(mat->Channels(), 0.0), - std::vector(mat->Channels(), 1.0)); + // Normalize::Run(mat, std::vector(mat->Channels(), 0.0), + // std::vector(mat->Channels(), 1.0)); + // Compute `result = mat * alpha + beta` directly by channel + std::vector alpha = {1.0f / 255.0f, 1.0f / 255.0f, 1.0f / 255.0f}; + std::vector beta = {0.0f, 0.0f, 0.0f}; + Convert::Run(mat, alpha, beta); // Record output shape of preprocessed image (*im_info)["output_shape"] = {static_cast(mat->Height()), diff --git a/fastdeploy/vision/ppcls/model.cc b/fastdeploy/vision/ppcls/model.cc index 915cb97512..c4e5b767c7 100644 --- a/fastdeploy/vision/ppcls/model.cc +++ b/fastdeploy/vision/ppcls/model.cc @@ -1,3 +1,16 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "fastdeploy/vision/ppcls/model.h" #include "fastdeploy/vision/utils/utils.h" @@ -135,6 +148,6 @@ bool Model::Predict(cv::Mat* im, ClassifyResult* result, int topk) { return true; } -} // namespace ppcls -} // namespace vision -} // namespace fastdeploy +} // namespace ppcls +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/ppcls/model.h b/fastdeploy/vision/ppcls/model.h index 36841d74c6..265f92d32b 100644 --- a/fastdeploy/vision/ppcls/model.h +++ b/fastdeploy/vision/ppcls/model.h @@ -1,7 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #pragma once #include "fastdeploy/fastdeploy_model.h" -#include "fastdeploy/vision/common/result.h" #include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" namespace fastdeploy { namespace vision { diff --git a/fastdeploy/vision/ppcls/ppcls_pybind.cc b/fastdeploy/vision/ppcls/ppcls_pybind.cc index ef3fffee8e..1abc0b2b7c 100644 --- a/fastdeploy/vision/ppcls/ppcls_pybind.cc +++ b/fastdeploy/vision/ppcls/ppcls_pybind.cc @@ -14,7 +14,7 @@ #include "fastdeploy/pybind/main.h" namespace fastdeploy { -void BindPpClsModel(pybind11::module& m) { +void BindPPCls(pybind11::module& m) { auto ppcls_module = m.def_submodule("ppcls", "Module to deploy PaddleClas."); pybind11::class_(ppcls_module, "Model") .def(pybind11::init(ppdet_module, + "PPYOLOE") + .def(pybind11::init()) + .def("predict", [](vision::ppdet::PPYOLOE& self, pybind11::array& data, + float conf_threshold, float nms_iou_threshold) { + auto mat = PyArrayToCvMat(data); + vision::DetectionResult res; + self.Predict(&mat, &res, conf_threshold, nms_iou_threshold); + return res; + }); +} +} // namespace fastdeploy diff --git a/fastdeploy/vision/ppdet/ppyoloe.cc b/fastdeploy/vision/ppdet/ppyoloe.cc new file mode 100644 index 0000000000..c215ecb0ca --- /dev/null +++ b/fastdeploy/vision/ppdet/ppyoloe.cc @@ -0,0 +1,170 @@ +#include "fastdeploy/vision/ppdet/ppyoloe.h" +#include "fastdeploy/vision/utils/utils.h" +#include "yaml-cpp/yaml.h" + +namespace fastdeploy { +namespace vision { +namespace ppdet { + +PPYOLOE::PPYOLOE(const std::string& model_file, const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option, + const Frontend& model_format) { + config_file_ = config_file; + valid_cpu_backends = {Backend::ORT, Backend::PDINFER}; + valid_gpu_backends = {Backend::ORT, Backend::PDINFER}; + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + initialized = Initialize(); +} + +bool PPYOLOE::Initialize() { + if (!BuildPreprocessPipelineFromConfig()) { + std::cout << "Failed to build preprocess pipeline from configuration file." + << std::endl; + return false; + } + if (!InitRuntime()) { + std::cout << "Failed to initialize fastdeploy backend." << std::endl; + return false; + } + return true; +} + +bool PPYOLOE::BuildPreprocessPipelineFromConfig() { + processors_.clear(); + YAML::Node cfg; + try { + cfg = YAML::LoadFile(config_file_); + } catch (YAML::BadFile& e) { + std::cout << "Failed to load yaml file " << config_file_ + << ", maybe you should check this file." << std::endl; + return false; + } + + if (cfg["arch"].as() != "YOLO") { + std::cout << "Require the arch of model is YOLO, but arch defined in " + "config file is " + << cfg["arch"].as() << "." << std::endl; + return false; + } + processors_.push_back(std::make_shared()); + + for (const auto& op : cfg["Preprocess"]) { + std::string op_name = op["type"].as(); + if (op_name == "NormalizeImage") { + auto mean = op["mean"].as>(); + auto std = op["std"].as>(); + bool is_scale = op["is_scale"].as(); + processors_.push_back(std::make_shared(mean, std, is_scale)); + } else if (op_name == "Resize") { + bool keep_ratio = op["keep_ratio"].as(); + auto target_size = op["target_size"].as>(); + int interp = op["interp"].as(); + FDASSERT(target_size.size(), + "Require size of target_size be 2, but now it's " + + std::to_string(target_size.size()) + "."); + FDASSERT(!keep_ratio, + "Only support keep_ratio is false while deploy " + "PaddleDetection model."); + int width = target_size[1]; + int height = target_size[0]; + processors_.push_back( + std::make_shared(width, height, -1.0, -1.0, interp, false)); + } else if (op_name == "Permute") { + processors_.push_back(std::make_shared()); + } else { + std::cout << "Unexcepted preprocess operator: " << op_name << "." + << std::endl; + return false; + } + } + return true; +} + +bool PPYOLOE::Preprocess(Mat* mat, std::vector* outputs) { + int origin_w = mat->Width(); + int origin_h = mat->Height(); + for (size_t i = 0; i < processors_.size(); ++i) { + if (!(*(processors_[i].get()))(mat)) { + std::cout << "Failed to process image data in " << processors_[i]->Name() + << "." << std::endl; + return false; + } + } + + outputs->resize(2); + (*outputs)[0].name = InputInfoOfRuntime(0).name; + mat->ShareWithTensor(&((*outputs)[0])); + + // reshape to [1, c, h, w] + (*outputs)[0].shape.insert((*outputs)[0].shape.begin(), 1); + + (*outputs)[1].Allocate({1, 2}, FDDataType::FP32, InputInfoOfRuntime(1).name); + float* ptr = static_cast((*outputs)[1].MutableData()); + ptr[0] = mat->Height() * 1.0 / mat->Height(); + ptr[1] = mat->Width() * 1.0 / mat->Width(); + return true; +} + +bool PPYOLOE::Postprocess(std::vector& infer_result, + DetectionResult* result, float conf_threshold, + float nms_threshold) { + FDASSERT(infer_result[1].shape[0] == 1, + "Only support batch = 1 in FastDeploy now."); + int box_num = 0; + if (infer_result[1].dtype == FDDataType::INT32) { + box_num = *(static_cast(infer_result[1].Data())); + } else if (infer_result[1].dtype == FDDataType::INT64) { + box_num = *(static_cast(infer_result[1].Data())); + } else { + FDASSERT( + false, + "The output box_num of PPYOLOE model should be type of int32/int64."); + } + result->Reserve(box_num); + float* box_data = static_cast(infer_result[0].Data()); + for (size_t i = 0; i < box_num; ++i) { + if (box_data[i * 6 + 1] < conf_threshold) { + continue; + } + result->label_ids.push_back(box_data[i * 6]); + result->scores.push_back(box_data[i * 6 + 1]); + result->boxes.emplace_back( + std::array{box_data[i * 6 + 2], box_data[i * 6 + 3], + box_data[i * 6 + 4] - box_data[i * 6 + 2], + box_data[i * 6 + 5] - box_data[i * 6 + 3]}); + } + return true; +} + +bool PPYOLOE::Predict(cv::Mat* im, DetectionResult* result, + float conf_threshold, float iou_threshold) { + Mat mat(*im); + std::vector processed_data; + if (!Preprocess(&mat, &processed_data)) { + FDERROR << "Failed to preprocess input data while using model:" + << ModelName() << "." << std::endl; + return false; + } + + std::vector infer_result; + if (!Infer(processed_data, &infer_result)) { + FDERROR << "Failed to inference while using model:" << ModelName() << "." + << std::endl; + return false; + } + + if (!Postprocess(infer_result, result, conf_threshold, iou_threshold)) { + FDERROR << "Failed to postprocess while using model:" << ModelName() << "." + << std::endl; + return false; + } + return true; +} + +} // namespace ppdet +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/ppdet/ppyoloe.h b/fastdeploy/vision/ppdet/ppyoloe.h new file mode 100644 index 0000000000..a3db268ca4 --- /dev/null +++ b/fastdeploy/vision/ppdet/ppyoloe.h @@ -0,0 +1,44 @@ +#pragma once +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +#include "fastdeploy/vision/utils/utils.h" + +namespace fastdeploy { +namespace vision { +namespace ppdet { + +class FASTDEPLOY_DECL PPYOLOE : public FastDeployModel { + public: + PPYOLOE(const std::string& model_file, const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option = RuntimeOption(), + const Frontend& model_format = Frontend::PADDLE); + + std::string ModelName() const { return "PaddleDetection/PPYOLOE"; } + + virtual bool Initialize(); + + virtual bool BuildPreprocessPipelineFromConfig(); + + virtual bool Preprocess(Mat* mat, std::vector* outputs); + + virtual bool Postprocess(std::vector& infer_result, + DetectionResult* result, float conf_threshold, + float nms_threshold); + + virtual bool Predict(cv::Mat* im, DetectionResult* result, + float conf_threshold = 0.5, float nms_threshold = 0.7); + + private: + std::vector> processors_; + std::string config_file_; + // PaddleDetection can export model without nms + // This flag will help us to handle the different + // situation + bool has_nms_; +}; +} // namespace ppdet +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/ultralytics/yolov5.cc b/fastdeploy/vision/ultralytics/yolov5.cc index 193cfe9794..0b7e50e735 100644 --- a/fastdeploy/vision/ultralytics/yolov5.cc +++ b/fastdeploy/vision/ultralytics/yolov5.cc @@ -87,8 +87,8 @@ bool YOLOv5::Initialize() { FDERROR << "Failed to initialize fastdeploy backend." << std::endl; return false; } - // Check if the input shape is dynamic after Runtime already initialized, - // Note that, We need to force is_mini_pad 'false' to keep static + // Check if the input shape is dynamic after Runtime already initialized, + // Note that, We need to force is_mini_pad 'false' to keep static // shape after padding (LetterBox) when the is_dynamic_shape is 'false'. is_dynamic_input_ = false; auto shape = InputInfoOfRuntime(0).shape; @@ -99,7 +99,7 @@ bool YOLOv5::Initialize() { break; } } - if (!is_dynamic_input_) { + if (!is_dynamic_input_) { is_mini_pad = false; } return true; @@ -126,8 +126,12 @@ bool YOLOv5::Preprocess(Mat* mat, FDTensor* output, LetterBox(mat, size, padding_value, is_mini_pad, is_no_pad, is_scale_up, stride); BGR2RGB::Run(mat); - Normalize::Run(mat, std::vector(mat->Channels(), 0.0), - std::vector(mat->Channels(), 1.0)); + // Normalize::Run(mat, std::vector(mat->Channels(), 0.0), + // std::vector(mat->Channels(), 1.0)); + // Compute `result = mat * alpha + beta` directly by channel + std::vector alpha = {1.0f / 255.0f, 1.0f / 255.0f, 1.0f / 255.0f}; + std::vector beta = {0.0f, 0.0f, 0.0f}; + Convert::Run(mat, alpha, beta); // Record output shape of preprocessed image (*im_info)["output_shape"] = {static_cast(mat->Height()), @@ -198,6 +202,11 @@ bool YOLOv5::Postprocess( result->scores.push_back(confidence); } } + + if (result->boxes.size() == 0) { + return true; + } + utils::NMS(result, nms_iou_threshold); // scale the boxes to the origin image shape diff --git a/fastdeploy/vision/utils/sort_det_res.cc b/fastdeploy/vision/utils/sort_det_res.cc index e4a0db9761..93dbb69694 100644 --- a/fastdeploy/vision/utils/sort_det_res.cc +++ b/fastdeploy/vision/utils/sort_det_res.cc @@ -68,7 +68,11 @@ void MergeSort(DetectionResult* result, size_t low, size_t high) { void SortDetectionResult(DetectionResult* result) { size_t low = 0; - size_t high = result->scores.size() - 1; + size_t high = result->scores.size(); + if (high == 0) { + return; + } + high = high - 1; MergeSort(result, low, high); } diff --git a/fastdeploy/vision/vision_pybind.cc b/fastdeploy/vision/vision_pybind.cc index 41ada5541a..0334303ce6 100644 --- a/fastdeploy/vision/vision_pybind.cc +++ b/fastdeploy/vision/vision_pybind.cc @@ -16,7 +16,8 @@ namespace fastdeploy { -void BindPpClsModel(pybind11::module& m); +void BindPPCls(pybind11::module& m); +void BindPPDet(pybind11::module& m); void BindWongkinyiu(pybind11::module& m); void BindUltralytics(pybind11::module& m); void BindMeituan(pybind11::module& m); @@ -41,13 +42,14 @@ void BindVision(pybind11::module& m) { .def("__repr__", &vision::DetectionResult::Str) .def("__str__", &vision::DetectionResult::Str); - BindPpClsModel(m); + BindPPCls(m); + BindPPDet(m); BindUltralytics(m); BindWongkinyiu(m); BindMeituan(m); BindMegvii(m); #ifdef ENABLE_VISION_VISUALIZE BindVisualize(m); -#endif +#endif } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/fastdeploy/vision/visualize/detection.cc b/fastdeploy/vision/visualize/detection.cc index d0c4116148..5b5538bff7 100644 --- a/fastdeploy/vision/visualize/detection.cc +++ b/fastdeploy/vision/visualize/detection.cc @@ -43,7 +43,7 @@ void Visualize::VisDetection(cv::Mat* im, const DetectionResult& result, } std::string text = id + "," + score; int font = cv::FONT_HERSHEY_SIMPLEX; - cv::Size text_size = cv::getTextSize(text, font, font_size, 0.5, nullptr); + cv::Size text_size = cv::getTextSize(text, font, font_size, 1, nullptr); cv::Point origin; origin.x = rect.x; origin.y = rect.y; @@ -52,7 +52,7 @@ void Visualize::VisDetection(cv::Mat* im, const DetectionResult& result, text_size.width, text_size.height); cv::rectangle(*im, rect, rect_color, line_size); cv::putText(*im, text, origin, font, font_size, cv::Scalar(255, 255, 255), - 0.5); + 1); } } diff --git a/model_zoo/vision/ppyoloe/README.md b/model_zoo/vision/ppyoloe/README.md new file mode 100644 index 0000000000..42d18104ad --- /dev/null +++ b/model_zoo/vision/ppyoloe/README.md @@ -0,0 +1,52 @@ +# PaddleDetection/PPYOLOE部署示例 + +- 当前支持PaddleDetection版本为[release/2.4](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4) + +本文档说明如何进行[PPYOLOE](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyoloe)的快速部署推理。本目录结构如下 +``` +. +├── cpp # C++ 代码目录 +│   ├── CMakeLists.txt # C++ 代码编译CMakeLists文件 +│   ├── README.md # C++ 代码编译部署文档 +│   └── ppyoloe.cc # C++ 示例代码 +├── README.md # PPYOLOE 部署文档 +└── ppyoloe.py # Python示例代码 +``` + +## 安装FastDeploy + +使用如下命令安装FastDeploy,注意到此处安装的是`vision-cpu`,也可根据需求安装`vision-gpu` +``` +# 安装fastdeploy-python工具 +pip install fastdeploy-python +``` + +## Python部署 + +执行如下代码即会自动下载PPYOLOE模型和测试图片 +``` +python ppyoloe.py +``` + +执行完成后会将可视化结果保存在本地`vis_result.jpg`,同时输出检测结果如下 +``` +DetectionResult: [xmin, ymin, xmax, ymax, score, label_id] +162.380249,132.057449, 463.178345, 413.167114, 0.962918, 33 +414.914642,141.148666, 91.275269, 308.688293, 0.951003, 0 +163.449234,129.669067, 35.253891, 135.111786, 0.900734, 0 +267.232239,142.290436, 31.578918, 126.329773, 0.848709, 0 +581.790833,179.027115, 30.893127, 135.484940, 0.837986, 0 +104.407021,72.602615, 22.900627, 75.469055, 0.796468, 0 +348.795380,70.122147, 18.806061, 85.829330, 0.785557, 0 +364.118683,92.457428, 17.437622, 89.212891, 0.774282, 0 +75.180283,192.470490, 41.898407, 55.552414, 0.712569, 56 +328.133759,61.894299, 19.100616, 65.633575, 0.710519, 0 +504.797760,181.732574, 107.740814, 248.115082, 0.708902, 0 +379.063080,64.762360, 15.956146, 68.312546, 0.680725, 0 +25.858747,186.564178, 34.958130, 56.007080, 0.580415, 0 +``` + +## 其它文档 + +- [C++部署](./cpp/README.md) +- [PPYOLOE API文档](./api.md) diff --git a/model_zoo/vision/ppyoloe/api.md b/model_zoo/vision/ppyoloe/api.md new file mode 100644 index 0000000000..1c5cbcaadb --- /dev/null +++ b/model_zoo/vision/ppyoloe/api.md @@ -0,0 +1,74 @@ +# PPYOLOE API说明 + +## Python API + +### PPYOLOE类 +``` +fastdeploy.vision.ultralytics.PPYOLOE(model_file, params_file, config_file, runtime_option=None, model_format=fd.Frontend.PADDLE) +``` +PPYOLOE模型加载和初始化,需同时提供model_file和params_file, 当前仅支持model_format为Paddle格式 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径 +> * **config_file**(str): 模型推理配置文件 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(Frontend): 模型格式 + +#### predict函数 +> ``` +> PPYOLOE.predict(image_data, conf_threshold=0.25, nms_iou_threshold=0.5) +> ``` +> 模型预测结口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **image_data**(np.ndarray): 输入数据,注意需为HWC,BGR格式 +> > * **conf_threshold**(float): 检测框置信度过滤阈值 +> > * **nms_iou_threshold**(float): NMS处理过程中iou阈值(当模型中包含nms处理时,此参数自动无效) + +示例代码参考[ppyoloe.py](./ppyoloe.py) + + +## C++ API + +### PPYOLOE类 +``` +fastdeploy::vision::ultralytics::PPYOLOE( + const string& model_file, + const string& params_file, + const string& config_file, + const RuntimeOption& runtime_option = RuntimeOption(), + const Frontend& model_format = Frontend::ONNX) +``` +PPYOLOE模型加载和初始化,需同时提供model_file和params_file, 当前仅支持model_format为Paddle格式 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径 +> * **config_file**(str): 模型推理配置文件 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(Frontend): 模型格式 + +#### Predict函数 +> ``` +> YOLOv5::Predict(cv::Mat* im, DetectionResult* result, +> float conf_threshold = 0.25, +> float nms_iou_threshold = 0.5) +> ``` +> 模型预测接口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **im**: 输入图像,注意需为HWC,BGR格式 +> > * **result**: 检测结果,包括检测框,各个框的置信度 +> > * **conf_threshold**: 检测框置信度过滤阈值 +> > * **nms_iou_threshold**: NMS处理过程中iou阈值(当模型中包含nms处理时,此参数自动无效) + +示例代码参考[cpp/yolov5.cc](cpp/yolov5.cc) + +## 其它API使用 + +- [模型部署RuntimeOption配置](../../../docs/api/runtime_option.md) diff --git a/model_zoo/vision/ppyoloe/cpp/CMakeLists.txt b/model_zoo/vision/ppyoloe/cpp/CMakeLists.txt new file mode 100644 index 0000000000..e681566517 --- /dev/null +++ b/model_zoo/vision/ppyoloe/cpp/CMakeLists.txt @@ -0,0 +1,17 @@ +PROJECT(ppyoloe_demo C CXX) +CMAKE_MINIMUM_REQUIRED (VERSION 3.16) + +# 在低版本ABI环境中,通过如下代码进行兼容性编译 +# add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) + +# 指定下载解压后的fastdeploy库路径 +set(FASTDEPLOY_INSTALL_DIR ${PROJECT_SOURCE_DIR}/fastdeploy-linux-x64-0.3.0/) + +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) + +# 添加FastDeploy依赖头文件 +include_directories(${FASTDEPLOY_INCS}) + +add_executable(ppyoloe_demo ${PROJECT_SOURCE_DIR}/ppyoloe.cc) +# 添加FastDeploy库依赖 +target_link_libraries(ppyoloe_demo ${FASTDEPLOY_LIBS}) diff --git a/model_zoo/vision/ppyoloe/cpp/README.md b/model_zoo/vision/ppyoloe/cpp/README.md new file mode 100644 index 0000000000..1027c2eeb2 --- /dev/null +++ b/model_zoo/vision/ppyoloe/cpp/README.md @@ -0,0 +1,39 @@ +# 编译PPYOLOE示例 + + +``` +# 下载和解压预测库 +wget https://bj.bcebos.com/paddle2onnx/fastdeploy/fastdeploy-linux-x64-0.0.3.tgz +tar xvf fastdeploy-linux-x64-0.0.3.tgz + +# 编译示例代码 +mkdir build & cd build +cmake .. +make -j + +# 下载模型和图片 +wget https://bj.bcebos.com/paddle2onnx/fastdeploy/models/ppdet/ppyoloe_crn_l_300e_coco.tgz +tar xvf ppyoloe_crn_l_300e_coco.tgz +wget https://raw.githubusercontent.com/PaddlePaddle/PaddleDetection/release/2.4/demo/000000014439_640x640.jpg + +# 执行 +./ppyoloe_demo +``` + +执行完后可视化的结果保存在本地`vis_result.jpg`,同时会将检测框输出在终端,如下所示 +``` +DetectionResult: [xmin, ymin, xmax, ymax, score, label_id] +162.380249,132.057449, 463.178345, 413.167114, 0.962918, 33 +414.914642,141.148666, 91.275269, 308.688293, 0.951003, 0 +163.449234,129.669067, 35.253891, 135.111786, 0.900734, 0 +267.232239,142.290436, 31.578918, 126.329773, 0.848709, 0 +581.790833,179.027115, 30.893127, 135.484940, 0.837986, 0 +104.407021,72.602615, 22.900627, 75.469055, 0.796468, 0 +348.795380,70.122147, 18.806061, 85.829330, 0.785557, 0 +364.118683,92.457428, 17.437622, 89.212891, 0.774282, 0 +75.180283,192.470490, 41.898407, 55.552414, 0.712569, 56 +328.133759,61.894299, 19.100616, 65.633575, 0.710519, 0 +504.797760,181.732574, 107.740814, 248.115082, 0.708902, 0 +379.063080,64.762360, 15.956146, 68.312546, 0.680725, 0 +25.858747,186.564178, 34.958130, 56.007080, 0.580415, 0 +``` diff --git a/model_zoo/vision/ppyoloe/cpp/ppyoloe.cc b/model_zoo/vision/ppyoloe/cpp/ppyoloe.cc new file mode 100644 index 0000000000..e63f29e62a --- /dev/null +++ b/model_zoo/vision/ppyoloe/cpp/ppyoloe.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +int main() { + namespace vis = fastdeploy::vision; + + std::string model_file = "ppyoloe_crn_l_300e_coco/model.pdmodel"; + std::string params_file = "ppyoloe_crn_l_300e_coco/model.pdiparams"; + std::string config_file = "ppyoloe_crn_l_300e_coco/infer_cfg.yml"; + std::string img_path = "000000014439_640x640.jpg"; + std::string vis_path = "vis.jpeg"; + + auto model = vis::ppdet::PPYOLOE(model_file, params_file, config_file); + if (!model.Initialized()) { + std::cerr << "Init Failed." << std::endl; + return -1; + } + + cv::Mat im = cv::imread(img_path); + cv::Mat vis_im = im.clone(); + + vis::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Prediction Failed." << std::endl; + return -1; + } else { + std::cout << "Prediction Done!" << std::endl; + } + + // 输出预测框结果 + std::cout << res.Str() << std::endl; + + // 可视化预测结果 + vis::Visualize::VisDetection(&vis_im, res); + cv::imwrite(vis_path, vis_im); + std::cout << "Detect Done! Saved: " << vis_path << std::endl; + return 0; +} diff --git a/model_zoo/vision/ppyoloe/ppyoloe.py b/model_zoo/vision/ppyoloe/ppyoloe.py new file mode 100644 index 0000000000..7d79dfd8cf --- /dev/null +++ b/model_zoo/vision/ppyoloe/ppyoloe.py @@ -0,0 +1,24 @@ +import fastdeploy as fd +import cv2 + +# 下载模型和测试图片 +model_url = "https://bj.bcebos.com/paddle2onnx/fastdeploy/models/ppdet/ppyoloe_crn_l_300e_coco.tgz" +test_jpg_url = "https://raw.githubusercontent.com/PaddlePaddle/PaddleDetection/release/2.4/demo/000000014439_640x640.jpg" +fd.download_and_decompress(model_url, ".") +fd.download(test_jpg_url, ".", show_progress=True) + +# 加载模型 +model = fd.vision.ppdet.PPYOLOE("ppyoloe_crn_l_300e_coco/model.pdmodel", + "ppyoloe_crn_l_300e_coco/model.pdiparams", + "ppyoloe_crn_l_300e_coco/infer_cfg.yml") + +# 预测图片 +im = cv2.imread("000000014439_640x640.jpg") +result = model.predict(im, conf_threshold=0.5) + +# 可视化结果 +fd.vision.visualize.vis_detection(im, result) +cv2.imwrite("vis_result.jpg", im) + +# 输出预测结果 +print(result) diff --git a/setup.py b/setup.py index f0ff3f16de..e76f057b1c 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,8 @@ setup_configs["ENABLE_TRT_BACKEND"] = os.getenv("ENABLE_TRT_BACKEND", "OFF") setup_configs["WITH_GPU"] = os.getenv("WITH_GPU", "OFF") setup_configs["TRT_DIRECTORY"] = os.getenv("TRT_DIRECTORY", "UNDEFINED") -setup_configs["CUDA_DIRECTORY"] = os.getenv("CUDA_DIRECTORY", "/usr/local/cuda") +setup_configs["CUDA_DIRECTORY"] = os.getenv("CUDA_DIRECTORY", + "/usr/local/cuda") TOP_DIR = os.path.realpath(os.path.dirname(__file__)) SRC_DIR = os.path.join(TOP_DIR, "fastdeploy") @@ -325,17 +326,32 @@ def run(self): shutil.copy("LICENSE", "fastdeploy") depend_libs = list() - # modify the search path of libraries - command = "patchelf --set-rpath '$ORIGIN/libs/' .setuptools-cmake-build/fastdeploy_main.cpython-36m-x86_64-linux-gnu.so" - # The sw_64 not suppot patchelf, so we just disable that. - if platform.machine() != 'sw_64' and platform.machine() != 'mips64': - assert os.system(command) == 0, "patch fastdeploy_main.cpython-36m-x86_64-linux-gnu.so failed, the command: {}".format(command) + if platform.system().lower() == "linux": + for f in os.listdir(".setuptools-cmake-build"): + full_name = os.path.join(".setuptools-cmake-build", f) + if not os.path.isfile(full_name): + continue + if not full_name.count("fastdeploy_main.cpython-"): + continue + if not full_name.endswith(".so"): + continue + # modify the search path of libraries + command = "patchelf --set-rpath '$ORIGIN/libs/' {}".format( + full_name) + # The sw_64 not suppot patchelf, so we just disable that. + if platform.machine() != 'sw_64' and platform.machine( + ) != 'mips64': + assert os.system( + command + ) == 0, "patch fastdeploy_main.cpython-36m-x86_64-linux-gnu.so failed, the command: {}".format( + command) for f in os.listdir(".setuptools-cmake-build"): if not os.path.isfile(os.path.join(".setuptools-cmake-build", f)): continue if f.count("libfastdeploy") > 0: - shutil.copy(os.path.join(".setuptools-cmake-build", f), "fastdeploy/libs") + shutil.copy( + os.path.join(".setuptools-cmake-build", f), "fastdeploy/libs") for dirname in os.listdir(".setuptools-cmake-build/third_libs/install"): for lib in os.listdir( os.path.join(".setuptools-cmake-build/third_libs/install",