Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add streaming API for MVAD #1893

Merged
merged 12 commits into from
Apr 21, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,25 @@ import com.microsoft.azure.synapse.ml.cognitive.anomaly.MADJsonProtocol._
import com.microsoft.azure.synapse.ml.cognitive.vision.HasAsyncReply
import com.microsoft.azure.synapse.ml.core.contracts.{HasInputCols, HasOutputCol}
import com.microsoft.azure.synapse.ml.core.env.StreamUtilities.using
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions
import com.microsoft.azure.synapse.ml.io.http.HandlingUtils.{convertAndClose, sendWithRetries}
import com.microsoft.azure.synapse.ml.io.http.RESTHelpers.{Client, retry}
import com.microsoft.azure.synapse.ml.io.http._
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import com.microsoft.azure.synapse.ml.stages._
import com.microsoft.azure.synapse.ml.param.CognitiveServiceStructParam
import org.apache.commons.io.IOUtils
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.http.client.methods._
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
import org.apache.http.impl.client.CloseableHttpClient
import org.apache.spark.injections.UDFUtils
import org.apache.spark.internal.Logging
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import spray.json._
Expand All @@ -49,7 +53,7 @@ private[ml] object Conversions {
scala.collection.Iterator[T] = RemoteIteratorWrapper[T](underlying)
}

