diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 12c667e6e92d..5b617b3ea9d0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -87,33 +87,35 @@ case class InsertIntoHiveTable( // Note that this function is executed on executor side def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = { - val serializer = newSerializer(fileSinkConf.getTableInfo) - val standardOI = ObjectInspectorUtils - .getStandardObjectInspector( - fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, - ObjectInspectorCopyOption.JAVA) - .asInstanceOf[StructObjectInspector] - - val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray - val dataTypes: Array[DataType] = child.output.map(_.dataType).toArray - val wrappers = fieldOIs.zip(dataTypes).map { case (f, dt) => wrapperFor(f, dt)} - val outputData = new Array[Any](fieldOIs.length) - - writerContainer.executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) - - iterator.foreach { row => - var i = 0 - while (i < fieldOIs.length) { - outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i))) - i += 1 + if (iterator.hasNext) { + val serializer = newSerializer(fileSinkConf.getTableInfo) + val standardOI = ObjectInspectorUtils + .getStandardObjectInspector( + fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, + ObjectInspectorCopyOption.JAVA) + .asInstanceOf[StructObjectInspector] + + val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray + val dataTypes: Array[DataType] = child.output.map(_.dataType).toArray + val wrappers = fieldOIs.zip(dataTypes).map { case (f, dt) => wrapperFor(f, dt)} + val outputData = new Array[Any](fieldOIs.length) + + writerContainer.executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) + + iterator.foreach { row => + var i = 0 + while (i < fieldOIs.length) { + outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i))) + i += 1 + } + + writerContainer + .getLocalFileWriter(row, table.schema) + .write(serializer.serialize(outputData, standardOI)) } - writerContainer - .getLocalFileWriter(row, table.schema) - .write(serializer.serialize(outputData, standardOI)) + writerContainer.close() } - - writerContainer.close() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index d33e81227db8..64db73cbd3b9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -262,4 +262,30 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { sql("DROP TABLE table_with_partition") } + + test("SPARK-10216: Avoid creating empty files during overwrite into Hive table with group by query") { + val testData = TestHive.sparkContext.parallelize( + (1 to 2).map(i => TestData(i, i.toString))).toDF() + testData.registerTempTable("testData") + + val tmpDir = Utils.createTempDir() + sql( + s""" + |CREATE TABLE table1(key int,value string) + |location '${tmpDir.toURI.toString}' + """.stripMargin) + sql( + """ + |INSERT OVERWRITE TABLE table1 + |SELECT count(key), value FROM testData GROUP BY value + """.stripMargin) + def listFiles(path: File): List[File] = { + val file = path.listFiles() + file.filter { e => e.isFile && !e.getName.endsWith(".crc")}.toList + } + val fileList = listFiles(tmpDir) + assert(fileList.filter(e => e.length > 0).sortBy(_.getName) === fileList.sortBy(_.getName)) + + sql("DROP TABLE table1") + } }