From 5095561ed3b89244f4344d4a3a4902c8702c2ecf Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 17 Mar 2019 16:05:32 +0800 Subject: [PATCH 1/4] Copy code for upgrade built-in Hive --- sql/core/pom.xml | 14 + .../datasources/orc/OrcColumnVector.java | 0 .../orc/OrcColumnarBatchReader.java | 0 .../datasources/orc/OrcDeserializer.scala | 0 .../datasources/orc/OrcFilters.scala | 0 .../datasources/orc/OrcSerializer.scala | 0 .../datasources/orc/OrcFilterSuite.scala | 0 .../datasources/orc/OrcV1FilterSuite.scala | 0 .../datasources/orc/OrcColumnVector.java | 193 ++++++++ .../orc/OrcColumnarBatchReader.java | 210 +++++++++ .../datasources/orc/OrcDeserializer.scala | 251 ++++++++++ .../datasources/orc/OrcFilters.scala | 276 +++++++++++ .../datasources/orc/OrcSerializer.scala | 228 ++++++++++ .../datasources/orc/OrcFilterSuite.scala | 429 ++++++++++++++++++ .../datasources/orc/OrcV1FilterSuite.scala | 107 +++++ 15 files changed, 1708 insertions(+) rename sql/core/{ => v1.2.1}/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java (100%) rename sql/core/{ => v1.2.1}/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java (100%) rename sql/core/{ => v1.2.1}/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala (100%) rename sql/core/{ => v1.2.1}/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala (100%) rename sql/core/{ => v1.2.1}/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala (100%) rename sql/core/{ => v1.2.1}/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala (100%) rename sql/core/{ => v1.2.1}/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala (100%) create mode 100644 sql/core/v2.3.4/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java create mode 100644 sql/core/v2.3.4/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java create mode 100644 sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala create mode 100644 sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala create mode 100644 sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala create mode 100644 sql/core/v2.3.4/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala create mode 100644 sql/core/v2.3.4/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala diff --git a/sql/core/pom.xml b/sql/core/pom.xml index ac5f1fc923e7..5ddfb02f0de3 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -189,6 +189,19 @@ org.codehaus.mojo build-helper-maven-plugin + + add-sources + generate-sources + + add-source + + + + v${hive.version.short}/src/main/scala + v${hive.version.short}/src/main/java + + + add-scala-test-sources generate-test-sources @@ -197,6 +210,7 @@ + v${hive.version.short}/src/test/scala src/test/gen-java diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/v1.2.1/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java rename to sql/core/v1.2.1/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/v1.2.1/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java rename to sql/core/v1.2.1/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala similarity index 100% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala rename to sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala similarity index 100% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala rename to sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala b/sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala similarity index 100% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala rename to sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v1.2.1/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala similarity index 100% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala rename to sql/core/v1.2.1/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala b/sql/core/v1.2.1/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala similarity index 100% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala rename to sql/core/v1.2.1/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala diff --git a/sql/core/v2.3.4/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/v2.3.4/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java new file mode 100644 index 000000000000..9bfad1e83ee7 --- /dev/null +++ b/sql/core/v2.3.4/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import java.math.BigDecimal; + +import org.apache.orc.storage.ql.exec.vector.*; + +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A column vector class wrapping Hive's ColumnVector. Because Spark ColumnarBatch only accepts + * Spark's vectorized.ColumnVector, this column vector is used to adapt Hive ColumnVector with + * Spark ColumnarVector. + */ +public class OrcColumnVector extends org.apache.spark.sql.vectorized.ColumnVector { + private ColumnVector baseData; + private LongColumnVector longData; + private DoubleColumnVector doubleData; + private BytesColumnVector bytesData; + private DecimalColumnVector decimalData; + private TimestampColumnVector timestampData; + private final boolean isTimestamp; + + private int batchSize; + + OrcColumnVector(DataType type, ColumnVector vector) { + super(type); + + if (type instanceof TimestampType) { + isTimestamp = true; + } else { + isTimestamp = false; + } + + baseData = vector; + if (vector instanceof LongColumnVector) { + longData = (LongColumnVector) vector; + } else if (vector instanceof DoubleColumnVector) { + doubleData = (DoubleColumnVector) vector; + } else if (vector instanceof BytesColumnVector) { + bytesData = (BytesColumnVector) vector; + } else if (vector instanceof DecimalColumnVector) { + decimalData = (DecimalColumnVector) vector; + } else if (vector instanceof TimestampColumnVector) { + timestampData = (TimestampColumnVector) vector; + } else { + throw new UnsupportedOperationException(); + } + } + + public void setBatchSize(int batchSize) { + this.batchSize = batchSize; + } + + @Override + public void close() { + + } + + @Override + public boolean hasNull() { + return !baseData.noNulls; + } + + @Override + public int numNulls() { + if (baseData.isRepeating) { + if (baseData.isNull[0]) { + return batchSize; + } else { + return 0; + } + } else if (baseData.noNulls) { + return 0; + } else { + int count = 0; + for (int i = 0; i < batchSize; i++) { + if (baseData.isNull[i]) count++; + } + return count; + } + } + + /* A helper method to get the row index in a column. */ + private int getRowIndex(int rowId) { + return baseData.isRepeating ? 0 : rowId; + } + + @Override + public boolean isNullAt(int rowId) { + return baseData.isNull[getRowIndex(rowId)]; + } + + @Override + public boolean getBoolean(int rowId) { + return longData.vector[getRowIndex(rowId)] == 1; + } + + @Override + public byte getByte(int rowId) { + return (byte) longData.vector[getRowIndex(rowId)]; + } + + @Override + public short getShort(int rowId) { + return (short) longData.vector[getRowIndex(rowId)]; + } + + @Override + public int getInt(int rowId) { + return (int) longData.vector[getRowIndex(rowId)]; + } + + @Override + public long getLong(int rowId) { + int index = getRowIndex(rowId); + if (isTimestamp) { + return timestampData.time[index] * 1000 + timestampData.nanos[index] / 1000 % 1000; + } else { + return longData.vector[index]; + } + } + + @Override + public float getFloat(int rowId) { + return (float) doubleData.vector[getRowIndex(rowId)]; + } + + @Override + public double getDouble(int rowId) { + return doubleData.vector[getRowIndex(rowId)]; + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; + BigDecimal data = decimalData.vector[getRowIndex(rowId)].getHiveDecimal().bigDecimalValue(); + return Decimal.apply(data, precision, scale); + } + + @Override + public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) return null; + int index = getRowIndex(rowId); + BytesColumnVector col = bytesData; + return UTF8String.fromBytes(col.vector[index], col.start[index], col.length[index]); + } + + @Override + public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) return null; + int index = getRowIndex(rowId); + byte[] binary = new byte[bytesData.length[index]]; + System.arraycopy(bytesData.vector[index], bytesData.start[index], binary, 0, binary.length); + return binary; + } + + @Override + public ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public org.apache.spark.sql.vectorized.ColumnVector getChild(int ordinal) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/core/v2.3.4/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/v2.3.4/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java new file mode 100644 index 000000000000..efca96e9ce58 --- /dev/null +++ b/sql/core/v2.3.4/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import java.io.IOException; + +import com.google.common.annotations.VisibleForTesting; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.RecordReader; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.hadoop.mapreduce.lib.input.FileSplit; +import org.apache.orc.OrcConf; +import org.apache.orc.OrcFile; +import org.apache.orc.Reader; +import org.apache.orc.TypeDescription; +import org.apache.orc.mapred.OrcInputFormat; +import org.apache.orc.storage.ql.exec.vector.*; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarBatch; + + +/** + * To support vectorization in WholeStageCodeGen, this reader returns ColumnarBatch. + * After creating, `initialize` and `initBatch` should be called sequentially. + */ +public class OrcColumnarBatchReader extends RecordReader { + + // The capacity of vectorized batch. + private int capacity; + + // Vectorized ORC Row Batch + private VectorizedRowBatch batch; + + /** + * The column IDs of the physical ORC file schema which are required by this reader. + * -1 means this required column is partition column, or it doesn't exist in the ORC file. + * Ideally partition column should never appear in the physical file, and should only appear + * in the directory name. However, Spark allows partition columns inside physical file, + * but Spark will discard the values from the file, and use the partition value got from + * directory name. The column order will be reserved though. + */ + @VisibleForTesting + public int[] requestedDataColIds; + + // Record reader from ORC row batch. + private org.apache.orc.RecordReader recordReader; + + private StructField[] requiredFields; + + // The result columnar batch for vectorized execution by whole-stage codegen. + @VisibleForTesting + public ColumnarBatch columnarBatch; + + // The wrapped ORC column vectors. + private org.apache.spark.sql.vectorized.ColumnVector[] orcVectorWrappers; + + public OrcColumnarBatchReader(int capacity) { + this.capacity = capacity; + } + + + @Override + public Void getCurrentKey() { + return null; + } + + @Override + public ColumnarBatch getCurrentValue() { + return columnarBatch; + } + + @Override + public float getProgress() throws IOException { + return recordReader.getProgress(); + } + + @Override + public boolean nextKeyValue() throws IOException { + return nextBatch(); + } + + @Override + public void close() throws IOException { + if (columnarBatch != null) { + columnarBatch.close(); + columnarBatch = null; + } + if (recordReader != null) { + recordReader.close(); + recordReader = null; + } + } + + /** + * Initialize ORC file reader and batch record reader. + * Please note that `initBatch` is needed to be called after this. + */ + @Override + public void initialize( + InputSplit inputSplit, TaskAttemptContext taskAttemptContext) throws IOException { + FileSplit fileSplit = (FileSplit)inputSplit; + Configuration conf = taskAttemptContext.getConfiguration(); + Reader reader = OrcFile.createReader( + fileSplit.getPath(), + OrcFile.readerOptions(conf) + .maxLength(OrcConf.MAX_FILE_LENGTH.getLong(conf)) + .filesystem(fileSplit.getPath().getFileSystem(conf))); + Reader.Options options = + OrcInputFormat.buildOptions(conf, reader, fileSplit.getStart(), fileSplit.getLength()); + recordReader = reader.rows(options); + } + + /** + * Initialize columnar batch by setting required schema and partition information. + * With this information, this creates ColumnarBatch with the full schema. + * + * @param orcSchema Schema from ORC file reader. + * @param requiredFields All the fields that are required to return, including partition fields. + * @param requestedDataColIds Requested column ids from orcSchema. -1 if not existed. + * @param requestedPartitionColIds Requested column ids from partition schema. -1 if not existed. + * @param partitionValues Values of partition columns. + */ + public void initBatch( + TypeDescription orcSchema, + StructField[] requiredFields, + int[] requestedDataColIds, + int[] requestedPartitionColIds, + InternalRow partitionValues) { + batch = orcSchema.createRowBatch(capacity); + assert(!batch.selectedInUse); // `selectedInUse` should be initialized with `false`. + assert(requiredFields.length == requestedDataColIds.length); + assert(requiredFields.length == requestedPartitionColIds.length); + // If a required column is also partition column, use partition value and don't read from file. + for (int i = 0; i < requiredFields.length; i++) { + if (requestedPartitionColIds[i] != -1) { + requestedDataColIds[i] = -1; + } + } + this.requiredFields = requiredFields; + this.requestedDataColIds = requestedDataColIds; + + StructType resultSchema = new StructType(requiredFields); + + // Just wrap the ORC column vector instead of copying it to Spark column vector. + orcVectorWrappers = new org.apache.spark.sql.vectorized.ColumnVector[resultSchema.length()]; + + for (int i = 0; i < requiredFields.length; i++) { + DataType dt = requiredFields[i].dataType(); + if (requestedPartitionColIds[i] != -1) { + OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, dt); + ColumnVectorUtils.populate(partitionCol, partitionValues, requestedPartitionColIds[i]); + partitionCol.setIsConstant(); + orcVectorWrappers[i] = partitionCol; + } else { + int colId = requestedDataColIds[i]; + // Initialize the missing columns once. + if (colId == -1) { + OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); + missingCol.putNulls(0, capacity); + missingCol.setIsConstant(); + orcVectorWrappers[i] = missingCol; + } else { + orcVectorWrappers[i] = new OrcColumnVector(dt, batch.cols[colId]); + } + } + } + + columnarBatch = new ColumnarBatch(orcVectorWrappers); + } + + /** + * Return true if there exists more data in the next batch. If exists, prepare the next batch + * by copying from ORC VectorizedRowBatch columns to Spark ColumnarBatch columns. + */ + private boolean nextBatch() throws IOException { + recordReader.nextBatch(batch); + int batchSize = batch.size; + if (batchSize == 0) { + return false; + } + columnarBatch.setNumRows(batchSize); + + for (int i = 0; i < requiredFields.length; i++) { + if (requestedDataColIds[i] != -1) { + ((OrcColumnVector) orcVectorWrappers[i]).setBatchSize(batchSize); + } + } + return true; + } +} diff --git a/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala new file mode 100644 index 000000000000..62e16707a8e3 --- /dev/null +++ b/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import org.apache.hadoop.io._ +import org.apache.orc.mapred.{OrcList, OrcMap, OrcStruct, OrcTimestamp} +import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * A deserializer to deserialize ORC structs to Spark rows. + */ +class OrcDeserializer( + dataSchema: StructType, + requiredSchema: StructType, + requestedColIds: Array[Int]) { + + private val resultRow = new SpecificInternalRow(requiredSchema.map(_.dataType)) + + // `fieldWriters(index)` is + // - null if the respective source column is missing, since the output value + // is always null in this case + // - a function that updates target column `index` otherwise. + private val fieldWriters: Array[WritableComparable[_] => Unit] = { + requiredSchema.zipWithIndex + .map { case (f, index) => + if (requestedColIds(index) == -1) { + null + } else { + val writer = newWriter(f.dataType, new RowUpdater(resultRow)) + (value: WritableComparable[_]) => writer(index, value) + } + }.toArray + } + + def deserialize(orcStruct: OrcStruct): InternalRow = { + var targetColumnIndex = 0 + while (targetColumnIndex < fieldWriters.length) { + if (fieldWriters(targetColumnIndex) != null) { + val value = orcStruct.getFieldValue(requestedColIds(targetColumnIndex)) + if (value == null) { + resultRow.setNullAt(targetColumnIndex) + } else { + fieldWriters(targetColumnIndex)(value) + } + } + targetColumnIndex += 1 + } + resultRow + } + + /** + * Creates a writer to write ORC values to Catalyst data structure at the given ordinal. + */ + private def newWriter( + dataType: DataType, updater: CatalystDataUpdater): (Int, WritableComparable[_]) => Unit = + dataType match { + case NullType => (ordinal, _) => + updater.setNullAt(ordinal) + + case BooleanType => (ordinal, value) => + updater.setBoolean(ordinal, value.asInstanceOf[BooleanWritable].get) + + case ByteType => (ordinal, value) => + updater.setByte(ordinal, value.asInstanceOf[ByteWritable].get) + + case ShortType => (ordinal, value) => + updater.setShort(ordinal, value.asInstanceOf[ShortWritable].get) + + case IntegerType => (ordinal, value) => + updater.setInt(ordinal, value.asInstanceOf[IntWritable].get) + + case LongType => (ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[LongWritable].get) + + case FloatType => (ordinal, value) => + updater.setFloat(ordinal, value.asInstanceOf[FloatWritable].get) + + case DoubleType => (ordinal, value) => + updater.setDouble(ordinal, value.asInstanceOf[DoubleWritable].get) + + case StringType => (ordinal, value) => + updater.set(ordinal, UTF8String.fromBytes(value.asInstanceOf[Text].copyBytes)) + + case BinaryType => (ordinal, value) => + val binary = value.asInstanceOf[BytesWritable] + val bytes = new Array[Byte](binary.getLength) + System.arraycopy(binary.getBytes, 0, bytes, 0, binary.getLength) + updater.set(ordinal, bytes) + + case DateType => (ordinal, value) => + updater.setInt(ordinal, DateTimeUtils.fromJavaDate(value.asInstanceOf[DateWritable].get)) + + case TimestampType => (ordinal, value) => + updater.setLong(ordinal, DateTimeUtils.fromJavaTimestamp(value.asInstanceOf[OrcTimestamp])) + + case DecimalType.Fixed(precision, scale) => (ordinal, value) => + val decimal = value.asInstanceOf[HiveDecimalWritable].getHiveDecimal() + val v = Decimal(decimal.bigDecimalValue, decimal.precision(), decimal.scale()) + v.changePrecision(precision, scale) + updater.set(ordinal, v) + + case st: StructType => (ordinal, value) => + val result = new SpecificInternalRow(st) + val fieldUpdater = new RowUpdater(result) + val fieldConverters = st.map(_.dataType).map { dt => + newWriter(dt, fieldUpdater) + }.toArray + val orcStruct = value.asInstanceOf[OrcStruct] + + var i = 0 + while (i < st.length) { + val value = orcStruct.getFieldValue(i) + if (value == null) { + result.setNullAt(i) + } else { + fieldConverters(i)(i, value) + } + i += 1 + } + + updater.set(ordinal, result) + + case ArrayType(elementType, _) => (ordinal, value) => + val orcArray = value.asInstanceOf[OrcList[WritableComparable[_]]] + val length = orcArray.size() + val result = createArrayData(elementType, length) + val elementUpdater = new ArrayDataUpdater(result) + val elementConverter = newWriter(elementType, elementUpdater) + + var i = 0 + while (i < length) { + val value = orcArray.get(i) + if (value == null) { + result.setNullAt(i) + } else { + elementConverter(i, value) + } + i += 1 + } + + updater.set(ordinal, result) + + case MapType(keyType, valueType, _) => (ordinal, value) => + val orcMap = value.asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] + val length = orcMap.size() + val keyArray = createArrayData(keyType, length) + val keyUpdater = new ArrayDataUpdater(keyArray) + val keyConverter = newWriter(keyType, keyUpdater) + val valueArray = createArrayData(valueType, length) + val valueUpdater = new ArrayDataUpdater(valueArray) + val valueConverter = newWriter(valueType, valueUpdater) + + var i = 0 + val it = orcMap.entrySet().iterator() + while (it.hasNext) { + val entry = it.next() + keyConverter(i, entry.getKey) + val value = entry.getValue + if (value == null) { + valueArray.setNullAt(i) + } else { + valueConverter(i, value) + } + i += 1 + } + + // The ORC map will never have null or duplicated map keys, it's safe to create a + // ArrayBasedMapData directly here. + updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray)) + + case udt: UserDefinedType[_] => newWriter(udt.sqlType, updater) + + case _ => + throw new UnsupportedOperationException(s"$dataType is not supported yet.") + } + + private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match { + case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length)) + case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length)) + case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length)) + case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length)) + case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length)) + case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length)) + case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length)) + case _ => new GenericArrayData(new Array[Any](length)) + } + + /** + * A base interface for updating values inside catalyst data structure like `InternalRow` and + * `ArrayData`. + */ + sealed trait CatalystDataUpdater { + def set(ordinal: Int, value: Any): Unit + + def setNullAt(ordinal: Int): Unit = set(ordinal, null) + def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value) + def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value) + def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value) + def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value) + def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value) + def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value) + def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value) + } + + final class RowUpdater(row: InternalRow) extends CatalystDataUpdater { + override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal) + override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value) + + override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value) + } + + final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater { + override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal) + override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value) + + override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value) + } +} diff --git a/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala new file mode 100644 index 000000000000..cd2a68a53bab --- /dev/null +++ b/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument} +import org.apache.orc.storage.ql.io.sarg.SearchArgument.Builder +import org.apache.orc.storage.ql.io.sarg.SearchArgumentFactory.newBuilder +import org.apache.orc.storage.serde2.io.HiveDecimalWritable + +import org.apache.spark.sql.sources.{And, Filter} +import org.apache.spark.sql.types._ + +/** + * Helper object for building ORC `SearchArgument`s, which are used for ORC predicate push-down. + * + * Due to limitation of ORC `SearchArgument` builder, we had to end up with a pretty weird double- + * checking pattern when converting `And`/`Or`/`Not` filters. + * + * An ORC `SearchArgument` must be built in one pass using a single builder. For example, you can't + * build `a = 1` and `b = 2` first, and then combine them into `a = 1 AND b = 2`. This is quite + * different from the cases in Spark SQL or Parquet, where complex filters can be easily built using + * existing simpler ones. + * + * The annoying part is that, `SearchArgument` builder methods like `startAnd()`, `startOr()`, and + * `startNot()` mutate internal state of the builder instance. This forces us to translate all + * convertible filters with a single builder instance. However, before actually converting a filter, + * we've no idea whether it can be recognized by ORC or not. Thus, when an inconvertible filter is + * found, we may already end up with a builder whose internal state is inconsistent. + * + * For example, to convert an `And` filter with builder `b`, we call `b.startAnd()` first, and then + * try to convert its children. Say we convert `left` child successfully, but find that `right` + * child is inconvertible. Alas, `b.startAnd()` call can't be rolled back, and `b` is inconsistent + * now. + * + * The workaround employed here is that, for `And`/`Or`/`Not`, we first try to convert their + * children with brand new builders, and only do the actual conversion with the right builder + * instance when the children are proven to be convertible. + * + * P.S.: Hive seems to use `SearchArgument` together with `ExprNodeGenericFuncDesc` only. Usage of + * builder methods mentioned above can only be found in test code, where all tested filters are + * known to be convertible. + */ +private[sql] object OrcFilters { + private[sql] def buildTree(filters: Seq[Filter]): Option[Filter] = { + filters match { + case Seq() => None + case Seq(filter) => Some(filter) + case Seq(filter1, filter2) => Some(And(filter1, filter2)) + case _ => // length > 2 + val (left, right) = filters.splitAt(filters.length / 2) + Some(And(buildTree(left).get, buildTree(right).get)) + } + } + + // Since ORC 1.5.0 (ORC-323), we need to quote for column names with `.` characters + // in order to distinguish predicate pushdown for nested columns. + private def quoteAttributeNameIfNeeded(name: String) : String = { + if (!name.contains("`") && name.contains(".")) { + s"`$name`" + } else { + name + } + } + + /** + * Create ORC filter as a SearchArgument instance. + */ + def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { + val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap + for { + // Combines all convertible filters using `And` to produce a single conjunction + conjunction <- buildTree(convertibleFilters(schema, dataTypeMap, filters)) + // Then tries to build a single ORC `SearchArgument` for the conjunction predicate + builder <- buildSearchArgument(dataTypeMap, conjunction, newBuilder) + } yield builder.build() + } + + def convertibleFilters( + schema: StructType, + dataTypeMap: Map[String, DataType], + filters: Seq[Filter]): Seq[Filter] = { + for { + filter <- filters + _ <- buildSearchArgument(dataTypeMap, filter, newBuilder()) + } yield filter + } + + /** + * Return true if this is a searchable type in ORC. + * Both CharType and VarcharType are cleaned at AstBuilder. + */ + private def isSearchableType(dataType: DataType) = dataType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } + + /** + * Get PredicateLeafType which is corresponding to the given DataType. + */ + private def getPredicateLeafType(dataType: DataType) = dataType match { + case BooleanType => PredicateLeaf.Type.BOOLEAN + case ByteType | ShortType | IntegerType | LongType => PredicateLeaf.Type.LONG + case FloatType | DoubleType => PredicateLeaf.Type.FLOAT + case StringType => PredicateLeaf.Type.STRING + case DateType => PredicateLeaf.Type.DATE + case TimestampType => PredicateLeaf.Type.TIMESTAMP + case _: DecimalType => PredicateLeaf.Type.DECIMAL + case _ => throw new UnsupportedOperationException(s"DataType: ${dataType.catalogString}") + } + + /** + * Cast literal values for filters. + * + * We need to cast to long because ORC raises exceptions + * at 'checkLiteralType' of SearchArgumentImpl.java. + */ + private def castLiteralValue(value: Any, dataType: DataType): Any = dataType match { + case ByteType | ShortType | IntegerType | LongType => + value.asInstanceOf[Number].longValue + case FloatType | DoubleType => + value.asInstanceOf[Number].doubleValue() + case _: DecimalType => + val decimal = value.asInstanceOf[java.math.BigDecimal] + val decimalWritable = new HiveDecimalWritable(decimal.longValue) + decimalWritable.mutateEnforcePrecisionScale(decimal.precision, decimal.scale) + decimalWritable + case _ => value + } + + /** + * Build a SearchArgument and return the builder so far. + */ + private def buildSearchArgument( + dataTypeMap: Map[String, DataType], + expression: Filter, + builder: Builder): Option[Builder] = { + createBuilder(dataTypeMap, expression, builder, canPartialPushDownConjuncts = true) + } + + /** + * @param dataTypeMap a map from the attribute name to its data type. + * @param expression the input filter predicates. + * @param builder the input SearchArgument.Builder. + * @param canPartialPushDownConjuncts whether a subset of conjuncts of predicates can be pushed + * down safely. Pushing ONLY one side of AND down is safe to + * do at the top level or none of its ancestors is NOT and OR. + * @return the builder so far. + */ + private def createBuilder( + dataTypeMap: Map[String, DataType], + expression: Filter, + builder: Builder, + canPartialPushDownConjuncts: Boolean): Option[Builder] = { + def getType(attribute: String): PredicateLeaf.Type = + getPredicateLeafType(dataTypeMap(attribute)) + + import org.apache.spark.sql.sources._ + + expression match { + case And(left, right) => + // At here, it is not safe to just convert one side and remove the other side + // if we do not understand what the parent filters are. + // + // Here is an example used to explain the reason. + // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to + // convert b in ('1'). If we only convert a = 2, we will end up with a filter + // NOT(a = 2), which will generate wrong results. + // + // Pushing one side of AND down is only safe to do at the top level or in the child + // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate + // can be safely removed. + val leftBuilderOption = + createBuilder(dataTypeMap, left, newBuilder, canPartialPushDownConjuncts) + val rightBuilderOption = + createBuilder(dataTypeMap, right, newBuilder, canPartialPushDownConjuncts) + (leftBuilderOption, rightBuilderOption) match { + case (Some(_), Some(_)) => + for { + lhs <- createBuilder(dataTypeMap, left, + builder.startAnd(), canPartialPushDownConjuncts) + rhs <- createBuilder(dataTypeMap, right, lhs, canPartialPushDownConjuncts) + } yield rhs.end() + + case (Some(_), None) if canPartialPushDownConjuncts => + createBuilder(dataTypeMap, left, builder, canPartialPushDownConjuncts) + + case (None, Some(_)) if canPartialPushDownConjuncts => + createBuilder(dataTypeMap, right, builder, canPartialPushDownConjuncts) + + case _ => None + } + + case Or(left, right) => + for { + _ <- createBuilder(dataTypeMap, left, newBuilder, canPartialPushDownConjuncts = false) + _ <- createBuilder(dataTypeMap, right, newBuilder, canPartialPushDownConjuncts = false) + lhs <- createBuilder(dataTypeMap, left, + builder.startOr(), canPartialPushDownConjuncts = false) + rhs <- createBuilder(dataTypeMap, right, lhs, canPartialPushDownConjuncts = false) + } yield rhs.end() + + case Not(child) => + for { + _ <- createBuilder(dataTypeMap, child, newBuilder, canPartialPushDownConjuncts = false) + negate <- createBuilder(dataTypeMap, + child, builder.startNot(), canPartialPushDownConjuncts = false) + } yield negate.end() + + // NOTE: For all case branches dealing with leaf predicates below, the additional `startAnd()` + // call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be + // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). + + case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startAnd().equals(quotedName, getType(attribute), castedValue).end()) + + case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startAnd().nullSafeEquals(quotedName, getType(attribute), castedValue).end()) + + case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startAnd().lessThan(quotedName, getType(attribute), castedValue).end()) + + case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startAnd().lessThanEquals(quotedName, getType(attribute), castedValue).end()) + + case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startNot().lessThanEquals(quotedName, getType(attribute), castedValue).end()) + + case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) + val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + Some(builder.startNot().lessThan(quotedName, getType(attribute), castedValue).end()) + + case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) + Some(builder.startAnd().isNull(quotedName, getType(attribute)).end()) + + case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) + Some(builder.startNot().isNull(quotedName, getType(attribute)).end()) + + case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) + val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(attribute))) + Some(builder.startAnd().in(quotedName, getType(attribute), + castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) + + case _ => None + } + } +} diff --git a/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala b/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala new file mode 100644 index 000000000000..90d126802809 --- /dev/null +++ b/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import org.apache.hadoop.io._ +import org.apache.orc.TypeDescription +import org.apache.orc.mapred.{OrcList, OrcMap, OrcStruct, OrcTimestamp} +import org.apache.orc.storage.common.`type`.HiveDecimal +import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.types._ + +/** + * A serializer to serialize Spark rows to ORC structs. + */ +class OrcSerializer(dataSchema: StructType) { + + private val result = createOrcValue(dataSchema).asInstanceOf[OrcStruct] + private val converters = dataSchema.map(_.dataType).map(newConverter(_)).toArray + + def serialize(row: InternalRow): OrcStruct = { + var i = 0 + while (i < converters.length) { + if (row.isNullAt(i)) { + result.setFieldValue(i, null) + } else { + result.setFieldValue(i, converters(i)(row, i)) + } + i += 1 + } + result + } + + private type Converter = (SpecializedGetters, Int) => WritableComparable[_] + + /** + * Creates a converter to convert Catalyst data at the given ordinal to ORC values. + */ + private def newConverter( + dataType: DataType, + reuseObj: Boolean = true): Converter = dataType match { + case NullType => (getter, ordinal) => null + + case BooleanType => + if (reuseObj) { + val result = new BooleanWritable() + (getter, ordinal) => + result.set(getter.getBoolean(ordinal)) + result + } else { + (getter, ordinal) => new BooleanWritable(getter.getBoolean(ordinal)) + } + + case ByteType => + if (reuseObj) { + val result = new ByteWritable() + (getter, ordinal) => + result.set(getter.getByte(ordinal)) + result + } else { + (getter, ordinal) => new ByteWritable(getter.getByte(ordinal)) + } + + case ShortType => + if (reuseObj) { + val result = new ShortWritable() + (getter, ordinal) => + result.set(getter.getShort(ordinal)) + result + } else { + (getter, ordinal) => new ShortWritable(getter.getShort(ordinal)) + } + + case IntegerType => + if (reuseObj) { + val result = new IntWritable() + (getter, ordinal) => + result.set(getter.getInt(ordinal)) + result + } else { + (getter, ordinal) => new IntWritable(getter.getInt(ordinal)) + } + + + case LongType => + if (reuseObj) { + val result = new LongWritable() + (getter, ordinal) => + result.set(getter.getLong(ordinal)) + result + } else { + (getter, ordinal) => new LongWritable(getter.getLong(ordinal)) + } + + case FloatType => + if (reuseObj) { + val result = new FloatWritable() + (getter, ordinal) => + result.set(getter.getFloat(ordinal)) + result + } else { + (getter, ordinal) => new FloatWritable(getter.getFloat(ordinal)) + } + + case DoubleType => + if (reuseObj) { + val result = new DoubleWritable() + (getter, ordinal) => + result.set(getter.getDouble(ordinal)) + result + } else { + (getter, ordinal) => new DoubleWritable(getter.getDouble(ordinal)) + } + + + // Don't reuse the result object for string and binary as it would cause extra data copy. + case StringType => (getter, ordinal) => + new Text(getter.getUTF8String(ordinal).getBytes) + + case BinaryType => (getter, ordinal) => + new BytesWritable(getter.getBinary(ordinal)) + + case DateType => + if (reuseObj) { + val result = new DateWritable() + (getter, ordinal) => + result.set(getter.getInt(ordinal)) + result + } else { + (getter, ordinal) => new DateWritable(getter.getInt(ordinal)) + } + + // The following cases are already expensive, reusing object or not doesn't matter. + + case TimestampType => (getter, ordinal) => + val ts = DateTimeUtils.toJavaTimestamp(getter.getLong(ordinal)) + val result = new OrcTimestamp(ts.getTime) + result.setNanos(ts.getNanos) + result + + case DecimalType.Fixed(precision, scale) => (getter, ordinal) => + val d = getter.getDecimal(ordinal, precision, scale) + new HiveDecimalWritable(HiveDecimal.create(d.toJavaBigDecimal)) + + case st: StructType => (getter, ordinal) => + val result = createOrcValue(st).asInstanceOf[OrcStruct] + val fieldConverters = st.map(_.dataType).map(newConverter(_)) + val numFields = st.length + val struct = getter.getStruct(ordinal, numFields) + var i = 0 + while (i < numFields) { + if (struct.isNullAt(i)) { + result.setFieldValue(i, null) + } else { + result.setFieldValue(i, fieldConverters(i)(struct, i)) + } + i += 1 + } + result + + case ArrayType(elementType, _) => (getter, ordinal) => + val result = createOrcValue(dataType).asInstanceOf[OrcList[WritableComparable[_]]] + // Need to put all converted values to a list, can't reuse object. + val elementConverter = newConverter(elementType, reuseObj = false) + val array = getter.getArray(ordinal) + var i = 0 + while (i < array.numElements()) { + if (array.isNullAt(i)) { + result.add(null) + } else { + result.add(elementConverter(array, i)) + } + i += 1 + } + result + + case MapType(keyType, valueType, _) => (getter, ordinal) => + val result = createOrcValue(dataType) + .asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] + // Need to put all converted values to a list, can't reuse object. + val keyConverter = newConverter(keyType, reuseObj = false) + val valueConverter = newConverter(valueType, reuseObj = false) + val map = getter.getMap(ordinal) + val keyArray = map.keyArray() + val valueArray = map.valueArray() + var i = 0 + while (i < map.numElements()) { + val key = keyConverter(keyArray, i) + if (valueArray.isNullAt(i)) { + result.put(key, null) + } else { + result.put(key, valueConverter(valueArray, i)) + } + i += 1 + } + result + + case udt: UserDefinedType[_] => newConverter(udt.sqlType) + + case _ => + throw new UnsupportedOperationException(s"$dataType is not supported yet.") + } + + /** + * Return a Orc value object for the given Spark schema. + */ + private def createOrcValue(dataType: DataType) = { + OrcStruct.createValue(TypeDescription.fromString(OrcFileFormat.getQuotedSchemaString(dataType))) + } +} diff --git a/sql/core/v2.3.4/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v2.3.4/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala new file mode 100644 index 000000000000..034454d21d7a --- /dev/null +++ b/sql/core/v2.3.4/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -0,0 +1,429 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} + +import scala.collection.JavaConverters._ + +import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument} + +import org.apache.spark.sql.{AnalysisException, Column, DataFrame} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.orc.OrcTable +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ + +/** + * A test suite that tests Apache ORC filter API based filter pushdown optimization. + * OrcFilterSuite and HiveOrcFilterSuite is logically duplicated to provide the same test coverage. + * The difference are the packages containing 'Predicate' and 'SearchArgument' classes. + * - OrcFilterSuite uses 'org.apache.orc.storage.ql.io.sarg' package. + * - HiveOrcFilterSuite uses 'org.apache.hadoop.hive.ql.io.sarg' package. + */ +class OrcFilterSuite extends OrcTest with SharedSQLContext { + + protected def checkFilterPredicate( + df: DataFrame, + predicate: Predicate, + checker: (SearchArgument) => Unit): Unit = { + val output = predicate.collect { case a: Attribute => a }.distinct + val query = df + .select(output.map(e => Column(e)): _*) + .where(Column(predicate)) + + query.queryExecution.optimizedPlan match { + case PhysicalOperation(_, filters, + DataSourceV2Relation(orcTable: OrcTable, _, options)) => + assert(filters.nonEmpty, "No filter is analyzed from the given query") + val scanBuilder = orcTable.newScanBuilder(options) + scanBuilder.pushFilters(filters.flatMap(DataSourceStrategy.translateFilter).toArray) + val pushedFilters = scanBuilder.pushedFilters() + assert(pushedFilters.nonEmpty, "No filter is pushed down") + val maybeFilter = OrcFilters.createFilter(query.schema, pushedFilters) + assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pushedFilters") + checker(maybeFilter.get) + + case _ => + throw new AnalysisException("Can not match OrcTable in the query.") + } + } + + protected def checkFilterPredicate + (predicate: Predicate, filterOperator: PredicateLeaf.Operator) + (implicit df: DataFrame): Unit = { + def checkComparisonOperator(filter: SearchArgument) = { + val operator = filter.getLeaves.asScala + assert(operator.map(_.getOperator).contains(filterOperator)) + } + checkFilterPredicate(df, predicate, checkComparisonOperator) + } + + protected def checkFilterPredicate + (predicate: Predicate, stringExpr: String) + (implicit df: DataFrame): Unit = { + def checkLogicalOperator(filter: SearchArgument) = { + assert(filter.toString == stringExpr) + } + checkFilterPredicate(df, predicate, checkLogicalOperator) + } + + protected def checkNoFilterPredicate + (predicate: Predicate, noneSupported: Boolean = false) + (implicit df: DataFrame): Unit = { + val output = predicate.collect { case a: Attribute => a }.distinct + val query = df + .select(output.map(e => Column(e)): _*) + .where(Column(predicate)) + + query.queryExecution.optimizedPlan match { + case PhysicalOperation(_, filters, + DataSourceV2Relation(orcTable: OrcTable, _, options)) => + assert(filters.nonEmpty, "No filter is analyzed from the given query") + val scanBuilder = orcTable.newScanBuilder(options) + scanBuilder.pushFilters(filters.flatMap(DataSourceStrategy.translateFilter).toArray) + val pushedFilters = scanBuilder.pushedFilters() + if (noneSupported) { + assert(pushedFilters.isEmpty, "Unsupported filters should not show in pushed filters") + } else { + assert(pushedFilters.nonEmpty, "No filter is pushed down") + val maybeFilter = OrcFilters.createFilter(query.schema, pushedFilters) + assert(maybeFilter.isEmpty, s"Couldn't generate filter predicate for $pushedFilters") + } + + case _ => + throw new AnalysisException("Can not match OrcTable in the query.") + } + } + + test("filter pushdown - integer") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - long") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - float") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - double") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - string") { + withOrcDataFrame((1 to 4).map(i => Tuple1(i.toString))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === "1", PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> "1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < "2", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > "3", PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= "1", PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= "4", PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal("1") === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal("1") <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal("2") > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal("3") < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal("1") >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal("4") <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - boolean") { + withOrcDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === true, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> true, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < true, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > false, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= false, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= false, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(false) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(false) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(false) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(true) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(true) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(true) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - decimal") { + withOrcDataFrame((1 to 4).map(i => Tuple1.apply(BigDecimal.valueOf(i)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === BigDecimal.valueOf(1), PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> BigDecimal.valueOf(1), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < BigDecimal.valueOf(2), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > BigDecimal.valueOf(3), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= BigDecimal.valueOf(1), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= BigDecimal.valueOf(4), PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate( + Literal(BigDecimal.valueOf(1)) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate( + Literal(BigDecimal.valueOf(1)) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate( + Literal(BigDecimal.valueOf(2)) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate( + Literal(BigDecimal.valueOf(3)) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate( + Literal(BigDecimal.valueOf(1)) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate( + Literal(BigDecimal.valueOf(4)) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - timestamp") { + val timeString = "2015-08-20 14:57:00" + val timestamps = (1 to 4).map { i => + val milliseconds = Timestamp.valueOf(timeString).getTime + i * 3600 + new Timestamp(milliseconds) + } + withOrcDataFrame(timestamps.map(Tuple1(_))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === timestamps(0), PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> timestamps(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < timestamps(1), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > timestamps(2), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= timestamps(0), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= timestamps(3), PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(timestamps(0)) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(timestamps(0)) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(timestamps(1)) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(timestamps(2)) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(timestamps(0)) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(timestamps(3)) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - combinations with logical operators") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => + checkFilterPredicate( + '_1.isNotNull, + "leaf-0 = (IS_NULL _1), expr = (not leaf-0)" + ) + checkFilterPredicate( + '_1 =!= 1, + "leaf-0 = (IS_NULL _1), leaf-1 = (EQUALS _1 1), expr = (and (not leaf-0) (not leaf-1))" + ) + checkFilterPredicate( + !('_1 < 4), + "leaf-0 = (IS_NULL _1), leaf-1 = (LESS_THAN _1 4), expr = (and (not leaf-0) (not leaf-1))" + ) + checkFilterPredicate( + '_1 < 2 || '_1 > 3, + "leaf-0 = (LESS_THAN _1 2), leaf-1 = (LESS_THAN_EQUALS _1 3), " + + "expr = (or leaf-0 (not leaf-1))" + ) + checkFilterPredicate( + '_1 < 2 && '_1 > 3, + "leaf-0 = (IS_NULL _1), leaf-1 = (LESS_THAN _1 2), leaf-2 = (LESS_THAN_EQUALS _1 3), " + + "expr = (and (not leaf-0) leaf-1 (not leaf-2))" + ) + } + } + + test("filter pushdown - date") { + val dates = Seq("2017-08-18", "2017-08-19", "2017-08-20", "2017-08-21").map { day => + Date.valueOf(day) + } + withOrcDataFrame(dates.map(Tuple1(_))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === dates(0), PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> dates(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < dates(1), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > dates(2), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= dates(0), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= dates(3), PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(dates(0)) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(dates(0)) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(dates(1)) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(dates(2)) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(dates(0)) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(dates(3)) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("no filter pushdown - non-supported types") { + implicit class IntToBinary(int: Int) { + def b: Array[Byte] = int.toString.getBytes(StandardCharsets.UTF_8) + } + // ArrayType + withOrcDataFrame((1 to 4).map(i => Tuple1(Array(i)))) { implicit df => + checkNoFilterPredicate('_1.isNull, noneSupported = true) + } + // BinaryType + withOrcDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df => + checkNoFilterPredicate('_1 <=> 1.b, noneSupported = true) + } + // MapType + withOrcDataFrame((1 to 4).map(i => Tuple1(Map(i -> i)))) { implicit df => + checkNoFilterPredicate('_1.isNotNull, noneSupported = true) + } + } + + test("SPARK-12218 and SPARK-25699 Converting conjunctions into ORC SearchArguments") { + import org.apache.spark.sql.sources._ + // The `LessThan` should be converted while the `StringContains` shouldn't + val schema = new StructType( + Array( + StructField("a", IntegerType, nullable = true), + StructField("b", StringType, nullable = true))) + assertResult("leaf-0 = (LESS_THAN a 10), expr = leaf-0") { + OrcFilters.createFilter(schema, Array( + LessThan("a", 10), + StringContains("b", "prefix") + )).get.toString + } + + // The `LessThan` should be converted while the whole inner `And` shouldn't + assertResult("leaf-0 = (LESS_THAN a 10), expr = leaf-0") { + OrcFilters.createFilter(schema, Array( + LessThan("a", 10), + Not(And( + GreaterThan("a", 1), + StringContains("b", "prefix") + )) + )).get.toString + } + + // Can not remove unsupported `StringContains` predicate since it is under `Or` operator. + assert(OrcFilters.createFilter(schema, Array( + Or( + LessThan("a", 10), + And( + StringContains("b", "prefix"), + GreaterThan("a", 1) + ) + ) + )).isEmpty) + + // Safely remove unsupported `StringContains` predicate and push down `LessThan` + assertResult("leaf-0 = (LESS_THAN a 10), expr = leaf-0") { + OrcFilters.createFilter(schema, Array( + And( + LessThan("a", 10), + StringContains("b", "prefix") + ) + )).get.toString + } + + // Safely remove unsupported `StringContains` predicate, push down `LessThan` and `GreaterThan`. + assertResult("leaf-0 = (LESS_THAN a 10), leaf-1 = (LESS_THAN_EQUALS a 1)," + + " expr = (and leaf-0 (not leaf-1))") { + OrcFilters.createFilter(schema, Array( + And( + And( + LessThan("a", 10), + StringContains("b", "prefix") + ), + GreaterThan("a", 1) + ) + )).get.toString + } + } +} + diff --git a/sql/core/v2.3.4/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala b/sql/core/v2.3.4/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala new file mode 100644 index 000000000000..5a1bf9b43756 --- /dev/null +++ b/sql/core/v2.3.4/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.orc + +import scala.collection.JavaConverters._ + +import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument} + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, Predicate} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.internal.SQLConf + +class OrcV1FilterSuite extends OrcFilterSuite { + + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_READER_LIST, "orc") + .set(SQLConf.USE_V1_SOURCE_WRITER_LIST, "orc") + + override def checkFilterPredicate( + df: DataFrame, + predicate: Predicate, + checker: (SearchArgument) => Unit): Unit = { + val output = predicate.collect { case a: Attribute => a }.distinct + val query = df + .select(output.map(e => Column(e)): _*) + .where(Column(predicate)) + + var maybeRelation: Option[HadoopFsRelation] = None + val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _, _)) => + maybeRelation = Some(orcRelation) + filters + }.flatten.reduceLeftOption(_ && _) + assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") + + val (_, selectedFilters, _) = + DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) + assert(selectedFilters.nonEmpty, "No filter is pushed down") + + val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters) + assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $selectedFilters") + checker(maybeFilter.get) + } + + override def checkFilterPredicate + (predicate: Predicate, filterOperator: PredicateLeaf.Operator) + (implicit df: DataFrame): Unit = { + def checkComparisonOperator(filter: SearchArgument) = { + val operator = filter.getLeaves.asScala + assert(operator.map(_.getOperator).contains(filterOperator)) + } + checkFilterPredicate(df, predicate, checkComparisonOperator) + } + + override def checkFilterPredicate + (predicate: Predicate, stringExpr: String) + (implicit df: DataFrame): Unit = { + def checkLogicalOperator(filter: SearchArgument) = { + assert(filter.toString == stringExpr) + } + checkFilterPredicate(df, predicate, checkLogicalOperator) + } + + override def checkNoFilterPredicate + (predicate: Predicate, noneSupported: Boolean = false) + (implicit df: DataFrame): Unit = { + val output = predicate.collect { case a: Attribute => a }.distinct + val query = df + .select(output.map(e => Column(e)): _*) + .where(Column(predicate)) + + var maybeRelation: Option[HadoopFsRelation] = None + val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _, _)) => + maybeRelation = Some(orcRelation) + filters + }.flatten.reduceLeftOption(_ && _) + assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") + + val (_, selectedFilters, _) = + DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) + assert(selectedFilters.nonEmpty, "No filter is pushed down") + + val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters) + assert(maybeFilter.isEmpty, s"Could generate filter predicate for $selectedFilters") + } +} From 511ae8dd7c926dedbb6db630ed6a43df38fca348 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 19 Mar 2019 14:44:24 +0800 Subject: [PATCH 2/4] revert OrcDeserializer and OrcSerializer --- .../datasources/orc/OrcDeserializer.scala | 6 +- .../datasources/orc/OrcSerializer.scala | 16 +- .../execution/datasources/orc/OrcTest.scala | 34 ++- .../datasources/orc/OrcShimUtils.scala | 58 ++++ .../datasources/orc/OrcFilterSuite.scala | 28 -- .../datasources/orc/OrcDeserializer.scala | 251 ------------------ .../datasources/orc/OrcSerializer.scala | 228 ---------------- .../datasources/orc/OrcShimUtils.scala | 58 ++++ .../datasources/orc/OrcFilterSuite.scala | 28 -- 9 files changed, 154 insertions(+), 553 deletions(-) rename sql/core/{v1.2.1 => }/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala (96%) rename sql/core/{v1.2.1 => }/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala (92%) create mode 100644 sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcShimUtils.scala delete mode 100644 sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala delete mode 100644 sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala create mode 100644 sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcShimUtils.scala diff --git a/sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala similarity index 96% rename from sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala index 62e16707a8e3..6d52d40d6dd0 100644 --- a/sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.datasources.orc import org.apache.hadoop.io._ import org.apache.orc.mapred.{OrcList, OrcMap, OrcStruct, OrcTimestamp} -import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData} @@ -109,14 +108,13 @@ class OrcDeserializer( updater.set(ordinal, bytes) case DateType => (ordinal, value) => - updater.setInt(ordinal, DateTimeUtils.fromJavaDate(value.asInstanceOf[DateWritable].get)) + updater.setInt(ordinal, DateTimeUtils.fromJavaDate(OrcShimUtils.getSqlDate(value))) case TimestampType => (ordinal, value) => updater.setLong(ordinal, DateTimeUtils.fromJavaTimestamp(value.asInstanceOf[OrcTimestamp])) case DecimalType.Fixed(precision, scale) => (ordinal, value) => - val decimal = value.asInstanceOf[HiveDecimalWritable].getHiveDecimal() - val v = Decimal(decimal.bigDecimalValue, decimal.precision(), decimal.scale()) + val v = OrcShimUtils.getDecimal(value) v.changePrecision(precision, scale) updater.set(ordinal, v) diff --git a/sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala similarity index 92% rename from sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala index 90d126802809..0b9cbecd0d32 100644 --- a/sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.execution.datasources.orc import org.apache.hadoop.io._ import org.apache.orc.TypeDescription import org.apache.orc.mapred.{OrcList, OrcMap, OrcStruct, OrcTimestamp} -import org.apache.orc.storage.common.`type`.HiveDecimal -import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters @@ -139,14 +137,7 @@ class OrcSerializer(dataSchema: StructType) { new BytesWritable(getter.getBinary(ordinal)) case DateType => - if (reuseObj) { - val result = new DateWritable() - (getter, ordinal) => - result.set(getter.getInt(ordinal)) - result - } else { - (getter, ordinal) => new DateWritable(getter.getInt(ordinal)) - } + OrcShimUtils.getDateWritable(reuseObj) // The following cases are already expensive, reusing object or not doesn't matter. @@ -156,9 +147,8 @@ class OrcSerializer(dataSchema: StructType) { result.setNanos(ts.getNanos) result - case DecimalType.Fixed(precision, scale) => (getter, ordinal) => - val d = getter.getDecimal(ordinal, precision, scale) - new HiveDecimalWritable(HiveDecimal.create(d.toJavaBigDecimal)) + case DecimalType.Fixed(precision, scale) => + OrcShimUtils.getHiveDecimalWritable(precision, scale) case st: StructType => (getter, ordinal) => val result = createOrcValue(st).asInstanceOf[OrcStruct] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala index 411e632f95c1..adbd93dcb4fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala @@ -25,7 +25,11 @@ import scala.reflect.runtime.universe.TypeTag import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ -import org.apache.spark.sql.execution.datasources.FileBasedDataSourceTest +import org.apache.spark.sql.catalyst.expressions.{Attribute, Predicate} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileBasedDataSourceTest} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.orc.OrcTable import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.ORC_IMPLEMENTATION @@ -104,4 +108,32 @@ abstract class OrcTest extends QueryTest with FileBasedDataSourceTest with Befor assert(actual < numRows) } } + + protected def checkNoFilterPredicate + (predicate: Predicate, noneSupported: Boolean = false) + (implicit df: DataFrame): Unit = { + val output = predicate.collect { case a: Attribute => a }.distinct + val query = df + .select(output.map(e => Column(e)): _*) + .where(Column(predicate)) + + query.queryExecution.optimizedPlan match { + case PhysicalOperation(_, filters, + DataSourceV2Relation(orcTable: OrcTable, _, options)) => + assert(filters.nonEmpty, "No filter is analyzed from the given query") + val scanBuilder = orcTable.newScanBuilder(options) + scanBuilder.pushFilters(filters.flatMap(DataSourceStrategy.translateFilter).toArray) + val pushedFilters = scanBuilder.pushedFilters() + if (noneSupported) { + assert(pushedFilters.isEmpty, "Unsupported filters should not show in pushed filters") + } else { + assert(pushedFilters.nonEmpty, "No filter is pushed down") + val maybeFilter = OrcFilters.createFilter(query.schema, pushedFilters) + assert(maybeFilter.isEmpty, s"Couldn't generate filter predicate for $pushedFilters") + } + + case _ => + throw new AnalysisException("Can not match OrcTable in the query.") + } + } } diff --git a/sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcShimUtils.scala b/sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcShimUtils.scala new file mode 100644 index 000000000000..1dce379b9f53 --- /dev/null +++ b/sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcShimUtils.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import java.sql.Date + +import org.apache.orc.storage.common.`type`.HiveDecimal +import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} + +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.types.Decimal + +/** + * Helper functions for Orc serialize and deserialize. + */ +private[spark] object OrcShimUtils { + + def getSqlDate(value: Any): Date = value.asInstanceOf[DateWritable].get + + def getDecimal(value: Any): Decimal = { + val decimal = value.asInstanceOf[HiveDecimalWritable].getHiveDecimal() + Decimal(decimal.bigDecimalValue, decimal.precision(), decimal.scale()) + } + + def getDateWritable(reuseObj: Boolean): (SpecializedGetters, Int) => DateWritable = { + if (reuseObj) { + val result = new DateWritable() + (getter, ordinal) => + result.set(getter.getInt(ordinal)) + result + } else { + (getter: SpecializedGetters, ordinal: Int) => + new DateWritable(getter.getInt(ordinal)) + } + } + + def getHiveDecimalWritable(precision: Int, scale: Int): + (SpecializedGetters, Int) => HiveDecimalWritable = { + (getter, ordinal) => + val d = getter.getDecimal(ordinal, precision, scale) + new HiveDecimalWritable(HiveDecimal.create(d.toJavaBigDecimal)) + } +} diff --git a/sql/core/v1.2.1/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v1.2.1/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index 034454d21d7a..981821371a0d 100644 --- a/sql/core/v1.2.1/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v1.2.1/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -89,34 +89,6 @@ class OrcFilterSuite extends OrcTest with SharedSQLContext { checkFilterPredicate(df, predicate, checkLogicalOperator) } - protected def checkNoFilterPredicate - (predicate: Predicate, noneSupported: Boolean = false) - (implicit df: DataFrame): Unit = { - val output = predicate.collect { case a: Attribute => a }.distinct - val query = df - .select(output.map(e => Column(e)): _*) - .where(Column(predicate)) - - query.queryExecution.optimizedPlan match { - case PhysicalOperation(_, filters, - DataSourceV2Relation(orcTable: OrcTable, _, options)) => - assert(filters.nonEmpty, "No filter is analyzed from the given query") - val scanBuilder = orcTable.newScanBuilder(options) - scanBuilder.pushFilters(filters.flatMap(DataSourceStrategy.translateFilter).toArray) - val pushedFilters = scanBuilder.pushedFilters() - if (noneSupported) { - assert(pushedFilters.isEmpty, "Unsupported filters should not show in pushed filters") - } else { - assert(pushedFilters.nonEmpty, "No filter is pushed down") - val maybeFilter = OrcFilters.createFilter(query.schema, pushedFilters) - assert(maybeFilter.isEmpty, s"Couldn't generate filter predicate for $pushedFilters") - } - - case _ => - throw new AnalysisException("Can not match OrcTable in the query.") - } - } - test("filter pushdown - integer") { withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) diff --git a/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala deleted file mode 100644 index 62e16707a8e3..000000000000 --- a/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala +++ /dev/null @@ -1,251 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.orc - -import org.apache.hadoop.io._ -import org.apache.orc.mapred.{OrcList, OrcMap, OrcStruct, OrcTimestamp} -import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData} -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -/** - * A deserializer to deserialize ORC structs to Spark rows. - */ -class OrcDeserializer( - dataSchema: StructType, - requiredSchema: StructType, - requestedColIds: Array[Int]) { - - private val resultRow = new SpecificInternalRow(requiredSchema.map(_.dataType)) - - // `fieldWriters(index)` is - // - null if the respective source column is missing, since the output value - // is always null in this case - // - a function that updates target column `index` otherwise. - private val fieldWriters: Array[WritableComparable[_] => Unit] = { - requiredSchema.zipWithIndex - .map { case (f, index) => - if (requestedColIds(index) == -1) { - null - } else { - val writer = newWriter(f.dataType, new RowUpdater(resultRow)) - (value: WritableComparable[_]) => writer(index, value) - } - }.toArray - } - - def deserialize(orcStruct: OrcStruct): InternalRow = { - var targetColumnIndex = 0 - while (targetColumnIndex < fieldWriters.length) { - if (fieldWriters(targetColumnIndex) != null) { - val value = orcStruct.getFieldValue(requestedColIds(targetColumnIndex)) - if (value == null) { - resultRow.setNullAt(targetColumnIndex) - } else { - fieldWriters(targetColumnIndex)(value) - } - } - targetColumnIndex += 1 - } - resultRow - } - - /** - * Creates a writer to write ORC values to Catalyst data structure at the given ordinal. - */ - private def newWriter( - dataType: DataType, updater: CatalystDataUpdater): (Int, WritableComparable[_]) => Unit = - dataType match { - case NullType => (ordinal, _) => - updater.setNullAt(ordinal) - - case BooleanType => (ordinal, value) => - updater.setBoolean(ordinal, value.asInstanceOf[BooleanWritable].get) - - case ByteType => (ordinal, value) => - updater.setByte(ordinal, value.asInstanceOf[ByteWritable].get) - - case ShortType => (ordinal, value) => - updater.setShort(ordinal, value.asInstanceOf[ShortWritable].get) - - case IntegerType => (ordinal, value) => - updater.setInt(ordinal, value.asInstanceOf[IntWritable].get) - - case LongType => (ordinal, value) => - updater.setLong(ordinal, value.asInstanceOf[LongWritable].get) - - case FloatType => (ordinal, value) => - updater.setFloat(ordinal, value.asInstanceOf[FloatWritable].get) - - case DoubleType => (ordinal, value) => - updater.setDouble(ordinal, value.asInstanceOf[DoubleWritable].get) - - case StringType => (ordinal, value) => - updater.set(ordinal, UTF8String.fromBytes(value.asInstanceOf[Text].copyBytes)) - - case BinaryType => (ordinal, value) => - val binary = value.asInstanceOf[BytesWritable] - val bytes = new Array[Byte](binary.getLength) - System.arraycopy(binary.getBytes, 0, bytes, 0, binary.getLength) - updater.set(ordinal, bytes) - - case DateType => (ordinal, value) => - updater.setInt(ordinal, DateTimeUtils.fromJavaDate(value.asInstanceOf[DateWritable].get)) - - case TimestampType => (ordinal, value) => - updater.setLong(ordinal, DateTimeUtils.fromJavaTimestamp(value.asInstanceOf[OrcTimestamp])) - - case DecimalType.Fixed(precision, scale) => (ordinal, value) => - val decimal = value.asInstanceOf[HiveDecimalWritable].getHiveDecimal() - val v = Decimal(decimal.bigDecimalValue, decimal.precision(), decimal.scale()) - v.changePrecision(precision, scale) - updater.set(ordinal, v) - - case st: StructType => (ordinal, value) => - val result = new SpecificInternalRow(st) - val fieldUpdater = new RowUpdater(result) - val fieldConverters = st.map(_.dataType).map { dt => - newWriter(dt, fieldUpdater) - }.toArray - val orcStruct = value.asInstanceOf[OrcStruct] - - var i = 0 - while (i < st.length) { - val value = orcStruct.getFieldValue(i) - if (value == null) { - result.setNullAt(i) - } else { - fieldConverters(i)(i, value) - } - i += 1 - } - - updater.set(ordinal, result) - - case ArrayType(elementType, _) => (ordinal, value) => - val orcArray = value.asInstanceOf[OrcList[WritableComparable[_]]] - val length = orcArray.size() - val result = createArrayData(elementType, length) - val elementUpdater = new ArrayDataUpdater(result) - val elementConverter = newWriter(elementType, elementUpdater) - - var i = 0 - while (i < length) { - val value = orcArray.get(i) - if (value == null) { - result.setNullAt(i) - } else { - elementConverter(i, value) - } - i += 1 - } - - updater.set(ordinal, result) - - case MapType(keyType, valueType, _) => (ordinal, value) => - val orcMap = value.asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] - val length = orcMap.size() - val keyArray = createArrayData(keyType, length) - val keyUpdater = new ArrayDataUpdater(keyArray) - val keyConverter = newWriter(keyType, keyUpdater) - val valueArray = createArrayData(valueType, length) - val valueUpdater = new ArrayDataUpdater(valueArray) - val valueConverter = newWriter(valueType, valueUpdater) - - var i = 0 - val it = orcMap.entrySet().iterator() - while (it.hasNext) { - val entry = it.next() - keyConverter(i, entry.getKey) - val value = entry.getValue - if (value == null) { - valueArray.setNullAt(i) - } else { - valueConverter(i, value) - } - i += 1 - } - - // The ORC map will never have null or duplicated map keys, it's safe to create a - // ArrayBasedMapData directly here. - updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray)) - - case udt: UserDefinedType[_] => newWriter(udt.sqlType, updater) - - case _ => - throw new UnsupportedOperationException(s"$dataType is not supported yet.") - } - - private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match { - case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length)) - case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length)) - case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length)) - case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length)) - case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length)) - case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length)) - case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length)) - case _ => new GenericArrayData(new Array[Any](length)) - } - - /** - * A base interface for updating values inside catalyst data structure like `InternalRow` and - * `ArrayData`. - */ - sealed trait CatalystDataUpdater { - def set(ordinal: Int, value: Any): Unit - - def setNullAt(ordinal: Int): Unit = set(ordinal, null) - def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value) - def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value) - def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value) - def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value) - def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value) - def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value) - def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value) - } - - final class RowUpdater(row: InternalRow) extends CatalystDataUpdater { - override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal) - override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value) - - override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value) - override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value) - override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value) - override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value) - override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value) - override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value) - override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value) - } - - final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater { - override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal) - override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value) - - override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value) - override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value) - override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value) - override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value) - override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value) - override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value) - override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value) - } -} diff --git a/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala b/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala deleted file mode 100644 index 90d126802809..000000000000 --- a/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala +++ /dev/null @@ -1,228 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.orc - -import org.apache.hadoop.io._ -import org.apache.orc.TypeDescription -import org.apache.orc.mapred.{OrcList, OrcMap, OrcStruct, OrcTimestamp} -import org.apache.orc.storage.common.`type`.HiveDecimal -import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.SpecializedGetters -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.types._ - -/** - * A serializer to serialize Spark rows to ORC structs. - */ -class OrcSerializer(dataSchema: StructType) { - - private val result = createOrcValue(dataSchema).asInstanceOf[OrcStruct] - private val converters = dataSchema.map(_.dataType).map(newConverter(_)).toArray - - def serialize(row: InternalRow): OrcStruct = { - var i = 0 - while (i < converters.length) { - if (row.isNullAt(i)) { - result.setFieldValue(i, null) - } else { - result.setFieldValue(i, converters(i)(row, i)) - } - i += 1 - } - result - } - - private type Converter = (SpecializedGetters, Int) => WritableComparable[_] - - /** - * Creates a converter to convert Catalyst data at the given ordinal to ORC values. - */ - private def newConverter( - dataType: DataType, - reuseObj: Boolean = true): Converter = dataType match { - case NullType => (getter, ordinal) => null - - case BooleanType => - if (reuseObj) { - val result = new BooleanWritable() - (getter, ordinal) => - result.set(getter.getBoolean(ordinal)) - result - } else { - (getter, ordinal) => new BooleanWritable(getter.getBoolean(ordinal)) - } - - case ByteType => - if (reuseObj) { - val result = new ByteWritable() - (getter, ordinal) => - result.set(getter.getByte(ordinal)) - result - } else { - (getter, ordinal) => new ByteWritable(getter.getByte(ordinal)) - } - - case ShortType => - if (reuseObj) { - val result = new ShortWritable() - (getter, ordinal) => - result.set(getter.getShort(ordinal)) - result - } else { - (getter, ordinal) => new ShortWritable(getter.getShort(ordinal)) - } - - case IntegerType => - if (reuseObj) { - val result = new IntWritable() - (getter, ordinal) => - result.set(getter.getInt(ordinal)) - result - } else { - (getter, ordinal) => new IntWritable(getter.getInt(ordinal)) - } - - - case LongType => - if (reuseObj) { - val result = new LongWritable() - (getter, ordinal) => - result.set(getter.getLong(ordinal)) - result - } else { - (getter, ordinal) => new LongWritable(getter.getLong(ordinal)) - } - - case FloatType => - if (reuseObj) { - val result = new FloatWritable() - (getter, ordinal) => - result.set(getter.getFloat(ordinal)) - result - } else { - (getter, ordinal) => new FloatWritable(getter.getFloat(ordinal)) - } - - case DoubleType => - if (reuseObj) { - val result = new DoubleWritable() - (getter, ordinal) => - result.set(getter.getDouble(ordinal)) - result - } else { - (getter, ordinal) => new DoubleWritable(getter.getDouble(ordinal)) - } - - - // Don't reuse the result object for string and binary as it would cause extra data copy. - case StringType => (getter, ordinal) => - new Text(getter.getUTF8String(ordinal).getBytes) - - case BinaryType => (getter, ordinal) => - new BytesWritable(getter.getBinary(ordinal)) - - case DateType => - if (reuseObj) { - val result = new DateWritable() - (getter, ordinal) => - result.set(getter.getInt(ordinal)) - result - } else { - (getter, ordinal) => new DateWritable(getter.getInt(ordinal)) - } - - // The following cases are already expensive, reusing object or not doesn't matter. - - case TimestampType => (getter, ordinal) => - val ts = DateTimeUtils.toJavaTimestamp(getter.getLong(ordinal)) - val result = new OrcTimestamp(ts.getTime) - result.setNanos(ts.getNanos) - result - - case DecimalType.Fixed(precision, scale) => (getter, ordinal) => - val d = getter.getDecimal(ordinal, precision, scale) - new HiveDecimalWritable(HiveDecimal.create(d.toJavaBigDecimal)) - - case st: StructType => (getter, ordinal) => - val result = createOrcValue(st).asInstanceOf[OrcStruct] - val fieldConverters = st.map(_.dataType).map(newConverter(_)) - val numFields = st.length - val struct = getter.getStruct(ordinal, numFields) - var i = 0 - while (i < numFields) { - if (struct.isNullAt(i)) { - result.setFieldValue(i, null) - } else { - result.setFieldValue(i, fieldConverters(i)(struct, i)) - } - i += 1 - } - result - - case ArrayType(elementType, _) => (getter, ordinal) => - val result = createOrcValue(dataType).asInstanceOf[OrcList[WritableComparable[_]]] - // Need to put all converted values to a list, can't reuse object. - val elementConverter = newConverter(elementType, reuseObj = false) - val array = getter.getArray(ordinal) - var i = 0 - while (i < array.numElements()) { - if (array.isNullAt(i)) { - result.add(null) - } else { - result.add(elementConverter(array, i)) - } - i += 1 - } - result - - case MapType(keyType, valueType, _) => (getter, ordinal) => - val result = createOrcValue(dataType) - .asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]] - // Need to put all converted values to a list, can't reuse object. - val keyConverter = newConverter(keyType, reuseObj = false) - val valueConverter = newConverter(valueType, reuseObj = false) - val map = getter.getMap(ordinal) - val keyArray = map.keyArray() - val valueArray = map.valueArray() - var i = 0 - while (i < map.numElements()) { - val key = keyConverter(keyArray, i) - if (valueArray.isNullAt(i)) { - result.put(key, null) - } else { - result.put(key, valueConverter(valueArray, i)) - } - i += 1 - } - result - - case udt: UserDefinedType[_] => newConverter(udt.sqlType) - - case _ => - throw new UnsupportedOperationException(s"$dataType is not supported yet.") - } - - /** - * Return a Orc value object for the given Spark schema. - */ - private def createOrcValue(dataType: DataType) = { - OrcStruct.createValue(TypeDescription.fromString(OrcFileFormat.getQuotedSchemaString(dataType))) - } -} diff --git a/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcShimUtils.scala b/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcShimUtils.scala new file mode 100644 index 000000000000..1dce379b9f53 --- /dev/null +++ b/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcShimUtils.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import java.sql.Date + +import org.apache.orc.storage.common.`type`.HiveDecimal +import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} + +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.types.Decimal + +/** + * Helper functions for Orc serialize and deserialize. + */ +private[spark] object OrcShimUtils { + + def getSqlDate(value: Any): Date = value.asInstanceOf[DateWritable].get + + def getDecimal(value: Any): Decimal = { + val decimal = value.asInstanceOf[HiveDecimalWritable].getHiveDecimal() + Decimal(decimal.bigDecimalValue, decimal.precision(), decimal.scale()) + } + + def getDateWritable(reuseObj: Boolean): (SpecializedGetters, Int) => DateWritable = { + if (reuseObj) { + val result = new DateWritable() + (getter, ordinal) => + result.set(getter.getInt(ordinal)) + result + } else { + (getter: SpecializedGetters, ordinal: Int) => + new DateWritable(getter.getInt(ordinal)) + } + } + + def getHiveDecimalWritable(precision: Int, scale: Int): + (SpecializedGetters, Int) => HiveDecimalWritable = { + (getter, ordinal) => + val d = getter.getDecimal(ordinal, precision, scale) + new HiveDecimalWritable(HiveDecimal.create(d.toJavaBigDecimal)) + } +} diff --git a/sql/core/v2.3.4/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v2.3.4/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index 034454d21d7a..981821371a0d 100644 --- a/sql/core/v2.3.4/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v2.3.4/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -89,34 +89,6 @@ class OrcFilterSuite extends OrcTest with SharedSQLContext { checkFilterPredicate(df, predicate, checkLogicalOperator) } - protected def checkNoFilterPredicate - (predicate: Predicate, noneSupported: Boolean = false) - (implicit df: DataFrame): Unit = { - val output = predicate.collect { case a: Attribute => a }.distinct - val query = df - .select(output.map(e => Column(e)): _*) - .where(Column(predicate)) - - query.queryExecution.optimizedPlan match { - case PhysicalOperation(_, filters, - DataSourceV2Relation(orcTable: OrcTable, _, options)) => - assert(filters.nonEmpty, "No filter is analyzed from the given query") - val scanBuilder = orcTable.newScanBuilder(options) - scanBuilder.pushFilters(filters.flatMap(DataSourceStrategy.translateFilter).toArray) - val pushedFilters = scanBuilder.pushedFilters() - if (noneSupported) { - assert(pushedFilters.isEmpty, "Unsupported filters should not show in pushed filters") - } else { - assert(pushedFilters.nonEmpty, "No filter is pushed down") - val maybeFilter = OrcFilters.createFilter(query.schema, pushedFilters) - assert(maybeFilter.isEmpty, s"Couldn't generate filter predicate for $pushedFilters") - } - - case _ => - throw new AnalysisException("Can not match OrcTable in the query.") - } - } - test("filter pushdown - integer") { withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) From 020d7e77d84a9af02167800e21fbe4185c8d2e2c Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 20 Mar 2019 13:27:35 +0800 Subject: [PATCH 3/4] Merge master --- .../datasources/orc/OrcDeserializer.scala | 4 ++-- .../execution/datasources/orc/OrcSerializer.scala | 4 ++-- ...{OrcShimUtils.scala => OrcSerializeUtils.scala} | 2 +- .../sql/execution/datasources/orc/OrcFilters.scala | 6 ++---- ...{OrcShimUtils.scala => OrcSerializeUtils.scala} | 2 +- .../execution/datasources/orc/OrcFilterSuite.scala | 14 +++++++++++++- 6 files changed, 21 insertions(+), 11 deletions(-) rename sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/{OrcShimUtils.scala => OrcSerializeUtils.scala} (98%) rename sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/{OrcShimUtils.scala => OrcSerializeUtils.scala} (98%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala index 6d52d40d6dd0..69aee41a9e4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -108,13 +108,13 @@ class OrcDeserializer( updater.set(ordinal, bytes) case DateType => (ordinal, value) => - updater.setInt(ordinal, DateTimeUtils.fromJavaDate(OrcShimUtils.getSqlDate(value))) + updater.setInt(ordinal, DateTimeUtils.fromJavaDate(OrcSerializeUtils.getSqlDate(value))) case TimestampType => (ordinal, value) => updater.setLong(ordinal, DateTimeUtils.fromJavaTimestamp(value.asInstanceOf[OrcTimestamp])) case DecimalType.Fixed(precision, scale) => (ordinal, value) => - val v = OrcShimUtils.getDecimal(value) + val v = OrcSerializeUtils.getDecimal(value) v.changePrecision(precision, scale) updater.set(ordinal, v) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala index 0b9cbecd0d32..e3edb78fa12b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala @@ -137,7 +137,7 @@ class OrcSerializer(dataSchema: StructType) { new BytesWritable(getter.getBinary(ordinal)) case DateType => - OrcShimUtils.getDateWritable(reuseObj) + OrcSerializeUtils.getDateWritable(reuseObj) // The following cases are already expensive, reusing object or not doesn't matter. @@ -148,7 +148,7 @@ class OrcSerializer(dataSchema: StructType) { result case DecimalType.Fixed(precision, scale) => - OrcShimUtils.getHiveDecimalWritable(precision, scale) + OrcSerializeUtils.getHiveDecimalWritable(precision, scale) case st: StructType => (getter, ordinal) => val result = createOrcValue(st).asInstanceOf[OrcStruct] diff --git a/sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcShimUtils.scala b/sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializeUtils.scala similarity index 98% rename from sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcShimUtils.scala rename to sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializeUtils.scala index 1dce379b9f53..7998655f4c7d 100644 --- a/sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcShimUtils.scala +++ b/sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializeUtils.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.Decimal /** * Helper functions for Orc serialize and deserialize. */ -private[spark] object OrcShimUtils { +private[spark] object OrcSerializeUtils { def getSqlDate(value: Any): Date = value.asInstanceOf[DateWritable].get diff --git a/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index cd2a68a53bab..98484003644a 100644 --- a/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.orc +import org.apache.orc.storage.common.`type`.HiveDecimal import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument} import org.apache.orc.storage.ql.io.sarg.SearchArgument.Builder import org.apache.orc.storage.ql.io.sarg.SearchArgumentFactory.newBuilder @@ -136,10 +137,7 @@ private[sql] object OrcFilters { case FloatType | DoubleType => value.asInstanceOf[Number].doubleValue() case _: DecimalType => - val decimal = value.asInstanceOf[java.math.BigDecimal] - val decimalWritable = new HiveDecimalWritable(decimal.longValue) - decimalWritable.mutateEnforcePrecisionScale(decimal.precision, decimal.scale) - decimalWritable + new HiveDecimalWritable(HiveDecimal.create(value.asInstanceOf[java.math.BigDecimal])) case _ => value } diff --git a/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcShimUtils.scala b/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializeUtils.scala similarity index 98% rename from sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcShimUtils.scala rename to sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializeUtils.scala index 1dce379b9f53..7998655f4c7d 100644 --- a/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcShimUtils.scala +++ b/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializeUtils.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.Decimal /** * Helper functions for Orc serialize and deserialize. */ -private[spark] object OrcShimUtils { +private[spark] object OrcSerializeUtils { def getSqlDate(value: Any): Date = value.asInstanceOf[DateWritable].get diff --git a/sql/core/v2.3.4/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v2.3.4/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index 981821371a0d..e96c6fb7716c 100644 --- a/sql/core/v2.3.4/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v2.3.4/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.orc +import java.math.MathContext import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} @@ -31,7 +32,6 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.orc.OrcTable -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -397,5 +397,17 @@ class OrcFilterSuite extends OrcTest with SharedSQLContext { )).get.toString } } + + test("SPARK-27160: Fix casting of the DecimalType literal") { + import org.apache.spark.sql.sources._ + val schema = StructType(Array(StructField("a", DecimalType(3, 2)))) + assertResult("leaf-0 = (LESS_THAN a 3.14), expr = leaf-0") { + OrcFilters.createFilter(schema, Array( + LessThan( + "a", + new java.math.BigDecimal(3.14, MathContext.DECIMAL64).setScale(2))) + ).get.toString + } + } } From 11bc98284566ae93caffa7d947543c095de03c75 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 26 Mar 2019 09:59:27 +0800 Subject: [PATCH 4/4] Merge master --- .../orc/OrcColumnarBatchReader.java | 0 .../datasources/orc/OrcV1FilterSuite.scala | 0 .../datasources/orc/OrcShimUtils.scala | 0 .../orc/OrcColumnarBatchReader.java | 210 ------------------ .../datasources/orc/OrcFilters.scala | 34 +-- .../datasources/orc/OrcSerializeUtils.scala | 58 ----- .../datasources/orc/OrcShimUtils.scala} | 12 +- .../datasources/orc/OrcV1FilterSuite.scala | 107 --------- 8 files changed, 12 insertions(+), 409 deletions(-) rename sql/core/{v1.2.1 => }/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java (100%) rename sql/core/{v1.2.1 => }/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala (100%) rename sql/core/{ => v1.2.1}/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcShimUtils.scala (100%) delete mode 100644 sql/core/v2.3.4/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java delete mode 100644 sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializeUtils.scala rename sql/core/{v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializeUtils.scala => v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcShimUtils.scala} (80%) delete mode 100644 sql/core/v2.3.4/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala diff --git a/sql/core/v1.2.1/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java similarity index 100% rename from sql/core/v1.2.1/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java rename to sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java diff --git a/sql/core/v1.2.1/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala similarity index 100% rename from sql/core/v1.2.1/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcShimUtils.scala b/sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcShimUtils.scala similarity index 100% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcShimUtils.scala rename to sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcShimUtils.scala diff --git a/sql/core/v2.3.4/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/v2.3.4/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java deleted file mode 100644 index efca96e9ce58..000000000000 --- a/sql/core/v2.3.4/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ /dev/null @@ -1,210 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.orc; - -import java.io.IOException; - -import com.google.common.annotations.VisibleForTesting; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.mapreduce.InputSplit; -import org.apache.hadoop.mapreduce.RecordReader; -import org.apache.hadoop.mapreduce.TaskAttemptContext; -import org.apache.hadoop.mapreduce.lib.input.FileSplit; -import org.apache.orc.OrcConf; -import org.apache.orc.OrcFile; -import org.apache.orc.Reader; -import org.apache.orc.TypeDescription; -import org.apache.orc.mapred.OrcInputFormat; -import org.apache.orc.storage.ql.exec.vector.*; - -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; -import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; -import org.apache.spark.sql.types.*; -import org.apache.spark.sql.vectorized.ColumnarBatch; - - -/** - * To support vectorization in WholeStageCodeGen, this reader returns ColumnarBatch. - * After creating, `initialize` and `initBatch` should be called sequentially. - */ -public class OrcColumnarBatchReader extends RecordReader { - - // The capacity of vectorized batch. - private int capacity; - - // Vectorized ORC Row Batch - private VectorizedRowBatch batch; - - /** - * The column IDs of the physical ORC file schema which are required by this reader. - * -1 means this required column is partition column, or it doesn't exist in the ORC file. - * Ideally partition column should never appear in the physical file, and should only appear - * in the directory name. However, Spark allows partition columns inside physical file, - * but Spark will discard the values from the file, and use the partition value got from - * directory name. The column order will be reserved though. - */ - @VisibleForTesting - public int[] requestedDataColIds; - - // Record reader from ORC row batch. - private org.apache.orc.RecordReader recordReader; - - private StructField[] requiredFields; - - // The result columnar batch for vectorized execution by whole-stage codegen. - @VisibleForTesting - public ColumnarBatch columnarBatch; - - // The wrapped ORC column vectors. - private org.apache.spark.sql.vectorized.ColumnVector[] orcVectorWrappers; - - public OrcColumnarBatchReader(int capacity) { - this.capacity = capacity; - } - - - @Override - public Void getCurrentKey() { - return null; - } - - @Override - public ColumnarBatch getCurrentValue() { - return columnarBatch; - } - - @Override - public float getProgress() throws IOException { - return recordReader.getProgress(); - } - - @Override - public boolean nextKeyValue() throws IOException { - return nextBatch(); - } - - @Override - public void close() throws IOException { - if (columnarBatch != null) { - columnarBatch.close(); - columnarBatch = null; - } - if (recordReader != null) { - recordReader.close(); - recordReader = null; - } - } - - /** - * Initialize ORC file reader and batch record reader. - * Please note that `initBatch` is needed to be called after this. - */ - @Override - public void initialize( - InputSplit inputSplit, TaskAttemptContext taskAttemptContext) throws IOException { - FileSplit fileSplit = (FileSplit)inputSplit; - Configuration conf = taskAttemptContext.getConfiguration(); - Reader reader = OrcFile.createReader( - fileSplit.getPath(), - OrcFile.readerOptions(conf) - .maxLength(OrcConf.MAX_FILE_LENGTH.getLong(conf)) - .filesystem(fileSplit.getPath().getFileSystem(conf))); - Reader.Options options = - OrcInputFormat.buildOptions(conf, reader, fileSplit.getStart(), fileSplit.getLength()); - recordReader = reader.rows(options); - } - - /** - * Initialize columnar batch by setting required schema and partition information. - * With this information, this creates ColumnarBatch with the full schema. - * - * @param orcSchema Schema from ORC file reader. - * @param requiredFields All the fields that are required to return, including partition fields. - * @param requestedDataColIds Requested column ids from orcSchema. -1 if not existed. - * @param requestedPartitionColIds Requested column ids from partition schema. -1 if not existed. - * @param partitionValues Values of partition columns. - */ - public void initBatch( - TypeDescription orcSchema, - StructField[] requiredFields, - int[] requestedDataColIds, - int[] requestedPartitionColIds, - InternalRow partitionValues) { - batch = orcSchema.createRowBatch(capacity); - assert(!batch.selectedInUse); // `selectedInUse` should be initialized with `false`. - assert(requiredFields.length == requestedDataColIds.length); - assert(requiredFields.length == requestedPartitionColIds.length); - // If a required column is also partition column, use partition value and don't read from file. - for (int i = 0; i < requiredFields.length; i++) { - if (requestedPartitionColIds[i] != -1) { - requestedDataColIds[i] = -1; - } - } - this.requiredFields = requiredFields; - this.requestedDataColIds = requestedDataColIds; - - StructType resultSchema = new StructType(requiredFields); - - // Just wrap the ORC column vector instead of copying it to Spark column vector. - orcVectorWrappers = new org.apache.spark.sql.vectorized.ColumnVector[resultSchema.length()]; - - for (int i = 0; i < requiredFields.length; i++) { - DataType dt = requiredFields[i].dataType(); - if (requestedPartitionColIds[i] != -1) { - OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, dt); - ColumnVectorUtils.populate(partitionCol, partitionValues, requestedPartitionColIds[i]); - partitionCol.setIsConstant(); - orcVectorWrappers[i] = partitionCol; - } else { - int colId = requestedDataColIds[i]; - // Initialize the missing columns once. - if (colId == -1) { - OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); - missingCol.putNulls(0, capacity); - missingCol.setIsConstant(); - orcVectorWrappers[i] = missingCol; - } else { - orcVectorWrappers[i] = new OrcColumnVector(dt, batch.cols[colId]); - } - } - } - - columnarBatch = new ColumnarBatch(orcVectorWrappers); - } - - /** - * Return true if there exists more data in the next batch. If exists, prepare the next batch - * by copying from ORC VectorizedRowBatch columns to Spark ColumnarBatch columns. - */ - private boolean nextBatch() throws IOException { - recordReader.nextBatch(batch); - int batchSize = batch.size; - if (batchSize == 0) { - return false; - } - columnarBatch.setNumRows(batchSize); - - for (int i = 0; i < requiredFields.length; i++) { - if (requestedDataColIds[i] != -1) { - ((OrcColumnVector) orcVectorWrappers[i]).setBatchSize(batchSize); - } - } - return true; - } -} diff --git a/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 98484003644a..112dcb2cb238 100644 --- a/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -23,7 +23,7 @@ import org.apache.orc.storage.ql.io.sarg.SearchArgument.Builder import org.apache.orc.storage.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.orc.storage.serde2.io.HiveDecimalWritable -import org.apache.spark.sql.sources.{And, Filter} +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ /** @@ -56,27 +56,7 @@ import org.apache.spark.sql.types._ * builder methods mentioned above can only be found in test code, where all tested filters are * known to be convertible. */ -private[sql] object OrcFilters { - private[sql] def buildTree(filters: Seq[Filter]): Option[Filter] = { - filters match { - case Seq() => None - case Seq(filter) => Some(filter) - case Seq(filter1, filter2) => Some(And(filter1, filter2)) - case _ => // length > 2 - val (left, right) = filters.splitAt(filters.length / 2) - Some(And(buildTree(left).get, buildTree(right).get)) - } - } - - // Since ORC 1.5.0 (ORC-323), we need to quote for column names with `.` characters - // in order to distinguish predicate pushdown for nested columns. - private def quoteAttributeNameIfNeeded(name: String) : String = { - if (!name.contains("`") && name.contains(".")) { - s"`$name`" - } else { - name - } - } +private[sql] object OrcFilters extends OrcFiltersBase { /** * Create ORC filter as a SearchArgument instance. @@ -101,16 +81,6 @@ private[sql] object OrcFilters { } yield filter } - /** - * Return true if this is a searchable type in ORC. - * Both CharType and VarcharType are cleaned at AstBuilder. - */ - private def isSearchableType(dataType: DataType) = dataType match { - case BinaryType => false - case _: AtomicType => true - case _ => false - } - /** * Get PredicateLeafType which is corresponding to the given DataType. */ diff --git a/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializeUtils.scala b/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializeUtils.scala deleted file mode 100644 index 7998655f4c7d..000000000000 --- a/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializeUtils.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.orc - -import java.sql.Date - -import org.apache.orc.storage.common.`type`.HiveDecimal -import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} - -import org.apache.spark.sql.catalyst.expressions.SpecializedGetters -import org.apache.spark.sql.types.Decimal - -/** - * Helper functions for Orc serialize and deserialize. - */ -private[spark] object OrcSerializeUtils { - - def getSqlDate(value: Any): Date = value.asInstanceOf[DateWritable].get - - def getDecimal(value: Any): Decimal = { - val decimal = value.asInstanceOf[HiveDecimalWritable].getHiveDecimal() - Decimal(decimal.bigDecimalValue, decimal.precision(), decimal.scale()) - } - - def getDateWritable(reuseObj: Boolean): (SpecializedGetters, Int) => DateWritable = { - if (reuseObj) { - val result = new DateWritable() - (getter, ordinal) => - result.set(getter.getInt(ordinal)) - result - } else { - (getter: SpecializedGetters, ordinal: Int) => - new DateWritable(getter.getInt(ordinal)) - } - } - - def getHiveDecimalWritable(precision: Int, scale: Int): - (SpecializedGetters, Int) => HiveDecimalWritable = { - (getter, ordinal) => - val d = getter.getDecimal(ordinal, precision, scale) - new HiveDecimalWritable(HiveDecimal.create(d.toJavaBigDecimal)) - } -} diff --git a/sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializeUtils.scala b/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcShimUtils.scala similarity index 80% rename from sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializeUtils.scala rename to sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcShimUtils.scala index 7998655f4c7d..68503aba22b4 100644 --- a/sql/core/v1.2.1/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializeUtils.scala +++ b/sql/core/v2.3.4/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcShimUtils.scala @@ -20,15 +20,23 @@ package org.apache.spark.sql.execution.datasources.orc import java.sql.Date import org.apache.orc.storage.common.`type`.HiveDecimal +import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch +import org.apache.orc.storage.ql.io.sarg.{SearchArgument => OrcSearchArgument} +import org.apache.orc.storage.ql.io.sarg.PredicateLeaf.{Operator => OrcOperator} import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable} import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.types.Decimal /** - * Helper functions for Orc serialize and deserialize. + * Various utilities for ORC used to upgrade the built-in Hive. */ -private[spark] object OrcSerializeUtils { +private[sql] object OrcShimUtils { + + class VectorizedRowBatchWrap(val batch: VectorizedRowBatch) {} + + private[sql] type Operator = OrcOperator + private[sql] type SearchArgument = OrcSearchArgument def getSqlDate(value: Any): Date = value.asInstanceOf[DateWritable].get diff --git a/sql/core/v2.3.4/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala b/sql/core/v2.3.4/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala deleted file mode 100644 index 5a1bf9b43756..000000000000 --- a/sql/core/v2.3.4/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.datasources.orc - -import scala.collection.JavaConverters._ - -import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument} - -import org.apache.spark.SparkConf -import org.apache.spark.sql.{Column, DataFrame} -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, Predicate} -import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation} -import org.apache.spark.sql.internal.SQLConf - -class OrcV1FilterSuite extends OrcFilterSuite { - - override protected def sparkConf: SparkConf = - super - .sparkConf - .set(SQLConf.USE_V1_SOURCE_READER_LIST, "orc") - .set(SQLConf.USE_V1_SOURCE_WRITER_LIST, "orc") - - override def checkFilterPredicate( - df: DataFrame, - predicate: Predicate, - checker: (SearchArgument) => Unit): Unit = { - val output = predicate.collect { case a: Attribute => a }.distinct - val query = df - .select(output.map(e => Column(e)): _*) - .where(Column(predicate)) - - var maybeRelation: Option[HadoopFsRelation] = None - val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { - case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _, _)) => - maybeRelation = Some(orcRelation) - filters - }.flatten.reduceLeftOption(_ && _) - assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") - - val (_, selectedFilters, _) = - DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) - assert(selectedFilters.nonEmpty, "No filter is pushed down") - - val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters) - assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $selectedFilters") - checker(maybeFilter.get) - } - - override def checkFilterPredicate - (predicate: Predicate, filterOperator: PredicateLeaf.Operator) - (implicit df: DataFrame): Unit = { - def checkComparisonOperator(filter: SearchArgument) = { - val operator = filter.getLeaves.asScala - assert(operator.map(_.getOperator).contains(filterOperator)) - } - checkFilterPredicate(df, predicate, checkComparisonOperator) - } - - override def checkFilterPredicate - (predicate: Predicate, stringExpr: String) - (implicit df: DataFrame): Unit = { - def checkLogicalOperator(filter: SearchArgument) = { - assert(filter.toString == stringExpr) - } - checkFilterPredicate(df, predicate, checkLogicalOperator) - } - - override def checkNoFilterPredicate - (predicate: Predicate, noneSupported: Boolean = false) - (implicit df: DataFrame): Unit = { - val output = predicate.collect { case a: Attribute => a }.distinct - val query = df - .select(output.map(e => Column(e)): _*) - .where(Column(predicate)) - - var maybeRelation: Option[HadoopFsRelation] = None - val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { - case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _, _)) => - maybeRelation = Some(orcRelation) - filters - }.flatten.reduceLeftOption(_ && _) - assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") - - val (_, selectedFilters, _) = - DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) - assert(selectedFilters.nonEmpty, "No filter is pushed down") - - val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters) - assert(maybeFilter.isEmpty, s"Could generate filter predicate for $selectedFilters") - } -}