You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
class SESR{
public:
SESR()=default;
int Init(const char* model_path);
std::vectorMNN::Express::VARP pre_process(const uint8_t * yuv_data, int width, int height, int channels);
std::vectorMNN::Express::VARP& processing(std::vectorMNN::Express::VARP& yuv_var);
uint8_t* post_process(std::vectorMNN::Express::VARP& yuv_var);
平台(如果交叉编译请再附上交叉编译目标平台):
Platform(Include target platform as well if cross-compiling):
mac os 11.7.10
代码如下:
C++
//
// Created by xxx on 19.9.24.
//
#include<MNN/Interpreter.hpp>
#include<MNN/ImageProcess.hpp>
#include<MNN/expr/Expr.hpp>
#include<MNN/expr/Executor.hpp>
#include<MNN/expr/ExprCreator.hpp>
#include<MNN/expr/Module.hpp>
#include<MNN/AutoTime.hpp>
#include
#include
#include
#include
#include
class SESR{
public:
SESR()=default;
int Init(const char* model_path);
std::vectorMNN::Express::VARP pre_process(const uint8_t * yuv_data, int width, int height, int channels);
std::vectorMNN::Express::VARP& processing(std::vectorMNN::Express::VARP& yuv_var);
uint8_t* post_process(std::vectorMNN::Express::VARP& yuv_var);
private:
const float mean_vals[1] = {0.f};
const float normal_vals[1] = {1.f/255};
std::shared_ptrMNN::CV::ImageProcess pretreat;
std::shared_ptrMNN::Express::Module net = nullptr;
};
int SESR::Init(const char* model_path) {
net.reset(MNN::Express::Module::load(std::vectorstd::string{"img"}, std::vectorstd::string{"out"}, model_path));
if(nullptr == net){
MNN_ERROR("load model failed!\n");
}
MNN::ScheduleConfig sConfig;
// sConfig.type = MNN_FORWARD_CPU;
sConfig.type = MNN_FORWARD_OPENCL;
sConfig.backupType = MNN_FORWARD_VULKAN;
sConfig.numThread = std::thread::hardware_concurrency(); // cpu推理时的线程数
MNN::BackendConfig backendConfig;
backendConfig.memory = MNN::BackendConfig::Memory_Normal;
backendConfig.precision = MNN::BackendConfig::Precision_Normal;
sConfig.backendConfig = &backendConfig;
// std::shared_ptrMNN::Express::Executor::RuntimeManager rtmgr{
// MNN::Express::Executor::RuntimeManager::createRuntimeManager(sConfig)
// };
// if(nullptr == rtmgr){
// MNN_ERROR("Empty RuntimeManager\n");
// }
MNN::CV::ImageProcess::Config config;
config.sourceFormat = MNN::CV::YUV_I420; // 指定源格式YUV_I420
config.destFormat = MNN::CV::YUV_I420; // 转换成的目标格式YUV_I420
std::memcpy(config.mean, mean_vals, sizeof(mean_vals));
std::memcpy(config.normal, normal_vals, sizeof(normal_vals));
pretreat = std::shared_ptrMNN::CV::ImageProcess{MNN::CV::ImageProcess::create(config)};
}
std::vectorMNN::Express::VARP SESR::pre_process(const uint8_t* yuv_data, int width, int height, int channels) {
// MNN::Express::Variable::Info y_info{MNN::Express::NCHW, {1, channels, height, width},
// halide_type_of(), channelsheightwidth};
// MNN::Express::EXPRP y_expr = MNN::Express::Expr::create(std::move(y_info), yuv_data, MNN::Express::VARP::INPUT, MNN::Express::Expr::REF);
// MNN::Express::VARP y = MNN::Express::Variable::create(y_expr);
MNN::Express::VARP y = MNN::Express::_Input({1, channels, height, width}, MNN::Express::NCHW); // 模型输入格式NCHW, RGB
pretreat->convert(yuv_data, width, height, widthchannels, y->writeMap(),
width, height, channels, widthchannels, halide_type_of()); // y通道数据
}
std::vectorMNN::Express::VARP& SESR::processing(std::vectorMNN::Express::VARP& yuv_var) {
yuv_var[0] = net->onForward({yuv_var[0]})[0]; // 取出第一个输出
yuv_var[1] = MNN::Express::_Resize(yuv_var[1], 1.5, 1.5); // 对u通道数据插值缩放
yuv_var[2] = MNN::Express::_Resize(yuv_var[2], 1.5, 1.5); // 对v通道数据插值缩放
}
uint8_t* SESR::post_process(std::vectorMNN::Express::VARP& yuv_var) {
yuv_var[0] = MNN::Express::_Relu6(yuv_var[0], 0.0f, 1.0f); // 截断到[0, 255]
yuv_var[0] = yuv_var[0] * MNN::Express::_Scalar(255.0f); // 像素值缩放到[0, 255]之间
yuv_var[0] = MNN::Express::_Cast<uint8_t>(yuv_var[0]); // 像素值转换为整数
yuv_var[0] = MNN::Express::_Reshape(yuv_var[0], {-1});
yuv_var[1] = MNN::Express::_Reshape(yuv_var[1], {-1});
yuv_var[2] = MNN::Express::_Reshape(yuv_var[2], {-1});
}
int main(){
const std::string& file_path = "/Users/xxx/pycharm_projects/mnn_practice/720p_I420.yuv";
const std::string& dest_path = "/Users/xxx/pycharm_projects/mnn_practice/out_720p.yuv";
const std::string& model_path = "/Users/xxx/pycharm_projects/mnn_practice/sesr_c1.mnn";
u_int32_t width = 1280;
u_int32_t height = 720;
u_int32_t channels = 1;
u_int32_t num_pixels = width * height * channels * 1.5; // 每个像素有3个通道(R, G, B)
}
报的异常截图如下:
模型文件:https://github.com/zhuzhu18/sesr/blob/main/sesr_c1.mnn
The text was updated successfully, but these errors were encountered: