diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVFilters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVFilters.scala new file mode 100644 index 000000000000..b50a76a49655 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVFilters.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.csv + +import scala.util.Try + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources +import org.apache.spark.sql.types.{BooleanType, StructType} + +/** + * An instance of the class compiles filters to predicates and allows to + * apply the predicates to an internal row with partially initialized values + * converted from parsed CSV fields. + * + * @param filters The filters pushed down to CSV datasource. + * @param requiredSchema The schema with only fields requested by the upper layer. + */ +class CSVFilters(filters: Seq[sources.Filter], requiredSchema: StructType) { + /** + * Converted filters to predicates and grouped by maximum field index + * in the read schema. For example, if an filter refers to 2 attributes + * attrA with field index 5 and attrB with field index 10 in the read schema: + * 0 === $"attrA" or $"attrB" < 100 + * the filter is compiled to a predicate, and placed to the `predicates` + * array at the position 10. In this way, if there is a row with initialized + * fields from the 0 to 10 index, the predicate can be applied to the row + * to check that the row should be skipped or not. + * Multiple predicates with the same maximum reference index are combined + * by the `And` expression. + */ + private val predicates: Array[BasePredicate] = { + val len = requiredSchema.fields.length + val groupedPredicates = Array.fill[BasePredicate](len)(null) + if (SQLConf.get.csvFilterPushDown) { + val groupedFilters = Array.fill(len)(Seq.empty[sources.Filter]) + for (filter <- filters) { + val refs = filter.references + val index = if (refs.isEmpty) { + // For example, AlwaysTrue and AlwaysFalse doesn't have any references + // Filters w/o refs always return the same result. Taking into account + // that predicates are combined via And, we can apply such filters only + // once at the position 0. + 0 + } else { + // readSchema must contain attributes of all filters. + // Accordingly, fieldIndex() returns a valid index always. + refs.map(requiredSchema.fieldIndex).max + } + groupedFilters(index) :+= filter + } + if (len > 0 && !groupedFilters(0).isEmpty) { + // We assume that filters w/o refs like AlwaysTrue and AlwaysFalse + // can be evaluated faster that others. We put them in front of others. + val (literals, others) = groupedFilters(0).partition(_.references.isEmpty) + groupedFilters(0) = literals ++ others + } + for (i <- 0 until len) { + if (!groupedFilters(i).isEmpty) { + val reducedExpr = groupedFilters(i) + .flatMap(CSVFilters.filterToExpression(_, toRef)) + .reduce(And) + groupedPredicates(i) = Predicate.create(reducedExpr) + } + } + } + groupedPredicates + } + + /** + * Applies all filters that refer to row fields at the positions from 0 to index. + * @param row The internal row to check. + * @param index Maximum field index. The function assumes that all fields + * from 0 to index position are set. + * @return false iff row fields at the position from 0 to index pass filters + * or there are no applicable filters + * otherwise false if at least one of the filters returns false. + */ + def skipRow(row: InternalRow, index: Int): Boolean = { + val predicate = predicates(index) + predicate != null && !predicate.eval(row) + } + + // Finds a filter attribute in the read schema and converts it to a `BoundReference` + private def toRef(attr: String): Option[BoundReference] = { + requiredSchema.getFieldIndex(attr).map { index => + val field = requiredSchema(index) + BoundReference(requiredSchema.fieldIndex(attr), field.dataType, field.nullable) + } + } +} + +object CSVFilters { + private def checkFilterRefs(filter: sources.Filter, schema: StructType): Boolean = { + val fieldNames = schema.fields.map(_.name).toSet + filter.references.forall(fieldNames.contains(_)) + } + + /** + * Returns the filters currently supported by CSV datasource. + * @param filters The filters pushed down to CSV datasource. + * @param schema data schema of CSV files. + * @return a sub-set of `filters` that can be handled by CSV datasource. + */ + def pushedFilters(filters: Array[sources.Filter], schema: StructType): Array[sources.Filter] = { + filters.filter(checkFilterRefs(_, schema)) + } + + private def zip[A, B](a: Option[A], b: Option[B]): Option[(A, B)] = { + a.zip(b).headOption + } + + private def toLiteral(value: Any): Option[Literal] = { + Try(Literal(value)).toOption + } + + /** + * Converts a filter to an expression and binds it to row positions. + * + * @param filter The filter to convert. + * @param toRef The function converts a filter attribute to a bound reference. + * @return some expression with resolved attributes or None if the conversion + * of the given filter to an expression is impossible. + */ + def filterToExpression( + filter: sources.Filter, + toRef: String => Option[BoundReference]): Option[Expression] = { + def zipAttributeAndValue(name: String, value: Any): Option[(BoundReference, Literal)] = { + zip(toRef(name), toLiteral(value)) + } + def translate(filter: sources.Filter): Option[Expression] = filter match { + case sources.And(left, right) => + zip(translate(left), translate(right)).map(And.tupled) + case sources.Or(left, right) => + zip(translate(left), translate(right)).map(Or.tupled) + case sources.Not(child) => + translate(child).map(Not) + case sources.EqualTo(attribute, value) => + zipAttributeAndValue(attribute, value).map(EqualTo.tupled) + case sources.EqualNullSafe(attribute, value) => + zipAttributeAndValue(attribute, value).map(EqualNullSafe.tupled) + case sources.IsNull(attribute) => + toRef(attribute).map(IsNull) + case sources.IsNotNull(attribute) => + toRef(attribute).map(IsNotNull) + case sources.In(attribute, values) => + val literals = values.toSeq.flatMap(toLiteral) + if (literals.length == values.length) { + toRef(attribute).map(In(_, literals)) + } else { + None + } + case sources.GreaterThan(attribute, value) => + zipAttributeAndValue(attribute, value).map(GreaterThan.tupled) + case sources.GreaterThanOrEqual(attribute, value) => + zipAttributeAndValue(attribute, value).map(GreaterThanOrEqual.tupled) + case sources.LessThan(attribute, value) => + zipAttributeAndValue(attribute, value).map(LessThan.tupled) + case sources.LessThanOrEqual(attribute, value) => + zipAttributeAndValue(attribute, value).map(LessThanOrEqual.tupled) + case sources.StringContains(attribute, value) => + zipAttributeAndValue(attribute, value).map(Contains.tupled) + case sources.StringStartsWith(attribute, value) => + zipAttributeAndValue(attribute, value).map(StartsWith.tupled) + case sources.StringEndsWith(attribute, value) => + zipAttributeAndValue(attribute, value).map(EndsWith.tupled) + case sources.AlwaysTrue() => + Some(Literal(true, BooleanType)) + case sources.AlwaysFalse() => + Some(Literal(false, BooleanType)) + } + translate(filter) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index 661525a65294..288179fc480d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericInternalRow} import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -39,15 +40,20 @@ import org.apache.spark.unsafe.types.UTF8String * @param requiredSchema The schema of the data that should be output for each row. This should be a * subset of the columns in dataSchema. * @param options Configuration options for a CSV parser. + * @param filters The pushdown filters that should be applied to converted values. */ class UnivocityParser( dataSchema: StructType, requiredSchema: StructType, - val options: CSVOptions) extends Logging { + val options: CSVOptions, + filters: Seq[Filter]) extends Logging { require(requiredSchema.toSet.subsetOf(dataSchema.toSet), s"requiredSchema (${requiredSchema.catalogString}) should be the subset of " + s"dataSchema (${dataSchema.catalogString}).") + def this(dataSchema: StructType, requiredSchema: StructType, options: CSVOptions) = { + this(dataSchema, requiredSchema, options, Seq.empty) + } def this(schema: StructType, options: CSVOptions) = this(schema, schema, options) // A `ValueConverter` is responsible for converting the given value to a desired type. @@ -72,7 +78,11 @@ class UnivocityParser( new CsvParser(parserSetting) } - private val row = new GenericInternalRow(requiredSchema.length) + // Pre-allocated Seq to avoid the overhead of the seq builder. + private val requiredRow = Seq(new GenericInternalRow(requiredSchema.length)) + // Pre-allocated empty sequence returned when the parsed row cannot pass filters. + // We preallocate it avoid unnecessary invokes of the seq builder. + private val noRows = Seq.empty[InternalRow] private val timestampFormatter = TimestampFormatter( options.timestampFormat, @@ -83,6 +93,8 @@ class UnivocityParser( options.zoneId, options.locale) + private val csvFilters = new CSVFilters(filters, requiredSchema) + // Retrieve the raw record string. private def getCurrentInput: UTF8String = { UTF8String.fromString(tokenizer.getContext.currentParsedContent().stripLineEnd) @@ -194,7 +206,7 @@ class UnivocityParser( private val doParse = if (options.columnPruning && requiredSchema.isEmpty) { // If `columnPruning` enabled and partition attributes scanned only, // `schema` gets empty. - (_: String) => InternalRow.empty + (_: String) => Seq(InternalRow.empty) } else { // parse if the columnPruning is disabled or requiredSchema is nonEmpty (input: String) => convert(tokenizer.parseLine(input)) @@ -204,7 +216,7 @@ class UnivocityParser( * Parses a single CSV string and turns it into either one resulting row or no row (if the * the record is malformed). */ - def parse(input: String): InternalRow = doParse(input) + def parse(input: String): Seq[InternalRow] = doParse(input) private val getToken = if (options.columnPruning) { (tokens: Array[String], index: Int) => tokens(index) @@ -212,7 +224,7 @@ class UnivocityParser( (tokens: Array[String], index: Int) => tokens(tokenIndexArr(index)) } - private def convert(tokens: Array[String]): InternalRow = { + private def convert(tokens: Array[String]): Seq[InternalRow] = { if (tokens == null) { throw BadRecordException( () => getCurrentInput, @@ -229,7 +241,7 @@ class UnivocityParser( } def getPartialResult(): Option[InternalRow] = { try { - Some(convert(checkedTokens)) + convert(checkedTokens).headOption } catch { case _: BadRecordException => None } @@ -242,12 +254,24 @@ class UnivocityParser( new RuntimeException("Malformed CSV record")) } else { // When the length of the returned tokens is identical to the length of the parsed schema, - // we just need to convert the tokens that correspond to the required columns. - var badRecordException: Option[Throwable] = None + // we just need to: + // 1. Convert the tokens that correspond to the required schema. + // 2. Apply the pushdown filters to `requiredRow`. var i = 0 + val row = requiredRow.head + var skipRow = false + var badRecordException: Option[Throwable] = None while (i < requiredSchema.length) { try { - row(i) = valueConverters(i).apply(getToken(tokens, i)) + if (!skipRow) { + row(i) = valueConverters(i).apply(getToken(tokens, i)) + if (csvFilters.skipRow(row, i)) { + skipRow = true + } + } + if (skipRow) { + row.setNullAt(i) + } } catch { case NonFatal(e) => badRecordException = badRecordException.orElse(Some(e)) @@ -255,11 +279,15 @@ class UnivocityParser( } i += 1 } - - if (badRecordException.isEmpty) { - row + if (skipRow) { + noRows } else { - throw BadRecordException(() => getCurrentInput, () => Some(row), badRecordException.get) + if (badRecordException.isDefined) { + throw BadRecordException( + () => getCurrentInput, () => requiredRow.headOption, badRecordException.get) + } else { + requiredRow + } } } } @@ -291,7 +319,7 @@ private[sql] object UnivocityParser { schema: StructType): Iterator[InternalRow] = { val tokenizer = parser.tokenizer val safeParser = new FailureSafeParser[Array[String]]( - input => Seq(parser.convert(input)), + input => parser.convert(input), parser.options.parseMode, schema, parser.options.columnNameOfCorruptRecord) @@ -344,7 +372,7 @@ private[sql] object UnivocityParser { val filteredLines: Iterator[String] = CSVExprUtils.filterCommentAndEmpty(lines, options) val safeParser = new FailureSafeParser[String]( - input => Seq(parser.parse(input)), + input => parser.parse(input), parser.options.parseMode, schema, parser.options.columnNameOfCorruptRecord) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 73d329b4f582..54af314fe417 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -114,7 +114,7 @@ case class CsvToStructs( StructType(nullableSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) val rawParser = new UnivocityParser(actualSchema, actualSchema, parsedOptions) new FailureSafeParser[String]( - input => Seq(rawParser.parse(input)), + input => rawParser.parse(input), mode, nullableSchema, parsedOptions.columnNameOfCorruptRecord) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6d45d30a787d..e9d5cb58b612 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2153,6 +2153,11 @@ object SQLConf { .booleanConf .createWithDefault(false) + val CSV_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.csv.filterPushdown.enabled") + .doc("When true, enable filter pushdown to CSV datasource.") + .booleanConf + .createWithDefault(true) + /** * Holds information about keys that have been deprecated. * @@ -2722,6 +2727,8 @@ class SQLConf extends Serializable with Logging { def ignoreDataLocality: Boolean = getConf(SQLConf.IGNORE_DATA_LOCALITY) + def csvFilterPushDown: Boolean = getConf(CSV_FILTER_PUSHDOWN_ENABLED) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVFiltersSuite.scala new file mode 100644 index 000000000000..499bbaf452ae --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVFiltersSuite.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.csv + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.sources +import org.apache.spark.sql.sources.{AlwaysFalse, AlwaysTrue, Filter} +import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.unsafe.types.UTF8String + +class CSVFiltersSuite extends SparkFunSuite { + test("filter to expression conversion") { + val ref = BoundReference(0, IntegerType, true) + def check(f: Filter, expr: Expression): Unit = { + assert(CSVFilters.filterToExpression(f, _ => Some(ref)).get === expr) + } + + check(sources.AlwaysTrue, Literal(true)) + check(sources.AlwaysFalse, Literal(false)) + check(sources.IsNull("a"), IsNull(ref)) + check(sources.Not(sources.IsNull("a")), Not(IsNull(ref))) + check(sources.IsNotNull("a"), IsNotNull(ref)) + check(sources.EqualTo("a", "b"), EqualTo(ref, Literal("b"))) + check(sources.EqualNullSafe("a", "b"), EqualNullSafe(ref, Literal("b"))) + check(sources.StringStartsWith("a", "b"), StartsWith(ref, Literal("b"))) + check(sources.StringEndsWith("a", "b"), EndsWith(ref, Literal("b"))) + check(sources.StringContains("a", "b"), Contains(ref, Literal("b"))) + check(sources.LessThanOrEqual("a", 1), LessThanOrEqual(ref, Literal(1))) + check(sources.LessThan("a", 1), LessThan(ref, Literal(1))) + check(sources.GreaterThanOrEqual("a", 1), GreaterThanOrEqual(ref, Literal(1))) + check(sources.GreaterThan("a", 1), GreaterThan(ref, Literal(1))) + check(sources.And(sources.AlwaysTrue, sources.AlwaysTrue), And(Literal(true), Literal(true))) + check(sources.Or(sources.AlwaysTrue, sources.AlwaysTrue), Or(Literal(true), Literal(true))) + check(sources.In("a", Array(1)), In(ref, Seq(Literal(1)))) + } + + private def getSchema(str: String): StructType = str match { + case "" => new StructType() + case _ => StructType.fromDDL(str) + } + + test("skipping rows") { + def check( + requiredSchema: String = "i INTEGER, d DOUBLE", + filters: Seq[Filter], + row: InternalRow, + pos: Int, + skip: Boolean): Unit = { + val csvFilters = new CSVFilters(filters, getSchema(requiredSchema)) + assert(csvFilters.skipRow(row, pos) === skip) + } + + check(filters = Seq(), row = InternalRow(3.14), pos = 0, skip = false) + check(filters = Seq(AlwaysTrue), row = InternalRow(1), pos = 0, skip = false) + check(filters = Seq(AlwaysFalse), row = InternalRow(1), pos = 0, skip = true) + check( + filters = Seq(sources.EqualTo("i", 1), sources.LessThan("d", 10), sources.AlwaysFalse), + row = InternalRow(1, 3.14), + pos = 0, + skip = true) + check( + filters = Seq(sources.EqualTo("i", 10)), + row = InternalRow(10, 3.14), + pos = 0, + skip = false) + check( + filters = Seq(sources.IsNotNull("d"), sources.GreaterThanOrEqual("d", 2.96)), + row = InternalRow(3.14), + pos = 0, + skip = false) + check( + filters = Seq(sources.In("i", Array(10, 20)), sources.LessThanOrEqual("d", 2.96)), + row = InternalRow(10, 3.14), + pos = 1, + skip = true) + val filters1 = Seq( + sources.Or( + sources.AlwaysTrue, + sources.And( + sources.Not(sources.IsNull("i")), + sources.Not( + sources.And( + sources.StringEndsWith("s", "ab"), + sources.StringEndsWith("s", "cd") + ) + ) + ) + ), + sources.GreaterThan("d", 0), + sources.LessThan("i", 500) + ) + val filters2 = Seq( + sources.And( + sources.StringContains("s", "abc"), + sources.And( + sources.Not(sources.IsNull("i")), + sources.And( + sources.StringEndsWith("s", "ab"), + sources.StringEndsWith("s", "bc") + ) + ) + ), + sources.GreaterThan("d", 100), + sources.LessThan("i", 0) + ) + Seq(filters1 -> false, filters2 -> true).foreach { case (filters, skip) => + for (p <- 0 until 3) { + check( + requiredSchema = "i INTEGER, d DOUBLE, s STRING", + filters = filters, + row = InternalRow(10, 3.14, UTF8String.fromString("abc")), + pos = p, + skip = skip) + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala index 31601f787f1a..bd4b2529f8b9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala @@ -24,9 +24,11 @@ import java.util.{Locale, TimeZone} import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.sources.{EqualTo, Filter, StringStartsWith} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -267,4 +269,52 @@ class UnivocityParserSuite extends SparkFunSuite with SQLHelper { assert(convertedValue.isInstanceOf[UTF8String]) assert(convertedValue == expected) } + + test("skipping rows using pushdown filters") { + def check( + input: String = "1,a", + dataSchema: StructType = StructType.fromDDL("i INTEGER, s STRING"), + requiredSchema: StructType = StructType.fromDDL("i INTEGER"), + filters: Seq[Filter], + expected: Seq[InternalRow]): Unit = { + Seq(false, true).foreach { columnPruning => + val options = new CSVOptions(Map.empty[String, String], columnPruning, "GMT") + val parser = new UnivocityParser(dataSchema, requiredSchema, options, filters) + val actual = parser.parse(input) + assert(actual === expected) + } + } + + check(filters = Seq(), expected = Seq(InternalRow(1))) + check(filters = Seq(EqualTo("i", 1)), expected = Seq(InternalRow(1))) + check(filters = Seq(EqualTo("i", 2)), expected = Seq()) + check( + requiredSchema = StructType.fromDDL("s STRING"), + filters = Seq(StringStartsWith("s", "b")), + expected = Seq()) + check( + requiredSchema = StructType.fromDDL("i INTEGER, s STRING"), + filters = Seq(StringStartsWith("s", "a")), + expected = Seq(InternalRow(1, UTF8String.fromString("a")))) + check( + input = "1,a,3.14", + dataSchema = StructType.fromDDL("i INTEGER, s STRING, d DOUBLE"), + requiredSchema = StructType.fromDDL("i INTEGER, d DOUBLE"), + filters = Seq(EqualTo("d", 3.14)), + expected = Seq(InternalRow(1, 3.14))) + + val errMsg = intercept[IllegalArgumentException] { + check(filters = Seq(EqualTo("invalid attr", 1)), expected = Seq()) + }.getMessage + assert(errMsg.contains("invalid attr does not exist")) + + val errMsg2 = intercept[IllegalArgumentException] { + check( + dataSchema = new StructType(), + requiredSchema = new StructType(), + filters = Seq(EqualTo("i", 1)), + expected = Seq(InternalRow.empty)) + }.getMessage + assert(errMsg2.contains("i does not exist")) + } } diff --git a/sql/core/benchmarks/CSVBenchmark-jdk11-results.txt b/sql/core/benchmarks/CSVBenchmark-jdk11-results.txt index 2d24a273f757..d8071e7bbdb3 100644 --- a/sql/core/benchmarks/CSVBenchmark-jdk11-results.txt +++ b/sql/core/benchmarks/CSVBenchmark-jdk11-results.txt @@ -56,4 +56,12 @@ parse dates from Dataset[String] 51026 51447 5 from_csv(timestamp) 60738 61818 936 0.2 6073.8 0.0X from_csv(date) 46012 46278 370 0.2 4601.2 0.1X +OpenJDK 64-Bit Server VM 11.0.5+10 on Mac OS X 10.15.2 +Intel(R) Core(TM) i7-4850HQ CPU @ 2.30GHz +Filters pushdown: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +w/o filters 11889 11945 52 0.0 118893.1 1.0X +pushdown disabled 11790 11860 115 0.0 117902.3 1.0X +w/ filters 1240 1278 33 0.1 12400.8 9.6X + diff --git a/sql/core/benchmarks/CSVBenchmark-results.txt b/sql/core/benchmarks/CSVBenchmark-results.txt index 0777549efc5f..b3ba69c9eb6b 100644 --- a/sql/core/benchmarks/CSVBenchmark-results.txt +++ b/sql/core/benchmarks/CSVBenchmark-results.txt @@ -56,4 +56,12 @@ parse dates from Dataset[String] 48728 49071 3 from_csv(timestamp) 62294 62493 260 0.2 6229.4 0.0X from_csv(date) 44581 44665 117 0.2 4458.1 0.1X +Java HotSpot(TM) 64-Bit Server VM 1.8.0_231-b11 on Mac OS X 10.15.2 +Intel(R) Core(TM) i7-4850HQ CPU @ 2.30GHz +Filters pushdown: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +w/o filters 12557 12634 78 0.0 125572.9 1.0X +pushdown disabled 12449 12509 65 0.0 124486.4 1.0X +w/ filters 1372 1393 18 0.1 13724.8 9.1X + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index a8b352407be8..1af4931c553e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -557,7 +557,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val parsed = linesWithoutHeader.mapPartitions { iter => val rawParser = new UnivocityParser(actualSchema, parsedOptions) val parser = new FailureSafeParser[String]( - input => Seq(rawParser.parse(input)), + input => rawParser.parse(input), parsedOptions.parseMode, schema, parsedOptions.columnNameOfCorruptRecord) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 8abc6fcacd4c..cbf9d2bac7ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -17,16 +17,13 @@ package org.apache.spark.sql.execution.datasources.csv -import java.nio.charset.Charset - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce._ -import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityGenerator, UnivocityParser} +import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ @@ -134,7 +131,11 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) val actualRequiredSchema = StructType( requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) - val parser = new UnivocityParser(actualDataSchema, actualRequiredSchema, parsedOptions) + val parser = new UnivocityParser( + actualDataSchema, + actualRequiredSchema, + parsedOptions, + filters) val schema = if (columnPruning) actualRequiredSchema else actualDataSchema val isStartOfFile = file.start == 0 val headerChecker = new CSVHeaderChecker( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala index a20b0f1560a1..31d31bd43f45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.sql.execution.datasources.csv.CSVDataSource import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration @@ -43,7 +44,8 @@ case class CSVPartitionReaderFactory( dataSchema: StructType, readDataSchema: StructType, partitionSchema: StructType, - parsedOptions: CSVOptions) extends FilePartitionReaderFactory { + parsedOptions: CSVOptions, + filters: Seq[Filter]) extends FilePartitionReaderFactory { private val columnPruning = sqlConf.csvColumnPruning override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = { @@ -55,7 +57,8 @@ case class CSVPartitionReaderFactory( val parser = new UnivocityParser( actualDataSchema, actualReadDataSchema, - parsedOptions) + parsedOptions, + filters) val schema = if (columnPruning) actualReadDataSchema else actualDataSchema val isStartOfFile = file.start == 0 val headerChecker = new CSVHeaderChecker( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala index 78b04aa811e0..690d66908e61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.csv.CSVDataSource import org.apache.spark.sql.execution.datasources.v2.{FileScan, TextBasedFileScan} +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -38,6 +39,7 @@ case class CSVScan( readDataSchema: StructType, readPartitionSchema: StructType, options: CaseInsensitiveStringMap, + pushedFilters: Array[Filter], partitionFilters: Seq[Expression] = Seq.empty) extends TextBasedFileScan(sparkSession, options) { @@ -86,17 +88,21 @@ case class CSVScan( // The partition values are already truncated in `FileScan.partitions`. // We should use `readPartitionSchema` as the partition schema here. CSVPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, - dataSchema, readDataSchema, readPartitionSchema, parsedOptions) + dataSchema, readDataSchema, readPartitionSchema, parsedOptions, pushedFilters) } override def withPartitionFilters(partitionFilters: Seq[Expression]): FileScan = this.copy(partitionFilters = partitionFilters) override def equals(obj: Any): Boolean = obj match { - case c: CSVScan => super.equals(c) && dataSchema == c.dataSchema && options == c.options - + case c: CSVScan => super.equals(c) && dataSchema == c.dataSchema && options == c.options && + equivalentFilters(pushedFilters, c.pushedFilters) case _ => false } override def hashCode(): Int = super.hashCode() + + override def description(): String = { + super.description() + ", PushedFilters: " + pushedFilters.mkString("[", ", ", "]") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala index 8b486d034450..81a234e25400 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala @@ -18,9 +18,11 @@ package org.apache.spark.sql.execution.datasources.v2.csv import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.Scan +import org.apache.spark.sql.catalyst.csv.CSVFilters +import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -30,9 +32,27 @@ case class CSVScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { override def build(): Scan = { - CSVScan(sparkSession, fileIndex, dataSchema, readDataSchema(), readPartitionSchema(), options) + CSVScan( + sparkSession, + fileIndex, + dataSchema, + readDataSchema(), + readPartitionSchema(), + options, + pushedFilters()) } + + private var _pushedFilters: Array[Filter] = Array.empty + + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + if (sparkSession.sessionState.conf.csvFilterPushDown) { + _pushedFilters = CSVFilters.pushedFilters(filters, dataSchema) + } + filters + } + + override def pushedFilters(): Array[Filter] = _pushedFilters } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmark.scala index ad80afa441de..e2abb39c986a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmark.scala @@ -23,6 +23,7 @@ import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{Column, Dataset, Row} import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -293,6 +294,50 @@ object CSVBenchmark extends SqlBasedBenchmark { } } + private def filtersPushdownBenchmark(rowsNum: Int, numIters: Int): Unit = { + val benchmark = new Benchmark(s"Filters pushdown", rowsNum, output = output) + val colsNum = 100 + val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", TimestampType)) + val schema = StructType(StructField("key", IntegerType) +: fields) + def columns(): Seq[Column] = { + val ts = Seq.tabulate(colsNum) { i => + lit(Instant.ofEpochSecond(i * 12345678)).as(s"col$i") + } + ($"id" % 1000).as("key") +: ts + } + withTempPath { path => + spark.range(rowsNum).select(columns(): _*) + .write.option("header", true) + .csv(path.getAbsolutePath) + def readback = { + spark.read + .option("header", true) + .schema(schema) + .csv(path.getAbsolutePath) + } + + benchmark.addCase(s"w/o filters", numIters) { _ => + readback.noop() + } + + def withFilter(configEnabled: Boolean): Unit = { + withSQLConf(SQLConf.CSV_FILTER_PUSHDOWN_ENABLED.key -> configEnabled.toString()) { + readback.filter($"key" === 0).noop() + } + } + + benchmark.addCase(s"pushdown disabled", numIters) { _ => + withFilter(configEnabled = false) + } + + benchmark.addCase(s"w/ filters", numIters) { _ => + withFilter(configEnabled = true) + } + + benchmark.run() + } + } + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { runBenchmark("Benchmark to measure CSV read/write performance") { val numIters = 3 @@ -300,6 +345,7 @@ object CSVBenchmark extends SqlBasedBenchmark { multiColumnsBenchmark(rowsNum = 1000 * 1000, numIters) countBenchmark(rowsNum = 10 * 1000 * 1000, numIters) datetimeBenchmark(rowsNum = 10 * 1000 * 1000, numIters) + filtersPushdownBenchmark(rowsNum = 100 * 1000, numIters) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index ae9aaf15aae9..846b5c594d42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -34,7 +34,7 @@ import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.{SparkException, TestUtils} -import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, Column, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -2195,4 +2195,79 @@ class CSVSuite extends QueryTest with SharedSparkSession with TestCsvData { checkAnswer(resultDF, Row("a", 2, "e", "c")) } } + + test("filters push down") { + Seq(true, false).foreach { filterPushdown => + Seq(true, false).foreach { columnPruning => + withSQLConf( + SQLConf.CSV_FILTER_PUSHDOWN_ENABLED.key -> filterPushdown.toString, + SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> columnPruning.toString) { + + withTempPath { path => + val t = "2019-12-17 00:01:02" + Seq( + "c0,c1,c2", + "abc,1,2019-11-14 20:35:30", + s"def,2,$t").toDF("data") + .repartition(1) + .write.text(path.getAbsolutePath) + Seq(true, false).foreach { multiLine => + Seq("PERMISSIVE", "DROPMALFORMED", "FAILFAST").foreach { mode => + val readback = spark.read + .option("mode", mode) + .option("header", true) + .option("timestampFormat", "uuuu-MM-dd HH:mm:ss") + .option("multiLine", multiLine) + .schema("c0 string, c1 integer, c2 timestamp") + .csv(path.getAbsolutePath) + .where($"c1" === 2) + .select($"c2") + // count() pushes empty schema. This checks handling of a filter + // which refers to not existed field. + assert(readback.count() === 1) + checkAnswer(readback, Row(Timestamp.valueOf(t))) + } + } + } + } + } + } + } + + test("filters push down - malformed input in PERMISSIVE mode") { + val invalidTs = "2019-123-14 20:35:30" + val invalidRow = s"0,$invalidTs,999" + val validTs = "2019-12-14 20:35:30" + Seq(true, false).foreach { filterPushdown => + withSQLConf(SQLConf.CSV_FILTER_PUSHDOWN_ENABLED.key -> filterPushdown.toString) { + withTempPath { path => + Seq( + "c0,c1,c2", + invalidRow, + s"1,$validTs,999").toDF("data") + .repartition(1) + .write.text(path.getAbsolutePath) + def checkReadback(condition: Column, expected: Seq[Row]): Unit = { + val readback = spark.read + .option("mode", "PERMISSIVE") + .option("columnNameOfCorruptRecord", "c3") + .option("header", true) + .option("timestampFormat", "uuuu-MM-dd HH:mm:ss") + .schema("c0 integer, c1 timestamp, c2 integer, c3 string") + .csv(path.getAbsolutePath) + .where(condition) + .select($"c0", $"c1", $"c3") + checkAnswer(readback, expected) + } + + checkReadback( + condition = $"c2" === 999, + expected = Seq(Row(0, null, invalidRow), Row(1, Timestamp.valueOf(validTs), null))) + checkReadback( + condition = $"c2" === 999 && $"c1" > "1970-01-01 00:00:00", + expected = Seq(Row(1, Timestamp.valueOf(validTs), null))) + } + } + } + } }