-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Unifying prediction API. #6632
Comments
Have you managed to figure this out? |
Yup, I wrote a test for predict leaf |
So what do you think of the RFC? |
@trivialfis Is this RFC connected to #6091? |
No, this is just for the predict function, does not affect feature importance. |
Some more notes on the prediction function:
|
Would these proposed functions be thread safe, or is that argument dependent? |
The inplace prediction function is thread safe and lock free. Normal prediction is argument dependent currently, but should not be too difficult to make it fully thread safe. |
Ok. I'm in favour of making it completely thread safe, as that's one of the things that's frustrating about the current XGBoost4J API. It would be nice to get rid of If it's possible to help do that without having to understand all the internals of the library then I'm happy to help, but I think the last time I looked it required understanding a lot of the internals to figure out the thread safety of a particular bit of code. |
Background
XGBoost has a number of prediction functions exposed on C API and various language bindings. Including prediction on DMatrix and inplace prediction. Inside these prediction functions, we also have a number of prediction types, including
value
,margin
,leaf
,contribs
andinteraction
. The outputs of them have different meanings and shapes. Right now language bindings are responsible for figuring that out, which has became a burden since we have introduced dask interface on top of Python (#6614). Also, the output shape is quite complicated, I have difficult time on figuring out how to slice up the output array frompred_leaf
. Aside from these, there are also different prediction parameters, includingntree_limit
,n_layers
,is_training
, also a never used parametertree_begin
. Lastly if the prediction is carried inplace, some more information likemissing
andbase_margin
needs to be carried into implementation.Requirments
We unify the prediction functions of C API in a consistent manner. The new prediction functions should be able to figure out the output shape for language bindings, and should be extensible to future feature addition. At the same time, we need to look into what are the parameters that we don't want, like
ntree_limit
. Since this is designing at C API level, we should try to comply to some C programming practices on API design.At the same time, we are not near next major release (2.0), so old API should be kept for compatibility.
Proposal
The functions should output correct shape on
out_shape
parameter, andPredictParam
will be responsible for future extensibility. Additionally we can cooperate more information into input and output, like device ordinal, data slicing etc. This RFC is for whether should we be carrying out this refactor.Brief notes
Some more notes on the prediction function:
@hcho3 @RAMitchell
The text was updated successfully, but these errors were encountered: