-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP][PySpark] Add XGBoost PySpark API support #7709
Closed
Closed
Changes from 5 commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
bb17dc9
[PySpark] Add XGBoost PySpark API support
wbo4958 c4d8f60
resolve lint error
wbo4958 f541f0d
add integration tests
wbo4958 3a0eef1
add readme for integration tests
wbo4958 473856f
add regressor and fix bug
wbo4958 314eb22
update the tests
wbo4958 89dda8f
fix couldn't find XGBoostClassifier issue
wbo4958 7c1496f
add tutorial for xgboost-pyspark
wbo4958 bb07127
move pyspark doc to python packages
wbo4958 92711c3
Formats.
trivialfis ca8a547
Sphinx
trivialfis a7cffcc
Some pylint fixes.
trivialfis 46e4853
Export documents.
trivialfis 8a145b1
Extract iris example, generate module on the fly.
trivialfis e1f6c69
typing.
trivialfis 5a2ef75
Fix CI script.
trivialfis 5651246
Warning.
trivialfis 5b58181
Pylint errors.
trivialfis 5c0c1fc
Action.
trivialfis 9be06b4
Move.
trivialfis f14730a
add more parameters
wbo4958 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# XGBoost4j Pyspark API Integration Tests | ||
|
||
This integration tests framework refers to [Nvidia/spark-rapids/integration_tests](https://github.com/NVIDIA/spark-rapids/tree/branch-22.04/integration_tests). | ||
|
||
## Setting Up the Environment | ||
|
||
The tests are based off of `pyspark` and `pytest` running on Python 3. There really are | ||
only a small number of Python dependencies that you need to install for the tests. The | ||
dependencies also only need to be on the driver. You can install them on all nodes | ||
in the cluster but it is not required. | ||
|
||
- install python dependencies | ||
|
||
``` bash | ||
pip install pytest numpy scipy | ||
``` | ||
|
||
- install xgboost python package | ||
|
||
XGBoost4j pyspark APIs are in xgboost python package, so we need to install it first | ||
|
||
``` bash | ||
cd xgboost/python-packages | ||
python setup.py install | ||
``` | ||
|
||
- compile xgboost jvm packages | ||
|
||
``` bash | ||
cd xgboost/jvm-packages | ||
mvn -Dmaven.test.skip=true -DskipTests clean package | ||
``` | ||
|
||
- run integration tests | ||
|
||
```bash | ||
cd xgboost/jvm-packages/integration-tests | ||
./run_pyspark_from_build.sh | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (c) 2022 by Contributors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
def pytest_addoption(parser): | ||
"""Pytest hook to define command line options for pytest""" | ||
parser.addoption( | ||
"--platform", action="store", default="cpu", help="optional values [ cpu, gpu ]" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
; Copyright (c) 2022 by Contributors | ||
; | ||
; Licensed under the Apache License, Version 2.0 (the "License"); | ||
; you may not use this file except in compliance with the License. | ||
; You may obtain a copy of the License at | ||
; | ||
; http://www.apache.org/licenses/LICENSE-2.0 | ||
; | ||
; Unless required by applicable law or agreed to in writing, software | ||
; distributed under the License is distributed on an "AS IS" BASIS, | ||
; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
; See the License for the specific language governing permissions and | ||
; limitations under the License. | ||
|
||
[pytest] | ||
markers = | ||
skip_by_platform(platform): skip test for the given platform |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# | ||
# Copyright (c) 2022 by Contributors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import pytest | ||
|
||
|
||
@pytest.fixture | ||
def platform(request): | ||
return request.config.getoption('platform') | ||
|
||
|
||
# https://stackoverflow.com/questions/28179026/how-to-skip-a-pytest-using-an-external-fixture | ||
@pytest.fixture(autouse=True) | ||
def skip_by_platform(request, platform): | ||
if request.node.get_closest_marker('skip_platform'): | ||
if request.node.get_closest_marker('skip_platform').args[0] == platform: | ||
pytest.skip('skipped on this platform: {}'.format(platform)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# | ||
# Copyright (c) 2022 by Contributors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
from xgboost.spark import XGBoostClassifier | ||
|
||
|
||
def test_xgboost_parameters_from_dictionary(): | ||
xgb_params = {'objective': 'multi:softprob', | ||
'treeMethod': 'hist', | ||
'numWorkers': 1, | ||
'labelCol': 'classIndex', | ||
'featuresCol': 'features', | ||
'numRound': 100, | ||
'numClass': 3} | ||
xgb = XGBoostClassifier(**xgb_params) | ||
assert xgb.getObjective() == 'multi:softprob' | ||
assert xgb.getTreeMethod() == 'hist' | ||
assert xgb.getNumWorkers() == 1 | ||
assert xgb.getLabelCol() == 'classIndex' | ||
assert xgb.getFeaturesCol() == 'features' | ||
assert xgb.getNumRound() == 100 | ||
assert xgb.getNumClass() == 3 | ||
|
||
|
||
def test_xgboost_set_parameter(): | ||
xgb = XGBoostClassifier() | ||
xgb.setObjective('multi:softprob') | ||
xgb.setTreeMethod('hist') | ||
xgb.setNumWorkers(1) | ||
xgb.setLabelCol('classIndex') | ||
xgb.setFeaturesCol('features') | ||
xgb.setNumRound(100) | ||
xgb.setNumClass(3) | ||
assert xgb.getObjective() == 'multi:softprob' | ||
assert xgb.getTreeMethod() == 'hist' | ||
assert xgb.getNumWorkers() == 1 | ||
assert xgb.getLabelCol() == 'classIndex' | ||
assert xgb.getFeaturesCol() == 'features' | ||
assert xgb.getNumRound() == 100 | ||
assert xgb.getNumClass() == 3 |
96 changes: 96 additions & 0 deletions
96
jvm-packages/integration-tests/python/spark_init_internal.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# | ||
# Copyright (c) 2022 by Contributors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import os | ||
|
||
try: | ||
import pyspark | ||
except ImportError as error: | ||
import findspark | ||
findspark.init() | ||
import pyspark | ||
|
||
_DRIVER_ENV = 'PYSP_TEST_spark_driver_extraJavaOptions' | ||
|
||
def _spark__init(): | ||
# Force the RapidsPlugin to be enabled, so it blows up if the classpath is not set properly | ||
# DO NOT SET ANY OTHER CONFIGS HERE!!! | ||
# due to bugs in pyspark/pytest it looks like any configs set here | ||
# can be reset in the middle of a test if specific operations are done (some types of cast etc) | ||
_sb = pyspark.sql.SparkSession.builder | ||
|
||
for key, value in os.environ.items(): | ||
if key.startswith('PYSP_TEST_') and key != _DRIVER_ENV: | ||
_sb.config(key[10:].replace('_', '.'), value) | ||
|
||
driver_opts = os.environ.get(_DRIVER_ENV, "") | ||
|
||
_sb.config('spark.driver.extraJavaOptions', driver_opts) | ||
_handle_event_log_dir(_sb, 'gw0') | ||
|
||
_s = _sb.appName('xgboost4j pyspark integration tests').getOrCreate() | ||
# TODO catch the ClassNotFound error that happens if the classpath is not set up properly and | ||
# make it a better error message | ||
_s.sparkContext.setLogLevel("WARN") | ||
return _s | ||
|
||
|
||
def _handle_event_log_dir(sb, wid): | ||
if os.environ.get('SPARK_EVENTLOG_ENABLED', str(True)).lower() in [ | ||
str(False).lower(), 'off', '0' | ||
]: | ||
print('Automatic configuration for spark event log disabled') | ||
return | ||
|
||
spark_conf = pyspark.SparkConf() | ||
master_url = os.environ.get('PYSP_TEST_spark_master', | ||
spark_conf.get("spark.master", 'local')) | ||
event_log_config = os.environ.get('PYSP_TEST_spark_eventLog_enabled', | ||
spark_conf.get('spark.eventLog.enabled', str(False).lower())) | ||
event_log_codec = os.environ.get('PYSP_TEST_spark_eventLog_compression_codec', 'zstd') | ||
|
||
if not master_url.startswith('local') or event_log_config != str(False).lower(): | ||
print("SPARK_EVENTLOG_ENABLED is ignored for non-local Spark master and when " | ||
"it's pre-configured by the user") | ||
return | ||
d = "./eventlog_{}".format(wid) | ||
if not os.path.exists(d): | ||
os.makedirs(d) | ||
|
||
print('Spark event logs will appear under {}. Set the environmnet variable ' | ||
'SPARK_EVENTLOG_ENABLED=false if you want to disable it'.format(d)) | ||
|
||
sb\ | ||
.config('spark.eventLog.dir', "file://{}".format(os.path.abspath(d))) \ | ||
.config('spark.eventLog.compress', True) \ | ||
.config('spark.eventLog.enabled', True) \ | ||
.config('spark.eventLog.compression.codec', event_log_codec) | ||
|
||
_spark = _spark__init() | ||
|
||
|
||
def get_spark_i_know_what_i_am_doing(): | ||
""" | ||
Get the current SparkSession. | ||
This should almost never be called directly instead you should call | ||
with_spark_session, with_cpu_session, or with_gpu_session for spark_session. | ||
This is to guarantee that the session and it's config is setup in a repeatable way. | ||
""" | ||
return _spark | ||
|
||
|
||
def spark_version(): | ||
return _spark.version |
63 changes: 63 additions & 0 deletions
63
jvm-packages/integration-tests/python/xgboost_classifier_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# | ||
# Copyright (c) 2022 by Contributors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
from pyspark.ml.feature import StringIndexer | ||
from pyspark.ml.linalg import Vectors | ||
from xgboost.spark import XGBoostClassifier, XGBoostClassificationModel | ||
|
||
from spark_init_internal import get_spark_i_know_what_i_am_doing | ||
|
||
|
||
def test_save_xgboost_classifier(): | ||
params = { | ||
'objective': 'binary:logistic', | ||
'numRound': 5, | ||
'numWorkers': 2, | ||
'treeMethod': 'hist' | ||
} | ||
classifier = XGBoostClassifier(**params) | ||
classifier.write().overwrite().save("/tmp/xgboost-integration-tests/xgboost-classifier") | ||
classifier1 = XGBoostClassifier.load("/tmp/xgboost-integration-tests/xgboost-classifier") | ||
assert classifier1.getObjective() == 'binary:logistic' | ||
assert classifier1.getNumRound() == 5 | ||
assert classifier1.getNumWorkers() == 2 | ||
assert classifier1.getTreeMethod() == 'hist' | ||
|
||
|
||
def test_xgboost_regressor_training_without_error(): | ||
spark = get_spark_i_know_what_i_am_doing() | ||
df = spark.createDataFrame([ | ||
("a", Vectors.dense([1.0, 2.0, 3.0, 4.0, 5.0])), | ||
("b", Vectors.dense([5.0, 6.0, 7.0, 8.0, 9.0]))], | ||
["label", "features"]) | ||
label_name = 'label_indexed' | ||
string_indexer = StringIndexer(inputCol="label", outputCol=label_name).fit(df) | ||
indexed_df = string_indexer.transform(df).select(label_name, 'features') | ||
params = { | ||
'objective': 'binary:logistic', | ||
'numRound': 5, | ||
'numWorkers': 1, | ||
'treeMethod': 'hist' | ||
} | ||
classifier = XGBoostClassifier(**params) \ | ||
.setLabelCol(label_name) \ | ||
.setFeaturesCol('features') | ||
classifier.write().overwrite().save("/tmp/xgboost-integration-tests/xgboost-classifier") | ||
classifier1 = XGBoostClassifier.load("/tmp/xgboost-integration-tests/xgboost-classifier") | ||
model = classifier1.fit(indexed_df) | ||
model.write().overwrite().save("/tmp/xgboost-integration-tests/xgboost-classifier-model") | ||
model1 = XGBoostClassificationModel.load("/tmp/xgboost-integration-tests/xgboost-classifier-model") | ||
model1.transform(df).show() |
58 changes: 58 additions & 0 deletions
58
jvm-packages/integration-tests/python/xgboost_regressor_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# | ||
# Copyright (c) 2022 by Contributors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
from pyspark.ml.linalg import Vectors | ||
from xgboost.spark import XGBoostRegressor, XGBoostRegressionModel | ||
|
||
from python.spark_init_internal import get_spark_i_know_what_i_am_doing | ||
|
||
|
||
def test_save_xgboost_regressor(): | ||
params = { | ||
'objective': 'reg:squarederror', | ||
'numRound': 5, | ||
'numWorkers': 2, | ||
'treeMethod': 'hist' | ||
} | ||
classifier = XGBoostRegressor(**params) | ||
classifier.write().overwrite().save("/tmp/xgboost-integration-tests/xgboost-regressor") | ||
classifier1 = XGBoostRegressor.load("/tmp/xgboost-integration-tests/xgboost-regressor") | ||
assert classifier1.getObjective() == 'reg:squarederror' | ||
assert classifier1.getNumRound() == 5 | ||
assert classifier1.getNumWorkers() == 2 | ||
assert classifier1.getTreeMethod() == 'hist' | ||
|
||
|
||
def test_xgboost_regressor_training_without_error(): | ||
spark = get_spark_i_know_what_i_am_doing() | ||
df = spark.createDataFrame([ | ||
(1.0, Vectors.dense(1.0)), | ||
(0.0, Vectors.dense(2.0))], ["label", "features"]) | ||
params = { | ||
'objective': 'reg:squarederror', | ||
'numRound': 5, | ||
'numWorkers': 1, | ||
'treeMethod': 'hist' | ||
} | ||
regressor = XGBoostRegressor(**params) \ | ||
.setLabelCol('label') \ | ||
.setFeaturesCol('features') | ||
regressor.write().overwrite().save("/tmp/xgboost-integration-tests/xgboost-regressor") | ||
regressor1 = XGBoostRegressor.load("/tmp/xgboost-integration-tests/xgboost-regressor") | ||
model = regressor1.fit(df) | ||
model.write().overwrite().save("/tmp/xgboost-integration-tests/xgboost-regressor-model") | ||
model1 = XGBoostRegressionModel.load("/tmp/xgboost-integration-tests/xgboost-regressor-model") | ||
model1.transform(df).show() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, what's the specific case that we should use this instead of with cpu session? Could you please document it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the API was copied from spark-rapids, and now I've changed it according to need of xgboost.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wbo4958 Okay .... but I'm not sure how's that relevant to the question?