Skip to content

Commit

Permalink
PySpark XGBoost integration (#8020)
Browse files Browse the repository at this point in the history
Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
  • Loading branch information
3 people authored Jul 13, 2022
1 parent 8959622 commit 176fec8
Show file tree
Hide file tree
Showing 25 changed files with 3,650 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ jobs:
- name: Install Python packages
run: |
python -m pip install wheel setuptools
python -m pip install pylint cpplint numpy scipy scikit-learn
python -m pip install pylint cpplint numpy scipy scikit-learn pyspark pandas cloudpickle
- name: Run lint
run: |
make lint
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/python_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ jobs:
python-tests-on-macos:
name: Test XGBoost Python package on ${{ matrix.config.os }}
runs-on: ${{ matrix.config.os }}
timeout-minutes: 90
strategy:
matrix:
config:
Expand Down
3 changes: 2 additions & 1 deletion python-package/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,8 @@ def run(self) -> None:
'scikit-learn': ['scikit-learn'],
'dask': ['dask', 'pandas', 'distributed'],
'datatable': ['datatable'],
'plotting': ['graphviz', 'matplotlib']
'plotting': ['graphviz', 'matplotlib'],
"pyspark": ["pyspark", "scikit-learn", "cloudpickle"],
},
maintainer='Hyunsu Cho',
maintainer_email='chohyu01@cs.washington.edu',
Expand Down
22 changes: 22 additions & 0 deletions python-package/xgboost/spark/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# type: ignore
"""PySpark XGBoost integration interface
"""

try:
import pyspark
except ImportError as e:
raise ImportError("pyspark package needs to be installed to use this module") from e

from .estimator import (
SparkXGBClassifier,
SparkXGBClassifierModel,
SparkXGBRegressor,
SparkXGBRegressorModel,
)

__all__ = [
"SparkXGBClassifier",
"SparkXGBClassifierModel",
"SparkXGBRegressor",
"SparkXGBRegressorModel",
]
Loading

0 comments on commit 176fec8

Please sign in to comment.