Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Sep 11, 2023
1 parent 82dcb34 commit 730e86e
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,11 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes {
void ComputeAgg(concurrency::ThreadPool* ttp, const Tensor* X, Tensor* Y, Tensor* label, const AGG& agg) const;

private:
size_t AddNodes(const size_t i, const InlinedVector<NODE_MODE>& cmodes, const InlinedVector<size_t>& truenode_ids, const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids, const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values, const std::vector<int64_t>& nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping, int64_t tree_id, const InlinedVector<TreeNodeElementId>& node_tree_ids);
size_t AddNodes(const size_t i, const InlinedVector<NODE_MODE>& cmodes, const InlinedVector<size_t>& truenode_ids,
const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,
const std::vector<int64_t>& nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping,
int64_t tree_id, const InlinedVector<TreeNodeElementId>& node_tree_ids);
};

template <typename InputType, typename ThresholdType, typename OutputType>
Expand Down Expand Up @@ -277,7 +281,9 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
if (previous_tree_id == -1 || (previous_tree_id != node_tree_ids[i].tree_id)) {
// New tree.
int64_t tree_id = node_tree_ids[i].tree_id;
size_t root_position = AddNodes(i, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, nodes_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
size_t root_position = AddNodes(i, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor,
nodes_values, nodes_missing_value_tracks_true, updated_mapping,
tree_id, node_tree_ids);
roots_.push_back(&nodes_[root_position]);
previous_tree_id = tree_id;
}
Expand Down Expand Up @@ -342,7 +348,12 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
}

template <typename InputType, typename ThresholdType, typename OutputType>
size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(const size_t i, const InlinedVector<NODE_MODE>& cmodes, const InlinedVector<size_t>& truenode_ids, const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids, const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values, const std::vector<int64_t>& nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping, int64_t tree_id, const InlinedVector<TreeNodeElementId>& node_tree_ids) {
size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(const size_t i,
const InlinedVector<NODE_MODE>& cmodes, const InlinedVector<size_t>& truenode_ids,

Check warning on line 352 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h:352: Lines should be <= 120 characters long [whitespace/line_length] [2]
const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,

Check warning on line 353 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h:353: Lines should be <= 120 characters long [whitespace/line_length] [2]
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,

Check warning on line 354 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h:354: Lines should be <= 120 characters long [whitespace/line_length] [2]
const std::vector<int64_t>& nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping,

Check warning on line 355 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h:355: Lines should be <= 120 characters long [whitespace/line_length] [2]
int64_t tree_id, const InlinedVector<TreeNodeElementId>& node_tree_ids) {

Check warning on line 356 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h:356: Lines should be <= 120 characters long [whitespace/line_length] [2]
// Validate this index maps to the same tree_id as the one we should be building.
if (node_tree_ids[i].tree_id != tree_id) {
ORT_THROW("Tree id mismatch. Expected ", tree_id, " but got ", node_tree_ids[i].tree_id, " at position ", i);
Expand All @@ -351,7 +362,7 @@ size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(const
if (updated_mapping[i] != 0) {
// In theory we should not accept any cycles, however in practice LGBM conversion implements set membership via a
// series of "Equals" nodes, with the true branches directed at the same child node (a cycle).
// We may instead seek to formalise set membership in the future.
// We may instead seek to formalize set membership in the future.
return updated_mapping[i];
}

Expand All @@ -372,11 +383,14 @@ size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(const
}
nodes_.push_back(std::move(node));
if (nodes_[node_pos].is_not_leaf()) {
size_t false_branch = AddNodes(falsenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
size_t false_branch = AddNodes(falsenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids,
nodes_values_as_tensor, node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);

Check warning on line 387 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h:387: Lines should be <= 120 characters long [whitespace/line_length] [2]
if (false_branch != node_pos + 1) {
ORT_THROW("False node must always be the next node, but it isn't at index ", node_pos, " with flags ", static_cast<int>(nodes_[node_pos].flags));
ORT_THROW("False node must always be the next node, but it isn't at index ", node_pos, " with flags ",
static_cast<int>(nodes_[node_pos].flags));
}
size_t true_branch = AddNodes(truenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
size_t true_branch = AddNodes(truenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids,
nodes_values_as_tensor, node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);

Check warning on line 393 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h:393: Lines should be <= 120 characters long [whitespace/line_length] [2]
// We don't need to store the false branch pointer since we know it is always in the immediate next entry in nodes_.
// nodes_[node_pos].falsenode_inc_or_n_weights.ptr = &nodes_[false_branch];
nodes_[node_pos].truenode_or_weight.ptr = &nodes_[true_branch];
Expand Down

0 comments on commit 730e86e

Please sign in to comment.