Skip to content

Commit

Permalink
Python API for restoring delta table\removing cache() of fileToRemove…
Browse files Browse the repository at this point in the history
…\moving

 * Add possibility to restore delta table using version or timestamp from pyspark
   Examples:
   ```
   DeltaTable.forPath(spark, path).restoreToVersion(0)
   DeltaTable.forPath(spark, path).restoreToTimestamp('2021-01-01 01:01-01')
   ```
 * Remove unnecessary caching of filesToRemove in RestoreTableCommand
 * Move RestoreTableCommand to org.apache.spark.sql.delta.commands

Fixes delta-io#890

Signed-off-by: Maksym Dovhal <maksym.dovhal@gmail.com>
  • Loading branch information
Maksym Dovhal committed Jan 22, 2022
1 parent 0d07d09 commit c139bc0
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@
package io.delta.tables.execution

import scala.collection.Map

import org.apache.spark.sql.delta.{DeltaLog, RestoreTableCommand}
import org.apache.spark.sql.delta.commands.{DeltaGenerateCommand, VacuumCommand}
import org.apache.spark.sql.delta.DeltaLog
import org.apache.spark.sql.delta.commands.{DeltaGenerateCommand, RestoreTableCommand, VacuumCommand}
import org.apache.spark.sql.delta.util.AnalysisHelper
import io.delta.tables.DeltaTable

import org.apache.spark.sql.{functions, Column, DataFrame}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,18 @@
* limitations under the License.
*/

package org.apache.spark.sql.delta
package org.apache.spark.sql.delta.commands

import java.sql.Timestamp

import scala.collection.JavaConverters._
import scala.util.{Success, Try}

import org.apache.spark.sql.delta.DeltaErrors.timestampInvalid
import org.apache.spark.sql.delta.actions.{AddFile, RemoveFile}
import org.apache.spark.sql.delta.commands.DeltaCommand
import org.apache.spark.sql.delta.util.DeltaFileOperations.absolutePath

