Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check the shape for mat, csr and csc in prediction #2464

Merged
merged 12 commits into from
Oct 3, 2019

Conversation

guolinke
Copy link
Collaborator

@guolinke guolinke commented Sep 27, 2019

partially improve #812

@guolinke guolinke requested a review from chivee as a code owner September 27, 2019 12:02
@guolinke guolinke changed the title check the shape for mat, csr and csc check the shape for mat, csr and csc in prediction Sep 27, 2019
@guolinke guolinke requested a review from StrikerRUS September 27, 2019 12:08
@StrikerRUS
Copy link
Collaborator

Close-reopen for CI.

@StrikerRUS StrikerRUS closed this Sep 27, 2019
@StrikerRUS StrikerRUS reopened this Sep 27, 2019
Copy link
Collaborator

@StrikerRUS StrikerRUS left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can something similar be done for predictions for data from a file?

src/c_api.cpp Outdated Show resolved Hide resolved
src/c_api.cpp Outdated Show resolved Hide resolved
src/c_api.cpp Outdated Show resolved Hide resolved
src/c_api.cpp Show resolved Hide resolved
src/c_api.cpp Show resolved Hide resolved
@StrikerRUS
Copy link
Collaborator

Please add a simple test from the original issue:

import numpy as np
import lightgbm as lgb

x_data = np.random.rand(100, 10)
x_bad_data = np.random.rand(100, 11)
y_data =  np.random.rand(100) > .5
self.assertNotEqual(x_data.shape[-1], x_bad_data.shape[-1])
train_dataset = lgb.Dataset(x_data, y_data)
bst = lgb.train({'objective': 'binary'}, train_dataset)
with np.testing.assert_raises_regex(lgb.basic.LightGBMError,
                                    'The number of features in data*'):
    bst.predict(x_bad_data)

@guolinke
Copy link
Collaborator Author

There are more changes than I expected. So I think we should test all cases, including the the mat, libsvm file and CSR format.
@StrikerRUS could you help for the test cases?

@StrikerRUS
Copy link
Collaborator

@guolinke

could you help for the test cases?

Sure!

What do you think about changing the original type to avoid casting? #2464 (comment)

Does this PR already include that check?

@guolinke
Copy link
Collaborator Author

guolinke commented Sep 29, 2019

Does this PR already include that check?

Yeah, as the zero-based and one-based libsvm format have the different number of columns.

@StrikerRUS StrikerRUS force-pushed the predict_shape_check branch from 3dc9c08 to 4b1f0c0 Compare October 2, 2019 22:47
@StrikerRUS
Copy link
Collaborator

StrikerRUS commented Oct 2, 2019

@guolinke I added tests in the latest commit.

As we do not support a case when data for prediction is a list of arrays, I tried to check it for validation. And it seems that validation shape is not covered by this PR, right?

bad_valid_data = train_data.create_valid(bad_X_test, label=y_test)
bst.add_valid(bad_valid_data, "valid_bad")
bst.eval_valid()  # no error risen

@guolinke
Copy link
Collaborator Author

guolinke commented Oct 3, 2019

yeah, validation is not converted for now.

@guolinke
Copy link
Collaborator Author

guolinke commented Oct 3, 2019

due the need of pr #2485, I will merge this

@guolinke guolinke merged commit dee7215 into master Oct 3, 2019
@StrikerRUS StrikerRUS deleted the predict_shape_check branch October 3, 2019 18:27
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants