diff --git a/arrow-data-source/common/src/main/scala/com/intel/oap/spark/sql/ArrowWriteQueue.scala b/arrow-data-source/common/src/main/scala/com/intel/oap/spark/sql/ArrowWriteQueue.scala index d03ab27b0..f169a2361 100644 --- a/arrow-data-source/common/src/main/scala/com/intel/oap/spark/sql/ArrowWriteQueue.scala +++ b/arrow-data-source/common/src/main/scala/com/intel/oap/spark/sql/ArrowWriteQueue.scala @@ -23,6 +23,7 @@ import java.util.Collections import java.util.UUID import java.util.concurrent.ArrayBlockingQueue import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference import java.util.regex.Pattern import com.intel.oap.spark.sql.ArrowWriteQueue.EOS_BATCH @@ -34,9 +35,12 @@ import org.apache.arrow.dataset.scanner.ScanTask import org.apache.arrow.vector.ipc.message.ArrowRecordBatch import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.internal.Logging + class ArrowWriteQueue(schema: Schema, fileFormat: FileFormat, outputFileURI: String) - extends AutoCloseable { + extends AutoCloseable with Logging { private val scanner = new ScannerImpl(schema) + private val writeException = new AtomicReference[Throwable] private val writeThread = new Thread(() => { URI.create(outputFileURI) // validate uri @@ -47,18 +51,34 @@ class ArrowWriteQueue(schema: Schema, fileFormat: FileFormat, outputFileURI: Str val dirURI = matcher.group(1) val fileName = matcher.group(2) - DatasetFileWriter.write(scanner, fileFormat, dirURI, Array(), 1, fileName) + try { + DatasetFileWriter.write(scanner, fileFormat, dirURI, Array(), 1, fileName) + } catch { + case e: Throwable => + writeException.set(e) + } }, "ArrowWriteQueue - " + UUID.randomUUID().toString) writeThread.start() + private def checkWriteException(): Unit = { + // check if ArrowWriteQueue thread was failed + val exception = writeException.get() + if (exception != null) { + logWarning("Failed to write arrow.", exception) + throw exception + } + } + def enqueue(batch: ArrowRecordBatch): Unit = { scanner.enqueue(batch) + checkWriteException() } override def close(): Unit = { scanner.enqueue(EOS_BATCH) writeThread.join() + checkWriteException() } }