Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import org.apache.spark.mllib.tree.loss.Losses
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel, RandomForestModel}
import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -960,7 +961,7 @@ private[python] class PythonMLLibAPI extends Serializable {
def estimateKernelDensity(
sample: JavaRDD[Double],
bandwidth: Double, points: java.util.ArrayList[Double]): Array[Double] = {
return new KernelDensity().setSample(sample).setBandwidth(bandwidth).estimate(
new KernelDensity().setSample(sample).setBandwidth(bandwidth).estimate(
points.asScala.toArray)
}

Expand All @@ -979,6 +980,35 @@ private[python] class PythonMLLibAPI extends Serializable {
List[AnyRef](model.clusterCenters, Vectors.dense(model.clusterWeights)).asJava
}

/**
* Wrapper around the generateLinearInput method of LinearDataGenerator.
*/
def generateLinearInputWrapper(
intercept: Double,
weights: JList[Double],
xMean: JList[Double],
xVariance: JList[Double],
nPoints: Int,
seed: Int,
eps: Double): Array[LabeledPoint] = {
LinearDataGenerator.generateLinearInput(
intercept, weights.asScala.toArray, xMean.asScala.toArray,
xVariance.asScala.toArray, nPoints, seed, eps).toArray
}

/**
* Wrapper around the generateLinearRDD method of LinearDataGenerator.
*/
def generateLinearRDDWrapper(
sc: JavaSparkContext,
nexamples: Int,
nfeatures: Int,
eps: Double,
nparts: Int,
intercept: Double): JavaRDD[LabeledPoint] = {
LinearDataGenerator.generateLinearRDD(
sc, nexamples, nfeatures, eps, nparts, intercept)
}
}

/**
Expand Down
22 changes: 20 additions & 2 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@
from pyspark.mllib.stat import Statistics
from pyspark.mllib.feature import Word2Vec
from pyspark.mllib.feature import IDF
from pyspark.mllib.feature import StandardScaler
from pyspark.mllib.feature import ElementwiseProduct
from pyspark.mllib.feature import StandardScaler, ElementwiseProduct
from pyspark.mllib.util import LinearDataGenerator
from pyspark.serializers import PickleSerializer
from pyspark.streaming import StreamingContext
from pyspark.sql import SQLContext
Expand Down Expand Up @@ -1011,6 +1011,24 @@ def collect(rdd):
self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]])


class LinearDataGeneratorTests(MLlibTestCase):
def test_dim(self):
linear_data = LinearDataGenerator.generateLinearInput(
intercept=0.0, weights=[0.0, 0.0, 0.0],
xMean=[0.0, 0.0, 0.0], xVariance=[0.33, 0.33, 0.33],
nPoints=4, seed=0, eps=0.1)
self.assertEqual(len(linear_data), 4)
for point in linear_data:
self.assertEqual(len(point.features), 3)

linear_data = LinearDataGenerator.generateLinearRDD(
sc=sc, nexamples=6, nfeatures=2, eps=0.1,
nParts=2, intercept=0.0).collect()
self.assertEqual(len(linear_data), 6)
for point in linear_data:
self.assertEqual(len(point.features), 2)


if __name__ == "__main__":
if not _have_scipy:
print("NOTE: Skipping SciPy tests as it does not seem to be installed")
Expand Down
35 changes: 35 additions & 0 deletions python/pyspark/mllib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,41 @@ def load(cls, sc, path):
return cls(java_model)


class LinearDataGenerator(object):
"""Utils for generating linear data"""

@staticmethod
def generateLinearInput(intercept, weights, xMean, xVariance,
nPoints, seed, eps):
"""
:param: intercept bias factor, the term c in X'w + c
:param: weights feature vector, the term w in X'w + c
:param: xMean Point around which the data X is centered.
:param: xVariance Variance of the given data
:param: nPoints Number of points to be generated
:param: seed Random Seed
:param: eps Used to scale the noise. If eps is set high,
the amount of gaussian noise added is more.
Returns a list of LabeledPoints of length nPoints
"""
weights = [float(weight) for weight in weights]
xMean = [float(mean) for mean in xMean]
xVariance = [float(var) for var in xVariance]
return list(callMLlibFunc(
"generateLinearInputWrapper", float(intercept), weights, xMean,
xVariance, int(nPoints), int(seed), float(eps)))

@staticmethod
def generateLinearRDD(sc, nexamples, nfeatures, eps,
nParts=2, intercept=0.0):
"""
Generate a RDD of LabeledPoints.
"""
return callMLlibFunc(
"generateLinearRDDWrapper", sc, int(nexamples), int(nfeatures),
float(eps), int(nParts), float(intercept))


def _test():
import doctest
from pyspark.context import SparkContext
Expand Down