Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix Guava version issue in Azure Synapse and Databricks #1103

Merged
merged 12 commits into from
Jul 8, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql._
import org.apache.spark.sql.types.StructType

import java.lang.reflect.Method
import scala.collection.mutable.ListBuffer
import scala.concurrent.{Awaitable, ExecutionContext, Future}
import scala.concurrent.duration.Duration
import scala.reflect.internal.util.ScalaClassLoader
import scala.util.control.NonFatal

/** Tunes model hyperparameters
Expand Down Expand Up @@ -95,7 +97,21 @@ class TuneHyperparameters(override val uid: String) extends Estimator[TuneHyperp
private def getExecutionContext: ExecutionContext = {
getParallelism match {
case 1 =>
ExecutionContext.fromExecutorService(MoreExecutors.sameThreadExecutor())
val classPath = "com.google.common.util.concurrent.MoreExecutors"
val funcNameOld = "sameThreadExecutor"
val funcNameNew = "newDirectExecutorService"
val c = ScalaClassLoader(getClass.getClassLoader).tryToLoadClass(classPath)
val method: Method = {
try {
c.get.getMethod(funcNameNew)
}
catch {
case _: NoSuchMethodError => c.get.getMethod(funcNameOld)
case _: NoSuchMethodException => c.get.getMethod(funcNameOld)
}
}
val executorService = method.invoke(c.get).asInstanceOf[ExecutorService]
ExecutionContext.fromExecutorService(executorService)
case _ =>
val keepAliveSeconds = 60L
val prefix = s"${this.getClass.getSimpleName}-thread-pool"
Expand Down