From ff339a55c9dc5428d4e5a20af72273107b6d1721 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 25 Aug 2014 12:12:24 -0700 Subject: [PATCH 1/3] add EscapedTextInputFormat --- .../spark/input/EscapedTextInputFormat.scala | 229 ++++++++++++++++++ .../input/EscapedTextInputFormatSuite.scala | 118 +++++++++ 2 files changed, 347 insertions(+) create mode 100644 core/src/main/scala/org/apache/spark/input/EscapedTextInputFormat.scala create mode 100644 core/src/test/scala/org/apache/spark/input/EscapedTextInputFormatSuite.scala diff --git a/core/src/main/scala/org/apache/spark/input/EscapedTextInputFormat.scala b/core/src/main/scala/org/apache/spark/input/EscapedTextInputFormat.scala new file mode 100644 index 0000000000000..2649cb883d18a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/input/EscapedTextInputFormat.scala @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.input + +import java.io.{BufferedReader, IOException, InputStreamReader} + +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FSDataInputStream +import org.apache.hadoop.io.compress.CompressionCodecFactory +import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} + +/** + * Input format for text records saved with in-record delimiter and newline characters escaped. + * + * For example, a record containing two fields: `"a\n"` and `"|b\\"` saved with delimiter `|` + * should be the following: + * {{{ + * a\\\n|\\|b\\\\\n + * }}}, + * where the in-record `|`, `\n`, and `\\` characters are escaped by `\\`. + * Users can configure the delimiter via [[EscapedTextInputFormat$#KEY_DELIMITER]]. + * Its default value [[EscapedTextInputFormat$#DEFAULT_DELIMITER]] is set to match Redshift's UNLOAD + * with the ESCAPE option: + * {{{ + * UNLOAD ('select_statement') + * TO 's3://object_path_prefix' + * ESCAPE + * }}} + * + * @see org.apache.spark.SparkContext#newAPIHadoopFile + */ +class EscapedTextInputFormat extends FileInputFormat[Long, Array[String]] { + + override def createRecordReader( + split: InputSplit, + context: TaskAttemptContext): RecordReader[Long, Array[String]] = { + new EscapedTextRecordReader + } +} + +object EscapedTextInputFormat { + + /** configuration key for delimiter */ + val KEY_DELIMITER = "spark.input.escapedText.delimiter" + /** default delimiter */ + val DEFAULT_DELIMITER = '|' + + /** Gets the delimiter char from conf or the default. */ + private[input] def getDelimiterOrDefault(conf: Configuration): Char = { + val c = conf.get(KEY_DELIMITER, DEFAULT_DELIMITER.toString) + if (c.length != 1) { + throw new IllegalArgumentException(s"Expect delimiter be a single character but got '$c'.") + } else { + c.charAt(0) + } + } +} + +private[input] class EscapedTextRecordReader extends RecordReader[Long, Array[String]] { + + private var reader: BufferedReader = _ + + private var key: Long = _ + private var value: Array[String] = _ + + private var start: Long = _ + private var end: Long = _ + private var cur: Long = _ + + private var delimiter: Char = _ + @inline private[this] final val escapeChar = '\\' + @inline private[this] final val newline = '\n' + + @inline private[this] final val defaultBufferSize = 64 * 1024 + + override def initialize(inputSplit: InputSplit, context: TaskAttemptContext): Unit = { + val split = inputSplit.asInstanceOf[FileSplit] + val file = split.getPath + val conf = context.getConfiguration + delimiter = EscapedTextInputFormat.getDelimiterOrDefault(conf) + require(delimiter != escapeChar, + s"The delimiter and the escape char cannot be the same but found $delimiter.") + require(delimiter != newline, "The delimiter cannot be the newline character.") + val compressionCodecs = new CompressionCodecFactory(conf) + val codec = compressionCodecs.getCodec(file) + if (codec != null) { + throw new IOException(s"Do not support compressed files but found $file.") + } + val fs = file.getFileSystem(conf) + val in = fs.open(file) + start = findNext(in, split.getStart) + end = findNext(in, split.getStart + split.getLength) + cur = start + in.seek(cur) + reader = new BufferedReader(new InputStreamReader(in), defaultBufferSize) + } + + override def getProgress: Float = { + if (start >= end) { + 1.0f + } else { + math.min((cur - start).toFloat / (end - start), 1.0f) + } + } + + override def nextKeyValue(): Boolean = { + if (cur < end) { + key = cur + value = nextValue() + true + } else { + false + } + } + + override def getCurrentValue: Array[String] = value + + override def getCurrentKey: Long = key + + override def close(): Unit = { + if (reader != null) { + reader.close() + } + } + + /** + * Finds the start of the next record. + * Because we don't know whether the first char is escaped or not, we need to first find a + * position that is not escaped. + * @return the start position of the next record + */ + private def findNext(in: FSDataInputStream, start: Long): Long = { + if (start == 0L) return 0L + var pos = start + in.seek(pos) + val br = new BufferedReader(new InputStreamReader(in), defaultBufferSize) + var escaped = true + var eof = false + while (escaped && !eof) { + val v = br.read() + if (v < 0) { + eof = true + } else { + pos += 1 + if (v != escapeChar) { + escaped = false + } + } + } + var newline = false + while ((escaped || !newline) && !eof) { + val v = br.read() + if (v < 0) { + eof = true + } else { + pos += 1 + if (v == escapeChar) { + escaped = true + } else { + if (!escaped) { + newline = v == '\n' + } else { + escaped = false + } + } + } + } + pos + } + + private def nextValue(): Array[String] = { + var escaped = false + val fields = ArrayBuffer.empty[String] + var endOfRecord = false + var eof = false + while (!endOfRecord && !eof) { + var endOfField = false + val sb = new StringBuilder + while (!endOfField && !endOfRecord && !eof) { + val v = reader.read() + if (v < 0) { + eof = true + } else { + cur += 1 + if (escaped) { + if (v != escapeChar && v != delimiter && v != newline) { + throw new IllegalStateException(s"Found ${v.asInstanceOf[Char]} after $escapeChar.") + } + sb.append(v.asInstanceOf[Char]) + escaped = false + } else { + if (v == escapeChar) { + escaped = true + } else if (v == delimiter) { + endOfField = true + } else if (v == newline) { + endOfRecord = true + } else { + sb.append(v.asInstanceOf[Char]) + } + } + } + } + fields.append(sb.toString()) + } + if (escaped) { + throw new IllegalStateException(s"Found hanging escape char.") + } + fields.toArray + } +} diff --git a/core/src/test/scala/org/apache/spark/input/EscapedTextInputFormatSuite.scala b/core/src/test/scala/org/apache/spark/input/EscapedTextInputFormatSuite.scala new file mode 100644 index 0000000000000..c5e14f0c51170 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/input/EscapedTextInputFormatSuite.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.input + +import java.io.{DataOutputStream, File, FileOutputStream} + +import scala.language.implicitConversions + +import com.google.common.io.Files +import org.apache.hadoop.conf.Configuration +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.SparkContext._ +import org.apache.spark.input.EscapedTextInputFormat._ +import org.apache.spark.util.Utils + +class EscapedTextInputFormatSuite extends FunSuite with BeforeAndAfterAll with Logging { + + import EscapedTextInputFormatSuite._ + + private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + + // Set the block size of local file system to test whether files are split right or not. + sc.hadoopConfiguration.setLong("fs.local.block.size", 4) + } + + override def afterAll() { + sc.stop() + } + + private def writeToFile(contents: String, file: File) = { + val bytes = contents.getBytes + val out = new DataOutputStream(new FileOutputStream(file)) + out.write(bytes, 0, bytes.length) + out.close() + } + + private def escape(records: Set[Seq[String]], delimiter: Char): String = { + require(delimiter != '\\' && delimiter != '\n') + records.map { r => + r.map { f => + f.replace("\\", "\\\\") + .replace("\n", "\\\n") + .replace(delimiter, "\\" + delimiter) + }.mkString(delimiter) + }.mkString("", "\n", "\n") + } + + private final val TAB = '\t' + + private val records = Set( + Seq("a\n", DEFAULT_DELIMITER + "b\\"), + Seq("c", TAB + "d"), + Seq("\ne", "\\\\f")) + + private def withTempDir(func: File => Unit): Unit = { + val dir = Files.createTempDir() + dir.deleteOnExit() + logDebug(s"dir: $dir") + func(dir) + Utils.deleteRecursively(dir) + } + + test("default delimiter") { + withTempDir { dir => + val escaped = escape(records, DEFAULT_DELIMITER) + writeToFile(escaped, new File(dir, "part-00000")) + + val rdd = sc.newAPIHadoopFile(dir.toString, classOf[EscapedTextInputFormat], + classOf[Long], classOf[Array[String]]) + assert(rdd.partitions.size > 3) // so there will be empty partitions + + val actual = rdd.values.map(_.toSeq).collect().toSet + assert(actual === records) + } + } + + test("customized delimiter") { + withTempDir { dir => + val escaped = escape(records, TAB) + writeToFile(escaped, new File(dir, "part-00000")) + + val conf = new Configuration + conf.set(KEY_DELIMITER, TAB) + + val rdd = sc.newAPIHadoopFile(dir.toString, classOf[EscapedTextInputFormat], + classOf[Long], classOf[Array[String]], conf) + assert(rdd.partitions.size > 3) // so their will be empty partitions + + val actual = rdd.values.map(_.toSeq).collect().toSet + assert(actual === records) + } + } +} + +object EscapedTextInputFormatSuite { + + implicit def charToString(c: Char): String = c.toString +} From e35a366b11bd7c8f92775b056b5fb7e62eda5df4 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 25 Aug 2014 14:13:39 -0700 Subject: [PATCH 2/3] use LocalSparkContext --- .../input/EscapedTextInputFormatSuite.scala | 43 +++++++++---------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/input/EscapedTextInputFormatSuite.scala b/core/src/test/scala/org/apache/spark/input/EscapedTextInputFormatSuite.scala index c5e14f0c51170..daa64e2403f84 100644 --- a/core/src/test/scala/org/apache/spark/input/EscapedTextInputFormatSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/EscapedTextInputFormatSuite.scala @@ -23,30 +23,17 @@ import scala.language.implicitConversions import com.google.common.io.Files import org.apache.hadoop.conf.Configuration -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.FunSuite -import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.{SparkContext, LocalSparkContext, Logging} import org.apache.spark.SparkContext._ import org.apache.spark.input.EscapedTextInputFormat._ import org.apache.spark.util.Utils -class EscapedTextInputFormatSuite extends FunSuite with BeforeAndAfterAll with Logging { +class EscapedTextInputFormatSuite extends FunSuite with LocalSparkContext with Logging { import EscapedTextInputFormatSuite._ - private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - - // Set the block size of local file system to test whether files are split right or not. - sc.hadoopConfiguration.setLong("fs.local.block.size", 4) - } - - override def afterAll() { - sc.stop() - } - private def writeToFile(contents: String, file: File) = { val bytes = contents.getBytes val out = new DataOutputStream(new FileOutputStream(file)) @@ -65,6 +52,8 @@ class EscapedTextInputFormatSuite extends FunSuite with BeforeAndAfterAll with L }.mkString("", "\n", "\n") } + private final val KEY_BLOCK_SIZE = "fs.local.block.size" + private final val TAB = '\t' private val records = Set( @@ -81,33 +70,41 @@ class EscapedTextInputFormatSuite extends FunSuite with BeforeAndAfterAll with L } test("default delimiter") { + sc = new SparkContext("local", "test default delimiter") withTempDir { dir => val escaped = escape(records, DEFAULT_DELIMITER) writeToFile(escaped, new File(dir, "part-00000")) + val conf = new Configuration + conf.setLong(KEY_BLOCK_SIZE, 4) + val rdd = sc.newAPIHadoopFile(dir.toString, classOf[EscapedTextInputFormat], - classOf[Long], classOf[Array[String]]) - assert(rdd.partitions.size > 3) // so there will be empty partitions + classOf[Long], classOf[Array[String]], conf) + assert(rdd.partitions.size > records.size) // so there exist at least one empty partition - val actual = rdd.values.map(_.toSeq).collect().toSet - assert(actual === records) + val actual = rdd.values.map(_.toSeq).collect() + assert(actual.size === records.size) + assert(actual.toSet === records) } } test("customized delimiter") { + sc = new SparkContext("local", "test customized delimiter") withTempDir { dir => val escaped = escape(records, TAB) writeToFile(escaped, new File(dir, "part-00000")) val conf = new Configuration + conf.setLong(KEY_BLOCK_SIZE, 4) conf.set(KEY_DELIMITER, TAB) val rdd = sc.newAPIHadoopFile(dir.toString, classOf[EscapedTextInputFormat], classOf[Long], classOf[Array[String]], conf) - assert(rdd.partitions.size > 3) // so their will be empty partitions + assert(rdd.partitions.size > records.size) // so there exist at least one empty partitions - val actual = rdd.values.map(_.toSeq).collect().toSet - assert(actual === records) + val actual = rdd.values.map(_.toSeq).collect() + assert(actual.size === records.size) + assert(actual.toSet === records) } } } From f8d0191ce20d5a49ae32ecad85322c5db065fbb5 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 25 Aug 2014 16:09:32 -0700 Subject: [PATCH 3/3] avoid seeking beyond eof --- .../spark/input/EscapedTextInputFormat.scala | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/input/EscapedTextInputFormat.scala b/core/src/main/scala/org/apache/spark/input/EscapedTextInputFormat.scala index 2649cb883d18a..747c70d6af2d5 100644 --- a/core/src/main/scala/org/apache/spark/input/EscapedTextInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/EscapedTextInputFormat.scala @@ -22,7 +22,7 @@ import java.io.{BufferedReader, IOException, InputStreamReader} import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FSDataInputStream +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.compress.CompressionCodecFactory import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} @@ -105,11 +105,15 @@ private[input] class EscapedTextRecordReader extends RecordReader[Long, Array[St throw new IOException(s"Do not support compressed files but found $file.") } val fs = file.getFileSystem(conf) - val in = fs.open(file) - start = findNext(in, split.getStart) - end = findNext(in, split.getStart + split.getLength) + val size = fs.getFileStatus(file).getLen + start = findNext(fs, file, size, split.getStart) + end = findNext(fs, file, size, split.getStart + split.getLength) cur = start - in.seek(cur) + val in = fs.open(file) + if (cur > 0L) { + in.seek(cur - 1L) + in.read() + } reader = new BufferedReader(new InputStreamReader(in), defaultBufferSize) } @@ -147,9 +151,11 @@ private[input] class EscapedTextRecordReader extends RecordReader[Long, Array[St * position that is not escaped. * @return the start position of the next record */ - private def findNext(in: FSDataInputStream, start: Long): Long = { - if (start == 0L) return 0L - var pos = start + private def findNext(fs: FileSystem, file: Path, size: Long, offset: Long): Long = { + if (offset == 0L) return 0L + if (offset >= size) return size + var pos = offset + val in = fs.open(file) in.seek(pos) val br = new BufferedReader(new InputStreamReader(in), defaultBufferSize) var escaped = true @@ -183,6 +189,7 @@ private[input] class EscapedTextRecordReader extends RecordReader[Long, Array[St } } } + in.close() pos }