Skip to content

Commit ff78cdb

Browse files
committed
Update
1 parent 3032eec commit ff78cdb

File tree

5 files changed

+60
-66
lines changed

5 files changed

+60
-66
lines changed

src/capi_frontend/server_settings.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ struct EmbeddingsGraphSettingsImpl {
130130
};
131131

132132
struct RerankGraphSettingsImpl {
133-
RerankGraphSettingsImpl::RerankGraphSettingsImpl() :
133+
RerankGraphSettingsImpl() :
134134
pluginConfig{std::nullopt, std::nullopt, std::nullopt, 1} {}
135135
std::string modelPath = "./";
136136
std::string targetDevice = "CPU";

src/graph_export/graph_export.cpp

Lines changed: 46 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,14 @@ std::string GraphExport::getDraftModelDirectoryPath(const std::string& directory
8282
std::string fullPath = FileSystem::joinPath({directoryPath, GraphExport::getDraftModelDirectoryName(draftModel)});
8383
return fullPath;
8484
}
85+
#define GET_PLUGIN_CONFIG_OPT_OR_FAIL_AND_RETURN(PLUGIN_SETTINGS_IMPL, EXPORT_SETTINGS) \
86+
auto pluginConfigOrStatus = GraphExport::createPluginString(PLUGIN_SETTINGS_IMPL, EXPORT_SETTINGS); \
87+
if (std::holds_alternative<Status>(pluginConfigOrStatus)) { \
88+
auto status = std::get<Status>(pluginConfigOrStatus); \
89+
SPDLOG_ERROR("Failed to create plugin config: {}", status.string()); \
90+
return status; \
91+
} \
92+
auto pluginConfigOpt = std::get<std::optional<std::string>>(pluginConfigOrStatus)
8593

8694
static Status createTextGenerationGraphTemplate(const std::string& directoryPath, const HFSettingsImpl& hfSettings) {
8795
if (!std::holds_alternative<TextGenGraphSettingsImpl>(hfSettings.graphSettings)) {
@@ -96,12 +104,7 @@ static Status createTextGenerationGraphTemplate(const std::string& directoryPath
96104
oss << OVMS_VERSION_GRAPH_LINE;
97105
std::string modelsPath = constructModelsPath(graphSettings.modelPath, ggufFilename);
98106
SPDLOG_TRACE("modelsPath: {}, directoryPath: {}, ggufFilename: {}", modelsPath, directoryPath, ggufFilename.value_or("std::nullopt"));
99-
auto pluginConfigOrStatus = GraphExport::createPluginString(graphSettings.pluginConfig, exportSettings);
100-
if (std::holds_alternative<Status>(pluginConfigOrStatus)) {
101-
auto status = std::get<Status>(pluginConfigOrStatus);
102-
SPDLOG_ERROR("Failed to create plugin config: {}", status.string());
103-
return status;
104-
}
107+
GET_PLUGIN_CONFIG_OPT_OR_FAIL_AND_RETURN(graphSettings.pluginConfig, exportSettings);
105108
// clang-format off
106109
oss << R"(
107110
input_stream: "HTTP_REQUEST_PAYLOAD:input"
@@ -126,9 +129,13 @@ static Status createTextGenerationGraphTemplate(const std::string& directoryPath
126129
<< graphSettings.targetDevice << R"(",
127130
models_path: ")"
128131
<< modelsPath << R"(",
129-
plugin_config: ')"
130-
<< std::get<std::string>(pluginConfigOrStatus) << R"(',
131-
enable_prefix_caching: )"
132+
)";
133+
if (pluginConfigOpt.has_value()) {
134+
oss << R"(plugin_config: ')"
135+
<< pluginConfigOpt.value() << R"(',
136+
)";
137+
}
138+
oss << R"(enable_prefix_caching: )"
132139
<< graphSettings.enablePrefixCaching << R"(,
133140
cache_size: )"
134141
<< graphSettings.cacheSize << R"(,)";
@@ -204,12 +211,7 @@ static Status createRerankGraphTemplate(const std::string& directoryPath, const
204211
// Windows path creation - graph parser needs forward slashes in paths
205212
std::string modelsPath = constructModelsPath(graphSettings.modelPath, ggufFilename);
206213
SPDLOG_TRACE("modelsPath: {}, directoryPath: {}, ggufFilename: {}", modelsPath, directoryPath, ggufFilename.value_or("std::nullopt"));
207-
auto pluginConfigOrStatus = GraphExport::createPluginString(graphSettings.pluginConfig, exportSettings);
208-
if (std::holds_alternative<Status>(pluginConfigOrStatus)) {
209-
auto status = std::get<Status>(pluginConfigOrStatus);
210-
SPDLOG_ERROR("Failed to create plugin config: {}", status.string());
211-
return status;
212-
}
214+
GET_PLUGIN_CONFIG_OPT_OR_FAIL_AND_RETURN(graphSettings.pluginConfig, exportSettings);
213215
// clang-format off
214216
oss << R"(
215217
input_stream: "REQUEST_PAYLOAD:input"
@@ -228,7 +230,11 @@ node {
228230
max_allowed_chunks: )"
229231
<< graphSettings.maxAllowedChunks << R"(,
230232
target_device: ")" << graphSettings.targetDevice << R"(",
231-
plugin_config: ')" << std::get<std::string>(pluginConfigOrStatus) << R"(',
233+
)";
234+
if (pluginConfigOpt.has_value()) {
235+
oss << R"(plugin_config: ')" << pluginConfigOpt.value() << R"(',)";
236+
}
237+
oss << R"(
232238
}
233239
}
234240
})";
@@ -259,12 +265,7 @@ static Status createEmbeddingsGraphTemplate(const std::string& directoryPath, co
259265
oss << OVMS_VERSION_GRAPH_LINE;
260266
std::string modelsPath = constructModelsPath(graphSettings.modelPath, ggufFilename);
261267
SPDLOG_TRACE("modelsPath: {}, directoryPath: {}, ggufFilename: {}", modelsPath, directoryPath, ggufFilename.value_or("std::nullopt"));
262-
auto pluginConfigOrStatus = GraphExport::createPluginString(graphSettings.pluginConfig, exportSettings);
263-
if (std::holds_alternative<Status>(pluginConfigOrStatus)) {
264-
auto status = std::get<Status>(pluginConfigOrStatus);
265-
SPDLOG_ERROR("Failed to create plugin config: {}", status.string());
266-
return status;
267-
}
268+
GET_PLUGIN_CONFIG_OPT_OR_FAIL_AND_RETURN(graphSettings.pluginConfig, exportSettings);
268269
// clang-format off
269270
oss << R"(
270271
input_stream: "REQUEST_PAYLOAD:input"
@@ -287,8 +288,12 @@ node {
287288
pooling: )"
288289
<< graphSettings.pooling << R"(,
289290
target_device: ")" << graphSettings.targetDevice << R"(",
290-
plugin_config: ')" << std::get<std::string>(pluginConfigOrStatus) << R"(',
291-
}
291+
)";
292+
if (pluginConfigOpt.has_value()) {
293+
oss << R"(plugin_config: ')" << pluginConfigOpt.value() << R"(',
294+
)";
295+
}
296+
oss << R"(}
292297
}
293298
})";
294299

