diff --git a/core/src/main/scala/com/microsoft/ml/spark/automl/TuneHyperparameters.scala b/core/src/main/scala/com/microsoft/ml/spark/automl/TuneHyperparameters.scala index a0c4ffdd28..1d74928780 100644 --- a/core/src/main/scala/com/microsoft/ml/spark/automl/TuneHyperparameters.scala +++ b/core/src/main/scala/com/microsoft/ml/spark/automl/TuneHyperparameters.scala @@ -21,9 +21,11 @@ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql._ import org.apache.spark.sql.types.StructType +import java.lang.reflect.Method import scala.collection.mutable.ListBuffer import scala.concurrent.{Awaitable, ExecutionContext, Future} import scala.concurrent.duration.Duration +import scala.reflect.internal.util.ScalaClassLoader import scala.util.control.NonFatal /** Tunes model hyperparameters @@ -95,7 +97,21 @@ class TuneHyperparameters(override val uid: String) extends Estimator[TuneHyperp private def getExecutionContext: ExecutionContext = { getParallelism match { case 1 => - ExecutionContext.fromExecutorService(MoreExecutors.sameThreadExecutor()) + val classPath = "com.google.common.util.concurrent.MoreExecutors" + val funcNameOld = "sameThreadExecutor" + val funcNameNew = "newDirectExecutorService" + val c = ScalaClassLoader(getClass.getClassLoader).tryToLoadClass(classPath) + val method: Method = { + try { + c.get.getMethod(funcNameNew) + } + catch { + case _: NoSuchMethodError => c.get.getMethod(funcNameOld) + case _: NoSuchMethodException => c.get.getMethod(funcNameOld) + } + } + val executorService = method.invoke(c.get).asInstanceOf[ExecutorService] + ExecutionContext.fromExecutorService(executorService) case _ => val keepAliveSeconds = 60L val prefix = s"${this.getClass.getSimpleName}-thread-pool"