Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-979] Support reading parquet with case sensitive #980

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ package org.apache.spark.sql.execution.datasources.v2.arrow
import java.util.Objects
import java.util.TimeZone

import org.apache.arrow.vector.types.pojo.Schema
import org.apache.arrow.vector.types.pojo.{Field, Schema}

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.ArrowUtils

object SparkSchemaUtils {
Expand All @@ -36,6 +36,11 @@ object SparkSchemaUtils {
ArrowUtils.toArrowSchema(schema, timeZoneId)
}

def toArrowField(
name: String, dt: DataType, nullable: Boolean, timeZoneId: String): Field = {
ArrowUtils.toArrowField(name, dt, nullable, timeZoneId)
}

@deprecated // experimental
def getGandivaCompatibleTimeZoneID(): String = {
val zone = SQLConf.get.sessionLocalTimeZone
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ package com.intel.oap.spark.sql.execution.datasources.arrow
import java.net.URLDecoder

import scala.collection.JavaConverters._
import scala.collection.mutable

import com.intel.oap.spark.sql.ArrowWriteExtension.FakeRow
import com.intel.oap.spark.sql.ArrowWriteQueue
import com.intel.oap.spark.sql.execution.datasources.v2.arrow.{ArrowFilters, ArrowOptions, ArrowUtils}
import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowSQLConf._
import com.intel.oap.vectorized.ArrowWritableColumnVector
import org.apache.arrow.dataset.scanner.ScanOptions
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.mapreduce.Job
Expand Down Expand Up @@ -117,6 +119,7 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Serializab
val sqlConf = sparkSession.sessionState.conf;
val batchSize = sqlConf.parquetVectorizedReaderBatchSize
val enableFilterPushDown = sqlConf.arrowFilterPushDown
val caseSensitive = sqlConf.caseSensitiveAnalysis

(file: PartitionedFile) => {
val factory = ArrowUtils.makeArrowDiscovery(
Expand All @@ -126,16 +129,34 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Serializab
options.asJava).asScala.toMap))

// todo predicate validation / pushdown
val dataset = factory.finish(ArrowUtils.toArrowSchema(requiredSchema));
val parquetFileFields = factory.inspect().getFields.asScala
val caseInsensitiveFieldMap = mutable.Map[String, String]()
val requiredFields = if (sqlConf.caseSensitiveAnalysis) {
new Schema(requiredSchema.map { field =>
parquetFileFields.find(_.getName.equals(field.name))
.getOrElse(ArrowUtils.toArrowField(field))
}.asJava)
} else {
new Schema(requiredSchema.map { readField =>
parquetFileFields.find(_.getName.equalsIgnoreCase(readField.name))
.map{ field =>
caseInsensitiveFieldMap += (readField.name -> field.getName)
field
}.getOrElse(ArrowUtils.toArrowField(readField))
}.asJava)
}
val dataset = factory.finish(requiredFields)

val filter = if (enableFilterPushDown) {
ArrowFilters.translateFilters(filters)
ArrowFilters.translateFilters(filters, caseInsensitiveFieldMap.toMap)
} else {
org.apache.arrow.dataset.filter.Filter.EMPTY
}

val scanOptions = new ScanOptions(requiredSchema.map(f => f.name).toArray,
filter, batchSize)
val scanOptions = new ScanOptions(
requiredFields.getFields.asScala.map(f => f.getName).toArray,
filter,
batchSize)
val scanner = dataset.newScan(scanOptions)

val taskList = scanner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package com.intel.oap.spark.sql.execution.datasources.v2.arrow
import org.apache.arrow.dataset.DatasetTypes
import org.apache.arrow.dataset.DatasetTypes.TreeNode
import org.apache.arrow.dataset.filter.FilterImpl
import org.apache.arrow.vector.types.pojo.Field

import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -56,9 +57,11 @@ object ArrowFilters {
false
}

def translateFilters(pushedFilters: Seq[Filter]): org.apache.arrow.dataset.filter.Filter = {
def translateFilters(
pushedFilters: Seq[Filter],
caseInsensitiveFieldMap: Map[String, String]): org.apache.arrow.dataset.filter.Filter = {
val node = pushedFilters
.flatMap(translateFilter)
.flatMap(filter => translateFilter(filter, caseInsensitiveFieldMap))
.reduceOption((t1: TreeNode, t2: TreeNode) => {
DatasetTypes.TreeNode.newBuilder.setAndNode(
DatasetTypes.AndNode.newBuilder()
Expand Down Expand Up @@ -100,28 +103,35 @@ object ArrowFilters {
}
}

private def translateFilter(pushedFilter: Filter): Option[TreeNode] = {
private def translateFilter(
pushedFilter: Filter,
caseInsensitiveFieldMap: Map[String, String]): Option[TreeNode] = {
pushedFilter match {
case EqualTo(attribute, value) =>
createComparisonNode("equal", attribute, value)
createComparisonNode(
"equal", caseInsensitiveFieldMap.getOrElse(attribute, attribute), value)
case GreaterThan(attribute, value) =>
createComparisonNode("greater", attribute, value)
createComparisonNode(
"greater", caseInsensitiveFieldMap.getOrElse(attribute, attribute), value)
case GreaterThanOrEqual(attribute, value) =>
createComparisonNode("greater_equal", attribute, value)
createComparisonNode(
"greater_equal", caseInsensitiveFieldMap.getOrElse(attribute, attribute), value)
case LessThan(attribute, value) =>
createComparisonNode("less", attribute, value)
createComparisonNode(
"less", caseInsensitiveFieldMap.getOrElse(attribute, attribute), value)
case LessThanOrEqual(attribute, value) =>
createComparisonNode("less_equal", attribute, value)
createComparisonNode(
"less_equal", caseInsensitiveFieldMap.getOrElse(attribute, attribute), value)
case Not(child) =>
createNotNode(child)
createNotNode(child, caseInsensitiveFieldMap)
case And(left, right) =>
createAndNode(left, right)
createAndNode(left, right, caseInsensitiveFieldMap)
case Or(left, right) =>
createOrNode(left, right)
createOrNode(left, right, caseInsensitiveFieldMap)
case IsNotNull(attribute) =>
createIsNotNullNode(attribute)
createIsNotNullNode(caseInsensitiveFieldMap.getOrElse(attribute, attribute))
case IsNull(attribute) =>
createIsNullNode(attribute)
createIsNullNode(caseInsensitiveFieldMap.getOrElse(attribute, attribute))
case _ => None // fixme complete this
}
}
Expand All @@ -145,8 +155,10 @@ object ArrowFilters {
}
}

def createNotNode(child: Filter): Option[TreeNode] = {
val translatedChild = translateFilter(child)
def createNotNode(
child: Filter,
caseInsensitiveFieldMap: Map[String, String]): Option[TreeNode] = {
val translatedChild = translateFilter(child, caseInsensitiveFieldMap)
if (translatedChild.isEmpty) {
return None
}
Expand Down Expand Up @@ -176,9 +188,12 @@ object ArrowFilters {
.build()).build()).build()).build())
}

def createAndNode(left: Filter, right: Filter): Option[TreeNode] = {
val translatedLeft = translateFilter(left)
val translatedRight = translateFilter(right)
def createAndNode(
left: Filter,
right: Filter,
caseInsensitiveFieldMap: Map[String, String]): Option[TreeNode] = {
val translatedLeft = translateFilter(left, caseInsensitiveFieldMap)
val translatedRight = translateFilter(right, caseInsensitiveFieldMap)
if (translatedLeft.isEmpty || translatedRight.isEmpty) {
return None
}
Expand All @@ -190,9 +205,12 @@ object ArrowFilters {
.build())
}

def createOrNode(left: Filter, right: Filter): Option[TreeNode] = {
val translatedLeft = translateFilter(left)
val translatedRight = translateFilter(right)
def createOrNode(
left: Filter,
right: Filter,
caseInsensitiveFieldMap: Map[String, String]): Option[TreeNode] = {
val translatedLeft = translateFilter(left, caseInsensitiveFieldMap)
val translatedRight = translateFilter(right, caseInsensitiveFieldMap)
if (translatedLeft.isEmpty || translatedRight.isEmpty) {
return None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ package com.intel.oap.spark.sql.execution.datasources.v2.arrow
import java.net.URLDecoder

import scala.collection.JavaConverters._
import scala.collection.mutable

import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowPartitionReaderFactory.ColumnarBatchRetainer
import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowSQLConf._
import org.apache.arrow.dataset.scanner.ScanOptions
import org.apache.arrow.vector.types.pojo.Schema

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -59,9 +61,27 @@ case class ArrowPartitionReaderFactory(
val path = partitionedFile.filePath
val factory = ArrowUtils.makeArrowDiscovery(URLDecoder.decode(path, "UTF-8"),
partitionedFile.start, partitionedFile.length, options)
val dataset = factory.finish(ArrowUtils.toArrowSchema(readDataSchema))
val parquetFileFields = factory.inspect().getFields.asScala
val caseInsensitiveFieldMap = mutable.Map[String, String]()
val requiredFields = if (sqlConf.caseSensitiveAnalysis) {
new Schema(readDataSchema.map { field =>
parquetFileFields.find(_.getName.equals(field.name))
.getOrElse(ArrowUtils.toArrowField(field))
}.asJava)
} else {
new Schema(readDataSchema.map { readField =>
parquetFileFields.find(_.getName.equalsIgnoreCase(readField.name))
.map{ field =>
caseInsensitiveFieldMap += (readField.name -> field.getName)
field
}.getOrElse(ArrowUtils.toArrowField(readField))
}.asJava)
}
val dataset = factory.finish(requiredFields)
val filter = if (enableFilterPushDown) {
ArrowFilters.translateFilters(ArrowFilters.pruneWithSchema(pushedFilters, readDataSchema))
ArrowFilters.translateFilters(
ArrowFilters.pruneWithSchema(pushedFilters, readDataSchema),
caseInsensitiveFieldMap.toMap)
} else {
org.apache.arrow.dataset.filter.Filter.EMPTY
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,17 @@ import scala.collection.JavaConverters._
import com.intel.oap.vectorized.{ArrowColumnVectorUtils, ArrowWritableColumnVector}
import org.apache.arrow.dataset.file.FileSystemDatasetFactory
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.arrow.vector.types.pojo.{Field, Schema}
import org.apache.hadoop.fs.FileStatus

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.v2.arrow.{SparkMemoryUtils, SparkSchemaUtils}
import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.vectorized.ColumnVector
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}

object ArrowUtils {

Expand Down Expand Up @@ -89,6 +88,11 @@ object ArrowUtils {
SparkSchemaUtils.toArrowSchema(t, SparkSchemaUtils.getLocalTimezoneID())
}

def toArrowField(t: StructField): Field = {
SparkSchemaUtils.toArrowField(
t.name, t.dataType, t.nullable, SparkSchemaUtils.getLocalTimezoneID())
}

def loadBatch(input: ArrowRecordBatch, partitionValues: InternalRow,
partitionSchema: StructType, dataSchema: StructType): ColumnarBatch = {
val rowCount: Int = input.getLength
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,36 @@ class ArrowDataSourceTest extends QueryTest with SharedSparkSession {

}

test("read and write with case sensitive or insensitive") {
val caseSensitiveAnalysisEnabled = Seq[Boolean](true, false)
val v1SourceList = Seq[String]("", "arrow")
caseSensitiveAnalysisEnabled.foreach{ caseSensitiveAnalysis =>
v1SourceList.foreach{v1Source =>
withSQLConf(
SQLConf.CASE_SENSITIVE.key -> caseSensitiveAnalysis.toString,
SQLConf.USE_V1_SOURCE_LIST.key -> v1Source) {
withTempPath { tempPath =>
spark.range(0, 100)
.withColumnRenamed("id", "Id")
.write
.mode("overwrite")
.arrow(tempPath.getPath)
val selectColName = if (caseSensitiveAnalysis) {
"Id"
} else {
"id"
}
val df = spark.read
.schema(s"$selectColName long")
.arrow(tempPath.getPath)
.filter(s"$selectColName <= 2")
checkAnswer(df, Row(0) :: Row(1) :: Row(2) :: Nil)
}
}
}
}
}

test("file descriptor leak") {
val path = ArrowDataSourceTest.locateResourcePath(parquetFile1)
val frame = spark.read
Expand Down