Skip to content

Commit 855ed44

Browse files
zsxwingmarmbrus
authored andcommitted
[SPARK-14176][SQL] Add DataFrameWriter.trigger to set the stream batch period
## What changes were proposed in this pull request? Add a processing time trigger to control the batch processing speed ## How was this patch tested? Unit tests Author: Shixiong Zhu <shixiong@databricks.com> Closes #11976 from zsxwing/trigger.
1 parent 89f3bef commit 855ed44

File tree

9 files changed

+413
-13
lines changed

9 files changed

+413
-13
lines changed

sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,20 @@ class ContinuousQueryManager(sqlContext: SQLContext) {
171171
name: String,
172172
checkpointLocation: String,
173173
df: DataFrame,
174-
sink: Sink): ContinuousQuery = {
174+
sink: Sink,
175+
trigger: Trigger = ProcessingTime(0)): ContinuousQuery = {
175176
activeQueriesLock.synchronized {
176177
if (activeQueries.contains(name)) {
177178
throw new IllegalArgumentException(
178179
s"Cannot start query with name $name as a query with that name is already active")
179180
}
180-
val query = new StreamExecution(sqlContext, name, checkpointLocation, df.logicalPlan, sink)
181+
val query = new StreamExecution(
182+
sqlContext,
183+
name,
184+
checkpointLocation,
185+
df.logicalPlan,
186+
sink,
187+
trigger)
181188
query.start()
182189
activeQueries.put(name, query)
183190
query

sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,35 @@ final class DataFrameWriter private[sql](df: DataFrame) {
7777
this
7878
}
7979

80+
/**
81+
* :: Experimental ::
82+
* Set the trigger for the stream query. The default value is `ProcessingTime(0)` and it will run
83+
* the query as fast as possible.
84+
*
85+
* Scala Example:
86+
* {{{
87+
* def.writer.trigger(ProcessingTime("10 seconds"))
88+
*
89+
* import scala.concurrent.duration._
90+
* def.writer.trigger(ProcessingTime(10.seconds))
91+
* }}}
92+
*
93+
* Java Example:
94+
* {{{
95+
* def.writer.trigger(ProcessingTime.create("10 seconds"))
96+
*
97+
* import java.util.concurrent.TimeUnit
98+
* def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS))
99+
* }}}
100+
*
101+
* @since 2.0.0
102+
*/
103+
@Experimental
104+
def trigger(trigger: Trigger): DataFrameWriter = {
105+
this.trigger = trigger
106+
this
107+
}
108+
80109
/**
81110
* Specifies the underlying output data source. Built-in options include "parquet", "json", etc.
82111
*
@@ -261,7 +290,8 @@ final class DataFrameWriter private[sql](df: DataFrame) {
261290
queryName,
262291
checkpointLocation,
263292
df,
264-
dataSource.createSink())
293+
dataSource.createSink(),
294+
trigger)
265295
}
266296

267297
/**
@@ -552,6 +582,8 @@ final class DataFrameWriter private[sql](df: DataFrame) {
552582

553583
private var mode: SaveMode = SaveMode.ErrorIfExists
554584

585+
private var trigger: Trigger = ProcessingTime(0L)
586+
555587
private var extraOptions = new scala.collection.mutable.HashMap[String, String]
556588

557589
private var partitioningColumns: Option[Seq[String]] = None
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import java.util.concurrent.TimeUnit
21+
22+
import scala.concurrent.duration.Duration
23+
24+
import org.apache.commons.lang3.StringUtils
25+
26+
import org.apache.spark.annotation.Experimental
27+
import org.apache.spark.unsafe.types.CalendarInterval
28+
29+
/**
30+
* :: Experimental ::
31+
* Used to indicate how often results should be produced by a [[ContinuousQuery]].
32+
*/
33+
@Experimental
34+
sealed trait Trigger {}
35+
36+
/**
37+
* :: Experimental ::
38+
* A trigger that runs a query periodically based on the processing time. If `intervalMs` is 0,
39+
* the query will run as fast as possible.
40+
*
41+
* Scala Example:
42+
* {{{
43+
* def.writer.trigger(ProcessingTime("10 seconds"))
44+
*
45+
* import scala.concurrent.duration._
46+
* def.writer.trigger(ProcessingTime(10.seconds))
47+
* }}}
48+
*
49+
* Java Example:
50+
* {{{
51+
* def.writer.trigger(ProcessingTime.create("10 seconds"))
52+
*
53+
* import java.util.concurrent.TimeUnit
54+
* def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS))
55+
* }}}
56+
*/
57+
@Experimental
58+
case class ProcessingTime(intervalMs: Long) extends Trigger {
59+
require(intervalMs >= 0, "the interval of trigger should not be negative")
60+
}
61+
62+
/**
63+
* :: Experimental ::
64+
* Used to create [[ProcessingTime]] triggers for [[ContinuousQuery]]s.
65+
*/
66+
@Experimental
67+
object ProcessingTime {
68+
69+
/**
70+
* Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible.
71+
*
72+
* Example:
73+
* {{{
74+
* def.writer.trigger(ProcessingTime("10 seconds"))
75+
* }}}
76+
*/
77+
def apply(interval: String): ProcessingTime = {
78+
if (StringUtils.isBlank(interval)) {
79+
throw new IllegalArgumentException(
80+
"interval cannot be null or blank.")
81+
}
82+
val cal = if (interval.startsWith("interval")) {
83+
CalendarInterval.fromString(interval)
84+
} else {
85+
CalendarInterval.fromString("interval " + interval)
86+
}
87+
if (cal == null) {
88+
throw new IllegalArgumentException(s"Invalid interval: $interval")
89+
}
90+
if (cal.months > 0) {
91+
throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval")
92+
}
93+
new ProcessingTime(cal.microseconds / 1000)
94+
}
95+
96+
/**
97+
* Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible.
98+
*
99+
* Example:
100+
* {{{
101+
* import scala.concurrent.duration._
102+
* def.writer.trigger(ProcessingTime(10.seconds))
103+
* }}}
104+
*/
105+
def apply(interval: Duration): ProcessingTime = {
106+
new ProcessingTime(interval.toMillis)
107+
}
108+
109+
/**
110+
* Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible.
111+
*
112+
* Example:
113+
* {{{
114+
* def.writer.trigger(ProcessingTime.create("10 seconds"))
115+
* }}}
116+
*/
117+
def create(interval: String): ProcessingTime = {
118+
apply(interval)
119+
}
120+
121+
/**
122+
* Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible.
123+
*
124+
* Example:
125+
* {{{
126+
* import java.util.concurrent.TimeUnit
127+
* def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS))
128+
* }}}
129+
*/
130+
def create(interval: Long, unit: TimeUnit): ProcessingTime = {
131+
new ProcessingTime(unit.toMillis(interval))
132+
}
133+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,14 @@ class StreamExecution(
4747
override val name: String,
4848
val checkpointRoot: String,
4949
private[sql] val logicalPlan: LogicalPlan,
50-
val sink: Sink) extends ContinuousQuery with Logging {
50+
val sink: Sink,
51+
val trigger: Trigger) extends ContinuousQuery with Logging {
5152

5253
/** An monitor used to wait/notify when batches complete. */
5354
private val awaitBatchLock = new Object
5455
private val startLatch = new CountDownLatch(1)
5556
private val terminationLatch = new CountDownLatch(1)
5657

57-
/** Minimum amount of time in between the start of each batch. */
58-
private val minBatchTime = 10
59-
6058
/**
6159
* Tracks how much data we have processed and committed to the sink or state store from each
6260
* input source.
@@ -79,6 +77,10 @@ class StreamExecution(
7977
/** A list of unique sources in the query plan. */
8078
private val uniqueSources = sources.distinct
8179

80+
private val triggerExecutor = trigger match {
81+
case t: ProcessingTime => ProcessingTimeExecutor(t)
82+
}
83+
8284
/** Defines the internal state of execution */
8385
@volatile
8486
private var state: State = INITIALIZED
@@ -154,11 +156,15 @@ class StreamExecution(
154156
SQLContext.setActive(sqlContext)
155157
populateStartOffsets()
156158
logDebug(s"Stream running from $committedOffsets to $availableOffsets")
157-
while (isActive) {
158-
if (dataAvailable) runBatch()
159-
commitAndConstructNextBatch()
160-
Thread.sleep(minBatchTime) // TODO: Could be tighter
161-
}
159+
triggerExecutor.execute(() => {
160+
if (isActive) {
161+
if (dataAvailable) runBatch()
162+
commitAndConstructNextBatch()
163+
true
164+
} else {
165+
false
166+
}
167+
})
162168
} catch {
163169
case _: InterruptedException if state == TERMINATED => // interrupted by stop()
164170
case NonFatal(e) =>
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.streaming
19+
20+
import org.apache.spark.internal.Logging
21+
import org.apache.spark.sql.ProcessingTime
22+
import org.apache.spark.util.{Clock, SystemClock}
23+
24+
trait TriggerExecutor {
25+
26+
/**
27+
* Execute batches using `batchRunner`. If `batchRunner` runs `false`, terminate the execution.
28+
*/
29+
def execute(batchRunner: () => Boolean): Unit
30+
}
31+
32+
/**
33+
* A trigger executor that runs a batch every `intervalMs` milliseconds.
34+
*/
35+
case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock = new SystemClock())
36+
extends TriggerExecutor with Logging {
37+
38+
private val intervalMs = processingTime.intervalMs
39+
40+
override def execute(batchRunner: () => Boolean): Unit = {
41+
while (true) {
42+
val batchStartTimeMs = clock.getTimeMillis()
43+
val terminated = !batchRunner()
44+
if (intervalMs > 0) {
45+
val batchEndTimeMs = clock.getTimeMillis()
46+
val batchElapsedTimeMs = batchEndTimeMs - batchStartTimeMs
47+
if (batchElapsedTimeMs > intervalMs) {
48+
notifyBatchFallingBehind(batchElapsedTimeMs)
49+
}
50+
if (terminated) {
51+
return
52+
}
53+
clock.waitTillTime(nextBatchTime(batchEndTimeMs))
54+
} else {
55+
if (terminated) {
56+
return
57+
}
58+
}
59+
}
60+
}
61+
62+
/** Called when a batch falls behind. Expose for test only */
63+
def notifyBatchFallingBehind(realElapsedTimeMs: Long): Unit = {
64+
logWarning("Current batch is falling behind. The trigger interval is " +
65+
s"${intervalMs} milliseconds, but spent ${realElapsedTimeMs} milliseconds")
66+
}
67+
68+
/** Return the next multiple of intervalMs */
69+
def nextBatchTime(now: Long): Long = {
70+
(now - 1) / intervalMs * intervalMs + intervalMs
71+
}
72+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import java.util.concurrent.TimeUnit
21+
22+
import scala.concurrent.duration._
23+
24+
import org.apache.spark.SparkFunSuite
25+
26+
class ProcessingTimeSuite extends SparkFunSuite {
27+
28+
test("create") {
29+
assert(ProcessingTime(10.seconds).intervalMs === 10 * 1000)
30+
assert(ProcessingTime.create(10, TimeUnit.SECONDS).intervalMs === 10 * 1000)
31+
assert(ProcessingTime("1 minute").intervalMs === 60 * 1000)
32+
assert(ProcessingTime("interval 1 minute").intervalMs === 60 * 1000)
33+
34+
intercept[IllegalArgumentException] { ProcessingTime(null: String) }
35+
intercept[IllegalArgumentException] { ProcessingTime("") }
36+
intercept[IllegalArgumentException] { ProcessingTime("invalid") }
37+
intercept[IllegalArgumentException] { ProcessingTime("1 month") }
38+
intercept[IllegalArgumentException] { ProcessingTime("1 year") }
39+
}
40+
}

sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,11 @@ trait StreamTest extends QueryTest with Timeouts {
288288
currentStream =
289289
sqlContext
290290
.streams
291-
.startQuery(StreamExecution.nextName, metadataRoot, stream, sink)
291+
.startQuery(
292+
StreamExecution.nextName,
293+
metadataRoot,
294+
stream,
295+
sink)
292296
.asInstanceOf[StreamExecution]
293297
currentStream.microBatchThread.setUncaughtExceptionHandler(
294298
new UncaughtExceptionHandler {

0 commit comments

Comments
 (0)