-
Notifications
You must be signed in to change notification settings - Fork 3.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[R-package] Keep row names in output from predict
#4977
Conversation
Failing CI checks on mac are due to a dependency that fails to install and not related to the changes in this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks very much for this proposal.
I don't support adding this feature. I'm not convinced that the (in my opinion) slight convenience is worth the additional code complexity and maintenance burden.
I also want to be especially protective of predict()
code paths, since I expect user code creating predictions to be much more latency-sensitive than training code. (related conversation: #4909 (review)).
We can get some other opinions though. @mayer79 @StrikerRUS @Laurae2 if any of you are interested, could you give your opinion on this proposed change?
The latency added by this feature is around 30 microseconds in my setup (~40 when objects have names), which is very low compared to the time it takes to make a single-row prediction which can reach up to a few miliseconds depending on the choice of hyperparameters and arguments (and lower than the overhead added from syntax-related coding choices such as usage of R6). |
I'm sorry, I very far away from the R-lang world, so can't write thoughtful comment here. If keeping row names is something like a standard behavior for other R packages, then we can add this feature. If not, then I agree with @jameslamb . |
Row names are kept in base R and most recommended/core packages for decision trees (e.g. Did a bit of testing with popular non-core packages, and found that many decision-tree packages do not keep them (e.g. OTOH tried a few packages for linear models off the top of my mind, and all of them kept the row names ( |
Sorry for the delayed response, I recently returned from traveling. Thanks for checking other packages! Since projects like However, I'd only feel confident taking on this behavior if unit tests were added confirming that row names are handled correctly for every combination of the following:
If you have time and interest in adding such tests, I'd support this PR. If not, then I think this PR should be closed and we could convert it to a feature request to be added to #2302 and picked up at some point in the future. |
Added more tests, but left out CSV tests because predicting on them is currently broken. |
The failing checks are due to the vignette. It fails to make a prediction at L75: summary(predict(fit, X))
Probably coming from here: LightGBM/R-package/R/lgb.Predictor.R Line 151 in 6b56a90
It is not related to this PR since the error happens before reaching the code that was changed. |
Nevermind this comment, there was another change I didn't notice in the merge conflict. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks very much for adding tests and documenting the CSV issue in #5093.
The tests made me realize that for input data with row names, the value of reshape
isn't sufficient to determine whether or not the output should have row names.
reshape
's value only matters for multi-class classification...for regression and binary classification tasks, the number of rows in the output from predict()
should be the same as the number of rows in newdata
, so for those objectives row names should be preserved regardless of the value of reshape
.
Could you please do the following:
- move the code you've added into
predict.lgb.Booster()
intoPredictor$predict()
- ensure that for binary classification and regression, row names are always preserved
For reference, here's the relevant code in Predictor$predict()
that is figuring out whether or not the predictions have the same number of elements as the input data has rows.
LightGBM/R-package/R/lgb.Predictor.R
Lines 215 to 231 in 17d4e00
# Get number of cases per row | |
npred_per_case <- length(preds) / num_row | |
# Data reshaping | |
if (predleaf | predcontrib) { | |
# Predict leaves only, reshaping is mandatory | |
preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE) | |
} else if (reshape && npred_per_case > 1L) { | |
# Predict with data reshaping | |
preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE) | |
} |
@@ -111,3 +113,117 @@ test_that("start_iteration works correctly", { | |||
pred_leaf2 <- predict(bst, test$data, start_iteration = 0L, num_iteration = end_iter + 1L, predleaf = TRUE) | |||
expect_equal(pred_leaf1, pred_leaf2) | |||
}) | |||
|
|||
test_that("predictions keep row names from the data", { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these tests look great, thank you! To speed up debugging if they break due to changes in future PRs, can you please break them into 3 test cases?
- "predict() keeps row names from data (binary classification)"
- "predict() keeps row names from data (multi-class classification)"
- "predict() keeps row names from data (regression)"
To avoid duplicating test helper code and to make it clear that those methods were written just for these tests, please move them out of the test case and rename them to .expect_has_row_names()
and .expect_doesnt_have_row_names()
.
Updated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks very much for the changes! But please see my suggestion about simplifying the test code. In its current state, the indirection with .expect_row_names_kept()
makes it difficult to understand what is being tested.
if (!multiclass || NROW(X) == NROW(pred)) { | ||
.expect_has_row_names(pred, X) | ||
} else { | ||
.expect_doesnt_have_row_names(pred) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm really struggling to understand these test cases, since it requires reasoning about what was passed in to .expect_row_names_kept()
from within each of the test_that(...)
calls, and since these if
statements rely on the fact that pred
here happens to reference the predictions produced with reshape = FALSE
and none of rawscore
/ predleaf
/ predcontrib
.
To make this easier to follow (both for reviewing and for future contributors), can you please remove the use of this .expect_row_names_kept()
function and instead explicitly include checks in each of the test_that(...)
calls? Those checks should not need to use NROW(pred)
at all. For test code, I'd prefer some code duplication in exchange for it being easier to understand what is being tested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what you want me to do there. That pred
does reference predictions with different combinations of reshape
/ rawscore
/ predcontrib
/ predleaf
(all predictions generated there are named pred
).
You asked me to add tests for different values of reshape
, but reshape
might get ignored depending on the rest of the parameters, and thus whether the output is meant to have row names or not, depends on the shape of what's actually returned in the end by the prediction function. If I were to adjust the tests by hard-coding the reshape logic in them, they then would test both the row names and the reshape logic together, which I don't think is ideal. The most logical way of testing the row names in a way that's independent of reshape
would be by checking the number of rows in the data and the predictions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they then would test both the row names and the reshape logic together, which I don't think is ideal
"hard-coding the reshape logic", i.e. not changing the test expectation based on a condition like NROW(X) == NROW(pred)
, is exactly what I'm asking for.
We know, for example, that for reshape = FALSE
and a multi-class objective, the output of predict()
is expected to have NROW(X) * number_of_classes
elements, and therefore shouldn't have row names. It's desirable for the tests to break if that changes.
As a maintainer on this project, I believe I'm able to push commits onto branches from forks associated with open PRs. Would you like me to just directly push the testing changes I'm proposing to this branch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, would be better if you modify the tests yourself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok no prob! I'll push a commit in the next day or two.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice change, thanks for proposing it and for helping by adding such comprehensive tests!
@jmoralez could you review this whenever you have time? I think it's good to be merged, but since I pushed some commits to this PR, I don't want to click "merge" without another review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could the testing logic
# dense matrix with row names
pred <- predict(bst, X)
.expect_has_row_names(pred, X)
pred <- predict(bst, X, rawscore = TRUE)
.expect_has_row_names(pred, X)
pred <- predict(bst, X, predleaf = TRUE)
.expect_has_row_names(pred, X)
pred <- predict(bst, X, predcontrib = TRUE)
.expect_has_row_names(pred, X)
# dense matrix without row names
Xcopy <- X
row.names(Xcopy) <- NULL
pred <- predict(bst, Xcopy)
.expect_doesnt_have_row_names(pred)
# sparse matrix with row names
Xcsc <- as(X, "CsparseMatrix")
pred <- predict(bst, Xcsc)
.expect_has_row_names(pred, Xcsc)
pred <- predict(bst, Xcsc, rawscore = TRUE)
.expect_has_row_names(pred, Xcsc)
pred <- predict(bst, Xcsc, predleaf = TRUE)
.expect_has_row_names(pred, Xcsc)
pred <- predict(bst, Xcsc, predcontrib = TRUE)
.expect_has_row_names(pred, Xcsc)
# sparse matrix without row names
Xcopy <- Xcsc
row.names(Xcopy) <- NULL
pred <- predict(bst, Xcopy)
.expect_doesnt_have_row_names(pred)
be extracted into a function?
Well, it was in a function before... |
@jmoralez I specifically wanted to NOT do that (#4977 (comment)) to make it more explicit exactly what the expectation is for each of these settings. For this many combinations of possibilities, I think having each test be self-contained is preferable to the indirection introduced by putting that logic into a function. |
I think with the |
Fair enough, you're right that with I'll push a commit right now moving it back into a function. I don't want to hold up this PR just for a style consideration like this that we can change in the future. |
This pull request has been automatically locked since there has not been any recent activity since it was closed. To start a new related discussion, open a new issue at https://github.com/microsoft/LightGBM/issues including a reference to this. |
ref #4968
Calling
predict
on a lightgbm model object will produce a result without row names. If the data on which predictions are being made has row names, chances are that the user will want to keep those in the prediction output. This PR keeps the row names in the prediction output when the input data has them.