Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Getting error from XGB model when loading the model back and passing the booster arg to the constructor #651

Closed
rnyak opened this issue Aug 15, 2022 · 4 comments
Assignees
Labels
bug Something isn't working P1
Milestone

Comments

@rnyak
Copy link
Contributor

rnyak commented Aug 15, 2022

Bug description

I get the following error from when I load back the saved XGB model and then pass the booster argument to the constructor via XGBoost(schema, booster=reloaded_model)

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [10], in <cell line: 7>()
      4 model.booster.save_model('mymodel.xgb')
      5 bst.load_model('mymodel.xgb')  # load model
----> 7 XGBoost(schema, booster=bst)

File /usr/local/lib/python3.8/dist-packages/merlin/models/xgb/__init__.py:66, in XGBoost.__init__(self, schema, target_columns, qid_column, objective, booster, **params)
     64 if isinstance(target_columns, str):
     65     target_columns = [target_columns]
---> 66 self.target_columns = target_columns or get_targets(schema, target_tag)
     67 self.feature_columns = get_features(schema, self.target_columns)
     69 if objective.startswith("rank") and qid_column is None:

File /usr/local/lib/python3.8/dist-packages/merlin/models/xgb/__init__.py:248, in get_targets(schema, target_tag)
    246 if len(targets) >= 1:
    247     return targets.column_names
--> 248 raise ValueError(
    249     f"No target columns in the dataset schema with tags TARGET and {target_tag.name}"
    250 )

ValueError: No target columns in the dataset schema with tags TARGET and REGRESSION

Steps/Code to reproduce bug

You can repro the error by running the code below:


from merlin.datasets.entertainment import get_movielens
import xgboost as xgb

train, valid = get_movielens(variant='ml-100k')
# remove cols from schema
schema = train.schema.without(['title', 'rating'])
xgb_booster_params = {
    'objective':'binary:logistic',
    'tree_method':'gpu_hist',
}

xgb_train_params = {
    'num_boost_round': 100,
    'verbose_eval': 20,
    'early_stopping_rounds': 10,
}

with Distributed():
    model = XGBoost(schema, **xgb_booster_params)
    model.fit(
        train,
        evals=[(valid, 'validation_set'),],
        **xgb_train_params
    )
    metrics = model.evaluate(valid)


import os
bst = xgb.Booster()  # init model
model.booster.save_model('mymodel.xgb')
bst.load_model('mymodel.xgb')  # load model

XGBoost(schema, booster=bst)

Expected behavior

I should be able to load back the saved XGB model and do offline inference with model.predict()

Environment details

  • Merlin version: merlin-tensorflow:22.07 (with the latest main branches pulled)
  • Platform:
  • Python version:
  • PyTorch version (GPU?):
  • Tensorflow version (GPU?):

Additional context

@rnyak rnyak added bug Something isn't working status/needs-triage labels Aug 15, 2022
@rnyak rnyak added the P1 label Aug 15, 2022
@rnyak
Copy link
Contributor Author

rnyak commented Aug 15, 2022

@radekosmulski fyi.

@radekosmulski
Copy link
Contributor

Would be great to have this in 22.08, but not sure that is still feasible 🙂

@oliverholworthy
Copy link
Member

For this reloading to work this way the booster params need to be passed through as well on the last line.

XGBoost(schema, booster=bst, **xgb_booster_params )

@oliverholworthy
Copy link
Member

oliverholworthy commented Aug 17, 2022

Closing this as resolved by the above comment. Or using the new save/load methods in #656

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working P1
Projects
None yet
Development

No branches or pull requests

3 participants