Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BT-711 Refresh SAS token for filesystem on expiry #6831

Merged
merged 18 commits into from
Sep 12, 2022
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package cromwell.filesystems.blob

import com.azure.core.credential.AzureSasCredential
import com.azure.core.management.AzureEnvironment
import com.azure.core.management.profile.AzureProfile
import com.azure.identity.DefaultAzureCredentialBuilder
import com.azure.resourcemanager.AzureResourceManager
import com.azure.storage.blob.nio.AzureFileSystem
import com.azure.storage.blob.sas.{BlobContainerSasPermission, BlobServiceSasSignatureValues}
import com.azure.storage.blob.{BlobContainerClient, BlobContainerClientBuilder}
import com.azure.storage.common.StorageSharedKeyCredential

import java.net.URI
import java.nio.file.{FileSystem, FileSystemNotFoundException, FileSystems}
import java.time.temporal.ChronoUnit
import java.time.{Duration, Instant, OffsetDateTime}
import scala.jdk.CollectionConverters._
import scala.util.{Failure, Success, Try}
import com.azure.resourcemanager.storage.models.StorageAccountKey

case class FileSystemAPI() {
def getFileSystem(uri: URI): Try[FileSystem] = Try(FileSystems.getFileSystem(uri))
def newFileSystem(uri: URI, config: Map[String, Object]): FileSystem = FileSystems.newFileSystem(uri, config.asJava)
def closeFileSystem(uri: URI): Option[Unit] = getFileSystem(uri).toOption.map(_.close)
}

object BlobFileSystemManager {
def parseTokenExpiry(token: AzureSasCredential): Option[Instant] = for {
expiryString <- token.getSignature.split("&").find(_.startsWith("se")).map(_.replaceFirst("se=","")).map(_.replace("%3A", ":"))
instant = Instant.parse(expiryString)
} yield instant

def buildConfigMap(credential: AzureSasCredential, container: BlobContainerName): Map[String, Object] = {
Map((AzureFileSystem.AZURE_STORAGE_SAS_TOKEN_CREDENTIAL, credential),
(AzureFileSystem.AZURE_STORAGE_FILE_STORES, container.value),
(AzureFileSystem.AZURE_STORAGE_SKIP_INITIAL_CONTAINER_CHECK, java.lang.Boolean.TRUE))
}
def hasTokenExpired(tokenExpiry: Instant, buffer: Duration): Boolean = Instant.now.plus(buffer).isAfter(tokenExpiry)
def uri(endpoint: EndpointURL) = new URI("azb://?endpoint=" + endpoint)
}
case class BlobFileSystemManager(
container: BlobContainerName,
endpoint: EndpointURL,
expiryBufferMinutes: Long,
blobTokenGenerator: BlobTokenGenerator,
fileSystemAPI: FileSystemAPI = FileSystemAPI(),
private val initialExpiration: Option[Instant] = None) {
private var expiry: Option[Instant] = initialExpiration
val buffer: Duration = Duration.of(expiryBufferMinutes, ChronoUnit.MINUTES)

def getExpiry: Option[Instant] = expiry
def uri: URI = BlobFileSystemManager.uri(endpoint)
def isTokenExpired: Boolean = expiry.exists(BlobFileSystemManager.hasTokenExpired(_, buffer))
def shouldReopenFilesystem: Boolean = isTokenExpired || expiry.isEmpty
def retrieveFilesystem(): Try[FileSystem] = {
synchronized {
shouldReopenFilesystem match {
case false => fileSystemAPI.getFileSystem(uri).recoverWith {
// If no filesystem already exists, this will create a new connection, with the provided configs
case _: FileSystemNotFoundException => blobTokenGenerator.generateAccessToken.flatMap(generateFilesystem(uri, container, _))
}
// If the token has expired, OR there is no token record, try to close the FS and regenerate
case true =>
fileSystemAPI.closeFileSystem(uri)
blobTokenGenerator.generateAccessToken.flatMap(generateFilesystem(uri, container, _))
}
}
}

private def generateFilesystem(uri: URI, container: BlobContainerName, token: AzureSasCredential): Try[FileSystem] = {
expiry = BlobFileSystemManager.parseTokenExpiry(token)
if (expiry.isEmpty) return Failure(new Exception("Could not reopen filesystem, no expiration found"))
Try(fileSystemAPI.newFileSystem(uri, BlobFileSystemManager.buildConfigMap(token, container)))
}

}

