diff --git a/.gitignore b/.gitignore index c97725f9b9..9ce70a2f90 100644 --- a/.gitignore +++ b/.gitignore @@ -128,3 +128,6 @@ work_dirs/ # install directory /install + +# the generated header files +/tests/test_csrc/test_define.h diff --git a/csrc/apis/c/model.h b/csrc/apis/c/model.h index 83fea727ba..5f7cba116d 100644 --- a/csrc/apis/c/model.h +++ b/csrc/apis/c/model.h @@ -9,6 +9,6 @@ int mmdeploy_model_create_by_path(const char* path, mm_model_t* model); int mmdeploy_model_create(const void* buffer, int size, mm_model_t* model); -void mmdeploy_model_destroy(mm_model_t* model); +void mmdeploy_model_destroy(mm_model_t model); #endif // MMDEPLOY_SRC_APIS_C_MODEL_H_ diff --git a/csrc/codebase/mmocr/resize_ocr.cpp b/csrc/codebase/mmocr/resize_ocr.cpp index a7fc9dad8c..c9c48b0732 100644 --- a/csrc/codebase/mmocr/resize_ocr.cpp +++ b/csrc/codebase/mmocr/resize_ocr.cpp @@ -19,12 +19,18 @@ class ResizeOCRImpl : public Module { public: explicit ResizeOCRImpl(const Value& args) noexcept { height_ = args.value("height", height_); - min_width_ = args.value("min_width", min_width_); - max_width_ = args.value("max_width", max_width_); + min_width_ = args.contains("min_width") && args["min_width"].is_number_integer() + ? args["min_width"].get() + : min_width_; + max_width_ = args.contains("max_width") && args["max_width"].is_number_integer() + ? args["max_width"].get() + : max_width_; keep_aspect_ratio_ = args.value("keep_aspect_ratio", keep_aspect_ratio_); + backend_ = args.contains("backend") && args["backend"].is_string() + ? args["backend"].get() + : backend_; img_pad_value_ = args.value("img_pad_value", img_pad_value_); width_downsample_ratio_ = args.value("width_downsample_ratio", width_downsample_ratio_); - backend_ = args.value("backend", backend_); stream_ = args["context"]["stream"].get(); } diff --git a/csrc/codebase/mmseg/segment.cpp b/csrc/codebase/mmseg/segment.cpp index a8f315fed1..91df317140 100644 --- a/csrc/codebase/mmseg/segment.cpp +++ b/csrc/codebase/mmseg/segment.cpp @@ -9,36 +9,24 @@ namespace mmdeploy::mmseg { -static Result VisualizeMask(const std::string &image_name, const Tensor &mask, int height, - int width, Stream &stream) { - Device cpu_device{"cpu"}; - OUTCOME_TRY(auto host_mask, MakeAvailableOnDevice(mask, cpu_device, stream)); - OUTCOME_TRY(stream.Wait()); - // cv::Mat mask_image(height, width, CV_32SC1, host_mask.data()); - // cv::imwrite(image_name + ".png", mask_image * 10); - // ofstream ofs(image_name + ".data"); - // auto _data_ptr = host_mask.data(); - // for (auto i = 0; i < height; ++i) { - // for (auto j = 0; j < width; ++j) { - // ofs << *_data_ptr++ << ", "; - // } - // ofs << "\n"; - // } - return success(); -} - class ResizeMask : public MMSegmentation { public: explicit ResizeMask(const Value &cfg) : MMSegmentation(cfg) { - classes_ = cfg["params"]["num_classes"].get(); + try { + classes_ = cfg["params"]["num_classes"].get(); + } catch (const std::exception &e) { + ERROR("no ['params']['num_classes'] is specified in cfg: {}", cfg); + throw_exception(eInvalidArgument); + } } Result operator()(const Value &preprocess_result, const Value &inference_result) { DEBUG("preprocess: {}\ninference: {}", preprocess_result, inference_result); auto mask = inference_result["output"].get(); - INFO("tensor.name: {}, tensor.shape: {}", mask.name(), mask.shape()); - assert(mask.data_type() == DataType::kINT32); + INFO("tensor.name: {}, tensor.shape: {}, tensor.data_type: {}", mask.name(), mask.shape(), + mask.data_type()); + assert(mask.data_type() == DataType::kINT32 || mask.data_type() == DataType::kINT64); assert(mask.shape(0) == 1); assert(mask.shape(1) == 1); @@ -46,23 +34,41 @@ class ResizeMask : public MMSegmentation { auto width = (int)mask.shape(3); auto input_height = preprocess_result["img_metas"]["ori_shape"][1].get(); auto input_width = preprocess_result["img_metas"]["ori_shape"][2].get(); - if (height == input_height && width == input_width) { - SegmentorOutput output{mask, input_height, input_width, classes_}; - return to_value(output); + Device host{"cpu"}; + OUTCOME_TRY(auto host_tensor, MakeAvailableOnDevice(mask, host, stream_)); + stream_.Wait().value(); + if (mask.data_type() == DataType::kINT64) { + // change kINT64 to 2 INT32 + TensorDesc desc{.device = host_tensor.device(), + .data_type = DataType::kINT32, + .shape = {1, 2, height, width}, + .name = host_tensor.name()}; + Tensor _host_tensor(desc, mask.buffer()); + return MaskResize(_host_tensor, input_height, input_width); } else { - Device host{"cpu"}; - - OUTCOME_TRY(auto host_tensor, MakeAvailableOnDevice(mask, host, stream_)); - host_tensor.Reshape({1, height, width, 1}); - auto mat = cpu::Tensor2CVMat(host_tensor); - auto dst = cpu::Resize(mat, input_height, input_width, "nearest"); - auto output_tensor = cpu::CVMat2Tensor(dst); + return MaskResize(host_tensor, input_height, input_width); + } + } - SegmentorOutput output{output_tensor, input_height, input_width, classes_}; + private: + Result MaskResize(Tensor &tensor, int dst_height, int dst_width) { + auto channel = tensor.shape(1); + auto height = tensor.shape(2); + auto width = tensor.shape(3); - // OUTCOME_TRY( - // VisualizeMask("resize_mask", output_tensor, input_height, input_width, - // stream_)); + // reshape tensor to convert it to cv::Mat + tensor.Reshape({1, height, width, channel}); + auto mat = cpu::Tensor2CVMat(tensor); + auto dst = cpu::Resize(mat, dst_height, dst_width, "nearest"); + if (channel == 1) { + auto output_tensor = cpu::CVMat2Tensor(dst); + SegmentorOutput output{output_tensor, dst_height, dst_width, classes_}; + return to_value(output); + } else { + cv::Mat _dst; + cv::extractChannel(dst, _dst, 0); + auto output_tensor = cpu::CVMat2Tensor(_dst); + SegmentorOutput output{output_tensor, dst_height, dst_width, classes_}; return to_value(output); } } diff --git a/csrc/core/model.h b/csrc/core/model.h index 593b08b6dd..a9ce11eff3 100644 --- a/csrc/core/model.h +++ b/csrc/core/model.h @@ -27,8 +27,7 @@ struct model_meta_info_t { struct deploy_meta_info_t { std::string version; std::vector models; - std::vector customs; - MMDEPLOY_ARCHIVE_MEMBERS(version, models, customs); + MMDEPLOY_ARCHIVE_MEMBERS(version, models); }; class ModelImpl; diff --git a/csrc/model/zip_model_impl.cpp b/csrc/model/zip_model_impl.cpp index b6aa0df460..0f1479f64c 100644 --- a/csrc/model/zip_model_impl.cpp +++ b/csrc/model/zip_model_impl.cpp @@ -1,7 +1,5 @@ // Copyright (c) OpenMMLab. All rights reserved. -#include - #include #include @@ -10,6 +8,13 @@ #include "core/model.h" #include "core/model_impl.h" #include "zip.h" +#if __GNUC__ >= 8 +#include +namespace fs = std::filesystem; +#else +#include +namespace fs = std::experimental::filesystem; +#endif using nlohmann::json; @@ -62,6 +67,7 @@ class ZipModelImpl : public ModelImpl { Result ReadFile(const std::string& file_path) const override { int ret = 0; int index = -1; + auto iter = file_index_.find(file_path); if (iter == file_index_.end()) { ERROR("cannot find file {} under dir {}", file_path.c_str(), root_dir_.c_str()); @@ -103,16 +109,20 @@ class ZipModelImpl : public ModelImpl { Result InitZip() { int files = zip_get_num_files(zip_); INFO("there are {} files in sdk model file", files); - + if (files == 0) { + return Status(eFail); + } for (int i = 0; i < files; ++i) { struct zip_stat stat; zip_stat_init(&stat); zip_stat_index(zip_, i, 0, &stat); - if (stat.name[strlen(stat.name) - 1] == '/') { + fs::path path(stat.name); + auto file_name = path.filename().string(); + if (file_name == ".") { DEBUG("{}-th file name is: {}, which is a directory", i, stat.name); } else { DEBUG("{}-th file name is: {}, which is a file", i, stat.name); - file_index_[stat.name] = i; + file_index_[file_name] = i; } } return success(); diff --git a/csrc/net/trt/trt_net.cpp b/csrc/net/trt/trt_net.cpp index 3e5f1f6e02..6f4cb940a1 100644 --- a/csrc/net/trt/trt_net.cpp +++ b/csrc/net/trt/trt_net.cpp @@ -212,7 +212,7 @@ Result TRTNet::ForwardAsync(Event* event) { return Status(eNotSupported); class TRTNetCreator : public Creator { public: - const char* GetName() const override { return "trt"; } + const char* GetName() const override { return "tensorrt"; } int GetVersion() const override { return 0; } std::unique_ptr Create(const Value& args) override { auto p = std::make_unique(); diff --git a/csrc/preprocess/transform/load.cpp b/csrc/preprocess/transform/load.cpp index 90b886bdd0..671948f2d8 100644 --- a/csrc/preprocess/transform/load.cpp +++ b/csrc/preprocess/transform/load.cpp @@ -38,7 +38,9 @@ Result PrepareImageImpl::Process(const Value& input) { Value output = input; Mat src_mat = input["ori_img"].get(); - auto res = (arg_.color_type == "color" ? ConvertToBGR(src_mat) : ConvertToGray(src_mat)); + auto res = (arg_.color_type == "color" || arg_.color_type == "color_ignore_orientation" + ? ConvertToBGR(src_mat) + : ConvertToGray(src_mat)); OUTCOME_TRY(auto tensor, std::move(res)); diff --git a/demo/csrc/config/resnet50_ort/pipeline.json b/demo/csrc/config/resnet50_ort/pipeline.json index db6979dbeb..5f8d67b897 100644 --- a/demo/csrc/config/resnet50_ort/pipeline.json +++ b/demo/csrc/config/resnet50_ort/pipeline.json @@ -80,8 +80,8 @@ { "name": "postprocess", "type": "Task", - "module": "MMClsPostprocess", - "postprocess_type": "SoftmaxPost", + "module": "mmcls", + "component": "LinearClsHead", "input": [ "data", "cls_res" diff --git a/demo/csrc/config/retinanet_ort/pipeline.json b/demo/csrc/config/retinanet_ort/pipeline.json index d689c70dd4..c0b1439f1a 100644 --- a/demo/csrc/config/retinanet_ort/pipeline.json +++ b/demo/csrc/config/retinanet_ort/pipeline.json @@ -81,10 +81,10 @@ } }, { - "name": "retinanet_post", + "name": "postprocess", "type": "Task", - "module": "MMDetPostprocess", - "postprocess_type": "SingleStagePost", + "module": "mmdet", + "component": "ResizeBBox", "input": [ "prep_res", "infer_res" @@ -92,7 +92,10 @@ "output": [ "bboxes" ], - "score_thr": 0.3 + "params": { + "score_thr": 0.3, + "min_bbox_size": 10 + } } ] } diff --git a/tests/test_csrc/CMakeLists.txt b/tests/test_csrc/CMakeLists.txt index 0ce8761f34..7ab56d54c8 100644 --- a/tests/test_csrc/CMakeLists.txt +++ b/tests/test_csrc/CMakeLists.txt @@ -4,72 +4,89 @@ project(tests) include(${CMAKE_SOURCE_DIR}/cmake/opencv.cmake) -# find TC source files and related shared libraries and modules that are going to be linked +# find TC source files and related shared libraries and modules that are going +# to be linked set(MMDEPLOY_LIBS) -set(MMDEPLOY_MODULES - mmdeploy::directory_model - mmdeploy::transform_module - mmdeploy::transform - mmdeploy::net_module - mmdeploy::graph) +set(MMDEPLOY_MODULES mmdeploy::directory_model mmdeploy::transform_module + mmdeploy::transform mmdeploy::net_module mmdeploy::graph) set(TC_SRCS test_main.cpp) aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/archive ARCHIVE_TC) aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/core CORE_TC) -aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/graph GRAPH_TC) +aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/preprocess TRANSFORM_TC) set(DEVICE_TC) -foreach (DEVICE IN LISTS MMDEPLOY_TARGET_DEVICES) - list(APPEND DEVICE_TC ${CMAKE_CURRENT_SOURCE_DIR}/device/test_${DEVICE}_device.cpp) - list(APPEND MMDEPLOY_MODULES mmdeploy::device::${DEVICE} mmdeploy::transform_impl::${DEVICE}) -endforeach () +foreach(DEVICE IN LISTS MMDEPLOY_TARGET_DEVICES) + list(APPEND DEVICE_TC + ${CMAKE_CURRENT_SOURCE_DIR}/device/test_${DEVICE}_device.cpp) + list(APPEND MMDEPLOY_MODULES mmdeploy::device::${DEVICE} + mmdeploy::transform_impl::${DEVICE}) +endforeach() + +aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/net NET_TC) +foreach(BACKEND IN LISTS MMDEPLOY_TARGET_BACKENDS) + list(APPEND MMDEPLOY_MODULES mmdeploy::${BACKEND}_net) +endforeach() + +aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/model MODEL_TC) +if(MMDEPLOY_ZIP_MODEL) + message(STATUS "MMDEPLOY_ZIP_MODEL: ${MMDEPLOY_ZIP_MODEL}") + list(APPEND MMDEPLOY_MODULES mmdeploy::zip_model) +endif() -set(NET_TC) -foreach (BACKEND IN LISTS MMDEPLOY_TARGET_BACKENDS) - list(APPEND MMDEPLOY_MODULES mmdeploy::${BACKEND}_net) -endforeach () set(CAPI_TC) -if ("all" IN_LIST MMDEPLOY_CODEBASES) - set(TASK_LIST "classifier;detector;segmentor;text_detector;text_recognizer;restorer;model") - set(CODEBASES "mmcls;mmdet;mmseg;mmedit;mmocr") -else () - set(TASK_LIST "model") - set(CODEBASES "${MMDEPLOY_CODEBASES}") - if ("mmcls" IN_LIST MMDEPLOY_CODEBASES) - list(APPEND TASK_LIST "classifier") - endif () - if ("mmdet" IN_LIST MMDEPLOY_CODEBASES) - list(APPEND TASK_LIST "detector") - endif () - if ("mmseg" IN_LIST MMDEPLOY_CODEBASES) - list(APPEND TASK_LIST "segmentor") - endif () - if ("mmedit" IN_LIST MMDEPLOY_CODEBASES) - list(APPEND TASK_LIST "restorer") - endif () - if ("mmocr" IN_LIST MMDEPLOY_CODEBASES) - list(APPEND TASK_LIST "text_detector") - list(APPEND TASK_LIST "text_recognizer") - endif () -endif () -foreach (TASK ${TASK_LIST}) - list(APPEND CAPI_TC ${CMAKE_CURRENT_SOURCE_DIR}/capi/test_${TASK}.cpp) - list(APPEND MMDEPLOY_LIBS mmdeploy_${TASK}) -endforeach () +if("all" IN_LIST MMDEPLOY_CODEBASES) + set(TASK_LIST + "classifier;detector;segmentor;text_detector;text_recognizer;restorer;model" + ) + set(CODEBASES "mmcls;mmdet;mmseg;mmedit;mmocr") +else() + set(TASK_LIST "model") + set(CODEBASES "${MMDEPLOY_CODEBASES}") + if("mmcls" IN_LIST MMDEPLOY_CODEBASES) + list(APPEND TASK_LIST "classifier") + endif() + if("mmdet" IN_LIST MMDEPLOY_CODEBASES) + list(APPEND TASK_LIST "detector") + endif() + if("mmseg" IN_LIST MMDEPLOY_CODEBASES) + list(APPEND TASK_LIST "segmentor") + endif() + if("mmedit" IN_LIST MMDEPLOY_CODEBASES) + list(APPEND TASK_LIST "restorer") + endif() + if("mmocr" IN_LIST MMDEPLOY_CODEBASES) + list(APPEND TASK_LIST "text_detector") + list(APPEND TASK_LIST "text_recognizer") + endif() +endif() +foreach(TASK ${TASK_LIST}) + list(APPEND CAPI_TC ${CMAKE_CURRENT_SOURCE_DIR}/capi/test_${TASK}.cpp) + list(APPEND MMDEPLOY_LIBS mmdeploy_${TASK}) +endforeach() -# TODO(lvhan): add model test +# generate the header file +configure_file(config/test_define.h.in + ${CMAKE_CURRENT_SOURCE_DIR}/test_define.h) -set(TC_SRCS ${TC_SRCS} ${ARCHIVE_TC} ${CORE_TC} ${GRAPH_TC} ${DEVICE_TC} ${CAPI_TC}) +set(TC_SRCS + ${TC_SRCS} + ${ARCHIVE_TC} + ${CORE_TC} + ${DEVICE_TC} + ${CAPI_TC} + ${TRANSFORM_TC} + ${MODEL_TC} + ${NET_TC}) list(APPEND MMDEPLOY_LIBS mmdeploy::core) -foreach (CODEBASE IN LISTS CODEBASES) - list(APPEND MMDEPLOY_MODULES mmdeploy::${CODEBASE}) -endforeach () +foreach(CODEBASE IN LISTS CODEBASES) + list(APPEND MMDEPLOY_MODULES mmdeploy::${CODEBASE}) +endforeach() add_executable(mmdeploy_tests ${TC_SRCS}) -target_include_directories(mmdeploy_tests PRIVATE ${CMAKE_SOURCE_DIR}/third_party/catch2) -target_link_libraries(mmdeploy_tests PRIVATE - ${MMDEPLOY_LIBS} ${OpenCV_LIBS} - -Wl,--no-as-needed - ${MMDEPLOY_MODULES} - -Wl,--as-need - ) +target_include_directories(mmdeploy_tests + PRIVATE ${CMAKE_SOURCE_DIR}/third_party/catch2) +target_include_directories(mmdeploy_tests PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) +target_link_libraries( + mmdeploy_tests PRIVATE ${MMDEPLOY_LIBS} ${OpenCV_LIBS} -Wl,--no-as-needed + ${MMDEPLOY_MODULES} -Wl,--as-need) diff --git a/tests/test_csrc/capi/test_classifier.cpp b/tests/test_csrc/capi/test_classifier.cpp index 8a2cc86225..5e0b0f4136 100644 --- a/tests/test_csrc/capi/test_classifier.cpp +++ b/tests/test_csrc/capi/test_classifier.cpp @@ -1,36 +1,63 @@ // Copyright (c) OpenMMLab. All rights reserved. -#include - // clang-format off #include "catch.hpp" // clang-format on #include "apis/c/classifier.h" -#include "apis/c/model.h" #include "core/logger.h" #include "opencv2/opencv.hpp" +#include "test_resource.h" using namespace std; TEST_CASE("test classifier's c api", "[classifier]") { - mm_handle_t handle{nullptr}; - auto model_path = "../../config/classifier/resnet50_t4-cuda11.1-trt7.2-fp32"; - // auto ret = mmdeploy_classifier_create_by_path(model_path, "cuda", 0, &handle); - mm_model_t model{}; - auto ret = mmdeploy_model_create_by_path(model_path, &model); - REQUIRE(ret == MM_SUCCESS); - ret = mmdeploy_classifier_create(model, "cuda", 0, &handle); - REQUIRE(ret == MM_SUCCESS); - - cv::Mat mat = cv::imread("../../tests/data/images/dogs.jpg"); - vector mats{{mat.data, mat.rows, mat.cols, mat.channels(), MM_BGR, MM_INT8}}; - mm_class_t* results{nullptr}; - int* result_count{nullptr}; - ret = mmdeploy_classifier_apply(handle, mats.data(), (int)mats.size(), &results, &result_count); - REQUIRE(ret == MM_SUCCESS); - INFO("label: {}, score: {}", results->label_id, results->score); - - mmdeploy_classifier_release_result(results, result_count, (int)mats.size()); - mmdeploy_classifier_destroy(handle); + auto test = [](const std::string& device_name, const std::string& model_path, + const std::vector& img_list) { + mm_handle_t handle{nullptr}; + auto ret = + mmdeploy_classifier_create_by_path(model_path.c_str(), device_name.c_str(), 0, &handle); + REQUIRE(ret == MM_SUCCESS); + + vector cv_mats; + vector mats; + for (auto& img_path : img_list) { + cv::Mat mat = cv::imread(img_path); + REQUIRE(!mat.empty()); + cv_mats.push_back(mat); + mats.push_back({mat.data, mat.rows, mat.cols, mat.channels(), MM_BGR, MM_INT8}); + } + + mm_class_t* results{nullptr}; + int* result_count{nullptr}; + ret = mmdeploy_classifier_apply(handle, mats.data(), (int)mats.size(), &results, &result_count); + REQUIRE(ret == MM_SUCCESS); + auto result_ptr = results; + INFO("model_path: {}", model_path); + for (auto i = 0; i < (int)mats.size(); ++i) { + INFO("the {}-th classification result: ", i); + for (int j = 0; j < *result_count; ++j, ++result_ptr) { + INFO("\t label: {}, score: {}", result_ptr->label_id, result_ptr->score); + } + } + + mmdeploy_classifier_release_result(results, result_count, (int)mats.size()); + mmdeploy_classifier_destroy(handle); + }; + + auto gResources = MMDeployTestResources::Get(); + auto img_lists = gResources.LocateImageResources("mmcls/images"); + REQUIRE(!img_lists.empty()); + + for (auto& backend : gResources.backends()) { + DYNAMIC_SECTION("loop backend: " << backend) { + auto model_list = gResources.LocateModelResources("mmcls/" + backend); + REQUIRE(!model_list.empty()); + for (auto& model_path : model_list) { + for (auto& device_name : gResources.device_names(backend)) { + test(device_name, model_path, img_lists); + } + } + } + } } diff --git a/tests/test_csrc/capi/test_detector.cpp b/tests/test_csrc/capi/test_detector.cpp index 9ad7054c2e..f7a72e5410 100644 --- a/tests/test_csrc/capi/test_detector.cpp +++ b/tests/test_csrc/capi/test_detector.cpp @@ -4,37 +4,58 @@ #include "catch.hpp" // clang-format on -#include - #include "apis/c/detector.h" +#include "core/logger.h" #include "opencv2/opencv.hpp" +#include "test_resource.h" using namespace std; TEST_CASE("test detector's c api", "[detector]") { - mm_handle_t handle{nullptr}; - auto model_path = "../../config/detector/retinanet_t4-cuda11.1-trt7.2-fp32"; - auto ret = mmdeploy_detector_create_by_path(model_path, "cuda", 0, &handle); - REQUIRE(ret == MM_SUCCESS); - - cv::Mat mat = cv::imread("../../tests/data/images/dogs.jpg"); - vector mats{{mat.data, mat.rows, mat.cols, mat.channels(), MM_BGR, MM_INT8}}; - - mm_detect_t* results{nullptr}; - int* result_count{nullptr}; - ret = mmdeploy_detector_apply(handle, mats.data(), (int)mats.size(), &results, &result_count); - REQUIRE(ret == MM_SUCCESS); - auto result_ptr = results; - for (auto i = 0; i < mats.size(); ++i) { - cout << "the " << i << "-th image has '" << result_count[i] << "' objects" << endl; - for (auto j = 0; j < result_count[i]; ++j, ++result_ptr) { - auto& bbox = result_ptr->bbox; - cout << " >> bbox[" << bbox.left << ", " << bbox.top << ", " << bbox.right << ", " - << bbox.bottom << "], label_id " << result_ptr->label_id << ", score " - << result_ptr->score << endl; + auto test = [](const string &device, const string &model_path, const vector &img_list) { + mm_handle_t handle{nullptr}; + auto ret = mmdeploy_detector_create_by_path(model_path.c_str(), device.c_str(), 0, &handle); + REQUIRE(ret == MM_SUCCESS); + + vector cv_mats; + vector mats; + for (auto &img_path : img_list) { + cv::Mat mat = cv::imread(img_path); + REQUIRE(!mat.empty()); + cv_mats.push_back(mat); + mats.push_back({mat.data, mat.rows, mat.cols, mat.channels(), MM_BGR, MM_INT8}); } - } - mmdeploy_detector_release_result(results, result_count, (int)mats.size()); - mmdeploy_detector_destroy(handle); + mm_detect_t *results{nullptr}; + int *result_count{nullptr}; + ret = mmdeploy_detector_apply(handle, mats.data(), (int)mats.size(), &results, &result_count); + REQUIRE(ret == MM_SUCCESS); + auto result_ptr = results; + for (auto i = 0; i < mats.size(); ++i) { + INFO("the '{}-th' image has '{}' objects", i, result_count[i]); + for (auto j = 0; j < result_count[i]; ++j, ++result_ptr) { + auto &bbox = result_ptr->bbox; + INFO(" >> bbox[{}, {}, {}, {}], label_id {}, score {}", bbox.left, bbox.top, bbox.right, + bbox.bottom, result_ptr->label_id, result_ptr->score); + } + } + mmdeploy_detector_release_result(results, result_count, (int)mats.size()); + mmdeploy_detector_destroy(handle); + }; + + auto gResources = MMDeployTestResources::Get(); + auto img_lists = gResources.LocateImageResources("mmdet/images"); + REQUIRE(!img_lists.empty()); + + for (auto &backend : gResources.backends()) { + DYNAMIC_SECTION("loop backend: " << backend) { + auto model_list = gResources.LocateModelResources("mmdet/" + backend); + REQUIRE(!model_list.empty()); + for (auto &model_path : model_list) { + for (auto &device_name : gResources.device_names(backend)) { + test(device_name, model_path, img_lists); + } + } + } + } } diff --git a/tests/test_csrc/capi/test_model.cpp b/tests/test_csrc/capi/test_model.cpp index 046d8319c4..af0a983628 100644 --- a/tests/test_csrc/capi/test_model.cpp +++ b/tests/test_csrc/capi/test_model.cpp @@ -1 +1,31 @@ // Copyright (c) OpenMMLab. All rights reserved. + +// clang-format off +#include "catch.hpp" +// clang-format on + +#include "apis/c/model.h" +#include "test_resource.h" + +TEST_CASE("test model c capi", "[model]") { + auto &gResource = MMDeployTestResources::Get(); + std::string model_path; + for (auto const &codebase : gResource.codebases()) { + for (auto const &backend : gResource.backends()) { + if (auto _model_list = gResource.LocateModelResources(codebase + "/" + backend); + !_model_list.empty()) { + model_path = _model_list.front(); + break; + } + } + } + + REQUIRE(!model_path.empty()); + mm_model_t model{}; + REQUIRE(mmdeploy_model_create_by_path(model_path.c_str(), &model) == MM_SUCCESS); + mmdeploy_model_destroy(model); + model = nullptr; + + REQUIRE(mmdeploy_model_create(nullptr, 0, &model) == MM_E_FAIL); + mmdeploy_model_destroy(model); +} diff --git a/tests/test_csrc/capi/test_restorer.cpp b/tests/test_csrc/capi/test_restorer.cpp index c5a322c0d3..502d377021 100644 --- a/tests/test_csrc/capi/test_restorer.cpp +++ b/tests/test_csrc/capi/test_restorer.cpp @@ -4,27 +4,54 @@ #include "catch.hpp" // clang-format on -#include - #include "apis/c/restorer.h" #include "opencv2/opencv.hpp" +#include "test_resource.h" -TEST_CASE("test restorer's c api", "[restorer]") { - mm_handle_t handle{nullptr}; - auto ret = mmdeploy_restorer_create_by_path("../../config/restorer/esrgan", "cuda", 0, &handle); - REQUIRE(ret == MM_SUCCESS); - - cv::Mat mat = cv::imread("../../tests/data/image/demo_text_det.jpg"); - std::vector mats{{mat.data, mat.rows, mat.cols, mat.channels(), MM_BGR, MM_INT8}}; +using namespace std; - mm_mat_t* res{}; - ret = mmdeploy_restorer_apply(handle, mats.data(), (int)mats.size(), &res); - REQUIRE(ret == MM_SUCCESS); - - cv::Mat out(res->height, res->width, CV_8UC3, res->data); - cv::cvtColor(out, out, cv::COLOR_RGB2BGR); - cv::imwrite("test_restorer.bmp", out); - - mmdeploy_restorer_release_result(res, (int)mats.size()); - mmdeploy_restorer_destroy(handle); +TEST_CASE("test restorer's c api", "[restorer]") { + auto test = [](const string &device, const string &backend, const string &model_path, + const vector &img_list) { + mm_handle_t handle{nullptr}; + auto ret = mmdeploy_restorer_create_by_path(model_path.c_str(), device.c_str(), 0, &handle); + REQUIRE(ret == MM_SUCCESS); + + vector cv_mats; + vector mats; + for (auto &img_path : img_list) { + cv::Mat mat = cv::imread(img_path); + REQUIRE(!mat.empty()); + cv_mats.push_back(mat); + mats.push_back({mat.data, mat.rows, mat.cols, mat.channels(), MM_BGR, MM_INT8}); + } + mm_mat_t *res{}; + ret = mmdeploy_restorer_apply(handle, mats.data(), (int)mats.size(), &res); + REQUIRE(ret == MM_SUCCESS); + + for (auto i = 0; i < cv_mats.size(); ++i) { + cv::Mat out(res[i].height, res[i].width, CV_8UC3, res[i].data); + cv::cvtColor(out, out, cv::COLOR_RGB2BGR); + cv::imwrite("restorer_" + backend + "_" + to_string(i) + ".bmp", out); + } + + mmdeploy_restorer_release_result(res, (int)mats.size()); + mmdeploy_restorer_destroy(handle); + }; + + auto gResources = MMDeployTestResources::Get(); + auto img_lists = gResources.LocateImageResources("mmedit/images"); + REQUIRE(!img_lists.empty()); + + for (auto &backend : gResources.backends()) { + DYNAMIC_SECTION("loop backend: " << backend) { + auto model_list = gResources.LocateModelResources("mmedit/" + backend); + REQUIRE(!model_list.empty()); + for (auto &model_path : model_list) { + for (auto &device_name : gResources.device_names(backend)) { + test(device_name, backend, model_path, img_lists); + } + } + } + } } diff --git a/tests/test_csrc/capi/test_segmentor.cpp b/tests/test_csrc/capi/test_segmentor.cpp index b26f0bafda..b042d793c5 100644 --- a/tests/test_csrc/capi/test_segmentor.cpp +++ b/tests/test_csrc/capi/test_segmentor.cpp @@ -1,34 +1,60 @@ // Copyright (c) OpenMMLab. All rights reserved. -#include +// clang-format off +#include "catch.hpp" +// clang-format on #include "apis/c/segmentor.h" -#include "catch.hpp" #include "opencv2/opencv.hpp" +#include "test_resource.h" using namespace std; TEST_CASE("test segmentor's c api", "[segmentor]") { - mm_handle_t handle{nullptr}; - const auto model_path = "../../config/segmentor/fcn_t4-cuda11.1-trt7.2-fp16"; - auto ret = mmdeploy_segmentor_create_by_path(model_path, "cuda", 0, &handle); - REQUIRE(ret == MM_SUCCESS); - - cv::Mat mat = cv::imread("../../tests/data/images/dogs.jpg"); - vector mats{{mat.data, mat.rows, mat.cols, mat.channels(), MM_BGR, MM_INT8}}; - - mm_segment_t* results{nullptr}; - int count = 0; - ret = mmdeploy_segmentor_apply(handle, mats.data(), (int)mats.size(), &results); - REQUIRE(ret == MM_SUCCESS); - REQUIRE(results != nullptr); - - auto result_ptr = results; - for (auto i = 0; i < mats.size(); ++i) { - cv::Mat mask(result_ptr->height, result_ptr->width, CV_32SC1, result_ptr->mask); - cv::imwrite("mask.png", mask * 10); + auto test = [](const string &device, const string &backend, const string &model_path, + const vector &img_list) { + mm_handle_t handle{nullptr}; + auto ret = mmdeploy_segmentor_create_by_path(model_path.c_str(), device.c_str(), 0, &handle); + REQUIRE(ret == MM_SUCCESS); + + vector cv_mats; + vector mats; + for (auto &img_path : img_list) { + cv::Mat mat = cv::imread(img_path); + REQUIRE(!mat.empty()); + cv_mats.push_back(mat); + mats.push_back({mat.data, mat.rows, mat.cols, mat.channels(), MM_BGR, MM_INT8}); + } + + mm_segment_t *results{nullptr}; + int count = 0; + ret = mmdeploy_segmentor_apply(handle, mats.data(), (int)mats.size(), &results); + REQUIRE(ret == MM_SUCCESS); + REQUIRE(results != nullptr); + + auto result_ptr = results; + for (auto i = 0; i < mats.size(); ++i, ++result_ptr) { + cv::Mat mask(result_ptr->height, result_ptr->width, CV_32SC1, result_ptr->mask); + cv::imwrite("mask_" + backend + "_" + to_string(i) + ".png", mask * 10); + } + + mmdeploy_segmentor_release_result(results, (int)mats.size()); + mmdeploy_segmentor_destroy(handle); + }; + + auto gResources = MMDeployTestResources::Get(); + auto img_lists = gResources.LocateImageResources("mmseg/images"); + REQUIRE(!img_lists.empty()); + + for (auto &backend : gResources.backends()) { + DYNAMIC_SECTION("loop backend: " << backend) { + auto model_list = gResources.LocateModelResources("mmseg/" + backend); + REQUIRE(!model_list.empty()); + for (auto &model_path : model_list) { + for (auto &device_name : gResources.device_names(backend)) { + test(device_name, backend, model_path, img_lists); + } + } + } } - - mmdeploy_segmentor_release_result(results, (int)mats.size()); - mmdeploy_segmentor_destroy(handle); } diff --git a/tests/test_csrc/capi/test_text_detector.cpp b/tests/test_csrc/capi/test_text_detector.cpp index c7925a2136..a2bdd84493 100644 --- a/tests/test_csrc/capi/test_text_detector.cpp +++ b/tests/test_csrc/capi/test_text_detector.cpp @@ -1,42 +1,66 @@ // Copyright (c) OpenMMLab. All rights reserved. - -#include -#include +// clang-format off +#include "catch.hpp" +// clang-format on #include "apis/c/text_detector.h" -#include "catch.hpp" +#include "core/logger.h" #include "opencv2/opencv.hpp" +#include "test_resource.h" using namespace std; TEST_CASE("test text detector's c api", "[text-detector]") { - mm_handle_t handle{nullptr}; - auto model_path = "../../config/text-detector/dbnet18_t4-cuda11.1-trt7.2-fp16"; - auto ret = mmdeploy_text_detector_create_by_path(model_path, "cuda", 0, &handle); - REQUIRE(ret == MM_SUCCESS); - - cv::Mat mat = cv::imread("../../tests/data/images/ocr.jpg"); - vector mats{{mat.data, mat.rows, mat.cols, mat.channels(), MM_BGR, MM_INT8}}; - - mm_text_detect_t* results{nullptr}; - int* result_count{nullptr}; - ret = - mmdeploy_text_detector_apply(handle, mats.data(), (int)mats.size(), &results, &result_count); - REQUIRE(ret == MM_SUCCESS); - auto result_ptr = results; - for (auto i = 0; i < mats.size(); ++i) { - cout << "the " << i << "-th image has '" << result_count[i] << "' objects" << endl; - for (auto j = 0; j < result_count[i]; ++j, ++result_ptr) { - auto& bbox = result_ptr->bbox; - cout << ">> bbox[" << j << "].score: " << result_ptr->score << ", coordinate: "; - for (auto k = 0; k < 4; ++k) { - auto& bbox = result_ptr->bbox[k]; - cout << "(" << bbox.x << ", " << bbox.y << "), "; + auto test = [](const string& device, const string& model_path, const vector& img_list) { + mm_handle_t handle{nullptr}; + auto ret = + mmdeploy_text_detector_create_by_path(model_path.c_str(), device.c_str(), 0, &handle); + REQUIRE(ret == MM_SUCCESS); + + vector cv_mats; + vector mats; + for (auto& img_path : img_list) { + cv::Mat mat = cv::imread(img_path); + REQUIRE(!mat.empty()); + cv_mats.push_back(mat); + mats.push_back({mat.data, mat.rows, mat.cols, mat.channels(), MM_BGR, MM_INT8}); + } + + mm_text_detect_t* results{nullptr}; + int* result_count{nullptr}; + ret = mmdeploy_text_detector_apply(handle, mats.data(), (int)mats.size(), &results, + &result_count); + REQUIRE(ret == MM_SUCCESS); + + auto result_ptr = results; + for (auto i = 0; i < mats.size(); ++i) { + INFO("the {}-th image has '{}' objects", i, result_count[i]); + for (auto j = 0; j < result_count[i]; ++j, ++result_ptr) { + auto& bbox = result_ptr->bbox; + INFO(">> bbox[{}].score: {}, coordinate: ", i, result_ptr->score); + for (auto& _bbox : result_ptr->bbox) { + INFO(">> >> ({}, {})", _bbox.x, _bbox.y); + } } - cout << endl; } - } - mmdeploy_text_detector_release_result(results, result_count, (int)mats.size()); - mmdeploy_text_detector_destroy(handle); + mmdeploy_text_detector_release_result(results, result_count, (int)mats.size()); + mmdeploy_text_detector_destroy(handle); + }; + + auto& gResources = MMDeployTestResources::Get(); + auto img_list = gResources.LocateImageResources("mmocr/images"); + REQUIRE(!img_list.empty()); + + for (auto& backend : gResources.backends()) { + DYNAMIC_SECTION("loop backend: " << backend) { + auto model_list = gResources.LocateModelResources("mmocr/textdet/" + backend); + REQUIRE(!model_list.empty()); + for (auto& model_path : model_list) { + for (auto& device_name : gResources.device_names(backend)) { + test(device_name, model_path, img_list); + } + } + } + } } diff --git a/tests/test_csrc/capi/test_text_recognizer.cpp b/tests/test_csrc/capi/test_text_recognizer.cpp index a941573be9..94f01063dc 100644 --- a/tests/test_csrc/capi/test_text_recognizer.cpp +++ b/tests/test_csrc/capi/test_text_recognizer.cpp @@ -4,101 +4,122 @@ #include "catch.hpp" // clang-format on -#include -#include - #include "apis/c/text_recognizer.h" #include "core/logger.h" #include "core/utils/formatter.h" #include "opencv2/opencv.hpp" +#include "test_resource.h" using namespace std; -static std::string ReadFileContent(const char* path) { - std::ifstream ifs(path, std::ios::binary); - ifs.seekg(0, std::ios::end); - auto size = ifs.tellg(); - ifs.seekg(0, std::ios::beg); - std::string bin(size, 0); - ifs.read((char*)bin.data(), size); - return bin; -} - TEST_CASE("test text recognizer's c api", "[text-recognizer]") { - const auto model_path = "../../config/text-recognizer/crnn"; - - mm_handle_t handle{nullptr}; - auto ret = mmdeploy_text_recognizer_create_by_path(model_path, "cuda", 0, &handle); - REQUIRE(ret == MM_SUCCESS); - - cv::Mat mat = cv::imread("/data/verify/mmsdk/18.png"); - vector mats{{mat.data, mat.rows, mat.cols, mat.channels(), MM_BGR, MM_INT8}}; - mats.push_back(mats.back()); - mats.push_back(mats.back()); - mats.push_back(mats.back()); - - mm_text_recognize_t* results{}; - ret = mmdeploy_text_recognizer_apply_bbox(handle, mats.data(), (int)mats.size(), nullptr, nullptr, - &results); - REQUIRE(ret == MM_SUCCESS); - - for (auto i = 0; i < mats.size(); ++i) { - std::vector score(results[i].score, results[i].score + results[i].length); - INFO("image {}, text = {}, score = {}", i, results[i].text, score); - } + auto test = [](const string& device, const string& model_path, const vector& img_list) { + mm_handle_t handle{nullptr}; + auto ret = + mmdeploy_text_recognizer_create_by_path(model_path.c_str(), device.c_str(), 0, &handle); + REQUIRE(ret == MM_SUCCESS); + + vector cv_mats; + vector mats; + for (auto& img_path : img_list) { + cv::Mat mat = cv::imread(img_path); + REQUIRE(!mat.empty()); + cv_mats.push_back(mat); + mats.push_back({mat.data, mat.rows, mat.cols, mat.channels(), MM_BGR, MM_INT8}); + } - mmdeploy_text_recognizer_release_result(results, (int)mats.size()); - mmdeploy_text_recognizer_destroy(handle); + mm_text_recognize_t* results{}; + ret = mmdeploy_text_recognizer_apply_bbox(handle, mats.data(), (int)mats.size(), nullptr, + nullptr, &results); + REQUIRE(ret == MM_SUCCESS); + + for (auto i = 0; i < mats.size(); ++i) { + std::vector score(results[i].score, results[i].score + results[i].length); + INFO("image {}, text = {}, score = {}", i, results[i].text, score); + } + + mmdeploy_text_recognizer_release_result(results, (int)mats.size()); + mmdeploy_text_recognizer_destroy(handle); + }; + + auto& gResources = MMDeployTestResources::Get(); + auto img_list = gResources.LocateImageResources("mmocr/images"); + REQUIRE(!img_list.empty()); + + for (auto& backend : gResources.backends()) { + DYNAMIC_SECTION("loop backend: " << backend) { + auto model_list = gResources.LocateModelResources("mmocr/textreg/" + backend); + REQUIRE(!model_list.empty()); + for (auto& model_path : model_list) { + for (auto& device_name : gResources.device_names(backend)) { + test(device_name, model_path, img_list); + } + } + } + } } TEST_CASE("test text detector-recognizer combo", "[text-detector-recognizer]") { - const auto det_model_path = "../../config/text-detector/dbnet18_t4-cuda11.1-trt7.2-fp16"; - mm_handle_t detector{}; - REQUIRE(mmdeploy_text_detector_create_by_path(det_model_path, "cuda", 0, &detector) == - MM_SUCCESS); - - mm_handle_t recognizer{}; - const auto reg_model_path = "../../config/text-recognizer/crnn"; - REQUIRE(mmdeploy_text_recognizer_create_by_path(reg_model_path, "cuda", 0, &recognizer) == - MM_SUCCESS); - - const char* file_list[] = { - "../../tests/data/image/demo_kie.jpeg", "../../tests/data/image/demo_text_det.jpg", - "../../tests/data/image/demo_text_ocr.jpg", "../../tests/data/image/demo_text_recog.jpg"}; - - vector cv_mats; - vector mats; - for (const auto filename : file_list) { - cv::Mat mat = cv::imread(filename); - cv_mats.push_back(mat); - mats.push_back({mat.data, mat.rows, mat.cols, mat.channels(), MM_BGR, MM_INT8}); - } + auto test = [](const std::string& device, const string& det_model_path, + const string& reg_model_path, std::vector& img_list) { + mm_handle_t detector{}; + REQUIRE(mmdeploy_text_detector_create_by_path(det_model_path.c_str(), device.c_str(), 0, + &detector) == MM_SUCCESS); + mm_handle_t recognizer{}; + REQUIRE(mmdeploy_text_recognizer_create_by_path(reg_model_path.c_str(), device.c_str(), 0, + &recognizer) == MM_SUCCESS); + + vector cv_mats; + vector mats; + for (const auto& img_path : img_list) { + cv::Mat mat = cv::imread(img_path); + REQUIRE(!mat.empty()); + cv_mats.push_back(mat); + mats.push_back({mat.data, mat.rows, mat.cols, mat.channels(), MM_BGR, MM_INT8}); + } - mm_text_detect_t* bboxes{}; - int* bbox_count{}; - REQUIRE(mmdeploy_text_detector_apply(detector, mats.data(), mats.size(), &bboxes, &bbox_count) == - MM_SUCCESS); + mm_text_detect_t* bboxes{}; + int* bbox_count{}; + REQUIRE(mmdeploy_text_detector_apply(detector, mats.data(), mats.size(), &bboxes, + &bbox_count) == MM_SUCCESS); - mm_text_recognize_t* texts{}; + mm_text_recognize_t* texts{}; - REQUIRE(mmdeploy_text_recognizer_apply_bbox(recognizer, mats.data(), (int)mats.size(), bboxes, - bbox_count, &texts) == MM_SUCCESS); + REQUIRE(mmdeploy_text_recognizer_apply_bbox(recognizer, mats.data(), (int)mats.size(), bboxes, + bbox_count, &texts) == MM_SUCCESS); - int offset = 0; - for (auto i = 0; i < mats.size(); ++i) { - for (int j = 0; j < bbox_count[i]; ++j) { - auto& text = texts[offset + j]; - std::vector score(text.score, text.score + text.length); - INFO("image {}, text = {}, score = {}", i, text.text, score); + int offset = 0; + for (auto i = 0; i < mats.size(); ++i) { + for (int j = 0; j < bbox_count[i]; ++j) { + auto& text = texts[offset + j]; + std::vector score(text.score, text.score + text.length); + INFO("image {}, text = {}, score = {}", i, text.text, score); + } + offset += bbox_count[i]; } - offset += bbox_count[i]; - } - - mmdeploy_text_recognizer_release_result(texts, offset); - mmdeploy_text_detector_release_result(bboxes, bbox_count, offset); - - mmdeploy_text_recognizer_destroy(recognizer); - - mmdeploy_text_detector_destroy(detector); + mmdeploy_text_recognizer_release_result(texts, offset); + mmdeploy_text_detector_release_result(bboxes, bbox_count, offset); + + mmdeploy_text_recognizer_destroy(recognizer); + mmdeploy_text_detector_destroy(detector); + }; + + auto& gResources = MMDeployTestResources::Get(); + auto img_list = gResources.LocateImageResources("mmocr/images"); + REQUIRE(!img_list.empty()); + + for (auto& backend : gResources.backends()) { + DYNAMIC_SECTION("loop backend: " << backend) { + auto det_model_list = gResources.LocateModelResources("/mmocr/textdet/" + backend); + auto reg_model_list = gResources.LocateModelResources("/mmocr/textreg/" + backend); + REQUIRE(!det_model_list.empty()); + REQUIRE(!reg_model_list.empty()); + auto det_model_path = det_model_list.front(); + auto reg_model_path = reg_model_list.front(); + for (auto& device_name : gResources.device_names(backend)) { + test(device_name, det_model_path, reg_model_path, img_list); + } + } + } } diff --git a/tests/test_csrc/config/test_define.h.in b/tests/test_csrc/config/test_define.h.in new file mode 100644 index 0000000000..1d22b75b49 --- /dev/null +++ b/tests/test_csrc/config/test_define.h.in @@ -0,0 +1,10 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_TEST_DEFINE_H +#define MMDEPLOY_TEST_DEFINE_H + +static constexpr const char *kBackends = "@MMDEPLOY_TARGET_BACKENDS@"; +static constexpr const char *kDevices = "@MMDEPLOY_TARGET_DEVICES@"; +static constexpr const char *kCodebases = "@CODEBASES@"; + +#endif // MMDEPLOY_TEST_DEFINE_H diff --git a/tests/test_csrc/core/test_mat.cpp b/tests/test_csrc/core/test_mat.cpp index 84748f5a29..b1ae27cb35 100644 --- a/tests/test_csrc/core/test_mat.cpp +++ b/tests/test_csrc/core/test_mat.cpp @@ -5,54 +5,15 @@ #include "catch.hpp" #include "core/logger.h" #include "core/mat.h" +#include "test_resource.h" + using namespace mmdeploy; using namespace std; -// ostream& operator << (ostream& stream, PixelFormat format) { -// switch (format) { -// case PixelFormat::kGRAYSCALE: -// stream << "gray_scale"; -// break; -// case PixelFormat::kNV12: -// stream << "nv12"; break; -// case PixelFormat::kNV21: -// stream << "nv21"; break; -// case PixelFormat::kBGR: -// stream << "bgr"; break; -// case PixelFormat::kRGB: -// stream << "rgb"; -// break; -// case PixelFormat::kBGRA: -// stream << "bgra"; -// break; -// default: -// stream << "unknown_pixel_format"; -// break; -// } -// return stream; -// } -// ostream& operator << (ostream& stream, DataType type) { -// switch (type) { -// case DataType::kFLOAT: -// stream << "float"; -// break; -// case DataType::kHALF: -// stream << "half"; -// break; -// case DataType::kINT32: -// stream << "int"; -// break; -// case DataType::kINT8: -// stream << "int8"; -// break; -// default: -// stream << "unknown_data_type"; -// break; -// } -// return stream; -// } - TEST_CASE("default mat constructor", "[mat]") { + auto gResource = MMDeployTestResources::Get(); + const Device kHost{"cpu"}; + SECTION("default constructor") { Mat mat; REQUIRE(mat.pixel_format() == PixelFormat::kGRAYSCALE); @@ -72,18 +33,21 @@ TEST_CASE("default mat constructor", "[mat]") { PixelFormat::kNV21, PixelFormat::kBGRA}; std::array data_types{DataType::kFLOAT, DataType::kHALF, DataType::kINT8, DataType::kINT32}; + int success = 0; for (auto format : pixel_formats) { for (auto data_type : data_types) { - Mat mat{100, 200, format, data_type, Device{"cpu"}}; + Mat mat{100, 200, format, data_type, kHost}; success += (mat.byte_size() > 0); } } REQUIRE(success == pixel_formats.size() * data_types.size()); - Mat mat(100, 200, pixel_formats[0], data_types[0], Device{}); - REQUIRE_THROWS(Mat{100, 200, PixelFormat(0xff), DataType::kINT8, Device{"cpu"}}); - REQUIRE_THROWS(Mat{100, 200, PixelFormat::kGRAYSCALE, DataType(0xff), Device{"cpu"}}); + for (auto &device_name : gResource.device_names()) { + Device device{device_name.c_str()}; + REQUIRE_THROWS(Mat{100, 200, PixelFormat(0xff), DataType::kINT8, device}); + REQUIRE_THROWS(Mat{100, 200, PixelFormat::kGRAYSCALE, DataType(0xff), device}); + } } SECTION("construct with data") { @@ -91,65 +55,47 @@ TEST_CASE("default mat constructor", "[mat]") { constexpr int kCols = 200; vector data(kRows * kCols, 0); SECTION("void* data") { - Mat mat{kRows, kCols, PixelFormat::kGRAYSCALE, DataType::kINT8, data.data(), Device{"cpu"}}; + Mat mat{kRows, kCols, PixelFormat::kGRAYSCALE, DataType::kINT8, data.data(), kHost}; REQUIRE(mat.byte_size() > 0); } + SECTION("shared_ptr") { - std::shared_ptr data_ptr(data.data(), [&](void* p) {}); - Mat mat{kRows, kCols, PixelFormat::kGRAYSCALE, DataType::kINT8, data_ptr, Device{"cpu"}}; + std::shared_ptr data_ptr(data.data(), [&](void *p) {}); + Mat mat{kRows, kCols, PixelFormat::kGRAYSCALE, DataType::kINT8, data_ptr, kHost}; REQUIRE(mat.byte_size() > 0); } } } TEST_CASE("mat constructor in difference devices", "[mat]") { + auto gResource = MMDeployTestResources::Get(); + constexpr int kRows = 10; constexpr int kCols = 10; constexpr int kSize = kRows * kCols; - SECTION("host") { - vector data(kSize); - std::iota(data.begin(), data.end(), 1); + vector data(kSize); + std::iota(data.begin(), data.end(), 1); + + for (auto &device_name : gResource.device_names()) { + Device device{device_name.c_str()}; - Device host{"cpu"}; - Mat mat{kRows, kCols, PixelFormat::kGRAYSCALE, DataType::kINT8, host}; - Stream stream = Stream::GetDefault(host); + // copy to device + Mat mat{kRows, kCols, PixelFormat::kGRAYSCALE, DataType::kINT8, device}; + Stream stream = Stream::GetDefault(device); REQUIRE(stream.Copy(data.data(), mat.buffer(), mat.buffer().GetSize())); REQUIRE(stream.Wait()); - auto data_ptr = mat.data(); + // copy to host + vector host_data(mat.size()); + REQUIRE(stream.Copy(mat.buffer(), host_data.data(), mat.byte_size())); + REQUIRE(stream.Wait()); + + // compare data to check if they are the same int count = 0; - for (size_t i = 0; i < mat.size(); ++i) { - count += (data_ptr[i] == data[i]); + for (size_t i = 0; i < host_data.size(); ++i) { + count += (host_data[i] == data[i]); } REQUIRE(count == mat.size()); } - - SECTION("cuda") { - try { - vector data(kSize); - std::iota(data.begin(), data.end(), 1); - - Device cuda{"cuda"}; - Mat mat(kRows, kCols, PixelFormat::kGRAYSCALE, DataType::kFLOAT, cuda); - REQUIRE(mat.byte_size() == kSize * sizeof(float)); - - Stream stream = Stream::GetDefault(cuda); - REQUIRE(stream.Copy(data.data(), mat.buffer(), mat.byte_size())); - - vector host_data(mat.size()); - REQUIRE(stream.Copy(mat.buffer(), host_data.data(), mat.byte_size())); - - REQUIRE(stream.Wait()); - - int count = 0; - REQUIRE(mat.data() != nullptr); - for (size_t i = 0; i < host_data.size(); ++i) { - count += (host_data[i] == data[i]); - } - REQUIRE(count == mat.size()); - } catch (const Exception& e) { - ERROR("exception happened: {}", e.what()); - } - } } diff --git a/tests/test_csrc/core/test_value.cpp b/tests/test_csrc/core/test_value.cpp index 0e26ad47c7..07bfe6d7ff 100644 --- a/tests/test_csrc/core/test_value.cpp +++ b/tests/test_csrc/core/test_value.cpp @@ -336,6 +336,6 @@ TEST_CASE("test speed of value", "[value]") { } TEST_CASE("test ctor of value", "[value]") { - static_assert(!std::is_constructible::value, ""); + static_assert(!std::is_constructible::value, ""); static_assert(!std::is_constructible::value, ""); } diff --git a/tests/test_csrc/graph/load_image.cpp b/tests/test_csrc/graph/load_image.cpp deleted file mode 100644 index a2e972f8a9..0000000000 --- a/tests/test_csrc/graph/load_image.cpp +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "core/mat.h" -#include "core/module.h" -#include "core/registry.h" -#include "opencv2/imgcodecs.hpp" -#include "preprocess/cpu/opencv_utils.h" - -namespace test { - -using namespace mmdeploy; - -class LoadImageModule : public mmdeploy::Module { - public: - Result Process(const Value& args) override { - auto filename = args[0]["filename"].get(); - cv::Mat img = cv::imread(filename); - if (!img.data) { - ERROR("Failed to load image: {}", filename); - return Status(eInvalidArgument); - } - auto mat = mmdeploy::cpu::CVMat2Mat(img, PixelFormat::kBGR); - return Value{{{"ori_img", mat}}}; - } -}; - -class LoadImageModuleCreator : public Creator { - public: - const char* GetName() const override { return "LoadImage"; } - int GetVersion() const override { return 0; } - std::unique_ptr Create(const Value& value) override { - return std::make_unique(); - } -}; - -REGISTER_MODULE(Module, LoadImageModuleCreator); - -} // namespace test diff --git a/tests/test_csrc/graph/test_crnn.cpp b/tests/test_csrc/graph/test_crnn.cpp index 362d4e5072..e69de29bb2 100644 --- a/tests/test_csrc/graph/test_crnn.cpp +++ b/tests/test_csrc/graph/test_crnn.cpp @@ -1,70 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -// clang-format off -#include "catch.hpp" -// clang-format on - -#include -#include -#include - -#include "archive/json_archive.h" -#include "core/graph.h" -#include "core/mat.h" -#include "core/registry.h" - -const auto json_str = R"({ - "pipeline": { - "tasks": [ - { - "name": "load", - "type": "Task", - "module": "LoadImage", - "input": ["input"], - "output": ["img"] - }, - { - "name": "cls", - "type": "Inference", - "params": { - "model": "../../config/text-recognizer/crnn", - "batch_size": 1 - }, - "input": ["img"], - "output": ["text"] - } - ], - "input": ["input"], - "output": ["img", "text"] - } -} -)"; - -TEST_CASE("test crnn", "[crnn]") { - using namespace mmdeploy; - - auto json = nlohmann::json::parse(json_str); - auto value = mmdeploy::from_json(json); - - value["context"]["device"] = Device("cuda"); - value["context"]["stream"] = Stream::GetDefault(Device(0)); - auto pipeline = Registry::Get().GetCreator("Pipeline")->Create(value); - REQUIRE(pipeline); - - graph::TaskGraph graph; - pipeline->Build(graph); - - const auto img_list = "../crnn/imglist.txt"; - - Device device{"cpu"}; - auto stream = Stream::GetDefault(device); - - std::ifstream ifs(img_list); - - std::string path; - for (int image_id = 0; ifs >> path; ++image_id) { - auto output = graph.Run({{{"filename", path}}}); - REQUIRE(output); - INFO("output: {}", output.value()); - } -} diff --git a/tests/test_csrc/graph/test_dbnet18.cpp b/tests/test_csrc/graph/test_dbnet18.cpp deleted file mode 100644 index 2df31b06cc..0000000000 --- a/tests/test_csrc/graph/test_dbnet18.cpp +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -// clang-format off -#include "catch.hpp" -// clang-format on - -#include -#include - -#include "archive/json_archive.h" -#include "core/graph.h" -#include "core/mat.h" -#include "core/registry.h" - -const auto json_str = R"({ - "pipeline": { - "tasks": [ - { - "name": "load", - "type": "Task", - "module": "LoadImage", - "input": ["input"], - "output": ["img"] - }, - { - "name": "textdet", - "type": "Inference", - "params": { - "model": "../../config/text-detector/dbnet18_t4-cuda11.1-trt7.2-fp16" - }, - "input": ["img"], - "output": ["det"] - }, - { - "name": "warp", - "type": "Task", - "module": "WarpBoxes", - "input": ["img", "det"], - "output": ["warp"] - } - ], - "input": ["input"], - "output": ["img", "det", "warp"] - } -} -)"; - -TEST_CASE("test dbnet18", "[dbnet18]") { - using namespace mmdeploy; - auto json = nlohmann::json::parse(json_str); - auto value = mmdeploy::from_json(json); - - Device device{"cuda"}; - auto stream = Stream::GetDefault(device); - value["context"]["device"] = device; - value["context"]["stream"] = stream; - - auto pipeline = Registry::Get().GetCreator("Pipeline")->Create(value); - REQUIRE(pipeline); - - graph::TaskGraph graph; - pipeline->Build(graph); - - const auto img_list = "../../dbnet18/imglist.txt"; - - std::ifstream ifs(img_list); - - std::string path; - for (int image_id = 0; ifs >> path; ++image_id) { - auto output = graph.Run({{{"filename", path}, {"image_id", image_id}}}); - REQUIRE(output); - INFO("output: {}", output.value()); - } -} diff --git a/tests/test_csrc/graph/test_imagenet.cpp b/tests/test_csrc/graph/test_imagenet.cpp deleted file mode 100644 index a82f543067..0000000000 --- a/tests/test_csrc/graph/test_imagenet.cpp +++ /dev/null @@ -1,166 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -// clang-format off -#include "catch.hpp" -// clang-format on - -#include -#include -#include - -#include "archive/json_archive.h" -#include "core/graph.h" -#include "core/mat.h" -#include "core/operator.h" -#include "core/registry.h" -#include "core/tensor.h" - -const auto json_str = R"({ - "pipeline": { - "input": ["input", "id"], - "output": ["output"], - "tasks": [ - { - "name": "load", - "type": "Task", - "module": "LoadImage", - "input": ["input"], - "output": ["img"], - "is_thread_safe": true - }, - { - "name": "cls", - "type": "Inference", - "params": { - "model": "../../resnet50", - "batch_size": 1 - }, - "input": ["img"], - "output": ["prob"] - }, - { - "name": "accuracy", - "type": "Task", - "module": "Accuracy", - "input": ["prob", "id"], - "output": ["output"], - "gt": "/data/imagenet_val_gt.txt" - } - ] - } -} -)"; - -namespace test { - -using namespace mmdeploy; - -class AccuracyModule : public mmdeploy::Module { - public: - explicit AccuracyModule(const Value& config) { - stream_ = config["context"]["stream"].get(); - auto path = config["gt"].get(); - std::ifstream ifs(path); - if (!ifs.is_open()) { - throw_exception(eFileNotExist); - } - std::string _; - for (int clsid = -1; ifs >> _ >> clsid;) { - label_.push_back(clsid); - } - } - Result Process(const Value& input) override { - // WARN("{}", to_json(input).dump(2)); - std::vector probs(1000); - auto tensor = input[0]["probs"].get(); - auto image_id = input[1].get(); - // auto stream = Stream::GetDefault(tensor.desc().device); - OUTCOME_TRY(tensor.CopyTo(probs.data(), stream_)); - OUTCOME_TRY(stream_.Wait()); - std::vector idx(probs.size()); - iota(begin(idx), end(idx), 0); - partial_sort(begin(idx), begin(idx) + 5, end(idx), - [&](int i, int j) { return probs[i] > probs[j]; }); - // ERROR("top-1: {}", idx[0]); - auto gt = label_[image_id]; - if (idx[0] == gt) { - ++top1_; - } - if (std::find(begin(idx), begin(idx) + 5, gt) != begin(idx) + 5) { - ++top5_; - } - ++cnt_; - auto fcnt = static_cast(cnt_); - if ((image_id + 1) % 1000 == 0) { - ERROR("index: {}, top1: {}, top5: {}", image_id, top1_ / fcnt, top5_ / fcnt); - } - return Value{ValueType::kObject}; - } - - private: - int cnt_{0}; - int top1_{0}; - int top5_{0}; - Stream stream_; - std::vector label_; -}; - -class AccuracyModuleCreator : public Creator { - public: - const char* GetName() const override { return "Accuracy"; } - int GetVersion() const override { return 0; } - std::unique_ptr Create(const Value& value) override { - return std::make_unique(value); - } -}; - -REGISTER_MODULE(Module, AccuracyModuleCreator); - -} // namespace test - -TEST_CASE("test mmcls imagenet", "[imagenet]") { - - using namespace mmdeploy; - auto json = nlohmann::json::parse(json_str); - auto value = mmdeploy::from_json(json); - - // Device device{"cuda", 0}; - Device device("cpu"); - auto stream = Stream::GetDefault(device); - value["context"]["device"] = device; - value["context"]["stream"] = stream; - - auto pipeline = Registry::Get().GetCreator("Pipeline")->Create(value); - REQUIRE(pipeline); - - graph::TaskGraph graph; - pipeline->Build(graph); - - // const auto img_list = "../tests/data/config/imagenet.list"; - const auto img_list = "/data/imagenet_val.txt"; - - std::ifstream ifs(img_list); - REQUIRE(ifs.is_open()); - - int image_id = 0; - const auto batch_size = 64; - bool done{}; - while (!done) { - // if (image_id > 5000) break; - Value batch = Value::kArray; - for (int i = 0; i < batch_size; ++i) { - std::string path; - if (ifs >> path) { - batch.push_back({{{"filename", path}}, image_id++}); - } else { - done = true; - break; - } - } - if (!batch.empty()) { - batch = graph::DistribAA(batch).value(); - graph.Run(batch).value(); - } - break; - } -} diff --git a/tests/test_csrc/graph/test_ocr.cpp b/tests/test_csrc/graph/test_ocr.cpp deleted file mode 100644 index 2f078ed7c8..0000000000 --- a/tests/test_csrc/graph/test_ocr.cpp +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -// clang-format off -#include "catch.hpp" -// clang-format on - -#include -#include - -#include "archive/json_archive.h" -#include "core/graph.h" -#include "core/mat.h" -#include "core/registry.h" - -using namespace mmdeploy; - -class DrawOCR : public Module { - public: - explicit DrawOCR(const Value& config) {} - - Result Process(const Value& input) override { return Value{ValueType::kNull}; } - - private: -}; - -class DrawOCRCreator : public mmdeploy::Creator { - public: - const char* GetName() const override { return "DrawOCR"; } - int GetVersion() const override { return 0; } - std::unique_ptr Create(const Value& value) override { - return std::make_unique(value); - } -}; - -REGISTER_MODULE(Module, DrawOCRCreator); - -TEST_CASE("test ocr det & recog", "[ocr_det_recog]") { - using namespace mmdeploy; - - std::string json_str; - { - std::ifstream ifs("../../tests/data/config/ocr_det_recog.json"); - REQUIRE(ifs.is_open()); - json_str = std::string(std::istreambuf_iterator(ifs), std::istreambuf_iterator()); - } - - auto json = nlohmann::json::parse(json_str); - auto value = mmdeploy::from_json(json); - - Device device{"cuda", 0}; - auto stream = Stream::GetDefault(device); - - value["context"].update({{"device", device}, {"stream", stream}}); - - auto pipeline = Registry::Get().GetCreator("Pipeline")->Create(value); - REQUIRE(pipeline); - - graph::TaskGraph graph; - pipeline->Build(graph); - - const auto img_list = "../../tests/data/config/ocr_det_recog.list"; - - std::vector files; - { - std::ifstream ifs(img_list); - std::string path; - while (ifs >> path) { - files.push_back(path); - } - } - - auto output = graph.Run({{{{"filename", files[0]}}, - {{"filename", files[1]}}, - {{"filename", files[2]}}, - {{"filename", files[3]}}}}); - REQUIRE(output); - INFO("output: {}", output.value()); -} diff --git a/tests/test_csrc/model/test_directory_model.cpp b/tests/test_csrc/model/test_directory_model.cpp new file mode 100644 index 0000000000..6ea1bacc99 --- /dev/null +++ b/tests/test_csrc/model/test_directory_model.cpp @@ -0,0 +1,37 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +// clang-format off +#include "catch.hpp" +// clang-format on +#include "core/model.h" +#include "core/model_impl.h" +#include "test_resource.h" + +using namespace mmdeploy; + +TEST_CASE("test directory model", "[model]") { + std::unique_ptr model_impl; + for (auto& entry : ModelRegistry::Get().ListEntries()) { + if (entry.name == "DirectoryModel") { + model_impl = entry.creator(); + break; + } + } + REQUIRE(model_impl); + + auto& gResource = MMDeployTestResources::Get(); + auto directory_model_list = gResource.LocateModelResources("sdk_models"); + REQUIRE(!directory_model_list.empty()); + auto model_dir = "sdk_models/good_model"; + REQUIRE(gResource.IsDir(model_dir)); + auto model_path = gResource.resource_root_path() + "/" + model_dir; + REQUIRE(!model_impl->Init(model_path).has_error()); + REQUIRE(!model_impl->ReadFile("deploy.json").has_error()); + REQUIRE(model_impl->ReadFile("not-existing-file").has_error()); + + model_dir = "sdk_models/bad_model"; + REQUIRE(gResource.IsDir(model_dir)); + model_path = gResource.resource_root_path() + "/" + model_dir; + REQUIRE(!model_impl->Init(model_path).has_error()); + REQUIRE(model_impl->ReadMeta().has_error()); +} diff --git a/tests/test_csrc/model/test_model.cpp b/tests/test_csrc/model/test_model.cpp index 8d41626cfc..b00f8c2b5f 100644 --- a/tests/test_csrc/model/test_model.cpp +++ b/tests/test_csrc/model/test_model.cpp @@ -1,138 +1,51 @@ // Copyright (c) OpenMMLab. All rights reserved. -#include - +// clang-format off #include "catch.hpp" +// clang-format on #include "core/logger.h" #include "core/model.h" #include "core/model_impl.h" +#include "test_resource.h" using namespace mmdeploy; -namespace mmdeploy { -bool operator==(const model_meta_info_t a, const model_meta_info_t b) { - return a.name == b.name && a.net == b.net && a.weights == b.weights && a.backend == b.backend && - a.batch_size == b.batch_size && a.precision == b.precision && - a.dynamic_shape == b.dynamic_shape; -} -// std::ostream& operator<<(std::ostream& os, const model_meta_info_t& a) { -// os << a.name << ", " << a.net << ", " << a.weights << ", " << a.backend -// << ", " << a.batch_size << ", " << a.precision << ", " << -// a.dynamic_shape -// << std::endl; -// return os; -// } -} // namespace mmdeploy TEST_CASE("model constructor", "[model]") { SECTION("default constructor") { Model model; REQUIRE(!model); } - SECTION("explicit constructor") { - try { - Model model("../../tests/data/model/resnet50"); - REQUIRE(model); - Model failed_model("unsupported_sdk_model_format"); - } catch (const Exception& e) { - ERROR("exception happened: {}", e.what()); - REQUIRE(true); - } + SECTION("explicit constructor with model path") { + REQUIRE_THROWS(Model{"path/to/not/existing/model"}); } + SECTION("explicit constructor with buffer") { REQUIRE_THROWS(Model{nullptr, 0}); } } -TEST_CASE("test plain model implementation", "[model]") { - Model model; - - REQUIRE(!model); - - SECTION("load failed") { REQUIRE(!model.Init("unsupported_sdk_model_format")); } - - SECTION("read meta failed") { - std::string path{"../../tests/data/model"}; - REQUIRE(!model.Init(path)); - } - - SECTION("invalid meta file") { - std::string path{"../../tests/data/model/resnet50_bad_deploy_meta"}; - REQUIRE(!model.Init(path)); - } - - SECTION("normal case") { - Result res = success(); - SECTION("plain model") { - std::string path{"../../tests/data/model/resnet50"}; - res = model.Init(path); +TEST_CASE("model init", "[model]") { + auto& gResource = MMDeployTestResources::Get(); + for (auto& codebase : gResource.codebases()) { + if (auto img_list = gResource.LocateImageResources(codebase + "/images"); !img_list.empty()) { + Model model; + REQUIRE(model.Init(img_list.front()).has_error()); + break; } - - REQUIRE(model); - REQUIRE(res); - - const deploy_meta_info_t expected_meta{ - "0.1.0", - {{"resnet50", "resnet50.engine", "resnet50.engine", "trt", 32, "INT8", false}}, - {}}; - auto meta = model.meta(); - REQUIRE(meta.version == expected_meta.version); - REQUIRE(meta.models == expected_meta.models); - REQUIRE(meta.customs == expected_meta.customs); - auto model_meta = model.GetModelConfig(meta.models[0].name); - REQUIRE(model_meta.value() == meta.models[0]); - model_meta = model.GetModelConfig("error_model_name"); - REQUIRE(model_meta.has_error()); } -} - -TEST_CASE("zip model implementation", "[model]") { - Model model; - std::string path{"../../tests/data/model/resnet50.zip"}; - auto res = model.Init(path); - if (!res.has_error()) { - const deploy_meta_info_t expected_meta{ - "0.1.0", - {{"resnet50", "resnet50.engine", "resnet50.engine", "trt", 32, "INT8", false}}, - {}}; - auto meta = model.meta(); - REQUIRE(meta.version == expected_meta.version); - REQUIRE(meta.models == expected_meta.models); - REQUIRE(meta.customs == expected_meta.customs); - auto model_meta = model.GetModelConfig(meta.models[0].name); - REQUIRE(model_meta.value() == meta.models[0]); - model_meta = model.GetModelConfig("error_model_name"); - REQUIRE(model_meta.has_error()); - } -} - -TEST_CASE("zip model from buffer", "[model]") { - Model model; - std::string path{"../../tests/data/model/resnet50.zip"}; - std::ifstream ifs(path, std::ios::binary | std::ios::in); - REQUIRE(ifs.is_open()); - std::string buffer((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); - auto res = model.Init(buffer.data(), buffer.size()); - if (!res.has_error()) { - const deploy_meta_info_t expected_meta{ - "0.1.0", - {{"resnet50", "resnet50.engine", "resnet50.engine", "trt", 32, "INT8", false}}, - {}}; - auto meta = model.meta(); - REQUIRE(meta.version == expected_meta.version); - REQUIRE(meta.models == expected_meta.models); - REQUIRE(meta.customs == expected_meta.customs); - auto model_meta = model.GetModelConfig(meta.models[0].name); - REQUIRE(model_meta.value() == meta.models[0]); - model_meta = model.GetModelConfig("error_model_name"); - REQUIRE(model_meta.has_error()); + for (auto& codebase : gResource.codebases()) { + for (auto& backend : gResource.backends()) { + if (auto model_list = gResource.LocateModelResources(codebase + "/" + backend); + !model_list.empty()) { + Model model; + REQUIRE(!model.Init(model_list.front()).has_error()); + REQUIRE(!model.ReadFile("deploy.json").has_error()); + auto const& meta = model.meta(); + REQUIRE(!model.GetModelConfig(meta.models[0].name).has_error()); + REQUIRE(model.GetModelConfig("not-existing-model").has_error()); + break; + } + } } } -TEST_CASE("bad zip buffer", "[model1]") { - std::vector buffer(100); - Model model; - REQUIRE(!model.Init(buffer.data(), buffer.size())); -} - -TEST_CASE("ReadFile", "[model]") {} - TEST_CASE("ModelRegistry", "[model]") { class ANewModelImpl : public ModelImpl { Result Init(const std::string& sdk_model_path) override { return Status(eNotSupported); } @@ -145,8 +58,8 @@ TEST_CASE("ModelRegistry", "[model]") { } }; - // Test duplicated register. `ZipModel` is already registered. - (void)ModelRegistry::Get().Register("PlainModel", []() -> std::unique_ptr { + // Test duplicated register. `DirectoryModel` is already registered. + (void)ModelRegistry::Get().Register("DirectoryModel", []() -> std::unique_ptr { return std::make_unique(); }); } diff --git a/tests/test_csrc/model/test_zip_model.cpp b/tests/test_csrc/model/test_zip_model.cpp new file mode 100644 index 0000000000..48f787bdea --- /dev/null +++ b/tests/test_csrc/model/test_zip_model.cpp @@ -0,0 +1,52 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +// clang-format off +#include "catch.hpp" +// clang-format on +#include + +#include "core/logger.h" +#include "core/model.h" +#include "core/model_impl.h" +#include "test_resource.h" + +using namespace std; +using namespace mmdeploy; + +TEST_CASE("test zip model", "[zip_model]") { + std::unique_ptr model_impl; + for (auto& entry : ModelRegistry::Get().ListEntries()) { + if (entry.name == "ZipModel") { + model_impl = entry.creator(); + break; + } + } + REQUIRE(model_impl); + + auto& gResource = MMDeployTestResources::Get(); + SECTION("bad sdk model") { + auto zip_model_path = "sdk_models/not_zip_file"; + REQUIRE(gResource.IsFile(zip_model_path)); + auto model_path = gResource.resource_root_path() + "/" + zip_model_path; + REQUIRE(model_impl->Init(model_path).has_error()); + } + SECTION("bad zip buffer") { + std::vector buffer(100); + REQUIRE(model_impl->Init(buffer.data(), buffer.size()).has_error()); + } + + SECTION("good sdk model") { + auto zip_model_path = "sdk_models/good_model.zip"; + REQUIRE(gResource.IsFile(zip_model_path)); + auto model_path = gResource.resource_root_path() + "/" + zip_model_path; + REQUIRE(!model_impl->Init(model_path).has_error()); + REQUIRE(!model_impl->ReadFile("deploy.json").has_error()); + REQUIRE(model_impl->ReadFile("not-exist-file").has_error()); + REQUIRE(!model_impl->ReadMeta().has_error()); + + ifstream ifs(model_path, std::ios::binary | std::ios::in); + REQUIRE(ifs.is_open()); + string buffer((istreambuf_iterator(ifs)), istreambuf_iterator()); + REQUIRE(!model_impl->Init(buffer.data(), buffer.size()).has_error()); + } +} diff --git a/tests/test_csrc/net/test_ncnn_net.cpp b/tests/test_csrc/net/test_ncnn_net.cpp new file mode 100644 index 0000000000..98b348c19b --- /dev/null +++ b/tests/test_csrc/net/test_ncnn_net.cpp @@ -0,0 +1,31 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +// clang-format off +#include "catch.hpp" +// clang-format on + +#include "core/model.h" +#include "core/net.h" +#include "test_resource.h" + +using namespace mmdeploy; + +TEST_CASE("test ncnn net", "[ncnn_net]") { + auto& gResource = MMDeployTestResources::Get(); + auto model_list = gResource.LocateModelResources("mmcls/ncnn"); + REQUIRE(!model_list.empty()); + + Model model(model_list.front()); + REQUIRE(model); + + auto backend("ncnn"); + auto creator = Registry::Get().GetCreator(backend); + REQUIRE(creator); + + Device device{"cpu"}; + auto stream = Stream::GetDefault(device); + Value net_config{{"context", {{"device", device}, {"model", model}, {"stream", stream}}}, + {"name", model.meta().models[0].name}}; + auto net = creator->Create(net_config); + REQUIRE(net); +} diff --git a/tests/test_csrc/net/test_net.cpp b/tests/test_csrc/net/test_net.cpp deleted file mode 100644 index 8b5defa447..0000000000 --- a/tests/test_csrc/net/test_net.cpp +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include -#include -#include - -#include "catch.hpp" -#include "core/model.h" -#include "core/net.h" - -using namespace mmdeploy; - -static Value ReadFileContent(const char* path) { - std::ifstream ifs(path, std::ios::binary); - ifs.seekg(0, std::ios::end); - auto size = ifs.tellg(); - ifs.seekg(0, std::ios::beg); - Value::Binary bin(size); - ifs.read((char*)bin.data(), size); - return bin; -} - -template ::value_type, - std::enable_if_t && std::is_integral_v, int> = 0> -std::string shape_string(const T& v) { - std::stringstream ss; - ss << "("; - auto first = true; - for (const auto& x : v) { - if (!first) { - ss << ", "; - } else { - first = false; - } - ss << x; - } - ss << ")"; - return ss.str(); -} - -TEST_CASE("test pplnn", "[net]") { - auto backend = "pplnn"; - Model model("../../resnet50"); - REQUIRE(model); - auto img_path = "../../sea_lion.txt"; - - auto creator = Registry::Get().GetCreator(backend); - REQUIRE(creator); - Device device{"cpu"}; - auto stream = Stream::GetDefault(device); - // clang-format off - Value net_config{ - {"context", { - {"device", device}, - {"model", model}, - {"stream", stream} - } - }, - {"name", "resnet50"} - }; - // clang-format on - auto net = creator->Create(net_config); - - std::vector img(3 * 224 * 224); - { - std::ifstream ifs(img_path); - REQUIRE(ifs.is_open()); - for (auto& x : img) { - ifs >> x; - } - } - - std::vector input_shape{{1, 3, 224, 224}}; - REQUIRE(net->Reshape(input_shape)); - - auto inputs = net->GetInputTensors().value(); - - for (auto& tensor : inputs) { - std::cout << "input: " << tensor.name() << " " << shape_string(tensor.shape()) << "\n"; - } - - REQUIRE(inputs.front().CopyFrom(img.data(), stream)); - REQUIRE(stream.Wait()); - - REQUIRE(net->Forward()); - - auto outputs = net->GetOutputTensors().value(); - - for (auto& tensor : outputs) { - std::cout << "output: " << tensor.name() << " " << shape_string(tensor.shape()) << "\n"; - } - - std::vector logits(1000); - REQUIRE(outputs.front().CopyTo(logits.data(), stream)); - REQUIRE(stream.Wait()); - - auto cls_id = std::max_element(logits.begin(), logits.end()) - logits.begin(); - std::cout << "class id = " << cls_id << "\n"; -} diff --git a/tests/test_csrc/net/test_net_module.cpp b/tests/test_csrc/net/test_net_module.cpp deleted file mode 100644 index ac7a7dff68..0000000000 --- a/tests/test_csrc/net/test_net_module.cpp +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include - -#include "catch.hpp" -#include "core/model.h" -#include "core/module.h" -#include "core/registry.h" -#include "net/net_module.h" - -using namespace mmdeploy; - -TEST_CASE("test net module", "[net]") { - auto creator = Registry::Get().GetCreator("Net"); - REQUIRE(creator); - - Device device("cpu"); - auto stream = Stream::GetDefault(device); - REQUIRE(stream); - - Model model("../../resnet50"); - REQUIRE(model); - - auto net = - creator->Create({{"name", "resnet50"}, - {"context", {{"device", device}, {"stream", stream}, {"model", model}}}}); - REQUIRE(net); - - std::vector img(3 * 224 * 224); - { - std::ifstream ifs("../../sea_lion.bin", std::ios::binary | std::ios::in); - REQUIRE(ifs.is_open()); - ifs.read((char*)img.data(), img.size() * sizeof(float)); - } - - Tensor input{TensorDesc{ - .device = device, .data_type = DataType::kFLOAT, .shape = {1, 3, 224, 224}, .name = "input"}}; - - REQUIRE(input.CopyFrom(img.data(), stream)); - - auto result = net->Process({{{"input", input}}}); - REQUIRE(result); - - auto& output = result.value(); - - std::vector probs(1000); - REQUIRE(output[0]["probs"].get().CopyTo(probs.data(), stream)); - - REQUIRE(stream.Wait()); - - auto cls_id = max_element(begin(probs), end(probs)) - begin(probs); - - std::cout << "cls_id: " << cls_id << ", prob: " << probs[cls_id] << "\n"; - REQUIRE(cls_id == 150); -} diff --git a/tests/test_csrc/net/test_openvino_net.cpp b/tests/test_csrc/net/test_openvino_net.cpp new file mode 100644 index 0000000000..f4a2f683f3 --- /dev/null +++ b/tests/test_csrc/net/test_openvino_net.cpp @@ -0,0 +1,31 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +// clang-format off +#include "catch.hpp" +// clang-format on + +#include "core/model.h" +#include "core/net.h" +#include "test_resource.h" + +using namespace mmdeploy; + +TEST_CASE("test openvino net", "[openvino_net]") { + auto& gResource = MMDeployTestResources::Get(); + auto model_list = gResource.LocateModelResources("mmcls/openvino"); + REQUIRE(!model_list.empty()); + + Model model(model_list.front()); + REQUIRE(model); + + auto backend("openvino"); + auto creator = Registry::Get().GetCreator(backend); + REQUIRE(creator); + + Device device{"cpu"}; + auto stream = Stream::GetDefault(device); + Value net_config{{"context", {{"device", device}, {"model", model}, {"stream", stream}}}, + {"name", model.meta().models[0].name}}; + auto net = creator->Create(net_config); + REQUIRE(net); +} diff --git a/tests/test_csrc/net/test_ort_net.cpp b/tests/test_csrc/net/test_ort_net.cpp new file mode 100644 index 0000000000..506fbaf199 --- /dev/null +++ b/tests/test_csrc/net/test_ort_net.cpp @@ -0,0 +1,31 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +// clang-format off +#include "catch.hpp" +// clang-format on + +#include "core/model.h" +#include "core/net.h" +#include "test_resource.h" + +using namespace mmdeploy; + +TEST_CASE("test ort net", "[ort_net]") { + auto& gResource = MMDeployTestResources::Get(); + auto model_list = gResource.LocateModelResources("mmcls/ort"); + REQUIRE(!model_list.empty()); + + Model model(model_list.front()); + REQUIRE(model); + + auto backend("onnxruntime"); + auto creator = Registry::Get().GetCreator(backend); + REQUIRE(creator); + + Device device{"cpu"}; + auto stream = Stream::GetDefault(device); + Value net_config{{"context", {{"device", device}, {"model", model}, {"stream", stream}}}, + {"name", model.meta().models[0].name}}; + auto net = creator->Create(net_config); + REQUIRE(net); +} diff --git a/tests/test_csrc/net/test_ppl_net.cpp b/tests/test_csrc/net/test_ppl_net.cpp new file mode 100644 index 0000000000..64a6a478a1 --- /dev/null +++ b/tests/test_csrc/net/test_ppl_net.cpp @@ -0,0 +1,37 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +// clang-format off +#include "catch.hpp" +// clang-format on + +#include "core/model.h" +#include "core/net.h" +#include "test_resource.h" + +using namespace mmdeploy; + +TEST_CASE("test pplnn net", "[ppl_net]") { + auto& gResource = MMDeployTestResources::Get(); + auto model_list = gResource.LocateModelResources("mmcls/pplnn"); + REQUIRE(!model_list.empty()); + + Model model(model_list.front()); + REQUIRE(model); + + auto backend = "pplnn"; + auto creator = Registry::Get().GetCreator(backend); + REQUIRE(creator); + + Device device{"cpu"}; + auto stream = Stream::GetDefault(device); + // clang-format off + Value net_config{ + {"context", { + {"device", device}, + {"model", model}, + {"stream", stream} + } + }, + {"name", model.meta().models[0].name} + }; +} diff --git a/tests/test_csrc/net/test_trt_net.cpp b/tests/test_csrc/net/test_trt_net.cpp index 1835d496ec..2b2841d772 100644 --- a/tests/test_csrc/net/test_trt_net.cpp +++ b/tests/test_csrc/net/test_trt_net.cpp @@ -1,21 +1,31 @@ // Copyright (c) OpenMMLab. All rights reserved. +// clang-format off #include "catch.hpp" +// clang-format on + #include "core/model.h" #include "core/net.h" +#include "test_resource.h" using namespace mmdeploy; TEST_CASE("test trt net", "[trt_net]") { - Model model("../../config/detector/retinanet_t4-cuda11.1-trt7.2-fp32"); - auto backend("trt"); + auto& gResource = MMDeployTestResources::Get(); + auto model_list = gResource.LocateModelResources("mmcls/trt"); + REQUIRE(!model_list.empty()); + + Model model(model_list.front()); + REQUIRE(model); + + auto backend("tensorrt"); auto creator = Registry::Get().GetCreator(backend); + REQUIRE(creator); Device device{"cuda"}; auto stream = Stream::GetDefault(device); Value net_config{{"context", {{"device", device}, {"model", model}, {"stream", stream}}}, - {"name", "retinanet"}}; - + {"name", model.meta().models[0].name}}; auto net = creator->Create(net_config); REQUIRE(net); } diff --git a/tests/test_csrc/preprocess/transform/test_collect.cpp b/tests/test_csrc/preprocess/test_collect.cpp similarity index 79% rename from tests/test_csrc/preprocess/transform/test_collect.cpp rename to tests/test_csrc/preprocess/test_collect.cpp index 06b2254608..0ac7e7091a 100644 --- a/tests/test_csrc/preprocess/transform/test_collect.cpp +++ b/tests/test_csrc/preprocess/test_collect.cpp @@ -12,22 +12,10 @@ TEST_CASE("test collect constructor", "[collect]") { auto creator = Registry::Get().GetCreator(transform_type, 1); REQUIRE(creator != nullptr); - SECTION("empty args") { - try { - auto module = creator->Create({}); - } catch (std::exception& e) { - REQUIRE(true); - INFO("expected exception: {}", e.what()); - } - } + REQUIRE_THROWS(creator->Create({})); SECTION("args with 'keys' which is not an array") { - try { - auto module = creator->Create({{"keys", "img"}}); - } catch (std::exception& e) { - REQUIRE(true); - INFO("expected exception: {}", e.what()); - } + REQUIRE_THROWS(creator->Create({{"keys", "img"}})); } SECTION("args with keys in array") { @@ -36,12 +24,7 @@ TEST_CASE("test collect constructor", "[collect]") { } SECTION("args with meta_keys that is not an array") { - try { - auto module = creator->Create({{"keys", {"img"}}, {"meta_keys", "ori_img"}}); - } catch (std::exception& e) { - REQUIRE(true); - INFO("expected exception: {}", e.what()); - } + REQUIRE_THROWS(creator->Create({{"keys", {"img"}}, {"meta_keys", "ori_img"}})); } SECTION("args with meta_keys in array") { auto module = creator->Create({{"keys", {"img"}}, {"meta_keys", {"ori_img"}}}); @@ -87,7 +70,7 @@ TEST_CASE("test collect", "[collect]") { Tensor tensor; Value input{{"img", tensor}, {"filename", "test.jpg"}, - {"ori_filename", "../tests/preprocess/data/test.jpg"}, + {"ori_filename", "/the/path/of/test.jpg"}, {"ori_shape", {1000, 1000, 3}}, {"img_shape", {1, 3, 224, 224}}, {"flip", "false"}, diff --git a/tests/test_csrc/preprocess/test_compose.cpp b/tests/test_csrc/preprocess/test_compose.cpp new file mode 100644 index 0000000000..9b7cd4d8d1 --- /dev/null +++ b/tests/test_csrc/preprocess/test_compose.cpp @@ -0,0 +1,100 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include + +// clang-format off +#include "catch.hpp" +// clang-format on + +#include "archive/json_archive.h" +#include "core/mat.h" +#include "core/registry.h" +#include "core/utils/formatter.h" +#include "json.hpp" +#include "preprocess/cpu/opencv_utils.h" +#include "test_resource.h" +#include "test_utils.h" + +using namespace mmdeploy; +using namespace mmdeploy::test; +using namespace std; +using nlohmann::json; + +static constexpr const char *gPipelineConfig = R"( +[{ + "type": "LoadImageFromFile" + }, + { + "type": "Resize", + "size": [ + 256, -1 + ] + }, + { + "type": "CenterCrop", + "crop_size": 224 + }, + { + "type": "Normalize", + "mean": [ + 123.675, + 116.28, + 103.53 + ], + "std": [ + 58.395, + 57.12, + 57.375 + ], + "to_rgb": true + }, + { + "type": "ImageToTensor", + "keys": [ + "img" + ] + }, + { + "type": "Collect", + "keys": [ + "img" + ] + } +] +)"; + +TEST_CASE("transform Compose exceptional case", "[compose]") { + Value compose_cfg; + SECTION("wrong transform type") { + compose_cfg = {{"type", "Compose"}, {"transforms", {{{"type", "collect"}}}}}; + } + + SECTION("wrong transform parameter") { + compose_cfg = {{"type", "Compose"}, {"transforms", {{{"type", "Collect"}}}}}; + } + const Device kHost{"cpu"}; + Stream stream{kHost}; + REQUIRE(CreateTransform(compose_cfg, kHost, stream) == nullptr); +} + +TEST_CASE("transform Compose", "[compose]") { + auto gResource = MMDeployTestResources::Get(); + auto img_list = gResource.LocateImageResources("transform"); + REQUIRE(!img_list.empty()); + + auto img_path = img_list.front(); + cv::Mat bgr_mat = cv::imread(img_path, cv::IMREAD_COLOR); + auto src_mat = cpu::CVMat2Mat(bgr_mat, PixelFormat::kBGR); + Value input{{"ori_img", src_mat}}; + + auto json = json::parse(gPipelineConfig); + auto cfg = ::mmdeploy::from_json(json); + Value compose_cfg{{"type", "Compose"}, {"transforms", cfg}}; + + const Device kHost{"cpu"}; + Stream stream{kHost}; + auto transform = CreateTransform(compose_cfg, kHost, stream); + REQUIRE(transform != nullptr); + auto res = transform->Process({{"ori_img", src_mat}}); + REQUIRE(!res.has_error()); +} diff --git a/tests/test_csrc/preprocess/test_crop.cpp b/tests/test_csrc/preprocess/test_crop.cpp new file mode 100644 index 0000000000..f963282cd0 --- /dev/null +++ b/tests/test_csrc/preprocess/test_crop.cpp @@ -0,0 +1,109 @@ + +// Copyright (c) OpenMMLab. All rights reserved. + +#include "catch.hpp" +#include "core/mat.h" +#include "preprocess/cpu/opencv_utils.h" +#include "preprocess/transform/transform.h" +#include "preprocess/transform/transform_utils.h" +#include "test_resource.h" +#include "test_utils.h" + +using namespace mmdeploy; +using namespace std; +using namespace mmdeploy::test; + +tuple CenterCropArea(const cv::Mat& mat, int crop_height, int crop_width) { + auto img_height = mat.rows; + auto img_width = mat.cols; + auto y1 = max(0, int(round((img_height - crop_height) / 2.))); + auto x1 = max(0, int(round((img_width - crop_width) / 2.))); + auto y2 = min(img_height, y1 + crop_height) - 1; + auto x2 = min(img_width, x1 + crop_width) - 1; + return {y1, x1, y2, x2}; +} + +void TestCenterCrop(const Value& cfg, const cv::Mat& mat, int crop_height, int crop_width) { + auto gResource = MMDeployTestResources::Get(); + for (auto const& device_name : gResource.device_names()) { + Device device{device_name.c_str()}; + Stream stream{device}; + auto transform = CreateTransform(cfg, device, stream); + REQUIRE(transform != nullptr); + + auto [top, left, bottom, right] = CenterCropArea(mat, crop_height, crop_width); + auto ref_mat = mmdeploy::cpu::Crop(mat, top, left, bottom, right); + auto res = transform->Process({{"img", cpu::CVMat2Tensor(mat)}}); + REQUIRE(!res.has_error()); + auto res_tensor = res.value()["img"].get(); + REQUIRE(res_tensor.device() == device); + REQUIRE(Shape(res.value(), "img_shape") == + vector{1, ref_mat.rows, ref_mat.cols, ref_mat.channels()}); + + const Device kHost{"cpu"}; + auto host_tensor = MakeAvailableOnDevice(res_tensor, kHost, stream); + REQUIRE(stream.Wait()); + + auto res_mat = mmdeploy::cpu::Tensor2CVMat(host_tensor.value()); + REQUIRE(mmdeploy::cpu::Compare(ref_mat, res_mat)); + } +} + +TEST_CASE("transform CenterCrop", "[crop]") { + auto gResource = MMDeployTestResources::Get(); + auto img_list = gResource.LocateImageResources("transform"); + REQUIRE(!img_list.empty()); + + auto img_path = img_list.front(); + cv::Mat bgr_mat = cv::imread(img_path, cv::IMREAD_COLOR); + cv::Mat gray_mat = cv::imread(img_path, cv::IMREAD_GRAYSCALE); + cv::Mat bgr_float_mat; + cv::Mat gray_float_mat; + bgr_mat.convertTo(bgr_float_mat, CV_32FC3); + gray_mat.convertTo(gray_float_mat, CV_32FC1); + + vector mats{bgr_mat, gray_mat, bgr_float_mat, gray_float_mat}; + + SECTION("crop_size: int; small size") { + constexpr int crop_size = 224; + Value cfg{{"type", "CenterCrop"}, {"crop_size", crop_size}}; + for (auto& mat : mats) { + TestCenterCrop(cfg, mat, crop_size, crop_size); + } + } + + SECTION("crop_size: int; oversize") { + constexpr int crop_size = 800; + Value cfg{{"type", "CenterCrop"}, {"crop_size", crop_size}}; + for (auto& mat : mats) { + TestCenterCrop(cfg, mat, crop_size, crop_size); + } + } + + SECTION("crop_size: tuple") { + constexpr int crop_height = 224; + constexpr int crop_width = 224; + Value cfg{{"type", "CenterCrop"}, {"crop_size", {crop_height, crop_width}}}; + for (auto& mat : mats) { + TestCenterCrop(cfg, mat, crop_height, crop_width); + } + } + + SECTION("crop_size: tuple;oversize in height") { + constexpr int crop_height = 640; + constexpr int crop_width = 224; + Value cfg{{"type", "CenterCrop"}, {"crop_size", {crop_height, crop_width}}}; + for (auto& mat : mats) { + TestCenterCrop(cfg, mat, crop_height, crop_width); + } + } + + SECTION("crop_size: tuple;oversize in width") { + constexpr int crop_height = 224; + constexpr int crop_width = 800; + Value cfg{{"type", "CenterCrop"}, {"crop_size", {crop_height, crop_width}}}; + for (auto& mat : mats) { + TestCenterCrop(cfg, mat, crop_height, crop_width); + } + } +} diff --git a/tests/test_csrc/preprocess/test_image2tensor.cpp b/tests/test_csrc/preprocess/test_image2tensor.cpp new file mode 100644 index 0000000000..1e1d5d0232 --- /dev/null +++ b/tests/test_csrc/preprocess/test_image2tensor.cpp @@ -0,0 +1,68 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#include "catch.hpp" +#include "core/tensor.h" +#include "preprocess/cpu/opencv_utils.h" +#include "preprocess/transform/transform.h" +#include "preprocess/transform/transform_utils.h" +#include "test_resource.h" +#include "test_utils.h" + +using namespace mmdeploy; +using namespace mmdeploy::test; +using namespace std; + +void TestImage2Tensor(const Value& cfg, const cv::Mat& mat) { + auto gResource = MMDeployTestResources::Get(); + for (auto const& device_name : gResource.device_names()) { + Device device{device_name.c_str()}; + Stream stream{device}; + auto transform = CreateTransform(cfg, device, stream); + REQUIRE(transform != nullptr); + + vector channel_mats(mat.channels()); + for (auto i = 0; i < mat.channels(); ++i) { + cv::extractChannel(mat, channel_mats[i], i); + } + + auto res = transform->Process({{"img", cpu::CVMat2Tensor(mat)}}); + REQUIRE(!res.has_error()); + auto res_tensor = res.value()["img"].get(); + REQUIRE(res_tensor.device() == device); + auto shape = res_tensor.desc().shape; + REQUIRE(shape == std::vector{1, mat.channels(), mat.rows, mat.cols}); + + const Device kHost{"cpu"}; + auto host_tensor = MakeAvailableOnDevice(res_tensor, kHost, stream); + REQUIRE(stream.Wait()); + + // mat's shape is {h, w, c}, while res_tensor's shape is {1, c, h, w} + // compare each channel between `res_tensor` and `mat` + auto step = shape[2] * shape[3] * mat.elemSize1(); + auto data = host_tensor.value().data(); + for (auto i = 0; i < mat.channels(); ++i) { + cv::Mat _mat{mat.rows, mat.cols, CV_MAKETYPE(mat.depth(), 1), data}; + REQUIRE(::mmdeploy::cpu::Compare(channel_mats[i], _mat)); + data += step; + } + } +} + +TEST_CASE("transform ImageToTensor", "[img2tensor]") { + auto gResource = MMDeployTestResources::Get(); + auto img_list = gResource.LocateImageResources("transform"); + REQUIRE(!img_list.empty()); + + auto img_path = img_list.front(); + cv::Mat bgr_mat = cv::imread(img_path, cv::IMREAD_COLOR); + cv::Mat gray_mat = cv::imread(img_path, cv::IMREAD_GRAYSCALE); + cv::Mat bgr_float_mat; + cv::Mat gray_float_mat; + bgr_mat.convertTo(bgr_float_mat, CV_32FC3); + gray_mat.convertTo(gray_float_mat, CV_32FC1); + + Value cfg{{"type", "ImageToTensor"}, {"keys", {"img"}}}; + vector mats{bgr_mat, gray_mat, bgr_float_mat, gray_float_mat}; + for (auto& mat : mats) { + TestImage2Tensor(cfg, mat); + } +} diff --git a/tests/test_csrc/preprocess/test_load.cpp b/tests/test_csrc/preprocess/test_load.cpp new file mode 100644 index 0000000000..3f9b4068e0 --- /dev/null +++ b/tests/test_csrc/preprocess/test_load.cpp @@ -0,0 +1,81 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "catch.hpp" +#include "core/mat.h" +#include "core/tensor.h" +#include "preprocess/cpu/opencv_utils.h" +#include "preprocess/transform/transform.h" +#include "preprocess/transform/transform_utils.h" +#include "test_resource.h" +#include "test_utils.h" + +using namespace mmdeploy; +using namespace std; +using namespace mmdeploy::test; + +void TestLoad(const Value& cfg, const cv::Mat& mat, PixelFormat src_format, + PixelFormat dst_format) { + auto gResource = MMDeployTestResources::Get(); + for (auto const& device_name : gResource.device_names()) { + Device device{device_name.c_str()}; + Stream stream{device}; + auto transform = CreateTransform(cfg, device, stream); + REQUIRE(transform != nullptr); + + auto ref_mat = mmdeploy::cpu::ColorTransfer(mat, src_format, dst_format); + + auto res = transform->Process({{"ori_img", cpu::CVMat2Mat(mat, PixelFormat(src_format))}}); + REQUIRE(!res.has_error()); + auto res_tensor = res.value()["img"].get(); + REQUIRE(res_tensor.device() == device); + REQUIRE(Shape(res.value(), "img_shape") == + vector{1, ref_mat.rows, ref_mat.cols, ref_mat.channels()}); + REQUIRE(Shape(res.value(), "ori_shape") == + vector{1, mat.rows, mat.cols, mat.channels()}); + REQUIRE(res.value().contains("img_fields")); + REQUIRE(res.value()["img_fields"].is_array()); + REQUIRE(res.value()["img_fields"].size() == 1); + REQUIRE(res.value()["img_fields"][0].get() == "img"); + + const Device kHost{"cpu"}; + auto host_tensor = MakeAvailableOnDevice(res_tensor, kHost, stream); + REQUIRE(stream.Wait()); + + auto res_mat = mmdeploy::cpu::Tensor2CVMat(host_tensor.value()); + REQUIRE(mmdeploy::cpu::Compare(ref_mat, res_mat)); + } +} + +TEST_CASE("prepare image, that is LoadImageFromFile transform", "[load]") { + auto gResource = MMDeployTestResources::Get(); + auto img_list = gResource.LocateImageResources("transform"); + REQUIRE(!img_list.empty()); + + auto img_path = img_list.front(); + cv::Mat bgr_mat = cv::imread(img_path, cv::IMREAD_COLOR); + cv::Mat gray_mat = cv::imread(img_path, cv::IMREAD_GRAYSCALE); + cv::Mat rgb_mat; + cv::Mat bgra_mat; + // TODO: make up yuv nv12/nv21 mat + + cv::cvtColor(bgr_mat, rgb_mat, cv::COLOR_BGR2RGB); + cv::cvtColor(bgr_mat, bgra_mat, cv::COLOR_BGR2BGRA); + + vector> mats{{bgr_mat, PixelFormat::kBGR}, + {rgb_mat, PixelFormat::kRGB}, + {gray_mat, PixelFormat::kGRAYSCALE}, + {bgra_mat, PixelFormat::kBGRA}}; + // pair is + vector> conditions{ + {"color", true}, {"color", false}, {"grayscale", true}, {"grayscale", false}}; + + for (auto& condition : conditions) { + Value cfg{{"type", "LoadImageFromFile"}, + {"to_float32", condition.second}, + {"color_type", condition.first}}; + for (auto& mat : mats) { + TestLoad(cfg, mat.first, mat.second, + condition.first == "color" ? PixelFormat::kBGR : PixelFormat::kGRAYSCALE); + } + } +} diff --git a/tests/test_csrc/preprocess/test_normalize.cpp b/tests/test_csrc/preprocess/test_normalize.cpp new file mode 100644 index 0000000000..8fcac13d37 --- /dev/null +++ b/tests/test_csrc/preprocess/test_normalize.cpp @@ -0,0 +1,101 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "catch.hpp" +#include "core/mat.h" +#include "preprocess/cpu/opencv_utils.h" +#include "preprocess/transform/transform.h" +#include "preprocess/transform/transform_utils.h" +#include "test_resource.h" +#include "test_utils.h" + +using namespace mmdeploy; +using namespace mmdeploy::test; +using namespace std; + +void TestNormalize(const Value &cfg, const cv::Mat &mat) { + auto gResource = MMDeployTestResources::Get(); + for (auto const &device_name : gResource.device_names()) { + Device device{device_name.c_str()}; + Stream stream{device}; + auto transform = CreateTransform(cfg, device, stream); + REQUIRE(transform != nullptr); + + vector mean; + vector std; + for (auto &v : cfg["mean"]) { + mean.push_back(v.get()); + } + for (auto &v : cfg["std"]) { + std.push_back(v.get()); + } + bool to_rgb = cfg.value("to_rgb", false); + + auto _mat = mat.clone(); + auto ref_mat = mmdeploy::cpu::Normalize(_mat, mean, std, to_rgb); + + auto res = transform->Process({{"img", cpu::CVMat2Tensor(mat)}}); + REQUIRE(!res.has_error()); + auto res_tensor = res.value()["img"].get(); + REQUIRE(res_tensor.device() == device); + REQUIRE(res_tensor.desc().data_type == DataType::kFLOAT); + REQUIRE(ImageNormCfg(res.value(), "mean") == mean); + REQUIRE(ImageNormCfg(res.value(), "std") == std); + + Device kHost{"cpu"}; + auto host_tensor = MakeAvailableOnDevice(res_tensor, kHost, stream); + REQUIRE(stream.Wait()); + auto res_mat = mmdeploy::cpu::Tensor2CVMat(host_tensor.value()); + REQUIRE(mmdeploy::cpu::Compare(ref_mat, res_mat)); + } +} + +TEST_CASE("transform Normalize", "[normalize]") { + auto gResource = MMDeployTestResources::Get(); + auto img_list = gResource.LocateImageResources("transform"); + REQUIRE(!img_list.empty()); + + auto img_path = img_list.front(); + cv::Mat bgr_mat = cv::imread(img_path); + cv::Mat gray_mat; + cv::Mat float_bgr_mat; + cv::Mat float_gray_mat; + + cv::cvtColor(bgr_mat, gray_mat, cv::COLOR_BGR2GRAY); + bgr_mat.convertTo(float_bgr_mat, CV_32FC3); + gray_mat.convertTo(float_gray_mat, CV_32FC1); + + SECTION("cpu vs gpu: 3 channel mat") { + bool to_rgb = true; + Value cfg{{"type", "Normalize"}, + {"mean", {123.675, 116.28, 103.53}}, + {"std", {58.395, 57.12, 57.375}}, + {"to_rgb", to_rgb}}; + vector mats{bgr_mat, float_bgr_mat}; + for (auto &mat : mats) { + TestNormalize(cfg, mat); + } + } + + SECTION("cpu vs gpu: 3 channel mat, to_rgb false") { + bool to_rgb = false; + Value cfg{{"type", "Normalize"}, + {"mean", {123.675, 116.28, 103.53}}, + {"std", {58.395, 57.12, 57.375}}, + {"to_rgb", to_rgb}}; + + vector mats{bgr_mat, float_bgr_mat}; + for (auto &mat : mats) { + TestNormalize(cfg, mat); + } + } + + SECTION("cpu vs gpu: 1 channel mat") { + bool to_rgb = true; + Value cfg{{"type", "Normalize"}, {"mean", {123.675}}, {"std", {58.395}}, {"to_rgb", to_rgb}}; + + vector mats{gray_mat, float_gray_mat}; + for (auto &mat : mats) { + TestNormalize(cfg, mat); + } + } +} diff --git a/tests/test_csrc/preprocess/test_pad.cpp b/tests/test_csrc/preprocess/test_pad.cpp new file mode 100644 index 0000000000..49ffd4af62 --- /dev/null +++ b/tests/test_csrc/preprocess/test_pad.cpp @@ -0,0 +1,117 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "catch.hpp" +#include "core/mat.h" +#include "preprocess/cpu/opencv_utils.h" +#include "preprocess/transform/transform.h" +#include "preprocess/transform/transform_utils.h" +#include "test_resource.h" +#include "test_utils.h" + +using namespace mmdeploy; +using namespace std; +using namespace mmdeploy::test; + +// left, top, right, bottom +tuple GetPadSize(const cv::Mat& mat, int dst_height, int dst_width) { + return {0, 0, dst_width - mat.cols, dst_height - mat.rows}; +} + +tuple GetPadSize(const cv::Mat& mat, bool square = true) { + int size = std::max(mat.rows, mat.cols); + return GetPadSize(mat, size, size); +} + +tuple GetPadSize(const cv::Mat& mat, int divisor) { + auto pad_h = int(ceil(mat.rows * 1.0 / divisor)) * divisor; + auto pad_w = int(ceil(mat.cols * 1.0 / divisor)) * divisor; + return GetPadSize(mat, pad_h, pad_w); +} + +void TestPad(const Value& cfg, const cv::Mat& mat, int top, int left, int bottom, int right, + int border_type, float val) { + auto gResource = MMDeployTestResources::Get(); + for (auto const& device_name : gResource.device_names()) { + Device device{device_name.c_str()}; + Stream stream{device}; + auto transform = CreateTransform(cfg, device, stream); + REQUIRE(transform != nullptr); + + auto ref_mat = mmdeploy::cpu::Pad(mat, top, left, bottom, right, border_type, val); + + auto res = transform->Process({{"img", cpu::CVMat2Tensor(mat)}}); + REQUIRE(!res.has_error()); + auto res_tensor = res.value()["img"].get(); + REQUIRE(res_tensor.device() == device); + REQUIRE(Shape(res.value(), "pad_shape") == + vector{1, ref_mat.rows, ref_mat.cols, ref_mat.channels()}); + REQUIRE(Shape(res.value(), "pad_fixed_size") == + std::vector{ref_mat.rows, ref_mat.cols}); + + const Device kHost{"cpu"}; + auto host_tensor = MakeAvailableOnDevice(res_tensor, kHost, stream); + REQUIRE(stream.Wait()); + + auto res_mat = mmdeploy::cpu::Tensor2CVMat(host_tensor.value()); + REQUIRE(mmdeploy::cpu::Compare(ref_mat, res_mat)); + } +} + +TEST_CASE("transform 'Pad'", "[pad]") { + auto gResource = MMDeployTestResources::Get(); + auto img_list = gResource.LocateImageResources("transform"); + REQUIRE(!img_list.empty()); + + auto img_path = img_list.front(); + cv::Mat bgr_mat = cv::imread(img_path, cv::IMREAD_COLOR); + cv::Mat gray_mat; + cv::Mat float_bgr_mat; + cv::Mat float_gray_mat; + cv::cvtColor(bgr_mat, gray_mat, cv::COLOR_BGR2GRAY); + bgr_mat.convertTo(float_bgr_mat, CV_32FC3); + gray_mat.convertTo(float_gray_mat, CV_32FC1); + + vector mats{bgr_mat, gray_mat, float_bgr_mat, float_gray_mat}; + vector modes{"constant", "edge", "reflect", "symmetric"}; + map border_map{{"constant", cv::BORDER_CONSTANT}, + {"edge", cv::BORDER_REPLICATE}, + {"reflect", cv::BORDER_REFLECT_101}, + {"symmetric", cv::BORDER_REFLECT}}; + SECTION("pad to square") { + bool square{true}; + float val = 255.0f; + for (auto& mat : mats) { + for (auto& mode : modes) { + Value cfg{ + {"type", "Pad"}, {"pad_to_square", square}, {"padding_mode", mode}, {"pad_val", val}}; + auto [pad_left, pad_top, pad_right, pad_bottom] = GetPadSize(mat, square); + TestPad(cfg, mat, pad_top, pad_left, pad_bottom, pad_right, border_map[mode], 255); + } + } + } + + SECTION("pad with size_divisor") { + constexpr int divisor = 32; + float val = 255.0f; + for (auto& mat : mats) { + for (auto& mode : modes) { + Value cfg{ + {"type", "Pad"}, {"size_divisor", divisor}, {"padding_mode", mode}, {"pad_val", val}}; + auto [pad_left, pad_top, pad_right, pad_bottom] = GetPadSize(mat, divisor); + TestPad(cfg, mat, pad_top, pad_left, pad_bottom, pad_right, border_map[mode], 255); + } + } + } + + SECTION("pad with size") { + constexpr int height = 600; + constexpr int width = 800; + for (auto& mat : mats) { + for (auto& mode : modes) { + Value cfg{{"type", "Pad"}, {"size", {height, width}}, {"padding_mode", mode}}; + auto [pad_left, pad_top, pad_right, pad_bottom] = GetPadSize(mat, height, width); + TestPad(cfg, mat, pad_top, pad_left, pad_bottom, pad_right, border_map[mode], 0); + } + } + } +} diff --git a/tests/test_csrc/preprocess/test_resize.cpp b/tests/test_csrc/preprocess/test_resize.cpp new file mode 100644 index 0000000000..abdec3e5cb --- /dev/null +++ b/tests/test_csrc/preprocess/test_resize.cpp @@ -0,0 +1,300 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "catch.hpp" +#include "core/mat.h" +#include "preprocess/cpu/opencv_utils.h" +#include "preprocess/transform/transform.h" +#include "preprocess/transform/transform_utils.h" +#include "test_resource.h" +#include "test_utils.h" + +using namespace mmdeploy; +using namespace std; +using namespace mmdeploy::test; + +// return {target_height, target_width} +tuple GetTargetSize(const cv::Mat& src, int size0, int size1) { + assert(size0 > 0); + if (size1 > 0) { + return {size0, size1}; + } else { + if (src.rows < src.cols) { + return {size0, size0 * src.cols / src.rows}; + } else { + return {size0 * src.rows / src.cols, size0}; + } + } +} + +// return {target_height, target_width} +tuple GetTargetSize(const cv::Mat& src, int scale0, int scale1, bool keep_ratio) { + auto w = src.cols; + auto h = src.rows; + auto max_long_edge = max(scale0, scale1); + auto max_short_edge = min(scale0, scale1); + if (keep_ratio) { + auto scale_factor = + std::min(max_long_edge * 1.0 / std::max(h, w), max_short_edge * 1.0 / std::min(h, w)); + return {int(h * scale_factor + 0.5f), int(w * scale_factor + 0.5f)}; + } else { + return {scale0, scale1}; + } +} + +void TestResize(const Value& cfg, const std::string& device_name, const cv::Mat& mat, + int dst_height, int dst_width) { + if (MMDeployTestResources::Get().HasDevice(device_name)) { + Device device{device_name.c_str()}; + Stream stream{device}; + + auto transform = CreateTransform(cfg, device, stream); + REQUIRE(transform != nullptr); + + auto interpolation = cfg["interpolation"].get(); + auto ref_mat = mmdeploy::cpu::Resize(mat, dst_height, dst_width, interpolation); + + auto res = transform->Process({{"img", cpu::CVMat2Tensor(mat)}}); + REQUIRE(!res.has_error()); + auto res_tensor = res.value()["img"].get(); + REQUIRE(res_tensor.device().device_id() == device.device_id()); + REQUIRE(res_tensor.device().platform_id() == device.platform_id()); + REQUIRE(res_tensor.device() == device); + REQUIRE(Shape(res.value(), "img_shape") == + vector{1, ref_mat.rows, ref_mat.cols, ref_mat.channels()}); + REQUIRE(Shape(res.value(), "img_shape") == res_tensor.desc().shape); + + const Device kHost{"cpu"}; + auto host_tensor = MakeAvailableOnDevice(res_tensor, kHost, stream); + REQUIRE(stream.Wait()); + + auto res_mat = mmdeploy::cpu::Tensor2CVMat(host_tensor.value()); + REQUIRE(mmdeploy::cpu::Compare(ref_mat, res_mat)); + cv::imwrite("ref.bmp", ref_mat); + cv::imwrite("res.bmp", res_mat); + } +} + +void TestResizeWithScale(const Value& cfg, const std::string& device_name, const cv::Mat& mat, + int scale0, int scale1, bool keep_ratio) { + if (MMDeployTestResources::Get().HasDevice(device_name)) { + Device device{device_name.c_str()}; + Stream stream{device}; + auto transform = CreateTransform(cfg, device, stream); + REQUIRE(transform != nullptr); + + auto [dst_height, dst_width] = GetTargetSize(mat, scale0, scale1, keep_ratio); + auto interpolation = cfg["interpolation"].get(); + auto ref_mat = mmdeploy::cpu::Resize(mat, dst_height, dst_width, interpolation); + + Value input{{"img", cpu::CVMat2Tensor(mat)}, {"scale", {scale0, scale1}}}; + auto res = transform->Process(input); + REQUIRE(!res.has_error()); + auto res_tensor = res.value()["img"].get(); + REQUIRE(res_tensor.device() == device); + REQUIRE(Shape(res.value(), "img_shape") == + vector{1, ref_mat.rows, ref_mat.cols, ref_mat.channels()}); + REQUIRE(Shape(res.value(), "img_shape") == res_tensor.desc().shape); + + const Device kHost{"cpu"}; + auto host_tensor = MakeAvailableOnDevice(res_tensor, kHost, stream); + REQUIRE(stream.Wait()); + + auto res_mat = mmdeploy::cpu::Tensor2CVMat(host_tensor.value()); + REQUIRE(mmdeploy::cpu::Compare(ref_mat, res_mat)); + // cv::imwrite("ref.bmp", ref_mat); + // cv::imwrite("res.bmp", res_mat); + } +} + +void TestResizeWithScaleFactor(const Value& cfg, const std::string& device_name, const cv::Mat& mat, + float scale_factor) { + if (MMDeployTestResources::Get().HasDevice(device_name)) { + Device device{device_name.c_str()}; + Stream stream{device}; + auto transform = CreateTransform(cfg, device, stream); + REQUIRE(transform != nullptr); + + auto [dst_height, dst_width] = make_tuple(mat.rows * scale_factor, mat.cols * scale_factor); + auto interpolation = cfg["interpolation"].get(); + auto ref_mat = mmdeploy::cpu::Resize(mat, dst_height, dst_width, interpolation); + + Value input{{"img", cpu::CVMat2Tensor(mat)}, {"scale_factor", scale_factor}}; + auto res = transform->Process(input); + REQUIRE(!res.has_error()); + auto res_tensor = res.value()["img"].get(); + REQUIRE(res_tensor.device() == device); + REQUIRE(Shape(res.value(), "img_shape") == + vector{1, ref_mat.rows, ref_mat.cols, ref_mat.channels()}); + REQUIRE(Shape(res.value(), "img_shape") == res_tensor.desc().shape); + + const Device kHost{"cpu"}; + auto host_tensor = MakeAvailableOnDevice(res_tensor, kHost, stream); + auto res_mat = mmdeploy::cpu::Tensor2CVMat(host_tensor.value()); + REQUIRE(mmdeploy::cpu::Compare(ref_mat, res_mat)); + // cv::imwrite("ref.bmp", ref_mat); + // cv::imwrite("res.bmp", res_mat); + } +} + +TEST_CASE("resize transform: size", "[resize]") { + auto gResource = MMDeployTestResources::Get(); + auto img_list = gResource.LocateImageResources("transform"); + REQUIRE(!img_list.empty()); + + auto img_path = img_list.front(); + cv::Mat bgr_mat = cv::imread(img_path, cv::IMREAD_COLOR); + cv::Mat gray_mat = cv::imread(img_path, cv::IMREAD_GRAYSCALE); + cv::Mat bgr_float_mat; + cv::Mat gray_float_mat; + bgr_mat.convertTo(bgr_float_mat, CV_32FC3); + gray_mat.convertTo(gray_float_mat, CV_32FC1); + + vector mats{bgr_mat, gray_mat, bgr_float_mat, gray_float_mat}; + vector interpolations{"bilinear", "nearest", "area", "bicubic", "lanczos"}; + set cuda_interpolations{"bilinear", "nearest"}; + constexpr const char* kHost = "cpu"; + SECTION("tuple size with -1") { + for (auto& mat : mats) { + auto size = std::max(mat.rows, mat.cols) + 10; + for (auto& interp : interpolations) { + Value cfg{{"type", "Resize"}, + {"size", {size, -1}}, + {"keep_ratio", false}, + {"interpolation", interp}}; + auto [dst_height, dst_width] = GetTargetSize(mat, size, -1); + TestResize(cfg, kHost, mat, dst_height, dst_width); + if (cuda_interpolations.find(interp) != cuda_interpolations.end()) { + TestResize(cfg, "cuda", mat, dst_height, dst_width); + } + } + } + } + + SECTION("no need to resize") { + for (auto& mat : mats) { + auto size = std::min(mat.rows, mat.cols); + for (auto& interp : interpolations) { + Value cfg{{"type", "Resize"}, + {"size", {size, -1}}, + {"keep_ratio", false}, + {"interpolation", interp}}; + auto [dst_height, dst_width] = GetTargetSize(mat, size, -1); + TestResize(cfg, kHost, mat, dst_height, dst_width); + } + } + } + + SECTION("fixed integer size") { + for (auto& mat : mats) { + constexpr int size = 224; + for (auto& interp : interpolations) { + Value cfg{ + {"type", "Resize"}, {"size", size}, {"keep_ratio", false}, {"interpolation", interp}}; + TestResize(cfg, kHost, mat, size, size); + if (cuda_interpolations.find(interp) != cuda_interpolations.end()) { + TestResize(cfg, "cuda", mat, size, size); + } + } + } + } + + SECTION("fixed size: [1333, 800]. keep_ratio: true") { + constexpr int max_long_edge = 1333; + constexpr int max_short_edge = 800; + bool keep_ratio = true; + for (auto& mat : mats) { + for (auto& interp : interpolations) { + Value cfg{{"type", "Resize"}, + {"size", {max_long_edge, max_short_edge}}, + {"keep_ratio", keep_ratio}, + {"interpolation", interp}}; + auto [dst_height, dst_width] = + GetTargetSize(mat, max_long_edge, max_short_edge, keep_ratio); + TestResize(cfg, kHost, mat, dst_height, dst_width); + if (cuda_interpolations.find(interp) != cuda_interpolations.end()) { + TestResize(cfg, "cuda", mat, dst_height, dst_width); + } + } + } + } + + SECTION("fixed size: [1333, 800]. keep_ratio: false") { + constexpr int dst_height = 800; + constexpr int dst_width = 1333; + bool keep_ratio = false; + for (auto& mat : mats) { + for (auto& interp : interpolations) { + Value cfg{{"type", "Resize"}, + {"size", {dst_height, dst_width}}, + {"keep_ratio", keep_ratio}, + {"interpolation", interp}}; + TestResize(cfg, kHost, mat, dst_height, dst_width); + if (cuda_interpolations.find(interp) != cuda_interpolations.end()) { + TestResize(cfg, "cuda", mat, dst_height, dst_width); + } + } + } + } + + SECTION("fixed size: [800, 1333]. keep_ratio: true") { + constexpr int dst_height = 800; + constexpr int dst_width = 1333; + bool keep_ratio = true; + for (auto& mat : mats) { + for (auto& interp : interpolations) { + Value cfg{{"type", "Resize"}, + {"size", {dst_height, dst_width}}, + {"keep_ratio", keep_ratio}, + {"interpolation", interp}}; + TestResizeWithScale(cfg, kHost, mat, dst_height, dst_width, keep_ratio); + } + } + } + + SECTION("img_scale: [800, 1333]. keep_ratio: false") { + constexpr int dst_height = 800; + constexpr int dst_width = 1333; + bool keep_ratio = false; + for (auto& mat : mats) { + for (auto& interp : interpolations) { + Value cfg{{"type", "Resize"}, + {"size", {dst_height, dst_width}}, + {"keep_ratio", keep_ratio}, + {"interpolation", interp}}; + TestResizeWithScale(cfg, kHost, mat, dst_height, dst_width, keep_ratio); + } + } + } + + SECTION("scale_factor: 0.5") { + float scale_factor = 0.5; + bool keep_ratio = true; + for (auto& mat : mats) { + for (auto& interp : interpolations) { + Value cfg{{"type", "Resize"}, + {"size", {600, 800}}, + {"keep_ratio", keep_ratio}, + {"interpolation", interp}}; + TestResizeWithScaleFactor(cfg, kHost, mat, scale_factor); + } + } + } + + SECTION("resize 4 channel image") { + cv::Mat mat = cv::imread(img_path, cv::IMREAD_COLOR); + cv::Mat bgra_mat; + cv::cvtColor(bgr_mat, bgra_mat, cv::COLOR_BGR2BGRA); + assert(bgra_mat.channels() == 4); + constexpr int size = 256; + auto [dst_height, dst_width] = GetTargetSize(bgra_mat, size, -1); + for (auto& device_name : gResource.device_names()) { + for (auto& interp : cuda_interpolations) { + Value cfg{{"type", "Resize"}, + {"size", {size, -1}}, + {"keep_ratio", false}, + {"interpolation", interp}}; + TestResize(cfg, device_name, bgra_mat, dst_height, dst_width); + } + } + } +} diff --git a/tests/test_csrc/preprocess/transform/test_utils.cpp b/tests/test_csrc/preprocess/test_utils.cpp similarity index 100% rename from tests/test_csrc/preprocess/transform/test_utils.cpp rename to tests/test_csrc/preprocess/test_utils.cpp diff --git a/tests/test_csrc/preprocess/transform/test_utils.h b/tests/test_csrc/preprocess/test_utils.h similarity index 100% rename from tests/test_csrc/preprocess/transform/test_utils.h rename to tests/test_csrc/preprocess/test_utils.h diff --git a/tests/test_csrc/preprocess/transform/test_compose.cpp b/tests/test_csrc/preprocess/transform/test_compose.cpp deleted file mode 100644 index 913b3f659c..0000000000 --- a/tests/test_csrc/preprocess/transform/test_compose.cpp +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include - -// clang-format off -#include "catch.hpp" -// clang-format on - -#include "archive/json_archive.h" -#include "core/mat.h" -#include "core/registry.h" -#include "core/utils/formatter.h" -#include "json.hpp" -#include "preprocess/cpu/opencv_utils.h" -#include "preprocess/transform/transform_utils.h" -#include "test_utils.h" - -using namespace mmdeploy; -using namespace mmdeploy::test; -using namespace std; -using nlohmann::json; - -void TestCpuCompose(const Value& cfg, const cv::Mat& mat) { - Device device{"cpu"}; - Stream stream{device}; - - auto transform = CreateTransform(cfg, device, stream); - REQUIRE(transform != nullptr); -} - -void TestCudaCompose(const Value& cfg, const cv::Mat& mat) { - Device device{"cuda"}; - Stream stream{device}; - - auto transform = CreateTransform(cfg, device, stream); - REQUIRE(transform != nullptr); -} - -TEST_CASE("compose", "[compose]") { - const char* img_path = "../../tests/data/images/ocr.jpg"; - cv::Mat bgr_mat = cv::imread(img_path, cv::IMREAD_COLOR); - auto src_mat = cpu::CVMat2Mat(bgr_mat, PixelFormat::kBGR); - Value input{{"ori_img", src_mat}}; - - auto config_path{"../../config/text-detector/dbnet18_t4-cuda11.1-trt7.2-fp16/pipeline.json"}; - ifstream ifs(config_path); - std::string config(istreambuf_iterator{ifs}, istreambuf_iterator{}); - auto json = json::parse(config); - auto transform_json = json["pipeline"]["tasks"][0]["transforms"]; - auto cfg = ::mmdeploy::from_json(transform_json); - Value compose_cfg{{"type", "Compose"}, {"transforms", cfg}}; - INFO("cfg: {}", compose_cfg); - - Device cpu_device{"cpu"}; - Stream cpu_stream{cpu_device}; - - auto cpu_transform = CreateTransform(compose_cfg, cpu_device, cpu_stream); - REQUIRE(cpu_transform != nullptr); - - auto cpu_result = cpu_transform->Process({{"ori_img", src_mat}}); - REQUIRE(!cpu_result.has_error()); - - auto _cpu_result = cpu_result.value(); - auto cpu_tensor = _cpu_result["img"].get(); - INFO("cpu_tensor.shape: {}", cpu_tensor.shape()); - - cpu_tensor.Reshape( - {cpu_tensor.shape(0), cpu_tensor.shape(2), cpu_tensor.shape(3), cpu_tensor.shape(1)}); - auto ref_mat = mmdeploy::cpu::Tensor2CVMat(cpu_tensor); - INFO("ref_mat, h:{}, w:{}, c:{}", ref_mat.rows, ref_mat.cols, ref_mat.channels()); - - Device cuda_device{"cuda"}; - Stream cuda_stream{cuda_device}; - auto gpu_transform = CreateTransform(compose_cfg, cuda_device, cuda_stream); - REQUIRE(gpu_transform != nullptr); - auto gpu_result = gpu_transform->Process({{"ori_img", src_mat}}); - REQUIRE(!gpu_result.has_error()); - auto _gpu_result = gpu_result.value(); - auto gpu_tensor = _gpu_result["img"].get(); - Device _device{"cpu"}; - auto host_tensor = MakeAvailableOnDevice(gpu_tensor, _device, cuda_stream).value(); - REQUIRE(cuda_stream.Wait()); - INFO("host_tensor.shape: {}", host_tensor.shape()); - host_tensor.Reshape( - {host_tensor.shape(0), host_tensor.shape(2), host_tensor.shape(3), host_tensor.shape(1)}); - auto res_mat = mmdeploy::cpu::Tensor2CVMat(host_tensor); - INFO("res_mat, h:{}, w:{}, c:{}", res_mat.rows, res_mat.cols, res_mat.channels()); - REQUIRE(mmdeploy::cpu::Compare(ref_mat, res_mat)); -} diff --git a/tests/test_csrc/preprocess/transform/test_crop.cpp b/tests/test_csrc/preprocess/transform/test_crop.cpp deleted file mode 100644 index 6e54b5cd00..0000000000 --- a/tests/test_csrc/preprocess/transform/test_crop.cpp +++ /dev/null @@ -1,178 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "catch.hpp" -#include "core/mat.h" -#include "preprocess/cpu/opencv_utils.h" -#include "preprocess/transform/transform.h" -#include "preprocess/transform/transform_utils.h" -#include "test_utils.h" - -using namespace mmdeploy; -using namespace std; -using namespace mmdeploy::test; - -tuple CenterCropArea(const cv::Mat& mat, int crop_height, int crop_width) { - auto img_height = mat.rows; - auto img_width = mat.cols; - auto y1 = max(0, int(round((img_height - crop_height) / 2.))); - auto x1 = max(0, int(round((img_width - crop_width) / 2.))); - auto y2 = min(img_height, y1 + crop_height) - 1; - auto x2 = min(img_width, x1 + crop_width) - 1; - return {y1, x1, y2, x2}; -} - -void TestCpuCenterCrop(const Value& cfg, const cv::Mat& mat, int crop_height, int crop_width) { - Device device{"cpu"}; - Stream stream{device}; - auto transform = CreateTransform(cfg, device, stream); - REQUIRE(transform != nullptr); - - auto [top, left, bottom, right] = CenterCropArea(mat, crop_height, crop_width); - auto ref_mat = mmdeploy::cpu::Crop(mat, top, left, bottom, right); - - auto res = transform->Process({{"img", cpu::CVMat2Tensor(mat)}}); - REQUIRE(!res.has_error()); - auto res_mat = mmdeploy::cpu::Tensor2CVMat(res.value()["img"].get()); - REQUIRE(mmdeploy::cpu::Compare(ref_mat, res_mat)); - REQUIRE(Shape(res.value(), "img_shape") == - vector{1, ref_mat.rows, ref_mat.cols, ref_mat.channels()}); -} - -void TestCudaCenterCrop(const Value& cfg, const cv::Mat& mat, int crop_height, int crop_width) { - Device device{"cuda"}; - Stream stream{device}; - auto transform = CreateTransform(cfg, device, stream); - if (transform == nullptr) { - return; - } - - auto [top, left, bottom, right] = CenterCropArea(mat, crop_height, crop_width); - auto ref_mat = mmdeploy::cpu::Crop(mat, top, left, bottom, right); - - auto res = transform->Process({{"img", cpu::CVMat2Tensor(mat)}}); - REQUIRE(!res.has_error()); - auto res_tensor = res.value()["img"].get(); - REQUIRE(res_tensor.device().is_device()); - Device _device{"cpu"}; - auto host_tensor = MakeAvailableOnDevice(res_tensor, _device, stream); - REQUIRE(stream.Wait()); - - auto res_mat = mmdeploy::cpu::Tensor2CVMat(host_tensor.value()); - // cv::imwrite("ref.jpg",ref_mat); - // cv::imwrite("res.jpg", res_mat); - REQUIRE(mmdeploy::cpu::Compare(ref_mat, res_mat)); - REQUIRE(Shape(res.value(), "img_shape") == - vector{1, ref_mat.rows, ref_mat.cols, ref_mat.channels()}); -} - -TEST_CASE("test transform crop (cpu) process", "[crop]") { - std::string transform_type("CenterCrop"); - const char* img_path = "../../tests/preprocess/data/imagenet_banner.jpeg"; - cv::Mat bgr_mat = cv::imread(img_path, cv::IMREAD_COLOR); - cv::Mat gray_mat = cv::imread(img_path, cv::IMREAD_GRAYSCALE); - cv::Mat bgr_float_mat; - cv::Mat gray_float_mat; - bgr_mat.convertTo(bgr_float_mat, CV_32FC3); - gray_mat.convertTo(gray_float_mat, CV_32FC1); - - vector mats{bgr_mat, gray_mat, bgr_float_mat, gray_float_mat}; - - SECTION("crop_size: int; small size") { - constexpr int crop_size = 224; - Value cfg{{"type", "CenterCrop"}, {"crop_size", crop_size}}; - for (auto& mat : mats) { - TestCpuCenterCrop(cfg, mat, crop_size, crop_size); - } - } - - SECTION("crop_size: int; oversize") { - constexpr int crop_size = 800; - Value cfg{{"type", "CenterCrop"}, {"crop_size", crop_size}}; - for (auto& mat : mats) { - TestCpuCenterCrop(cfg, mat, crop_size, crop_size); - } - } - - SECTION("crop_size: tuple") { - constexpr int crop_height = 224; - constexpr int crop_width = 224; - Value cfg{{"type", "CenterCrop"}, {"crop_size", {crop_height, crop_width}}}; - for (auto& mat : mats) { - TestCpuCenterCrop(cfg, mat, crop_height, crop_width); - } - } - - SECTION("crop_size: tuple;oversize in height") { - constexpr int crop_height = 640; - constexpr int crop_width = 224; - Value cfg{{"type", "CenterCrop"}, {"crop_size", {crop_height, crop_width}}}; - for (auto& mat : mats) { - TestCpuCenterCrop(cfg, mat, crop_height, crop_width); - } - } - - SECTION("crop_size: tuple;oversize in width") { - constexpr int crop_height = 224; - constexpr int crop_width = 800; - Value cfg{{"type", "CenterCrop"}, {"crop_size", {crop_height, crop_width}}}; - for (auto& mat : mats) { - TestCpuCenterCrop(cfg, mat, crop_height, crop_width); - } - } -} - -TEST_CASE("test transform crop (gpu) process", "[crop]") { - std::string transform_type("CenterCrop"); - const char* img_path = "../../tests/preprocess/data/imagenet_banner.jpeg"; - cv::Mat bgr_mat = cv::imread(img_path, cv::IMREAD_COLOR); - cv::Mat gray_mat = cv::imread(img_path, cv::IMREAD_GRAYSCALE); - cv::Mat bgr_float_mat; - cv::Mat gray_float_mat; - bgr_mat.convertTo(bgr_float_mat, CV_32FC3); - gray_mat.convertTo(gray_float_mat, CV_32FC1); - - vector mats{bgr_mat, gray_mat, bgr_float_mat, gray_float_mat}; - - SECTION("crop_size: int; small size") { - constexpr int crop_size = 224; - Value cfg{{"type", "CenterCrop"}, {"crop_size", crop_size}}; - for (auto& mat : mats) { - TestCudaCenterCrop(cfg, mat, crop_size, crop_size); - } - } - - SECTION("crop_size: int; oversize") { - constexpr int crop_size = 800; - Value cfg{{"type", "CenterCrop"}, {"crop_size", crop_size}}; - for (auto& mat : mats) { - TestCudaCenterCrop(cfg, mat, crop_size, crop_size); - } - } - - SECTION("crop_size: tuple") { - constexpr int crop_height = 224; - constexpr int crop_width = 224; - Value cfg{{"type", "CenterCrop"}, {"crop_size", {crop_height, crop_width}}}; - for (auto& mat : mats) { - TestCudaCenterCrop(cfg, mat, crop_height, crop_width); - } - } - - SECTION("crop_size: tuple;oversize in height") { - constexpr int crop_height = 640; - constexpr int crop_width = 224; - Value cfg{{"type", "CenterCrop"}, {"crop_size", {crop_height, crop_width}}}; - for (auto& mat : mats) { - TestCpuCenterCrop(cfg, mat, crop_height, crop_width); - } - } - - SECTION("crop_size: tuple;oversize in width") { - constexpr int crop_height = 224; - constexpr int crop_width = 800; - Value cfg{{"type", "CenterCrop"}, {"crop_size", {crop_height, crop_width}}}; - for (auto& mat : mats) { - TestCudaCenterCrop(cfg, mat, crop_height, crop_width); - } - } -} diff --git a/tests/test_csrc/preprocess/transform/test_image2tensor.cpp b/tests/test_csrc/preprocess/transform/test_image2tensor.cpp deleted file mode 100644 index 2d429d1618..0000000000 --- a/tests/test_csrc/preprocess/transform/test_image2tensor.cpp +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. -#include "catch.hpp" -#include "core/tensor.h" -#include "preprocess/cpu/opencv_utils.h" -#include "preprocess/transform/transform.h" -#include "preprocess/transform/transform_utils.h" -#include "test_utils.h" - -using namespace mmdeploy; -using namespace mmdeploy::test; -using namespace std; - -void TestCpuImage2Tensor(const Value& cfg, const cv::Mat& mat) { - Device device{"cpu"}; - Stream stream{device}; - auto transform = CreateTransform(cfg, device, stream); - REQUIRE(transform != nullptr); - - vector channel_mats(mat.channels()); - for (auto i = 0; i < mat.channels(); ++i) { - cv::extractChannel(mat, channel_mats[i], i); - } - - auto res = transform->Process({{"img", cpu::CVMat2Tensor(mat)}}); - REQUIRE(!res.has_error()); - auto res_tensor = res.value()["img"].get(); - auto shape = res_tensor.desc().shape; - REQUIRE(shape == std::vector{1, mat.channels(), mat.rows, mat.cols}); - - // mat's shape is {h, w, c}, while res_tensor's shape is {1, c, h, w} - // compare each channel between `res_tensor` and `mat` - auto step = shape[2] * shape[3] * mat.elemSize1(); - uint8_t* data = res_tensor.data(); - for (auto i = 0; i < mat.channels(); ++i) { - cv::Mat _mat{mat.rows, mat.cols, CV_MAKETYPE(mat.depth(), 1), data}; - REQUIRE(::mmdeploy::cpu::Compare(channel_mats[i], _mat)); - data += step; - } -} - -void TestCudaImage2Tensor(const Value& cfg, const cv::Mat& mat) { - Device device{"cuda"}; - Stream stream{device}; - auto transform = CreateTransform(cfg, device, stream); - REQUIRE(transform != nullptr); - - vector channel_mats(mat.channels()); - for (auto i = 0; i < mat.channels(); ++i) { - cv::extractChannel(mat, channel_mats[i], i); - } - - auto res = transform->Process({{"img", cpu::CVMat2Tensor(mat)}}); - REQUIRE(!res.has_error()); - auto res_tensor = res.value()["img"].get(); - REQUIRE(res_tensor.device().is_device()); - Device _device{"cpu"}; - auto host_tensor = MakeAvailableOnDevice(res_tensor, _device, stream); - REQUIRE(stream.Wait()); - - auto shape = host_tensor.value().shape(); - REQUIRE(shape == std::vector{1, mat.channels(), mat.rows, mat.cols}); - - // mat's shape is {h, w, c}, while res_tensor's shape is {1, c, h, w} - // compare each channel between `res_tensor` and `mat` - auto step = shape[2] * shape[3] * mat.elemSize1(); - uint8_t* data = host_tensor.value().data(); - for (auto i = 0; i < mat.channels(); ++i) { - cv::Mat _mat{mat.rows, mat.cols, CV_MAKETYPE(mat.depth(), 1), data}; - REQUIRE(::mmdeploy::cpu::Compare(channel_mats[i], _mat)); - data += step; - } -} - -TEST_CASE("test cpu ImageToTensor", "[img2tensor]") { - const char* img_path = "../../tests/preprocess/data/imagenet_banner.jpeg"; - cv::Mat bgr_mat = cv::imread(img_path, cv::IMREAD_COLOR); - cv::Mat gray_mat = cv::imread(img_path, cv::IMREAD_GRAYSCALE); - cv::Mat bgr_float_mat; - cv::Mat gray_float_mat; - bgr_mat.convertTo(bgr_float_mat, CV_32FC3); - gray_mat.convertTo(gray_float_mat, CV_32FC1); - - Value cfg{{"type", "ImageToTensor"}, {"keys", {"img"}}}; - vector mats{bgr_mat, gray_mat, bgr_float_mat, gray_float_mat}; - for (auto& mat : mats) { - TestCpuImage2Tensor(cfg, mat); - } -} - -TEST_CASE("test gpu ImageToTensor", "[img2tensor]") { - const char* img_path = "../../tests/preprocess/data/imagenet_banner.jpeg"; - cv::Mat bgr_mat = cv::imread(img_path, cv::IMREAD_COLOR); - cv::Mat gray_mat = cv::imread(img_path, cv::IMREAD_GRAYSCALE); - cv::Mat bgr_float_mat; - cv::Mat gray_float_mat; - bgr_mat.convertTo(bgr_float_mat, CV_32FC3); - gray_mat.convertTo(gray_float_mat, CV_32FC1); - - Value cfg{{"type", "ImageToTensor"}, {"keys", {"img"}}}; - vector mats{bgr_mat, gray_mat, bgr_float_mat, gray_float_mat}; - for (auto& mat : mats) { - TestCudaImage2Tensor(cfg, mat); - } -} diff --git a/tests/test_csrc/preprocess/transform/test_load.cpp b/tests/test_csrc/preprocess/transform/test_load.cpp deleted file mode 100644 index 6cf1d9b98d..0000000000 --- a/tests/test_csrc/preprocess/transform/test_load.cpp +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "catch.hpp" -#include "core/mat.h" -#include "core/tensor.h" -#include "preprocess/cpu/opencv_utils.h" -#include "preprocess/transform/transform.h" -#include "preprocess/transform/transform_utils.h" -#include "test_utils.h" - -using namespace mmdeploy; -using namespace std; -using namespace mmdeploy::test; - -void TestCpuLoad(const Value& cfg, const cv::Mat& mat, PixelFormat src_format, - PixelFormat dst_format) { - Device device{"cpu"}; - Stream stream{device}; - auto transform = CreateTransform(cfg, device, stream); - REQUIRE(transform != nullptr); - - auto ref_mat = mmdeploy::cpu::ColorTransfer(mat, src_format, dst_format); - - auto res = transform->Process({{"ori_img", cpu::CVMat2Mat(mat, PixelFormat(src_format))}}); - REQUIRE(!res.has_error()); - auto res_tensor = res.value()["img"].get(); - auto res_mat = mmdeploy::cpu::Tensor2CVMat(res_tensor); - cv::imwrite("ref.bmp", ref_mat); - cv::imwrite("res.bmp", res_mat); - REQUIRE(mmdeploy::cpu::Compare(ref_mat, res_mat)); - - REQUIRE(Shape(res.value(), "img_shape") == - vector{1, ref_mat.rows, ref_mat.cols, ref_mat.channels()}); - REQUIRE(Shape(res.value(), "ori_shape") == - vector{1, mat.rows, mat.cols, mat.channels()}); - REQUIRE(res.value().contains("img_fields")); - REQUIRE(res.value()["img_fields"].is_array()); - REQUIRE(res.value()["img_fields"].size() == 1); - REQUIRE(res.value()["img_fields"][0].get() == "img"); -} - -void TestCudaLoad(const Value& cfg, const cv::Mat& mat, PixelFormat src_format, - PixelFormat dst_format) { - Device device{"cuda"}; - Stream stream{device}; - auto transform = CreateTransform(cfg, device, stream); - REQUIRE(transform != nullptr); - - auto ref_mat = mmdeploy::cpu::ColorTransfer(mat, src_format, dst_format); - - auto src_mat = cpu::CVMat2Mat(mat, PixelFormat(src_format)); - auto res = transform->Process({{"ori_img", src_mat}}); - REQUIRE(!res.has_error()); - auto res_tensor = res.value()["img"].get(); - REQUIRE(res_tensor.device().is_device()); - - Device _device{"cpu"}; - auto host_tensor = MakeAvailableOnDevice(res_tensor, _device, stream); - REQUIRE(stream.Wait()); - - auto res_mat = mmdeploy::cpu::Tensor2CVMat(host_tensor.value()); - // cv::imwrite("ref.bmp", ref_mat); - // cv::imwrite("res.bmp", res_mat); - - REQUIRE(mmdeploy::cpu::Compare(ref_mat, res_mat)); - REQUIRE(Shape(res.value(), "img_shape") == - vector{1, ref_mat.rows, ref_mat.cols, ref_mat.channels()}); - REQUIRE(Shape(res.value(), "ori_shape") == - vector{1, mat.rows, mat.cols, mat.channels()}); - REQUIRE(res.value().contains("img_fields")); - REQUIRE(res.value()["img_fields"].is_array()); - REQUIRE(res.value()["img_fields"].size() == 1); - REQUIRE(res.value()["img_fields"][0].get() == "img"); -} - -TEST_CASE("prepare image, that is LoadImageFromFile transform", "[load]") { - const char* img_path = "../../tests/preprocess/data/imagenet_banner.jpeg"; - cv::Mat bgr_mat = cv::imread(img_path, cv::IMREAD_COLOR); - cv::Mat gray_mat = cv::imread(img_path, cv::IMREAD_GRAYSCALE); - cv::Mat rgb_mat; - cv::Mat bgra_mat; - // TODO(lvhan): make up yuv nv12/nv21 mat - // cv::Mat nv12_mat; - // cv::Mat nv21_mat; - - cv::cvtColor(bgr_mat, rgb_mat, cv::COLOR_BGR2RGB); - cv::cvtColor(bgr_mat, bgra_mat, cv::COLOR_BGR2BGRA); - - vector> mats{{bgr_mat, PixelFormat::kBGR}, - {rgb_mat, PixelFormat::kRGB}, - {gray_mat, PixelFormat::kGRAYSCALE}, - {bgra_mat, PixelFormat::kBGRA}}; - // pair is - vector> conditions{ - {"color", true}, {"color", false}, {"grayscale", true}, {"grayscale", false}}; - - for (auto& condition : conditions) { - Value cfg{{"type", "LoadImageFromFile"}, - {"to_float32", condition.second}, - {"color_type", condition.first}}; - for (auto& mat : mats) { - TestCpuLoad(cfg, mat.first, mat.second, - condition.first == "color" ? PixelFormat::kBGR : PixelFormat::kGRAYSCALE); - TestCudaLoad(cfg, mat.first, mat.second, - condition.first == "color" ? PixelFormat::kBGR : PixelFormat::kGRAYSCALE); - } - } -} diff --git a/tests/test_csrc/preprocess/transform/test_normalize.cpp b/tests/test_csrc/preprocess/transform/test_normalize.cpp deleted file mode 100644 index 642272e33f..0000000000 --- a/tests/test_csrc/preprocess/transform/test_normalize.cpp +++ /dev/null @@ -1,169 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "catch.hpp" -#include "core/mat.h" -#include "preprocess/cpu/opencv_utils.h" -#include "preprocess/transform/transform.h" -#include "preprocess/transform/transform_utils.h" -#include "test_utils.h" - -using namespace mmdeploy; -using namespace mmdeploy::test; -using namespace std; - -void TestCpuNormalize(const Value& cfg, const cv::Mat& mat) { - Device device{"cpu"}; - Stream stream{device}; - auto transform = CreateTransform(cfg, device, stream); - REQUIRE(transform != nullptr); - - vector mean; - vector std; - for (auto& v : cfg["mean"]) { - mean.push_back(v.get()); - } - for (auto& v : cfg["std"]) { - std.push_back(v.get()); - } - bool to_rgb = cfg.value("to_rgb", false); - - auto _mat = mat.clone(); - auto ref_mat = mmdeploy::cpu::Normalize(_mat, mean, std, to_rgb); - - auto res = transform->Process({{"img", cpu::CVMat2Tensor(mat)}}); - REQUIRE(!res.has_error()); - auto res_tensor = res.value()["img"].get(); - auto res_mat = mmdeploy::cpu::Tensor2CVMat(res_tensor); - REQUIRE(mmdeploy::cpu::Compare(ref_mat, res_mat)); - - REQUIRE(res_tensor.desc().data_type == DataType::kFLOAT); - REQUIRE(ImageNormCfg(res.value(), "mean") == mean); - REQUIRE(ImageNormCfg(res.value(), "std") == std); -} - -void TestCudaNormalize(const Value& cfg, const cv::Mat& mat) { - Device device{"cuda"}; - Stream stream{device}; - auto transform = CreateTransform(cfg, device, stream); - REQUIRE(transform != nullptr); - - vector mean; - vector std; - for (auto& v : cfg["mean"]) { - mean.push_back(v.get()); - } - for (auto& v : cfg["std"]) { - std.push_back(v.get()); - } - bool to_rgb = cfg.value("to_rgb", false); - - auto _mat = mat.clone(); - auto ref_mat = mmdeploy::cpu::Normalize(_mat, mean, std, to_rgb); - - auto res = transform->Process({{"img", cpu::CVMat2Tensor(mat)}}); - REQUIRE(!res.has_error()); - auto res_tensor = res.value()["img"].get(); - REQUIRE(res_tensor.device().is_device()); - - Device _device{"cpu"}; - auto host_tensor = MakeAvailableOnDevice(res_tensor, _device, stream); - REQUIRE(stream.Wait()); - auto res_mat = mmdeploy::cpu::Tensor2CVMat(host_tensor.value()); - - REQUIRE(mmdeploy::cpu::Compare(ref_mat, res_mat)); - REQUIRE(res_tensor.desc().data_type == DataType::kFLOAT); - REQUIRE(ImageNormCfg(res.value(), "mean") == mean); - REQUIRE(ImageNormCfg(res.value(), "std") == std); -} - -TEST_CASE("cpu normalize", "[normalize]") { - cv::Mat bgr_mat = cv::imread("../../tests/preprocess/data/test.jpg"); - cv::Mat gray_mat; - cv::Mat float_bgr_mat; - cv::Mat float_gray_mat; - - cv::cvtColor(bgr_mat, gray_mat, cv::COLOR_BGR2GRAY); - bgr_mat.convertTo(float_bgr_mat, CV_32FC3); - gray_mat.convertTo(float_gray_mat, CV_32FC1); - - SECTION("cpu vs gpu: 3 channel mat") { - bool to_rgb = true; - Value cfg{{"type", "Normalize"}, - {"mean", {123.675, 116.28, 103.53}}, - {"std", {58.395, 57.12, 57.375}}, - {"to_rgb", to_rgb}}; - vector mats{bgr_mat, float_bgr_mat}; - for (auto& mat : mats) { - TestCpuNormalize(cfg, mat); - } - } - - SECTION("cpu vs gpu: 3 channel mat, to_rgb false") { - bool to_rgb = false; - Value cfg{{"type", "Normalize"}, - {"mean", {123.675, 116.28, 103.53}}, - {"std", {58.395, 57.12, 57.375}}, - {"to_rgb", to_rgb}}; - - vector mats{bgr_mat, float_bgr_mat}; - for (auto& mat : mats) { - TestCpuNormalize(cfg, mat); - } - } - - SECTION("cpu vs gpu: 1 channel mat") { - bool to_rgb = true; - Value cfg{{"type", "Normalize"}, {"mean", {123.675}}, {"std", {58.395}}, {"to_rgb", to_rgb}}; - - vector mats{gray_mat, float_gray_mat}; - for (auto& mat : mats) { - TestCpuNormalize(cfg, mat); - } - } -} - -TEST_CASE("gpu normalize", "[normalize]") { - cv::Mat bgr_mat = cv::imread("../../tests/preprocess/data/test.jpg"); - cv::Mat gray_mat; - cv::Mat float_bgr_mat; - cv::Mat float_gray_mat; - - cv::cvtColor(bgr_mat, gray_mat, cv::COLOR_BGR2GRAY); - bgr_mat.convertTo(float_bgr_mat, CV_32FC3); - gray_mat.convertTo(float_gray_mat, CV_32FC1); - - SECTION("cpu vs gpu: 3 channel mat") { - bool to_rgb = true; - Value cfg{{"type", "Normalize"}, - {"mean", {123.675, 116.28, 103.53}}, - {"std", {58.395, 57.12, 57.375}}, - {"to_rgb", to_rgb}}; - vector mats{bgr_mat, float_bgr_mat}; - for (auto& mat : mats) { - TestCudaNormalize(cfg, mat); - } - } - - SECTION("cpu vs gpu: 3 channel mat, to_rgb false") { - bool to_rgb = false; - Value cfg{{"type", "Normalize"}, - {"mean", {123.675, 116.28, 103.53}}, - {"std", {58.395, 57.12, 57.375}}, - {"to_rgb", to_rgb}}; - - vector mats{bgr_mat, float_bgr_mat}; - for (auto& mat : mats) { - TestCudaNormalize(cfg, mat); - } - } - - SECTION("cpu vs gpu: 1 channel mat") { - bool to_rgb = true; - Value cfg{{"type", "Normalize"}, {"mean", {123.675}}, {"std", {58.395}}, {"to_rgb", to_rgb}}; - - vector mats{gray_mat, float_gray_mat}; - for (auto& mat : mats) { - TestCudaNormalize(cfg, mat); - } - } -} diff --git a/tests/test_csrc/preprocess/transform/test_pad.cpp b/tests/test_csrc/preprocess/transform/test_pad.cpp deleted file mode 100644 index 8d83af354c..0000000000 --- a/tests/test_csrc/preprocess/transform/test_pad.cpp +++ /dev/null @@ -1,185 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "catch.hpp" -#include "core/mat.h" -#include "preprocess/cpu/opencv_utils.h" -#include "preprocess/transform/transform.h" -#include "preprocess/transform/transform_utils.h" -#include "test_utils.h" - -using namespace mmdeploy; -using namespace std; -using namespace mmdeploy::test; - -// left, top, right, bottom -tuple GetPadSize(const cv::Mat& mat, int dst_height, int dst_width) { - return {0, 0, dst_width - mat.cols, dst_height - mat.rows}; -} - -tuple GetPadSize(const cv::Mat& mat, bool square = true) { - int size = std::max(mat.rows, mat.cols); - return GetPadSize(mat, size, size); -} - -tuple GetPadSize(const cv::Mat& mat, int divisor) { - auto pad_h = int(ceil(mat.rows * 1.0 / divisor)) * divisor; - auto pad_w = int(ceil(mat.cols * 1.0 / divisor)) * divisor; - return GetPadSize(mat, pad_h, pad_w); -} - -void TestCpuPad(const Value& cfg, const cv::Mat& mat, int top, int left, int bottom, int right, - int border_type, float val) { - Device device{"cpu"}; - Stream stream{device}; - auto transform = CreateTransform(cfg, device, stream); - REQUIRE(transform != nullptr); - - auto ref_mat = mmdeploy::cpu::Pad(mat, top, left, bottom, right, border_type, val); - - auto res = transform->Process({{"img", cpu::CVMat2Tensor(mat)}}); - REQUIRE(!res.has_error()); - auto res_tensor = res.value()["img"].get(); - auto res_mat = mmdeploy::cpu::Tensor2CVMat(res_tensor); - // cv::imwrite("ref.bmp", ref_mat); - // cv::imwrite("res.bmp", res_mat); - REQUIRE(mmdeploy::cpu::Compare(ref_mat, res_mat)); - REQUIRE(Shape(res.value(), "pad_shape") == - vector{1, ref_mat.rows, ref_mat.cols, ref_mat.channels()}); - REQUIRE(Shape(res.value(), "pad_fixed_size") == std::vector{ref_mat.rows, ref_mat.cols}); -} - -void TestCudaPad(const Value& cfg, const cv::Mat& mat, int top, int left, int bottom, int right, - int border_type, float val) { - Device device{"cuda"}; - Stream stream{device}; - auto transform = CreateTransform(cfg, device, stream); - REQUIRE(transform != nullptr); - - auto ref_mat = mmdeploy::cpu::Pad(mat, top, left, bottom, right, border_type, val); - - auto res = transform->Process({{"img", cpu::CVMat2Tensor(mat)}}); - REQUIRE(!res.has_error()); - auto res_tensor = res.value()["img"].get(); - REQUIRE(res_tensor.device().is_device()); - - Device _device{"cpu"}; - auto host_tensor = MakeAvailableOnDevice(res_tensor, _device, stream); - REQUIRE(stream.Wait()); - auto res_mat = mmdeploy::cpu::Tensor2CVMat(host_tensor.value()); - // cv::imwrite("ref.bmp", ref_mat); - // cv::imwrite("res.bmp", res_mat); - REQUIRE(mmdeploy::cpu::Compare(ref_mat, res_mat)); - REQUIRE(Shape(res.value(), "pad_shape") == - vector{1, ref_mat.rows, ref_mat.cols, ref_mat.channels()}); - REQUIRE(Shape(res.value(), "pad_fixed_size") == std::vector{ref_mat.rows, ref_mat.cols}); -} - -TEST_CASE("cpu Pad", "[pad]") { - auto img_path = "../../tests/preprocess/data/imagenet_banner.jpeg"; - cv::Mat bgr_mat = cv::imread(img_path, cv::IMREAD_COLOR); - cv::Mat gray_mat; - cv::Mat float_bgr_mat; - cv::Mat float_gray_mat; - cv::cvtColor(bgr_mat, gray_mat, cv::COLOR_BGR2GRAY); - bgr_mat.convertTo(float_bgr_mat, CV_32FC3); - gray_mat.convertTo(float_gray_mat, CV_32FC1); - - vector mats{bgr_mat, gray_mat, float_bgr_mat, float_gray_mat}; - vector modes{"constant", "edge", "reflect", "symmetric"}; - map border_map{{"constant", cv::BORDER_CONSTANT}, - {"edge", cv::BORDER_REPLICATE}, - {"reflect", cv::BORDER_REFLECT_101}, - {"symmetric", cv::BORDER_REFLECT}}; - SECTION("pad to square") { - bool square{true}; - float val = 255.0f; - for (auto& mat : mats) { - for (auto& mode : modes) { - Value cfg{ - {"type", "Pad"}, {"pad_to_square", square}, {"padding_mode", mode}, {"pad_val", val}}; - auto [pad_left, pad_top, pad_right, pad_bottom] = GetPadSize(mat, square); - TestCpuPad(cfg, mat, pad_top, pad_left, pad_bottom, pad_right, border_map[mode], 255); - } - } - } - - SECTION("pad with size_divisor") { - constexpr int divisor = 32; - float val = 255.0f; - for (auto& mat : mats) { - for (auto& mode : modes) { - Value cfg{ - {"type", "Pad"}, {"size_divisor", divisor}, {"padding_mode", mode}, {"pad_val", val}}; - auto [pad_left, pad_top, pad_right, pad_bottom] = GetPadSize(mat, divisor); - TestCpuPad(cfg, mat, pad_top, pad_left, pad_bottom, pad_right, border_map[mode], 255); - } - } - } - - SECTION("pad with size") { - constexpr int height = 600; - constexpr int width = 800; - for (auto& mat : mats) { - for (auto& mode : modes) { - Value cfg{{"type", "Pad"}, {"size", {height, width}}, {"padding_mode", mode}}; - auto [pad_left, pad_top, pad_right, pad_bottom] = GetPadSize(mat, height, width); - TestCpuPad(cfg, mat, pad_top, pad_left, pad_bottom, pad_right, border_map[mode], 0); - } - } - } -} - -TEST_CASE("gpu Pad", "[pad]") { - auto img_path = "../../tests/preprocess/data/imagenet_banner.jpeg"; - cv::Mat bgr_mat = cv::imread(img_path, cv::IMREAD_COLOR); - cv::Mat gray_mat; - cv::Mat float_bgr_mat; - cv::Mat float_gray_mat; - cv::cvtColor(bgr_mat, gray_mat, cv::COLOR_BGR2GRAY); - bgr_mat.convertTo(float_bgr_mat, CV_32FC3); - gray_mat.convertTo(float_gray_mat, CV_32FC1); - - vector mats{bgr_mat, gray_mat, float_bgr_mat, float_gray_mat}; - vector modes{"constant", "edge", "reflect", "symmetric"}; - map border_map{{"constant", cv::BORDER_CONSTANT}, - {"edge", cv::BORDER_REPLICATE}, - {"reflect", cv::BORDER_REFLECT_101}, - {"symmetric", cv::BORDER_REFLECT}}; - SECTION("pad to square") { - bool square{true}; - float val = 255.0f; - for (auto& mat : mats) { - for (auto& mode : modes) { - Value cfg{ - {"type", "Pad"}, {"pad_to_square", square}, {"padding_mode", mode}, {"pad_val", val}}; - auto [pad_left, pad_top, pad_right, pad_bottom] = GetPadSize(mat, square); - TestCudaPad(cfg, mat, pad_top, pad_left, pad_bottom, pad_right, border_map[mode], 255); - } - } - } - - SECTION("pad with size_divisor") { - constexpr int divisor = 32; - float val = 255.0f; - for (auto& mat : mats) { - for (auto& mode : modes) { - Value cfg{ - {"type", "Pad"}, {"size_divisor", divisor}, {"padding_mode", mode}, {"pad_val", val}}; - auto [pad_left, pad_top, pad_right, pad_bottom] = GetPadSize(mat, divisor); - TestCudaPad(cfg, mat, pad_top, pad_left, pad_bottom, pad_right, border_map[mode], 255); - } - } - } - - SECTION("pad with size") { - constexpr int height = 600; - constexpr int width = 800; - for (auto& mat : mats) { - for (auto& mode : modes) { - Value cfg{{"type", "Pad"}, {"size", {height, width}}, {"padding_mode", mode}}; - auto [pad_left, pad_top, pad_right, pad_bottom] = GetPadSize(mat, height, width); - TestCudaPad(cfg, mat, pad_top, pad_left, pad_bottom, pad_right, border_map[mode], 0); - } - } - } -} diff --git a/tests/test_csrc/preprocess/transform/test_resize.cpp b/tests/test_csrc/preprocess/transform/test_resize.cpp deleted file mode 100644 index 02d2d0755a..0000000000 --- a/tests/test_csrc/preprocess/transform/test_resize.cpp +++ /dev/null @@ -1,327 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "catch.hpp" -#include "core/mat.h" -#include "preprocess/cpu/opencv_utils.h" -#include "preprocess/transform/transform.h" -#include "preprocess/transform/transform_utils.h" -#include "test_utils.h" - -using namespace mmdeploy; -using namespace std; -using namespace mmdeploy::test; - -// return {target_height, target_width} -tuple GetTargetSize(const cv::Mat& src, int size0, int size1) { - assert(size0 > 0); - if (size1 > 0) { - return {size0, size1}; - } else { - if (src.rows < src.cols) { - return {size0, size0 * src.cols / src.rows}; - } else { - return {size0 * src.rows / src.cols, size0}; - } - } -} - -// return {target_height, target_width} -tuple GetTargetSize(const cv::Mat& src, int scale0, int scale1, bool keep_ratio) { - auto w = src.cols; - auto h = src.rows; - auto max_long_edge = max(scale0, scale1); - auto max_short_edge = min(scale0, scale1); - if (keep_ratio) { - auto scale_factor = - std::min(max_long_edge * 1.0 / std::max(h, w), max_short_edge * 1.0 / std::min(h, w)); - return {int(h * scale_factor + 0.5f), int(w * scale_factor + 0.5f)}; - } else { - return {scale0, scale1}; - } -} - -void TestCpuResize(const Value& cfg, const cv::Mat& mat, int dst_height, int dst_width) { - Device device{"cpu"}; - Stream stream{device}; - auto transform = CreateTransform(cfg, device, stream); - REQUIRE(transform != nullptr); - - auto interpolation = cfg["interpolation"].get(); - auto ref_mat = mmdeploy::cpu::Resize(mat, dst_height, dst_width, interpolation); - - auto res = transform->Process({{"img", cpu::CVMat2Tensor(mat)}}); - REQUIRE(!res.has_error()); - auto res_tensor = res.value()["img"].get(); - auto res_mat = mmdeploy::cpu::Tensor2CVMat(res_tensor); - // cv::imwrite("ref.bmp", ref_mat); - // cv::imwrite("res.bmp", res_mat); - REQUIRE(mmdeploy::cpu::Compare(ref_mat, res_mat)); - REQUIRE(Shape(res.value(), "img_shape") == - vector{1, ref_mat.rows, ref_mat.cols, ref_mat.channels()}); - REQUIRE(Shape(res.value(), "img_shape") == res_tensor.desc().shape); -} - -void TestCudaResize(const Value& cfg, const cv::Mat& mat, int dst_height, int dst_width) { - Device device{"cuda"}; - Stream stream{device}; - auto transform = CreateTransform(cfg, device, stream); - REQUIRE(transform != nullptr); - - auto interpolation = cfg["interpolation"].get(); - auto ref_mat = mmdeploy::cpu::Resize(mat, dst_height, dst_width, interpolation); - - auto res = transform->Process({{"img", cpu::CVMat2Tensor(mat)}}); - REQUIRE(!res.has_error()); - auto res_tensor = res.value()["img"].get(); - REQUIRE(res_tensor.device().is_device()); - - Device _device{"cpu"}; - auto host_tensor = MakeAvailableOnDevice(res_tensor, _device, stream); - REQUIRE(stream.Wait()); - - auto res_mat = mmdeploy::cpu::Tensor2CVMat(host_tensor.value()); - cv::imwrite("ref.bmp", ref_mat); - cv::imwrite("res.bmp", res_mat); - REQUIRE(Shape(res.value(), "img_shape") == - vector{1, ref_mat.rows, ref_mat.cols, ref_mat.channels()}); - REQUIRE(Shape(res.value(), "img_shape") == res_tensor.desc().shape); - REQUIRE(mmdeploy::cpu::Compare(ref_mat, res_mat)); -} - -void TestCpuResizeWithScale(const Value& cfg, const cv::Mat& mat, int scale0, int scale1, - bool keep_ratio) { - Device device{"cpu"}; - Stream stream{device}; - auto transform = CreateTransform(cfg, device, stream); - REQUIRE(transform != nullptr); - - auto [dst_height, dst_width] = GetTargetSize(mat, scale0, scale1, keep_ratio); - auto interpolation = cfg["interpolation"].get(); - auto ref_mat = mmdeploy::cpu::Resize(mat, dst_height, dst_width, interpolation); - - Value input{{"img", cpu::CVMat2Tensor(mat)}, {"scale", {scale0, scale1}}}; - auto res = transform->Process(input); - REQUIRE(!res.has_error()); - auto res_tensor = res.value()["img"].get(); - auto res_mat = mmdeploy::cpu::Tensor2CVMat(res_tensor); - // cv::imwrite("ref.bmp", ref_mat); - // cv::imwrite("res.bmp", res_mat); - REQUIRE(Shape(res.value(), "img_shape") == - vector{1, ref_mat.rows, ref_mat.cols, ref_mat.channels()}); - REQUIRE(Shape(res.value(), "img_shape") == res_tensor.desc().shape); - REQUIRE(mmdeploy::cpu::Compare(ref_mat, res_mat)); -} - -void TestCpuResizeWithScaleFactor(const Value& cfg, const cv::Mat& mat, float scale_factor) { - Device device{"cpu"}; - Stream stream{device}; - auto transform = CreateTransform(cfg, device, stream); - REQUIRE(transform != nullptr); - - auto [dst_height, dst_width] = make_tuple(mat.rows * scale_factor, mat.cols * scale_factor); - auto interpolation = cfg["interpolation"].get(); - auto ref_mat = mmdeploy::cpu::Resize(mat, dst_height, dst_width, interpolation); - - Value input{{"img", cpu::CVMat2Tensor(mat)}, {"scale_factor", scale_factor}}; - auto res = transform->Process(input); - REQUIRE(!res.has_error()); - auto res_tensor = res.value()["img"].get(); - auto res_mat = mmdeploy::cpu::Tensor2CVMat(res_tensor); - // cv::imwrite("ref.bmp", ref_mat); - // cv::imwrite("res.bmp", res_mat); - REQUIRE(Shape(res.value(), "img_shape") == - vector{1, ref_mat.rows, ref_mat.cols, ref_mat.channels()}); - REQUIRE(Shape(res.value(), "img_shape") == res_tensor.desc().shape); - REQUIRE(mmdeploy::cpu::Compare(ref_mat, res_mat)); -} - -TEST_CASE("mmclassification resize(cpu)", "[resize]") { - const char* img_path = "../../tests/preprocess/data/imagenet_banner.jpeg"; - cv::Mat bgr_mat = cv::imread(img_path, cv::IMREAD_COLOR); - cv::Mat gray_mat = cv::imread(img_path, cv::IMREAD_GRAYSCALE); - cv::Mat bgr_float_mat; - cv::Mat gray_float_mat; - bgr_mat.convertTo(bgr_float_mat, CV_32FC3); - gray_mat.convertTo(gray_float_mat, CV_32FC1); - - vector mats{bgr_mat, gray_mat, bgr_float_mat, gray_float_mat}; - vector interpolations{"bilinear", "nearest", "area", "bicubic", "lanczos"}; - SECTION("size: {256, -1}") { - constexpr int size = 256; - for (auto& mat : mats) { - for (auto& interp : interpolations) { - Value cfg{{"type", "Resize"}, - {"size", {size, -1}}, - {"keep_ratio", false}, - {"interpolation", interp}}; - auto [dst_height, dst_width] = GetTargetSize(mat, size, -1); - TestCpuResize(cfg, mat, dst_height, dst_width); - } - } - } - - SECTION("size: {300, -1}. It shouldn't be resized") { - constexpr int size = 300; - for (auto& mat : mats) { - for (auto& interp : interpolations) { - Value cfg{{"type", "Resize"}, - {"size", {size, -1}}, - {"keep_ratio", false}, - {"interpolation", interp}}; - auto [dst_height, dst_width] = GetTargetSize(mat, size, -1); - TestCpuResize(cfg, mat, dst_height, dst_width); - } - } - } - - SECTION("size: 384") { - constexpr int size = 384; - for (auto& mat : mats) { - for (auto& interp : interpolations) { - Value cfg{ - {"type", "Resize"}, {"size", size}, {"keep_ratio", false}, {"interpolation", interp}}; - TestCpuResize(cfg, mat, size, size); - } - } - } -} - -TEST_CASE("mmclassification resize(gpu)", "[resize]") { - const char* img_path = "../../tests/preprocess/data/imagenet_banner.jpeg"; - cv::Mat bgr_mat = cv::imread(img_path, cv::IMREAD_COLOR); - cv::Mat gray_mat = cv::imread(img_path, cv::IMREAD_GRAYSCALE); - cv::Mat bgr_float_mat; - cv::Mat gray_float_mat; - bgr_mat.convertTo(bgr_float_mat, CV_32FC3); - gray_mat.convertTo(gray_float_mat, CV_32FC1); - - vector mats{bgr_mat, gray_mat, bgr_float_mat, gray_float_mat}; - vector interpolations{"bilinear", "nearest"}; - - SECTION("size: {256, -1}") { - constexpr int size = 256; - for (auto& mat : mats) { - for (auto& interp : interpolations) { - Value cfg{{"type", "Resize"}, - {"size", {size, -1}}, - {"keep_ratio", false}, - {"interpolation", interp}}; - auto [dst_height, dst_width] = GetTargetSize(mat, size, -1); - TestCudaResize(cfg, mat, dst_height, dst_width); - } - } - } - - SECTION("size: 384") { - constexpr int size = 384; - for (auto& mat : mats) { - for (auto& interp : interpolations) { - Value cfg{ - {"type", "Resize"}, {"size", size}, {"keep_ratio", false}, {"interpolation", interp}}; - TestCudaResize(cfg, mat, size, size); - } - } - } -} - -TEST_CASE("mmdetection resize (cpu)", "[resize]") { - const char* img_path = "../../tests/preprocess/data/imagenet_banner.jpeg"; - cv::Mat bgr_mat = cv::imread(img_path, cv::IMREAD_COLOR); - cv::Mat gray_mat = cv::imread(img_path, cv::IMREAD_GRAYSCALE); - cv::Mat bgr_float_mat; - cv::Mat gray_float_mat; - bgr_mat.convertTo(bgr_float_mat, CV_32FC3); - gray_mat.convertTo(gray_float_mat, CV_32FC1); - - vector mats{bgr_mat, gray_mat, bgr_float_mat, gray_float_mat}; - vector interpolations{"bilinear", "nearest", "area", "bicubic", "lanczos"}; - SECTION("img_scale: [1333, 800]. keep_ratio: true") { - constexpr int max_long_edge = 1333; - constexpr int max_short_edge = 800; - bool keep_ratio = true; - for (auto& mat : mats) { - for (auto& interp : interpolations) { - Value cfg{{"type", "Resize"}, - {"size", {max_long_edge, max_short_edge}}, - {"keep_ratio", keep_ratio}, - {"interpolation", interp}}; - auto [dst_height, dst_width] = - GetTargetSize(mat, max_long_edge, max_short_edge, keep_ratio); - TestCpuResize(cfg, mat, dst_height, dst_width); - } - } - } - - SECTION("img_scale: [1333, 800]. keep_ratio: false") { - constexpr int dst_height = 800; - constexpr int dst_width = 1333; - bool keep_ratio = false; - for (auto& mat : mats) { - for (auto& interp : interpolations) { - Value cfg{{"type", "Resize"}, - {"size", {dst_height, dst_width}}, - {"keep_ratio", keep_ratio}, - {"interpolation", interp}}; - - TestCpuResize(cfg, mat, dst_height, dst_width); - } - } - } - - SECTION("scale: [1333, 800]") { - constexpr int max_long_edge = 1333; - constexpr int max_short_edge = 800; - bool keep_ratio = true; - for (auto& mat : mats) { - for (auto& interp : interpolations) { - Value cfg{{"type", "Resize"}, - {"size", {max_long_edge, max_short_edge}}, - {"keep_ratio", keep_ratio}, - {"interpolation", interp}}; - TestCpuResizeWithScale(cfg, mat, max_long_edge, max_short_edge, keep_ratio); - } - } - } - - SECTION("scale: [1333, 800]") { - constexpr int max_long_edge = 1333; - constexpr int max_short_edge = 800; - bool keep_ratio = false; - for (auto& mat : mats) { - for (auto& interp : interpolations) { - Value cfg{{"type", "Resize"}, - {"size", {max_long_edge, max_short_edge}}, - {"keep_ratio", keep_ratio}, - {"interpolation", interp}}; - TestCpuResizeWithScale(cfg, mat, max_long_edge, max_short_edge, keep_ratio); - } - } - } - - SECTION("scale_factor: 0.5") { - float scale_factor = 0.5; - bool keep_ratio = true; - for (auto& mat : mats) { - for (auto& interp : interpolations) { - Value cfg{{"type", "Resize"}, - {"size", {600, 800}}, - {"keep_ratio", keep_ratio}, - {"interpolation", interp}}; - TestCpuResizeWithScaleFactor(cfg, mat, scale_factor); - } - } - } -} - -TEST_CASE("4 channel resize", "[resize]") { - const char* img_path = "../../tests/preprocess/data/imagenet_banner.jpeg"; - cv::Mat bgr_mat = cv::imread(img_path, cv::IMREAD_COLOR); - cv::Mat bgra_mat; - cv::cvtColor(bgr_mat, bgra_mat, cv::COLOR_BGR2BGRA); - - Value cfg{ - {"type", "Resize"}, {"size", {256, -1}}, {"keep_ratio", false}, {"interpolation", "nearest"}}; - auto [dst_height, dst_width] = GetTargetSize(bgra_mat, 256, -1); - TestCpuResize(cfg, bgra_mat, dst_height, dst_width); - TestCudaResize(cfg, bgra_mat, dst_height, dst_width); -} diff --git a/tests/test_csrc/test_resource.h b/tests/test_csrc/test_resource.h new file mode 100644 index 0000000000..11fbd034e2 --- /dev/null +++ b/tests/test_csrc/test_resource.h @@ -0,0 +1,149 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#ifndef MMDEPLOY_TEST_RESOURCE_H +#define MMDEPLOY_TEST_RESOURCE_H +#include +#include +#include +#include +#include +#include + +#include "test_define.h" + +#if __GNUC__ >= 8 +#include +namespace fs = std::filesystem; +#else + +#include + +namespace fs = std::experimental::filesystem; +#endif + +using namespace std; + +class MMDeployTestResources { + public: + static MMDeployTestResources &Get() { + static MMDeployTestResources resource; + return resource; + } + + const std::vector &device_names() const { return devices_; } + const std::vector &device_names(const std::string &backend) const { + return backend_devices_.at(backend); + } + const std::vector &backends() const { return backends_; } + const std::vector &codebases() const { return codebases_; } + const std::string &resource_root_path() const { return resource_root_path_; } + + bool HasDevice(const std::string &name) const { + return std::any_of(devices_.begin(), devices_.end(), + [&](const std::string &device_name) { return device_name == name; }); + } + + bool IsDir(const std::string &dir_name) const { + fs::path path{resource_root_path_ + "/" + dir_name}; + return fs::is_directory(path); + } + + bool IsFile(const std::string &file_name) const { + fs::path path{resource_root_path_ + "/" + file_name}; + return fs::is_regular_file(path); + } + + public: + std::vector LocateModelResources(const std::string &sdk_model_zoo_dir) { + std::vector sdk_model_list; + if (resource_root_path_.empty()) { + return sdk_model_list; + } + + fs::path path{resource_root_path_ + "/" + sdk_model_zoo_dir}; + if (!fs::is_directory(path)) { + return sdk_model_list; + } + for (auto const &dir_entry : fs::directory_iterator{path}) { + fs::directory_entry entry{dir_entry.path()}; + if (auto const &_path = dir_entry.path(); fs::is_directory(_path)) { + sdk_model_list.push_back(dir_entry.path()); + } + } + return sdk_model_list; + } + + std::vector LocateImageResources(const std::string &img_dir) { + std::vector img_list; + + if (resource_root_path_.empty()) { + return img_list; + } + + fs::path path{resource_root_path_ + "/" + img_dir}; + if (!fs::is_directory(path)) { + return img_list; + } + + set extensions{".png", ".jpg", ".jpeg", ".bmp"}; + for (auto const &dir_entry : fs::directory_iterator{path}) { + if (!fs::is_regular_file(dir_entry.path())) { + std::cout << dir_entry.path().string() << std::endl; + continue; + } + auto const &_path = dir_entry.path(); + auto ext = _path.extension().string(); + std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); + if (extensions.find(ext) != extensions.end()) { + img_list.push_back(_path.string()); + } + } + return img_list; + } + + private: + MMDeployTestResources() { + devices_ = Split(kDevices); + backends_ = Split(kBackends); + codebases_ = Split(kCodebases); + backend_devices_["pplnn"] = {"cpu", "cuda"}; + backend_devices_["trt"] = {"cuda"}; + backend_devices_["ort"] = {"cpu"}; + backend_devices_["ncnn"] = {"cpu"}; + backend_devices_["openvino"] = {"cpu"}; + resource_root_path_ = LocateResourceRootPath(fs::current_path(), 8); + } + + static std::vector Split(const std::string &text, char delimiter = ';') { + std::vector result; + std::istringstream ss(text); + for (std::string word; std::getline(ss, word, delimiter);) { + result.emplace_back(word); + } + return result; + } + + std::string LocateResourceRootPath(const fs::path &cur_path, int max_depth) { + if (max_depth < 0) { + return ""; + } + for (auto const &dir_entry : fs::directory_iterator{cur_path}) { + fs::directory_entry entry{dir_entry.path()}; + auto const &_path = dir_entry.path(); + if (fs::is_directory(_path) && _path.filename() == "mmdeploy_test_resources") { + return _path.string(); + } + } + // Didn't find 'mmdeploy_test_resources' in current directory. + // Move to its parent directory and keep looking for it + return LocateResourceRootPath(cur_path.parent_path(), max_depth - 1); + } + + private: + std::vector devices_; + std::vector backends_; + std::vector codebases_; + std::map> backend_devices_; + std::string resource_root_path_; +}; + +#endif // MMDEPLOY_TEST_RESOURCE_H