diff --git a/build.sbt b/build.sbt index bb3c22ef5..e729597fe 100644 --- a/build.sbt +++ b/build.sbt @@ -20,6 +20,7 @@ lazy val qbeastSpark = (project in file(".")) sparkSql % Provided, hadoopClient % Provided, deltaCore % Provided, + sparkml % Provided, amazonAws % Test, hadoopCommons % Test, hadoopAws % Test), diff --git a/core/src/main/scala/io/qbeast/core/model/ColumnsToIndexSelector.scala b/core/src/main/scala/io/qbeast/core/model/ColumnsToIndexSelector.scala new file mode 100644 index 000000000..508b8ec33 --- /dev/null +++ b/core/src/main/scala/io/qbeast/core/model/ColumnsToIndexSelector.scala @@ -0,0 +1,36 @@ +package io.qbeast.core.model + +/** + * ColumnsToIndexSelector interface to automatically select which columns to index. + * @tparam DATA + * the data to index + */ +trait ColumnsToIndexSelector[DATA] { + + /** + * The maximum number of columns to index. + * @return + */ + def MAX_COLUMNS_TO_INDEX: Int + + /** + * Selects the columns to index given a DataFrame + * @param data + * the data to index + * @return + */ + def selectColumnsToIndex(data: DATA): Seq[String] = + selectColumnsToIndex(data, MAX_COLUMNS_TO_INDEX) + + /** + * Selects the columns to index with a given number of columns to index + * @param data + * the data to index + * @param numColumnsToIndex + * the number of columns to index + * @return + * A sequence with the names of the columns to index + */ + def selectColumnsToIndex(data: DATA, numColumnsToIndex: Int): Seq[String] + +} diff --git a/core/src/main/scala/io/qbeast/core/model/QbeastCoreContext.scala b/core/src/main/scala/io/qbeast/core/model/QbeastCoreContext.scala index 9f4c6fb73..11b3eaa92 100644 --- a/core/src/main/scala/io/qbeast/core/model/QbeastCoreContext.scala +++ b/core/src/main/scala/io/qbeast/core/model/QbeastCoreContext.scala @@ -21,6 +21,7 @@ trait QbeastCoreContext[DATA, DataSchema, QbeastOptions, FileDescriptor] { def indexManager: IndexManager[DATA] def queryManager[QUERY: ClassTag]: QueryManager[QUERY, DATA] def revisionBuilder: RevisionFactory[DataSchema, QbeastOptions] + def columnSelector: ColumnsToIndexSelector[DATA] def keeper: Keeper } diff --git a/docs/AdvancedConfiguration.md b/docs/AdvancedConfiguration.md index ded3f3e5f..c0373f6fd 100644 --- a/docs/AdvancedConfiguration.md +++ b/docs/AdvancedConfiguration.md @@ -65,6 +65,22 @@ You can specify different advanced options to the columns to index: df.write.format("qbeast").option("columnsToIndex", "column:type,column2:type...") ``` +## Automatic Column Selection + +To **avoid specifying the `columnsToIndex`**, you can enable auto indexer through the Spark Configuration: + +```shell +--conf spark.qbeast.index.columnsToIndex.auto=true \ +--conf spark.qbeast.index.columnsToIndex.auto.max=10 +``` +And write the DataFrame without any extra option: + +```scala +df.write.format("qbeast").save("path/to/table") +``` + +Read more about it in the [Columns to Index selector](ColumnsToIndexSelector.md) section. + ## CubeSize CubeSize option lets you specify the maximum size of the cube, in number of records. By default, it's set to 5M. diff --git a/docs/ColumnsToIndexSelector.md b/docs/ColumnsToIndexSelector.md new file mode 100644 index 000000000..b02a8dfa7 --- /dev/null +++ b/docs/ColumnsToIndexSelector.md @@ -0,0 +1,73 @@ +## Columns To Index Selector + +Qbeast Format organizes the records using a multidimensional index. This index is built on a subset of the columns in the table. From `1.0.0` version, **the columns can be selected automatically by enabling the automatic column index selector or manually by the user**. + +If you want to forget about the distribution and let qbeast handle all the indexing pre-process, there's no need to specify the `columnsToIndex` in the **DataFrame**. + +You only need to **enable the Columns To Index Selector in the `SparkConf`**: + +```shell +--conf spark.qbeast.index.columnsToIndex.auto=true \ +--conf spark.qbeast.index.columnsToIndex.auto.max=10 +``` + +And **write the DataFrame as usual**: + +```scala +df.write.format("qbeast").save("path/to/table") +``` + +Or use SQL: + +```scala +spark.sql("CREATE TABLE table_name USING qbeast LOCATION 'path/to/table'") +``` +### Interface + +The `ColumnsToIndexSelector` is an interface that can be implemented by different classes. The interface is defined as follows: + +```scala +trait ColumnsToIndexSelector[DATA] { + + /** + * The maximum number of columns to index. + * @return + */ + def MAX_COLUMNS_TO_INDEX: Int + + /** + * Selects the columns to index given a DataFrame + * @param data + * the data to index + * @return + */ + def selectColumnsToIndex(data: DATA): Seq[String] = + selectColumnsToIndex(data, MAX_COLUMNS_TO_INDEX) + + /** + * Selects the columns to index with a given number of columns to index + * @param data + * the data to index + * @param numColumnsToIndex + * the number of columns to index + * @return + * A sequence with the names of the columns to index + */ + def selectColumnsToIndex(data: DATA, numColumnsToIndex: Int): Seq[String] + +} + +``` + +### SparkColumnsToIndexSelector + +`SparkColumnsToIndexSelector` is the first implementation of the `ColumnsToIndexSelector` process. Is designed to work with Apache Spark DataFrames and **provides functionality to automatically select columns for indexing based on certain criteria**. + +The steps are the following: + +1. **Convert Timestamp columns** to Unix timestamps and update the DataFrame. +2. **Initialize Vector Assembler** for each column. For String columns, transform them into numeric with StringIndexer. +4. **Combine features** from VectorAssembler into a Single Vector column. +5. Calculate the **Correlation Matrix**. +6. Calculate the **absolute correlation** for each column. +7. Get the **top N columns that have the lowest average correlation**. \ No newline at end of file diff --git a/project/Dependencies.scala b/project/Dependencies.scala index aaa1ca117..2f267cf4f 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -20,4 +20,5 @@ object Dependencies { val hadoopCommons = "org.apache.hadoop" % "hadoop-common" % hadoopVersion val hadoopAws = "org.apache.hadoop" % "hadoop-aws" % hadoopVersion val fasterxml = "com.fasterxml.jackson.module" %% "jackson-module-scala" % "2.12.0" + val sparkml = "org.apache.spark" %% "spark-mllib" % sparkVersion } diff --git a/src/main/scala/io/qbeast/context/QbeastContext.scala b/src/main/scala/io/qbeast/context/QbeastContext.scala index 51000d609..0016f5bc2 100644 --- a/src/main/scala/io/qbeast/context/QbeastContext.scala +++ b/src/main/scala/io/qbeast/context/QbeastContext.scala @@ -8,6 +8,7 @@ import io.qbeast.core.keeper.LocalKeeper import io.qbeast.core.model._ import io.qbeast.spark.delta.writer.RollupDataWriter import io.qbeast.spark.delta.SparkDeltaMetadataManager +import io.qbeast.spark.index.SparkColumnsToIndexSelector import io.qbeast.spark.index.SparkOTreeManager import io.qbeast.spark.index.SparkRevisionFactory import io.qbeast.spark.internal.QbeastOptions @@ -92,6 +93,8 @@ object QbeastContext override def revisionBuilder: RevisionFactory[StructType, QbeastOptions] = SparkRevisionFactory + override def columnSelector: ColumnsToIndexSelector[DataFrame] = SparkColumnsToIndexSelector + /** * Sets the unmanaged context. The specified context will not be disposed automatically at the * end of the Spark session. @@ -146,7 +149,8 @@ object QbeastContext indexManager, metadataManager, dataWriter, - revisionBuilder) + revisionBuilder, + columnSelector) private def destroyManaged(): Unit = this.synchronized { managedOption.foreach(_.keeper.stop()) diff --git a/src/main/scala/io/qbeast/spark/index/SparkColumnsToIndexSelector.scala b/src/main/scala/io/qbeast/spark/index/SparkColumnsToIndexSelector.scala new file mode 100644 index 000000000..4984a82af --- /dev/null +++ b/src/main/scala/io/qbeast/spark/index/SparkColumnsToIndexSelector.scala @@ -0,0 +1,146 @@ +/* + * Copyright 2021 Qbeast Analytics, S.L. + */ +package io.qbeast.spark.index + +import io.qbeast.core.model.ColumnsToIndexSelector +import org.apache.spark.ml.feature.OneHotEncoder +import org.apache.spark.ml.feature.StringIndexer +import org.apache.spark.ml.feature.VectorAssembler +import org.apache.spark.ml.linalg.Matrix +import org.apache.spark.ml.stat.Correlation +import org.apache.spark.ml.Pipeline +import org.apache.spark.qbeast.config.MAX_NUM_COLUMNS_TO_INDEX +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.unix_timestamp +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.TimestampType +import org.apache.spark.sql.DataFrame + +object SparkColumnsToIndexSelector extends ColumnsToIndexSelector[DataFrame] with Serializable { + + /** + * The maximum number of columns to index. + * + * @return + */ + override def MAX_COLUMNS_TO_INDEX: Int = MAX_NUM_COLUMNS_TO_INDEX + + /** + * Adds unix timestamp columns to the DataFrame for the columns specified + * @param data + * @param inputCols + * @return + */ + private def withUnixTimestamp(data: DataFrame, inputCols: Seq[StructField]): DataFrame = { + val timestampColsTransformation = inputCols + .filter(_.dataType == TimestampType) + .map(c => (c.name, unix_timestamp(col(c.name)))) + .toMap + + data.withColumns(timestampColsTransformation) + } + + /** + * Adds preprocessing transformers to the DataFrame for the columns specified + * @param data + * the DataFrame + * @param inputCols + * the columns to preprocess + * @return + */ + protected def withPreprocessedPipeline( + data: DataFrame, + inputCols: Seq[StructField]): DataFrame = { + + val transformers = inputCols + .collect { + case column if column.dataType == StringType => + val colName = column.name + val indexer = new StringIndexer().setInputCol(colName).setOutputCol(s"${colName}_Index") + val encoder = + new OneHotEncoder().setInputCol(s"${colName}_Index").setOutputCol(s"${colName}_Vec") + Seq(indexer, encoder) + + case column => + val colName = column.name + Seq( + new VectorAssembler() + .setInputCols(Array(colName)) + .setOutputCol(s"${colName}_Vec") + .setHandleInvalid("keep")) + } + .flatten + .toArray + + val preprocessingPipeline = new Pipeline().setStages(transformers) + val preprocessingModel = preprocessingPipeline.fit(data) + val preprocessedData = preprocessingModel.transform(data) + + preprocessedData + } + + /** + * Selects the top N minimum absolute correlated columns + * @param data + * the DataFrame + * @param inputCols + * the columns to preprocess + * @param numCols + * the number of columns to return + * @return + */ + protected def selectTopNCorrelatedColumns( + data: DataFrame, + inputCols: Seq[StructField], + numCols: Int): Array[String] = { + + val inputVecCols = inputCols.map(_.name + "_Vec").toArray + + val assembler = new VectorAssembler() + .setInputCols(inputVecCols) + .setOutputCol("features") + .setHandleInvalid("keep") + + val vectorDf = assembler.transform(data) + + // Calculate the correlation matrix + val correlationMatrix: DataFrame = Correlation.corr(vectorDf, "features") + // Extract the correlation matrix as a Matrix + val corrArray = correlationMatrix.select("pearson(features)").head.getAs[Matrix](0) + + // Calculate the average absolute correlation for each column + val averageCorrelation = + corrArray.toArray.map(Math.abs).grouped(inputVecCols.length).toArray.head + + // Get the indices of columns with the lowest average correlation + val sortedIndices = averageCorrelation.zipWithIndex.sortBy { case (corr, _) => corr } + val selectedIndices = sortedIndices.take(numCols).map(_._2) + + val selectedCols = selectedIndices.map(inputCols(_).name) + selectedCols + + } + + override def selectColumnsToIndex(data: DataFrame, numColumnsToIndex: Int): Seq[String] = { + + // IF there's no data to write, we return all the columns to index + if (data.isEmpty) { + return data.columns.take(numColumnsToIndex) + } + + val inputCols = data.schema + // Add unix timestamp columns + val updatedData = withUnixTimestamp(data, inputCols) + // Add column transformers + val preprocessedPipeline = withPreprocessedPipeline(updatedData, inputCols) + // Calculate the top N minimum absolute correlated columns + val selectedColumns = + selectTopNCorrelatedColumns(preprocessedPipeline, inputCols, numColumnsToIndex) + + selectedColumns + + } + +} diff --git a/src/main/scala/io/qbeast/spark/internal/sources/QbeastDataSource.scala b/src/main/scala/io/qbeast/spark/internal/sources/QbeastDataSource.scala index 96b76954b..834bc180f 100644 --- a/src/main/scala/io/qbeast/spark/internal/sources/QbeastDataSource.scala +++ b/src/main/scala/io/qbeast/spark/internal/sources/QbeastDataSource.scala @@ -10,6 +10,7 @@ import io.qbeast.spark.internal.QbeastOptions import io.qbeast.spark.table.IndexedTableFactory import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.fs.Path +import org.apache.spark.qbeast.config.COLUMN_SELECTOR_ENABLED import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.expressions.Transform @@ -94,7 +95,7 @@ class QbeastDataSource private[sources] (private val tableFactory: IndexedTableF data: DataFrame): BaseRelation = { require( - parameters.contains("columnsToIndex") || mode == SaveMode.Append, + parameters.contains("columnsToIndex") || mode == SaveMode.Append || COLUMN_SELECTOR_ENABLED, throw AnalysisExceptionFactory.create("'columnsToIndex' is not specified")) val tableId = QbeastOptions.loadTableIDFromParameters(parameters) diff --git a/src/main/scala/io/qbeast/spark/internal/sources/catalog/QbeastCatalog.scala b/src/main/scala/io/qbeast/spark/internal/sources/catalog/QbeastCatalog.scala index 945570a37..4fa6a90e9 100644 --- a/src/main/scala/io/qbeast/spark/internal/sources/catalog/QbeastCatalog.scala +++ b/src/main/scala/io/qbeast/spark/internal/sources/catalog/QbeastCatalog.scala @@ -6,7 +6,6 @@ package io.qbeast.spark.internal.sources.catalog import io.qbeast.context.QbeastContext import io.qbeast.spark.internal.sources.v2.QbeastStagedTableImpl import io.qbeast.spark.internal.sources.v2.QbeastTableImpl -import io.qbeast.spark.internal.QbeastOptions.checkQbeastProperties import org.apache.hadoop.fs.Path import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException @@ -109,7 +108,6 @@ class QbeastCatalog[T <: TableCatalog with SupportsNamespaces with FunctionCatal properties: util.Map[String, String]): Table = { if (QbeastCatalogUtils.isQbeastProvider(properties)) { - checkQbeastProperties(properties.asScala.toMap) // Create the table QbeastCatalogUtils.createQbeastTable( ident, diff --git a/src/main/scala/io/qbeast/spark/internal/sources/v2/QbeastStagedTableImpl.scala b/src/main/scala/io/qbeast/spark/internal/sources/v2/QbeastStagedTableImpl.scala index 70326a3a2..349fbd5b7 100644 --- a/src/main/scala/io/qbeast/spark/internal/sources/v2/QbeastStagedTableImpl.scala +++ b/src/main/scala/io/qbeast/spark/internal/sources/v2/QbeastStagedTableImpl.scala @@ -5,7 +5,6 @@ package io.qbeast.spark.internal.sources.v2 import io.qbeast.spark.internal.sources.catalog.CreationMode import io.qbeast.spark.internal.sources.catalog.QbeastCatalogUtils -import io.qbeast.spark.internal.QbeastOptions.checkQbeastProperties import io.qbeast.spark.table.IndexedTableFactory import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.connector.catalog.Identifier @@ -72,9 +71,6 @@ private[sources] class QbeastStagedTableImpl( // we pass all the writeOptions to the properties as well writeOptions.foreach { case (k, v) => props.put(k, v) } - // Check all the Qbeast properties are correctly specified - checkQbeastProperties(props.asScala.toMap) - // Creates the corresponding table on the Catalog and executes // the writing of the dataFrame (if any) QbeastCatalogUtils.createQbeastTable( diff --git a/src/main/scala/io/qbeast/spark/table/IndexedTable.scala b/src/main/scala/io/qbeast/spark/table/IndexedTable.scala index fe37d969c..0ff8f7ede 100644 --- a/src/main/scala/io/qbeast/spark/table/IndexedTable.scala +++ b/src/main/scala/io/qbeast/spark/table/IndexedTable.scala @@ -11,6 +11,7 @@ import io.qbeast.spark.internal.sources.QbeastBaseRelation import io.qbeast.spark.internal.QbeastOptions import io.qbeast.spark.internal.QbeastOptions.COLUMNS_TO_INDEX import io.qbeast.spark.internal.QbeastOptions.CUBE_SIZE +import org.apache.spark.qbeast.config.COLUMN_SELECTOR_ENABLED import org.apache.spark.qbeast.config.DEFAULT_NUMBER_OF_RETRIES import org.apache.spark.sql.delta.actions.FileAction import org.apache.spark.sql.sources.BaseRelation @@ -129,7 +130,8 @@ final class IndexedTableFactoryImpl( private val indexManager: IndexManager[DataFrame], private val metadataManager: MetadataManager[StructType, FileAction, QbeastOptions], private val dataWriter: DataWriter[DataFrame, StructType, FileAction], - private val revisionFactory: RevisionFactory[StructType, QbeastOptions]) + private val revisionFactory: RevisionFactory[StructType, QbeastOptions], + private val autoIndexer: ColumnsToIndexSelector[DataFrame]) extends IndexedTableFactory { override def getIndexedTable(tableID: QTableID): IndexedTable = @@ -139,7 +141,8 @@ final class IndexedTableFactoryImpl( indexManager, metadataManager, dataWriter, - revisionFactory) + revisionFactory, + autoIndexer) } @@ -158,6 +161,8 @@ final class IndexedTableFactoryImpl( * the data writer * @param revisionBuilder * the revision builder + * @param autoIndexer + * the auto indexer */ private[table] class IndexedTableImpl( val tableID: QTableID, @@ -165,7 +170,8 @@ private[table] class IndexedTableImpl( private val indexManager: IndexManager[DataFrame], private val metadataManager: MetadataManager[StructType, FileAction, QbeastOptions], private val dataWriter: DataWriter[DataFrame, StructType, FileAction], - private val revisionFactory: RevisionFactory[StructType, QbeastOptions]) + private val revisionFactory: RevisionFactory[StructType, QbeastOptions], + private val autoIndexer: ColumnsToIndexSelector[DataFrame]) extends IndexedTable with StagingUtils { private var snapshotCache: Option[QbeastSnapshot] = None @@ -208,7 +214,7 @@ private[table] class IndexedTableImpl( * @param parameters * the parameters required for indexing */ - private def addRequiredParams( + def addRequiredParams( latestRevision: Revision, parameters: Map[String, String]): Map[String, String] = { val columnsToIndex = latestRevision.columnTransformers.map(_.columnName).mkString(",") @@ -267,7 +273,17 @@ private[table] class IndexedTableImpl( } } } else { - val options = QbeastOptions(parameters) + // IF autoIndexingEnabled, choose columns to index + val optionalColumnsToIndex = parameters.contains(COLUMNS_TO_INDEX) + val updatedParameters = if (!optionalColumnsToIndex && !COLUMN_SELECTOR_ENABLED) { + throw AnalysisExceptionFactory.create( + "Auto indexing is disabled. Pleasespecify the columns to index in a comma separated way" + + " as .option(columnsToIndex, ...) or enable auto indexing with spark.qbeast.index.autoIndexingEnabled=true") + } else if (COLUMN_SELECTOR_ENABLED) { + val columnsToIndex = autoIndexer.selectColumnsToIndex(data) + parameters + (COLUMNS_TO_INDEX -> columnsToIndex.mkString(",")) + } else parameters + val options = QbeastOptions(updatedParameters) val revision = revisionFactory.createNewRevision(tableID, data.schema, options) (IndexStatus(revision), options) } diff --git a/src/main/scala/org/apache/spark/qbeast/config.scala b/src/main/scala/org/apache/spark/qbeast/config.scala index 3109bc3c2..cc68ca164 100644 --- a/src/main/scala/org/apache/spark/qbeast/config.scala +++ b/src/main/scala/org/apache/spark/qbeast/config.scala @@ -46,6 +46,18 @@ package object config { .longConf .createOptional + private[config] val columnsToIndexSelectorEnabled: ConfigEntry[Boolean] = + ConfigBuilder("spark.qbeast.index.columnsToIndex.auto") + .version("0.2.0") + .booleanConf + .createWithDefault(false) + + private[config] val maxNumColumnsToIndex: ConfigEntry[Int] = + ConfigBuilder("spark.qbeast.index.columnsToIndex.auto.max") + .version("0.2.0") + .intConf + .createWithDefault(3) + def DEFAULT_NUMBER_OF_RETRIES: Int = QbeastContext.config .get(defaultNumberOfRetries) @@ -63,4 +75,8 @@ package object config { def STAGING_SIZE_IN_BYTES: Option[Long] = QbeastContext.config.get(stagingSizeInBytes) + def COLUMN_SELECTOR_ENABLED: Boolean = QbeastContext.config.get(columnsToIndexSelectorEnabled) + + def MAX_NUM_COLUMNS_TO_INDEX: Int = QbeastContext.config.get(maxNumColumnsToIndex) + } diff --git a/src/test/scala/io/qbeast/context/QbeastContextTest.scala b/src/test/scala/io/qbeast/context/QbeastContextTest.scala index b7e18e455..e1c3e4754 100644 --- a/src/test/scala/io/qbeast/context/QbeastContextTest.scala +++ b/src/test/scala/io/qbeast/context/QbeastContextTest.scala @@ -7,6 +7,7 @@ import io.qbeast.core.keeper.Keeper import io.qbeast.core.keeper.LocalKeeper import io.qbeast.spark.delta.writer.RollupDataWriter import io.qbeast.spark.delta.SparkDeltaMetadataManager +import io.qbeast.spark.index.SparkColumnsToIndexSelector import io.qbeast.spark.index.SparkOTreeManager import io.qbeast.spark.index.SparkRevisionFactory import io.qbeast.spark.table.IndexedTableFactoryImpl @@ -24,7 +25,8 @@ class QbeastContextTest extends AnyFlatSpec with Matchers with QbeastIntegration SparkOTreeManager, SparkDeltaMetadataManager, RollupDataWriter, - SparkRevisionFactory) + SparkRevisionFactory, + SparkColumnsToIndexSelector) val unmanaged = new QbeastContextImpl( config = SparkSession.active.sparkContext.getConf, keeper = keeper, @@ -44,7 +46,8 @@ class QbeastContextTest extends AnyFlatSpec with Matchers with QbeastIntegration SparkOTreeManager, SparkDeltaMetadataManager, RollupDataWriter, - SparkRevisionFactory) + SparkRevisionFactory, + SparkColumnsToIndexSelector) val unmanaged = new QbeastContextImpl( config = SparkSession.active.sparkContext.getConf, keeper = keeper, diff --git a/src/test/scala/io/qbeast/spark/QbeastIntegrationTestSpec.scala b/src/test/scala/io/qbeast/spark/QbeastIntegrationTestSpec.scala index a194389f7..96427626d 100644 --- a/src/test/scala/io/qbeast/spark/QbeastIntegrationTestSpec.scala +++ b/src/test/scala/io/qbeast/spark/QbeastIntegrationTestSpec.scala @@ -11,6 +11,7 @@ import io.qbeast.core.keeper.LocalKeeper import io.qbeast.core.model.IndexManager import io.qbeast.spark.delta.writer.RollupDataWriter import io.qbeast.spark.delta.SparkDeltaMetadataManager +import io.qbeast.spark.index.SparkColumnsToIndexSelector import io.qbeast.spark.index.SparkOTreeManager import io.qbeast.spark.index.SparkRevisionFactory import io.qbeast.spark.table.IndexedTableFactoryImpl @@ -150,7 +151,8 @@ trait QbeastIntegrationTestSpec extends AnyFlatSpec with Matchers with DatasetCo SparkOTreeManager, SparkDeltaMetadataManager, RollupDataWriter, - SparkRevisionFactory) + SparkRevisionFactory, + SparkColumnsToIndexSelector) val context = new QbeastContextImpl(spark.sparkContext.getConf, keeper, indexedTableFactory) try { QbeastContext.setUnmanaged(context) diff --git a/src/test/scala/io/qbeast/spark/index/SparkColumnsToIndexSelectorTest.scala b/src/test/scala/io/qbeast/spark/index/SparkColumnsToIndexSelectorTest.scala new file mode 100644 index 000000000..597639413 --- /dev/null +++ b/src/test/scala/io/qbeast/spark/index/SparkColumnsToIndexSelectorTest.scala @@ -0,0 +1,119 @@ +package io.qbeast.spark.index + +import io.qbeast.spark.QbeastIntegrationTestSpec + +class SparkColumnsToIndexSelectorTest extends QbeastIntegrationTestSpec { + + behavior of "SparkColumnsToIndexSelector" + + it should "select correct columns for indexing" in withSpark(spark => { + + import spark.implicits._ + // Create test data + val testDF = Seq( + (1, "Alice", java.sql.Timestamp.valueOf("2023-01-01 10:00:00"), 12.5), + (2, "Bob", java.sql.Timestamp.valueOf("2023-01-02 11:30:00"), 15.0) + // Add more rows as needed + ).toDF("id", "name", "timestamp", "value") + + // Initialize SparkColumnsToIndexSelector + val autoIndexer = SparkColumnsToIndexSelector + val numColumnsToSelect = 2 // Adjust as needed + val selectedColumns = autoIndexer.selectColumnsToIndex(testDF, numColumnsToSelect) + + // Assertions + selectedColumns.length shouldBe numColumnsToSelect + selectedColumns should contain theSameElementsAs Seq("name", "value") + }) + + it should "not discard string columns" in withSpark(spark => { + + import spark.implicits._ + // Create test data + val testDF = Seq( + ("a", 20), + ("b", 30), + ("c", 40) + // Add more rows as needed + ).toDF("s", "i") + + // Initialize SparkColumnsToIndexSelector + val autoIndexer = SparkColumnsToIndexSelector + val selectedColumns = autoIndexer.selectColumnsToIndex(testDF) + + selectedColumns should contain theSameElementsAs testDF.columns + }) + + // TODO - Check if this should be the default behavior + it should "select maximum 3 columns by default" in withSpark(spark => { + + import spark.implicits._ + // Create test data + val testDF = Seq( + (1, "Alice", java.sql.Timestamp.valueOf("2023-01-01 10:00:00"), 12.5), + (2, "Bob", java.sql.Timestamp.valueOf("2023-01-02 11:30:00"), 15.0) + // Add more rows as needed + ).toDF("id", "name", "timestamp", "value") + + // Initialize SparkColumnsToIndexSelector + val autoIndexer = SparkColumnsToIndexSelector + val selectedColumns = autoIndexer.selectColumnsToIndex(testDF) + + selectedColumns.length shouldBe 3 + }) + + it should "select all columns if maxColumnsToIndex > num columns of dataframe" in withExtendedSpark( + sparkConf = + sparkConfWithSqlAndCatalog.set("spark.qbeast.index.columnsToIndex.auto.max", "10"))( + spark => { + + import spark.implicits._ + // Create test data + val testDF = Seq( + (1, 5, java.sql.Timestamp.valueOf("2023-01-01 10:00:00"), 12.5), + (2, 6, java.sql.Timestamp.valueOf("2023-01-02 11:30:00"), 15.0) + // Add more rows as needed + ).toDF("id", "name", "timestamp", "value") + + // Initialize SparkColumnsToIndexSelector + val autoIndexer = SparkColumnsToIndexSelector + + // Invoke method + val selectedColumns = autoIndexer.selectColumnsToIndex(testDF) + + // Assertions + selectedColumns.length shouldBe 4 // 4 columns in the test data + selectedColumns should contain theSameElementsAs testDF.columns + }) + + it should "not select more than maxColumnsToIndex" in withExtendedSpark(sparkConf = + sparkConfWithSqlAndCatalog.set("spark.qbeast.index.columnsToIndex.auto.max", "10"))(spark => { + + import spark.implicits._ + val largeColumnDF = + Seq((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), (2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13)) + .toDF() + + val autoIndexer = SparkColumnsToIndexSelector + val selectedLargeColumns = autoIndexer.selectColumnsToIndex(largeColumnDF) + selectedLargeColumns.length shouldBe 10 + }) + + // TODO - Check if this should be the default behavior + it should "use the 3 first columns if no data is provided" in withSpark(spark => { + + import spark.implicits._ + // Create test data + val testDF = Seq + .empty[(Int, String, java.sql.Timestamp, Double)] + .toDF("id", "name", "timestamp", "value") + + // Initialize SparkColumnsToIndexSelector + val autoIndexer = SparkColumnsToIndexSelector + val selectedColumns = autoIndexer.selectColumnsToIndex(testDF) + + // Assertions + selectedColumns should contain theSameElementsAs testDF.columns.take(3) + }) + +} diff --git a/src/test/scala/io/qbeast/spark/internal/sources/QbeastDataSourceTest.scala b/src/test/scala/io/qbeast/spark/internal/sources/QbeastDataSourceTest.scala index fdf919052..896b6f40f 100644 --- a/src/test/scala/io/qbeast/spark/internal/sources/QbeastDataSourceTest.scala +++ b/src/test/scala/io/qbeast/spark/internal/sources/QbeastDataSourceTest.scala @@ -6,6 +6,7 @@ package io.qbeast.spark.internal.sources import io.qbeast.core.model.QTableID import io.qbeast.spark.table.IndexedTable import io.qbeast.spark.table.IndexedTableFactory +import org.apache.log4j.Level import org.apache.spark.sql.connector.catalog.SparkCatalogV2Util import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.Transform @@ -16,6 +17,8 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SQLContext import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.SparkSession +import org.apache.spark.SparkConf import org.mockito.ArgumentMatchers.any import org.mockito.ArgumentMatchers.anyBoolean import org.mockito.Mockito.verify @@ -43,6 +46,20 @@ class QbeastDataSourceTest extends FixtureAnyFlatSpec with MockitoSugar with Mat override type FixtureParam = Fixture + private def withSpark[T](testCode: SparkSession => T): T = { + val spark = SparkSession + .builder() + .appName("QbeastDataSource") + .config(new SparkConf().setMaster("local[2]")) + .getOrCreate() + spark.sparkContext.setLogLevel(Level.WARN.toString) + try { + testCode(spark) + } finally { + spark.close() + } + } + override protected def withFixture(test: OneArgTest): Outcome = { val sqlContext = mock[SQLContext] val relation = mock[BaseRelation] @@ -161,10 +178,12 @@ class QbeastDataSourceTest extends FixtureAnyFlatSpec with MockitoSugar with Mat } it should "throw exception for write if columns to index are not specified" in { f => - val parameters = Map("path" -> path) - val data = mock[DataFrame] - a[AnalysisException] shouldBe thrownBy { - f.dataSource.createRelation(f.sqlContext, SaveMode.Overwrite, parameters, data) + withSpark { _ => + val parameters = Map("path" -> path) + val data = mock[DataFrame] + a[AnalysisException] shouldBe thrownBy { + f.dataSource.createRelation(f.sqlContext, SaveMode.Overwrite, parameters, data) + } } } diff --git a/src/test/scala/io/qbeast/spark/internal/sources/catalog/QbeastCatalogIntegrationTest.scala b/src/test/scala/io/qbeast/spark/internal/sources/catalog/QbeastCatalogIntegrationTest.scala index 94d682a99..a49d34f53 100644 --- a/src/test/scala/io/qbeast/spark/internal/sources/catalog/QbeastCatalogIntegrationTest.scala +++ b/src/test/scala/io/qbeast/spark/internal/sources/catalog/QbeastCatalogIntegrationTest.scala @@ -188,15 +188,6 @@ class QbeastCatalogIntegrationTest extends QbeastIntegrationTestSpec with Catalo }) - it should "throw an error when no columnsToIndex is specified" in - withQbeastContextSparkAndTmpWarehouse((spark, _) => { - - an[AnalysisException] shouldBe thrownBy( - spark.sql("CREATE OR REPLACE TABLE student (id INT, name STRING, age INT)" + - " USING qbeast")) - - }) - it should "throw an error when trying to replace a non-qbeast table" in withQbeastContextSparkAndTmpWarehouse((spark, _) => { diff --git a/src/test/scala/io/qbeast/spark/utils/QbeastSQLIntegrationTest.scala b/src/test/scala/io/qbeast/spark/utils/QbeastSQLIntegrationTest.scala index 1de638a00..c4feda69c 100644 --- a/src/test/scala/io/qbeast/spark/utils/QbeastSQLIntegrationTest.scala +++ b/src/test/scala/io/qbeast/spark/utils/QbeastSQLIntegrationTest.scala @@ -1,6 +1,8 @@ package io.qbeast.spark.utils +import io.qbeast.spark.index.SparkColumnsToIndexSelector import io.qbeast.spark.QbeastIntegrationTestSpec +import io.qbeast.spark.QbeastTable import io.qbeast.TestClasses.Student import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.DataFrame @@ -180,4 +182,25 @@ class QbeastSQLIntegrationTest extends QbeastIntegrationTestSpec { } } + it should "work without providing columnsToIndex" in withExtendedSparkAndTmpDir( + sparkConfWithSqlAndCatalog.set("spark.qbeast.index.columnsToIndex.auto", "true")) { + (spark, tmpDir) => + { + val data = createTestData(spark) + data.createOrReplaceTempView("data") + + spark.sql( + "CREATE OR REPLACE TABLE student USING qbeast " + + s"LOCATION '$tmpDir' " + + "AS SELECT * FROM data;") + + val autoColumnsToIndex = SparkColumnsToIndexSelector.selectColumnsToIndex(data) + + val qbeastTable = QbeastTable.forPath(spark, tmpDir) + qbeastTable.indexedColumns() shouldBe autoColumnsToIndex + qbeastTable.latestRevisionID() shouldBe 1L + + } + } + } diff --git a/src/test/scala/io/qbeast/spark/utils/QbeastSparkIntegrationTest.scala b/src/test/scala/io/qbeast/spark/utils/QbeastSparkIntegrationTest.scala index b8e424fb6..ed1096cb4 100644 --- a/src/test/scala/io/qbeast/spark/utils/QbeastSparkIntegrationTest.scala +++ b/src/test/scala/io/qbeast/spark/utils/QbeastSparkIntegrationTest.scala @@ -1,6 +1,7 @@ package io.qbeast.spark.utils import io.qbeast.spark.QbeastIntegrationTestSpec +import io.qbeast.spark.QbeastTable import io.qbeast.TestClasses.Student import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SparkSession @@ -131,4 +132,30 @@ class QbeastSparkIntegrationTest extends QbeastIntegrationTestSpec { assertSmallDatasetEquality(indexed, data, orderedComparison = false, ignoreNullable = true) }) + it should "work without providing columnsToIndex" in withExtendedSparkAndTmpDir( + sparkConfWithSqlAndCatalog.set("spark.qbeast.index.columnsToIndex.auto", "true")) { + (spark, tmpDir) => + { + val data = createStudentsTestData(spark) + data.write.format("qbeast").save(tmpDir) + + val indexed = spark.read.format("qbeast").load(tmpDir) + + indexed.count() shouldBe data.count() + + indexed.columns.toSet shouldBe data.columns.toSet + + assertSmallDatasetEquality( + indexed, + data, + orderedComparison = false, + ignoreNullable = true) + + val qbeastTable = QbeastTable.forPath(spark, tmpDir) + qbeastTable.indexedColumns() shouldBe Seq("name", "age", "id") + qbeastTable.latestRevisionID() shouldBe 1L + + } + } + }