Skip to content

Commit

Permalink
GC combine to one command and handle deduplications (#2196)
Browse files Browse the repository at this point in the history
  • Loading branch information
guy-har authored Jul 6, 2021
1 parent 8288f4f commit 80405c0
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ package io.treeverse.clients

import com.google.common.cache.CacheBuilder
import io.lakefs.clients.api
import io.lakefs.clients.api.RetentionApi
import io.lakefs.clients.api.model.{
GarbageCollectionPrepareRequest,
GarbageCollectionPrepareResponse
}

import java.net.URI
import java.util.concurrent.TimeUnit
Expand Down Expand Up @@ -31,6 +36,7 @@ class ApiClient(apiUrl: String, accessKey: String, secretKey: String) {
private val commitsApi = new api.CommitsApi(client)
private val metadataApi = new api.MetadataApi(client)
private val branchesApi = new api.BranchesApi(client)
private val retentionApi = new RetentionApi(client)

private val storageNamespaceCache =
CacheBuilder.newBuilder().expireAfterWrite(2, TimeUnit.MINUTES).build[String, String]()
Expand All @@ -50,13 +56,24 @@ class ApiClient(apiUrl: String, accessKey: String, secretKey: String) {
)
}

def prepareGarbageCollectionCommits(
repoName: String,
previousRunID: String
): GarbageCollectionPrepareResponse = {
retentionApi.prepareGarbageCollectionCommits(
repoName,
new GarbageCollectionPrepareRequest().previousRunId(previousRunID)
)
}

def getMetaRangeURL(repoName: String, commitID: String): String = {
val commit = commitsApi.getCommit(repoName, commitID)
val metaRangeID = commit.getMetaRangeId

val metaRange = metadataApi.getMetaRange(repoName, metaRangeID)
val location = metaRange.getLocation
URI.create(getStorageNamespace(repoName) + "/" + location).normalize().toString
if (metaRangeID != "") {
val metaRange = metadataApi.getMetaRange(repoName, metaRangeID)
val location = metaRange.getLocation
URI.create(getStorageNamespace(repoName) + "/").resolve(location).normalize().toString
} else ""
}

def getRangeURL(repoName: String, rangeID: String): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import io.treeverse.clients.LakeFSContext.{
LAKEFS_CONF_API_URL_KEY
}
import org.apache.hadoop.conf.Configuration
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{SparkSession, _}
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration
import software.amazon.awssdk.core.retry.RetryPolicy
Expand All @@ -28,7 +30,6 @@ object GarbageCollector {
.option("header", value = true)
.option("inferSchema", value = true)
.csv(commitDFLocation)
.where(col("run_id") === runID)
}

private def getRangeTuples(
Expand All @@ -38,13 +39,16 @@ object GarbageCollector {
): Set[(String, Array[Byte], Array[Byte])] = {
val location =
new ApiClient(conf.apiURL, conf.accessKey, conf.secretKey).getMetaRangeURL(repo, commitID)
SSTableReader
.forMetaRange(new Configuration(), location)
.newIterator()
.map(range =>
(new String(range.id), range.message.minKey.toByteArray, range.message.maxKey.toByteArray)
)
.toSet
// continue on empty location, empty location is a result of a commit with no metaRangeID (e.g 'Repository created' commit)
if (location == "") Set()
else
SSTableReader
.forMetaRange(new Configuration(), location)
.newIterator()
.map(range =>
(new String(range.id), range.message.minKey.toByteArray, range.message.maxKey.toByteArray)
)
.toSet
}

def getRangesDFFromCommits(
Expand All @@ -67,14 +71,18 @@ object GarbageCollector {
.distinct
}

def getRangeAddresses(rangeID: String, conf: APIConfigurations, repo: String): Set[String] = {
def getRangeAddresses(
rangeID: String,
conf: APIConfigurations,
repo: String
): Seq[String] = {
val location =
new ApiClient(conf.apiURL, conf.accessKey, conf.secretKey).getRangeURL(repo, rangeID)
SSTableReader
.forRange(new Configuration(), location)
.newIterator()
.map(a => new String(a.key))
.toSet
.map(a => a.message.address)
.toSeq
}

def getEntryTuples(
Expand Down Expand Up @@ -209,48 +217,84 @@ object GarbageCollector {
): Dataset[Row] = {
val commitsDF = getCommitsDF(runID, commitDFLocation, spark)
val rangesDF = getRangesDFFromCommits(commitsDF, repo, conf)
getExpiredEntriesFromRanges(rangesDF, conf, repo)
val expired = getExpiredEntriesFromRanges(rangesDF, conf, repo)

val activeRangesDF = rangesDF.where("!expired")
subtractDeduplications(expired, activeRangesDF, conf, repo, spark)
}

private def subtractDeduplications(
expired: Dataset[Row],
activeRangesDF: Dataset[Row],
conf: APIConfigurations,
repo: String,
spark: SparkSession
): Dataset[Row] = {
val activeRangesRDD: RDD[String] =
activeRangesDF.select("range_id").rdd.distinct().map(x => x.getString(0))
val activeAddresses: RDD[String] = activeRangesRDD
.flatMap(range => {
getRangeAddresses(range, conf, repo)
})
.distinct()
val activeAddressesRows: RDD[Row] = activeAddresses.map(x => Row(x))
val schema = new StructType().add(StructField("address", StringType, true))
val activeDF = spark.createDataFrame(activeAddressesRows, schema)
// remove active addresses from delete candidates
expired.join(
activeDF,
expired("address") === activeDF("address"),
"leftanti"
)
}

def main(args: Array[String]) {
val spark = SparkSession.builder().getOrCreate()
if (args.length != 4) {
if (args.length != 2) {
Console.err.println(
"Usage: ... <repo_name> <runID> s3://storageNamespace/prepared_commits_table s3://storageNamespace/output_destination_table"
"Usage: ... <repo_name> <region>"
)
System.exit(1)
}

val repo = args(0)
val runID = args(1)
val commitDFLocation = args(2)
val addressesDFLocation = args(3)

val region = args(1)
val previousRunID =
"" //args(2) // TODO(Guys): get previous runID from arguments or from storage
val hc = spark.sparkContext.hadoopConfiguration
val apiURL = hc.get(LAKEFS_CONF_API_URL_KEY)
val accessKey = hc.get(LAKEFS_CONF_API_ACCESS_KEY_KEY)
val secretKey = hc.get(LAKEFS_CONF_API_SECRET_KEY_KEY)
val res = new ApiClient(apiURL, accessKey, secretKey)
.prepareGarbageCollectionCommits(repo, previousRunID)
val runID = res.getRunId

val gcCommitsLocation = ApiClient.translateS3(new URI(res.getGcCommitsLocation)).toString
val gcAddressesLocation = ApiClient.translateS3(new URI(res.getGcAddressesLocation)).toString

val expiredAddresses = getExpiredAddresses(repo,
runID,
commitDFLocation,
gcCommitsLocation,
spark,
APIConfigurations(apiURL, accessKey, secretKey)
).withColumn("run_id", lit(runID))
spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
expiredAddresses.write
.partitionBy("run_id")
.mode(SaveMode.Append)
.parquet(addressesDFLocation) // TODO(Guys): consider changing to overwrite
.mode(SaveMode.Overwrite)
.parquet(gcAddressesLocation)
S3BulkDeleter.remove(repo, gcAddressesLocation, runID, region, spark)
}
}

object S3BulkDeleter {
def repartitionBySize(df: DataFrame, maxSize: Int, column: String): DataFrame = {
private def repartitionBySize(df: DataFrame, maxSize: Int, column: String): DataFrame = {
val nRows = df.count()
val nPartitions = math.max(1, math.floor(nRows / maxSize)).toInt
df.repartitionByRange(nPartitions, col(column))
}

def delObjIteration(
private def delObjIteration(
bucket: String,
keys: Seq[String],
s3Client: S3Client,
Expand Down Expand Up @@ -297,21 +341,15 @@ object S3BulkDeleter {
})
}

def main(args: Array[String]): Unit = {
if (args.length != 5) {
Console.err.println(
"Usage: ... <repo_name> <runID> <region> s3://storageNamespace/prepared_addresses_table s3://storageNamespace/output_destination_table"
)
System.exit(1)
}
def remove(
repo: String,
addressDFLocation: String,
runID: String,
region: String,
spark: SparkSession
) = {
val MaxBulkSize = 1000
val awsRetries = 1000
val repo = args(0)
val runID = args(1)
val region = args(2)
val addressesDFLocation = args(3)
val deletedAddressesDFLocation = args(4)
val spark = SparkSession.builder().getOrCreate()

val hc = spark.sparkContext.hadoopConfiguration
val apiURL = hc.get(LAKEFS_CONF_API_URL_KEY)
Expand All @@ -326,16 +364,27 @@ object S3BulkDeleter {
if (addSuffixSlash.startsWith("/")) addSuffixSlash.substring(1) else addSuffixSlash

val df = spark.read
.parquet(addressesDFLocation)
.parquet(addressDFLocation)
.where(col("run_id") === runID)
.where(col("relative") === true)
val res =
bulkRemove(df, MaxBulkSize, spark, bucket, region, awsRetries, snPrefix).toDF("addresses")
res
.withColumn("run_id", lit(runID))
.write
.partitionBy("run_id")
.mode(SaveMode.Append)
.parquet(deletedAddressesDFLocation)
bulkRemove(df, MaxBulkSize, spark, bucket, region, awsRetries, snPrefix)
.toDF("addresses")
.collect()
}

def main(args: Array[String]): Unit = {
if (args.length != 4) {
Console.err.println(
"Usage: ... <repo_name> <runID> <region> s3://storageNamespace/prepared_addresses_table"
)
System.exit(1)
}
val repo = args(0)
val runID = args(1)
val region = args(2)
val addressesDFLocation = args(3)
val spark = SparkSession.builder().getOrCreate()
remove(repo, addressesDFLocation, runID, region, spark)
}
}

0 comments on commit 80405c0

Please sign in to comment.