-
-
Notifications
You must be signed in to change notification settings - Fork 259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make lightgbm work with HyperbandSearchCV #838
Comments
Seems like |
Thanks for raising an issue @vecorro. I've transferred this issue over to the The docs for cc'ing @stsievert ( |
That's correct,
So maybe it's possible for In addition, the Based on the examples in https://ml.dask.org/xgboost.html and https://ml.dask.org/hyper-parameter-search.html#drop-in-replacements-for-scikit-learn, my understanding is that the hyperparameter tuning stuff in So I even if |
Thanks all. Question for @jameslamb: Are you suggesting that I should have used lgb.LGBMClassifier instead of lgb.DaskLGBMClassifier? |
When I moved from incremental hyperparameter optimization with from distributed import Client
from dask_ml.model_selection import RandomizedSearchCV
from dask_ml import datasets
import lightgbm as lgb
if __name__ == "__main__":
X, y = datasets.make_classification(chunks=50)
model = lgb.LGBMRegressor()
param_space = {'n_estimators': range(100, 200, 50),
'max_depth': range(3, 6, 2)}
client = Client()
search = RandomizedSearchCV(model, param_space, n_iter=5)
search.fit(X, y)
print(search.best_score_)
Where have you seen that claim show up? That should be fixed I think.
That's my understanding too, even for the mentioned |
Thanks @stsievert, this helps. I think I had to read the Dask documentation several times to understand the trade-offs that apply to integrations between Dask 3rd party libraries, specially when the dataset is larger than the system memory. I'll use this example you're providing. I'm closing the issue as it looks like lightgbm is not designed to work in the way I was attempting. Thanks. |
I presume you're talking about https://ml.dask.org/hyper-parameter-search.html. Why did you have to read that documentation several times? |
These libraries don't seem to work together. I think that supporting or claiming integration with any new ML library should include support for hyperparameter tuning, that's definitely an MVP.
Here a code and error dump to back up my point:
And the error message
/opt/conda/lib/python3.8/site-packages/sklearn/model_selection/_search.py:285: UserWarning: The total space of parameters 8 is smaller than n_iter=81. Running 8 iterations. For exhaustive searches, use GridSearchCV.
warnings.warn(
/opt/conda/lib/python3.8/site-packages/sklearn/model_selection/_search.py:285: UserWarning: The total space of parameters 8 is smaller than n_iter=34. Running 8 iterations. For exhaustive searches, use GridSearchCV.
warnings.warn(
/opt/conda/lib/python3.8/site-packages/sklearn/model_selection/_search.py:285: UserWarning: The total space of parameters 8 is smaller than n_iter=15. Running 8 iterations. For exhaustive searches, use GridSearchCV.
warnings.warn(
[CV, bracket=0] For training there are between 47 and 47 examples in each chunk
[CV, bracket=1] For training there are between 47 and 47 examples in each chunk
AttributeError Traceback (most recent call last)
in
10
11 search = HyperbandSearchCV(model, param_space, random_state=0, patience=True, verbose=True, test_size=0.05)
---> 12 search.fit(X, y)
/opt/conda/lib/python3.8/site-packages/dask_ml/model_selection/_incremental.py in fit(self, X, y, **fit_params)
715 client = default_client()
716 if not client.asynchronous:
--> 717 return client.sync(self._fit, X, y, **fit_params)
718 return self._fit(X, y, **fit_params)
719
/opt/conda/lib/python3.8/site-packages/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
849 return future
850 else:
--> 851 return sync(
852 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
853 )
/opt/conda/lib/python3.8/site-packages/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
352 if error[0]:
353 typ, exc, tb = error[0]
--> 354 raise exc.with_traceback(tb)
355 else:
356 return result[0]
/opt/conda/lib/python3.8/site-packages/distributed/utils.py in f()
335 if callback_timeout is not None:
336 future = asyncio.wait_for(future, callback_timeout)
--> 337 result[0] = yield future
338 except Exception as exc:
339 error[0] = sys.exc_info()
/opt/conda/lib/python3.8/site-packages/tornado/gen.py in run(self)
760
761 try:
--> 762 value = future.result()
763 except Exception:
764 exc_info = sys.exc_info()
/opt/conda/lib/python3.8/site-packages/dask_ml/model_selection/_hyperband.py in _fit(self, X, y, **fit_params)
399 _brackets_ids = list(reversed(sorted(SHAs)))
400
--> 401 _SHAs = await asyncio.gather(
402 *[SHAs[b]._fit(X, y, **fit_params) for b in _brackets_ids]
403 )
/opt/conda/lib/python3.8/site-packages/dask_ml/model_selection/_incremental.py in _fit(self, X, y, **fit_params)
661
662 with context:
--> 663 results = await fit(
664 self.estimator,
665 self._get_params(),
/opt/conda/lib/python3.8/site-packages/dask_ml/model_selection/_incremental.py in fit(model, params, X_train, y_train, X_test, y_test, additional_calls, fit_params, scorer, random_state, verbose, prefix)
475 A history of all models scores over time
476 """
--> 477 return await _fit(
478 model,
479 params,
/opt/conda/lib/python3.8/site-packages/dask_ml/model_selection/_incremental.py in _fit(model, params, X_train, y_train, X_test, y_test, additional_calls, fit_params, scorer, random_state, verbose, prefix)
266 # async for future, result in seq:
267 for _i in itertools.count():
--> 268 metas = await client.gather(new_scores)
269
270 if log_delay and _i % int(log_delay) == 0:
/opt/conda/lib/python3.8/site-packages/distributed/client.py in _gather(self, futures, errors, direct, local_worker)
1846 exc = CancelledError(key)
1847 else:
-> 1848 raise exception.with_traceback(traceback)
1849 raise exc
1850 if errors == "skip":
/opt/conda/lib/python3.8/site-packages/dask_ml/model_selection/_incremental.py in _partial_fit()
101 if len(X):
102 model = deepcopy(model)
--> 103 model.partial_fit(X, y, **(fit_params or {}))
104
105 meta = dict(meta)
AttributeError: 'DaskLGBMRegressor' object has no attribute 'partial_fit'
The text was updated successfully, but these errors were encountered: