Skip to content

Commit

Permalink
more better tests
Browse files Browse the repository at this point in the history
  • Loading branch information
THWiseman committed Sep 12, 2023
1 parent a05f63e commit c66ee6f
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import cloud.nio.impl.drs._
import cloud.nio.spi.{CloudNioBackoff, CloudNioSimpleExponentialBackoff}
import com.typesafe.scalalogging.StrictLogging
import drs.localizer.CommandLineParser.AccessTokenStrategy.{Azure, Google}
import drs.localizer.DrsLocalizerMain.toUriType
import drs.localizer.DrsLocalizerMain.toValidatedUriType
import drs.localizer.downloaders._
import org.apache.commons.csv.{CSVFormat, CSVParser}

Expand Down Expand Up @@ -84,27 +84,50 @@ object DrsLocalizerMain extends IOApp with StrictLogging {
}
}

def toUriType(accessUrl: Option[AccessUrl], gsUri: Option[String]): URIType = {
/**
* Helper function to decide which downloader to use based on data from the DRS response.
* Throws a runtime exception if the DRS response is invalid.
*/
def toValidatedUriType(accessUrl: Option[AccessUrl], gsUri: Option[String]): URIType = {
// if both are provided, prefer using access urls
(accessUrl, gsUri) match {
case (Some(_), _) =>
if(!accessUrl.get.url.startsWith("https://")) { throw new RuntimeException("Resolved Access URL does not start with https://")}
URIType.ACCESS
case (_, Some(_)) =>
if(!gsUri.get.startsWith("gs://")) { throw new RuntimeException("Resolved Google URL does not start with gs://")}
URIType.GCS
case (Some(_), _) =>
URIType.HTTPS
case (_, _) =>
URIType.UNKNOWN
throw new RuntimeException("DRS response did not contain any URLs")
}
}
}

object URIType extends Enumeration {
type URIType = Value
val GCS, HTTPS, UNKNOWN = Value
val GCS, ACCESS, UNKNOWN = Value
}

