Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GTIL] Allow inferencing with FP32 input + FP64 model #574

Merged
merged 2 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 6 additions & 16 deletions src/gtil/predict.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@
auto leaf_view = Array2DView<LeafOutputT>(leaf_out.data(), model.num_target, max_num_class);
for (std::int32_t target_id = 0; target_id < model.num_target; ++target_id) {
for (std::int32_t class_id = 0; class_id < model.num_class[target_id]; ++class_id) {
output_view(row_id, target_id, class_id) += leaf_view(target_id, class_id);
output_view(row_id, target_id, class_id)
+= static_cast<InputT>(leaf_view(target_id, class_id));
}
}
} else if (model.target_id[tree_id] == -1) {
Expand All @@ -193,7 +194,7 @@
auto leaf_view = Array2DView<LeafOutputT>(leaf_out.data(), model.num_target, 1);
auto const class_id = model.class_id[tree_id];
for (std::int32_t target_id = 0; target_id < model.num_target; ++target_id) {
output_view(row_id, target_id, class_id) += leaf_view(target_id, 0);
output_view(row_id, target_id, class_id) += static_cast<InputT>(leaf_view(target_id, 0));
}
} else if (model.class_id[tree_id] == -1) {
std::vector<std::int32_t> const expected_leaf_shape{1, max_num_class};
Expand All @@ -202,15 +203,15 @@
auto leaf_view = Array2DView<LeafOutputT>(leaf_out.data(), 1, max_num_class);
auto const target_id = model.target_id[tree_id];
for (std::int32_t class_id = 0; class_id < model.num_class[target_id]; ++class_id) {
output_view(row_id, target_id, class_id) += leaf_view(0, class_id);
output_view(row_id, target_id, class_id) += static_cast<InputT>(leaf_view(0, class_id));
}
} else {
std::vector<std::int32_t> const expected_leaf_shape{1, 1};
TREELITE_CHECK(model.leaf_vector_shape.AsVector() == expected_leaf_shape);

auto const target_id = model.target_id[tree_id];
auto const class_id = model.class_id[tree_id];
output_view(row_id, target_id, class_id) += leaf_out[0];
output_view(row_id, target_id, class_id) += static_cast<InputT>(leaf_out[0]);

Check warning on line 214 in src/gtil/predict.cc

View check run for this annotation

Codecov / codecov/patch

src/gtil/predict.cc#L214

Added line #L214 was not covered by tests
}
}

Expand All @@ -224,7 +225,7 @@
std::vector<std::int32_t> const expected_leaf_shape{1, 1};
TREELITE_CHECK(model.leaf_vector_shape.AsVector() == expected_leaf_shape);

output_view(row_id, target_id, class_id) += tree.LeafValue(leaf_id);
output_view(row_id, target_id, class_id) += static_cast<InputT>(tree.LeafValue(leaf_id));
}