sealed trait BlobTokenGenerator {def generateAccessToken: Try[AzureSasCredential]}
object BlobTokenGenerator {
def createBlobTokenGenerator(container: BlobContainerName, endpoint: EndpointURL, subscription: Option[String]): BlobTokenGenerator = {
createBlobTokenGenerator(container, endpoint, None, None, subscription)
}
def createBlobTokenGenerator(container: BlobContainerName, endpoint: EndpointURL, workspaceId: Option[WorkspaceId], workspaceManagerURL: Option[WorkspaceManagerURL], subscription: Option[String]): BlobTokenGenerator = {
(container: BlobContainerName, endpoint: EndpointURL, workspaceId, workspaceManagerURL) match {
case (container, endpoint, None, None) =>
NativeBlobTokenGenerator(container, endpoint, subscription)
case (container, endpoint, Some(workspaceId), Some(workspaceManagerURL)) =>
WSMBlobTokenGenerator(container, endpoint, workspaceId, workspaceManagerURL)
case _ =>
throw new Exception("Arguments provided do not match any available BlobTokenGenerator implementation.")
}
}
def createBlobTokenGenerator(container: BlobContainerName, endpoint: EndpointURL): BlobTokenGenerator = createBlobTokenGenerator(container, endpoint, None)
def createBlobTokenGenerator(container: BlobContainerName, endpoint: EndpointURL, workspaceId: Option[WorkspaceId], workspaceManagerURL: Option[WorkspaceManagerURL]): BlobTokenGenerator =
createBlobTokenGenerator(container, endpoint, workspaceId, workspaceManagerURL, None)

}

case class WSMBlobTokenGenerator(container: BlobContainerName, endpoint: EndpointURL, workspaceId: WorkspaceId, workspaceManagerURL: WorkspaceManagerURL) extends BlobTokenGenerator {
def generateAccessToken: Try[AzureSasCredential] = Failure(new NotImplementedError)
}

case class NativeBlobTokenGenerator(container: BlobContainerName, endpoint: EndpointURL, subscription: Option[String] = None) extends BlobTokenGenerator {

private val azureProfile = new AzureProfile(AzureEnvironment.AZURE)
private def azureCredentialBuilder = new DefaultAzureCredentialBuilder()
.authorityHost(azureProfile.getEnvironment.getActiveDirectoryEndpoint)
.build
private def authenticateWithSubscription(sub: String) = AzureResourceManager.authenticate(azureCredentialBuilder, azureProfile).withSubscription(sub)
private def authenticateWithDefaultSubscription = AzureResourceManager.authenticate(azureCredentialBuilder, azureProfile).withDefaultSubscription()
private def azure = subscription.map(authenticateWithSubscription(_)).getOrElse(authenticateWithDefaultSubscription)

private def findAzureStorageAccount(name: StorageAccountName) = azure.storageAccounts.list.asScala.find(_.name.equals(name.value))
.map(Success(_)).getOrElse(Failure(new Exception("Azure Storage Account not found")))
private def buildBlobContainerClient(credential: StorageSharedKeyCredential, endpoint: EndpointURL, container: BlobContainerName): BlobContainerClient = {
new BlobContainerClientBuilder()
.credential(credential)
.endpoint(endpoint.value)
.containerName(container.value)
.buildClient()
}
private val bcsp = new BlobContainerSasPermission()
.setReadPermission(true)
.setCreatePermission(true)
.setListPermission(true)


def generateAccessToken: Try[AzureSasCredential] = for {
uri <- BlobPathBuilder.parseURI(endpoint.value)
configuredAccount <- BlobPathBuilder.parseStorageAccount(uri)
azureAccount <- findAzureStorageAccount(configuredAccount)
keys = azureAccount.getKeys.asScala
key <- keys.headOption.fold[Try[StorageAccountKey]](Failure(new Exception("Storage account has no keys")))(Success(_))
first = key.value
sskc = new StorageSharedKeyCredential(configuredAccount.value, first)
bcc = buildBlobContainerClient(sskc, endpoint, container)
bsssv = new BlobServiceSasSignatureValues(OffsetDateTime.now.plusDays(1), bcsp)
asc = new AzureSasCredential(bcc.generateSas(bsssv))
} yield asc
}
Original file line number Diff line number Diff line change
@@ -1,26 +1,23 @@
package cromwell.filesystems.blob

