Skip to content

Commit faac89e

Browse files
committed
Propagate plugin config handling to all MP based models
1 parent 2702650 commit faac89e

File tree

11 files changed

+107
-78
lines changed

11 files changed

+107
-78
lines changed

src/capi_frontend/server_settings.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
#include "../stringutils.hpp"
2121

2222
namespace ovms {
23+
EmbeddingsGraphSettingsImpl::EmbeddingsGraphSettingsImpl() :
24+
pluginConfig{std::nullopt, std::nullopt, std::nullopt, 1} {}
25+
RerankGraphSettingsImpl::RerankGraphSettingsImpl() :
26+
pluginConfig{std::nullopt, std::nullopt, std::nullopt, 1} {}
2327

2428
std::string enumToString(ConfigExportType type) {
2529
auto it = configExportTypeToString.find(type);

src/capi_frontend/server_settings.hpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,13 @@ struct PluginConfigSettingsImpl {
9191
std::optional<std::string> kvCachePrecision;
9292
std::optional<uint32_t> maxPromptLength;
9393
std::optional<std::string> modelDistributionPolicy;
94-
//std::optional<std::string> cacheDir;
94+
std::optional<uint32_t> numStreams;
95+
bool empty() const {
96+
return !kvCachePrecision.has_value() &&
97+
!maxPromptLength.has_value() &&
98+
!modelDistributionPolicy.has_value() &&
99+
!numStreams.has_value();
100+
}
95101
};
96102

97103
struct TextGenGraphSettingsImpl {
@@ -112,21 +118,23 @@ struct TextGenGraphSettingsImpl {
112118
};
113119

114120
struct EmbeddingsGraphSettingsImpl {
121+
EmbeddingsGraphSettingsImpl();
115122
std::string modelPath = "./";
116123
std::string targetDevice = "CPU";
117124
std::string modelName = "";
118-
uint32_t numStreams = 1;
119125
std::string normalize = "true";
120126
std::string truncate = "false";
121127
std::string pooling = "CLS";
128+
PluginConfigSettingsImpl pluginConfig;
122129
};
123130

124131
struct RerankGraphSettingsImpl {
132+
RerankGraphSettingsImpl();
125133
std::string modelPath = "./";
126134
std::string targetDevice = "CPU";
127135
std::string modelName = "";
128-
uint32_t numStreams = 1;
129136
uint64_t maxAllowedChunks = 10000;
137+
PluginConfigSettingsImpl pluginConfig;
130138
};
131139

132140
struct ImageGenerationGraphSettingsImpl {
@@ -141,7 +149,7 @@ struct ImageGenerationGraphSettingsImpl {
141149
std::optional<uint32_t> maxNumberImagesPerPrompt;
142150
std::optional<uint32_t> defaultNumInferenceSteps;
143151
std::optional<uint32_t> maxNumInferenceSteps;
144-
std::string pluginConfig;
152+
PluginConfigSettingsImpl pluginConfig;
145153
};
146154

147155
struct ExportSettings {

src/graph_export/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ ovms_cc_library(
8080
"@ovms//src:libovms_server_settings",
8181
"@ovms//src:ovms_exit_codes",
8282
"@com_github_jarro2783_cxxopts//:cxxopts",
83-
"@com_github_tencent_rapidjson//:rapidjson",
8483
],
8584
visibility = ["//visibility:public"],
8685
)

src/graph_export/embeddings_graph_cli_parser.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ void EmbeddingsGraphCLIParser::prepare(OvmsServerMode serverMode, HFSettingsImpl
9393
throw std::logic_error("Tried to prepare server and model settings without graph parse result");
9494
}
9595
} else {
96-
embeddingsGraphSettings.numStreams = result->operator[]("num_streams").as<uint32_t>();
96+
embeddingsGraphSettings.pluginConfig.numStreams = result->operator[]("num_streams").as<uint32_t>();
9797
embeddingsGraphSettings.normalize = result->operator[]("normalize").as<std::string>();
9898
embeddingsGraphSettings.truncate = result->operator[]("truncate").as<std::string>();
9999
embeddingsGraphSettings.pooling = result->operator[]("pooling").as<std::string>();

src/graph_export/graph_export.cpp

Lines changed: 62 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,13 @@ std::string GraphExport::getDraftModelDirectoryPath(const std::string& directory
8585

8686
static Status createTextGenerationGraphTemplate(const std::string& directoryPath, const HFSettingsImpl& hfSettings) {
8787
if (!std::holds_alternative<TextGenGraphSettingsImpl>(hfSettings.graphSettings)) {
88+
SPDLOG_ERROR("Graph options not initialized for text generation.");
8889
return StatusCode::INTERNAL_ERROR;
8990
}
9091
auto& graphSettings = std::get<TextGenGraphSettingsImpl>(hfSettings.graphSettings);
9192
auto& ggufFilename = hfSettings.ggufFilename;
9293
auto& exportSettings = hfSettings.exportSettings;
94+
9395
std::ostringstream oss;
9496
oss << OVMS_VERSION_GRAPH_LINE;
9597
std::string modelsPath = constructModelsPath(graphSettings.modelPath, ggufFilename);
@@ -188,13 +190,25 @@ static Status createTextGenerationGraphTemplate(const std::string& directoryPath
188190
return FileSystem::createFileOverwrite(fullPath, oss.str());
189191
}
190192

191-
static Status createRerankGraphTemplate(const std::string& directoryPath, const RerankGraphSettingsImpl& graphSettings) {
193+
static Status createRerankGraphTemplate(const std::string& directoryPath, const HFSettingsImpl& hfSettings) {
194+
if (!std::holds_alternative<RerankGraphSettingsImpl>(hfSettings.graphSettings)) {
195+
SPDLOG_ERROR("Graph options not initialized for reranking.");
196+
return StatusCode::INTERNAL_ERROR;
197+
}
198+
auto& graphSettings = std::get<RerankGraphSettingsImpl>(hfSettings.graphSettings);
199+
auto& ggufFilename = hfSettings.ggufFilename;
200+
auto& exportSettings = hfSettings.exportSettings;
201+
192202
std::ostringstream oss;
193203
oss << OVMS_VERSION_GRAPH_LINE;
194204
// Windows path creation - graph parser needs forward slashes in paths
195-
std::string graphOkPath = graphSettings.modelPath;
196-
if (FileSystem::getOsSeparator() != "/") {
197-
std::replace(graphOkPath.begin(), graphOkPath.end(), '\\', '/');
205+
std::string modelsPath = constructModelsPath(graphSettings.modelPath, ggufFilename);
206+
SPDLOG_TRACE("modelsPath: {}, directoryPath: {}, ggufFilename: {}", modelsPath, directoryPath, ggufFilename.value_or("std::nullopt"));
207+
auto pluginConfigOrStatus = GraphExport::createPluginString(graphSettings.pluginConfig, exportSettings);
208+
if (std::holds_alternative<Status>(pluginConfigOrStatus)) {
209+
auto status = std::get<Status>(pluginConfigOrStatus);
210+
SPDLOG_ERROR("Failed to create plugin config: {}", status.string());
211+
return status;
198212
}
199213
// clang-format off
200214
oss << R"(
@@ -210,11 +224,11 @@ node {
210224
node_options: {
211225
[type.googleapis.com / mediapipe.RerankCalculatorOVOptions]: {
212226
models_path: ")"
213-
<< graphOkPath << R"(",
227+
<< modelsPath << R"(",
214228
max_allowed_chunks: )"
215229
<< graphSettings.maxAllowedChunks << R"(,
216230
target_device: ")" << graphSettings.targetDevice << R"(",
217-
plugin_config: '{ "NUM_STREAMS": ")" << graphSettings.numStreams << R"("}',
231+
plugin_config: ')" << std::get<std::string>(pluginConfigOrStatus) << R"(',
218232
}
219233
}
220234
})";
@@ -232,15 +246,25 @@ node {
232246
return FileSystem::createFileOverwrite(fullPath, oss.str());
233247
}
234248

235-
static Status createEmbeddingsGraphTemplate(const std::string& directoryPath, const EmbeddingsGraphSettingsImpl& graphSettings) {
249+
static Status createEmbeddingsGraphTemplate(const std::string& directoryPath, const HFSettingsImpl& hfSettings) {
250+
if (!std::holds_alternative<EmbeddingsGraphSettingsImpl>(hfSettings.graphSettings)) {
251+
SPDLOG_ERROR("Graph options not initialized for embeddings.");
252+
return StatusCode::INTERNAL_ERROR;
253+
}
254+
auto& graphSettings = std::get<EmbeddingsGraphSettingsImpl>(hfSettings.graphSettings);
255+
auto& ggufFilename = hfSettings.ggufFilename;
256+
auto& exportSettings = hfSettings.exportSettings;
257+
236258
std::ostringstream oss;
237259
oss << OVMS_VERSION_GRAPH_LINE;
238-
// Windows path creation - graph parser needs forward slashes in paths
239-
std::string graphOkPath = graphSettings.modelPath;
240-
if (FileSystem::getOsSeparator() != "/") {
241-
std::replace(graphOkPath.begin(), graphOkPath.end(), '\\', '/');
260+
std::string modelsPath = constructModelsPath(graphSettings.modelPath, ggufFilename);
261+
SPDLOG_TRACE("modelsPath: {}, directoryPath: {}, ggufFilename: {}", modelsPath, directoryPath, ggufFilename.value_or("std::nullopt"));
262+
auto pluginConfigOrStatus = GraphExport::createPluginString(graphSettings.pluginConfig, exportSettings);
263+
if (std::holds_alternative<Status>(pluginConfigOrStatus)) {
264+
auto status = std::get<Status>(pluginConfigOrStatus);
265+
SPDLOG_ERROR("Failed to create plugin config: {}", status.string());
266+
return status;
242267
}
243-
244268
// clang-format off
245269
oss << R"(
246270
input_stream: "REQUEST_PAYLOAD:input"
@@ -255,15 +279,15 @@ node {
255279
node_options: {
256280
[type.googleapis.com / mediapipe.EmbeddingsCalculatorOVOptions]: {
257281
models_path: ")"
258-
<< graphOkPath << R"(",
282+
<< modelsPath << R"(",
259283
normalize_embeddings: )"
260284
<< graphSettings.normalize << R"(,
261285
truncate: )"
262286
<< graphSettings.truncate << R"(,
263287
pooling: )"
264288
<< graphSettings.pooling << R"(,
265289
target_device: ")" << graphSettings.targetDevice << R"(",
266-
plugin_config: '{ "NUM_STREAMS": ")" << graphSettings.numStreams << R"("}',
290+
plugin_config: ')" << std::get<std::string>(pluginConfigOrStatus) << R"(',
267291
}
268292
}
269293
})";
@@ -281,7 +305,24 @@ node {
281305
return FileSystem::createFileOverwrite(fullPath, oss.str());
282306
}
283307

284-
static Status createImageGenerationGraphTemplate(const std::string& directoryPath, const ImageGenerationGraphSettingsImpl& graphSettings) {
308+
static Status createImageGenerationGraphTemplate(const std::string& directoryPath, const HFSettingsImpl& hfSettings) {
309+
if (!std::holds_alternative<ImageGenerationGraphSettingsImpl>(hfSettings.graphSettings)) {
310+
SPDLOG_ERROR("Graph options not initialized for image generation.");
311+
return StatusCode::INTERNAL_ERROR;
312+
}
313+
auto& graphSettings = std::get<ImageGenerationGraphSettingsImpl>(hfSettings.graphSettings);
314+
auto& exportSettings = hfSettings.exportSettings;
315+
auto& ggufFilename = hfSettings.ggufFilename;
316+
std::string modelsPath = constructModelsPath(graphSettings.modelPath, ggufFilename);
317+
SPDLOG_TRACE("modelsPath: {}, directoryPath: {}, ggufFilename: {}", modelsPath, directoryPath, ggufFilename.value_or("std::nullopt"));
318+
auto pluginConfigOrStatus = GraphExport::createPluginString(graphSettings.pluginConfig, exportSettings);
319+
if (std::holds_alternative<Status>(pluginConfigOrStatus)) {
320+
auto status = std::get<Status>(pluginConfigOrStatus);
321+
SPDLOG_ERROR("Failed to create plugin config: {}", status.string());
322+
return status;
323+
}
324+
const std::string pluginConfig = std::get<std::string>(pluginConfigOrStatus);
325+
285326
std::ostringstream oss;
286327
oss << OVMS_VERSION_GRAPH_LINE;
287328
// clang-format off
@@ -299,10 +340,10 @@ node: {
299340
[type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: {
300341
models_path: ")" << graphSettings.modelPath << R"("
301342
device: ")" << graphSettings.targetDevice << R"(")";
302-
303-
if (graphSettings.pluginConfig.size()) {
343+
// TODO by default our utility generates empty plugin config which may differ in behavior to nto setting it at all
344+
if (pluginConfig.size() > 4) {
304345
oss << R"(
305-
plugin_config: ')" << graphSettings.pluginConfig << R"(')";
346+
plugin_config: ')" << std::get<std::string>(pluginConfigOrStatus) << R"(')";
306347
}
307348

308349
if (graphSettings.resolution.size()) {
@@ -383,26 +424,11 @@ Status GraphExport::createServableConfig(const std::string& directoryPath, const
383424
if (hfSettings.task == TEXT_GENERATION_GRAPH) {
384425
return createTextGenerationGraphTemplate(directoryPath, hfSettings);
385426
} else if (hfSettings.task == EMBEDDINGS_GRAPH) {
386-
if (std::holds_alternative<EmbeddingsGraphSettingsImpl>(hfSettings.graphSettings)) {
387-
return createEmbeddingsGraphTemplate(directoryPath, std::get<EmbeddingsGraphSettingsImpl>(hfSettings.graphSettings));
388-
} else {
389-
SPDLOG_ERROR("Graph options not initialized for embeddings.");
390-
return StatusCode::INTERNAL_ERROR;
391-
}
427+
return createEmbeddingsGraphTemplate(directoryPath, hfSettings);
392428
} else if (hfSettings.task == RERANK_GRAPH) {
393-
if (std::holds_alternative<RerankGraphSettingsImpl>(hfSettings.graphSettings)) {
394-
return createRerankGraphTemplate(directoryPath, std::get<RerankGraphSettingsImpl>(hfSettings.graphSettings));
395-
} else {
396-
SPDLOG_ERROR("Graph options not initialized for rerank.");
397-
return StatusCode::INTERNAL_ERROR;
398-
}
429+
return createRerankGraphTemplate(directoryPath, hfSettings);
399430
} else if (hfSettings.task == IMAGE_GENERATION_GRAPH) {
400-
if (std::holds_alternative<ImageGenerationGraphSettingsImpl>(hfSettings.graphSettings)) {
401-
return createImageGenerationGraphTemplate(directoryPath, std::get<ImageGenerationGraphSettingsImpl>(hfSettings.graphSettings));
402-
} else {
403-
SPDLOG_ERROR("Graph options not initialized for image generation.");
404-
return StatusCode::INTERNAL_ERROR;
405-
}
431+
return createImageGenerationGraphTemplate(directoryPath, hfSettings);
406432
} else if (hfSettings.task == UNKNOWN_GRAPH) {
407433
SPDLOG_ERROR("Graph options not initialized.");
408434
return StatusCode::INTERNAL_ERROR;

src/graph_export/image_generation_graph_cli_parser.cpp

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,6 @@
2424
#include <utility>
2525
#include <vector>
2626

27-
#pragma warning(push)
28-
#pragma warning(disable : 6313)
29-
#include <rapidjson/document.h>
30-
#include <rapidjson/istreamwrapper.h>
31-
#include <rapidjson/stringbuffer.h>
32-
#include <rapidjson/writer.h>
33-
#pragma warning(pop)
34-
3527
#include "../capi_frontend/server_settings.hpp"
3628
#include "../ovms_exit_codes.hpp"
3729
#include "../status.hpp"
@@ -159,25 +151,17 @@ void ImageGenerationGraphCLIParser::prepare(ServerSettingsImpl& serverSettings,
159151
}
160152

161153
if (result->count("num_streams") || serverSettings.cacheDir != "") {
162-
rapidjson::Document pluginConfigDoc;
163-
pluginConfigDoc.SetObject();
164-
rapidjson::Document::AllocatorType& allocator = pluginConfigDoc.GetAllocator();
165154
if (result->count("num_streams")) {
166155
uint32_t numStreams = result->operator[]("num_streams").as<uint32_t>();
167156
if (numStreams == 0) {
168157
throw std::invalid_argument("num_streams must be greater than 0");
169158
}
170-
pluginConfigDoc.AddMember("NUM_STREAMS", numStreams, allocator);
159+
imageGenerationGraphSettings.pluginConfig.numStreams = result->operator[]("num_streams").as<uint32_t>();
171160
}
172161

173162
if (!serverSettings.cacheDir.empty()) {
174-
pluginConfigDoc.AddMember("CACHE_DIR", rapidjson::Value(serverSettings.cacheDir.c_str(), allocator), allocator);
163+
hfSettings.exportSettings.cacheDir = serverSettings.cacheDir;
175164
}
176-
177-
rapidjson::StringBuffer buffer;
178-
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
179-
pluginConfigDoc.Accept(writer);
180-
imageGenerationGraphSettings.pluginConfig = buffer.GetString();
181165
}
182166
}
183167

src/graph_export/rerank_graph_cli_parser.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ void RerankGraphCLIParser::prepare(OvmsServerMode serverMode, HFSettingsImpl& hf
8686
throw std::logic_error("Tried to prepare server and model settings without graph parse result");
8787
}
8888
} else {
89-
rerankGraphSettings.numStreams = result->operator[]("num_streams").as<uint32_t>();
89+
rerankGraphSettings.pluginConfig.numStreams = result->operator[]("num_streams").as<uint32_t>();
9090
rerankGraphSettings.maxAllowedChunks = result->operator[]("max_allowed_chunks").as<uint64_t>();
9191
}
9292

src/mediapipe_internal/mediapipegraphdefinition.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ Status MediapipeGraphDefinition::validateForConfigLoadableness() {
106106
SPDLOG_LOGGER_ERROR(modelmanager_logger, "Trying to parse mediapipe graph definition: {} failed", this->getName(), this->chosenConfig);
107107
return StatusCode::MEDIAPIPE_GRAPH_CONFIG_FILE_INVALID;
108108
}
109+
SPDLOG_TRACE("Will try to load pbtxt config: {}", this->chosenConfig);
109110
return StatusCode::OK;
110111
}
111112

src/test/graph_export_test.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ TEST_F(GraphCreationTest, rerankPositiveNonDefault) {
571571
rerankGraphSettings.targetDevice = "GPU";
572572
rerankGraphSettings.modelName = "myModel";
573573
rerankGraphSettings.modelPath = "/some/path";
574-
rerankGraphSettings.numStreams = 2;
574+
rerankGraphSettings.pluginConfig.numStreams = 2;
575575
rerankGraphSettings.maxAllowedChunks = 18;
576576
hfSettings.graphSettings = std::move(rerankGraphSettings);
577577

@@ -605,7 +605,7 @@ TEST_F(GraphCreationTest, rerankCreatedPbtxtInvalid) {
605605
ovms::RerankGraphSettingsImpl rerankGraphSettings;
606606
rerankGraphSettings.targetDevice = "GPU";
607607
rerankGraphSettings.modelName = "myModel\"";
608-
rerankGraphSettings.numStreams = 2;
608+
rerankGraphSettings.pluginConfig.numStreams = 2;
609609
hfSettings.graphSettings = std::move(rerankGraphSettings);
610610
std::string graphPath = ovms::FileSystem::appendSlash(this->directoryPath) + "graph.pbtxt";
611611
std::unique_ptr<ovms::GraphExport> graphExporter = std::make_unique<ovms::GraphExport>();
@@ -624,7 +624,7 @@ TEST_F(GraphCreationTest, embeddingsPositiveNonDefault) {
624624
embeddingsGraphSettings.targetDevice = "GPU";
625625
embeddingsGraphSettings.modelName = "myModel";
626626
embeddingsGraphSettings.modelPath = "/model1/path";
627-
embeddingsGraphSettings.numStreams = 2;
627+
embeddingsGraphSettings.pluginConfig.numStreams = 2;
628628
embeddingsGraphSettings.normalize = "false";
629629
embeddingsGraphSettings.truncate = "true";
630630
embeddingsGraphSettings.pooling = "LAST";
@@ -658,7 +658,7 @@ TEST_F(GraphCreationTest, embeddingsCreatedPbtxtInvalid) {
658658
ovms::EmbeddingsGraphSettingsImpl embeddingsGraphSettings;
659659
embeddingsGraphSettings.targetDevice = "GPU";
660660
embeddingsGraphSettings.modelName = "myModel\"";
661-
embeddingsGraphSettings.numStreams = 2;
661+
embeddingsGraphSettings.pluginConfig.numStreams = 2;
662662
embeddingsGraphSettings.normalize = "true";
663663
embeddingsGraphSettings.pooling = "CLS";
664664
hfSettings.graphSettings = std::move(embeddingsGraphSettings);
@@ -794,7 +794,8 @@ TEST_F(GraphCreationTest, imageGenerationPositiveFull) {
794794
ovms::HFSettingsImpl hfSettings;
795795
hfSettings.task = ovms::IMAGE_GENERATION_GRAPH;
796796
ovms::ImageGenerationGraphSettingsImpl imageGenerationGraphSettings;
797-
imageGenerationGraphSettings.pluginConfig = "{\"NUM_STREAMS\":14,\"CACHE_DIR\":\"/cache\"}";
797+
imageGenerationGraphSettings.pluginConfig.numStreams = 14;
798+
hfSettings.exportSettings.cacheDir = "/cache";
798799
imageGenerationGraphSettings.targetDevice = "GPU";
799800
imageGenerationGraphSettings.defaultResolution = "300x400";
800801
imageGenerationGraphSettings.maxResolution = "3000x4000";

0 commit comments

Comments
 (0)