Skip to content

Commit

Permalink
[jvm-packages] initial pyspark api
Browse files Browse the repository at this point in the history
  • Loading branch information
thesuperzapper authored and Mathew Wicks committed Oct 9, 2019
1 parent 8097718 commit 119282e
Show file tree
Hide file tree
Showing 10 changed files with 550 additions and 0 deletions.
Empty file.
Empty file.
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#
# Copyright (c) 2019 by Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys

from sparkxgb import xgboost

# Allows Pipeline()/PipelineModel() with XGBoost stages to be loaded from disk.
# Needed because they try to import Python objects from their Java location.
sys.modules['ml.dmlc.xgboost4j.scala.spark'] = xgboost
57 changes: 57 additions & 0 deletions jvm-packages/xgboost4j-spark/src/main/resources/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#
# Copyright (c) 2019 by Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from codecs import open
from os import path
from setuptools import setup, find_packages

# Read the long description from README.MD
here = path.abspath(path.dirname(__file__))
with open(path.join(here, 'README.md'), encoding='utf-8') as f:
long_description = f.read()

setup(
name='spark-xgboost',
version='0.90',
description='spark-xgboost is the PySpark package for XGBoost',

long_description=long_description,
long_description_content_type='text/markdown',
url='https://xgboost.ai/',
author='DMLC',
classifiers=[
# Project Maturity
'Development Status :: 5 - Production/Stable',

# Intended Users
'Intended Audience :: Developers',
'Topic :: Software Development :: Build Tools',

# License
'License :: OSI Approved :: Apache Software License',

# Supported Python Versions
'Programming Language :: Python :: 2',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
],
keywords='development spark xgboost',

packages=find_packages(),
include_package_data=False
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#
# Copyright (c) 2019 by Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from sparkxgb import xgboost
from sparkxgb.xgboost import XGBoostClassifier, XGBoostRegressor, XGBoostClassificationModel, XGBoostRegressionModel

__all__ = ["XGBoostClassifier", "XGBoostRegressor", "XGBoostClassificationModel", "XGBoostRegressionModel"]
__version__ = "0.90"
85 changes: 85 additions & 0 deletions jvm-packages/xgboost4j-spark/src/main/resources/sparkxgb/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#
# Copyright (c) 2019 by Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re

from pyspark.ml.param import Params
from pyspark.ml.util import JavaMLWritable
from pyspark.ml.wrapper import JavaModel, JavaEstimator

from sparkxgb.util import XGBoostReadable


class ParamGettersSetters(Params):
"""
Mixin class used to generate the setters/getters for all params.
"""

def _create_param_getters_and_setters(self):
for param in self.params:
param_name = param.name
fg_attr = "get" + re.sub(r"(?:^|_)(.)", lambda m: m.group(1).upper(), param_name)
fs_attr = "set" + re.sub(r"(?:^|_)(.)", lambda m: m.group(1).upper(), param_name)
# Generates getter and setter only if not exists
try:
getattr(self, fg_attr)
except AttributeError:
setattr(self, fg_attr, self._get_param_value(param_name))
try:
getattr(self, fs_attr)
except AttributeError:
setattr(self, fs_attr, self._set_param_value(param_name))

def _get_param_value(self, param_name):
def r():
try:
return self.getOrDefault(param_name)
except KeyError:
return None
return r

def _set_param_value(self, param_name):
def r(v):
self.set(self.getParam(param_name), v)
return self
return r


class XGboostEstimator(JavaEstimator, XGBoostReadable, JavaMLWritable, ParamGettersSetters):
"""
Mixin class for XGBoost estimators, like XGBoostClassifier and XGBoostRegressor.
"""

def __init__(self, classname):
super(XGboostEstimator, self).__init__()
self.__class__._java_class_name = classname
self._java_obj = self._new_java_obj(classname, self.uid)
self._create_params_from_java()
self._create_param_getters_and_setters()


class XGboostModel(JavaModel, XGBoostReadable, JavaMLWritable, ParamGettersSetters):
"""
Mixin class for XGBoost models, like XGBoostClassificationModel and XGBoostRegressionModel.
"""

def __init__(self, classname, java_model=None):
super(XGboostModel, self).__init__(java_model=java_model)
if classname and not java_model:
self.__class__._java_class_name = classname
self._java_obj = self._new_java_obj(classname, self.uid)
if java_model is not None:
self._transfer_params_from_java()
self._create_param_getters_and_setters()
40 changes: 40 additions & 0 deletions jvm-packages/xgboost4j-spark/src/main/resources/sparkxgb/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#
# Copyright (c) 2019 by Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.ml.util import JavaMLReadable, JavaMLReader


class XGBoostReadable(JavaMLReadable):
"""
Mixin class that provides a read() method for XGBoostReader.
"""

@classmethod
def read(cls):
"""Returns an XGBoostReader instance for this class."""
return XGBoostReader(cls)


class XGBoostReader(JavaMLReader):
"""
A reader mixin class for XGBoost objects.
"""

@classmethod
def _java_loader_class(cls, clazz):
if hasattr(clazz, '_java_class_name') and clazz._java_class_name is not None:
return clazz._java_class_name
else:
return JavaMLReader._java_loader_class(clazz)
Loading

0 comments on commit 119282e

Please sign in to comment.