diff --git a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala index f9bc76de6ef94..1c3699058e462 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala @@ -79,6 +79,20 @@ private[spark] object Utils { builder.result() } + /** + * Same function as `keys.zipWithIndex.toMap`, but has perf gain. + */ + def toMapWithIndex[K](keys: Iterable[K]): Map[K, Int] = { + val builder = immutable.Map.newBuilder[K, Int] + val keyIter = keys.iterator + var idx = 0 + while (keyIter.hasNext) { + builder += (keyIter.next(), idx).asInstanceOf[(K, Int)] + idx = idx + 1 + } + builder.result() + } + /** * Same function as `keys.zip(values).toMap.asJava`, but has perf gain. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala index 576c771d83bec..f11cd865843d2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.attribute import scala.annotation.varargs import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, NumericType, StructField} +import org.apache.spark.util.collection.Utils /** * Abstract class for ML attributes. @@ -338,7 +339,7 @@ class NominalAttribute private[ml] ( override def isNominal: Boolean = true private lazy val valueToIndex: Map[String, Int] = { - values.map(_.zipWithIndex.toMap).getOrElse(Map.empty) + values.map(Utils.toMapWithIndex(_)).getOrElse(Map.empty) } /** Index of a specific value. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index fd07073c306e3..ca0340949fe46 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.collection.OpenHashMap +import org.apache.spark.util.collection.{OpenHashMap, Utils} /** * Params for [[CountVectorizer]] and [[CountVectorizerModel]]. @@ -305,7 +305,7 @@ class CountVectorizerModel( override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema, logging = true) if (broadcastDict.isEmpty) { - val dict = vocabulary.zipWithIndex.toMap + val dict = Utils.toMapWithIndex(vocabulary) broadcastDict = Some(dataset.sparkSession.sparkContext.broadcast(dict)) } val dictBr = broadcastDict.get diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 0e571ad508ff0..f36e98046afa6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -35,7 +35,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{StructField, StructType} -import org.apache.spark.util.collection.OpenHashSet +import org.apache.spark.util.collection.{OpenHashSet, Utils} /** Private trait for params for VectorIndexer and VectorIndexerModel */ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOutputCol @@ -235,7 +235,7 @@ object VectorIndexer extends DefaultParamsReadable[VectorIndexer] { if (zeroExists) { sortedFeatureValues = 0.0 +: sortedFeatureValues } - val categoryMap: Map[Double, Int] = sortedFeatureValues.zipWithIndex.toMap + val categoryMap: Map[Double, Int] = Utils.toMapWithIndex(sortedFeatureValues) (featureIndex, categoryMap) }.toMap } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 4464cfe2c0149..97f277d53ca9d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -39,6 +39,7 @@ import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd._ import org.apache.spark.sql.SparkSession import org.apache.spark.util.Utils +import org.apache.spark.util.collection.{Utils => CUtils} import org.apache.spark.util.random.XORShiftRandom /** @@ -470,7 +471,7 @@ class Word2Vec extends Serializable with Logging { newSentences.unpersist() val wordArray = vocab.map(_.word) - new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global) + new Word2VecModel(CUtils.toMapWithIndex(wordArray), syn0Global) } /** @@ -639,7 +640,7 @@ class Word2VecModel private[spark] ( object Word2VecModel extends Loader[Word2VecModel] { private def buildWordIndex(model: Map[String, Array[Float]]): Map[String, Int] = { - model.keys.zipWithIndex.toMap + CUtils.toMapWithIndex(model.keys) } private def buildWordVectors(model: Map[String, Array[Float]]): Array[Float] = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 3531822e77b78..ecdc28dea37fd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.collection.Utils /** * Model trained by [[FPGrowth]], which holds frequent itemsets. @@ -269,7 +270,7 @@ class FPGrowth private[spark] ( minCount: Long, freqItems: Array[Item], partitioner: Partitioner): RDD[FreqItemset[Item]] = { - val itemToRank = freqItems.zipWithIndex.toMap + val itemToRank = Utils.toMapWithIndex(freqItems) data.flatMap { transaction => genCondTransactions(transaction, itemToRank, partitioner) }.aggregateByKey(new FPTree[Int], partitioner.numPartitions)( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 6f71801814398..7c023bcfa72a4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.collection.Utils /** * A parallel PrefixSpan algorithm to mine frequent sequential patterns. @@ -147,7 +148,7 @@ class PrefixSpan private ( logInfo(s"number of frequent items: ${freqItems.length}") // Keep only frequent items from input sequences and convert them to internal storage. - val itemToInt = freqItems.zipWithIndex.toMap + val itemToInt = Utils.toMapWithIndex(freqItems) val dataInternalRepr = toDatabaseInternalRepr(data, itemToInt) .persist(StorageLevel.MEMORY_AND_DISK) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala index 9f0832804f27f..9c761824134c3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD -import org.apache.spark.util.collection.OpenHashMap +import org.apache.spark.util.collection.{OpenHashMap, Utils} /** * Conduct the chi-squared test for the input RDDs using the specified method. @@ -181,14 +181,14 @@ private[spark] object ChiSqTest extends Logging { counts: Map[(Double, Double), Long], methodName: String, col: Int): ChiSqTestResult = { - val label2Index = counts.iterator.map(_._1._1).toArray.distinct.sorted.zipWithIndex.toMap + val label2Index = Utils.toMapWithIndex(counts.iterator.map(_._1._1).toArray.distinct.sorted) val numLabels = label2Index.size if (numLabels > maxCategories) { throw new SparkException(s"Chi-square test expect factors (categorical values) but " + s"found more than $maxCategories distinct label values.") } - val value2Index = counts.iterator.map(_._1._2).toArray.distinct.sorted.zipWithIndex.toMap + val value2Index = Utils.toMapWithIndex(counts.iterator.map(_._1._2).toArray.distinct.sorted) val numValues = value2Index.size if (numValues > maxCategories) { throw new SparkException(s"Chi-square test expect factors (categorical values) but " diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 1baa5c20ba4ff..543bf7550cc31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.util.collection.Utils import org.apache.spark.util.random.RandomSampler /** @@ -1235,7 +1236,7 @@ object Expand { groupByAttrs: Seq[Attribute], gid: Attribute, child: LogicalPlan): Expand = { - val attrMap = groupByAttrs.zipWithIndex.toMap + val attrMap = Utils.toMapWithIndex(groupByAttrs) val hasDuplicateGroupingSets = groupingSetsAttrs.size != groupingSetsAttrs.map(_.map(_.exprId).toSet).distinct.size diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 92735c5831153..d5f32aac55a4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.collection.Utils /** * A [[StructType]] object can be constructed by @@ -117,7 +118,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru private lazy val fieldNamesSet: Set[String] = fieldNames.toSet private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap - private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap + private lazy val nameToIndex: Map[String, Int] = Utils.toMapWithIndex(fieldNames) override def equals(that: Any): Boolean = { that match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index e38416bfc4e25..9101e7d0ac525 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.collection.Utils /** * A [[ParentContainerUpdater]] is used by a Parquet converter to set converted values to some @@ -207,7 +208,7 @@ private[parquet] class ParquetRowConverter( private[this] val fieldConverters: Array[Converter with HasParentContainerUpdater] = { // (SPARK-31116) Use case insensitive map if spark.sql.caseSensitive is false // to prevent throwing IllegalArgumentException when searching catalyst type's field index - def nameToIndex: Map[String, Int] = catalystType.fieldNames.zipWithIndex.toMap + def nameToIndex: Map[String, Int] = Utils.toMapWithIndex(catalystType.fieldNames) val catalystFieldIdxByName = if (SQLConf.get.caseSensitiveAnalysis) { nameToIndex diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index f32f6620265b9..c9d3b99990830 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.collection.Utils object StatFunctions extends Logging { @@ -198,7 +199,7 @@ object StatFunctions extends Logging { } // get the distinct sorted values of column 2, so that we can make them the column names val distinctCol2: Map[Any, Int] = - counts.map(e => cleanElement(e.get(1))).distinct.sorted.zipWithIndex.toMap + Utils.toMapWithIndex(counts.map(e => cleanElement(e.get(1))).distinct.sorted) val columnSize = distinctCol2.size require(columnSize < 1e4, s"The number of distinct values for $col2, can't " + s"exceed 1e4. Currently $columnSize")