Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont
*/
override protected[streaming] val rateController: Option[RateController] = {
if (RateController.isBackPressureEnabled(ssc.conf)) {
RateEstimator.create(ssc.conf).map { new ReceiverRateController(id, _) }
Some(new ReceiverRateController(id, RateEstimator.create(ssc.conf, ssc.graph.batchDuration)))
} else {
None
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.streaming.scheduler.rate

/**
* Implements a proportional-integral-derivative (PID) controller which acts on
* the speed of ingestion of elements into Spark Streaming. A PID controller works
* by calculating an '''error''' between a measured output and a desired value. In the
* case of Spark Streaming the error is the difference between the measured processing
* rate (number of elements/processing delay) and the previous rate.
*
* @see https://en.wikipedia.org/wiki/PID_controller
*
* @param batchDurationMillis the batch duration, in milliseconds
* @param proportional how much the correction should depend on the current
* error. This term usually provides the bulk of correction and should be positive or zero.
* A value too large would make the controller overshoot the setpoint, while a small value
* would make the controller too insensitive. The default value is 1.
* @param integral how much the correction should depend on the accumulation
* of past errors. This value should be positive or 0. This term accelerates the movement
* towards the desired value, but a large value may lead to overshooting. The default value
* is 0.2.
* @param derivative how much the correction should depend on a prediction
* of future errors, based on current rate of change. This value should be positive or 0.
* This term is not used very often, as it impacts stability of the system. The default
* value is 0.
*/
private[streaming] class PIDRateEstimator(
batchIntervalMillis: Long,
proportional: Double = 1D,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I dont particularly mind, but is there any particular reason for having D at the end everywhere. Looks pretty weird.

integral: Double = .2D,
derivative: Double = 0D)
extends RateEstimator {

private var firstRun: Boolean = true
private var latestTime: Long = -1L
private var latestRate: Double = -1D
private var latestError: Double = -1L

require(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should verify the values of the parameters.

batchIntervalMillis > 0,
s"Specified batch interval $batchIntervalMillis in PIDRateEstimator is invalid.")
require(
proportional >= 0,
s"Proportional term $proportional in PIDRateEstimator should be >= 0.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you see my earlier comment about > 1 not being practically feasible.

require(
integral >= 0,
s"Integral term $integral in PIDRateEstimator should be >= 0.")
require(
derivative >= 0,
s"Derivative term $derivative in PIDRateEstimator should be >= 0.")


def compute(time: Long, // in milliseconds
numElements: Long,
processingDelay: Long, // in milliseconds
schedulingDelay: Long // in milliseconds
): Option[Double] = {

this.synchronized {
if (time > latestTime && processingDelay > 0 && batchIntervalMillis > 0) {

// in seconds, should be close to batchDuration
val delaySinceUpdate = (time - latestTime).toDouble / 1000

// in elements/second
val processingRate = numElements.toDouble / processingDelay * 1000

// In our system `error` is the difference between the desired rate and the measured rate
// based on the latest batch information. We consider the desired rate to be latest rate,
// which is what this estimator calculated for the previous batch.
// in elements/second
val error = latestRate - processingRate
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Carrying over conversation from previous thread that got lost due to rebase

Could you make the names more semantically meaningful? How about: error --> changeInRate?
@tdas
tdas added a note 14 hours ago
Why is the latestRate considered as the set point (that's my assumption since the error is calculated between the observed value and the set point, according to PID theory)? @huitseeker
@dragos
dragos added a note 2 hours ago
Since @huitseeker seems to be away, I'll answer this.

The latestRate is what we considered the desired value at the previous batch update. With the new information we got for the last batch interval, we compute a current rate, and compare to what we asked for, that's constitutes our error that needs correction.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I'd prefer to keep this as error, as I think most people reading this code would have more troubles mapping things to PID terminology than to Spark Streaming terminology, and all PID docs will mention error and correction.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that is exactly the problem, people who are reading the streaming code (like me) is more like to know streaming than PID concepts, and if the code does not make it clear in terms of streaming, it is super hard to relate to. So I am fine with keep this as error as long as there is a explanation in terms of streaming stuff along with it. Just like what you added for historicalError.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add the explanation.


// The error integral, based on schedulingDelay as an indicator for accumulated errors.
// A scheduling delay s corresponds to s * processingRate overflowing elements. Those
// are elements that couldn't be processed in previous batches, leading to this delay.
// In the following, we assume the processingRate didn't change too much.
// From the number of overflowing elements we can calculate the rate at which they would be
// processed by dividing it by the batch interval. This rate is our "historical" error,
// or integral part, since if we subtracted this rate from the previous "calculated rate",
// there wouldn't have been any overflowing elements, and the scheduling delay would have
// been zero.
// (in elements/second)
val historicalError = schedulingDelay.toDouble * processingRate / batchIntervalMillis

// in elements/(second ^ 2)
val dError = (error - latestError) / delaySinceUpdate

val newRate = (latestRate - proportional * error -
integral * historicalError -
derivative * dError).max(0.0)
latestTime = time
if (firstRun) {
latestRate = processingRate
latestError = 0D
firstRun = false

None
} else {
latestRate = newRate
latestError = error

Some(newRate)
}
} else None
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.streaming.scheduler.rate

import org.apache.spark.SparkConf
import org.apache.spark.SparkException
import org.apache.spark.streaming.Duration

/**
* A component that estimates the rate at wich an InputDStream should ingest
Expand Down Expand Up @@ -48,12 +49,21 @@ object RateEstimator {
/**
* Return a new RateEstimator based on the value of `spark.streaming.RateEstimator`.
*
* @return None if there is no configured estimator, otherwise an instance of RateEstimator
* The only known estimator right now is `pid`.
*
* @return An instance of RateEstimator
* @throws IllegalArgumentException if there is a configured RateEstimator that doesn't match any
* known estimators.
*/
def create(conf: SparkConf): Option[RateEstimator] =
conf.getOption("spark.streaming.backpressure.rateEstimator").map { estimator =>
throw new IllegalArgumentException(s"Unkown rate estimator: $estimator")
def create(conf: SparkConf, batchInterval: Duration): RateEstimator =
conf.get("spark.streaming.backpressure.rateEstimator", "pid") match {
case "pid" =>
val proportional = conf.getDouble("spark.streaming.backpressure.pid.proportional", 1.0)
val integral = conf.getDouble("spark.streaming.backpressure.pid.integral", 0.2)
val derived = conf.getDouble("spark.streaming.backpressure.pid.derived", 0.0)
new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived)

case estimator =>
throw new IllegalArgumentException(s"Unkown rate estimator: $estimator")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.streaming.scheduler.rate

import scala.util.Random

import org.scalatest.Inspectors.forAll
import org.scalatest.Matchers

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.streaming.Seconds

class PIDRateEstimatorSuite extends SparkFunSuite with Matchers {

test("the right estimator is created") {
val conf = new SparkConf
conf.set("spark.streaming.backpressure.rateEstimator", "pid")
val pid = RateEstimator.create(conf, Seconds(1))
pid.getClass should equal(classOf[PIDRateEstimator])
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test for the spark.streaming.backpressure.rateEstimator.pid.* params as well.


test("estimator checks ranges") {
intercept[IllegalArgumentException] {
new PIDRateEstimator(0, 1, 2, 3)
}
intercept[IllegalArgumentException] {
new PIDRateEstimator(100, -1, 2, 3)
}
intercept[IllegalArgumentException] {
new PIDRateEstimator(100, 0, -1, 3)
}
intercept[IllegalArgumentException] {
new PIDRateEstimator(100, 0, 0, -1)
}
}

private def createDefaultEstimator: PIDRateEstimator = {
new PIDRateEstimator(20, 1D, 0D, 0D)
}

test("first bound is None") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add an unit test that is going to verify whether setting the appropriate params in the SparkConf create the right estimator and the right estimator parameters.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add unit test to check whether setting the parameters to incorrect ranges throws errors or not (use intercept[ ] to catch errors)

val p = createDefaultEstimator
p.compute(0, 10, 10, 0) should equal(None)
}

test("second bound is rate") {
val p = createDefaultEstimator
p.compute(0, 10, 10, 0)
// 1000 elements / s
p.compute(10, 10, 10, 0) should equal(Some(1000))
}

test("works even with no time between updates") {
val p = createDefaultEstimator
p.compute(0, 10, 10, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is hard to understand. Basically the internal state is being mutated in these series of calls. So could you put asserts on the internal state so that I can understand what are the expected changes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test isn't really about mutating state, but checking that two consecutive updates for the same time doesn't trip the estimator (for instance, a division by zero).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then it will easier to get if inline comments said so .. // same time as previous update

p.compute(10, 10, 10, 0)
p.compute(10, 10, 10, 0) should equal(None)
}

test("bound is never negative") {
val p = new PIDRateEstimator(20, 1D, 1D, 0D)
// prepare a series of batch updates, one every 20ms, 0 processed elements, 2ms of processing
// this might point the estimator to try and decrease the bound, but we test it never
// goes below zero, which would be nonsensical.
val times = List.tabulate(50)(x => x * 20) // every 20ms
val elements = List.fill(50)(0) // no processing
val proc = List.fill(50)(20) // 20ms of processing
val sched = List.fill(50)(100) // strictly positive accumulation
val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i))
res.head should equal(None)
res.tail should equal(List.fill(49)(Some(0D)))
}

test("with no accumulated or positive error, |I| > 0, follow the processing speed") {
val p = new PIDRateEstimator(20, 1D, 1D, 0D)
// prepare a series of batch updates, one every 20ms with an increasing number of processed
// elements in each batch, but constant processing time, and no accumulated error. Even though
// the integral part is non-zero, the estimated rate should follow only the proportional term
val times = List.tabulate(50)(x => x * 20) // every 20ms
val elements = List.tabulate(50)(x => x * 20) // increasing
val proc = List.fill(50)(20) // 20ms of processing
val sched = List.fill(50)(0)
val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i))
res.head should equal(None)
res.tail should equal(List.tabulate(50)(x => Some(x * 1000D)).tail)
}

test("with no accumulated but some positive error, |I| > 0, follow the processing speed") {
val p = new PIDRateEstimator(20, 1D, 1D, 0D)
// prepare a series of batch updates, one every 20ms with an decreasing number of processed
// elements in each batch, but constant processing time, and no accumulated error. Even though
// the integral part is non-zero, the estimated rate should follow only the proportional term,
// asking for less and less elements
val times = List.tabulate(50)(x => x * 20) // every 20ms
val elements = List.tabulate(50)(x => (50 - x) * 20) // decreasing
val proc = List.fill(50)(20) // 20ms of processing
val sched = List.fill(50)(0)
val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i))
res.head should equal(None)
res.tail should equal(List.tabulate(50)(x => Some((50 - x) * 1000D)).tail)
}

test("with some accumulated and some positive error, |I| > 0, stay below the processing speed") {
val p = new PIDRateEstimator(20, 1D, .01D, 0D)
val times = List.tabulate(50)(x => x * 20) // every 20ms
val rng = new Random()
val elements = List.tabulate(50)(x => rng.nextInt(1000))
val procDelayMs = 20
val proc = List.fill(50)(procDelayMs) // 20ms of processing
val sched = List.tabulate(50)(x => rng.nextInt(19)) // random wait
val speeds = elements map ((x) => x.toDouble / procDelayMs * 1000)

val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i))
res.head should equal(None)
forAll(List.range(1, 50)) { (n) =>
res(n) should not be None
if (res(n).get > 0 && sched(n) > 0) {
res(n).get should be < speeds(n)
}
}
}
}