template <typename InputT, typename MatrixAccessorT>
Expand Down Expand Up @@ -380,17 +381,6 @@
void PredictImpl(Model const& model, MatrixAccessorT accessor, std::uint64_t num_row,
InputT* output, Configuration const& config,
detail::threading_utils::ThreadConfig const& thread_config) {
TypeInfo leaf_output_type = model.GetLeafOutputType();
TypeInfo input_type = TypeInfoFromType<InputT>();
if (leaf_output_type != input_type) {
std::string expected = TypeInfoToString(leaf_output_type);
std::string got = TypeInfoToString(input_type);
if (got == "invalid") {
got = typeid(InputT).name();
}
TREELITE_LOG(FATAL) << "Incorrect input type passed to GTIL predict(). "
<< "Expected: " << expected << ", Got: " << got;
}
if (config.pred_kind == PredictKind::kPredictDefault) {
PredictRaw(model, accessor, num_row, output, thread_config);
ApplyPostProcessor(model, output, num_row, config, thread_config);
Expand Down
5 changes: 5 additions & 0 deletions src/json_serializer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <treelite/enum/operator.h>
#include <treelite/enum/task_type.h>
#include <treelite/enum/tree_node_type.h>
#include <treelite/enum/typeinfo.h>
#include <treelite/logging.h>
#include <treelite/tree.h>

Expand Down Expand Up @@ -180,6 +181,10 @@ template <typename WriterType>
void DumpModelAsJSON(WriterType& writer, Model const& model) {
writer.StartObject();

writer.Key("threshold_type");
WriteString(writer, TypeInfoToString(model.GetThresholdType()));
writer.Key("leaf_output_type");
WriteString(writer, TypeInfoToString(model.GetLeafOutputType()));
writer.Key("num_feature");
writer.Int(model.num_feature);
writer.Key("task_type");
Expand Down
29 changes: 17 additions & 12 deletions tests/cpp/test_gtil.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,10 @@ TEST_P(ParametrizedTestSuite, LeafVectorRF) {
model_builder::PostProcessorFunc postprocessor{"identity_multiclass"};
std::vector<double> base_scores{100.0, 200.0, 300.0};
std::unique_ptr<model_builder::ModelBuilder> builder
= model_builder::GetModelBuilder(TypeInfo::kFloat32, TypeInfo::kFloat32, metadata,
= model_builder::GetModelBuilder(TypeInfo::kFloat64, TypeInfo::kFloat64, metadata,
tree_annotation, postprocessor, base_scores);
auto make_tree_stump
= [&](std::vector<float> const& left_child_val, std::vector<float> const& right_child_val) {
= [&](std::vector<double> const& left_child_val, std::vector<double> const& right_child_val) {
builder->StartTree();
builder->StartNode(0);
builder->NumericalTest(0, 0.0, false, Operator::kLT, 1, 2);
Expand All @@ -141,8 +141,8 @@ TEST_P(ParametrizedTestSuite, LeafVectorRF) {
builder->EndNode();
builder->EndTree();
};
make_tree_stump({1.0f, 0.0f, 0.0f}, {0.0f, 0.5f, 0.5f});
make_tree_stump({1.0f, 0.0f, 0.0f}, {0.0f, 0.5f, 0.5f});
make_tree_stump({1.0, 0.0, 0.0}, {0.0, 0.5, 0.5});
make_tree_stump({1.0, 0.0, 0.0}, {0.0, 0.5, 0.5});

auto const predict_kind = GetParam();

Expand All @@ -154,28 +154,33 @@ TEST_P(ParametrizedTestSuite, LeafVectorRF) {
predict_kind));

std::vector<std::uint64_t> expected_output_shape;
std::vector<std::vector<float>> expected_output;
std::vector<float> expected_output_left_child;
std::vector<double> expected_output_right_child;
if (predict_kind == "raw" || predict_kind == "default") {
expected_output_shape = {1, 1, 3};
expected_output = {{100.0f, 200.5f, 300.5f}, {101.0f, 200.0f, 300.0f}};
expected_output_left_child = {100.0f, 200.5f, 300.5f};
expected_output_right_child = {101.0, 200.0, 300.0};
} else if (predict_kind == "leaf_id") {
expected_output_shape = {1, 2};
expected_output = {{2, 2}, {1, 1}};
expected_output_left_child = {2, 2};
expected_output_right_child = {1, 1};
}
auto output_shape = gtil::GetOutputShape(*model, 1, config);
EXPECT_EQ(output_shape, expected_output_shape);

std::vector<float> output(std::accumulate(
output_shape.begin(), output_shape.end(), std::uint64_t(1), std::multiplies<>()));
auto output_size = std::accumulate(
output_shape.begin(), output_shape.end(), std::uint64_t(1), std::multiplies<>());
{
std::vector<float> input{1.0f};
std::vector<float> output(output_size);
gtil::Predict(*model, input.data(), 1, output.data(), config);
EXPECT_EQ(output, expected_output[0]);
EXPECT_EQ(output, expected_output_left_child);
}
{
std::vector<float> input{-1.0f};
std::vector<double> input{-1.0};
std::vector<double> output(output_size);
gtil::Predict(*model, input.data(), 1, output.data(), config);
EXPECT_EQ(output, expected_output[1]);
EXPECT_EQ(output, expected_output_right_child);
}
}

Expand Down
24 changes: 20 additions & 4 deletions tests/cpp/test_serializer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ void SerializerRoundTrip_TreeStump() {
/* Test correctness of JSON dump */
std::string expected_json_dump_str = fmt::format(R"JSON(
{{
"threshold_type": "{threshold_type}",
"leaf_output_type": "{leaf_output_type}",
"num_feature": 2,
"task_type": "kRegressor",
"average_tree_output": false,
Expand Down Expand Up @@ -145,7 +147,9 @@ void SerializerRoundTrip_TreeStump() {
)JSON",
"threshold"_a = static_cast<ThresholdType>(0),
"leaf_value0"_a = static_cast<LeafOutputType>(1),
"leaf_value1"_a = static_cast<LeafOutputType>(2));
"leaf_value1"_a = static_cast<LeafOutputType>(2),
"threshold_type"_a = TypeInfoToString(TypeInfoFromType<ThresholdType>()),
"leaf_output_type"_a = TypeInfoToString(TypeInfoFromType<LeafOutputType>()));

rapidjson::Document json_dump;
json_dump.Parse(model->DumpAsJSON(false).c_str());
Expand Down Expand Up @@ -193,6 +197,8 @@ void SerializerRoundTrip_TreeStumpLeafVec() {
/* Test correctness of JSON dump */
std::string expected_json_dump_str = fmt::format(R"JSON(
{{
"threshold_type": "{threshold_type}",
"leaf_output_type": "{leaf_output_type}",
"num_feature": 2,
"task_type": "kMultiClf",
"average_tree_output": true,
Expand Down Expand Up @@ -232,7 +238,9 @@ void SerializerRoundTrip_TreeStumpLeafVec() {
"leaf_value0"_a = static_cast<LeafOutputType>(1),
"leaf_value1"_a = static_cast<LeafOutputType>(2),
"leaf_value2"_a = static_cast<LeafOutputType>(2),
"leaf_value3"_a = static_cast<LeafOutputType>(1));
"leaf_value3"_a = static_cast<LeafOutputType>(1),
"threshold_type"_a = TypeInfoToString(TypeInfoFromType<ThresholdType>()),
"leaf_output_type"_a = TypeInfoToString(TypeInfoFromType<LeafOutputType>()));
rapidjson::Document json_dump;
json_dump.Parse(model->DumpAsJSON(false).c_str());

Expand Down Expand Up @@ -290,6 +298,8 @@ void SerializerRoundTrip_TreeStumpCategoricalSplit(
}
std::string expected_json_dump_str = fmt::format(R"JSON(
{{
"threshold_type": "{threshold_type}",
"leaf_output_type": "{leaf_output_type}",
"num_feature": 2,
"task_type": "kRegressor",
"average_tree_output": false,
Expand Down Expand Up @@ -326,7 +336,9 @@ void SerializerRoundTrip_TreeStumpCategoricalSplit(
}}
)JSON",
"leaf_value0"_a = static_cast<LeafOutputType>(2),
"leaf_value1"_a = static_cast<LeafOutputType>(3), "category_list"_a = category_list_str);
"leaf_value1"_a = static_cast<LeafOutputType>(3), "category_list"_a = category_list_str,
"threshold_type"_a = TypeInfoToString(TypeInfoFromType<ThresholdType>()),
"leaf_output_type"_a = TypeInfoToString(TypeInfoFromType<LeafOutputType>()));

rapidjson::Document json_dump;
json_dump.Parse(model->DumpAsJSON(false).c_str());
Expand Down Expand Up @@ -396,6 +408,8 @@ void SerializerRoundTrip_TreeDepth2() {

std::string expected_json_dump_str = fmt::format(R"JSON(
{{
"threshold_type": "{threshold_type}",
"leaf_output_type": "{leaf_output_type}",
"num_feature": 2,
"task_type": "kBinaryClf",
"average_tree_output": false,
Expand Down Expand Up @@ -553,7 +567,9 @@ void SerializerRoundTrip_TreeDepth2() {
"tree2_leaf3"_a = static_cast<LeafOutputType>(3 + 2),
"tree2_leaf4"_a = static_cast<LeafOutputType>(1 + 2),
"tree2_leaf5"_a = static_cast<LeafOutputType>(4 + 2),
"tree2_leaf6"_a = static_cast<LeafOutputType>(2 + 2));
"tree2_leaf6"_a = static_cast<LeafOutputType>(2 + 2),
"threshold_type"_a = TypeInfoToString(TypeInfoFromType<ThresholdType>()),
"leaf_output_type"_a = TypeInfoToString(TypeInfoFromType<LeafOutputType>()));

rapidjson::Document json_dump;
json_dump.Parse(model->DumpAsJSON(false).c_str());
Expand Down
Loading