Skip to content
This repository has been archived by the owner on Oct 23, 2024. It is now read-only.

Commit

Permalink
Add constraint support for Zone and Region (#5574)
Browse files Browse the repository at this point in the history
* Add constraint support for Zone and Region

* Zone / Region info added to agentInfo
* Consolidate matcher code

JIRA Issues: MARATHON_EE-1665

* ov => offerValue

* use @ instead of * for field prefix and other things

* introduce meetConstraint matcher
* more consistent variable naming
* selectInstancesToKill recognizes native region and zone fields

* convert one more test case to use matcher
  • Loading branch information
timcharper authored Oct 9, 2017
1 parent 50449fb commit 3fc89ec
Show file tree
Hide file tree
Showing 24 changed files with 370 additions and 430 deletions.
29 changes: 26 additions & 3 deletions src/main/scala/mesosphere/marathon/core/instance/Instance.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import mesosphere.marathon.core.condition.Condition
import mesosphere.marathon.core.instance.Instance.{ AgentInfo, InstanceState }
import mesosphere.marathon.core.task.Task
import mesosphere.marathon.state.{ MarathonState, PathId, Timestamp, UnreachableStrategy, UnreachableDisabled, UnreachableEnabled }
import mesosphere.marathon.tasks.OfferUtil
import mesosphere.marathon.stream.Implicits._
import mesosphere.mesos.Placed
import mesosphere.marathon.raml.Raml
Expand Down Expand Up @@ -69,6 +70,10 @@ case class Instance(
override def hostname: String = agentInfo.host

override def attributes: Seq[Attribute] = agentInfo.attributes

override def zone: Option[String] = agentInfo.zone

override def region: Option[String] = agentInfo.region
}

@SuppressWarnings(Array("DuplicateImport"))
Expand Down Expand Up @@ -267,12 +272,16 @@ object Instance {
case class AgentInfo(
host: String,
agentId: Option[String],
attributes: Seq[mesos.Protos.Attribute])
region: Option[String],
zone: Option[String],
attributes: Seq[Attribute])

object AgentInfo {
def apply(offer: org.apache.mesos.Protos.Offer): AgentInfo = AgentInfo(
host = offer.getHostname,
agentId = Some(offer.getSlaveId.getValue),
region = OfferUtil.region(offer),
zone = OfferUtil.zone(offer),
attributes = offer.getAttributesList.toIndexedSeq
)
}
Expand All @@ -290,7 +299,7 @@ object Instance {
throw new IllegalStateException(s"No task in ${instance.instanceId}"))
}

