-
Notifications
You must be signed in to change notification settings - Fork 834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add streaming API for MVAD #1893
Changes from all commits
b9c72c4
42ad7e3
7bb6245
d4f3cd8
f850905
8cc205f
9529be5
3b38b63
23aa74a
ea495ff
685f0e6
5244159
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,21 +10,25 @@ import com.microsoft.azure.synapse.ml.cognitive.anomaly.MADJsonProtocol._ | |
import com.microsoft.azure.synapse.ml.cognitive.vision.HasAsyncReply | ||
import com.microsoft.azure.synapse.ml.core.contracts.{HasInputCols, HasOutputCol} | ||
import com.microsoft.azure.synapse.ml.core.env.StreamUtilities.using | ||
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions | ||
import com.microsoft.azure.synapse.ml.io.http.HandlingUtils.{convertAndClose, sendWithRetries} | ||
import com.microsoft.azure.synapse.ml.io.http.RESTHelpers.{Client, retry} | ||
import com.microsoft.azure.synapse.ml.io.http._ | ||
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging | ||
import com.microsoft.azure.synapse.ml.stages._ | ||
import com.microsoft.azure.synapse.ml.param.CognitiveServiceStructParam | ||
import org.apache.commons.io.IOUtils | ||
import org.apache.hadoop.fs.{FileSystem, Path} | ||
import org.apache.http.client.methods._ | ||
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity} | ||
import org.apache.http.impl.client.CloseableHttpClient | ||
import org.apache.spark.injections.UDFUtils | ||
import org.apache.spark.internal.Logging | ||
import org.apache.spark.ml._ | ||
import org.apache.spark.ml.param._ | ||
import org.apache.spark.ml.util._ | ||
import org.apache.spark.sql._ | ||
import org.apache.spark.sql.expressions.Window | ||
import org.apache.spark.sql.functions._ | ||
import org.apache.spark.sql.types._ | ||
import spray.json._ | ||
|
@@ -49,7 +53,7 @@ private[ml] object Conversions { | |
scala.collection.Iterator[T] = RemoteIteratorWrapper[T](underlying) | ||
} | ||
|
||
object MADUtils { | ||
object MADUtils extends Logging { | ||
|
||
private[ml] val CreatedModels: mutable.ParHashSet[String] = new ParHashSet[String]() | ||
|
||
|
@@ -154,6 +158,27 @@ object MADUtils { | |
CreatedModels.clear() | ||
} | ||
|
||
private[ml] def checkModelStatus(url: String, modelId: String, subscriptionKey: String): Unit = try { | ||
val response = madGetModel(url, modelId, subscriptionKey) | ||
.parseJson.asJsObject.fields | ||
|
||
val modelInfo = response("modelInfo").asJsObject.fields | ||
val modelStatus = modelInfo("status").asInstanceOf[JsString].value.toLowerCase | ||
modelStatus match { | ||
case "failed" => | ||
val errors = modelInfo("errors").convertTo[Seq[DMAError]].toJson.compactPrint | ||
throw new RuntimeException(s"Caught errors during fitting: $errors") | ||
case "created" | "running" => | ||
throw new RuntimeException(s"model $modelId is not ready yet") | ||
case "ready" => | ||
logInfo("model is ready for inference") | ||
} | ||
} catch { | ||
case e: RuntimeException => | ||
throw new RuntimeException(s"Encounter error while fetching model $modelId, " + | ||
s"please double check the modelId is correct: ${e.getMessage}") | ||
} | ||
|
||
} | ||
|
||
trait MADHttpRequest extends HasURL with HasSubscriptionKey with HasAsyncReply { | ||
|
@@ -242,12 +267,8 @@ trait MADHttpRequest extends HasURL with HasSubscriptionKey with HasAsyncReply { | |
|
||
private case class StorageInfo(account: String, container: String, key: String, blob: String) | ||
|
||
trait MADBase extends HasOutputCol | ||
with MADHttpRequest with HasSetLocation with HasInputCols | ||
with ComplexParamsWritable with Wrappable | ||
with HasErrorCol with SynapseMLLogging { | ||
|
||
private def convertTimeFormat(name: String, v: String): String = { | ||
trait TimeConverter { | ||
protected def convertTimeFormat(name: String, v: String): String = { | ||
try { | ||
DateTimeFormatter.ISO_INSTANT.format(DateTimeFormatter.ISO_INSTANT.parse(v)) | ||
} | ||
|
@@ -257,6 +278,22 @@ trait MADBase extends HasOutputCol | |
s"${name.capitalize} should be ISO8601 format. e.g. 2021-01-01T00:00:00Z, received: ${e.toString}") | ||
} | ||
} | ||
} | ||
|
||
trait HasTimestampCol extends Params { | ||
val timestampCol = new Param[String](this, "timestampCol", "Timestamp column name") | ||
|
||
def setTimestampCol(v: String): this.type = set(timestampCol, v) | ||
|
||
def getTimestampCol: String = $(timestampCol) | ||
|
||
setDefault(timestampCol -> "timestamp") | ||
} | ||
|
||
trait MADBase extends HasOutputCol with TimeConverter | ||
with MADHttpRequest with HasSetLocation with HasInputCols | ||
with ComplexParamsWritable with Wrappable with HasTimestampCol | ||
with HasErrorCol with SynapseMLLogging { | ||
|
||
val startTime = new Param[String](this, "startTime", "A required field, start time" + | ||
" of data to be used for detection/generating multivariate anomaly detection model, should be date-time.") | ||
|
@@ -272,14 +309,6 @@ trait MADBase extends HasOutputCol | |
|
||
def getEndTime: String = $(endTime) | ||
|
||
val timestampCol = new Param[String](this, "timestampCol", "Timestamp column name") | ||
|
||
def setTimestampCol(v: String): this.type = set(timestampCol, v) | ||
|
||
def getTimestampCol: String = $(timestampCol) | ||
|
||
setDefault(timestampCol -> "timestamp") | ||
|
||
private def validateIntermediateSaveDir(dir: String): Boolean = { | ||
if (!dir.startsWith("wasbs://") && !dir.startsWith("abfss://")) { | ||
throw new IllegalArgumentException("improper HDFS loacation. Please use a wasb path such as: \n" + | ||
|
@@ -510,16 +539,7 @@ class SimpleFitMultivariateAnomaly(override val uid: String) extends Estimator[S | |
|
||
} | ||
|
||
object SimpleDetectMultivariateAnomaly extends ComplexParamsReadable[SimpleDetectMultivariateAnomaly] with Serializable | ||
|
||
class SimpleDetectMultivariateAnomaly(override val uid: String) extends Model[SimpleDetectMultivariateAnomaly] | ||
with MADBase with HasHandler { | ||
logClass() | ||
|
||
def this() = this(Identifiable.randomUID("SimpleDetectMultivariateAnomaly")) | ||
|
||
def urlPath: String = "anomalydetector/v1.1/multivariate/models/" | ||
|
||
trait DetectMAParams extends Params { | ||
val modelId = new Param[String](this, "modelId", "Format - uuid. Model identifier.") | ||
|
||
def setModelId(v: String): this.type = set(modelId, v) | ||
|
@@ -544,6 +564,17 @@ class SimpleDetectMultivariateAnomaly(override val uid: String) extends Model[Si | |
def getTopContributorCount: Int = $(topContributorCount) | ||
|
||
setDefault(topContributorCount -> 10) | ||
} | ||
|
||
object SimpleDetectMultivariateAnomaly extends ComplexParamsReadable[SimpleDetectMultivariateAnomaly] with Serializable | ||
|
||
class SimpleDetectMultivariateAnomaly(override val uid: String) extends Model[SimpleDetectMultivariateAnomaly] | ||
with MADBase with HasHandler with DetectMAParams { | ||
logClass() | ||
|
||
def this() = this(Identifiable.randomUID("SimpleDetectMultivariateAnomaly")) | ||
|
||
def urlPath: String = "anomalydetector/v1.1/multivariate/models/" | ||
|
||
protected def prepareEntity(dataSource: String): Option[AbstractHttpEntity] = { | ||
Some(new StringEntity( | ||
|
@@ -561,26 +592,7 @@ class SimpleDetectMultivariateAnomaly(override val uid: String) extends Model[Si | |
logTransform[DataFrame] { | ||
|
||
// check model status first | ||
try { | ||
val response = MADUtils.madGetModel(getUrl, getModelId, getSubscriptionKey) | ||
.parseJson.asJsObject.fields | ||
|
||
val modelInfo = response("modelInfo").asJsObject.fields | ||
val modelStatus = modelInfo("status").asInstanceOf[JsString].value.toLowerCase | ||
modelStatus match { | ||
case "failed" => | ||
val errors = modelInfo("errors").convertTo[Seq[DMAError]].toJson.compactPrint | ||
throw new RuntimeException(s"Caught errors during fitting: $errors") | ||
case "created" | "running" => | ||
throw new RuntimeException(s"model $getModelId is not ready yet") | ||
case "ready" => | ||
logInfo("model is ready for inference") | ||
} | ||
} catch { | ||
case e: RuntimeException => | ||
throw new RuntimeException(s"Encounter error while fetching model $getModelId, " + | ||
s"please double check the modelId is correct: ${e.getMessage}") | ||
} | ||
MADUtils.checkModelStatus(getUrl, getModelId, getSubscriptionKey) | ||
|
||
val spark = dataset.sparkSession | ||
val responseJson = submitDatasetAndJob(dataset) | ||
|
@@ -635,3 +647,112 @@ class SimpleDetectMultivariateAnomaly(override val uid: String) extends Model[Si | |
} | ||
|
||
} | ||
|
||
object DetectLastMultivariateAnomaly extends ComplexParamsReadable[DetectLastMultivariateAnomaly] with Serializable | ||
|
||
class DetectLastMultivariateAnomaly(override val uid: String) extends CognitiveServicesBase(uid) | ||
with HasInternalJsonOutputParser with TimeConverter with HasTimestampCol | ||
with HasSetLocation with HasCognitiveServiceInput with HasBatchSize | ||
with ComplexParamsWritable with Wrappable | ||
with HasErrorCol with SynapseMLLogging with DetectMAParams { | ||
logClass() | ||
|
||
def this() = this(Identifiable.randomUID("DetectLastMultivariateAnomaly")) | ||
|
||
def urlPath: String = "anomalydetector/v1.1/multivariate/models/" | ||
|
||
val inputVariablesCols = new StringArrayParam(this, "inputVariablesCols", | ||
"The names of the input variables columns") | ||
|
||
def setInputVariablesCols(value: Array[String]): this.type = set(inputVariablesCols, value) | ||
|
||
def getInputVariablesCols: Array[String] = $(inputVariablesCols) | ||
|
||
override def setBatchSize(value: Int): this.type = { | ||
logWarning("batchSize should be equal to 1 sliding window.") | ||
set(batchSize, value) | ||
} | ||
|
||
setDefault(batchSize -> 300) | ||
|
||
override protected def prepareUrl: Row => String = { | ||
row: Row => getUrl + s"$getModelId:detect-last" | ||
} | ||
|
||
protected def prepareEntity: Row => Option[AbstractHttpEntity] = { row => | ||
val timestamps = row.getAs[Seq[String]](s"${getTimestampCol}_list") | ||
val variables = getInputVariablesCols.map( | ||
variable => Variable(timestamps, row.getAs[Seq[Double]](s"${variable}_list"), variable)) | ||
Some(new StringEntity( | ||
DLMARequest(variables, getTopContributorCount).toJson.compactPrint | ||
)) | ||
} | ||
|
||
// scalastyle:off null | ||
override def transform(dataset: Dataset[_]): DataFrame = { | ||
logTransform[DataFrame]({ | ||
// check model status first | ||
MADUtils.checkModelStatus(getUrl, getModelId, getSubscriptionKey) | ||
|
||
val convertTimeFormatUdf = UDFUtils.oldUdf( | ||
{ value: String => convertTimeFormat("Timestamp column", value) }, | ||
StringType | ||
) | ||
val formattedDF = dataset.withColumn(getTimestampCol, convertTimeFormatUdf(col(getTimestampCol))) | ||
.sort(col(getTimestampCol).asc) | ||
.withColumn("group", lit(1)) | ||
|
||
val window = Window.partitionBy("group").rowsBetween(-getBatchSize, 0) | ||
var collectedDF = formattedDF | ||
var columnNames = Array(getTimestampCol) ++ getInputVariablesCols | ||
for (columnName <- columnNames) { | ||
collectedDF = collectedDF.withColumn(s"${columnName}_list", collect_list(columnName).over(window)) | ||
} | ||
collectedDF = collectedDF.drop("group") | ||
columnNames = columnNames.map(name => s"${name}_list") | ||
|
||
val testDF = getInternalTransformer(collectedDF.schema).transform(collectedDF) | ||
|
||
testDF | ||
.withColumn("isAnomaly", when(col(getOutputCol).isNotNull, | ||
col(s"$getOutputCol.results.value.isAnomaly")(0)).otherwise(null)) | ||
.withColumn("DetectDataTimestamp", when(col(getOutputCol).isNotNull, | ||
col(s"$getOutputCol.results.timestamp")(0)).otherwise(null)) | ||
.drop(columnNames: _*) | ||
|
||
Comment on lines
+705
to
+722
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like alot of the column names here are hard-coded. Will there be any instances where this doesent work with a given input df or setting of the params? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The hard-coded part "{name}_list" is because I use collect_list to aggregate those values in different rows to a single row with a list of values. Only the suffix '_list' is hard-coded, and finally those columns will be dropped and values will be mapped back to the original dataframe. |
||
}) | ||
} | ||
// scalastyle:on null | ||
|
||
override protected def getInternalTransformer(schema: StructType): PipelineModel = { | ||
val dynamicParamColName = DatasetExtensions.findUnusedColumnName("dynamic", schema) | ||
val lambda = Lambda(_.withColumn(dynamicParamColName, struct( | ||
s"${getTimestampCol}_list", getInputVariablesCols.map(name => s"${name}_list"): _*))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. likewise here |
||
|
||
val stages = Array( | ||
lambda, | ||
new SimpleHTTPTransformer() | ||
.setInputCol(dynamicParamColName) | ||
.setOutputCol(getOutputCol) | ||
.setInputParser(getInternalInputParser(schema)) | ||
.setOutputParser(getInternalOutputParser(schema)) | ||
.setHandler(handlingFunc _) | ||
.setConcurrency(getConcurrency) | ||
.setConcurrentTimeout(get(concurrentTimeout)) | ||
.setErrorCol(getErrorCol), | ||
new DropColumns().setCol(dynamicParamColName) | ||
) | ||
|
||
NamespaceInjections.pipelineModel(stages) | ||
|
||
} | ||
|
||
override def transformSchema(schema: StructType): StructType = { | ||
schema.add(getErrorCol, DMAError.schema) | ||
.add(getOutputCol, DLMAResponse.schema) | ||
.add("isAnomaly", BooleanType) | ||
} | ||
|
||
override def responseDataType: DataType = DLMAResponse.schema | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of "Timestamp column" here doesent look like its necessary