import com.azure.core.credential.AzureSasCredential
import com.azure.storage.blob.nio.AzureFileSystem
import com.google.common.net.UrlEscapers
import cromwell.core.path.{NioPath, Path, PathBuilder}
import cromwell.filesystems.blob.BlobPathBuilder._

import java.net.{MalformedURLException, URI}
import java.nio.file.{FileSystem, FileSystemNotFoundException, FileSystems}
import scala.jdk.CollectionConverters._
import scala.language.postfixOps
import scala.util.{Failure, Try}
import scala.util.{Failure, Success, Try}

object BlobPathBuilder {

sealed trait BlobPathValidation
case class ValidBlobPath(path: String) extends BlobPathValidation
case class UnparsableBlobPath(errorMessage: Throwable) extends BlobPathValidation

def invalidBlobPathMessage(container: String, endpoint: String) = s"Malformed Blob URL for this builder. Expecting a URL for a container $container and endpoint $endpoint"
def parseURI(string: String) = URI.create(UrlEscapers.urlFragmentEscaper().escape(string))
def parseStorageAccount(uri: URI) = uri.getHost().split("\\.").filter(!_.isEmpty()).headOption
def invalidBlobPathMessage(container: BlobContainerName, endpoint: EndpointURL) = s"Malformed Blob URL for this builder. Expecting a URL for a container $container and endpoint $endpoint"
def parseURI(string: String): Try[URI] = Try(URI.create(UrlEscapers.urlFragmentEscaper().escape(string)))
def parseStorageAccount(uri: URI): Try[StorageAccountName] = uri.getHost.split("\\.").find(_.nonEmpty).map(StorageAccountName(_))
.map(Success(_)).getOrElse(Failure(new Exception("Could not parse storage account")))

/**
* Validates a that a path from a string is a valid BlobPath of the format:
Expand All @@ -40,54 +37,48 @@ object BlobPathBuilder {
*
* If the configured container and storage account do not match, the string is considered unparsable
*/
def validateBlobPath(string: String, container: String, endpoint: String): BlobPathValidation = {
Try {
val uri = parseURI(string)
val storageAccount = parseStorageAccount(parseURI(endpoint))
val hasContainer = uri.getPath().split("/").filter(!_.isEmpty()).headOption.contains(container)
def hasEndpoint = parseStorageAccount(uri).contains(storageAccount.get)
if (hasContainer && !storageAccount.isEmpty && hasEndpoint) {
ValidBlobPath(uri.getPath.replaceFirst("/" + container, ""))
} else {
UnparsableBlobPath(new MalformedURLException(invalidBlobPathMessage(container, endpoint)))
def validateBlobPath(string: String, container: BlobContainerName, endpoint: EndpointURL): BlobPathValidation = {
val blobValidation = for {
testUri <- parseURI(string)
endpointUri <- parseURI(endpoint.value)
testStorageAccount <- parseStorageAccount(testUri)
endpointStorageAccount <- parseStorageAccount(endpointUri)
hasContainer = testUri.getPath.split("/").find(_.nonEmpty).contains(container.value)
hasEndpoint = testStorageAccount.equals(endpointStorageAccount)
blobPathValidation = (hasContainer && hasEndpoint) match {
case true => ValidBlobPath(testUri.getPath.replaceFirst("/" + container, ""))
case false => UnparsableBlobPath(new MalformedURLException(invalidBlobPathMessage(container, endpoint)))
}
} recover { case t => UnparsableBlobPath(t) } get
} yield blobPathValidation
blobValidation recover { case t => UnparsableBlobPath(t) } get
}
}

class BlobPathBuilder(blobTokenGenerator: BlobTokenGenerator, container: String, endpoint: String) extends PathBuilder {

val credential: AzureSasCredential = new AzureSasCredential(blobTokenGenerator.getAccessToken)
val fileSystemConfig: Map[String, Object] = Map((AzureFileSystem.AZURE_STORAGE_SAS_TOKEN_CREDENTIAL, credential),
(AzureFileSystem.AZURE_STORAGE_FILE_STORES, container),
(AzureFileSystem.AZURE_STORAGE_SKIP_INITIAL_CONTAINER_CHECK, java.lang.Boolean.TRUE))

def retrieveFilesystem(uri: URI): Try[FileSystem] = {
Try(FileSystems.getFileSystem(uri)) recover {
// If no filesystem already exists, this will create a new connection, with the provided configs
case _: FileSystemNotFoundException => FileSystems.newFileSystem(uri, fileSystemConfig.asJava)
}
}
class BlobPathBuilder(container: BlobContainerName, endpoint: EndpointURL)(private val fsm: BlobFileSystemManager) extends PathBuilder {

def build(string: String): Try[BlobPath] = {
validateBlobPath(string, container, endpoint) match {
case ValidBlobPath(path) => for {
fileSystem <- retrieveFilesystem(new URI("azb://?endpoint=" + endpoint))
nioPath <- Try(fileSystem.getPath(path))
blobPath = BlobPath(nioPath, endpoint, container)
} yield blobPath
case ValidBlobPath(path) => Try(BlobPath(path, endpoint, container)(fsm))
case UnparsableBlobPath(errorMessage: Throwable) => Failure(errorMessage)
}
}

override def name: String = "Azure Blob Storage"
}

// Add args for container, storage account name
case class BlobPath private[blob](nioPath: NioPath, endpoint: String, container: String) extends Path {
override protected def newPath(nioPath: NioPath): Path = BlobPath(nioPath, endpoint, container)
case class BlobPath private[blob](pathString: String, endpoint: EndpointURL, container: BlobContainerName)(private val fsm: BlobFileSystemManager) extends Path {
override def nioPath: NioPath = findNioPath(pathString)

override protected def newPath(nioPath: NioPath): Path = BlobPath(nioPath.toString, endpoint, container)(fsm)

override def pathAsString: String = List(endpoint, container, nioPath.toString).mkString("/")

override def pathAsString: String = List(endpoint, container, nioPath.toString()).mkString("/")
//This is purposefully an unprotected get because if the endpoint cannot be parsed this should fail loudly rather than quietly
override def pathWithoutScheme: String = parseURI(endpoint.value).map(_.getHost + "/" + container + "/" + nioPath.toString).get

override def pathWithoutScheme: String = parseURI(endpoint).getHost + "/" + container + "/" + nioPath.toString()
private def findNioPath(path: String): NioPath = (for {
fileSystem <- fsm.retrieveFilesystem()
nioPath = fileSystem.getPath(path)
// This is purposefully an unprotected get because the NIO API needing an unwrapped path object.
// If an error occurs the api expects a thrown exception
} yield nioPath).get
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've discussed begrudgingly accepting that this method will throw. We should ensure it throws something useful, though. I think we should throw different informative error messages depending on whether we failed to get the filesystem or failed to create the NIO path.

}
Loading