Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ class RelationResolution(override val catalogManager: CatalogManager)
ident,
table,
u.clearWritePrivileges.options,
u.isStreaming
u.isStreaming,
finalTimeTravelSpec
)
loaded.foreach(AnalysisContext.get.relationCache.update(key, _))
u.getTagValue(LogicalPlan.PLAN_ID_TAG)
Expand All @@ -162,7 +163,8 @@ class RelationResolution(override val catalogManager: CatalogManager)
ident: Identifier,
table: Option[Table],
options: CaseInsensitiveStringMap,
isStreaming: Boolean): Option[LogicalPlan] = {
isStreaming: Boolean,
timeTravelSpec: Option[TimeTravelSpec]): Option[LogicalPlan] = {
table.map {
// To utilize this code path to execute V1 commands, e.g. INSERT,
// either it must be session catalog, or tracksPartitionsInCatalog
Expand All @@ -189,6 +191,7 @@ class RelationResolution(override val catalogManager: CatalogManager)

case table =>
if (isStreaming) {
assert(timeTravelSpec.isEmpty, "time travel is not allowed in streaming")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be impossible to reach this line with a valid time travel spec. Just a sanity check.

val v1Fallback = table match {
case withFallback: V2TableWithV1Fallback =>
Some(UnresolvedCatalogRelation(withFallback.v1Table, isStreaming = true))
Expand All @@ -210,7 +213,7 @@ class RelationResolution(override val catalogManager: CatalogManager)
} else {
SubqueryAlias(
catalog.name +: ident.asMultipartIdentifier,
DataSourceV2Relation.create(table, Some(catalog), Some(ident), options)
DataSourceV2Relation.create(table, Some(catalog), Some(ident), options, timeTravelSpec)
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,13 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap

sealed trait TimeTravelSpec

case class AsOfTimestamp(timestamp: Long) extends TimeTravelSpec
case class AsOfVersion(version: String) extends TimeTravelSpec
case class AsOfTimestamp(timestamp: Long) extends TimeTravelSpec {
override def toString: String = s"TIMESTAMP AS OF $timestamp"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed for proper simpleString implementation in DataSourceV2Relation. See tests below.

}

case class AsOfVersion(version: String) extends TimeTravelSpec {
override def toString: String = s"VERSION AS OF '$version'"
}

object TimeTravelSpec {
def create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,13 @@ class SessionCatalog(
getRawLocalOrGlobalTempView(toNameParts(name)).map(getTempViewPlan)
}

/**
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An overloaded version that takes name parts directly instead of taking TableIdentifier.

* Generate a [[View]] operator from the local or global temporary view stored.
*/
def getLocalOrGlobalTempView(name: Seq[String]): Option[View] = {
getRawLocalOrGlobalTempView(name).map(getTempViewPlan)
}

/**
* Return the raw logical plan of a temporary local or global view for the given name.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.datasources.v2

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation}
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation, TimeTravelSpec}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, ExposesMetadataColumns, Histogram, HistogramBin, LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
Expand All @@ -45,7 +45,8 @@ abstract class DataSourceV2RelationBase(
output: Seq[AttributeReference],
catalog: Option[CatalogPlugin],
identifier: Option[Identifier],
options: CaseInsensitiveStringMap)
options: CaseInsensitiveStringMap,
timeTravelSpec: Option[TimeTravelSpec] = None)
extends LeafNode with MultiInstanceRelation with NamedRelation {

import DataSourceV2Implicits._
Expand All @@ -65,7 +66,12 @@ abstract class DataSourceV2RelationBase(
override def skipSchemaResolution: Boolean = table.supports(TableCapability.ACCEPT_ANY_SCHEMA)

override def simpleString(maxFields: Int): String = {
s"RelationV2${truncatedString(output, "[", ", ", "]", maxFields)} $name"
val outputString = truncatedString(output, "[", ", ", "]", maxFields)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Covered with tests.

val nameWithTimeTravelSpec = timeTravelSpec match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: seems clearer to have this just return either $spec or empty, and have the final string be

"RelationV2$outputString $timeTravelSpec $name"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would require an extra unnecessary space after the name if the spec is empty.

case Some(spec) => s"$name $spec"
case _ => name
}
s"RelationV2$outputString $nameWithTimeTravelSpec"
}

override def computeStats(): Statistics = {
Expand Down Expand Up @@ -96,8 +102,9 @@ case class DataSourceV2Relation(
override val output: Seq[AttributeReference],
catalog: Option[CatalogPlugin],
identifier: Option[Identifier],
options: CaseInsensitiveStringMap)
extends DataSourceV2RelationBase(table, output, catalog, identifier, options)
options: CaseInsensitiveStringMap,
timeTravelSpec: Option[TimeTravelSpec] = None)
extends DataSourceV2RelationBase(table, output, catalog, identifier, options, timeTravelSpec)
with ExposesMetadataColumns {

import DataSourceV2Implicits._
Expand All @@ -117,7 +124,7 @@ case class DataSourceV2Relation(
def withMetadataColumns(): DataSourceV2Relation = {
val newMetadata = metadataOutput.filterNot(outputSet.contains)
if (newMetadata.nonEmpty) {
DataSourceV2Relation(table, output ++ newMetadata, catalog, identifier, options)
copy(output = output ++ newMetadata)
} else {
this
}
Expand Down Expand Up @@ -151,7 +158,12 @@ case class DataSourceV2ScanRelation(
override def name: String = relation.name

override def simpleString(maxFields: Int): String = {
s"RelationV2${truncatedString(output, "[", ", ", "]", maxFields)} $name"
val outputString = truncatedString(output, "[", ", ", "]", maxFields)
val nameWithTimeTravelSpec = relation.timeTravelSpec match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment

case Some(spec) => s"$name $spec"
case _ => name
}
s"RelationV2$outputString $nameWithTimeTravelSpec"
}

override def computeStats(): Statistics = {
Expand Down Expand Up @@ -235,17 +247,29 @@ object ExtractV2Table {
def unapply(relation: DataSourceV2Relation): Option[Table] = Some(relation.table)
}

object ExtractV2CatalogAndIdentifier {
def unapply(relation: DataSourceV2Relation): Option[(CatalogPlugin, Identifier)] = {
relation match {
case DataSourceV2Relation(_, _, Some(catalog), Some(identifier), _, _) =>
Some((catalog, identifier))
case _ =>
None
}
}
}

object DataSourceV2Relation {
def create(
table: Table,
catalog: Option[CatalogPlugin],
identifier: Option[Identifier],
options: CaseInsensitiveStringMap): DataSourceV2Relation = {
options: CaseInsensitiveStringMap,
timeTravelSpec: Option[TimeTravelSpec] = None): DataSourceV2Relation = {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
// The v2 source may return schema containing char/varchar type. We replace char/varchar
// with "annotated" string type here as the query engine doesn't support char/varchar yet.
val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(table.columns.asSchema)
DataSourceV2Relation(table, toAttributes(schema), catalog, identifier, options)
DataSourceV2Relation(table, toAttributes(schema), catalog, identifier, options, timeTravelSpec)
}

def create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,10 @@ abstract class InMemoryBaseTable(
}
}
}

def copy(): Table = {
throw new UnsupportedOperationException(s"copy is not supported for ${getClass.getName}")
}
}

object InMemoryBaseTable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,41 @@ class InMemoryTable(
new InMemoryWriterBuilderWithOverWrite(info)
}

override def copy(): Table = {
val copiedTable = new InMemoryTable(
name,
columns(),
partitioning,
properties,
constraints,
distribution,
ordering,
numPartitions,
advisoryPartitionSize,
isDistributionStrictlyRequired,
numRowsPerSplit)

dataMap.synchronized {
dataMap.foreach { case (key, splits) =>
val copiedSplits = splits.map { bufferedRows =>
val copiedBufferedRows = new BufferedRows(bufferedRows.key, bufferedRows.schema)
copiedBufferedRows.rows ++= bufferedRows.rows.map(_.copy())
copiedBufferedRows
}
copiedTable.dataMap.put(key, copiedSplits)
}
}

copiedTable.commits ++= commits.map(_.copy())

copiedTable.setCurrentVersion(currentVersion())
if (validatedVersion() != null) {
copiedTable.setValidatedVersion(validatedVersion())
}

copiedTable
}

class InMemoryWriterBuilderWithOverWrite(override val info: LogicalWriteInfo)
extends InMemoryWriterBuilder(info) with SupportsOverwrite {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,19 @@ class BasicInMemoryTableCatalog extends TableCatalog {
}
}

def pinTable(ident: Identifier, version: String): Unit = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used in time travel tests.

Option(tables.get(ident)) match {
case Some(table: InMemoryBaseTable) =>
val versionIdent = Identifier.of(ident.namespace, ident.name + version)
val versionTable = table.copy()
tables.put(versionIdent, versionTable)
case Some(table) =>
throw new UnsupportedOperationException(s"Can't pin ${table.getClass.getName}")
case _ =>
throw new NoSuchTableException(ident.asMultipartIdentifier)
}
}

override def loadTable(ident: Identifier, version: String): Table = {
val versionIdent = Identifier.of(ident.namespace, ident.name + version)
Option(tables.get(versionIdent)) match {
Expand Down
30 changes: 14 additions & 16 deletions sql/core/src/main/scala/org/apache/spark/sql/classic/Catalog.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{CatalogHelper,
import org.apache.spark.sql.connector.catalog.CatalogV2Util.v2ColumnsToStructType
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.command.{ShowNamespacesCommand, ShowTablesCommand}
import org.apache.spark.sql.execution.command.CommandUtils
import org.apache.spark.sql.execution.datasources.{DataSource, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.connector.V1Function
Expand Down Expand Up @@ -810,20 +811,13 @@ class Catalog(sparkSession: SparkSession) extends catalog.Catalog {
* @since 2.0.0
*/
override def uncacheTable(tableName: String): Unit = {
// We first try to parse `tableName` to see if it is 2 part name. If so, then in HMS we check
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I needed to migrate this logic to uncaching by name. I feel like the try/catch was redundant. Keep in mind that internally calls parseMultipartIdentifier on spark.table(tableName) so it is not different compared to the old implementation. I see that getLocalOrGlobalTempView already handles multi part names correctly.

It would be great to have another pair of eyes on this one, though.

// if it is a temp view and uncache the temp view from HMS, otherwise we uncache it from the
// cache manager.
// if `tableName` is not 2 part name, then we directly uncache it from the cache manager.
try {
val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName)
sessionCatalog.getLocalOrGlobalTempView(tableIdent).map(uncacheView).getOrElse {
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName),
cascade = true)
}
} catch {
case e: org.apache.spark.sql.catalyst.parser.ParseException =>
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName),
cascade = true)
// parse the table name and check if it's a temp view (must have 1-2 name parts)
// temp views are uncached using uncacheView which respects view text semantics (SPARK-33142)
// use CommandUtils for all tables (including with 3+ part names)
val nameParts = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(tableName)
sessionCatalog.getLocalOrGlobalTempView(nameParts).map(uncacheView).getOrElse {
val relation = resolveRelation(tableName)
CommandUtils.uncacheTableOrView(sparkSession, relation, cascade = true)
}
}

Expand Down Expand Up @@ -868,7 +862,7 @@ class Catalog(sparkSession: SparkSession) extends catalog.Catalog {
* @since 2.0.0
*/
override def refreshTable(tableName: String): Unit = {
val relation = sparkSession.table(tableName).queryExecution.analyzed
val relation = resolveRelation(tableName)

relation.refresh()

Expand All @@ -891,7 +885,11 @@ class Catalog(sparkSession: SparkSession) extends catalog.Catalog {
// Note this is a no-op for the relation itself if it's not cached, but will clear all
// caches referencing this relation. If this relation is cached as an InMemoryRelation,
// this will clear the relation cache and caches of all its dependents.
sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, relation)
CommandUtils.recacheTableOrView(sparkSession, relation)
}

private def resolveRelation(tableName: String): LogicalPlan = {
sparkSession.table(tableName).queryExecution.analyzed
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{IgnoreCachedData, LogicalPla
import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.classic.{Dataset, SparkSession}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.command.CommandUtils
Expand Down Expand Up @@ -83,6 +84,11 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
cachedData.isEmpty
}

// Test-only
private[sql] def numCachedEntries: Int = {
cachedData.size
}

// Test-only
def cacheQuery(query: Dataset[_]): Unit = {
cacheQuery(query, tableName = None, storageLevel = MEMORY_AND_DISK)
Expand Down Expand Up @@ -215,12 +221,23 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
uncacheByCondition(spark, _.sameResult(plan), cascade, blocking)
}

def uncacheTableOrView(spark: SparkSession, name: Seq[String], cascade: Boolean): Unit = {
def uncacheTableOrView(
spark: SparkSession,
name: Seq[String],
cascade: Boolean,
blocking: Boolean = false): Unit = {
uncacheByCondition(
spark, isMatchedTableOrView(_, name, spark.sessionState.conf), cascade, blocking = false)
spark,
isMatchedTableOrView(_, name, spark.sessionState.conf, includeTimeTravel = true),
cascade,
blocking)
}

private def isMatchedTableOrView(plan: LogicalPlan, name: Seq[String], conf: SQLConf): Boolean = {
private def isMatchedTableOrView(
plan: LogicalPlan,
name: Seq[String],
conf: SQLConf,
includeTimeTravel: Boolean): Boolean = {
def isSameName(nameInCache: Seq[String]): Boolean = {
nameInCache.length == name.length && nameInCache.zip(name).forall(conf.resolver.tupled)
}
Expand All @@ -229,9 +246,9 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
case LogicalRelationWithTable(_, Some(catalogTable)) =>
isSameName(catalogTable.identifier.nameParts)

case DataSourceV2Relation(_, _, Some(catalog), Some(v2Ident), _) =>
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper
isSameName(v2Ident.toQualifiedNameParts(catalog))
case DataSourceV2Relation(_, _, Some(catalog), Some(v2Ident), _, timeTravelSpec) =>
val nameInCache = v2Ident.toQualifiedNameParts(catalog)
isSameName(nameInCache) && (includeTimeTravel || timeTravelSpec.isEmpty)

case v: View =>
isSameName(v.desc.identifier.nameParts)
Expand Down Expand Up @@ -304,6 +321,19 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
recacheByCondition(spark, _.plan.exists(_.sameResult(normalized)))
}

/**
* Re-caches all cache entries that reference the given table name.
*/
def recacheTableOrView(
spark: SparkSession,
name: Seq[String],
includeTimeTravel: Boolean = true): Unit = {
def shouldInvalidate(entry: CachedData): Boolean = {
entry.plan.exists(isMatchedTableOrView(_, name, spark.sessionState.conf, includeTimeTravel))
}
recacheByCondition(spark, shouldInvalidate)
}

/**
* Re-caches all the cache entries that satisfies the given `condition`.
*/
Expand Down
Loading