Skip to content

Commit

Permalink
test: Improve ONNXtests reliability (#1713)
Browse files Browse the repository at this point in the history
* Improve ONNXtests reliability

* fix constructor build

* make new defaults
  • Loading branch information
svotaw authored Nov 14, 2022
1 parent fe4c5d2 commit 0ff6802
Showing 1 changed file with 29 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package com.microsoft.azure.synapse.ml.onnx
import com.microsoft.azure.synapse.ml.core.env.FileUtilities
import com.microsoft.azure.synapse.ml.core.env.StreamUtilities.using
import com.microsoft.azure.synapse.ml.core.utils.FaultToleranceUtils
import com.microsoft.azure.synapse.ml.onnx.ONNXHub.DefaultCacheDir
import org.apache.commons.codec.digest.DigestUtils
import org.apache.commons.io.IOUtils
import org.apache.hadoop.fs.{FileSystem, Path}
Expand Down Expand Up @@ -71,28 +72,44 @@ object ONNXHubJsonProtocol extends DefaultJsonProtocol {
object ONNXHub {
val DefaultRepo: String = "onnx/models:main"
val AuthenticatedRepo: (String, String, String) = ("onnx", "models", "main")
val DefaultConnectTimeout = 15000
val DefaultReadTimeout = 5000
val DefaultConnectTimeout = 30000
val DefaultReadTimeout = 30000
val DefaultRetryCount = 3
val DefaultRetryTimeoutInSeconds = 600
}

class ONNXHub(val modelCacheDir: Path,
val connectTimeout: Int = ONNXHub.DefaultConnectTimeout,
val readTimeout: Int = ONNXHub.DefaultReadTimeout,
val retryCount: Int = ONNXHub.DefaultRetryCount,
val retryTimeoutInSeconds: Int = ONNXHub.DefaultRetryTimeoutInSeconds) extends Logging {
def this() = {
this(sys.env.get("ONNX_HOME")
lazy val DefaultCacheDir: Path = {
sys.env.get("ONNX_HOME")
.map(oh => new Path(oh, "hub"))
.orElse(sys.env.get("XDG_CACHE_HOME")
.map(xch => new Path(new Path(xch, "onnx"), "hub")))
.map(xch => new Path(new Path(xch, "onnx"), "hub")))
.getOrElse({
val home = new Path("placeholder")
.getFileSystem(SparkContext.getOrCreate().hadoopConfiguration)
.getHomeDirectory
FileUtilities.join(home, ".cache", "onnx", "hub")
}))
})
}
}

class ONNXHub(val modelCacheDir: Path,
val connectTimeout: Int,
val readTimeout: Int,
val retryCount: Int,
val retryTimeoutInSeconds: Int) extends Logging {
def this(connectTimeout: Int = ONNXHub.DefaultConnectTimeout,
readTimeout: Int = ONNXHub.DefaultReadTimeout,
retryCount: Int = ONNXHub.DefaultRetryCount,
retryTimeoutInSeconds: Int = ONNXHub.DefaultRetryTimeoutInSeconds) = {
this(DefaultCacheDir, connectTimeout, readTimeout, retryCount, retryTimeoutInSeconds)
}

def this(modelCacheDir: Path) = {
this(
modelCacheDir,
ONNXHub.DefaultConnectTimeout,
ONNXHub.DefaultReadTimeout,
ONNXHub.DefaultRetryCount,
ONNXHub.DefaultRetryTimeoutInSeconds)
}

def getDir: Path = modelCacheDir
Expand Down

0 comments on commit 0ff6802

Please sign in to comment.