Skip to content

Commit

Permalink
Use double to simpily the codes
Browse files Browse the repository at this point in the history
  • Loading branch information
zsxwing committed Jun 9, 2017
1 parent f471651 commit d5e7492
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -199,52 +199,13 @@ class RateStreamSource(
}

val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds)
val timeIntervalSizeMs = TimeUnit.SECONDS.toMillis(endSeconds - startSeconds)
val relativeMsPerValue =
TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart)

val func =
if (timeIntervalSizeMs < rangeEnd - rangeStart) {
// Different rows may have the same timestamp
val valueSizePerMs = (rangeEnd - rangeStart) / timeIntervalSizeMs
val remainderValue = (rangeEnd - rangeStart) % timeIntervalSizeMs

(v: Long) => {
val relativeValue = v - rangeStart
val relativeMs = {
// Increase the timestamp per "valueSizePerMs + 1" values before
// "(valueSizePerMs + 1) * remainderValue", and increase the timestamp per
// "valueSizePerMs" values for remaining values.

// The following condition is the same as
// "relativeValue < (valueSizePerMs + 1) * remainderValue", just rewrite it to avoid
// overflow.
if (relativeValue - remainderValue < valueSizePerMs * remainderValue) {
relativeValue / (valueSizePerMs + 1)
} else {
(relativeValue - remainderValue) / valueSizePerMs
}
}
InternalRow(DateTimeUtils.fromMillis(relativeMs + localStartTimeMs), v)
}
} else {
// Different rows never have the same timestamp
val relativeMsPerValue = timeIntervalSizeMs / (rangeEnd - rangeStart)
val remainderMs = timeIntervalSizeMs % (rangeEnd - rangeStart)

(v: Long) => {
val relativeValue = v - rangeStart
// The interval size for the first "remainderMs" values will be "relativeMsPerValue + 1",
// and the interval size for remaining values will be "relativeMsPerValue".
val relativeMs =
if (relativeValue < remainderMs) {
relativeValue * (relativeMsPerValue + 1)
} else {
remainderMs + relativeValue * relativeMsPerValue
}
InternalRow(DateTimeUtils.fromMillis(relativeMs + localStartTimeMs), v)
}
}

val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map(func)
val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v =>
val relative = math.round((v - rangeStart) * relativeMsPerValue)
InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v)
}
sqlContext.internalCreateDataFrame(rdd, schema)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,33 +56,17 @@ class RateSourceSuite extends StreamTest {
)
}

test("uniform distribution of event timestamps: rowsPerSecond > 1000") {
test("uniform distribution of event timestamps") {
val input = spark.readStream
.format("rate")
.option("rowsPerSecond", "1500")
.option("useManualClock", "true")
.load()
.as[(java.sql.Timestamp, Long)]
.map(v => (v._1.getTime, v._2))
val expectedAnswer =
(0 until 1000).map(v => (v / 2, v)) ++ // Two values share the same timestamp.
((1000 until 1500).map(v => (v - 500, v))) // Each value has one timestamp
testStream(input)(
AdvanceRateManualClock(seconds = 1),
CheckLastBatch(expectedAnswer: _*)
)
}

test("uniform distribution of event timestamps: rowsPerSecond < 1000") {
val input = spark.readStream
.format("rate")
.option("rowsPerSecond", "400")
.option("useManualClock", "true")
.load()
.as[(java.sql.Timestamp, Long)]
.map(v => (v._1.getTime, v._2))
val expectedAnswer = (0 until 200).map(v => (v * 3, v)) ++
((200 until 400).map(v => (600 + (v - 200) * 2, v)))
val expectedAnswer = (0 until 1500).map { v =>
(math.round(v * (1000.0 / 1500)), v)
}
testStream(input)(
AdvanceRateManualClock(seconds = 1),
CheckLastBatch(expectedAnswer: _*)
Expand Down Expand Up @@ -121,7 +105,7 @@ class RateSourceSuite extends StreamTest {
CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4
AdvanceRateManualClock(seconds = 1),
CheckLastBatch({
Seq(2000 -> 6, 2167 -> 7, 2334 -> 8, 2501 -> 9, 2668 -> 10, 2834 -> 11)
Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11)
}: _*), // speed = 6
AdvanceRateManualClock(seconds = 1),
CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8
Expand Down

0 comments on commit d5e7492

Please sign in to comment.