class DrsLocalizerMain(toResolveAndDownload: IO[List[UnresolvedDrsUrl]],
downloaderFactory: DownloaderFactory,
drsCredentials: DrsCredentials,
requesterPaysProjectIdOption: Option[String]) extends StrictLogging {

/**
* This will:
* - resolve all URLS
* - build downloader(s) for them
* - Invoke the downloaders to localize the files.
* @return DownloadSuccess if all downloads succeed. An error otherwise.
*/
def resolveAndDownload(): IO[DownloadResult] = {
IO {
val downloaders: List[Downloader] = buildDownloaders().unsafeRunSync()
val results: List[DownloadResult] = downloaders.map(downloader => downloader.download.unsafeRunSync())
results.find(res => res != DownloadSuccess).getOrElse(DownloadSuccess)
}
}

def getDrsPathResolver: IO[DrsLocalizerDrsPathResolver] = {
IO {
val drsConfig = DrsConfig.fromEnv(sys.env)
Expand All @@ -113,42 +136,41 @@ class DrsLocalizerMain(toResolveAndDownload: IO[List[UnresolvedDrsUrl]],
}
}

def resolveUrls(unresolvedUrls: IO[List[UnresolvedDrsUrl]]) : IO[List[ResolvedDrsUrl]] = {
unresolvedUrls.flatMap{unresolvedList =>
getDrsPathResolver.flatMap{resolver =>
unresolvedList.map{unresolvedUrl =>
resolveSingleUrl(resolver, unresolvedUrl)
}.traverse(identity)
}
}
}

/**
* Runs a synchronous HTTP request to resolve the provided DRS URL with the provided resolver.
*/
def resolveSingleUrl(resolverObject: DrsLocalizerDrsPathResolver, drsUrlToResolve: UnresolvedDrsUrl): IO[ResolvedDrsUrl] = {
IO {
val fields = NonEmptyList.of(DrsResolverField.GsUri, DrsResolverField.GoogleServiceAccount, DrsResolverField.AccessUrl, DrsResolverField.Hashes)
//Insert retry logic here.
val drsResponse = resolverObject.resolveDrs(drsUrlToResolve.drsUrl, fields).unsafeRunSync()
ResolvedDrsUrl(drsResponse, drsUrlToResolve.downloadDestinationPath, toUriType(drsResponse.accessUrl, drsResponse.gsUri))
ResolvedDrsUrl(drsResponse, drsUrlToResolve.downloadDestinationPath, toValidatedUriType(drsResponse.accessUrl, drsResponse.gsUri))
}
}



def resolveAndDownload() : IO[DownloadResult] = {
IO {
val downloaders: List[Downloader] = buildDownloaders().unsafeRunSync()
val results: List[DownloadResult] = downloaders.map(downloader => downloader.download.unsafeRunSync())
results.find(res => res!= DownloadSuccess).getOrElse(DownloadSuccess)
/**
* Runs synchronous HTTP requests to resolve all the DRS urls.
*/
def resolveUrls(unresolvedUrls: IO[List[UnresolvedDrsUrl]]) : IO[List[ResolvedDrsUrl]] = {
unresolvedUrls.flatMap{unresolvedList =>
getDrsPathResolver.flatMap{resolver =>
unresolvedList.map{unresolvedUrl =>
resolveSingleUrl(resolver, unresolvedUrl)
}.traverse(identity)
}
}
}

/**
* After resolving all of the URLs, this sorts them into an "Access" or "GCS" bucket.
* All access URLS will be downloaded as a batch with a single bulk downloader.
* All google URLs will be downloaded individually in their own google downloader.
* @return List of all downloaders required to fulfill the request.
*/
def buildDownloaders() : IO[List[Downloader]] = {
resolveUrls(toResolveAndDownload).flatMap { pendingDownloads =>
val accessUrls = pendingDownloads.filter(url => url.uriType == URIType.HTTPS)
val accessUrls = pendingDownloads.filter(url => url.uriType == URIType.ACCESS)
val googleUrls = pendingDownloads.filter(url => url.uriType == URIType.GCS)
pendingDownloads.filter(url => url.uriType == URIType.UNKNOWN).map(
unknown => logger.error(
s"Url is not an https:// or gs:// URL: " +
s"${unknown.drsResponse.accessUrl.getOrElse(unknown.drsResponse.gsUri.getOrElse("missing_url"))}"))
val bulkDownloader: Option[List[IO[Downloader]]] = if(accessUrls.isEmpty) None else Option(List(buildBulkAccessUrlDownloader(accessUrls)))
val googleDownloaders: Option[List[IO[Downloader]]] = if(googleUrls.isEmpty) None else Option(buildGoogleDownloaders(googleUrls))
val combined: List[IO[Downloader]] = googleDownloaders.map(list => list).getOrElse(List()) ++ bulkDownloader.map(list => list).getOrElse(List())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,10 @@ class DrsLocalizerMainSpec extends AnyFlatSpec with CromwellTimeoutSpec with Mat
val downloaders: List[Downloader] = mockdrsLocalizer.buildDownloaders().unsafeRunSync()
downloaders.length shouldBe 1

val correct = downloaders.head match {
case _: BulkAccessUrlDownloader => true
case _ => false
}
correct shouldBe true
val expected = BulkAccessUrlDownloader(
List(fakeAccessUrls.head._2)
)
expected shouldEqual downloaders.head
}

it should "build correct downloader(s) for multiple google URLs" in {
Expand Down Expand Up @@ -159,8 +158,12 @@ class DrsLocalizerMainSpec extends AnyFlatSpec with CromwellTimeoutSpec with Mat
case _: BulkAccessUrlDownloader => true
case _ => false
})
// We expect one GCS downloader for each GCS uri provided
// We expect one total Bulk downloader for all access URIs to share
countBulkDownloaders shouldBe 1
val expected = BulkAccessUrlDownloader(
fakeAccessUrls.map(pair => pair._2).toList
)
expected shouldEqual downloaders.head
}

it should "build 1 bulk downloader and 5 google downloaders for a mix of URLs" in {
Expand Down Expand Up @@ -198,45 +201,46 @@ class DrsLocalizerMainSpec extends AnyFlatSpec with CromwellTimeoutSpec with Mat
downloader shouldBe expected
}

it should "run successfully with all 3 arguments" in {
val unresolved = fakeGoogleUrls.head._1
val mockDrsLocalizer = new MockDrsLocalizerMain(IO(List(unresolved)), DrsLocalizerMain.defaultDownloaderFactory, FakeAccessTokenStrategy, Option(fakeRequesterPaysId))
val expected = GcsUriDownloader(
gcsUrl = fakeGoogleUrls.get(unresolved).get.drsResponse.gsUri.get,
serviceAccountJson = None,
downloadLoc = unresolved.downloadDestinationPath,
requesterPaysProjectIdOption = Option(fakeRequesterPaysId))
val downloader: Downloader = mockDrsLocalizer.buildDownloaders().unsafeRunSync().head
downloader shouldBe expected
}
/*
it should "fail and throw error if the DRS Resolver response does not have gs:// url" in {
val mockDrsLocalizer = new MockDrsLocalizerMain(MockDrsPaths.fakeDrsUrlWithoutAnyResolution, fakeDownloadLocation, None)
it should "run successfully with all 3 arguments" in {
val unresolved = fakeGoogleUrls.head._1
val mockDrsLocalizer = new MockDrsLocalizerMain(IO(List(unresolved)), DrsLocalizerMain.defaultDownloaderFactory, FakeAccessTokenStrategy, Option(fakeRequesterPaysId))
val expected = GcsUriDownloader(
gcsUrl = fakeGoogleUrls.get(unresolved).get.drsResponse.gsUri.get,
serviceAccountJson = None,
downloadLoc = unresolved.downloadDestinationPath,
requesterPaysProjectIdOption = Option(fakeRequesterPaysId))
val downloader: Downloader = mockDrsLocalizer.buildDownloaders().unsafeRunSync().head
downloader shouldBe expected
}

the[RuntimeException] thrownBy {
mockDrsLocalizer.resolve(DrsLocalizerMain.defaultDownloaderFactory).unsafeRunSync()
} should have message "No access URL nor GCS URI starting with 'gs://' found in the DRS Resolver response!"
}
it should "successfully identify uri types, preferring access" in {
val exampleAccessResponse = DrsResolverResponse(accessUrl = Option(AccessUrl("https://something.com", FakeHashes)))
val exampleGoogleResponse = DrsResolverResponse(gsUri = Option("gs://something"))
val exampleMixedResponse = DrsResolverResponse(accessUrl = Option(AccessUrl("https://something.com", FakeHashes)), gsUri = Option("gs://something"))
DrsLocalizerMain.toValidatedUriType(exampleAccessResponse.accessUrl, exampleAccessResponse.gsUri) shouldBe URIType.ACCESS
DrsLocalizerMain.toValidatedUriType(exampleGoogleResponse.accessUrl, exampleGoogleResponse.gsUri) shouldBe URIType.GCS
DrsLocalizerMain.toValidatedUriType(exampleMixedResponse.accessUrl, exampleMixedResponse.gsUri) shouldBe URIType.ACCESS
}

it should "resolve to use the correct downloader for an access url" in {
val mockDrsLocalizer = new MockDrsLocalizerMain(MockDrsPaths.fakeDrsUrlWithAccessUrlResolutionOnly, fakeDownloadLocation, None)
val expected = AccessUrlDownloader(
accessUrl = AccessUrl(url = "http://abc/def/ghi.bam", headers = None),
downloadLoc = fakeDownloadLocation,
hashes = FakeHashes
)
mockDrsLocalizer.resolve(DrsLocalizerMain.defaultDownloaderFactory).unsafeRunSync() shouldBe expected
}
it should "throw an exception if the DRS Resolver response is invalid" in {
val badAccessResponse = DrsResolverResponse(accessUrl = Option(AccessUrl("hQQps://something.com", FakeHashes)))
val badGoogleResponse = DrsResolverResponse(gsUri = Option("gQQs://something"))
val emptyResponse = DrsResolverResponse()

it should "resolve to use the correct downloader for an access url when the DRS Resolver response also contains a gs url" in {
val mockDrsLocalizer = new MockDrsLocalizerMain(MockDrsPaths.fakeDrsUrlWithAccessUrlAndGcsResolution, fakeDownloadLocation, None)
val expected = AccessUrlDownloader(
accessUrl = AccessUrl(url = "http://abc/def/ghi.bam", headers = None), downloadLoc = fakeDownloadLocation,
hashes = FakeHashes
)
mockDrsLocalizer.resolve(DrsLocalizerMain.defaultDownloaderFactory).unsafeRunSync() shouldBe expected
}
the[RuntimeException] thrownBy {
DrsLocalizerMain.toValidatedUriType(badAccessResponse.accessUrl, badAccessResponse.gsUri)
} should have message "Resolved Access URL does not start with https://"

the[RuntimeException] thrownBy {
DrsLocalizerMain.toValidatedUriType(badGoogleResponse.accessUrl, badGoogleResponse.gsUri)
} should have message "Resolved Google URL does not start with gs://"

the[RuntimeException] thrownBy {
DrsLocalizerMain.toValidatedUriType(emptyResponse.accessUrl, emptyResponse.gsUri)
} should have message "DRS response did not contain any URLs"
}

/*
it should "not retry on access URL download success" in {
var actualAttempts = 0
Expand Down Expand Up @@ -397,11 +401,6 @@ class DrsLocalizerMainSpec extends AnyFlatSpec with CromwellTimeoutSpec with Mat
})
val downloaderFactory = new DownloaderFactory {
override def buildAccessUrlDownloader(accessUrl: AccessUrl, downloadLoc: String, hashes: Hashes): IO[Downloader] = {
// This test path should never ask for the access URL downloader
throw new RuntimeException("test failure")
}
override def buildGcsUriDownloader(gcsPath: String, serviceAccountJsonOption: Option[String], downloadLoc: String, requesterPaysProjectOption: Option[String]): IO[Downloader] = {
gcsUriDownloader
}
Expand Down Expand Up @@ -431,10 +430,6 @@ class DrsLocalizerMainSpec extends AnyFlatSpec with CromwellTimeoutSpec with Mat
super.resolveAndDownload(downloaderFactory)
}
}
val accessUrlDownloader = IO.pure(new Downloader {
override def download: IO[DownloadResult] =
IO.pure(ChecksumFailure)
})
val downloaderFactory = new DownloaderFactory {
override def buildGcsUriDownloader(gcsPath: String, serviceAccountJsonOption: Option[String], downloadLoc: String, requesterPaysProjectOption: Option[String]): IO[Downloader] = {
Expand All @@ -455,7 +450,6 @@ class DrsLocalizerMainSpec extends AnyFlatSpec with CromwellTimeoutSpec with Mat
downloaderFactory = downloaderFactory,
backoff = None).unsafeRunSync()
}
actualAttempts shouldBe 2 // 1 initial attempt + 1 retry = 2 total attempts
}
*/
Expand All @@ -476,11 +470,11 @@ object MockDrsPaths {
)

val fakeAccessUrls: Map[UnresolvedDrsUrl, ResolvedDrsUrl] = Map(
(UnresolvedDrsUrl("drs://abc/foo-123/access/0", "/path/to/access/local0"), ResolvedDrsUrl(DrsResolverResponse(accessUrl = Option(AccessUrl("https://abc/foo-123/access/0", FakeHashes))), "/path/to/access/local0", URIType.HTTPS)),
(UnresolvedDrsUrl("drs://abc/foo-123/access/1", "/path/to/access/local1"), ResolvedDrsUrl(DrsResolverResponse(accessUrl = Option(AccessUrl("https://abc/foo-123/access/1", FakeHashes))), "/path/to/access/local1", URIType.HTTPS)),
(UnresolvedDrsUrl("drs://abc/foo-123/access/2", "/path/to/access/local2"), ResolvedDrsUrl(DrsResolverResponse(accessUrl = Option(AccessUrl("https://abc/foo-123/access/2", FakeHashes))), "/path/to/access/local2", URIType.HTTPS)),
(UnresolvedDrsUrl("drs://abc/foo-123/access/3", "/path/to/access/local3"), ResolvedDrsUrl(DrsResolverResponse(accessUrl = Option(AccessUrl("https://abc/foo-123/access/3", FakeHashes))), "/path/to/access/local3", URIType.HTTPS)),
(UnresolvedDrsUrl("drs://abc/foo-123/access/4", "/path/to/access/local4"), ResolvedDrsUrl(DrsResolverResponse(accessUrl = Option(AccessUrl("https://abc/foo-123/access/4", FakeHashes))), "/path/to/access/local4", URIType.HTTPS))
(UnresolvedDrsUrl("drs://abc/foo-123/access/0", "/path/to/access/local0"), ResolvedDrsUrl(DrsResolverResponse(accessUrl = Option(AccessUrl("https://abc/foo-123/access/0", FakeHashes))), "/path/to/access/local0", URIType.ACCESS)),
(UnresolvedDrsUrl("drs://abc/foo-123/access/1", "/path/to/access/local1"), ResolvedDrsUrl(DrsResolverResponse(accessUrl = Option(AccessUrl("https://abc/foo-123/access/1", FakeHashes))), "/path/to/access/local1", URIType.ACCESS)),
(UnresolvedDrsUrl("drs://abc/foo-123/access/2", "/path/to/access/local2"), ResolvedDrsUrl(DrsResolverResponse(accessUrl = Option(AccessUrl("https://abc/foo-123/access/2", FakeHashes))), "/path/to/access/local2", URIType.ACCESS)),
(UnresolvedDrsUrl("drs://abc/foo-123/access/3", "/path/to/access/local3"), ResolvedDrsUrl(DrsResolverResponse(accessUrl = Option(AccessUrl("https://abc/foo-123/access/3", FakeHashes))), "/path/to/access/local3", URIType.ACCESS)),
(UnresolvedDrsUrl("drs://abc/foo-123/access/4", "/path/to/access/local4"), ResolvedDrsUrl(DrsResolverResponse(accessUrl = Option(AccessUrl("https://abc/foo-123/access/4", FakeHashes))), "/path/to/access/local4", URIType.ACCESS))
)
}

Expand Down Expand Up @@ -508,7 +502,6 @@ class MockDrsLocalizerMain(toResolveAndDownload: IO[List[UnresolvedDrsUrl]],
}
}


class MockDrsLocalizerDrsPathResolver(drsConfig: DrsConfig) extends
DrsLocalizerDrsPathResolver(drsConfig, FakeAccessTokenStrategy) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ import org.scalatest.matchers.should.Matchers
import java.nio.file.Path

class BulkAccessUrlDownloaderSpec extends AnyFlatSpec with CromwellTimeoutSpec with Matchers {
val ex1 = ResolvedDrsUrl(DrsResolverResponse(accessUrl = Option(AccessUrl("https://my.fake/url123", None))), "path/to/local/download/dest", URIType.HTTPS)
val ex2 = ResolvedDrsUrl(DrsResolverResponse(accessUrl = Option(AccessUrl("https://my.fake/url1234", None))), "path/to/local/download/dest2", URIType.HTTPS)
val ex3 = ResolvedDrsUrl(DrsResolverResponse(accessUrl = Option(AccessUrl("https://my.fake/url1235", None))), "path/to/local/download/dest3", URIType.HTTPS)
val ex1 = ResolvedDrsUrl(DrsResolverResponse(accessUrl = Option(AccessUrl("https://my.fake/url123", None))), "path/to/local/download/dest", URIType.ACCESS)
val ex2 = ResolvedDrsUrl(DrsResolverResponse(accessUrl = Option(AccessUrl("https://my.fake/url1234", None))), "path/to/local/download/dest2", URIType.ACCESS)
val ex3 = ResolvedDrsUrl(DrsResolverResponse(accessUrl = Option(AccessUrl("https://my.fake/url1235", None))), "path/to/local/download/dest3", URIType.ACCESS)
val emptyList : List[ResolvedDrsUrl] = List()
val oneElement: List[ResolvedDrsUrl] = List(ex1)
val threeElements: List[ResolvedDrsUrl] = List(ex1, ex2, ex3)
Expand Down

0 comments on commit c66ee6f

Please sign in to comment.