diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index bdf7b99966e4..3b725cf29553 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -121,6 +121,43 @@ To use a custom metrics.properties for the application master and executors, upd
Use lower-case suffixes, e.g. k, m, g, t, and p, for kibi-, mebi-, gibi-, tebi-, and pebibytes, respectively.
+
+ spark.yarn.am.resource.{resource-type} |
+ (none) |
+
+ Amount of resource to use for the YARN Application Master in client mode.
+ In cluster mode, use spark.yarn.driver.resource.<resource-type> instead.
+ Please note that this feature can be used only with YARN 3.0+
+ For reference, see YARN Resource Model documentation: https://hadoop.apache.org/docs/r3.0.1/hadoop-yarn/hadoop-yarn-site/ResourceModel.html
+
+ Example:
+ To request GPU resources from YARN, use: spark.yarn.am.resource.yarn.io/gpu
+ |
+
+
+ spark.yarn.driver.resource.{resource-type} |
+ (none) |
+
+ Amount of resource to use for the YARN Application Master in cluster mode.
+ Please note that this feature can be used only with YARN 3.0+
+ For reference, see YARN Resource Model documentation: https://hadoop.apache.org/docs/r3.0.1/hadoop-yarn/hadoop-yarn-site/ResourceModel.html
+
+ Example:
+ To request GPU resources from YARN, use: spark.yarn.driver.resource.yarn.io/gpu
+ |
+
+
+ spark.yarn.executor.resource.{resource-type} |
+ (none) |
+
+ Amount of resource to use per executor process.
+ Please note that this feature can be used only with YARN 3.0+
+ For reference, see YARN Resource Model documentation: https://hadoop.apache.org/docs/r3.0.1/hadoop-yarn/hadoop-yarn-site/ResourceModel.html
+
+ Example:
+ To request GPU resources from YARN, use: spark.yarn.executor.resource.yarn.io/gpu
+ |
+
spark.yarn.am.cores |
1 |
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 01bdebc000b9..67d2c8610e91 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -154,6 +154,8 @@ private[spark] class Client(
* available in the alpha API.
*/
def submitApplication(): ApplicationId = {
+ ResourceRequestHelper.validateResources(sparkConf)
+
var appId: ApplicationId = null
try {
launcherBackend.connect()
@@ -234,6 +236,13 @@ private[spark] class Client(
def createApplicationSubmissionContext(
newApp: YarnClientApplication,
containerContext: ContainerLaunchContext): ApplicationSubmissionContext = {
+ val amResources =
+ if (isClusterMode) {
+ sparkConf.getAllWithPrefix(config.YARN_DRIVER_RESOURCE_TYPES_PREFIX).toMap
+ } else {
+ sparkConf.getAllWithPrefix(config.YARN_AM_RESOURCE_TYPES_PREFIX).toMap
+ }
+ logDebug(s"AM resources: $amResources")
val appContext = newApp.getApplicationSubmissionContext
appContext.setApplicationName(sparkConf.get("spark.app.name", "Spark"))
appContext.setQueue(sparkConf.get(QUEUE_NAME))
@@ -256,6 +265,10 @@ private[spark] class Client(
val capability = Records.newRecord(classOf[Resource])
capability.setMemory(amMemory + amMemoryOverhead)
capability.setVirtualCores(amCores)
+ if (amResources.nonEmpty) {
+ ResourceRequestHelper.setResourceRequests(amResources, capability)
+ }
+ logDebug(s"Created resource capability for AM request: $capability")
sparkConf.get(AM_NODE_LABEL_EXPRESSION) match {
case Some(expr) =>
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ResourceRequestHelper.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ResourceRequestHelper.scala
new file mode 100644
index 000000000000..9534f3aaa243
--- /dev/null
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ResourceRequestHelper.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.lang.{Long => JLong}
+import java.lang.reflect.InvocationTargetException
+
+import scala.collection.mutable
+import scala.util.Try
+
+import org.apache.hadoop.yarn.api.records.Resource
+
+import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
+import org.apache.spark.util.Utils
+
+/**
+ * This helper class uses some of Hadoop 3 methods from the YARN API,
+ * so we need to use reflection to avoid compile error when building against Hadoop 2.x
+ */
+private object ResourceRequestHelper extends Logging {
+ private val AMOUNT_AND_UNIT_REGEX = "([0-9]+)([A-Za-z]*)".r
+ private val RESOURCE_INFO_CLASS = "org.apache.hadoop.yarn.api.records.ResourceInformation"
+
+ /**
+ * Validates sparkConf and throws a SparkException if any of standard resources (memory or cores)
+ * is defined with the property spark.yarn.x.resource.y
+ * Need to reject all combinations of AM / Driver / Executor and memory / CPU cores resources, as
+ * Spark has its own names for them (memory, cores),
+ * but YARN have its names too: (memory, memory-mb, mb) and (cores, vcores, cpu-vcores).
+ * We need to disable every possible way YARN could receive the resource definitions above.
+ */
+ def validateResources(sparkConf: SparkConf): Unit = {
+ val resourceDefinitions = Seq[(String, String)](
+ (AM_MEMORY.key, YARN_AM_RESOURCE_TYPES_PREFIX + "memory"),
+ (DRIVER_MEMORY.key, YARN_DRIVER_RESOURCE_TYPES_PREFIX + "memory"),
+ (EXECUTOR_MEMORY.key, YARN_EXECUTOR_RESOURCE_TYPES_PREFIX + "memory"),
+ (AM_MEMORY.key, YARN_AM_RESOURCE_TYPES_PREFIX + "mb"),
+ (DRIVER_MEMORY.key, YARN_DRIVER_RESOURCE_TYPES_PREFIX + "mb"),
+ (EXECUTOR_MEMORY.key, YARN_EXECUTOR_RESOURCE_TYPES_PREFIX + "mb"),
+ (AM_MEMORY.key, YARN_AM_RESOURCE_TYPES_PREFIX + "memory-mb"),
+ (DRIVER_MEMORY.key, YARN_DRIVER_RESOURCE_TYPES_PREFIX + "memory-mb"),
+ (EXECUTOR_MEMORY.key, YARN_EXECUTOR_RESOURCE_TYPES_PREFIX + "memory-mb"),
+ (AM_CORES.key, YARN_AM_RESOURCE_TYPES_PREFIX + "cores"),
+ (DRIVER_CORES.key, YARN_DRIVER_RESOURCE_TYPES_PREFIX + "cores"),
+ (EXECUTOR_CORES.key, YARN_EXECUTOR_RESOURCE_TYPES_PREFIX + "cores"),
+ (AM_CORES.key, YARN_AM_RESOURCE_TYPES_PREFIX + "vcores"),
+ (DRIVER_CORES.key, YARN_DRIVER_RESOURCE_TYPES_PREFIX + "vcores"),
+ (EXECUTOR_CORES.key, YARN_EXECUTOR_RESOURCE_TYPES_PREFIX + "vcores"),
+ (AM_CORES.key, YARN_AM_RESOURCE_TYPES_PREFIX + "cpu-vcores"),
+ (DRIVER_CORES.key, YARN_DRIVER_RESOURCE_TYPES_PREFIX + "cpu-vcores"),
+ (EXECUTOR_CORES.key, YARN_EXECUTOR_RESOURCE_TYPES_PREFIX + "cpu-vcores"))
+ val errorMessage = new mutable.StringBuilder()
+
+ resourceDefinitions.foreach { case (sparkName, resourceRequest) =>
+ if (sparkConf.contains(resourceRequest)) {
+ errorMessage.append(s"Error: Do not use $resourceRequest, " +
+ s"please use $sparkName instead!\n")
+ }
+ }
+
+ if (errorMessage.nonEmpty) {
+ throw new SparkException(errorMessage.toString())
+ }
+ }
+
+ /**
+ * Sets resource amount with the corresponding unit to the passed resource object.
+ * @param resources resource values to set
+ * @param resource resource object to update
+ */
+ def setResourceRequests(
+ resources: Map[String, String],
+ resource: Resource): Unit = {
+ require(resource != null, "Resource parameter should not be null!")
+
+ logDebug(s"Custom resources requested: $resources")
+ if (!isYarnResourceTypesAvailable()) {
+ if (resources.nonEmpty) {
+ logWarning("Ignoring custom resource requests because " +
+ "the version of YARN does not support it!")
+ }
+ return
+ }
+
+ val resInfoClass = Utils.classForName(RESOURCE_INFO_CLASS)
+ val setResourceInformationMethod =
+ resource.getClass.getMethod("setResourceInformation", classOf[String], resInfoClass)
+ resources.foreach { case (name, rawAmount) =>
+ try {
+ val AMOUNT_AND_UNIT_REGEX(amountPart, unitPart) = rawAmount
+ val amount = amountPart.toLong
+ val unit = unitPart match {
+ case "g" => "G"
+ case "t" => "T"
+ case "p" => "P"
+ case _ => unitPart
+ }
+ logDebug(s"Registering resource with name: $name, amount: $amount, unit: $unit")
+ val resourceInformation = createResourceInformation(name, amount, unit, resInfoClass)
+ setResourceInformationMethod.invoke(
+ resource, name, resourceInformation.asInstanceOf[AnyRef])
+ } catch {
+ case _: MatchError =>
+ throw new IllegalArgumentException(s"Resource request for '$name' ('$rawAmount') " +
+ s"does not match pattern $AMOUNT_AND_UNIT_REGEX.")
+ case e: InvocationTargetException if e.getCause != null => throw e.getCause
+ }
+ }
+ }
+
+ private def createResourceInformation(
+ resourceName: String,
+ amount: Long,
+ unit: String,
+ resInfoClass: Class[_]): Any = {
+ val resourceInformation =
+ if (unit.nonEmpty) {
+ val resInfoNewInstanceMethod = resInfoClass.getMethod("newInstance",
+ classOf[String], classOf[String], JLong.TYPE)
+ resInfoNewInstanceMethod.invoke(null, resourceName, unit, amount.asInstanceOf[JLong])
+ } else {
+ val resInfoNewInstanceMethod = resInfoClass.getMethod("newInstance",
+ classOf[String], JLong.TYPE)
+ resInfoNewInstanceMethod.invoke(null, resourceName, amount.asInstanceOf[JLong])
+ }
+ resourceInformation
+ }
+
+ /**
+ * Checks whether Hadoop 2.x or 3 is used as a dependency.
+ * In case of Hadoop 3 and later, the ResourceInformation class
+ * should be available on the classpath.
+ */
+ def isYarnResourceTypesAvailable(): Boolean = {
+ Try(Utils.classForName(RESOURCE_INFO_CLASS)).isSuccess
+ }
+}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
index 8a7551de7c08..ebdcf45603ce 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
@@ -140,10 +140,18 @@ private[yarn] class YarnAllocator(
}
// Number of cores per executor.
protected val executorCores = sparkConf.get(EXECUTOR_CORES)
- // Resource capability requested for each executors
- private[yarn] val resource = Resource.newInstance(
- executorMemory + memoryOverhead + pysparkWorkerMemory,
- executorCores)
+
+ private val executorResourceRequests =
+ sparkConf.getAllWithPrefix(config.YARN_EXECUTOR_RESOURCE_TYPES_PREFIX).toMap
+
+ // Resource capability requested for each executor
+ private[yarn] val resource: Resource = {
+ val resource = Resource.newInstance(
+ executorMemory + memoryOverhead + pysparkWorkerMemory, executorCores)
+ ResourceRequestHelper.setResourceRequests(executorResourceRequests, resource)
+ logDebug(s"Created resource capability: $resource")
+ resource
+ }
private val launcherPool = ThreadUtils.newDaemonCachedThreadPool(
"ContainerLauncher", sparkConf.get(CONTAINER_LAUNCH_MAX_THREADS))
@@ -288,9 +296,16 @@ private[yarn] class YarnAllocator(
s"executorsStarting: ${numExecutorsStarting.get}")
if (missing > 0) {
- logInfo(s"Will request $missing executor container(s), each with " +
- s"${resource.getVirtualCores} core(s) and " +
- s"${resource.getMemory} MB memory (including $memoryOverhead MB of overhead)")
+ if (log.isInfoEnabled()) {
+ var requestContainerMessage = s"Will request $missing executor container(s), each with " +
+ s"${resource.getVirtualCores} core(s) and " +
+ s"${resource.getMemory} MB memory (including $memoryOverhead MB of overhead)"
+ if (ResourceRequestHelper.isYarnResourceTypesAvailable() &&
+ executorResourceRequests.nonEmpty) {
+ requestContainerMessage ++= s" with custom resources: " + resource.toString
+ }
+ logInfo(requestContainerMessage)
+ }
// Split the pending container request into three groups: locality matched list, locality
// unmatched list and non-locality list. Take the locality matched container request into
@@ -456,13 +471,20 @@ private[yarn] class YarnAllocator(
// memory, but use the asked vcore count for matching, effectively disabling matching on vcore
// count.
val matchingResource = Resource.newInstance(allocatedContainer.getResource.getMemory,
- resource.getVirtualCores)
+ resource.getVirtualCores)
+
+ ResourceRequestHelper.setResourceRequests(executorResourceRequests, matchingResource)
+
+ logDebug(s"Calling amClient.getMatchingRequests with parameters: " +
+ s"priority: ${allocatedContainer.getPriority}, " +
+ s"location: $location, resource: $matchingResource")
val matchingRequests = amClient.getMatchingRequests(allocatedContainer.getPriority, location,
matchingResource)
// Match the allocation to a request
if (!matchingRequests.isEmpty) {
val containerRequest = matchingRequests.get(0).iterator.next
+ logDebug(s"Removing container request via AM client: $containerRequest")
amClient.removeContainerRequest(containerRequest)
containersToUse += allocatedContainer
} else {
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala
index ab8273bd6321..f2ed555edc1d 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala
@@ -345,4 +345,8 @@ package object config {
.booleanConf
.createWithDefault(false)
+ private[yarn] val YARN_EXECUTOR_RESOURCE_TYPES_PREFIX = "spark.yarn.executor.resource."
+ private[yarn] val YARN_DRIVER_RESOURCE_TYPES_PREFIX = "spark.yarn.driver.resource."
+ private[yarn] val YARN_AM_RESOURCE_TYPES_PREFIX = "spark.yarn.am.resource."
+
}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
index 26013a109c42..533cb2b0f0bd 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
@@ -23,6 +23,7 @@ import java.util.Properties
import scala.collection.JavaConverters._
import scala.collection.mutable.{HashMap => MutableHashMap}
+import scala.util.control.NonFatal
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
@@ -199,6 +200,20 @@ class ClientSuite extends SparkFunSuite with Matchers {
appContext.getMaxAppAttempts should be (42)
}
+ test("resource request (client mode)") {
+ val sparkConf = new SparkConf().set("spark.submit.deployMode", "client")
+ .set(YARN_AM_RESOURCE_TYPES_PREFIX + "fpga", "2")
+ .set(YARN_AM_RESOURCE_TYPES_PREFIX + "gpu", "3")
+ testResourceRequest(sparkConf, List("gpu", "fpga"), Seq(("fpga", 2), ("gpu", 3)))
+ }
+
+ test("resource request (cluster mode)") {
+ val sparkConf = new SparkConf().set("spark.submit.deployMode", "cluster")
+ .set(YARN_DRIVER_RESOURCE_TYPES_PREFIX + "fpga", "4")
+ .set(YARN_DRIVER_RESOURCE_TYPES_PREFIX + "gpu", "5")
+ testResourceRequest(sparkConf, List("gpu", "fpga"), Seq(("fpga", 4), ("gpu", 5)))
+ }
+
test("spark.yarn.jars with multiple paths and globs") {
val libs = Utils.createTempDir()
val single = Utils.createTempDir()
@@ -433,4 +448,30 @@ class ClientSuite extends SparkFunSuite with Matchers {
classpath(env)
}
+ private def testResourceRequest(
+ sparkConf: SparkConf,
+ resources: List[String],
+ expectedResources: Seq[(String, Long)]): Unit = {
+ assume(ResourceRequestHelper.isYarnResourceTypesAvailable())
+ ResourceRequestTestHelper.initializeResourceTypes(resources)
+
+ val args = new ClientArguments(Array())
+
+ val appContext = Records.newRecord(classOf[ApplicationSubmissionContext])
+ val getNewApplicationResponse = Records.newRecord(classOf[GetNewApplicationResponse])
+ val containerLaunchContext = Records.newRecord(classOf[ContainerLaunchContext])
+
+ val client = new Client(args, sparkConf)
+ client.createApplicationSubmissionContext(
+ new YarnClientApplication(getNewApplicationResponse, appContext),
+ containerLaunchContext)
+
+ appContext.getAMContainerSpec should be (containerLaunchContext)
+ appContext.getApplicationType should be ("SPARK")
+
+ expectedResources.foreach { case (name, value) =>
+ ResourceRequestTestHelper.getResourceTypeValue(appContext.getResource, name) should be (value)
+ }
+ }
+
}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ResourceRequestHelperSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ResourceRequestHelperSuite.scala
new file mode 100644
index 000000000000..60059987ba3f
--- /dev/null
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ResourceRequestHelperSuite.scala
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import org.apache.hadoop.yarn.api.records.Resource
+import org.apache.hadoop.yarn.util.Records
+import org.scalatest.Matchers
+
+import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
+import org.apache.spark.deploy.yarn.ResourceRequestTestHelper.ResourceInformation
+import org.apache.spark.deploy.yarn.config._
+import org.apache.spark.internal.config.{DRIVER_MEMORY, EXECUTOR_MEMORY}
+
+class ResourceRequestHelperSuite extends SparkFunSuite with Matchers {
+
+ private val CUSTOM_RES_1 = "custom-resource-type-1"
+ private val CUSTOM_RES_2 = "custom-resource-type-2"
+ private val MEMORY = "memory"
+ private val CORES = "cores"
+ private val NEW_CONFIG_EXECUTOR_MEMORY = YARN_EXECUTOR_RESOURCE_TYPES_PREFIX + MEMORY
+ private val NEW_CONFIG_EXECUTOR_CORES = YARN_EXECUTOR_RESOURCE_TYPES_PREFIX + CORES
+ private val NEW_CONFIG_AM_MEMORY = YARN_AM_RESOURCE_TYPES_PREFIX + MEMORY
+ private val NEW_CONFIG_AM_CORES = YARN_AM_RESOURCE_TYPES_PREFIX + CORES
+ private val NEW_CONFIG_DRIVER_MEMORY = YARN_DRIVER_RESOURCE_TYPES_PREFIX + MEMORY
+ private val NEW_CONFIG_DRIVER_CORES = YARN_DRIVER_RESOURCE_TYPES_PREFIX + CORES
+
+ test("resource request value does not match pattern") {
+ verifySetResourceRequestsException(List(CUSTOM_RES_1),
+ Map(CUSTOM_RES_1 -> "**@#"), CUSTOM_RES_1)
+ }
+
+ test("resource request just unit defined") {
+ verifySetResourceRequestsException(List(), Map(CUSTOM_RES_1 -> "m"), CUSTOM_RES_1)
+ }
+
+ test("resource request with null value should not be allowed") {
+ verifySetResourceRequestsException(List(), null, Map(CUSTOM_RES_1 -> "123"),
+ "requirement failed: Resource parameter should not be null!")
+ }
+
+ test("resource request with valid value and invalid unit") {
+ verifySetResourceRequestsException(List(CUSTOM_RES_1), createResource,
+ Map(CUSTOM_RES_1 -> "123ppp"), "")
+ }
+
+ test("resource request with valid value and without unit") {
+ verifySetResourceRequestsSuccessful(List(CUSTOM_RES_1), Map(CUSTOM_RES_1 -> "123"),
+ Map(CUSTOM_RES_1 -> ResourceInformation(CUSTOM_RES_1, 123, "")))
+ }
+
+ test("resource request with valid value and unit") {
+ verifySetResourceRequestsSuccessful(List(CUSTOM_RES_1), Map(CUSTOM_RES_1 -> "2g"),
+ Map(CUSTOM_RES_1 -> ResourceInformation(CUSTOM_RES_1, 2, "G")))
+ }
+
+ test("two resource requests with valid values and units") {
+ verifySetResourceRequestsSuccessful(List(CUSTOM_RES_1, CUSTOM_RES_2),
+ Map(CUSTOM_RES_1 -> "123m", CUSTOM_RES_2 -> "10G"),
+ Map(CUSTOM_RES_1 -> ResourceInformation(CUSTOM_RES_1, 123, "m"),
+ CUSTOM_RES_2 -> ResourceInformation(CUSTOM_RES_2, 10, "G")))
+ }
+
+ test("empty SparkConf should be valid") {
+ val sparkConf = new SparkConf()
+ ResourceRequestHelper.validateResources(sparkConf)
+ }
+
+ test("just normal resources are defined") {
+ val sparkConf = new SparkConf()
+ sparkConf.set(DRIVER_MEMORY.key, "3G")
+ sparkConf.set(DRIVER_CORES.key, "4")
+ sparkConf.set(EXECUTOR_MEMORY.key, "4G")
+ sparkConf.set(EXECUTOR_CORES.key, "2")
+ ResourceRequestHelper.validateResources(sparkConf)
+ }
+
+ test("memory defined with new config for executor") {
+ val sparkConf = new SparkConf()
+ sparkConf.set(NEW_CONFIG_EXECUTOR_MEMORY, "30G")
+ verifyValidateResourcesException(sparkConf, NEW_CONFIG_EXECUTOR_MEMORY)
+ }
+
+ test("memory defined with new config for executor 2") {
+ val sparkConf = new SparkConf()
+ sparkConf.set(YARN_EXECUTOR_RESOURCE_TYPES_PREFIX + "memory-mb", "30G")
+ verifyValidateResourcesException(sparkConf, YARN_EXECUTOR_RESOURCE_TYPES_PREFIX + "memory-mb")
+ }
+
+ test("memory defined with new config for executor 3") {
+ val sparkConf = new SparkConf()
+ sparkConf.set(YARN_EXECUTOR_RESOURCE_TYPES_PREFIX + "mb", "30G")
+ verifyValidateResourcesException(sparkConf, YARN_EXECUTOR_RESOURCE_TYPES_PREFIX + "mb")
+ }
+
+ test("cores defined with new config for executor") {
+ val sparkConf = new SparkConf()
+ sparkConf.set(NEW_CONFIG_EXECUTOR_CORES, "5")
+ verifyValidateResourcesException(sparkConf, NEW_CONFIG_EXECUTOR_CORES)
+ }
+
+ test("cores defined with new config for executor 2") {
+ val sparkConf = new SparkConf()
+ sparkConf.set(YARN_EXECUTOR_RESOURCE_TYPES_PREFIX + "vcores", "5")
+ verifyValidateResourcesException(sparkConf, YARN_EXECUTOR_RESOURCE_TYPES_PREFIX + "vcores")
+ }
+
+ test("memory defined with new config, client mode") {
+ val sparkConf = new SparkConf()
+ sparkConf.set(NEW_CONFIG_AM_MEMORY, "1G")
+ verifyValidateResourcesException(sparkConf, NEW_CONFIG_AM_MEMORY)
+ }
+
+ test("memory defined with new config for driver, cluster mode") {
+ val sparkConf = new SparkConf()
+ sparkConf.set(NEW_CONFIG_DRIVER_MEMORY, "1G")
+ verifyValidateResourcesException(sparkConf, NEW_CONFIG_DRIVER_MEMORY)
+ }
+
+ test("cores defined with new config, client mode") {
+ val sparkConf = new SparkConf()
+ sparkConf.set(NEW_CONFIG_AM_CORES, "3")
+ verifyValidateResourcesException(sparkConf, NEW_CONFIG_AM_CORES)
+ }
+
+ test("cores defined with new config for driver, cluster mode") {
+ val sparkConf = new SparkConf()
+ sparkConf.set(NEW_CONFIG_DRIVER_CORES, "1G")
+ verifyValidateResourcesException(sparkConf, NEW_CONFIG_DRIVER_CORES)
+ }
+
+ test("various duplicated definitions") {
+ val sparkConf = new SparkConf()
+ sparkConf.set(DRIVER_MEMORY.key, "2G")
+ sparkConf.set(DRIVER_CORES.key, "2")
+ sparkConf.set(EXECUTOR_MEMORY.key, "2G")
+ sparkConf.set(EXECUTOR_CORES.key, "4")
+ sparkConf.set(AM_MEMORY.key, "3G")
+ sparkConf.set(NEW_CONFIG_EXECUTOR_MEMORY, "3G")
+ sparkConf.set(NEW_CONFIG_AM_MEMORY, "2G")
+ sparkConf.set(NEW_CONFIG_DRIVER_MEMORY, "2G")
+
+ val thrown = intercept[SparkException] {
+ ResourceRequestHelper.validateResources(sparkConf)
+ }
+ thrown.getMessage should (
+ include(NEW_CONFIG_EXECUTOR_MEMORY) and
+ include(NEW_CONFIG_AM_MEMORY) and
+ include(NEW_CONFIG_DRIVER_MEMORY))
+ }
+
+ private def verifySetResourceRequestsSuccessful(
+ definedResourceTypes: List[String],
+ resourceRequests: Map[String, String],
+ expectedResources: Map[String, ResourceInformation]): Unit = {
+ assume(ResourceRequestHelper.isYarnResourceTypesAvailable())
+ ResourceRequestTestHelper.initializeResourceTypes(definedResourceTypes)
+
+ val resource = createResource()
+ ResourceRequestHelper.setResourceRequests(resourceRequests, resource)
+
+ expectedResources.foreach { case (name, ri) =>
+ val resourceInfo = ResourceRequestTestHelper.getResourceInformationByName(resource, name)
+ assert(resourceInfo === ri)
+ }
+ }
+
+ private def verifySetResourceRequestsException(
+ definedResourceTypes: List[String],
+ resourceRequests: Map[String, String],
+ message: String): Unit = {
+ val resource = createResource()
+ verifySetResourceRequestsException(definedResourceTypes, resource, resourceRequests, message)
+ }
+
+ private def verifySetResourceRequestsException(
+ definedResourceTypes: List[String],
+ resource: Resource,
+ resourceRequests: Map[String, String],
+ message: String) = {
+ assume(ResourceRequestHelper.isYarnResourceTypesAvailable())
+ ResourceRequestTestHelper.initializeResourceTypes(definedResourceTypes)
+ val thrown = intercept[IllegalArgumentException] {
+ ResourceRequestHelper.setResourceRequests(resourceRequests, resource)
+ }
+ if (!message.isEmpty) {
+ thrown.getMessage should include (message)
+ }
+ }
+
+ private def verifyValidateResourcesException(sparkConf: SparkConf, message: String) = {
+ val thrown = intercept[SparkException] {
+ ResourceRequestHelper.validateResources(sparkConf)
+ }
+ thrown.getMessage should include (message)
+ }
+
+ private def createResource(): Resource = {
+ val resource = Records.newRecord(classOf[Resource])
+ resource.setMemory(512)
+ resource.setVirtualCores(2)
+ resource
+ }
+}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ResourceRequestTestHelper.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ResourceRequestTestHelper.scala
new file mode 100644
index 000000000000..c46f3c5faff9
--- /dev/null
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ResourceRequestTestHelper.scala
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.yarn.api.records.Resource
+
+import org.apache.spark.util.Utils
+
+object ResourceRequestTestHelper {
+ def initializeResourceTypes(resourceTypes: List[String]): Unit = {
+ if (!ResourceRequestHelper.isYarnResourceTypesAvailable()) {
+ throw new IllegalStateException("This method should not be invoked " +
+ "since YARN resource types is not available because of old Hadoop version!" )
+ }
+
+ val allResourceTypes = new ListBuffer[AnyRef]
+ // ResourceUtils.reinitializeResources() is the YARN-way
+ // to specify resources for the execution of the tests.
+ // This method should receive standard resources with names of memory-mb and vcores.
+ // Without specifying the standard resources or specifying them
+ // with different names e.g. memory, YARN would throw various exceptions
+ // because it relies on that standard resources are always specified.
+ val defaultResourceTypes = List(
+ createResourceTypeInfo("memory-mb"),
+ createResourceTypeInfo("vcores"))
+ val customResourceTypes = resourceTypes.map(createResourceTypeInfo)
+ allResourceTypes ++= defaultResourceTypes
+ allResourceTypes ++= customResourceTypes
+
+ val resourceUtilsClass =
+ Utils.classForName("org.apache.hadoop.yarn.util.resource.ResourceUtils")
+ val reinitializeResourcesMethod = resourceUtilsClass.getMethod("reinitializeResources",
+ classOf[java.util.List[AnyRef]])
+ reinitializeResourcesMethod.invoke(null, allResourceTypes.asJava)
+ }
+
+ private def createResourceTypeInfo(resourceName: String): AnyRef = {
+ val resTypeInfoClass = Utils.classForName("org.apache.hadoop.yarn.api.records.ResourceTypeInfo")
+ val resTypeInfoNewInstanceMethod = resTypeInfoClass.getMethod("newInstance", classOf[String])
+ resTypeInfoNewInstanceMethod.invoke(null, resourceName)
+ }
+
+ def getResourceTypeValue(res: Resource, name: String): AnyRef = {
+ val resourceInformation = getResourceInformation(res, name)
+ invokeMethod(resourceInformation, "getValue")
+ }
+
+ def getResourceInformationByName(res: Resource, nameParam: String): ResourceInformation = {
+ val resourceInformation: AnyRef = getResourceInformation(res, nameParam)
+ val name = invokeMethod(resourceInformation, "getName").asInstanceOf[String]
+ val value = invokeMethod(resourceInformation, "getValue").asInstanceOf[Long]
+ val units = invokeMethod(resourceInformation, "getUnits").asInstanceOf[String]
+ ResourceInformation(name, value, units)
+ }
+
+ private def getResourceInformation(res: Resource, name: String): AnyRef = {
+ if (!ResourceRequestHelper.isYarnResourceTypesAvailable()) {
+ throw new IllegalStateException("assertResourceTypeValue() should not be invoked " +
+ "since yarn resource types is not available because of old Hadoop version!")
+ }
+
+ val getResourceInformationMethod = res.getClass.getMethod("getResourceInformation",
+ classOf[String])
+ val resourceInformation = getResourceInformationMethod.invoke(res, name)
+ resourceInformation
+ }
+
+ private def invokeMethod(resourceInformation: AnyRef, methodName: String): AnyRef = {
+ val getValueMethod = resourceInformation.getClass.getMethod(methodName)
+ getValueMethod.invoke(resourceInformation)
+ }
+
+ case class ResourceInformation(name: String, value: Long, units: String)
+}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
index 3f783baed110..35299166d981 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
@@ -24,6 +24,7 @@ import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.client.api.AMRMClient
import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.mockito.ArgumentCaptor
import org.mockito.Mockito._
import org.scalatest.{BeforeAndAfterEach, Matchers}
@@ -86,7 +87,8 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter
def createAllocator(
maxExecutors: Int = 5,
- rmClient: AMRMClient[ContainerRequest] = rmClient): YarnAllocator = {
+ rmClient: AMRMClient[ContainerRequest] = rmClient,
+ additionalConfigs: Map[String, String] = Map()): YarnAllocator = {
val args = Array(
"--jar", "somejar.jar",
"--class", "SomeClass")
@@ -95,6 +97,11 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter
.set("spark.executor.instances", maxExecutors.toString)
.set("spark.executor.cores", "5")
.set("spark.executor.memory", "2048")
+
+ for ((name, value) <- additionalConfigs) {
+ sparkConfClone.set(name, value)
+ }
+
new YarnAllocator(
"not used",
mock(classOf[RpcEndpointRef]),
@@ -108,12 +115,12 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter
clock)
}
- def createContainer(host: String): Container = {
+ def createContainer(host: String, resource: Resource = containerResource): Container = {
// When YARN 2.6+ is required, avoid deprecation by using version with long second arg
val containerId = ContainerId.newInstance(appAttemptId, containerNum)
containerNum += 1
val nodeId = NodeId.newInstance(host, 1000)
- Container.newInstance(containerId, nodeId, "", containerResource, RM_REQUEST_PRIORITY, null)
+ Container.newInstance(containerId, nodeId, "", resource, RM_REQUEST_PRIORITY, null)
}
test("single container allocated") {
@@ -134,6 +141,29 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter
size should be (0)
}
+ test("custom resource requested from yarn") {
+ assume(ResourceRequestHelper.isYarnResourceTypesAvailable())
+ ResourceRequestTestHelper.initializeResourceTypes(List("gpu"))
+
+ val mockAmClient = mock(classOf[AMRMClient[ContainerRequest]])
+ val handler = createAllocator(1, mockAmClient,
+ Map(YARN_EXECUTOR_RESOURCE_TYPES_PREFIX + "gpu" -> "2G"))
+
+ handler.updateResourceRequests()
+ val container = createContainer("host1", handler.resource)
+ handler.handleAllocatedContainers(Array(container))
+
+ // get amount of memory and vcores from resource, so effectively skipping their validation
+ val expectedResources = Resource.newInstance(handler.resource.getMemory(),
+ handler.resource.getVirtualCores)
+ ResourceRequestHelper.setResourceRequests(Map("gpu" -> "2G"), expectedResources)
+ val captor = ArgumentCaptor.forClass(classOf[ContainerRequest])
+
+ verify(mockAmClient).addContainerRequest(captor.capture())
+ val containerRequest: ContainerRequest = captor.getValue
+ assert(containerRequest.getCapability === expectedResources)
+ }
+
test("container should not be created if requested number if met") {
// request a single container and receive it
val handler = createAllocator(1)