diff --git a/src/main/scala/com/qubole/spark/hiveacid/writer/hive/HiveAcidWriter.scala b/src/main/scala/com/qubole/spark/hiveacid/writer/hive/HiveAcidWriter.scala index 438e6b3..071d622 100644 --- a/src/main/scala/com/qubole/spark/hiveacid/writer/hive/HiveAcidWriter.scala +++ b/src/main/scala/com/qubole/spark/hiveacid/writer/hive/HiveAcidWriter.scala @@ -100,16 +100,24 @@ abstract private[writer] class HiveAcidWriter(val options: WriterOptions, lazy protected val taskId: Int = Utilities.getTaskIdFromFilename(TaskContext.get.taskAttemptId().toString).toInt - protected def getOrCreateWriter(partitionRow: InternalRow, acidBucketId: Int): Any = { + private val rootPath = new Path(HiveAcidOptions.rootPath) + + private val partitionPathCache: mutable.Map[String, Path] = + scala.collection.mutable.Map[String, Path]() + protected def getOrCreateWriter(partitionRow: InternalRow, acidBucketId: Int): Any = { val partitionBasePath = if (options.partitionColumns.isEmpty) { - new Path(HiveAcidOptions.rootPath) + rootPath } else { - val path = getPartitionPath(partitionRow) - partitionsTouchedSet.add(PartitioningUtils.parsePathFragment(path)) - new Path(HiveAcidOptions.rootPath, path) + val pathString = getPartitionPath(partitionRow) + // using cache so that we don't create new object in + // every getOrCreateWriter call since getOrCreateWriter + // is called on every InternalRow to process + partitionPathCache.getOrElseUpdate(pathString, { + partitionsTouchedSet.add(PartitioningUtils.parsePathFragment(pathString)) + new Path(rootPath, pathString) + }) } - writers.getOrElseUpdate((partitionBasePath.toUri.toString, taskId, acidBucketId), createWriter(partitionBasePath, acidBucketId)) } @@ -177,8 +185,10 @@ private[writer] class HiveAcidFullAcidWriter(options: WriterOptions, throw new RuntimeException(s"Invalid write operation $x") } - override protected def createWriter(path: Path, acidBucketId: Int): Any = { + val taskToBucketId = Utilities.getTaskIdFromFilename(TaskContext.getPartitionId().toString) + .toInt + override protected def createWriter(path: Path, acidBucketId: Int): Any = { val tableDesc = HiveAcidOptions.getFileSinkDesc.getTableInfo val recordUpdater = HiveFileFormatUtils.getAcidRecordUpdater( @@ -257,8 +267,7 @@ private[writer] class HiveAcidFullAcidWriter(options: WriterOptions, } else { options.operationType match { case HiveAcidOperation.INSERT_INTO | HiveAcidOperation.INSERT_OVERWRITE => - Utilities.getTaskIdFromFilename(TaskContext.getPartitionId().toString) - .toInt + taskToBucketId case HiveAcidOperation.DELETE | HiveAcidOperation.UPDATE => val rowID = dataRow.get(rowIdColNum, options.rowIDSchema) // FIXME: Currently hard coding codec as V1 and also bucket ordinal as 1. @@ -279,9 +288,11 @@ private[writer] class HiveAcidFullAcidWriter(options: WriterOptions, val partitionColRow = getPartitionValues(row) val dataColRow = getDataValues(row) + val bucketId = getBucketID(dataColRow) + // Get the recordWriter for this partitionedRow val recordUpdater = - getOrCreateWriter(partitionColRow, getBucketID(dataColRow)).asInstanceOf[RecordUpdater] + getOrCreateWriter(partitionColRow, bucketId).asInstanceOf[RecordUpdater] val recordValue = sparkHiveRowConverter.toHiveRow(dataColRow, hiveRow) @@ -458,10 +469,11 @@ private[hive] class SparkHiveRowConverter(options: WriterOptions, serializer.serialize(hiveRow, objectInspector.asInstanceOf[ObjectInspector]) } - def toHiveRow(sparkRow: InternalRow, hiveRow: Array[Any]): Array[Any] = { - val dataTypes = options.dataColumns.map(_.dataType).toArray - val wrappers = fieldOIs.zip(dataTypes).map { case (f, dt) => wrapperFor(f, dt) } + private val dataTypes = options.dataColumns.map(_.dataType).toArray + private val wrappers = fieldOIs.zip(dataTypes).map { case (f, dt) => wrapperFor(f, dt) } + + def toHiveRow(sparkRow: InternalRow, hiveRow: Array[Any]): Array[Any] = { var i = 0 while (i < fieldOIs.length) { hiveRow(i) = if (sparkRow.isNullAt(i)) null else wrappers(i)(sparkRow.get(i, dataTypes(i)))