implicit object AttributeFormat extends Format[mesos.Protos.Attribute] {
implicit object AttributeFormat extends Format[Attribute] {
override def reads(json: JsValue): JsResult[Attribute] = {
json.validate[String].map { base64 =>
mesos.Protos.Attribute.parseFrom(Base64.getDecoder.decode(base64))
Expand All @@ -312,7 +321,21 @@ object Instance {
}
}

implicit val agentFormat: Format[AgentInfo] = Json.format[AgentInfo]
// host: String,
// agentId: Option[String],
// region: String,
// zone: String,
// attributes: Seq[mesos.Protos.Attribute])
// private val agentFormatWrites: Writes[AgentInfo] = Json.format[AgentInfo]
private val agentReads: Reads[AgentInfo] = (
(__ \ "host").read[String] ~
(__ \ "agentId").readNullable[String] ~
(__ \ "region").readNullable[String] ~
(__ \ "zone").readNullable[String] ~
(__ \ "attributes").read[Seq[mesos.Protos.Attribute]]
)(AgentInfo(_, _, _, _, _))

implicit val agentFormat: Format[AgentInfo] = Format(agentReads, Json.writes[AgentInfo])
implicit val idFormat: Format[Instance.Id] = Json.format[Instance.Id]
implicit val instanceConditionFormat: Format[Condition] = Condition.conditionFormat
implicit val instanceStateFormat: Format[InstanceState] = Json.format[InstanceState]
Expand Down
23 changes: 23 additions & 0 deletions src/main/scala/mesosphere/marathon/tasks/OfferUtil.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package mesosphere.marathon
package tasks

import org.apache.mesos.Protos.Offer

object OfferUtil {
def region(offer: Offer): Option[String] = if (offer.hasDomain) {
val domain = offer.getDomain
if (domain.hasFaultDomain) {
// region and name are marked as required in the protobuf definition
Some(domain.getFaultDomain.getRegion.getName)
} else None
} else None

def zone(offer: Offer): Option[String] =
if (offer.hasDomain) {
val domain = offer.getDomain
if (domain.hasFaultDomain) {
// zone and name are marked as required in the protobuf definition
Some(domain.getFaultDomain.getZone.getName)
} else None
} else None
}
150 changes: 68 additions & 82 deletions src/main/scala/mesosphere/mesos/Constraints.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import mesosphere.marathon.Protos.Constraint.Operator
import mesosphere.marathon.core.instance.Instance
import mesosphere.marathon.state.RunSpec
import mesosphere.marathon.stream.Implicits._
import mesosphere.marathon.tasks.OfferUtil
import org.apache.mesos.Protos.{ Attribute, Offer, Value }
import org.slf4j.LoggerFactory

Expand All @@ -18,6 +19,8 @@ object Int {
trait Placed {
def attributes: Seq[Attribute]
def hostname: String
def region: Option[String]
def zone: Option[String]
}

object Constraints {
Expand Down Expand Up @@ -47,25 +50,36 @@ object Constraints {
s"{$s}"
}

type FieldReader = (Offer => Option[String], Placed => Option[String])
private val hostnameReader: FieldReader = (offer => Some(offer.getHostname), placed => Some(placed.hostname))
private val regionReader: FieldReader = (OfferUtil.region(_), _.region)
private val zoneReader: FieldReader = (OfferUtil.zone(_), _.zone)
private def attributeReader(field: String): FieldReader = (
{ offer => offer.getAttributesList.find(_.getName == field).map(getValueString) },
{ p => p.attributes.find(_.getName == field).map(getValueString) })

val hostnameField = "@hostname"
val regionField = "@region"
val zoneField = "@field"
def readerForField(field: String): FieldReader =
field match {
case "hostname" | `hostnameField` => hostnameReader
case `regionField` => regionReader
case `zoneField` => zoneReader
case _ => attributeReader(field)
}

private final class ConstraintsChecker(allPlaced: Seq[Placed], offer: Offer, constraint: Constraint) {
val field = constraint.getField
val value = constraint.getValue
lazy val attr = offer.getAttributesList.find(_.getName == field)

def isMatch: Boolean =
if (field == "hostname") {
checkHostName
} else if (attr.nonEmpty) {
checkAttribute
} else {
// This will be reached in case we want to schedule for an attribute
// that's not supplied.
checkMissingAttribute
}
val constraintValue = constraint.getValue

private def checkGroupBy(constraintValue: String, groupFunc: (Placed) => Option[String]) = {
def isMatch: Boolean = {
val (offerReader, placedReader) = readerForField(constraint.getField)
checkConstraint(offerReader(offer), placedReader)
}

private def checkGroupBy(offerValue: String, groupFunc: (Placed) => Option[String]) = {
// Minimum group count
val minimum = List(GroupByDefault, getIntValue(value, GroupByDefault)).max
val minimum = List(GroupByDefault, getIntValue(constraintValue, GroupByDefault)).max
// Group tasks by the constraint value, and calculate the task count of each group
val groupedTasks = allPlaced.groupBy(groupFunc).map { case (k, v) => k -> v.size }
// Task count of the smallest group
Expand All @@ -75,88 +89,63 @@ object Constraints {
// a) this offer matches the smallest grouping when there
// are >= minimum groupings
// b) the constraint value from the offer is not yet in the grouping
groupedTasks.find(_._1.contains(constraintValue))
groupedTasks.find(_._1.contains(offerValue))
.forall(pair => groupedTasks.size >= minimum && pair._2 == minCount)
}

private def checkMaxPer(constraintValue: String, maxCount: Int, groupFunc: (Placed) => Option[String]): Boolean = {
private def checkMaxPer(offerValue: String, maxCount: Int, groupFunc: (Placed) => Option[String]): Boolean = {
// Group tasks by the constraint value, and calculate the task count of each group
val groupedTasks = allPlaced.groupBy(groupFunc).map { case (k, v) => k -> v.size }

groupedTasks.find(_._1.contains(constraintValue)).forall(_._2 < maxCount)
groupedTasks.find(_._1.contains(offerValue)).forall(_._2 < maxCount)
}

private def checkHostName =
constraint.getOperator match {
case Operator.LIKE => offer.getHostname.matches(value)
case Operator.UNLIKE => !offer.getHostname.matches(value)
// All running tasks must have a hostname that is different from the one in the offer
case Operator.UNIQUE => allPlaced.forall(_.hostname != offer.getHostname)
case Operator.GROUP_BY => checkGroupBy(offer.getHostname, (p: Placed) => Some(p.hostname))
case Operator.MAX_PER => checkMaxPer(offer.getHostname, value.toInt, (p: Placed) => Some(p.hostname))
case Operator.CLUSTER =>
// Hostname must match or be empty
(value.isEmpty || value == offer.getHostname) &&
// All running tasks must have the same hostname as the one in the offer
allPlaced.forall(_.hostname == offer.getHostname)
case _ => false
}
private def checkCluster(offerValue: String, placedValue: Placed => Option[String]) =
if (constraintValue.isEmpty)
// If no placements are made, then accept (and make this offerValue) the value on which all future tasks are
// placed
allPlaced.headOption.fold(true) { p => placedValue(p) contains offerValue }
else
// Is constraint
(offerValue == constraintValue)

// All running tasks must have a value that is different from the one in the offer
private def checkUnique(offerValue: Option[String], placedValue: Placed => Option[String]) = {
allPlaced.forall { p => placedValue(p) != offerValue }
}

@SuppressWarnings(Array("OptionGet"))
private def checkAttribute: Boolean = {
def matches: Seq[Placed] = matchTaskAttributes(allPlaced, field, getValueString(attr.get))
def groupFunc = (p: Placed) => p.attributes
.find(_.getName == field)
.map(getValueString)
constraint.getOperator match {
case Operator.UNIQUE => matches.isEmpty
case Operator.CLUSTER =>
// If no value is set, accept the first one. Otherwise check for it.
(value.isEmpty || getValueString(attr.get) == value) &&
// All running tasks should have the matching attribute
matches.size == allPlaced.size
case Operator.GROUP_BY =>
checkGroupBy(getValueString(attr.get), groupFunc)
case Operator.MAX_PER =>
checkMaxPer(getValueString(attr.get), value.toInt, groupFunc)
case Operator.LIKE => checkLike
case Operator.UNLIKE => checkUnlike
def checkConstraint(maybeOfferValue: Option[String], placedValue: Placed => Option[String]) = {
maybeOfferValue match {
case Some(offerValue) =>
constraint.getOperator match {
case Operator.LIKE => checkLike(offerValue)
case Operator.UNLIKE => checkUnlike(offerValue)
case Operator.UNIQUE => checkUnique(maybeOfferValue, placedValue)
case Operator.GROUP_BY => checkGroupBy(offerValue, placedValue)
case Operator.MAX_PER => checkMaxPer(offerValue, constraintValue.toInt, placedValue)
case Operator.CLUSTER => checkCluster(offerValue, placedValue)
}
case None =>
// Only unlike can be matched if this offer does not have the specified value
constraint.getOperator == Operator.UNLIKE
}
}

@SuppressWarnings(Array("OptionGet"))
private def checkLike: Boolean = {
if (value.nonEmpty) {
getValueString(attr.get).matches(value)
private def checkLike(offerValue: String): Boolean =
if (constraintValue.nonEmpty) {
offerValue.matches(constraintValue)
} else {
log.warn("Error, value is required for LIKE operation")
false
}
}

@SuppressWarnings(Array("OptionGet"))
private def checkUnlike: Boolean = {
if (value.nonEmpty) {
!getValueString(attr.get).matches(value)
private def checkUnlike(offerValue: String): Boolean =
if (constraintValue.nonEmpty) {
!offerValue.matches(constraintValue)
} else {
log.warn("Error, value is required for UNLIKE operation")
false
}
}

private def checkMissingAttribute = constraint.getOperator == Operator.UNLIKE

/**
* Filters running tasks by matching their attributes to this field & value.
*/
private def matchTaskAttributes(allPlaced: Seq[Placed], field: String, value: String) =
allPlaced.filter {
_.attributes
.exists { y =>
y.getName == field &&
getValueString(y) == value
}
}
}

def meetsConstraint(allPlaced: Seq[Placed], offer: Offer, constraint: Constraint): Boolean =
Expand All @@ -181,12 +170,9 @@ object Constraints {

//currently, only the GROUP_BY operator is able to select instances to kill
val distributions = runSpec.constraints.withFilter(_.getOperator == Operator.GROUP_BY).map { constraint =>
def groupFn(instance: Instance): Option[String] = constraint.getField match {
case "hostname" => Some(instance.agentInfo.host)
case field: String => instance.agentInfo.attributes.find(_.getName == field).map(getValueString)
}
val (_, placed) = readerForField(constraint.getField)
val instanceGroups: Seq[Map[Instance.Id, Instance]] =
runningInstances.groupBy(groupFn).values.map(Instance.instancesById)(collection.breakOut)
runningInstances.groupBy(placed).values.map(Instance.instancesById)(collection.breakOut)
GroupByDistribution(constraint, instanceGroups)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class EndpointsHelperTest extends UnitTest {
case (numTasks, agentIndex) =>
val agentId = agentIndex + 1
val hostname = s"agent$agentId"
val agent = AgentInfo(hostname, None, Nil)
val agent = AgentInfo(hostname, None, None, None, Nil)

1.to(numTasks).map { taskIndex =>
val instanceId = Instance.Id.forRunSpec(app.id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ class PodsResourceTest extends AkkaUnitTest with Mockito {
implicit val killer = mock[TaskKiller]
val f = Fixture()
val instance = Instance(
Instance.Id.forRunSpec("/id1".toRootPath), Instance.AgentInfo("", None, Nil),
Instance.Id.forRunSpec("/id1".toRootPath), Instance.AgentInfo("", None, None, None, Nil),
InstanceState(Condition.Running, Timestamp.now(), Some(Timestamp.now()), None),
Map.empty,
runSpecVersion = Timestamp.now(),
Expand All @@ -684,12 +684,12 @@ class PodsResourceTest extends AkkaUnitTest with Mockito {
"attempting to kill multiple instances" in {
implicit val killer = mock[TaskKiller]
val instances = Seq(
Instance(Instance.Id.forRunSpec("/id1".toRootPath), Instance.AgentInfo("", None, Nil),
Instance(Instance.Id.forRunSpec("/id1".toRootPath), Instance.AgentInfo("", None, None, None, Nil),
InstanceState(Condition.Running, Timestamp.now(), Some(Timestamp.now()), None), Map.empty,
runSpecVersion = Timestamp.now(),
unreachableStrategy = UnreachableStrategy.default()
),
Instance(Instance.Id.forRunSpec("/id1".toRootPath), Instance.AgentInfo("", None, Nil),
Instance(Instance.Id.forRunSpec("/id1".toRootPath), Instance.AgentInfo("", None, None, None, Nil),
InstanceState(Condition.Running, Timestamp.now(), Some(Timestamp.now()), None), Map.empty,
runSpecVersion = Timestamp.now(),
unreachableStrategy = UnreachableStrategy.default()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class EnrichedTaskWritesTest extends UnitTest {
val runSpecId = runSpec.id
val hostName = "agent1.mesos"
val agentId = "abcd-1234"
val agentInfo = Instance.AgentInfo(hostName, Some(agentId), attributes = Seq.empty)
val agentInfo = Instance.AgentInfo(hostName, Some(agentId), None, None, attributes = Seq.empty)

val networkInfos = Seq(
MesosProtos.NetworkInfo.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class TaskCountsTest extends UnitTest {
object Fixture {
implicit class TaskImplicits(val task: Task) extends AnyVal {
def toInstance: Instance = LegacyAppInstance(
task, AgentInfo(host = "host", agentId = Some("agent"), attributes = Nil),
task, AgentInfo(host = "host", agentId = Some("agent"), region = None, zone = None, attributes = Nil),
unreachableStrategy = UnreachableStrategy.default(resident = task.reservationWithVolumes.nonEmpty)
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import org.joda.time.DateTime
class TaskLifeTimeTest extends UnitTest {
private[this] val now: Timestamp = Timestamp(new DateTime(2015, 4, 9, 12, 30))
private[this] val runSpecId = PathId("/test")
private[this] val agentInfo = AgentInfo(host = "host", agentId = Some("agent"), attributes = Nil)
private[this] val agentInfo = AgentInfo(host = "host", agentId = Some("agent"), region = None, zone = None, attributes = Nil)
private[this] def newTaskId(): Task.Id = {
Task.Id.forRunSpec(runSpecId)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class TaskStatsByVersionTest extends UnitTest {
}
private[this] def runningInstanceStartedAt(version: Timestamp, startingDelay: FiniteDuration): Instance = {
val startedAt = (version + startingDelay).millis
val agentInfo = AgentInfo(host = "host", agentId = Some("agent"), attributes = Nil)
val agentInfo = AgentInfo(host = "host", agentId = Some("agent"), region = None, zone = None, attributes = Nil)
LegacyAppInstance(
TestTaskBuilder.Helper.runningTask(newTaskId(), appVersion = version, startedAt = startedAt),
agentInfo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ class AppInfoBaseDataTest extends UnitTest with GroupCreation {

Instance(
instanceId = instanceId,
agentInfo = Instance.AgentInfo("", None, Nil),
agentInfo = Instance.AgentInfo("", None, None, None, Nil),
state = InstanceState(None, tasks, f.clock.now(), UnreachableStrategy.default()),
tasksMap = tasks,
runSpecVersion = pod.version,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ class HealthCheckTest extends UnitTest {
val check = new MarathonTcpHealthCheck(portIndex = Some(PortReference(0)))
val app = MarathonTestHelper.makeBasicApp().withPortDefinitions(Seq(PortDefinition(0)))
val hostName = "hostName"
val agentInfo = AgentInfo(host = hostName, agentId = Some("agent"), attributes = Nil)
val agentInfo = AgentInfo(host = hostName, agentId = Some("agent"), region = None, zone = None, attributes = Nil)
val task = {
val t: Task.LaunchedEphemeral = TestTaskBuilder.Helper.runningTaskForApp(app.id)
val hostPorts = Seq(4321)
Expand Down
Loading

0 comments on commit 3fc89ec

Please sign in to comment.