Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: Improve ONNXtests reliability #1713

Merged
merged 4 commits into from
Nov 14, 2022
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package com.microsoft.azure.synapse.ml.onnx
import com.microsoft.azure.synapse.ml.core.env.FileUtilities
import com.microsoft.azure.synapse.ml.core.env.StreamUtilities.using
import com.microsoft.azure.synapse.ml.core.utils.FaultToleranceUtils
import com.microsoft.azure.synapse.ml.onnx.ONNXHub.DefaultCacheDir
import org.apache.commons.codec.digest.DigestUtils
import org.apache.commons.io.IOUtils
import org.apache.hadoop.fs.{FileSystem, Path}
Expand Down Expand Up @@ -71,28 +72,44 @@ object ONNXHubJsonProtocol extends DefaultJsonProtocol {
object ONNXHub {
val DefaultRepo: String = "onnx/models:main"
val AuthenticatedRepo: (String, String, String) = ("onnx", "models", "main")
val DefaultConnectTimeout = 15000
val DefaultReadTimeout = 5000
val DefaultConnectTimeout = 30000
val DefaultReadTimeout = 30000
val DefaultRetryCount = 3
val DefaultRetryTimeoutInSeconds = 600
}

class ONNXHub(val modelCacheDir: Path,
val connectTimeout: Int = ONNXHub.DefaultConnectTimeout,
val readTimeout: Int = ONNXHub.DefaultReadTimeout,
val retryCount: Int = ONNXHub.DefaultRetryCount,
val retryTimeoutInSeconds: Int = ONNXHub.DefaultRetryTimeoutInSeconds) extends Logging {
def this() = {
this(sys.env.get("ONNX_HOME")
lazy val DefaultCacheDir: Path = {
sys.env.get("ONNX_HOME")
.map(oh => new Path(oh, "hub"))
.orElse(sys.env.get("XDG_CACHE_HOME")
.map(xch => new Path(new Path(xch, "onnx"), "hub")))
.map(xch => new Path(new Path(xch, "onnx"), "hub")))
.getOrElse({
val home = new Path("placeholder")
.getFileSystem(SparkContext.getOrCreate().hadoopConfiguration)
.getHomeDirectory
FileUtilities.join(home, ".cache", "onnx", "hub")
}))
})
}
}

class ONNXHub(val modelCacheDir: Path,
val connectTimeout: Int,
val readTimeout: Int,
val retryCount: Int,
val retryTimeoutInSeconds: Int) extends Logging {
def this(connectTimeout: Int = ONNXHub.DefaultConnectTimeout,
readTimeout: Int = ONNXHub.DefaultReadTimeout,
retryCount: Int = ONNXHub.DefaultRetryCount,
retryTimeoutInSeconds: Int = ONNXHub.DefaultRetryTimeoutInSeconds) = {
this(DefaultCacheDir, connectTimeout, readTimeout, retryCount, retryTimeoutInSeconds)
}

def this(modelCacheDir: Path) = {
this(
modelCacheDir,
ONNXHub.DefaultConnectTimeout,
ONNXHub.DefaultReadTimeout,
ONNXHub.DefaultRetryCount,
ONNXHub.DefaultRetryTimeoutInSeconds)
}
svotaw marked this conversation as resolved.
Show resolved Hide resolved

def getDir: Path = modelCacheDir
Expand Down