import org.apache.spark.sql.{Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.delta.{DeltaErrors, DeltaLog, DeltaOperations, Snapshot}
import org.apache.spark.sql.execution.command.LeafRunnableCommand
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.IGNORE_MISSING_FILES
Expand Down Expand Up @@ -91,9 +88,6 @@ case class RestoreTableCommand(
"left_anti")
.as[AddFile]
.map(_.copy(dataChange = true))
// To avoid recompute of Dataset with wide transformation by toLocalIterator and
// checkSnapshotFilesAvailability method with spark.sql.files.ignoreMissingFiles=false
.cache()

val filesToRemove = latestSnapshotFiles
.join(
Expand All @@ -102,8 +96,6 @@ case class RestoreTableCommand(
"left_anti")
.as[AddFile]
.map(_.removeWithTimestamp())
// To avoid recompute of Dataset with wide transformation by toLocalIterator
.cache()

try {
checkSnapshotFilesAvailability(deltaLog, filesToAdd, versionToRestore)
Expand All @@ -124,7 +116,6 @@ case class RestoreTableCommand(
metrics)
} finally {
filesToAdd.unpersist()
filesToRemove.unpersist()
}
}

Expand Down Expand Up @@ -184,6 +175,8 @@ case class RestoreTableCommand(
.getConf(IGNORE_MISSING_FILES)

if (!ignore) {
// To avoid recompute of files Dataset in calling method
files.cache()
val path = deltaLog.dataPath
val hadoopConf = spark.sparkContext.broadcast(
new SerializableConfiguration(deltaLog.newDeltaHadoopConf()))
Expand Down
40 changes: 40 additions & 0 deletions python/delta/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,46 @@ def upgradeTableProtocol(self, readerVersion: int, writerVersion: int) -> None:
type(writerVersion))
jdt.upgradeTableProtocol(readerVersion, writerVersion)

@since(1.2) # type: ignore[arg-type]
def restoreToVersion(self, version: int) -> DataFrame:
"""
Restore the DeltaTable to an older version of the table specified by version number.
Example::
deltaTable.restoreToVersion(1)
:param version: target version of restored table
:return: Dataframe with metrics of restore operation.
:rtype: pyspark.sql.DataFrame
"""

return DataFrame(
self._jdt.restoreToVersion(version),
self._spark._wrapped # type: ignore[attr-defined]
)

@since(1.2) # type: ignore[arg-type]
def restoreToTimestamp(self, timestamp: str) -> DataFrame:
"""
Restore the DeltaTable to an older version of the table specified by a timestamp.
Timestamp can be of the format yyyy-MM-dd or yyyy-MM-dd HH:mm:ss
Example::
deltaTable.restoreToTimestamp('2021-01-01')
deltaTable.restoreToTimestamp('2021-01-01 01:01:01')
:param timestamp: target timestamp of restored table
:return: Dataframe with metrics of restore operation.
:rtype: pyspark.sql.DataFrame
"""

return DataFrame(
self._jdt.restoreToTimestamp(timestamp),
self._spark._wrapped # type: ignore[attr-defined]
)

@staticmethod
def _dict_to_jmap(
sparkSession: SparkSession,
Expand Down
50 changes: 43 additions & 7 deletions python/delta/tests/test_deltatable.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,34 @@ def test_protocolUpgrade(self) -> None:
with self.assertRaisesRegex(ValueError, "writerVersion"):
dt.upgradeTableProtocol(1, {}) # type: ignore[arg-type]

def test_restoreToVersion(self) -> None:
self.__writeDeltaTable([('a', 1), ('b', 2)])
self.__overwriteDeltaTable([('a', 3), ('b', 2)],
schema=["key_new", "value_new"],
overwriteSchema='true')
DeltaTable.forPath(self.spark, self.tempFile).restoreToVersion(0)
restored = DeltaTable.forPath(self.spark, self.tempFile).toDF()

self.__checkAnswer(restored, [Row(key='a', value=1), Row(key='b', value=2)])

def test_restoreToTimestamp(self) -> None:
self.__writeDeltaTable([('a', 1), ('b', 2)])
timestampToRestore = DeltaTable.forPath(self.spark, self.tempFile) \
.history() \
.head() \
.timestamp \
.strftime('%Y-%m-%d %H:%M:%S.%f')

self.__overwriteDeltaTable([('a', 3), ('b', 2)],
schema=["key_new", "value_new"],
overwriteSchema='true')

DeltaTable.forPath(self.spark, self.tempFile).restoreToTimestamp(timestampToRestore)

restored = DeltaTable.forPath(self.spark, self.tempFile).toDF()

self.__checkAnswer(restored, [Row(key='a', value=1), Row(key='b', value=2)])

def __checkAnswer(self, df: DataFrame,
expectedAnswer: List[Any],
schema: Union[StructType, List[str]] = ["key", "value"]) -> None:
Expand All @@ -810,17 +838,25 @@ def __checkAnswer(self, df: DataFrame,
df.show()
raise

def __writeDeltaTable(self, datalist: List[Tuple[Any, Any]]) -> None:
df = self.spark.createDataFrame(datalist, ["key", "value"])
def __writeDeltaTable(self, datalist: List[Tuple[Any, Any]],
schema: Union[StructType, List[str]] = ["key", "value"]) -> None:
df = self.spark.createDataFrame(datalist, schema)
df.write.format("delta").save(self.tempFile)

def __writeAsTable(self, datalist: List[Tuple[Any, Any]], tblName: str) -> None:
df = self.spark.createDataFrame(datalist, ["key", "value"])
def __writeAsTable(self, datalist: List[Tuple[Any, Any]],
tblName: str,
schema: Union[StructType, List[str]] = ["key", "value"]) -> None:
df = self.spark.createDataFrame(datalist, schema)
df.write.format("delta").saveAsTable(tblName)

def __overwriteDeltaTable(self, datalist: List[Tuple[Any, Any]]) -> None:
df = self.spark.createDataFrame(datalist, ["key", "value"])
df.write.format("delta").mode("overwrite").save(self.tempFile)
def __overwriteDeltaTable(self, datalist: List[Tuple[Any, Any]],
schema: Union[StructType, List[str]] = ["key", "value"],
overwriteSchema: str = 'false') -> None:
df = self.spark.createDataFrame(datalist, schema)
df.write.format("delta") \
.option('overwriteSchema', overwriteSchema) \
.mode("overwrite") \
.save(self.tempFile)

def __createFile(self, fileName: str, content: Any) -> None:
with open(os.path.join(self.tempFile, fileName), 'w') as f:
Expand Down

0 comments on commit c139bc0

Please sign in to comment.