Skip to content

Commit a23633d

Browse files
committed
Handle get bytesRead from different thread
Change-Id: I76c6ff84904211e3fae4dcd11772fb7fa5ec503c
1 parent 6c1dbd6 commit a23633d

File tree

4 files changed

+64
-8
lines changed

4 files changed

+64
-8
lines changed

core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.io.IOException
2121
import java.security.PrivilegedExceptionAction
2222
import java.text.DateFormat
2323
import java.util.{Arrays, Comparator, Date, Locale}
24+
import java.util.concurrent.ConcurrentHashMap
2425

2526
import scala.collection.JavaConverters._
2627
import scala.util.control.NonFatal
@@ -143,14 +144,18 @@ class SparkHadoopUtil extends Logging {
143144
* Returns a function that can be called to find Hadoop FileSystem bytes read. If
144145
* getFSBytesReadOnThreadCallback is called from thread r at time t, the returned callback will
145146
* return the bytes read on r since t.
146-
*
147-
* @return None if the required method can't be found.
148147
*/
149148
private[spark] def getFSBytesReadOnThreadCallback(): () => Long = {
150-
val threadStats = FileSystem.getAllStatistics.asScala.map(_.getThreadStatistics)
151-
val f = () => threadStats.map(_.getBytesRead).sum
152-
val baselineBytesRead = f()
153-
() => f() - baselineBytesRead
149+
val f = () => FileSystem.getAllStatistics.asScala.map(_.getThreadStatistics.getBytesRead).sum
150+
val baseline = (Thread.currentThread().getId, f())
151+
val bytesReadMap = new ConcurrentHashMap[Long, Long]()
152+
153+
() => {
154+
bytesReadMap.put(Thread.currentThread().getId, f())
155+
bytesReadMap.asScala.map { case (k, v) =>
156+
v - (if (k == baseline._1) baseline._2 else 0)
157+
}.sum
158+
}
154159
}
155160

156161
/**

core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,13 @@ class HadoopRDD[K, V](
251251
null
252252
}
253253
// Register an on-task-completion callback to close the input stream.
254-
context.addTaskCompletionListener{ context => closeIfNeeded() }
254+
context.addTaskCompletionListener { context =>
255+
// Update the bytes read before closing is to make sure lingering bytesRead statistics in
256+
// this thread get correctly added.
257+
updateBytesRead()
258+
closeIfNeeded()
259+
}
260+
255261
private val key: K = if (reader == null) null.asInstanceOf[K] else reader.createKey()
256262
private val value: V = if (reader == null) null.asInstanceOf[V] else reader.createValue()
257263

core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,13 @@ class NewHadoopRDD[K, V](
191191
}
192192

193193
// Register an on-task-completion callback to close the input stream.
194-
context.addTaskCompletionListener(context => close())
194+
context.addTaskCompletionListener { context =>
195+
// Update the bytesRead before closing is to make sure lingering bytesRead statistics in
196+
// this thread get correctly added.
197+
updateBytesRead()
198+
close()
199+
}
200+
195201
private var havePair = false
196202
private var recordsSinceMetricsUpdate = 0
197203

core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,45 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext
319319
}
320320
assert(bytesRead >= tmpFile.length())
321321
}
322+
323+
test("input metrics with old Hadoop API in different thread") {
324+
val bytesRead = runAndReturnBytesRead {
325+
sc.textFile(tmpFilePath, 4).mapPartitions { iter =>
326+
val buf = new ArrayBuffer[String]()
327+
val thread = new Thread() {
328+
override def run(): Unit = {
329+
iter.flatMap(_.split(" ")).foreach(buf.append(_))
330+
}
331+
}
332+
thread.start()
333+
thread.join()
334+
335+
buf.iterator
336+
}.count()
337+
}
338+
assert(bytesRead != 0)
339+
assert(bytesRead >= tmpFile.length())
340+
}
341+
342+
test("input metrics with new Hadoop API in different thread") {
343+
val bytesRead = runAndReturnBytesRead {
344+
sc.newAPIHadoopFile(tmpFilePath, classOf[NewTextInputFormat], classOf[LongWritable],
345+
classOf[Text]).mapPartitions { iter =>
346+
val buf = new ArrayBuffer[String]()
347+
val thread = new Thread() {
348+
override def run(): Unit = {
349+
iter.map(_._2.toString).flatMap(_.split(" ")).foreach(buf.append(_))
350+
}
351+
}
352+
thread.start()
353+
thread.join()
354+
355+
buf.iterator
356+
}.count()
357+
}
358+
assert(bytesRead != 0)
359+
assert(bytesRead >= tmpFile.length())
360+
}
322361
}
323362

324363
/**

0 commit comments

Comments
 (0)