11/* *
2- * Copyright 2023 by XGBoost Contributors
2+ * Copyright 2023-2025, XGBoost Contributors
33 */
44#include " xgboost/multi_target_tree_model.h"
55
6- #include < algorithm> // for copy_n
7- #include < cstddef> // for size_t
8- #include < cstdint> // for int32_t, uint8_t
9- #include < limits> // for numeric_limits
10- #include < string_view> // for string_view
11- #include < utility> // for move
12- #include < vector> // for vector
13-
14- #include " io_utils.h" // for I32ArrayT, FloatArrayT, GetElem, ...
15- #include " xgboost/base.h" // for bst_node_t, bst_feature_t, bst_target_t
16- #include " xgboost/json.h" // for Json, get, Object, Number, Integer, ...
6+ #include < algorithm> // for copy_n
7+ #include < cstddef> // for size_t
8+ #include < cstdint> // for int32_t, uint8_t
9+ #include < limits> // for numeric_limits
10+ #include < string_view> // for string_view
11+ #include < utility> // for move
12+ #include < vector> // for vector
13+
14+ #include " io_utils.h" // for I32ArrayT, FloatArrayT, GetElem, ...
15+ #include " xgboost/base.h" // for bst_node_t, bst_feature_t, bst_target_t
16+ #include " xgboost/json.h" // for Json, get, Object, Number, Integer, ...
1717#include " xgboost/logging.h"
1818#include " xgboost/tree_model.h" // for TreeParam
1919
@@ -30,27 +30,47 @@ MultiTargetTree::MultiTargetTree(TreeParam const* param)
3030 CHECK_GT (param_->size_leaf_vector , 1 );
3131}
3232
33+ MultiTargetTree::MultiTargetTree (MultiTargetTree const & that)
34+ : param_{that.param_ },
35+ left_ (that.left_.Size(), 0 , that.left_.Device()),
36+ right_ (that.right_.Size(), 0 , that.right_.Device()),
37+ parent_ (that.parent_.Size(), 0 , that.parent_.Device()),
38+ split_index_ (that.split_index_.Size(), 0 , that.split_index_.Device()),
39+ default_left_ (that.default_left_.Size(), 0 , that.default_left_.Device()),
40+ split_conds_ (that.split_conds_.Size(), 0 , that.split_conds_.Device()),
41+ weights_ (that.weights_.Size(), 0 , that.weights_.Device()) {
42+ this ->left_ .Copy (that.left_ );
43+ this ->right_ .Copy (that.right_ );
44+ this ->parent_ .Copy (that.parent_ );
45+ this ->split_index_ .Copy (that.split_index_ );
46+ this ->default_left_ .Copy (that.default_left_ );
47+ this ->split_conds_ .Copy (that.split_conds_ );
48+ this ->weights_ .Copy (that.weights_ );
49+ }
50+
3351template <bool typed, bool feature_is_64>
34- void LoadModelImpl (Json const & in, std::vector<float >* p_weights, std::vector<bst_node_t >* p_lefts,
35- std::vector<bst_node_t >* p_rights, std::vector<bst_node_t >* p_parents,
36- std::vector<float >* p_conds, std::vector<bst_feature_t >* p_fidx,
37- std::vector<std::uint8_t >* p_dft_left) {
52+ void LoadModelImpl (Json const & in, HostDeviceVector<float >* p_weights,
53+ HostDeviceVector<bst_node_t >* p_lefts, HostDeviceVector<bst_node_t >* p_rights,
54+ HostDeviceVector<bst_node_t >* p_parents, HostDeviceVector<float >* p_conds,
55+ HostDeviceVector<bst_feature_t >* p_fidx,
56+ HostDeviceVector<std::uint8_t >* p_dft_left) {
3857 namespace tf = tree_field;
3958
40- auto get_float = [&](std::string_view name, std::vector <float >* p_out) {
59+ auto get_float = [&](std::string_view name, HostDeviceVector <float >* p_out) {
4160 auto & values = get<FloatArrayT<typed>>(get<Object const >(in).find (name)->second );
4261 auto & out = *p_out;
43- out.resize (values.size ());
62+ out.Resize (values.size ());
63+ auto & h_out = out.HostVector ();
4464 for (std::size_t i = 0 ; i < values.size (); ++i) {
45- out [i] = GetElem<Number>(values, i);
65+ h_out [i] = GetElem<Number>(values, i);
4666 }
4767 };
4868 get_float (tf::kBaseWeight , p_weights);
4969 get_float (tf::kSplitCond , p_conds);
5070
51- auto get_nidx = [&](std::string_view name, std::vector <bst_node_t >* p_nidx) {
71+ auto get_nidx = [&](std::string_view name, HostDeviceVector <bst_node_t >* p_nidx) {
5272 auto & nidx = get<I32ArrayT<typed>>(get<Object const >(in).find (name)->second );
53- auto & out_nidx = * p_nidx;
73+ auto & out_nidx = p_nidx-> HostVector () ;
5474 out_nidx.resize (nidx.size ());
5575 for (std::size_t i = 0 ; i < nidx.size (); ++i) {
5676 out_nidx[i] = GetElem<Integer>(nidx, i);
@@ -61,15 +81,15 @@ void LoadModelImpl(Json const& in, std::vector<float>* p_weights, std::vector<bs
6181 get_nidx (tf::kParent , p_parents);
6282
6383 auto const & splits = get<IndexArrayT<typed, feature_is_64> const >(in[tf::kSplitIdx ]);
64- p_fidx->resize (splits.size ());
65- auto & out_fidx = * p_fidx;
84+ p_fidx->Resize (splits.size ());
85+ auto & out_fidx = p_fidx-> HostVector () ;
6686 for (std::size_t i = 0 ; i < splits.size (); ++i) {
6787 out_fidx[i] = GetElem<Integer>(splits, i);
6888 }
6989
7090 auto const & dft_left = get<U8ArrayT<typed> const >(in[tf::kDftLeft ]);
71- auto & out_dft_l = * p_dft_left;
72- out_dft_l. resize (dft_left. size () );
91+ p_dft_left-> Resize (dft_left. size ()) ;
92+ auto & out_dft_l = p_dft_left-> HostVector ( );
7393 for (std::size_t i = 0 ; i < dft_left.size (); ++i) {
7494 out_dft_l[i] = GetElem<Boolean>(dft_left, i);
7595 }
@@ -109,19 +129,25 @@ void MultiTargetTree::SaveModel(Json* p_out) const {
109129 U8Array default_left (n_nodes);
110130 F32Array weights (n_nodes * this ->NumTarget ());
111131
132+ auto const & h_left = this ->left_ .ConstHostVector ();
133+ auto const & h_right = this ->right_ .ConstHostVector ();
134+ auto const & h_parent = this ->parent_ .ConstHostVector ();
135+ auto const & h_split_index = this ->split_index_ .ConstHostVector ();
136+ auto const & h_split_conds = this ->split_conds_ .ConstHostVector ();
137+ auto const & h_default_left = this ->default_left_ .ConstHostVector ();
112138 auto save_tree = [&](auto * p_indices_array) {
113139 auto & indices_array = *p_indices_array;
114140 for (bst_node_t nidx = 0 ; nidx < n_nodes; ++nidx) {
115- CHECK_LT (nidx, left_.size ());
116- lefts.Set (nidx, left_ [nidx]);
117- CHECK_LT (nidx, right_.size ());
118- rights.Set (nidx, right_ [nidx]);
119- CHECK_LT (nidx, parent_.size ());
120- parents.Set (nidx, parent_ [nidx]);
121- CHECK_LT (nidx, split_index_.size ());
122- indices_array.Set (nidx, split_index_ [nidx]);
123- conds.Set (nidx, split_conds_ [nidx]);
124- default_left.Set (nidx, default_left_ [nidx]);
141+ CHECK_LT (nidx, left_.Size ());
142+ lefts.Set (nidx, h_left [nidx]);
143+ CHECK_LT (nidx, right_.Size ());
144+ rights.Set (nidx, h_right [nidx]);
145+ CHECK_LT (nidx, parent_.Size ());
146+ parents.Set (nidx, h_parent [nidx]);
147+ CHECK_LT (nidx, split_index_.Size ());
148+ indices_array.Set (nidx, h_split_index [nidx]);
149+ conds.Set (nidx, h_split_conds [nidx]);
150+ default_left.Set (nidx, h_default_left [nidx]);
125151
126152 auto in_weight = this ->NodeWeight (nidx);
127153 auto weight_out = common::Span<float >(weights.GetArray ())
@@ -157,8 +183,8 @@ void MultiTargetTree::SetLeaf(bst_node_t nidx, linalg::VectorView<float const> w
157183 CHECK (this ->IsLeaf (nidx)) << " Collapsing a split node to leaf " << MTNotImplemented ();
158184 auto const next_nidx = nidx + 1 ;
159185 CHECK_EQ (weight.Size (), this ->NumTarget ());
160- CHECK_GE (weights_.size (), next_nidx * weight.Size ());
161- auto out_weight = common::Span< float >( weights_).subspan (nidx * weight.Size (), weight.Size ());
186+ CHECK_GE (weights_.Size (), next_nidx * weight.Size ());
187+ auto out_weight = weights_. HostSpan ( ).subspan (nidx * weight.Size (), weight.Size ());
162188 for (std::size_t i = 0 ; i < weight.Size (); ++i) {
163189 out_weight[i] = weight (i);
164190 }
@@ -169,39 +195,40 @@ void MultiTargetTree::Expand(bst_node_t nidx, bst_feature_t split_idx, float spl
169195 linalg::VectorView<float const > left_weight,
170196 linalg::VectorView<float const > right_weight) {
171197 CHECK (this ->IsLeaf (nidx));
172- CHECK_GE (parent_.size (), 1 );
173- CHECK_EQ (parent_.size (), left_.size ());
174- CHECK_EQ (left_.size (), right_.size ());
198+ CHECK_GE (parent_.Size (), 1 );
199+ CHECK_EQ (parent_.Size (), left_.Size ());
200+ CHECK_EQ (left_.Size (), right_.Size ());
175201
176202 std::size_t n = param_->num_nodes + 2 ;
177203 CHECK_LT (split_idx, this ->param_ ->num_feature );
178- left_.resize (n, InvalidNodeId ());
179- right_.resize (n, InvalidNodeId ());
180- parent_.resize (n, InvalidNodeId ());
204+ left_.Resize (n, InvalidNodeId ());
205+ right_.Resize (n, InvalidNodeId ());
206+ parent_.Resize (n, InvalidNodeId ());
181207
182- auto left_child = parent_.size () - 2 ;
183- auto right_child = parent_.size () - 1 ;
208+ auto left_child = parent_.Size () - 2 ;
209+ auto right_child = parent_.Size () - 1 ;
184210
185- left_[nidx] = left_child;
186- right_[nidx] = right_child;
211+ left_. HostVector () [nidx] = left_child;
212+ right_. HostVector () [nidx] = right_child;
187213
214+ auto & h_parent = parent_.HostVector ();
188215 if (nidx != 0 ) {
189- CHECK_NE (parent_ [nidx], InvalidNodeId ());
216+ CHECK_NE (h_parent [nidx], InvalidNodeId ());
190217 }
191218
192- parent_ [left_child] = nidx;
193- parent_ [right_child] = nidx;
219+ h_parent [left_child] = nidx;
220+ h_parent [right_child] = nidx;
194221
195- split_index_.resize (n);
196- split_index_[nidx] = split_idx;
222+ split_index_.Resize (n);
223+ split_index_. HostVector () [nidx] = split_idx;
197224
198- split_conds_.resize (n, std::numeric_limits<float >::quiet_NaN ());
199- split_conds_[nidx] = split_cond;
225+ split_conds_.Resize (n, std::numeric_limits<float >::quiet_NaN ());
226+ split_conds_. HostVector () [nidx] = split_cond;
200227
201- default_left_.resize (n);
202- default_left_[nidx] = static_cast <std::uint8_t >(default_left);
228+ default_left_.Resize (n);
229+ default_left_. HostVector () [nidx] = static_cast <std::uint8_t >(default_left);
203230
204- weights_.resize (n * this ->NumTarget ());
231+ weights_.Resize (n * this ->NumTarget ());
205232 auto p_weight = this ->NodeWeight (nidx);
206233 CHECK_EQ (p_weight.Size (), base_weight.Size ());
207234 auto l_weight = this ->NodeWeight (left_child);
@@ -217,5 +244,5 @@ void MultiTargetTree::Expand(bst_node_t nidx, bst_feature_t split_idx, float spl
217244}
218245
219246bst_target_t MultiTargetTree::NumTarget () const { return param_->size_leaf_vector ; }
220- std::size_t MultiTargetTree::Size () const { return parent_.size (); }
247+ std::size_t MultiTargetTree::Size () const { return parent_.Size (); }
221248} // namespace xgboost
0 commit comments