Skip to content

Commit

Permalink
More efficient database unmarshaling (#775)
Browse files Browse the repository at this point in the history
* More efficient database unmarshaling

* Update workbenchGoogle to the latest
  • Loading branch information
rtitle authored and Qi77Qi committed Jun 21, 2019
1 parent 9fc1c8e commit 0216110
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 17 deletions.
7 changes: 4 additions & 3 deletions project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ object Dependencies {
val scalaLoggingV = "3.9.0"
val scalaTestV = "3.0.5"
val slickV = "3.2.3"
val catsV = "1.3.1"

val workbenchUtilV = "0.3-0e9d080"
val workbenchUtilV = "0.5-6942040"
val workbenchModelV = "0.11-2bddd5b"
val workbenchGoogleV = "0.16-4fe117d"
val workbenchGoogleV = "0.18-6942040"
val workbenchMetricsV = "0.3-c5b80d2"

val samV = "1.0-5cdffb4"
Expand Down Expand Up @@ -48,7 +49,7 @@ object Dependencies {
val scalaLogging: ModuleID = "com.typesafe.scala-logging" %% "scala-logging" % scalaLoggingV
val swaggerUi: ModuleID = "org.webjars" % "swagger-ui" % "2.2.5"
val ficus: ModuleID = "com.iheart" %% "ficus" % "1.4.3"
val cats: ModuleID = "org.typelevel" %% "cats" % "0.9.0"
val cats: ModuleID = "org.typelevel" %% "cats-core" % catsV
val httpClient: ModuleID = "org.apache.httpcomponents" % "httpclient" % "4.5.5" // upgrading a transitive dependency to avoid security warnings
val enumeratum: ModuleID = "com.beachape" %% "enumeratum" % "1.5.13"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class HttpGoogleDataprocDAO(appName: String,
override def getClusterStatus(googleProject: GoogleProject, clusterName: ClusterName): Future[ClusterStatus] = {
val transformed = for {
cluster <- OptionT(getCluster(googleProject, clusterName))
status <- OptionT.pure[Future, ClusterStatus](
status <- OptionT.pure[Future](
Try(ClusterStatus.withNameInsensitive(cluster.getStatus.getState)).toOption.getOrElse(ClusterStatus.Unknown))
} yield status

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import java.time.Instant
import java.sql.Timestamp
import java.util.UUID

import cats.data.Chain
import cats.implicits._
import org.broadinstitute.dsde.workbench.leonardo.model.Cluster.LabelMap
import org.broadinstitute.dsde.workbench.leonardo.model._
Expand Down Expand Up @@ -193,6 +194,19 @@ trait ClusterComponent extends LeoComponent {
}
}

def getActiveClusterForDnsCache(project: GoogleProject, name: ClusterName): DBIO[Option[Cluster]] = {
clusterQuery
.filter { _.googleProject === project.value }
.filter { _.clusterName === name.value }
.filter { _.destroyedDate === Timestamp.from(dummyDate) }
.result
.map { recs =>
recs.headOption.map { clusterRec =>
unmarshalCluster(clusterRec, Seq.empty, List.empty, Map.empty, List.empty, List.empty, List.empty)
}
}
}

def getClusterById(id: Long): DBIO[Option[Cluster]] = {
fullClusterQuery.filter { _._1.id === id }.result map { recs =>
unmarshalFullCluster(recs).headOption
Expand Down Expand Up @@ -404,32 +418,36 @@ trait ClusterComponent extends LeoComponent {

private def unmarshalMinimalCluster(clusterLabels: Seq[(ClusterRecord, Option[LabelRecord])]): Seq[Cluster] = {
// Call foldMap to aggregate a Seq[(ClusterRecord, LabelRecord)] returned by the query to a Map[ClusterRecord, Map[labelKey, labelValue]].
val clusterLabelMap: Map[ClusterRecord, Map[String, List[String]]] = clusterLabels.toList.foldMap { case (clusterRecord, labelRecordOpt) =>
val labelMap = labelRecordOpt.map(labelRecordOpt => labelRecordOpt.key -> List(labelRecordOpt.value)).toMap
// Note we use Chain instead of List inside the foldMap because the Chain monoid is much more efficient than the List monoid.
// See: https://typelevel.org/cats/datatypes/chain.html
val clusterLabelMap: Map[ClusterRecord, Map[String, Chain[String]]] = clusterLabels.toList.foldMap { case (clusterRecord, labelRecordOpt) =>
val labelMap = labelRecordOpt.map(labelRecord => labelRecord.key -> Chain(labelRecord.value)).toMap
Map(clusterRecord -> labelMap)
}

// Unmarshal each (ClusterRecord, Map[labelKey, labelValue]) to a Cluster object
clusterLabelMap.map { case (clusterRec, labelMap) =>
unmarshalCluster(clusterRec, Seq.empty, List.empty, labelMap.mapValues(_.toSet.head), List.empty, List.empty, List.empty)
unmarshalCluster(clusterRec, Seq.empty, List.empty, labelMap.mapValues(_.toList.toSet.head), List.empty, List.empty, List.empty)
}.toSeq
}

private def unmarshalFullCluster(clusterRecords: Seq[(ClusterRecord, Option[InstanceRecord], Option[ClusterErrorRecord], Option[LabelRecord], Option[ExtensionRecord], Option[ClusterImageRecord], Option[ScopeRecord])]): Seq[Cluster] = {
// Call foldMap to aggregate a flat sequence of (cluster, instance, label) triples returned by the query
// to a grouped (cluster -> (instances, labels)) structure.
val clusterRecordMap: Map[ClusterRecord, (List[InstanceRecord], List[ClusterErrorRecord], Map[String, List[String]], List[ExtensionRecord], List[ClusterImageRecord], List[ScopeRecord])] = clusterRecords.toList.foldMap { case (clusterRecord, instanceRecordOpt, errorRecordOpt, labelRecordOpt, extensionOpt, clusterImageOpt, scopeOpt) =>
// Note we use Chain instead of List inside the foldMap because the Chain monoid is much more efficient than the List monoid.
// See: https://typelevel.org/cats/datatypes/chain.html
val clusterRecordMap: Map[ClusterRecord, (Chain[InstanceRecord], Chain[ClusterErrorRecord], Map[String, Chain[String]], Chain[ExtensionRecord], Chain[ClusterImageRecord], Chain[ScopeRecord])] = clusterRecords.toList.foldMap { case (clusterRecord, instanceRecordOpt, errorRecordOpt, labelRecordOpt, extensionOpt, clusterImageOpt, scopeOpt) =>
val instanceList = instanceRecordOpt.toList
val labelMap = labelRecordOpt.map(labelRecordOpt => labelRecordOpt.key -> List(labelRecordOpt.value)).toMap
val labelMap = labelRecordOpt.map(labelRecordOpt => labelRecordOpt.key -> Chain(labelRecordOpt.value)).toMap
val errorList = errorRecordOpt.toList
val extList = extensionOpt.toList
val clusterImageList = clusterImageOpt.toList
val scopeList = scopeOpt.toList
Map(clusterRecord -> (instanceList, errorList, labelMap, extList, clusterImageList, scopeList))
Map(clusterRecord -> (Chain.fromSeq(instanceList), Chain.fromSeq(errorList), labelMap, Chain.fromSeq(extList), Chain.fromSeq(clusterImageList), Chain.fromSeq(scopeList)))
}

clusterRecordMap.map { case (clusterRecord, (instanceRecords, errorRecords, labels, extensions, clusterImages, scopes)) =>
unmarshalCluster(clusterRecord, instanceRecords.toSet.toSeq, errorRecords.groupBy(_.timestamp).map(_._2.head).toList, labels.mapValues(_.toSet.head), extensions, clusterImages.toSet.toList, scopes)
unmarshalCluster(clusterRecord, instanceRecords.toList, errorRecords.toList.groupBy(_.timestamp).map(_._2.head).toList, labels.mapValues(_.toList.toSet.head), extensions.toList, clusterImages.toList, scopes.toList)
}.toSeq
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class ClusterDnsCache(proxyConfig: ProxyConfig, dbRef: DbReference, dnsCacheConf
def load(key: DnsCacheKey) = {
logger.debug(s"DNS Cache miss for ${key.clusterName} / ${key.clusterName}...loading from DB...")
dbRef
.inTransaction { _.clusterQuery.getActiveClusterByName(key.googleProject, key.clusterName) }
.inTransaction { _.clusterQuery.getActiveClusterForDnsCache(key.googleProject, key.clusterName) }
.map {
case Some(cluster) => getHostStatusAndUpdateHostToIpIfHostReady(cluster)
case None => HostNotFound
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ class ZombieClusterMonitor(config: ZombieClusterConfig, gdDAO: GoogleDataprocDAO

private def isProjectActiveInGoogle(googleProject: GoogleProject): Future[Boolean] = {
// Check the project and its billing info
(googleProjectDAO.isProjectActive(googleProject.value) |@| googleProjectDAO.isBillingActive(googleProject.value))
.map(_ && _)
(googleProjectDAO.isProjectActive(googleProject.value), googleProjectDAO.isBillingActive(googleProject.value))
.mapN(_ && _)
.recover { case e =>
logger.warn(s"Unable to check status of project ${googleProject.value} for zombie cluster detection", e)
true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -728,15 +728,15 @@ class LeonardoService(protected val dataprocConfig: DataprocConfig,
// Validate the user script URI
_ <- clusterRequest.jupyterUserScriptUri match {
case Some(userScriptUri) => OptionT.liftF[Future, Unit](validateBucketObjectUri(userEmail, petToken, userScriptUri.toUri))
case None => OptionT.pure[Future, Unit](())
case None => OptionT.pure[Future](())
}

// Validate the extension URIs
_ <- clusterRequest.userJupyterExtensionConfig match {
case Some(config) =>
val extensionsToValidate = (config.nbExtensions.values ++ config.serverExtensions.values ++ config.combinedExtensions.values).filter(_.startsWith("gs://"))
OptionT.liftF(Future.traverse(extensionsToValidate)(x => validateBucketObjectUri(userEmail, petToken, x)))
case None => OptionT.pure[Future, Unit](())
case None => OptionT.pure[Future](())
}
} yield ()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ object ClusterEnrichments {
fcs1 == fcs2
}

def stripFieldsForListCluster(cluster: Cluster): Cluster = {
def stripFieldsForListCluster: Cluster => Cluster = { cluster =>
cluster.copy(
instances = Set.empty,
clusterImages = Set.empty,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,4 +209,16 @@ class ClusterComponentSpec extends TestComponent with FlatSpecLike with CommonTe
dbFutureValue { _.clusterQuery.listByLabels(Map("a" -> "b"), true, Some(project)) }.toSet shouldEqual Set(savedCluster3).map(stripFieldsForListCluster)
dbFutureValue { _.clusterQuery.listByLabels(Map("a" -> "b"), true, Some(project2)) }.toSet shouldEqual Set.empty[Cluster]
}

it should "get for dns cache" in isolatedDbTest {
val savedCluster1 = makeCluster(1)
.copy(
labels = Map("bam" -> "yes", "vcf" -> "no", "foo" -> "bar"),
instances = Set(masterInstance, workerInstance1, workerInstance2))
.save(Some(serviceAccountKey.id))

// Result should not include labels or instances
dbFutureValue { _.clusterQuery.getActiveClusterForDnsCache(savedCluster1.googleProject, savedCluster1.clusterName) } shouldEqual
Some(savedCluster1).map(stripFieldsForListCluster andThen (_.copy(labels = Map.empty)))
}
}

0 comments on commit 0216110

Please sign in to comment.