-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-3128][MLLIB] Use streaming test suite for StreamingLR #2037
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,20 +17,19 @@ | |
|
|
||
| package org.apache.spark.mllib.regression | ||
|
|
||
| import java.io.File | ||
| import java.nio.charset.Charset | ||
|
|
||
| import scala.collection.mutable.ArrayBuffer | ||
|
|
||
| import com.google.common.io.Files | ||
| import org.scalatest.FunSuite | ||
|
|
||
| import org.apache.spark.mllib.linalg.Vectors | ||
| import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} | ||
| import org.apache.spark.streaming.{Milliseconds, StreamingContext} | ||
| import org.apache.spark.util.Utils | ||
| import org.apache.spark.mllib.util.LinearDataGenerator | ||
| import org.apache.spark.streaming.dstream.DStream | ||
| import org.apache.spark.streaming.TestSuiteBase | ||
|
|
||
| class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { | ||
|
|
||
| class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext { | ||
| // use longer wait time to ensure job completion | ||
| override def maxWaitTimeMillis = 20000 | ||
|
|
||
| // Assert that two values are equal within tolerance epsilon | ||
| def assertEqual(v1: Double, v2: Double, epsilon: Double) { | ||
|
|
@@ -49,35 +48,26 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext { | |
| } | ||
|
|
||
| // Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data | ||
| test("streaming linear regression parameter accuracy") { | ||
| test("parameter accuracy") { | ||
|
|
||
| val testDir = Files.createTempDir() | ||
| val numBatches = 10 | ||
| val batchDuration = Milliseconds(1000) | ||
| val ssc = new StreamingContext(sc, batchDuration) | ||
| val data = ssc.textFileStream(testDir.toString).map(LabeledPoint.parse) | ||
| // create model | ||
| val model = new StreamingLinearRegressionWithSGD() | ||
| .setInitialWeights(Vectors.dense(0.0, 0.0)) | ||
| .setStepSize(0.1) | ||
| .setNumIterations(50) | ||
| .setNumIterations(25) | ||
|
|
||
| model.trainOn(data) | ||
|
|
||
| ssc.start() | ||
|
|
||
| // write data to a file stream | ||
| for (i <- 0 until numBatches) { | ||
| val samples = LinearDataGenerator.generateLinearInput( | ||
| 0.0, Array(10.0, 10.0), 100, 42 * (i + 1)) | ||
| val file = new File(testDir, i.toString) | ||
| Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8")) | ||
| Thread.sleep(batchDuration.milliseconds) | ||
| // generate sequence of simulated data | ||
| val numBatches = 10 | ||
| val input = (0 until numBatches).map { i => | ||
| LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 42 * (i + 1)) | ||
| } | ||
|
|
||
| ssc.stop(stopSparkContext=false) | ||
|
|
||
| System.clearProperty("spark.driver.port") | ||
| Utils.deleteRecursively(testDir) | ||
| // apply model training to input stream | ||
| val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { | ||
| model.trainOn(inputDStream) | ||
| inputDStream.count() | ||
| }) | ||
| runStreams(ssc, numBatches, numBatches) | ||
|
|
||
| // check accuracy of final parameter estimates | ||
| assertEqual(model.latestModel().intercept, 0.0, 0.1) | ||
|
|
@@ -91,39 +81,33 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext { | |
| } | ||
|
|
||
| // Test that parameter estimates improve when learning Y = 10*X1 on streaming data | ||
| test("streaming linear regression parameter convergence") { | ||
| test("parameter convergence") { | ||
|
|
||
| val testDir = Files.createTempDir() | ||
| val batchDuration = Milliseconds(2000) | ||
| val ssc = new StreamingContext(sc, batchDuration) | ||
| val numBatches = 5 | ||
| val data = ssc.textFileStream(testDir.toString()).map(LabeledPoint.parse) | ||
| // create model | ||
| val model = new StreamingLinearRegressionWithSGD() | ||
| .setInitialWeights(Vectors.dense(0.0)) | ||
| .setStepSize(0.1) | ||
| .setNumIterations(50) | ||
|
|
||
| model.trainOn(data) | ||
|
|
||
| ssc.start() | ||
| .setNumIterations(25) | ||
|
|
||
| // write data to a file stream | ||
| val history = new ArrayBuffer[Double](numBatches) | ||
| for (i <- 0 until numBatches) { | ||
| val samples = LinearDataGenerator.generateLinearInput(0.0, Array(10.0), 100, 42 * (i + 1)) | ||
| val file = new File(testDir, i.toString) | ||
| Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8")) | ||
| Thread.sleep(batchDuration.milliseconds) | ||
| // wait an extra few seconds to make sure the update finishes before new data arrive | ||
| Thread.sleep(4000) | ||
| history.append(math.abs(model.latestModel().weights(0) - 10.0)) | ||
| // generate sequence of simulated data | ||
| val numBatches = 10 | ||
| val input = (0 until numBatches).map { i => | ||
| LinearDataGenerator.generateLinearInput(0.0, Array(10.0), 100, 42 * (i + 1)) | ||
| } | ||
|
|
||
| ssc.stop(stopSparkContext=false) | ||
| // create buffer to store intermediate fits | ||
| val history = new ArrayBuffer[Double](numBatches) | ||
|
|
||
| System.clearProperty("spark.driver.port") | ||
| Utils.deleteRecursively(testDir) | ||
| // apply model training to input stream, storing the intermediate results | ||
| // (we add a count to ensure the result is a DStream) | ||
| val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { | ||
| model.trainOn(inputDStream) | ||
| inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - 10.0))) | ||
| inputDStream.count() | ||
| }) | ||
| runStreams(ssc, numBatches, numBatches) | ||
|
|
||
| // compute change in error | ||
| val deltas = history.drop(1).zip(history.dropRight(1)) | ||
| // check error stability (it always either shrinks, or increases with small tol) | ||
| assert(deltas.forall(x => (x._1 - x._2) <= 0.1)) | ||
|
|
@@ -132,4 +116,33 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext { | |
|
|
||
| } | ||
|
|
||
| // Test predictions on a stream | ||
| test("predictions") { | ||
|
|
||
| // create model initialized with true weights | ||
| val model = new StreamingLinearRegressionWithSGD() | ||
| .setInitialWeights(Vectors.dense(10.0, 10.0)) | ||
| .setStepSize(0.1) | ||
| .setNumIterations(25) | ||
|
|
||
| // generate sequence of simulated data for testing | ||
| val numBatches = 10 | ||
| val nPoints = 100 | ||
| val testInput = (0 until numBatches).map { i => | ||
| LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1)) | ||
| } | ||
|
|
||
| // apply model predictions to test stream | ||
| val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { | ||
| model.predictOnValues(inputDStream.map(x => (x.label, x.features))) | ||
| }) | ||
| // collect the output as (true, estimated) tuples | ||
| val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) | ||
|
|
||
| // compute the mean absolute error and check that it's always less than 0.1 | ||
| val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints) | ||
| assert(errors.forall(x => x <= 0.1)) | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: extra line |
||
| } | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: extra line |
||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: extra line.