Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
f825352
Wrote Python API and example for DecisionTree. Also added toString, …
jkbradley Jul 30, 2014
5f920a1
Demonstration of bug before submitting fix: Updated DecisionTreeSuite…
jkbradley Jul 30, 2014
73fbea2
Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix
jkbradley Jul 30, 2014
2283df8
2 bug fixes.
jkbradley Jul 30, 2014
5fe44ed
Merge remote-tracking branch 'upstream/master' into decisiontree-pyth…
jkbradley Jul 30, 2014
8a758db
Merge branch 'decisiontree-bugfix' into decisiontree-python-new
jkbradley Jul 30, 2014
8ea8750
Bug fix: Off-by-1 when finding thresholds for splits for continuous f…
jkbradley Jul 31, 2014
cd1d933
Merge branch 'decisiontree-bugfix' into decisiontree-python-new
jkbradley Jul 31, 2014
8e227ea
Changed Strategy so it only requires numClassesForClassification >= 2…
jkbradley Jul 31, 2014
da50db7
Added one more test to DecisionTreeSuite: stump with 2 continuous var…
jkbradley Jul 31, 2014
f5a036c
Merge branch 'decisiontree-bugfix' into decisiontree-python-new
jkbradley Jul 31, 2014
52e17c5
Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix
jkbradley Jul 31, 2014
59750f8
* Updated Strategy to check numClassesForClassification only if algo=…
jkbradley Jul 31, 2014
bab3f19
Merge remote-tracking branch 'upstream/master' into decisiontree-pyth…
jkbradley Jul 31, 2014
e06e423
Merge branch 'decisiontree-bugfix' into decisiontree-python-new
jkbradley Jul 31, 2014
376dca2
Updated meaning of maxDepth by 1 to fit scikit-learn and rpart.
jkbradley Jul 31, 2014
6eed482
In DecisionTree: Changed from using procedural syntax for functions r…
jkbradley Jul 31, 2014
978cfcf
Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix
jkbradley Jul 31, 2014
8bb8aa0
Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix
jkbradley Jul 31, 2014
dab0b67
Added documentation for DecisionTree internals
jkbradley Jul 31, 2014
584449a
Merge remote-tracking branch 'upstream/master' into decisiontree-pyth…
jkbradley Jul 31, 2014
1b29c13
Merge branch 'decisiontree-bugfix' into decisiontree-python-new
jkbradley Jul 31, 2014
2b20c61
Small doc and style updates
jkbradley Jul 31, 2014
b8fac57
Finished Python DecisionTree API and example but need to test a bit m…
jkbradley Aug 1, 2014
6622247
Merge remote-tracking branch 'upstream/master' into decisiontree-pyth…
jkbradley Aug 1, 2014
188cb0d
Merge branch 'decisiontree-bugfix' into decisiontree-python-new
jkbradley Aug 1, 2014
665ba78
Small updates towards Python DecisionTree API
jkbradley Aug 1, 2014
4562c08
Merge remote-tracking branch 'upstream/master' into decisiontree-pyth…
jkbradley Aug 1, 2014
6df89a9
Merge remote-tracking branch 'upstream/master' into decisiontree-pyth…
jkbradley Aug 1, 2014
93953f1
Likely done with Python API.
jkbradley Aug 1, 2014
225822f
Bug: In DecisionTree, the method sequentialBinSearchForOrderedCategor…
jkbradley Aug 1, 2014
6873fa9
Merge remote-tracking branch 'upstream/master' into decisiontree-pyth…
jkbradley Aug 1, 2014
db0eab2
Merge branch 'decisiontree-bugfix2' into decisiontree-python-new
jkbradley Aug 1, 2014
4801b40
Small style update to DecisionTreeSuite
jkbradley Aug 1, 2014
e34c263
Merge remote-tracking branch 'upstream/master' into decisiontree-pyth…
jkbradley Aug 1, 2014
7968692
small braces typo fix
jkbradley Aug 1, 2014
fa10ea7
Small style update
jkbradley Aug 1, 2014
bf21be4
removed old run() func from DecisionTree
jkbradley Aug 1, 2014
aa29873
Merge remote-tracking branch 'upstream/master' into decisiontree-pyth…
jkbradley Aug 2, 2014
cf46ad7
Python DecisionTreeModel
jkbradley Aug 2, 2014
67a29bc
Merge remote-tracking branch 'upstream/master' into decisiontree-pyth…
jkbradley Aug 2, 2014
affceb9
* Fixed bug in doc tests in pyspark/mllib/util.py caused by change in…
jkbradley Aug 2, 2014
6b86a9d
Merge remote-tracking branch 'upstream/master' into decisiontree-pyth…
jkbradley Aug 2, 2014
3744488
Renamed test tree.py to decision_tree_runner.py
jkbradley Aug 2, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions examples/src/main/python/mllib/decision_tree_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
Decision tree classification and regression using MLlib.
"""

import numpy, os, sys

from operator import add

from pyspark import SparkContext
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.tree import DecisionTree
from pyspark.mllib.util import MLUtils


def getAccuracy(dtModel, data):
"""
Return accuracy of DecisionTreeModel on the given RDD[LabeledPoint].
"""
seqOp = (lambda acc, x: acc + (x[0] == x[1]))
predictions = dtModel.predict(data.map(lambda x: x.features))
truth = data.map(lambda p: p.label)
trainCorrect = predictions.zip(truth).aggregate(0, seqOp, add)
if data.count() == 0:
return 0
return trainCorrect / (0.0 + data.count())


def getMSE(dtModel, data):
"""
Return mean squared error (MSE) of DecisionTreeModel on the given
RDD[LabeledPoint].
"""
seqOp = (lambda acc, x: acc + numpy.square(x[0] - x[1]))
predictions = dtModel.predict(data.map(lambda x: x.features))
truth = data.map(lambda p: p.label)
trainMSE = predictions.zip(truth).aggregate(0, seqOp, add)
if data.count() == 0:
return 0
return trainMSE / (0.0 + data.count())


def reindexClassLabels(data):
"""
Re-index class labels in a dataset to the range {0,...,numClasses-1}.
If all labels in that range already appear at least once,
then the returned RDD is the same one (without a mapping).
Note: If a label simply does not appear in the data,
the index will not include it.
Be aware of this when reindexing subsampled data.
:param data: RDD of LabeledPoint where labels are integer values
denoting labels for a classification problem.
:return: Pair (reindexedData, origToNewLabels) where
reindexedData is an RDD of LabeledPoint with labels in
the range {0,...,numClasses-1}, and
origToNewLabels is a dictionary mapping original labels
to new labels.
"""
# classCounts: class --> # examples in class
classCounts = data.map(lambda x: x.label).countByValue()
numExamples = sum(classCounts.values())
sortedClasses = sorted(classCounts.keys())
numClasses = len(classCounts)
# origToNewLabels: class --> index in 0,...,numClasses-1
if (numClasses < 2):
print >> sys.stderr, \
"Dataset for classification should have at least 2 classes." + \
" The given dataset had only %d classes." % numClasses
exit(1)
origToNewLabels = dict([(sortedClasses[i], i) for i in range(0, numClasses)])

print "numClasses = %d" % numClasses
print "Per-class example fractions, counts:"
print "Class\tFrac\tCount"
for c in sortedClasses:
frac = classCounts[c] / (numExamples + 0.0)
print "%g\t%g\t%d" % (c, frac, classCounts[c])

if (sortedClasses[0] == 0 and sortedClasses[-1] == numClasses - 1):
return (data, origToNewLabels)
else:
reindexedData = \
data.map(lambda x: LabeledPoint(origToNewLabels[x.label], x.features))
return (reindexedData, origToNewLabels)


def usage():
print >> sys.stderr, \
"Usage: decision_tree_runner [libsvm format data filepath]\n" + \
" Note: This only supports binary classification."
exit(1)


if __name__ == "__main__":
if len(sys.argv) > 2:
usage()
sc = SparkContext(appName="PythonDT")

# Load data.
dataPath = 'data/mllib/sample_libsvm_data.txt'
if len(sys.argv) == 2:
dataPath = sys.argv[1]
if not os.path.isfile(dataPath):
usage()
points = MLUtils.loadLibSVMFile(sc, dataPath)

# Re-index class labels if needed.
(reindexedData, origToNewLabels) = reindexClassLabels(points)

# Train a classifier.
model = DecisionTree.trainClassifier(reindexedData, numClasses=2)
# Print learned tree and stats.
print "Trained DecisionTree for classification:"
print " Model numNodes: %d\n" % model.numNodes()
print " Model depth: %d\n" % model.depth()
print " Training accuracy: %g\n" % getAccuracy(model, reindexedData)
print model
4 changes: 3 additions & 1 deletion examples/src/main/python/mllib/logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@
from pyspark.mllib.classification import LogisticRegressionWithSGD


# Parse a line of text into an MLlib LabeledPoint object
def parsePoint(line):
"""
Parse a line of text into an MLlib LabeledPoint object.
"""
values = [float(s) for s in line.split(' ')]
if values[0] == -1: # Convert -1 labels to 0 for MLlib
values[0] = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.mllib.api.python

import java.nio.{ByteBuffer, ByteOrder}

import scala.collection.JavaConverters._

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.mllib.classification._
Expand All @@ -29,6 +31,11 @@ import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.random.{RandomRDDGenerators => RG}
import org.apache.spark.mllib.recommendation._
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.mllib.stat.correlation.CorrelationNames
import org.apache.spark.mllib.util.MLUtils
Expand Down Expand Up @@ -472,6 +479,76 @@ class PythonMLLibAPI extends Serializable {
ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha)
}

/**
* Java stub for Python mllib DecisionTree.train().
* This stub returns a handle to the Java object instead of the content of the Java object.
* Extra care needs to be taken in the Python code to ensure it gets freed on exit;
* see the Py4J documentation.
* @param dataBytesJRDD Training data
* @param categoricalFeaturesInfoJMap Categorical features info, as Java map
*/
def trainDecisionTreeModel(
dataBytesJRDD: JavaRDD[Array[Byte]],
algoStr: String,
numClasses: Int,
categoricalFeaturesInfoJMap: java.util.Map[Int, Int],
impurityStr: String,
maxDepth: Int,
maxBins: Int): DecisionTreeModel = {

val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint)

val algo: Algo = algoStr match {
case "classification" => Classification
case "regression" => Regression
case _ => throw new IllegalArgumentException(s"Bad algoStr parameter: $algoStr")
}
val impurity: Impurity = impurityStr match {
case "gini" => Gini
case "entropy" => Entropy
case "variance" => Variance
case _ => throw new IllegalArgumentException(s"Bad impurityStr parameter: $impurityStr")
}

val strategy = new Strategy(
algo = algo,
impurity = impurity,
maxDepth = maxDepth,
numClassesForClassification = numClasses,
maxBins = maxBins,
categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap)

DecisionTree.train(data, strategy)
}

/**
* Predict the label of the given data point.
* This is a Java stub for python DecisionTreeModel.predict()
*
* @param featuresBytes Serialized feature vector for data point
* @return predicted label
*/
def predictDecisionTreeModel(
model: DecisionTreeModel,
featuresBytes: Array[Byte]): Double = {
val features: Vector = deserializeDoubleVector(featuresBytes)
model.predict(features)
}

/**
* Predict the labels of the given data points.
* This is a Java stub for python DecisionTreeModel.predict()
*
* @param dataJRDD A JavaRDD with serialized feature vectors
* @return JavaRDD of serialized predictions
*/
def predictDecisionTreeModel(
model: DecisionTreeModel,
dataJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = {
val data = dataJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes))
model.predict(data).map(serializeDouble)
}

/**
* Java stub for mllib Statistics.corr(X: RDD[Vector], method: String).
* Returns the correlation matrix serialized into a byte array understood by deserializers in
Expand Down Expand Up @@ -597,4 +674,5 @@ class PythonMLLibAPI extends Serializable {
val s = getSeedOrDefault(seed)
RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(serializeDoubleVector)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class Strategy (
if (algo == Classification) {
require(numClassesForClassification >= 2)
}
val isMulticlassClassification = numClassesForClassification > 2
val isMulticlassClassification =
algo == Classification && numClassesForClassification > 2
val isMulticlassWithCategoricalFeatures
= isMulticlassClassification && (categoricalFeaturesInfo.size > 0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
requiredMSE: Double) {
val predictions = input.map(x => model.predict(x.features))
val squaredError = predictions.zip(input).map { case (prediction, expected) =>
(prediction - expected.label) * (prediction - expected.label)
val err = prediction - expected.label
err * err
}.sum
val mse = squaredError / input.length
assert(mse <= requiredMSE)
Expand Down
33 changes: 23 additions & 10 deletions python/pyspark/mllib/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,22 +343,35 @@ def _copyto(array, buffer, offset, shape, dtype):
temp_array[...] = array


def _get_unmangled_rdd(data, serializer):
def _get_unmangled_rdd(data, serializer, cache=True):
"""
:param cache: If True, the serialized RDD is cached. (default = True)
WARNING: Users should unpersist() this later!
"""
dataBytes = data.map(serializer)
dataBytes._bypass_serializer = True
dataBytes.cache() # TODO: users should unpersist() this later!
if cache:
dataBytes.cache()
return dataBytes


# Map a pickled Python RDD of Python dense or sparse vectors to a Java RDD of
# _serialized_double_vectors
def _get_unmangled_double_vector_rdd(data):
return _get_unmangled_rdd(data, _serialize_double_vector)
def _get_unmangled_double_vector_rdd(data, cache=True):
"""
Map a pickled Python RDD of Python dense or sparse vectors to a Java RDD of
_serialized_double_vectors.
:param cache: If True, the serialized RDD is cached. (default = True)
WARNING: Users should unpersist() this later!
"""
return _get_unmangled_rdd(data, _serialize_double_vector, cache)


# Map a pickled Python RDD of LabeledPoint to a Java RDD of _serialized_labeled_points
def _get_unmangled_labeled_point_rdd(data):
return _get_unmangled_rdd(data, _serialize_labeled_point)
def _get_unmangled_labeled_point_rdd(data, cache=True):
"""
Map a pickled Python RDD of LabeledPoint to a Java RDD of _serialized_labeled_points.
:param cache: If True, the serialized RDD is cached. (default = True)
WARNING: Users should unpersist() this later!
"""
return _get_unmangled_rdd(data, _serialize_labeled_point, cache)


# Common functions for dealing with and training linear models
Expand All @@ -380,7 +393,7 @@ def _linear_predictor_typecheck(x, coeffs):
if x.size != coeffs.shape[0]:
raise RuntimeError("Got sparse vector of size %d; wanted %d" % (
x.size, coeffs.shape[0]))
elif (type(x) == RDD):
elif isinstance(x, RDD):
raise RuntimeError("Bulk predict not yet supported.")
else:
raise TypeError("Argument of type " + type(x).__name__ + " unsupported")
Expand Down
Loading