diff --git a/deep-learning/src/main/scala/com/microsoft/azure/synapse/ml/onnx/ONNXHub.scala b/deep-learning/src/main/scala/com/microsoft/azure/synapse/ml/onnx/ONNXHub.scala index 1e474bca07..975de955c2 100644 --- a/deep-learning/src/main/scala/com/microsoft/azure/synapse/ml/onnx/ONNXHub.scala +++ b/deep-learning/src/main/scala/com/microsoft/azure/synapse/ml/onnx/ONNXHub.scala @@ -6,6 +6,7 @@ package com.microsoft.azure.synapse.ml.onnx import com.microsoft.azure.synapse.ml.core.env.FileUtilities import com.microsoft.azure.synapse.ml.core.env.StreamUtilities.using import com.microsoft.azure.synapse.ml.core.utils.FaultToleranceUtils +import com.microsoft.azure.synapse.ml.onnx.ONNXHub.DefaultCacheDir import org.apache.commons.codec.digest.DigestUtils import org.apache.commons.io.IOUtils import org.apache.hadoop.fs.{FileSystem, Path} @@ -71,28 +72,44 @@ object ONNXHubJsonProtocol extends DefaultJsonProtocol { object ONNXHub { val DefaultRepo: String = "onnx/models:main" val AuthenticatedRepo: (String, String, String) = ("onnx", "models", "main") - val DefaultConnectTimeout = 15000 - val DefaultReadTimeout = 5000 + val DefaultConnectTimeout = 30000 + val DefaultReadTimeout = 30000 val DefaultRetryCount = 3 val DefaultRetryTimeoutInSeconds = 600 -} -class ONNXHub(val modelCacheDir: Path, - val connectTimeout: Int = ONNXHub.DefaultConnectTimeout, - val readTimeout: Int = ONNXHub.DefaultReadTimeout, - val retryCount: Int = ONNXHub.DefaultRetryCount, - val retryTimeoutInSeconds: Int = ONNXHub.DefaultRetryTimeoutInSeconds) extends Logging { - def this() = { - this(sys.env.get("ONNX_HOME") + lazy val DefaultCacheDir: Path = { + sys.env.get("ONNX_HOME") .map(oh => new Path(oh, "hub")) .orElse(sys.env.get("XDG_CACHE_HOME") - .map(xch => new Path(new Path(xch, "onnx"), "hub"))) + .map(xch => new Path(new Path(xch, "onnx"), "hub"))) .getOrElse({ val home = new Path("placeholder") .getFileSystem(SparkContext.getOrCreate().hadoopConfiguration) .getHomeDirectory FileUtilities.join(home, ".cache", "onnx", "hub") - })) + }) + } +} + +class ONNXHub(val modelCacheDir: Path, + val connectTimeout: Int, + val readTimeout: Int, + val retryCount: Int, + val retryTimeoutInSeconds: Int) extends Logging { + def this(connectTimeout: Int = ONNXHub.DefaultConnectTimeout, + readTimeout: Int = ONNXHub.DefaultReadTimeout, + retryCount: Int = ONNXHub.DefaultRetryCount, + retryTimeoutInSeconds: Int = ONNXHub.DefaultRetryTimeoutInSeconds) = { + this(DefaultCacheDir, connectTimeout, readTimeout, retryCount, retryTimeoutInSeconds) + } + + def this(modelCacheDir: Path) = { + this( + modelCacheDir, + ONNXHub.DefaultConnectTimeout, + ONNXHub.DefaultReadTimeout, + ONNXHub.DefaultRetryCount, + ONNXHub.DefaultRetryTimeoutInSeconds) } def getDir: Path = modelCacheDir