@@ -315,13 +320,7 @@ static Status createImageGenerationGraphTemplate(const std::string& directoryPat
315320
auto& ggufFilename = hfSettings.ggufFilename;
316321
std::string modelsPath = constructModelsPath(graphSettings.modelPath, ggufFilename);
317322
SPDLOG_TRACE("modelsPath: {}, directoryPath: {}, ggufFilename: {}", modelsPath, directoryPath, ggufFilename.value_or("std::nullopt"));
318-
auto pluginConfigOrStatus = GraphExport::createPluginString(graphSettings.pluginConfig, exportSettings);
319-
if (std::holds_alternative<Status>(pluginConfigOrStatus)) {
320-
auto status = std::get<Status>(pluginConfigOrStatus);
321-
SPDLOG_ERROR("Failed to create plugin config: {}", status.string());
322-
return status;
323-
}
324-
const std::string pluginConfig = std::get<std::string>(pluginConfigOrStatus);
323+
GET_PLUGIN_CONFIG_OPT_OR_FAIL_AND_RETURN(graphSettings.pluginConfig, exportSettings);
325324

326325
std::ostringstream oss;
327326
oss << OVMS_VERSION_GRAPH_LINE;
@@ -340,10 +339,9 @@ node: {
340339
[type.googleapis.com / mediapipe.ImageGenCalculatorOptions]: {
341340
models_path: ")" << graphSettings.modelPath << R"("
342341
device: ")" << graphSettings.targetDevice << R"(")";
343-
// TODO by default our utility generates empty plugin config which may differ in behavior to nto setting it at all
344-
if (pluginConfig.size() > 4) {
342+
if (pluginConfigOpt.has_value()) {
345343
oss << R"(
346-
plugin_config: ')" << std::get<std::string>(pluginConfigOrStatus) << R"(')";
344+
plugin_config: ')" << pluginConfigOpt.value() << R"(')";
347345
}
348346

349347
if (graphSettings.resolution.size()) {
@@ -436,7 +434,7 @@ Status GraphExport::createServableConfig(const std::string& directoryPath, const
436434
return StatusCode::INTERNAL_ERROR;
437435
}
438436

439-
std::variant<std::string, Status> GraphExport::createPluginString(const PluginConfigSettingsImpl& pluginConfig, const ExportSettings& exportSettings) {
437+
std::variant<std::optional<std::string>, Status> GraphExport::createPluginString(const PluginConfigSettingsImpl& pluginConfig, const ExportSettings& exportSettings) {
440438
auto& stringPluginConfig = exportSettings.pluginConfig;
441439
rapidjson::Document d;
442440
d.SetObject();
@@ -446,7 +444,6 @@ std::variant<std::string, Status> GraphExport::createPluginString(const PluginCo
446444
}
447445
}
448446
bool configNotEmpty = false;
449-
450447
if (pluginConfig.kvCachePrecision.has_value()) {
451448
rapidjson::Value name;
452449
name.SetString(pluginConfig.kvCachePrecision.value().c_str(), d.GetAllocator());
@@ -457,7 +454,6 @@ std::variant<std::string, Status> GraphExport::createPluginString(const PluginCo
457454
d.AddMember("KV_CACHE_PRECISION", name, d.GetAllocator());
458455
configNotEmpty = true;
459456
}
460-
461457
if (pluginConfig.maxPromptLength.has_value()) {
462458
rapidjson::Value value;
463459
value.SetUint(pluginConfig.maxPromptLength.value());
@@ -468,7 +464,6 @@ std::variant<std::string, Status> GraphExport::createPluginString(const PluginCo
468464
d.AddMember("MAX_PROMPT_LEN", value, d.GetAllocator());
469465
configNotEmpty = true;
470466
}
471-
472467
if (pluginConfig.modelDistributionPolicy.has_value()) {
473468
rapidjson::Value value;
474469
value.SetString(pluginConfig.modelDistributionPolicy.value().c_str(), d.GetAllocator());
@@ -479,6 +474,16 @@ std::variant<std::string, Status> GraphExport::createPluginString(const PluginCo
479474
d.AddMember("MODEL_DISTRIBUTION_POLICY", value, d.GetAllocator());
480475
configNotEmpty = true;
481476
}
477+
if (pluginConfig.numStreams.has_value()) {
478+
rapidjson::Value value;
479+
value.SetUint(pluginConfig.numStreams.value());
480+
auto itr = d.FindMember("NUM_STREAMS");
481+
if (itr != d.MemberEnd()) {
482+
return Status(StatusCode::PLUGIN_CONFIG_CONFLICTING_PARAMETERS, "Doubled NUM_STREAMS parameter in plugin config.");
483+
}
484+
d.AddMember("NUM_STREAMS", value, d.GetAllocator());
485+
configNotEmpty = true;
486+
}
482487
if (exportSettings.cacheDir.has_value()) {
483488
rapidjson::Value value;
484489
value.SetString(exportSettings.cacheDir.value().c_str(), d.GetAllocator());
@@ -489,20 +494,17 @@ std::variant<std::string, Status> GraphExport::createPluginString(const PluginCo
489494
d.AddMember("CACHE_DIR", value, d.GetAllocator());
490495
configNotEmpty = true;
491496
}
492-
493-
std::string pluginString = "{ }";
494-
495497
if (configNotEmpty) {
496498
// Serialize the document to a JSON string
497499
rapidjson::StringBuffer buffer;
498500
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
499501
d.Accept(writer);
500502

501503
// Output the JSON string
502-
pluginString = buffer.GetString();
504+
return buffer.GetString();
505+
} else {
506+
return std::nullopt;
503507
}
504-
505-
return pluginString;
506508
}
507509

508510
} // namespace ovms

src/graph_export/graph_export.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class GraphExport {
2828
public:
2929
GraphExport();
3030
Status createServableConfig(const std::string& directoryPath, const HFSettingsImpl& graphSettings);
31-
static std::variant<std::string, Status> createPluginString(const PluginConfigSettingsImpl& pluginConfig, const ExportSettings& exportSettings);
31+
static std::variant<std::optional<std::string>, Status> createPluginString(const PluginConfigSettingsImpl& pluginConfig, const ExportSettings& exportSettings);
3232
static std::string getDraftModelDirectoryName(std::string draftModel);
3333
static std::string getDraftModelDirectoryPath(const std::string& directoryPath, const std::string& draftModel);
3434
};

src/test/graph_export_test.cpp

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ const std::string expectedGraphContentsWithResponseParser = R"(
123123
max_num_seqs:256,
124124
device: "CPU",
125125
models_path: "./",
126-
plugin_config: '{ }',
127126
enable_prefix_caching: true,
128127
cache_size: 10,
129128
reasoning_parser: "REASONING_PARSER",
@@ -164,7 +163,6 @@ const std::string expectedDefaultGraphContents = R"(
164163
max_num_seqs:256,
165164
device: "CPU",
166165
models_path: "./",
167-
plugin_config: '{ }',
168166
enable_prefix_caching: true,
169167
cache_size: 10,
170168
}
@@ -202,7 +200,6 @@ const std::string expectedDraftAndFuseGraphContents = R"(
202200
max_num_seqs:256,
203201
device: "CPU",
204202
models_path: "./",
205-
plugin_config: '{ }',
206203
enable_prefix_caching: true,
207204
cache_size: 10,
208205
dynamic_split_fuse: false,
@@ -243,7 +240,6 @@ const std::string expectedGGUFGraphContents = R"(
243240
max_num_seqs:256,
244241
device: "CPU",
245242
models_path: "./PRETTY_GOOD_GGUF_MODEL.gguf",
246-
plugin_config: '{ }',
247243
enable_prefix_caching: true,
248244
cache_size: 10,
249245
}
@@ -281,7 +277,6 @@ const std::string expectedGGUFGraphContents2 = R"(
281277
max_num_seqs:256,
282278
device: "CPU",
283279
models_path: "./PRETTY_GOOD_GGUF_MODEL_Q8-00001-of-20000.gguf",
284-
plugin_config: '{ }',
285280
enable_prefix_caching: true,
286281
cache_size: 10,
287282
}
@@ -313,7 +308,7 @@ node {
313308
models_path: "/some/path",
314309
max_allowed_chunks: 18,
315310
target_device: "GPU",
316-
plugin_config: '{ "NUM_STREAMS": "2"}',
311+
plugin_config: '{"NUM_STREAMS":2}',
317312
}
318313
}
319314
}
@@ -333,7 +328,7 @@ node {
333328
models_path: "./",
334329
max_allowed_chunks: 10000,
335330
target_device: "CPU",
336-
plugin_config: '{ "NUM_STREAMS": "1"}',
331+
plugin_config: '{"NUM_STREAMS":1}',
337332
}
338333
}
339334
}
@@ -355,7 +350,7 @@ node {
355350
truncate: true,
356351
pooling: LAST,
357352
target_device: "GPU",
358-
plugin_config: '{ "NUM_STREAMS": "2"}',
353+
plugin_config: '{"NUM_STREAMS":2}',
359354
}
360355
}
361356
}
@@ -377,7 +372,7 @@ node {
377372
truncate: false,
378373
pooling: CLS,
379374
target_device: "CPU",
380-
plugin_config: '{ "NUM_STREAMS": "1"}',
375+
plugin_config: '{"NUM_STREAMS":1}',
381376
}
382377
}
383378
}
@@ -818,12 +813,11 @@ TEST_F(GraphCreationTest, pluginConfigAsString) {
818813
pluginConfig.maxPromptLength = 256;
819814
pluginConfig.modelDistributionPolicy = "TENSOR_PARALLEL";
820815
ovms::ExportSettings exportSettings;
821-
exportSettings.pluginConfig = "{\"NUM_STREAMS\":\"4\"}";
816+
exportSettings.pluginConfig = "{\"NUM_STREAMS\":4}";
822817
auto res = ovms::GraphExport::createPluginString(pluginConfig, exportSettings);
823-
ASSERT_TRUE(std::holds_alternative<std::string>(res));
824-
ASSERT_EQ(std::get<std::string>(res),
825-
"{\"NUM_STREAMS\":\"4\",\"KV_CACHE_PRECISION\":\"u8\",\"MAX_PROMPT_LEN\":256,\"MODEL_DISTRIBUTION_POLICY\":\"TENSOR_PARALLEL\"}");
826-
// ovms::Model
818+
ASSERT_TRUE(std::holds_alternative<std::optional<std::string>>(res));
819+
ASSERT_EQ(std::get<std::optional<std::string>>(res).value(),
820+
"{\"NUM_STREAMS\":4,\"KV_CACHE_PRECISION\":\"u8\",\"MAX_PROMPT_LEN\":256,\"MODEL_DISTRIBUTION_POLICY\":\"TENSOR_PARALLEL\"}");
827821
}
828822
TEST_F(GraphCreationTest, pluginConfigNegative) {
829823
using ovms::Status;
@@ -837,20 +831,20 @@ TEST_F(GraphCreationTest, pluginConfigNegative) {
837831
exportSettings.pluginConfig = "{\"KV_CACHE_PRECISION\":\"fp16\"}";
838832
exportSettings.cacheDir = "/cache";
839833
auto res = ovms::GraphExport::createPluginString(pluginConfig, exportSettings);
840-
ASSERT_FALSE(std::holds_alternative<std::string>(res));
834+
ASSERT_TRUE(std::holds_alternative<ovms::Status>(res));
841835
ASSERT_EQ(std::get<Status>(res), ovms::StatusCode::PLUGIN_CONFIG_CONFLICTING_PARAMETERS);
842836

843837
exportSettings.pluginConfig = "{\"MAX_PROMPT_LEN\":512}";
844838
res = ovms::GraphExport::createPluginString(pluginConfig, exportSettings);
845-
ASSERT_FALSE(std::holds_alternative<std::string>(res));
839+
ASSERT_TRUE(std::holds_alternative<ovms::Status>(res));
846840
ASSERT_EQ(std::get<Status>(res), ovms::StatusCode::PLUGIN_CONFIG_CONFLICTING_PARAMETERS);
847841

848842
exportSettings.pluginConfig = "{\"CACHE_DIR\":\"/cache\"}";
849843
res = ovms::GraphExport::createPluginString(pluginConfig, exportSettings);
850-
ASSERT_FALSE(std::holds_alternative<std::string>(res));
844+
ASSERT_TRUE(std::holds_alternative<ovms::Status>(res));
851845
ASSERT_EQ(std::get<Status>(res), ovms::StatusCode::PLUGIN_CONFIG_CONFLICTING_PARAMETERS);
852846
exportSettings.pluginConfig = "{\"MODEL_DISTRIBUTION_POLICY\":\"PIPELINE_PARALLEL\"}";
853847
res = ovms::GraphExport::createPluginString(pluginConfig, exportSettings);
854-
ASSERT_FALSE(std::holds_alternative<std::string>(res));
848+
ASSERT_TRUE(std::holds_alternative<ovms::Status>(res));
855849
ASSERT_EQ(std::get<Status>(res), ovms::StatusCode::PLUGIN_CONFIG_CONFLICTING_PARAMETERS);
856850
}

src/test/pull_hf_model_test.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ const std::string expectedGraphContents = R"(
9191
max_num_seqs:256,
9292
device: "CPU",
9393
models_path: "./",
94-
plugin_config: '{ }',
9594
enable_prefix_caching: true,
9695
cache_size: 10,
9796
}
@@ -129,7 +128,6 @@ const std::string expectedGraphContentsDraft = R"(
129128
max_num_seqs:256,
130129
device: "CPU",
131130
models_path: "./",
132-
plugin_config: '{ }',
133131
enable_prefix_caching: true,
134132
cache_size: 10,
135133
# Speculative decoding configuration

0 commit comments

Comments
 (0)