Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker
import org.apache.spark.sql.execution.datasources.FileFormatWriter
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration

/**
Expand All @@ -41,8 +42,12 @@ trait DataWritingCommand extends Command {

override final def children: Seq[LogicalPlan] = query :: Nil

// Output columns of the analyzed input query plan
def outputColumns: Seq[Attribute]
// Output column names of the analyzed input query plan.
def outputColumnNames: Seq[String]

// Output columns of the analyzed input query plan.
def outputColumns: Seq[Attribute] =
DataWritingCommand.logicalPlanOutputWithNames(query, outputColumnNames)

lazy val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics

Expand All @@ -53,3 +58,35 @@ trait DataWritingCommand extends Command {

def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row]
}

object DataWritingCommand {
/**
* Returns output attributes with provided names.
* The length of provided names should be the same of the length of [[LogicalPlan.output]].
*/
def logicalPlanOutputWithNames(
query: LogicalPlan,
names: Seq[String]): Seq[Attribute] = {
// Save the output attributes to a variable to avoid duplicated function calls.
val outputAttributes = query.output
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

query: LogicalPlan -> outputAttributes: Seq[Attribute] in the function argument, then drop the line above?

Copy link
Member Author

@gengliangwang gengliangwang Sep 3, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think both are OK. The current way makes it easier to call this Util function, and it is easier to understand what the parameter should be. While the ways you suggests makes the argument carrying minimal information.

assert(outputAttributes.length == names.length,
"The length of provided names doesn't match the length of output attributes.")
outputAttributes.zip(names).map { case (attr, outputName) =>
attr.withName(outputName)
}
}