object MADUtils {
object MADUtils extends Logging {

private[ml] val CreatedModels: mutable.ParHashSet[String] = new ParHashSet[String]()

Expand Down Expand Up @@ -154,6 +158,27 @@ object MADUtils {
CreatedModels.clear()
}

private[ml] def checkModelStatus(url: String, modelId: String, subscriptionKey: String): Unit = try {
val response = madGetModel(url, modelId, subscriptionKey)
.parseJson.asJsObject.fields

val modelInfo = response("modelInfo").asJsObject.fields
val modelStatus = modelInfo("status").asInstanceOf[JsString].value.toLowerCase
modelStatus match {
case "failed" =>
val errors = modelInfo("errors").convertTo[Seq[DMAError]].toJson.compactPrint
throw new RuntimeException(s"Caught errors during fitting: $errors")
case "created" | "running" =>
throw new RuntimeException(s"model $modelId is not ready yet")
case "ready" =>
logInfo("model is ready for inference")
}
} catch {
case e: RuntimeException =>
throw new RuntimeException(s"Encounter error while fetching model $modelId, " +
s"please double check the modelId is correct: ${e.getMessage}")
}

}

trait MADHttpRequest extends HasURL with HasSubscriptionKey with HasAsyncReply {
Expand Down Expand Up @@ -242,12 +267,8 @@ trait MADHttpRequest extends HasURL with HasSubscriptionKey with HasAsyncReply {

private case class StorageInfo(account: String, container: String, key: String, blob: String)

trait MADBase extends HasOutputCol
with MADHttpRequest with HasSetLocation with HasInputCols
with ComplexParamsWritable with Wrappable
with HasErrorCol with SynapseMLLogging {

private def convertTimeFormat(name: String, v: String): String = {
trait TimeConverter {
protected def convertTimeFormat(name: String, v: String): String = {
try {
DateTimeFormatter.ISO_INSTANT.format(DateTimeFormatter.ISO_INSTANT.parse(v))
}
Expand All @@ -257,6 +278,22 @@ trait MADBase extends HasOutputCol
s"${name.capitalize} should be ISO8601 format. e.g. 2021-01-01T00:00:00Z, received: ${e.toString}")
}
}
}

trait HasTimestampCol extends Params {
val timestampCol = new Param[String](this, "timestampCol", "Timestamp column name")

def setTimestampCol(v: String): this.type = set(timestampCol, v)

def getTimestampCol: String = $(timestampCol)

setDefault(timestampCol -> "timestamp")
}

trait MADBase extends HasOutputCol with TimeConverter
with MADHttpRequest with HasSetLocation with HasInputCols
with ComplexParamsWritable with Wrappable with HasTimestampCol
with HasErrorCol with SynapseMLLogging {

val startTime = new Param[String](this, "startTime", "A required field, start time" +
" of data to be used for detection/generating multivariate anomaly detection model, should be date-time.")
Expand All @@ -272,14 +309,6 @@ trait MADBase extends HasOutputCol

def getEndTime: String = $(endTime)

val timestampCol = new Param[String](this, "timestampCol", "Timestamp column name")

def setTimestampCol(v: String): this.type = set(timestampCol, v)

def getTimestampCol: String = $(timestampCol)

setDefault(timestampCol -> "timestamp")

private def validateIntermediateSaveDir(dir: String): Boolean = {
if (!dir.startsWith("wasbs://") && !dir.startsWith("abfss://")) {
throw new IllegalArgumentException("improper HDFS loacation. Please use a wasb path such as: \n" +
Expand Down Expand Up @@ -510,16 +539,7 @@ class SimpleFitMultivariateAnomaly(override val uid: String) extends Estimator[S

}

object SimpleDetectMultivariateAnomaly extends ComplexParamsReadable[SimpleDetectMultivariateAnomaly] with Serializable

class SimpleDetectMultivariateAnomaly(override val uid: String) extends Model[SimpleDetectMultivariateAnomaly]
with MADBase with HasHandler {
logClass()

def this() = this(Identifiable.randomUID("SimpleDetectMultivariateAnomaly"))

def urlPath: String = "anomalydetector/v1.1/multivariate/models/"

trait DetectMAParams extends Params {
val modelId = new Param[String](this, "modelId", "Format - uuid. Model identifier.")

def setModelId(v: String): this.type = set(modelId, v)
Expand All @@ -544,6 +564,17 @@ class SimpleDetectMultivariateAnomaly(override val uid: String) extends Model[Si
def getTopContributorCount: Int = $(topContributorCount)

setDefault(topContributorCount -> 10)
}

object SimpleDetectMultivariateAnomaly extends ComplexParamsReadable[SimpleDetectMultivariateAnomaly] with Serializable

class SimpleDetectMultivariateAnomaly(override val uid: String) extends Model[SimpleDetectMultivariateAnomaly]
with MADBase with HasHandler with DetectMAParams {
logClass()

def this() = this(Identifiable.randomUID("SimpleDetectMultivariateAnomaly"))

def urlPath: String = "anomalydetector/v1.1/multivariate/models/"

protected def prepareEntity(dataSource: String): Option[AbstractHttpEntity] = {
Some(new StringEntity(
Expand All @@ -561,26 +592,7 @@ class SimpleDetectMultivariateAnomaly(override val uid: String) extends Model[Si
logTransform[DataFrame] {

// check model status first
try {
val response = MADUtils.madGetModel(getUrl, getModelId, getSubscriptionKey)
.parseJson.asJsObject.fields

val modelInfo = response("modelInfo").asJsObject.fields
val modelStatus = modelInfo("status").asInstanceOf[JsString].value.toLowerCase
modelStatus match {
case "failed" =>
val errors = modelInfo("errors").convertTo[Seq[DMAError]].toJson.compactPrint
throw new RuntimeException(s"Caught errors during fitting: $errors")
case "created" | "running" =>
throw new RuntimeException(s"model $getModelId is not ready yet")
case "ready" =>
logInfo("model is ready for inference")
}
} catch {
case e: RuntimeException =>
throw new RuntimeException(s"Encounter error while fetching model $getModelId, " +
s"please double check the modelId is correct: ${e.getMessage}")
}
MADUtils.checkModelStatus(getUrl, getModelId, getSubscriptionKey)

val spark = dataset.sparkSession
val responseJson = submitDatasetAndJob(dataset)
Expand Down Expand Up @@ -635,3 +647,112 @@ class SimpleDetectMultivariateAnomaly(override val uid: String) extends Model[Si
}

}

object DetectLastMultivariateAnomaly extends ComplexParamsReadable[DetectLastMultivariateAnomaly] with Serializable

class DetectLastMultivariateAnomaly(override val uid: String) extends CognitiveServicesBase(uid)
with HasInternalJsonOutputParser with TimeConverter with HasTimestampCol
with HasSetLocation with HasCognitiveServiceInput with HasBatchSize
with ComplexParamsWritable with Wrappable
with HasErrorCol with SynapseMLLogging with DetectMAParams {
logClass()

def this() = this(Identifiable.randomUID("DetectLastMultivariateAnomaly"))

def urlPath: String = "anomalydetector/v1.1/multivariate/models/"

val inputVariablesCols = new StringArrayParam(this, "inputVariablesCols",
"The names of the input variables columns")

def setInputVariablesCols(value: Array[String]): this.type = set(inputVariablesCols, value)

def getInputVariablesCols: Array[String] = $(inputVariablesCols)

override def setBatchSize(value: Int): this.type = {
logWarning("batchSize should be equal to 1 sliding window.")
set(batchSize, value)
}

setDefault(batchSize -> 300)

override protected def prepareUrl: Row => String = {
row: Row => getUrl + s"$getModelId:detect-last"
}

protected def prepareEntity: Row => Option[AbstractHttpEntity] = { row =>
val timestamps = row.getAs[Seq[String]](s"${getTimestampCol}_list")
val variables = getInputVariablesCols.map(
variable => Variable(timestamps, row.getAs[Seq[Double]](s"${variable}_list"), variable))
Some(new StringEntity(
DLMARequest(variables, getTopContributorCount).toJson.compactPrint
))
}

// scalastyle:off null
override def transform(dataset: Dataset[_]): DataFrame = {
logTransform[DataFrame]({
// check model status first
MADUtils.checkModelStatus(getUrl, getModelId, getSubscriptionKey)

val convertTimeFormatUdf = UDFUtils.oldUdf(
{ value: String => convertTimeFormat("Timestamp column", value) },
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of "Timestamp column" here doesent look like its necessary

StringType
)
val formattedDF = dataset.withColumn(getTimestampCol, convertTimeFormatUdf(col(getTimestampCol)))
.sort(col(getTimestampCol).asc)
.withColumn("group", lit(1))

val window = Window.partitionBy("group").rowsBetween(-getBatchSize, 0)
var collectedDF = formattedDF
var columnNames = Array(getTimestampCol) ++ getInputVariablesCols
for (columnName <- columnNames) {
collectedDF = collectedDF.withColumn(s"${columnName}_list", collect_list(columnName).over(window))
}
collectedDF = collectedDF.drop("group")
columnNames = columnNames.map(name => s"${name}_list")

val testDF = getInternalTransformer(collectedDF.schema).transform(collectedDF)

testDF
.withColumn("isAnomaly", when(col(getOutputCol).isNotNull,
col(s"$getOutputCol.results.value.isAnomaly")(0)).otherwise(null))
.withColumn("DetectDataTimestamp", when(col(getOutputCol).isNotNull,
col(s"$getOutputCol.results.timestamp")(0)).otherwise(null))
.drop(columnNames: _*)

Comment on lines +705 to +722
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like alot of the column names here are hard-coded. Will there be any instances where this doesent work with a given input df or setting of the params?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hard-coded part "{name}_list" is because I use collect_list to aggregate those values in different rows to a single row with a list of values. Only the suffix '_list' is hard-coded, and finally those columns will be dropped and values will be mapped back to the original dataframe.

})
}
// scalastyle:on null

override protected def getInternalTransformer(schema: StructType): PipelineModel = {
val dynamicParamColName = DatasetExtensions.findUnusedColumnName("dynamic", schema)
val lambda = Lambda(_.withColumn(dynamicParamColName, struct(
s"${getTimestampCol}_list", getInputVariablesCols.map(name => s"${name}_list"): _*)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likewise here


val stages = Array(
lambda,
new SimpleHTTPTransformer()
.setInputCol(dynamicParamColName)
.setOutputCol(getOutputCol)
.setInputParser(getInternalInputParser(schema))
.setOutputParser(getInternalOutputParser(schema))
.setHandler(handlingFunc _)
.setConcurrency(getConcurrency)
.setConcurrentTimeout(get(concurrentTimeout))
.setErrorCol(getErrorCol),
new DropColumns().setCol(dynamicParamColName)
)

NamespaceInjections.pipelineModel(stages)

}

override def transformSchema(schema: StructType): StructType = {
schema.add(getErrorCol, DMAError.schema)
.add(getOutputCol, DLMAResponse.schema)
.add("isAnomaly", BooleanType)
}

override def responseDataType: DataType = DLMAResponse.schema

}
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,18 @@ case class ModelState(epochIds: Option[Seq[Int]],
def getLatenciesInSeconds: java.util.List[Double] = this.latenciesInSeconds.getOrElse(Seq()).asJava
}

object DLMARequest extends SparkBindings[DLMARequest]

case class DLMARequest(variables: Seq[Variable], topContributorCount: Int)

object Variable extends SparkBindings[Variable]

case class Variable(timestamps: Seq[String], values: Seq[Double], variable: String)

object DLMAResponse extends SparkBindings[DLMAResponse]

case class DLMAResponse(variableStates: Option[Seq[DMAVariableState]], results: Option[Seq[DMAResult]])

object MADJsonProtocol extends DefaultJsonProtocol {
implicit val DMAReqEnc: RootJsonFormat[DMARequest] = jsonFormat4(DMARequest.apply)
implicit val EEnc: RootJsonFormat[DMAError] = jsonFormat2(DMAError.apply)
Expand All @@ -129,4 +141,6 @@ object MADJsonProtocol extends DefaultJsonProtocol {
implicit val DMASetupInfoEnc: RootJsonFormat[DMASetupInfo] = jsonFormat4(DMASetupInfo.apply)
implicit val DMASummaryEnc: RootJsonFormat[DMASummary] = jsonFormat4(DMASummary.apply)
implicit val MAEModelInfoEnc: RootJsonFormat[MAEModelInfo] = jsonFormat10(MAEModelInfo.apply)
implicit val VariableEnc: RootJsonFormat[Variable] = jsonFormat3(Variable.apply)
implicit val DLMARequestEnc: RootJsonFormat[DLMARequest] = jsonFormat2(DLMARequest.apply)
}
Loading