1717
1818package org .apache .spark .mllib .regression
1919
20- import java .io .File
21- import java .nio .charset .Charset
22-
2320import scala .collection .mutable .ArrayBuffer
2421
25- import com .google .common .io .Files
2622import org .scalatest .FunSuite
2723
2824import org .apache .spark .mllib .linalg .Vectors
29- import org .apache .spark .mllib .util .{LinearDataGenerator , LocalSparkContext }
30- import org .apache .spark .streaming .{Milliseconds , StreamingContext }
31- import org .apache .spark .util .Utils
25+ import org .apache .spark .mllib .util .LinearDataGenerator
26+ import org .apache .spark .streaming .dstream .DStream
27+ import org .apache .spark .streaming .TestSuiteBase
28+
29+ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
3230
33- class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
31+ // use longer wait time to ensure job completion
32+ override def maxWaitTimeMillis = 20000
3433
3534 // Assert that two values are equal within tolerance epsilon
3635 def assertEqual (v1 : Double , v2 : Double , epsilon : Double ) {
@@ -51,32 +50,24 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
5150 // Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data
5251 test(" parameter accuracy" ) {
5352
54- val testDir = Files .createTempDir()
55- val numBatches = 10
56- val batchDuration = Milliseconds (1000 )
57- val ssc = new StreamingContext (sc, batchDuration)
58- val data = ssc.textFileStream(testDir.toString).map(LabeledPoint .parse)
53+ // create model
5954 val model = new StreamingLinearRegressionWithSGD ()
6055 .setInitialWeights(Vectors .dense(0.0 , 0.0 ))
6156 .setStepSize(0.1 )
62- .setNumIterations(50 )
63-
64- model.trainOn(data)
57+ .setNumIterations(25 )
6558
66- ssc.start()
67-
68- // write data to a file stream
69- for (i <- 0 until numBatches) {
70- val samples = LinearDataGenerator .generateLinearInput(
71- 0.0 , Array (10.0 , 10.0 ), 100 , 42 * (i + 1 ))
72- val file = new File (testDir, i.toString)
73- Files .write(samples.map(x => x.toString).mkString(" \n " ), file, Charset .forName(" UTF-8" ))
74- Thread .sleep(batchDuration.milliseconds)
59+ // generate sequence of simulated data
60+ val numBatches = 10
61+ val input = (0 until numBatches).map { i =>
62+ LinearDataGenerator .generateLinearInput(0.0 , Array (10.0 , 10.0 ), 100 , 42 * (i + 1 ))
7563 }
7664
77- ssc.stop(stopSparkContext= false )
78-
79- Utils .deleteRecursively(testDir)
65+ // apply model training to input stream
66+ val ssc = setupStreams(input, (inputDStream : DStream [LabeledPoint ]) => {
67+ model.trainOn(inputDStream)
68+ inputDStream.count()
69+ })
70+ runStreams(ssc, numBatches, numBatches)
8071
8172 // check accuracy of final parameter estimates
8273 assertEqual(model.latestModel().intercept, 0.0 , 0.1 )
@@ -92,36 +83,31 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
9283 // Test that parameter estimates improve when learning Y = 10*X1 on streaming data
9384 test(" parameter convergence" ) {
9485
95- val testDir = Files .createTempDir()
96- val batchDuration = Milliseconds (2000 )
97- val ssc = new StreamingContext (sc, batchDuration)
98- val numBatches = 5
99- val data = ssc.textFileStream(testDir.toString()).map(LabeledPoint .parse)
86+ // create model
10087 val model = new StreamingLinearRegressionWithSGD ()
10188 .setInitialWeights(Vectors .dense(0.0 ))
10289 .setStepSize(0.1 )
103- .setNumIterations(50 )
104-
105- model.trainOn(data)
90+ .setNumIterations(25 )
10691
107- ssc.start()
108-
109- // write data to a file stream
110- val history = new ArrayBuffer [Double ](numBatches)
111- for (i <- 0 until numBatches) {
112- val samples = LinearDataGenerator .generateLinearInput(0.0 , Array (10.0 ), 100 , 42 * (i + 1 ))
113- val file = new File (testDir, i.toString)
114- Files .write(samples.map(x => x.toString).mkString(" \n " ), file, Charset .forName(" UTF-8" ))
115- Thread .sleep(batchDuration.milliseconds)
116- // wait an extra few seconds to make sure the update finishes before new data arrive
117- Thread .sleep(4000 )
118- history.append(math.abs(model.latestModel().weights(0 ) - 10.0 ))
92+ // generate sequence of simulated data
93+ val numBatches = 10
94+ val input = (0 until numBatches).map { i =>
95+ LinearDataGenerator .generateLinearInput(0.0 , Array (10.0 ), 100 , 42 * (i + 1 ))
11996 }
12097
121- ssc.stop(stopSparkContext= false )
98+ // create buffer to store intermediate fits
99+ val history = new ArrayBuffer [Double ](numBatches)
122100
123- Utils .deleteRecursively(testDir)
101+ // apply model training to input stream, storing the intermediate results
102+ // (we add a count to ensure the result is a DStream)
103+ val ssc = setupStreams(input, (inputDStream : DStream [LabeledPoint ]) => {
104+ model.trainOn(inputDStream)
105+ inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0 ) - 10.0 )))
106+ inputDStream.count()
107+ })
108+ runStreams(ssc, numBatches, numBatches)
124109
110+ // compute change in error
125111 val deltas = history.drop(1 ).zip(history.dropRight(1 ))
126112 // check error stability (it always either shrinks, or increases with small tol)
127113 assert(deltas.forall(x => (x._1 - x._2) <= 0.1 ))
@@ -133,63 +119,30 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
133119 // Test predictions on a stream
134120 test(" predictions" ) {
135121
136- val trainDir = Files .createTempDir()
137- val testDir = Files .createTempDir()
138- val batchDuration = Milliseconds (1000 )
139- val numBatches = 10
140- val nPoints = 100
141-
142- val ssc = new StreamingContext (sc, batchDuration)
143- val data = ssc.textFileStream(trainDir.toString).map(LabeledPoint .parse)
122+ // create model initialized with true weights
144123 val model = new StreamingLinearRegressionWithSGD ()
145- .setInitialWeights(Vectors .dense(0 .0 , 0 .0 ))
124+ .setInitialWeights(Vectors .dense(10 .0 , 10 .0 ))
146125 .setStepSize(0.1 )
147- .setNumIterations(50 )
126+ .setNumIterations(25 )
148127
149- model.trainOn(data)
150-
151- ssc.start()
152-
153- // write training data to a file stream
154- for (i <- 0 until numBatches) {
155- val samples = LinearDataGenerator .generateLinearInput(
156- 0.0 , Array (10.0 , 10.0 ), nPoints, 42 * (i + 1 ))
157- val file = new File (trainDir, i.toString)
158- Files .write(samples.map(x => x.toString).mkString(" \n " ), file, Charset .forName(" UTF-8" ))
159- Thread .sleep(batchDuration.milliseconds)
160- }
161-
162- ssc.stop(stopSparkContext= false )
163-
164- Utils .deleteRecursively(trainDir)
165-
166- print(model.latestModel().weights.toArray.mkString(" " ))
167- print(model.latestModel().intercept)
168-
169- val ssc2 = new StreamingContext (sc, batchDuration)
170- val data2 = ssc2.textFileStream(testDir.toString).map(LabeledPoint .parse)
171-
172- val history = new ArrayBuffer [Double ](numBatches)
173- val predictions = model.predictOnValues(data2.map(x => (x.label, x.features)))
174- val errors = predictions.map(x => math.abs(x._1 - x._2))
175- errors.foreachRDD(rdd => history.append(rdd.reduce(_+_) / nPoints.toDouble))
176-
177- ssc2.start()
178-
179- // write test data to a file stream
180-
181- // make a function
182- for (i <- 0 until numBatches) {
183- val samples = LinearDataGenerator .generateLinearInput(
184- 0.0 , Array (10.0 , 10.0 ), nPoints, 42 * (i + 1 ))
185- val file = new File (testDir, i.toString)
186- Files .write(samples.map(x => x.toString).mkString(" \n " ), file, Charset .forName(" UTF-8" ))
187- Thread .sleep(batchDuration.milliseconds)
128+ // generate sequence of simulated data for testing
129+ val numBatches = 10
130+ val nPoints = 100
131+ val testInput = (0 until numBatches).map { i =>
132+ LinearDataGenerator .generateLinearInput(0.0 , Array (10.0 , 10.0 ), nPoints, 42 * (i + 1 ))
188133 }
189134
190- println(history)
191-
192- ssc2.stop(stopSparkContext= false )
135+ // apply model predictions to test stream
136+ val ssc = setupStreams(testInput, (inputDStream : DStream [LabeledPoint ]) => {
137+ model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
138+ })
139+ // collect the output as (true, estimated) tuples
140+ val output : Seq [Seq [(Double , Double )]] = runStreams(ssc, numBatches, numBatches)
141+
142+ // compute the mean absolute error and check that it's always less than 0.1
143+ val errors = output.map(batch => batch.map(
144+ p => math.abs(p._1 - p._2)).reduce(_+_) / nPoints.toDouble)
145+ assert(errors.forall(x => x <= 0.1 ))
193146
194147 }
195148
0 commit comments