Skip to content
2 changes: 1 addition & 1 deletion R/pkg/inst/tests/testthat/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -2684,7 +2684,7 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume
# It makes sure that we can omit path argument in read.df API and then it calls
# DataFrameWriter.load() without path.
expect_error(read.df(source = "json"),
paste("Error in loadDF : analysis error - Unable to infer schema for JSON at .",
paste("Error in loadDF : analysis error - Unable to infer schema for JSON.",
"It must be specified manually"))
expect_error(read.df("arbitrary_path"), "Error in loadDF : analysis error - Path does not exist")
expect_error(read.json("arbitrary_path"), "Error in json : analysis error - Path does not exist")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,12 @@ import org.apache.spark.util.Utils
* qualified. This option only works when reading from a [[FileFormat]].
* @param userSpecifiedSchema An optional specification of the schema of the data. When present
* we skip attempting to infer the schema.
* @param partitionColumns A list of column names that the relation is partitioned by. When this
* list is empty, the relation is unpartitioned.
* @param partitionColumns A list of column names that the relation is partitioned by. This list is
* generally empty during the read path, unless this DataSource is managed
* by Hive. In these cases, during `resolveRelation`, we will call
* `getOrInferFileFormatSchema` for file based DataSources to infer the
* partitioning. In other cases, if this list is empty, then this table
* is unpartitioned.
* @param bucketSpec An optional specification for bucketing (hash-partitioning) of the data.
* @param catalogTable Optional catalog table reference that can be used to push down operations
* over the datasource to the catalog service.
Expand All @@ -84,30 +88,106 @@ case class DataSource(
private val caseInsensitiveOptions = new CaseInsensitiveMap(options)

/**
* Infer the schema of the given FileFormat, returns a pair of schema and partition column names.
* Get the schema of the given FileFormat, if provided by `userSpecifiedSchema`, or try to infer
* it. In the read path, only managed tables by Hive provide the partition columns properly when
* initializing this class. All other file based data sources will try to infer the partitioning,
* and then cast the inferred types to user specified dataTypes if the partition columns exist
* inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510.
* This method will try to skip file scanning whether `userSpecifiedSchema` and
* `partitionColumns` are provided. Here are some code paths that use this method:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you document what "least amount of work" is? That is, it will skip file scanning if .....

* 1. `spark.read` (no schema): Most amount of work. Infer both schema and partitioning columns
* 2. `spark.read.schema(userSpecifiedSchema)`: Parse partitioning columns, cast them to the
* dataTypes provided in `userSpecifiedSchema` if they exist or fallback to inferred
* dataType if they don't.
* 3. `spark.readStream.schema(userSpecifiedSchema)`: For streaming use cases, users have to
* provide the schema. Here, we also perform partition inference like 2, and try to use
* dataTypes in `userSpecifiedSchema`. All subsequent triggers for this stream will re-use
* this information, therefore calls to this method should be very cheap, i.e. there won't
* be any further inference in any triggers.
* 4. `df.saveAsTable(tableThatExisted)`: In this case, we call this method to resolve the
* existing table's partitioning scheme. This is achieved by not providing
* `userSpecifiedSchema`. For this case, we add the boolean `justPartitioning` for an early
* exit, if we don't care about the schema of the original table.
*
* @param format the file format object for this DataSource
* @param justPartitioning Whether to exit early and provide just the schema partitioning.
* @return A pair of the data schema (excluding partition columns) and the schema of the partition
* columns. If `justPartitioning` is `true`, then the dataSchema will be provided as
* `null`.
*/
private def inferFileFormatSchema(format: FileFormat): (StructType, Seq[String]) = {
userSpecifiedSchema.map(_ -> partitionColumns).orElse {
val allPaths = caseInsensitiveOptions.get("path")
private def getOrInferFileFormatSchema(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add param docs to define what justPartitioning means?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This aint right. You cant return something incorrect. Rather return null for the first schema.
also, the docs is confusing in this way. please add @param and @return after the params to clarify what gets returned in both cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be clearer if we can split this method into two: one for partition schema and the other for data schema. In this way, we can also remove the justPartitioning argument by calling the method you need at the right place.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, just realized that it might be hard to split because of the temporary InMemoryFileIndex.

format: FileFormat,
justPartitioning: Boolean = false): (StructType, StructType) = {
// the operations below are expensive therefore try not to do them if we don't need to
lazy val tempFileCatalog = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add docs on why this is lazy. It took me half-a-minute to trace down why this should be lazy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: tempFileIndex

val allPaths = caseInsensitiveOptions.get("path") ++ paths
val hadoopConf = sparkSession.sessionState.newHadoopConf()
val globbedPaths = allPaths.toSeq.flatMap { path =>
val hdfsPath = new Path(path)
val fs = hdfsPath.getFileSystem(sparkSession.sessionState.newHadoopConf())
val fs = hdfsPath.getFileSystem(hadoopConf)
val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
SparkHadoopUtil.get.globPathIfNecessary(qualified)
}.toArray
val fileCatalog = new InMemoryFileIndex(sparkSession, globbedPaths, options, None)
val partitionSchema = fileCatalog.partitionSpec().partitionColumns
val inferred = format.inferSchema(
new InMemoryFileIndex(sparkSession, globbedPaths, options, None)
}
val partitionSchema = if (partitionColumns.isEmpty && catalogTable.isEmpty) {
// Try to infer partitioning, because no DataSource in the read path provides the partitioning
// columns properly unless it is a Hive DataSource
val resolved = tempFileCatalog.partitionSchema.map { partitionField =>
val equality = sparkSession.sessionState.conf.resolver
// SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred
userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse(
partitionField)
}
StructType(resolved)
} else {
// in streaming mode, we have already inferred and registered partition columns, we will
// never have to materialize the lazy val below
lazy val inferredPartitions = tempFileCatalog.partitionSchema
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every call of this method will re-compute tempFileCatalog which as per your comment is an expensive op. This kinda contradicts 3rd point in the method doc which says that subsequent invocations of this method would be cheap. Did I miss anything ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since everything is already defined in userSpecifiedSchema once we create the FileStreamSource, we will never have to materialize this variable, because we will not use it.

// maintain old behavior before SPARK-18510. If userSpecifiedSchema is empty used inferred
// partitioning
if (userSpecifiedSchema.isEmpty) {
inferredPartitions
} else {
val partitionFields = partitionColumns.map { partitionColumn =>
userSpecifiedSchema.flatMap(_.find(_.name == partitionColumn)).orElse {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need to use the resolver to handle case sensitivity here.

val inferredOpt = inferredPartitions.find(_.name == partitionColumn)
if (inferredOpt.isDefined) {
logDebug(
s"""Type of partition column: $partitionColumn not found in specified schema
|for $format.
|User Specified Schema
|=====================
|${userSpecifiedSchema.orNull}
|
|Falling back to inferred dataType if it exists.
""".stripMargin)
}
inferredPartitions.find(_.name == partitionColumn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicated code?

}.getOrElse {
throw new AnalysisException(s"Failed to resolve the schema for $format for " +
s"the partition column: $partitionColumn. It must be specified manually.")
}
}
StructType(partitionFields)
}
}
if (justPartitioning) {
return (null, partitionSchema)
}
val dataSchema = userSpecifiedSchema.map { schema =>
val equality = sparkSession.sessionState.conf.resolver
StructType(schema.filterNot(f => partitionSchema.exists(p => equality(p.name, f.name))))
}.orElse {
format.inferSchema(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to confirm, inferschema returns schema without the partition columns?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

sparkSession,
caseInsensitiveOptions,
fileCatalog.allFiles())

inferred.map { inferredSchema =>
StructType(inferredSchema ++ partitionSchema) -> partitionSchema.map(_.name)
}
tempFileCatalog.allFiles())
}.getOrElse {
throw new AnalysisException("Unable to infer schema. It must be specified manually.")
throw new AnalysisException(
s"Unable to infer schema for $format. It must be specified manually.")
}
(dataSchema, partitionSchema)
}

/** Returns the name and schema of the source that can be used to continually read data. */
Expand Down Expand Up @@ -144,8 +224,8 @@ case class DataSource(
"you may be able to create a static DataFrame on that directory with " +
"'spark.read.load(directory)' and infer schema from it.")
}
val (schema, partCols) = inferFileFormatSchema(format)
SourceInfo(s"FileSource[$path]", schema, partCols)
val (schema, partCols) = getOrInferFileFormatSchema(format)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename to dataSchema and partitionSchema

SourceInfo(s"FileSource[$path]", StructType(schema ++ partCols), partCols.fieldNames)

case _ =>
throw new UnsupportedOperationException(
Expand Down Expand Up @@ -272,7 +352,7 @@ case class DataSource(

HadoopFsRelation(
fileCatalog,
partitionSchema = fileCatalog.partitionSpec().partitionColumns,
partitionSchema = fileCatalog.partitionSchema,
dataSchema = dataSchema,
bucketSpec = None,
format,
Expand All @@ -281,33 +361,25 @@ case class DataSource(
// This is a non-streaming file based datasource.
case (format: FileFormat, _) =>
val allPaths = caseInsensitiveOptions.get("path") ++ paths
val hadoopConf = sparkSession.sessionState.newHadoopConf()
val globbedPaths = allPaths.flatMap { path =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, for paths with globs, this would expand here ONCE, and then expand them AGAIN in getOrInferFileFormatSchema, right?

If so, we dont have to fix it in this PR, but we should document this in a JIRA or something for fixing later.

val hdfsPath = new Path(path)
val fs = hdfsPath.getFileSystem(sparkSession.sessionState.newHadoopConf())
val fs = hdfsPath.getFileSystem(hadoopConf)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any need for this change? hadoopConf is not reused any where else

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any need for this change? hadoopConf is not reused any where else

this is to avoid creating a new hadoopConf for each path.

val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
val globPath = SparkHadoopUtil.get.globPathIfNecessary(qualified)

if (globPath.isEmpty) {
throw new AnalysisException(s"Path does not exist: $qualified")
}
// Sufficient to check head of the globPath seq for non-glob scenario
// Don't need to check once again if files exist in streaming mode
if (checkFilesExist && !fs.exists(globPath.head)) {
throw new AnalysisException(s"Path does not exist: ${globPath.head}")
}
globPath
}.toArray

// If they gave a schema, then we try and figure out the types of the partition columns
// from that schema.
val partitionSchema = userSpecifiedSchema.map { schema =>
StructType(
partitionColumns.map { c =>
// TODO: Case sensitivity.
schema
.find(_.name.toLowerCase() == c.toLowerCase())
.getOrElse(throw new AnalysisException(s"Invalid partition column '$c'"))
})
}
val (dataSchema, inferredPartitionSchema) = getOrInferFileFormatSchema(format)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this may not be inferred right? so just partitionSchema would be a better name.


val fileCatalog = if (sparkSession.sqlContext.conf.manageFilesourcePartitions &&
catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog) {
Expand All @@ -316,27 +388,12 @@ case class DataSource(
catalogTable.get,
catalogTable.get.stats.map(_.sizeInBytes.toLong).getOrElse(0L))
} else {
new InMemoryFileIndex(
sparkSession, globbedPaths, options, partitionSchema)
}

val dataSchema = userSpecifiedSchema.map { schema =>
val equality = sparkSession.sessionState.conf.resolver
StructType(schema.filterNot(f => partitionColumns.exists(equality(_, f.name))))
}.orElse {
format.inferSchema(
sparkSession,
caseInsensitiveOptions,
fileCatalog.asInstanceOf[InMemoryFileIndex].allFiles())
}.getOrElse {
throw new AnalysisException(
s"Unable to infer schema for $format at ${allPaths.take(2).mkString(",")}. " +
"It must be specified manually")
new InMemoryFileIndex(sparkSession, globbedPaths, options, Some(inferredPartitionSchema))
}

HadoopFsRelation(
fileCatalog,
partitionSchema = fileCatalog.partitionSchema,
partitionSchema = inferredPartitionSchema,
dataSchema = dataSchema.asNullable,
bucketSpec = bucketSpec,
format,
Expand Down Expand Up @@ -384,11 +441,7 @@ case class DataSource(
// up. If we fail to load the table for whatever reason, ignore the check.
if (mode == SaveMode.Append) {
val existingPartitionColumns = Try {
resolveRelation()
.asInstanceOf[HadoopFsRelation]
.partitionSchema
.fieldNames
.toSeq
getOrInferFileFormatSchema(format, justPartitioning = true)._2.fieldNames.toList
}.getOrElse(Seq.empty[String])
// TODO: Case sensitivity.
val sameColumns =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
pathToPartitionedTable,
userSpecifiedSchema = Option("num int, str string"),
userSpecifiedPartitionCols = partitionCols,
expectedSchema = new StructType().add("num", IntegerType).add("str", StringType),
expectedSchema = new StructType().add("str", StringType).add("num", IntegerType),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so in this PR, for some cases, the order of fields in schema created after resolveRelation is changing?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the original test case was incorrect. Although the schema check passes, if you really read rows out of the Dataset, you'll hit an exception, as shown in the following Spark shell session:

import org.apache.spark.sql.types._

val df0 = spark.range(10).select(
  ('id % 4) cast StringType as "part",
  'id cast StringType as "data"
)

val path = "/tmp/part.parquet"
df0.write.mode("overwrite").partitionBy("part").parquet(path)

val df1 = spark.read.schema(
  new StructType()
    .add("part", StringType, nullable = true)
    .add("data", StringType, nullable = true)
).parquet(path)

df1.printSchema()
// root
//  |-- part: string (nullable = true)
//  |-- data: string (nullable = true)

df1.show()
// 16/11/22 22:52:21 ERROR Executor: Exception in task 0.0 in stage 10.0 (TID 34)
// java.lang.NullPointerException
//         at org.apache.spark.sql.execution.vectorized.OnHeapColumnVector.getArrayLength(OnHeapColumnVector.java:375)
//         at org.apache.spark.sql.execution.vectorized.ColumnVector.getArray(ColumnVector.java:554)
//         at org.apache.spark.sql.execution.vectorized.ColumnVector.getByteArray(ColumnVector.java:576)
//         [...]

expectedPartitionCols = partitionCols.map(Seq(_)).getOrElse(Seq.empty[String]))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest {
createFileStreamSourceAndGetSchema(
format = Some("json"), path = Some(src.getCanonicalPath), schema = None)
}
assert("Unable to infer schema. It must be specified manually.;" === e.getMessage)
assert("Unable to infer schema for JSON. It must be specified manually.;" === e.getMessage)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider}
import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, StreamingQuery, StreamTest}
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

object LastOptions {
Expand Down Expand Up @@ -532,4 +532,47 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter {
assert(e.getMessage.contains("does not support recovering"))
assert(e.getMessage.contains("checkpoint location"))
}

test("SPARK-18510: use user specified types for partition columns in file sources") {
import org.apache.spark.sql.functions.udf
import testImplicits._
withTempDir { src =>
val createArray = udf { (length: Long) =>
for (i <- 1 to length.toInt) yield i.toString
}
spark.range(4).select(createArray('id + 1) as 'ex, 'id, 'id % 4 as 'part).coalesce(1).write
.partitionBy("part", "id")
.mode("overwrite")
.parquet(src.toString)
// Specify a random ordering of the schema, partition column in the middle, etc.
// Also let's say that the partition columns are Strings instead of Longs.
// partition columns should go to the end
val schema = new StructType()
.add("id", StringType)
.add("ex", ArrayType(StringType))

val sdf = spark.readStream
.schema(schema)
.format("parquet")
.load(src.toString)

assert(sdf.schema.toList === List(
StructField("ex", ArrayType(StringType)),
StructField("part", IntegerType), // inferred partitionColumn dataType
StructField("id", StringType))) // used user provided partitionColumn dataType

val sq = sdf.writeStream
.queryName("corruption_test")
.format("memory")
.start()
sq.processAllAvailable()
checkAnswer(
spark.table("corruption_test"),
// notice how `part` is ordered before `id`
Row(Array("1"), 0, "0") :: Row(Array("1", "2"), 1, "1") ::
Row(Array("1", "2", "3"), 2, "2") :: Row(Array("1", "2", "3", "4"), 3, "3") :: Nil
)
sq.stop()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils


Expand Down Expand Up @@ -573,4 +573,40 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
}
}
}

test("SPARK-18510: use user specified types for partition columns in file sources") {
import org.apache.spark.sql.functions.udf
import testImplicits._
withTempDir { src =>
val createArray = udf { (length: Long) =>
for (i <- 1 to length.toInt) yield i.toString
}
spark.range(4).select(createArray('id + 1) as 'ex, 'id, 'id % 4 as 'part).coalesce(1).write
.partitionBy("part", "id")
.mode("overwrite")
.parquet(src.toString)
// Specify a random ordering of the schema, partition column in the middle, etc.
// Also let's say that the partition columns are Strings instead of Longs.
// partition columns should go to the end
val schema = new StructType()
.add("id", StringType)
.add("ex", ArrayType(StringType))
val df = spark.read
.schema(schema)
.format("parquet")
.load(src.toString)

assert(df.schema.toList === List(
StructField("ex", ArrayType(StringType)),
StructField("part", IntegerType), // inferred partitionColumn dataType
StructField("id", StringType))) // used user provided partitionColumn dataType

checkAnswer(
df,
// notice how `part` is ordered before `id`
Row(Array("1"), 0, "0") :: Row(Array("1", "2"), 1, "1") ::
Row(Array("1", "2", "3"), 2, "2") :: Row(Array("1", "2", "3", "4"), 3, "3") :: Nil
)
}
}
}