/**
* Returns schema of logical plan with provided names.
* The length of provided names should be the same of the length of [[LogicalPlan.schema]].
*/
def logicalPlanSchemaWithNames(
query: LogicalPlan,
names: Seq[String]): StructType = {
assert(query.schema.length == names.length,
"The length of provided names doesn't match the length of query schema.")
StructType(query.schema.zip(names).map { case (structField, outputName) =>
structField.copy(name = outputName)
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ case class CreateDataSourceTableAsSelectCommand(
table: CatalogTable,
mode: SaveMode,
query: LogicalPlan,
outputColumns: Seq[Attribute])
outputColumnNames: Seq[String])
extends DataWritingCommand {

override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
Expand Down Expand Up @@ -214,7 +214,7 @@ case class CreateDataSourceTableAsSelectCommand(
catalogTable = if (tableExists) Some(table) else None)

try {
dataSource.writeAndRead(mode, query, outputColumns, physicalPlan)
dataSource.writeAndRead(mode, query, outputColumnNames, physicalPlan)
} catch {
case ex: AnalysisException =>
logError(s"Failed to write to table ${table.identifier.unquotedString}", ex)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.DataWritingCommand
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
Expand Down Expand Up @@ -450,7 +451,7 @@ case class DataSource(
mode = mode,
catalogTable = catalogTable,
fileIndex = fileIndex,
outputColumns = data.output)
outputColumnNames = data.output.map(_.name))
}

/**
Expand All @@ -460,9 +461,9 @@ case class DataSource(
* @param mode The save mode for this writing.
* @param data The input query plan that produces the data to be written. Note that this plan
* is analyzed and optimized.
* @param outputColumns The original output columns of the input query plan. The optimizer may not
* preserve the output column's names' case, so we need this parameter
* instead of `data.output`.
* @param outputColumnNames The original output column names of the input query plan. The
* optimizer may not preserve the output column's names' case, so we need
* this parameter instead of `data.output`.
* @param physicalPlan The physical plan of the input query plan. We should run the writing
* command with this physical plan instead of creating a new physical plan,
* so that the metrics can be correctly linked to the given physical plan and
Expand All @@ -471,8 +472,9 @@ case class DataSource(
def writeAndRead(
mode: SaveMode,
data: LogicalPlan,
outputColumns: Seq[Attribute],
outputColumnNames: Seq[String],
physicalPlan: SparkPlan): BaseRelation = {
val outputColumns = DataWritingCommand.logicalPlanOutputWithNames(data, outputColumnNames)
if (outputColumns.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) {
throw new AnalysisException("Cannot save interval data type into external storage.")
}
Expand All @@ -495,7 +497,9 @@ case class DataSource(
s"Unable to resolve $name given [${data.output.map(_.name).mkString(", ")}]")
}
}
val resolved = cmd.copy(partitionColumns = resolvedPartCols, outputColumns = outputColumns)
val resolved = cmd.copy(
partitionColumns = resolvedPartCols,
outputColumnNames = outputColumnNames)
resolved.run(sparkSession, physicalPlan)
// Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring
copy(userSpecifiedSchema = Some(outputColumns.toStructType.asNullable)).resolveRelation()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast
case CreateTable(tableDesc, mode, Some(query))
if query.resolved && DDLUtils.isDatasourceTable(tableDesc) =>
DDLUtils.checkDataColNames(tableDesc.copy(schema = query.schema))
CreateDataSourceTableAsSelectCommand(tableDesc, mode, query, query.output)
CreateDataSourceTableAsSelectCommand(tableDesc, mode, query, query.output.map(_.name))

case InsertIntoTable(l @ LogicalRelation(_: InsertableRelation, _, _, _),
parts, query, overwrite, false) if parts.isEmpty =>
Expand Down Expand Up @@ -209,7 +209,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast
mode,
table,
Some(t.location),
actualQuery.output)
actualQuery.output.map(_.name))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ case class InsertIntoHadoopFsRelationCommand(
mode: SaveMode,
catalogTable: Option[CatalogTable],
fileIndex: Option[FileIndex],
outputColumns: Seq[Attribute])
outputColumnNames: Seq[String])
extends DataWritingCommand {
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 66: query.schema should be DataWritingCommand.logicalPlanSchemaWithNames(query, outputColumnNames).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, then we can use this method instead.

def checkColumnNameDuplication(
      columnNames: Seq[String], colType: String, caseSensitiveAnalysis: Boolean): Unit

override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
// Most formats don't do well with duplicate columns, so lets not allow that
SchemaUtils.checkSchemaColumnNameDuplication(
query.schema,
SchemaUtils.checkColumnNameDuplication(
outputColumnNames,
s"when inserting into $outputPath",
sparkSession.sessionState.conf.caseSensitiveAnalysis)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,80 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
}
}

test("Insert overwrite table command should output correct schema: basic") {
withTable("tbl", "tbl2") {
withView("view1") {
val df = spark.range(10).toDF("id")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is toDF("id") required? Why not spark.range(10) alone?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is trivial...As the column name id is case sensitive and used below, I would like to show it explicitly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"case sensitive"? How is so since Spark SQL is case-insensitive by default?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @gengliangwang meant case preserving, which is the behavior we are testing against.

spark.range(10).toDF("id") is same as spark.range(10), it's just clearer to people who don't know spark.range outputs a single column named "id".

df.write.format("parquet").saveAsTable("tbl")
spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl")
spark.sql("CREATE TABLE tbl2(ID long) USING parquet")
spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1")
val identifier = TableIdentifier("tbl2")
val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString
val expectedSchema = StructType(Seq(StructField("ID", LongType, true)))
assert(spark.read.parquet(location).schema == expectedSchema)
checkAnswer(spark.table("tbl2"), df)
}
}
}

test("Insert overwrite table command should output correct schema: complex") {
withTable("tbl", "tbl2") {
withView("view1") {
val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3")
df.write.format("parquet").saveAsTable("tbl")
spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl")
spark.sql("CREATE TABLE tbl2(COL1 long, COL2 int, COL3 int) USING parquet PARTITIONED " +
"BY (COL2) CLUSTERED BY (COL3) INTO 3 BUCKETS")
spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT COL1, COL2, COL3 FROM view1")
val identifier = TableIdentifier("tbl2")
val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString
val expectedSchema = StructType(Seq(
StructField("COL1", LongType, true),
StructField("COL3", IntegerType, true),
StructField("COL2", IntegerType, true)))
assert(spark.read.parquet(location).schema == expectedSchema)
checkAnswer(spark.table("tbl2"), df)
}
}
}

test("Create table as select command should output correct schema: basic") {
withTable("tbl", "tbl2") {
withView("view1") {
val df = spark.range(10).toDF("id")
df.write.format("parquet").saveAsTable("tbl")
spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl")
spark.sql("CREATE TABLE tbl2 USING parquet AS SELECT ID FROM view1")
val identifier = TableIdentifier("tbl2")
val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString
val expectedSchema = StructType(Seq(StructField("ID", LongType, true)))
assert(spark.read.parquet(location).schema == expectedSchema)
checkAnswer(spark.table("tbl2"), df)
}
}
}

test("Create table as select command should output correct schema: complex") {
withTable("tbl", "tbl2") {
withView("view1") {
val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3")
df.write.format("parquet").saveAsTable("tbl")
spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl")
spark.sql("CREATE TABLE tbl2 USING parquet PARTITIONED BY (COL2) " +
"CLUSTERED BY (COL3) INTO 3 BUCKETS AS SELECT COL1, COL2, COL3 FROM view1")
val identifier = TableIdentifier("tbl2")
val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString
val expectedSchema = StructType(Seq(
StructField("COL1", LongType, true),
StructField("COL3", IntegerType, true),
StructField("COL2", IntegerType, true)))
assert(spark.read.parquet(location).schema == expectedSchema)
checkAnswer(spark.table("tbl2"), df)
}
}
}

test("use Spark jobs to list files") {
withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "1") {
withTempDir { dir =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,22 +149,22 @@ object HiveAnalysis extends Rule[LogicalPlan] {
case InsertIntoTable(r: HiveTableRelation, partSpec, query, overwrite, ifPartitionNotExists)
if DDLUtils.isHiveTable(r.tableMeta) =>
InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite,
ifPartitionNotExists, query.output)
ifPartitionNotExists, query.output.map(_.name))

case CreateTable(tableDesc, mode, None) if DDLUtils.isHiveTable(tableDesc) =>
DDLUtils.checkDataColNames(tableDesc)
CreateTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore)

case CreateTable(tableDesc, mode, Some(query)) if DDLUtils.isHiveTable(tableDesc) =>
DDLUtils.checkDataColNames(tableDesc)
CreateHiveTableAsSelectCommand(tableDesc, query, query.output, mode)
CreateHiveTableAsSelectCommand(tableDesc, query, query.output.map(_.name), mode)

case InsertIntoDir(isLocal, storage, provider, child, overwrite)
if DDLUtils.isHiveTable(provider) =>
val outputPath = new Path(storage.locationUri.get)
if (overwrite) DDLUtils.verifyNotReadPath(child, outputPath)

InsertIntoHiveDirCommand(isLocal, storage, child, overwrite, child.output)
InsertIntoHiveDirCommand(isLocal, storage, child, overwrite, child.output.map(_.name))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.command.DataWritingCommand
case class CreateHiveTableAsSelectCommand(
tableDesc: CatalogTable,
query: LogicalPlan,
outputColumns: Seq[Attribute],
outputColumnNames: Seq[String],
mode: SaveMode)
extends DataWritingCommand {

Expand All @@ -63,13 +63,14 @@ case class CreateHiveTableAsSelectCommand(
query,
overwrite = false,
ifPartitionNotExists = false,
outputColumns = outputColumns).run(sparkSession, child)
outputColumnNames = outputColumnNames).run(sparkSession, child)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove one outputColumnNames?

} else {
// TODO ideally, we should get the output data ready first and then
// add the relation into catalog, just in case of failure occurs while data
// processing.
assert(tableDesc.schema.isEmpty)
catalog.createTable(tableDesc.copy(schema = query.schema), ignoreIfExists = false)
val schema = DataWritingCommand.logicalPlanSchemaWithNames(query, outputColumnNames)
catalog.createTable(tableDesc.copy(schema = schema), ignoreIfExists = false)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The schema naming need to be consistent with outputColumnNames here.


try {
// Read back the metadata of the table which was created just now.
Expand All @@ -82,7 +83,7 @@ case class CreateHiveTableAsSelectCommand(
query,
overwrite = true,
ifPartitionNotExists = false,
outputColumns = outputColumns).run(sparkSession, child)
outputColumnNames = outputColumnNames).run(sparkSession, child)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this duplication needed here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the duplication?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

outputColumnNames themselves. Specyfing outputColumnNames as the name of the property to set using outputColumnNames does nothing but introduces a duplication. If you removed one outputColumnNames the comprehension should not be lowered whatsoever, shouldn't it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it's better to specify parameters by name if the previous parameter is already specified by name, e.g. ifPartitionNotExists = false

} catch {
case NonFatal(e) =>
// drop the created table.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ case class InsertIntoHiveDirCommand(
storage: CatalogStorageFormat,
query: LogicalPlan,
overwrite: Boolean,
outputColumns: Seq[Attribute]) extends SaveAsHiveFile {
outputColumnNames: Seq[String]) extends SaveAsHiveFile {

override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
assert(storage.locationUri.nonEmpty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ case class InsertIntoHiveTable(
query: LogicalPlan,
overwrite: Boolean,
ifPartitionNotExists: Boolean,
outputColumns: Seq[Attribute]) extends SaveAsHiveFile {
outputColumnNames: Seq[String]) extends SaveAsHiveFile {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For better test coverage, can you add tests for hive tables?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem 👍

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!


/**
* Inserts all the rows in the table into Hive. Row objects are properly serialized with the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,54 @@ class HiveDDLSuite
}
}

test("Insert overwrite Hive table should output correct schema") {
withSQLConf(CONVERT_METASTORE_PARQUET.key -> "false") {
withTable("tbl", "tbl2") {
withView("view1") {
spark.sql("CREATE TABLE tbl(id long)")
spark.sql("INSERT OVERWRITE TABLE tbl VALUES 4")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be missing something, but why does this test use SQL statements not DataFrameWriter API, e.g. Seq(4).toDF("id").write.mode(SaveMode.Overwrite).saveAsTable("tbl")?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can, but it's important to keep the code style consistent with the existing code in the same file. In this test suite, seems SQL statements are prefered.

spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl")
withTempPath { path =>
sql(
s"""
|CREATE TABLE tbl2(ID long) USING hive
|OPTIONS(fileFormat 'parquet')
|LOCATION '${path.toURI}'
""".stripMargin)
spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1")
val expectedSchema = StructType(Seq(StructField("ID", LongType, true)))
assert(spark.read.parquet(path.toString).schema == expectedSchema)
checkAnswer(spark.table("tbl2"), Seq(Row(4)))
}
}
}
}
}

test("Create Hive table as select should output correct schema") {
withSQLConf(CONVERT_METASTORE_PARQUET.key -> "false") {
withTable("tbl", "tbl2") {
withView("view1") {
spark.sql("CREATE TABLE tbl(id long)")
spark.sql("INSERT OVERWRITE TABLE tbl VALUES 4")
spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl")
withTempPath { path =>
sql(
s"""
|CREATE TABLE tbl2 USING hive
|OPTIONS(fileFormat 'parquet')
|LOCATION '${path.toURI}'
|AS SELECT ID FROM view1
""".stripMargin)
val expectedSchema = StructType(Seq(StructField("ID", LongType, true)))
assert(spark.read.parquet(path.toString).schema == expectedSchema)
checkAnswer(spark.table("tbl2"), Seq(Row(4)))
}
}
}
}
}

test("alter table partition - storage information") {
sql("CREATE TABLE boxes (height INT, length INT) PARTITIONED BY (width INT)")
sql("INSERT OVERWRITE TABLE boxes PARTITION (width=4) SELECT 4, 4")
Expand Down