Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support 3 types model #1354

Merged
merged 6 commits into from
Aug 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 43 additions & 11 deletions paddle_inference/paddle/include/paddle_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include <dirent.h>
#include <pthread.h>
#include <fstream>
#include <map>
Expand Down Expand Up @@ -69,6 +70,30 @@ PrecisionType GetPrecision(const std::string& precision_data) {
return PrecisionType::kFloat32;
}

const std::string getFileBySuffix(
const std::string& path, const std::vector<std::string>& suffixVector) {
DIR* dp = nullptr;
std::string fileName = "";
struct dirent* dirp = nullptr;
if ((dp = opendir(path.c_str())) == nullptr) {
return fileName;
}
while ((dirp = readdir(dp)) != nullptr) {
if (dirp->d_type == DT_REG) {
for (int idx = 0; idx < suffixVector.size(); ++idx) {
if (std::string(dirp->d_name).find(suffixVector[idx]) !=
std::string::npos) {
fileName = static_cast<std::string>(dirp->d_name);
break;
}
}
}
if (fileName.length() != 0) break;
}
closedir(dp);
return fileName;
}

// Engine Base
class EngineCore {
public:
Expand Down Expand Up @@ -131,9 +156,21 @@ class PaddleInferenceEngine : public EngineCore {
}

Config config;
// todo, auto config(zhangjun)
if (engine_conf.has_encrypted_model() && engine_conf.encrypted_model()) {
std::vector<std::string> suffixParaVector = {".pdiparams", "__params__"};
std::vector<std::string> suffixModelVector = {".pdmodel", "__model__"};
std::string paraFileName = getFileBySuffix(model_path, suffixParaVector);
std::string modelFileName = getFileBySuffix(model_path, suffixModelVector);

std::string encryParaPath = model_path + "/encrypt_model";
std::string encryModelPath = model_path + "/encrypt_params";
std::string encryKeyPath = model_path + "/key";

// encrypt model
if (access(encryParaPath.c_str(), F_OK) != -1 &&
access(encryModelPath.c_str(), F_OK) != -1 &&
access(encryKeyPath.c_str(), F_OK) != -1) {
// decrypt model

std::string model_buffer, params_buffer, key_buffer;
predictor::ReadBinaryFile(model_path + "/encrypt_model", &model_buffer);
predictor::ReadBinaryFile(model_path + "/encrypt_params", &params_buffer);
Expand All @@ -147,16 +184,11 @@ class PaddleInferenceEngine : public EngineCore {
real_model_buffer.size(),
&real_params_buffer[0],
real_params_buffer.size());
} else if (engine_conf.has_combined_model()) {
if (!engine_conf.combined_model()) {
config.SetModel(model_path);
} else {
config.SetParamsFile(model_path + "/__params__");
config.SetProgFile(model_path + "/__model__");
}
} else if (paraFileName.length() != 0 && modelFileName.length() != 0) {
config.SetParamsFile(model_path + "/" + paraFileName);
config.SetProgFile(model_path + "/" + modelFileName);
} else {
config.SetParamsFile(model_path + "/__params__");
config.SetProgFile(model_path + "/__model__");
config.SetModel(model_path);
}

config.SwitchSpecifyInputNames(true);
Expand Down
4 changes: 2 additions & 2 deletions python/paddle_serving_client/httpclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,15 +403,15 @@ def process_tensor(self, key, feed_dict, batch):
# 由于输入比较特殊,shape保持原feedvar中不变
data_value = []
data_value.append(feed_dict[key])
if isinstance(feed_dict[key], str):
if isinstance(feed_dict[key], (str, bytes)):
if self.feed_types_[key] != bytes_type:
raise ValueError(
"feedvar is not string-type,feed can`t be a single string."
)
else:
if self.feed_types_[key] == bytes_type:
raise ValueError(
"feedvar is string-type,feed, feed can`t be a single int or others."
"feedvar is string-type,feed can`t be a single int or others."
)
# 如果不压缩,那么不需要统计数据量。
if self.try_request_gzip:
Expand Down