From 4845c31f7666ef50f58363da011307e516c2755e Mon Sep 17 00:00:00 2001 From: Assaf Israel Date: Wed, 16 Jun 2021 11:25:25 -0700 Subject: [PATCH] feat: TextAnalytics SDK connector POC Co-authored-by: Assaf Israel Co-authored-by: Assaf Israel Co-authored-by: Samantha Konigsberg Co-authored-by: victoriajmicrosoft Co-authored-by: Suhas Mehta Co-authored-by: Preeti Pidatala --- build.sbt | 4 +- .../ml/cognitive/CognitiveServiceBase.scala | 2 +- .../MultivariateAnomalyDetection.scala | 2 +- .../synapse/ml/cognitive/TextAnalytics.scala | 12 +- .../ml/cognitive/TextAnalyticsSDK.scala | 397 ++++++++++++++++++ .../cognitive/TextAnalyticsSDKSchemasV4.scala | 351 ++++++++++++++++ .../split1/TextAnalyticsSDKSuite.scala | 265 ++++++++++++ .../cognitive/split1/TextAnalyticsSuite.scala | 18 +- .../synapse/ml/core/contracts/Params.scala | 1 - .../synapse/ml/io/http/HTTPTransformer.scala | 7 +- .../ml/io/http/SimpleHTTPTransformer.scala | 4 +- .../ml/stages/PartitionConsolidator.scala | 4 +- 12 files changed, 1042 insertions(+), 25 deletions(-) create mode 100644 cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/TextAnalyticsSDK.scala create mode 100644 cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/TextAnalyticsSDKSchemasV4.scala create mode 100644 cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/split1/TextAnalyticsSDKSuite.scala diff --git a/build.sbt b/build.sbt index 148da1c8ff..1f49671ec4 100644 --- a/build.sbt +++ b/build.sbt @@ -22,6 +22,7 @@ val excludes = Seq( ) val coreDependencies = Seq( + "com.fasterxml.jackson.module" %% "jackson-module-scala" % "2.12.3", "org.apache.spark" %% "spark-core" % sparkVersion % "compile", "org.apache.spark" %% "spark-mllib" % sparkVersion % "compile", "org.apache.spark" %% "spark-avro" % sparkVersion % "provided", @@ -312,7 +313,8 @@ lazy val cognitive = (project in file("cognitive")) libraryDependencies ++= Seq( "com.microsoft.cognitiveservices.speech" % "client-sdk" % "1.14.0", "com.azure" % "azure-storage-blob" % "12.8.0", // can't upgrade higher due to conflict with jackson-databind -), + "com.azure" % "azure-ai-textanalytics" % "5.1.4", + ), resolvers += speechResolver, name := "synapseml-cognitive" ): _*) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/CognitiveServiceBase.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/CognitiveServiceBase.scala index 0d7eb466a3..8e632a86d0 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/CognitiveServiceBase.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/CognitiveServiceBase.scala @@ -269,7 +269,7 @@ trait HasSetLocation extends Wrappable with HasURL with HasUrlPath { } abstract class CognitiveServicesBaseNoHandler(val uid: String) extends Transformer - with HTTPParams with HasOutputCol + with ConcurrencyParams with HasOutputCol with HasURL with ComplexParamsWritable with HasSubscriptionKey with HasErrorCol with BasicLogging { diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/MultivariateAnomalyDetection.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/MultivariateAnomalyDetection.scala index 1e88d0acf8..ba326a8eb3 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/MultivariateAnomalyDetection.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/MultivariateAnomalyDetection.scala @@ -115,7 +115,7 @@ trait MADHttpRequest extends HasURL with HasSubscriptionKey with HasAsyncReply { trait MADBase extends HasOutputCol with MADHttpRequest with HasSetLocation with HasInputCols - with HTTPParams with ComplexParamsWritable with Wrappable + with ConcurrencyParams with ComplexParamsWritable with Wrappable with HasSubscriptionKey with HasErrorCol with BasicLogging { val startTime = new Param[String](this, "startTime", "A required field, start time" + diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/TextAnalytics.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/TextAnalytics.scala index e2df2b30e4..87b39d65d5 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/TextAnalytics.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/TextAnalytics.scala @@ -25,10 +25,8 @@ import java.net.URI import java.util import scala.collection.JavaConverters._ -abstract class TextAnalyticsBase(override val uid: String) extends CognitiveServicesBaseNoHandler(uid) - with HasCognitiveServiceInput with HasInternalJsonOutputParser with HasSetLocation - with HasSetLinkedService { +trait TextAnalyticsInputParams extends HasServiceParams { val text = new ServiceParam[Seq[String]](this, "text", "the text in the request body", isRequired = true) def setTextCol(v: String): this.type = setVectorParam(text, v) @@ -48,7 +46,13 @@ abstract class TextAnalyticsBase(override val uid: String) extends CognitiveServ def setLanguage(v: String): this.type = setScalarParam(language, Seq(v)) - setDefault(language -> Left(Seq("en"))) + //setDefault(language -> Left(Seq("en"))) + +} + +abstract class TextAnalyticsBase(override val uid: String) extends CognitiveServicesBaseNoHandler(uid) + with HasCognitiveServiceInput with HasInternalJsonOutputParser with HasSetLocation + with HasSetLinkedService with TextAnalyticsInputParams { protected def innerResponseDataType: StructType = responseDataType("documents").dataType match { diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/TextAnalyticsSDK.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/TextAnalyticsSDK.scala new file mode 100644 index 0000000000..9e5d847f9e --- /dev/null +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/TextAnalyticsSDK.scala @@ -0,0 +1,397 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.cognitive + +import com.azure.ai.textanalytics.models._ +import com.azure.ai.textanalytics.{TextAnalyticsClient, TextAnalyticsClientBuilder} +import com.azure.core.credential.AzureKeyCredential +import com.azure.core.http.policy.RetryPolicy +import com.azure.core.util.{ClientOptions, Context, Header} +import com.microsoft.azure.synapse.ml.build.BuildInfo +import com.microsoft.azure.synapse.ml.cognitive.SDKConverters._ +import com.microsoft.azure.synapse.ml.core.contracts.HasOutputCol +import com.microsoft.azure.synapse.ml.core.schema.SparkBindings +import com.microsoft.azure.synapse.ml.core.utils.AsyncUtils.bufferedAwait +import com.microsoft.azure.synapse.ml.io.http.{ConcurrencyParams, HasErrorCol, HasURL, HeaderValues} +import com.microsoft.azure.synapse.ml.logging.BasicLogging +import com.microsoft.azure.synapse.ml.stages.{FixedMiniBatchTransformer, FlattenBatch, HasBatchSize} +import org.apache.spark.ml.param.{ParamMap, ServiceParam} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer} +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.types.{ArrayType, StringType, StructType} +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import spray.json.DefaultJsonProtocol._ + +import java.time.temporal.ChronoUnit +import scala.collection.JavaConverters._ +import scala.concurrent.duration.Duration +import scala.concurrent.{ExecutionContext, Future} + +trait HasOptions extends HasServiceParams { + val modelVersion = new ServiceParam[String]( + this, name = "modelVersion", "modelVersion option") + + def getModelVersion: String = $(modelVersion).left.get + + def setModelVersion(v: String): this.type = setScalarParam(modelVersion, v) + + def setModelVersionCol(v: String): this.type = setVectorParam(modelVersion, v) + + val includeStatistics = new ServiceParam[Boolean]( + this, name = "includeStatistics", "includeStatistics option") + + def getIncludeStatistics: Boolean = $(includeStatistics).left.get + + def setIncludeStatistics(v: Boolean): this.type = setScalarParam(includeStatistics, v) + + def setIncludeStatisticsCol(v: String): this.type = setVectorParam(includeStatistics, v) + + val disableServiceLogs = new ServiceParam[Boolean]( + this, name = "disableServiceLogs", "disableServiceLogs option") + + def getDisableServiceLogs: Boolean = $(disableServiceLogs).left.get + + def setDisableServiceLogs(v: Boolean): this.type = setScalarParam(disableServiceLogs, v) + + def setDisableServiceLogsCol(v: String): this.type = setVectorParam(disableServiceLogs, v) + + setDefault( + modelVersion -> Left("latest"), + includeStatistics -> Left(false), + disableServiceLogs -> Left(true) + ) + +} + +trait HasOpinionMining extends HasServiceParams { + val includeOpinionMining = new ServiceParam[Boolean]( + this, name = "includeOpinionMining", "includeOpinionMining option") + + def getIncludeOpinionMining: Boolean = $(includeOpinionMining).left.get + + def setIncludeOpinionMining(v: Boolean): this.type = setScalarParam(includeOpinionMining, v) + + def setIncludeOpinionMiningCol(v: String): this.type = setVectorParam(includeOpinionMining, v) + + setDefault( + includeOpinionMining -> Left(false) + ) +} + +private[ml] abstract class TextAnalyticsSDKBase[T]() + extends Transformer with HasErrorCol with HasURL with HasSetLocation with HasSubscriptionKey + with TextAnalyticsInputParams with HasOutputCol with ConcurrencyParams with HasBatchSize with HasOptions + with ComplexParamsWritable with BasicLogging { + + override def urlPath: String = "" + + val responseBinding: SparkBindings[TAResponseSDK[T]] + + def invokeTextAnalytics(client: TextAnalyticsClient, + text: Seq[String], + lang: Seq[String], + row: Row + ): Seq[TAResponseSDK[T]] + + setDefault(batchSize -> 5) + + protected def transformTextRows(toRow: TAResponseSDK[T] => Row) + (rows: Iterator[Row]): Iterator[Row] = { + if (rows.hasNext) { + val key = new AzureKeyCredential("placeholder") + val client = new TextAnalyticsClientBuilder() + .retryPolicy(new RetryPolicy("Retry-After", ChronoUnit.SECONDS)) + .clientOptions(new ClientOptions().setHeaders(Seq( + new Header("User-Agent", s"synapseml/${BuildInfo.version}${HeaderValues.PlatformInfo}")).asJava)) + .credential(key) + .endpoint(getUrl) + .buildClient() + + val dur = get(concurrentTimeout) + .map(ct => Duration.fromNanos((ct * math.pow(10, 9)).toLong)) //scalastyle:ignore magic.number + .getOrElse(Duration.Inf) + + val futures = rows.map { row => + Future { + val validText = getValue(row, text) + val langs = getValueOpt(row, language).getOrElse(Seq.fill(validText.length)("")) + val validLanguages = if (langs.length == 1) { + Seq.fill(validText.length)(langs.head) + } else { + langs + } + assert(validLanguages.length == validText.length) + + key.update(getValue(row, subscriptionKey)) + val results = invokeTextAnalytics(client, validText, validLanguages, row) + Row.fromSeq(row.toSeq ++ Seq(results.map(toRow))) // Adding a new column + }(ExecutionContext.global) + } + bufferedAwait(futures, getConcurrency, dur)(ExecutionContext.global) + } else { + Iterator.empty + } + } + + protected def shouldAutoBatch(schema: StructType): Boolean = { + ($(text), get(language)) match { + case (Left(_), Some(Right(b))) => + schema(b).dataType.isInstanceOf[StringType] + case (Left(_), None) => + true + case (Right(a), Some(Right(b))) => + (schema(a).dataType, schema(b).dataType) match { + case (_: StringType, _: StringType) => true + case (_: ArrayType, _: ArrayType) => false + case (_: StringType, _: ArrayType) | (_: ArrayType, _: StringType) => + throw new IllegalArgumentException(s"Mismatched column types. " + + s"Both columns $a and $b need to be StringType (for auto batching)" + + s" or ArrayType(StringType) (for user batching)") + case _ => + throw new IllegalArgumentException(s"Unknown column types. " + + s"Both columns $a and $b need to be StringType (for auto batching)" + + s" or ArrayType(StringType) (for user batching)") + } + case (Right(a), _) => + schema(a).dataType.isInstanceOf[StringType] + case _ => false + } + } + + override def transform(dataset: Dataset[_]): DataFrame = { + logTransform[DataFrame]({ + val df = dataset.toDF + + val batchedDF = if (shouldAutoBatch(df.schema)) { + new FixedMiniBatchTransformer().setBatchSize(getBatchSize).transform(df) + } else { + df + } + + val enc = RowEncoder(batchedDF.schema.add(getOutputCol, ArrayType(responseBinding.schema))) + val toRow = responseBinding.makeToRowConverter + val resultDF = batchedDF.mapPartitions(transformTextRows(toRow))(enc) + + if (shouldAutoBatch(df.schema)) { + new FlattenBatch().transform(resultDF) + } else { + resultDF + } + }) + } + + override def transformSchema(schema: StructType): StructType = { + if (shouldAutoBatch(schema)) { + schema.add(getOutputCol, responseBinding.schema) + } else { + schema.add(getOutputCol, ArrayType(responseBinding.schema)) + } + } + + override def copy(extra: ParamMap): Transformer = defaultCopy(extra) +} + +object LanguageDetectorSDK extends ComplexParamsReadable[LanguageDetectorSDK] + +class LanguageDetectorSDK(override val uid: String) + extends TextAnalyticsSDKBase[DetectedLanguageSDK]() { + logClass() + + def this() = this(Identifiable.randomUID("LanguageDetectorSDK")) + + override val responseBinding: DetectLanguageResponseSDK.type = DetectLanguageResponseSDK + + override def invokeTextAnalytics(client: TextAnalyticsClient, + input: Seq[String], + hints: Seq[String], + row: Row + ): Seq[TAResponseSDK[DetectedLanguageSDK]] = { + val documents = (input, hints, input.indices).zipped.map { (doc, hint, i) => + new DetectLanguageInput(i.toString, doc, hint) + }.asJava + + val options = new TextAnalyticsRequestOptions() + .setModelVersion(getValue(row, modelVersion)) + .setIncludeStatistics(getValue(row, includeStatistics)) + .setServiceLogsDisabled(getValue(row, disableServiceLogs)) + + val response = client.detectLanguageBatchWithResponse(documents, options, Context.NONE).getValue + toResponse(response.asScala) + } +} + +object KeyPhraseExtractorSDK extends ComplexParamsReadable[KeyPhraseExtractorSDK] + +class KeyPhraseExtractorSDK(override val uid: String) + extends TextAnalyticsSDKBase[KeyphraseSDK]() { + logClass() + + def this() = this(Identifiable.randomUID("KeyPhraseExtractorSDK")) + + override val responseBinding: KeyPhraseResponseSDK.type = KeyPhraseResponseSDK + + override def invokeTextAnalytics(client: TextAnalyticsClient, + input: Seq[String], + lang: Seq[String], + row: Row + ): Seq[TAResponseSDK[KeyphraseSDK]] = { + + val documents = (input, lang, lang.indices).zipped.map { (doc, lang, i) => + new TextDocumentInput(i.toString, doc).setLanguage(lang) + }.asJava + val options = new TextAnalyticsRequestOptions() + .setModelVersion(getValue(row, modelVersion)) + .setIncludeStatistics(getValue(row, includeStatistics)) + .setServiceLogsDisabled(getValue(row, disableServiceLogs)) + + val response = client.extractKeyPhrasesBatchWithResponse(documents, options, Context.NONE).getValue + toResponse(response.asScala) + } +} + +object TextSentimentSDK extends ComplexParamsReadable[TextSentiment] + +class TextSentimentSDK(override val uid: String) + extends TextAnalyticsSDKBase[SentimentScoredDocumentSDK]() with HasOpinionMining { + logClass() + + def this() = this(Identifiable.randomUID("TextSentimentSDK")) + + override val responseBinding: SentimentResponseSDK.type = SentimentResponseSDK + + override def invokeTextAnalytics(client: TextAnalyticsClient, + input: Seq[String], + lang: Seq[String], + row: Row + ): Seq[TAResponseSDK[SentimentScoredDocumentSDK]] = { + val documents = (input, lang, lang.indices).zipped.map { (doc, lang, i) => + new TextDocumentInput(i.toString, doc).setLanguage(lang) + }.asJava + + val options = new AnalyzeSentimentOptions() + .setModelVersion(getValue(row, modelVersion)) + .setIncludeStatistics(getValue(row, includeStatistics)) + .setServiceLogsDisabled(getValue(row, disableServiceLogs)) + .setIncludeOpinionMining(getValue(row, includeOpinionMining)) + + val response = client.analyzeSentimentBatchWithResponse(documents, options, Context.NONE).getValue + toResponse(response.asScala) + } +} + +object PIISDK extends ComplexParamsReadable[PII] + +class PIISDK(override val uid: String) extends TextAnalyticsSDKBase[PIIEntityCollectionSDK]() { + logClass() + + def this() = this(Identifiable.randomUID("PIISDK")) + + override val responseBinding: PIIResponseSDK.type = PIIResponseSDK + + override def invokeTextAnalytics(client: TextAnalyticsClient, + input: Seq[String], + lang: Seq[String], + row: Row + ): Seq[TAResponseSDK[PIIEntityCollectionSDK]] = { + val documents = (input, lang, lang.indices).zipped.map { (doc, lang, i) => + new TextDocumentInput(i.toString, doc).setLanguage(lang) + }.asJava + + val options = new RecognizePiiEntitiesOptions() + .setModelVersion(getValue(row, modelVersion)) + .setIncludeStatistics(getValue(row, includeStatistics)) + .setServiceLogsDisabled(getValue(row, disableServiceLogs)) + + val response = client.recognizePiiEntitiesBatchWithResponse(documents, options, Context.NONE).getValue + toResponse(response.asScala) + } +} + +object HealthcareSDK extends ComplexParamsReadable[HealthcareSDK] + +class HealthcareSDK(override val uid: String) extends TextAnalyticsSDKBase[HealthEntitiesResultSDK]() { + logClass() + + def this() = this(Identifiable.randomUID("HealthcareSDK")) + + override val responseBinding: HealthcareResponseSDK.type = HealthcareResponseSDK + + override def invokeTextAnalytics(client: TextAnalyticsClient, + input: Seq[String], + lang: Seq[String], + row: Row + ): Seq[TAResponseSDK[HealthEntitiesResultSDK]] = { + val documents = (input, lang, lang.indices).zipped.map { (doc, lang, i) => + new TextDocumentInput(i.toString, doc).setLanguage(lang) + }.asJava + + val options = new AnalyzeHealthcareEntitiesOptions() + .setModelVersion(getValue(row, modelVersion)) + .setIncludeStatistics(getValue(row, includeStatistics)) + .setServiceLogsDisabled(getValue(row, disableServiceLogs)) + + val poller = client.beginAnalyzeHealthcareEntities(documents, options, Context.NONE) + poller.waitForCompletion() + + val pagedResults = poller.getFinalResult.asScala + toResponse(pagedResults.flatMap(_.asScala)) + } +} + +object EntityDetectorSDK extends ComplexParamsReadable[EntityDetectorSDK] + +class EntityDetectorSDK(override val uid: String) extends TextAnalyticsSDKBase[LinkedEntityCollectionSDK]() { + logClass() + + def this() = this(Identifiable.randomUID("EntityDetectorSDK")) + + override val responseBinding: LinkedEntityResponseSDK.type = LinkedEntityResponseSDK + + override def invokeTextAnalytics(client: TextAnalyticsClient, + input: Seq[String], + lang: Seq[String], + row: Row + ): Seq[TAResponseSDK[LinkedEntityCollectionSDK]] = { + val documents = (input, lang, lang.indices).zipped.map { (doc, lang, i) => + new TextDocumentInput(i.toString, doc).setLanguage(lang) + }.asJava + + val options = new TextAnalyticsRequestOptions() + .setModelVersion(getValue(row, modelVersion)) + .setIncludeStatistics(getValue(row, includeStatistics)) + .setServiceLogsDisabled(getValue(row, disableServiceLogs)) + + val response = client.recognizeLinkedEntitiesBatchWithResponse(documents, options, Context.NONE).getValue + toResponse(response.asScala) + } +} + +object NERSDK extends ComplexParamsReadable[NERSDK] + +class NERSDK(override val uid: String) extends TextAnalyticsSDKBase[NERCollectionSDK]() { + logClass() + + def this() = this(Identifiable.randomUID("NERSDK")) + + override val responseBinding: NERResponseSDK.type = NERResponseSDK + + override def invokeTextAnalytics(client: TextAnalyticsClient, + input: Seq[String], + lang: Seq[String], + row: Row + ): Seq[TAResponseSDK[NERCollectionSDK]] = { + val documents = (input, lang, lang.indices).zipped.map { (doc, lang, i) => + new TextDocumentInput(i.toString, doc).setLanguage(lang) + }.asJava + + val options = new TextAnalyticsRequestOptions() + .setModelVersion(getValue(row, modelVersion)) + .setIncludeStatistics(getValue(row, includeStatistics)) + .setServiceLogsDisabled(getValue(row, disableServiceLogs)) + + val response = client.recognizeEntitiesBatchWithResponse(documents, options, Context.NONE).getValue + toResponse(response.asScala) + } +} diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/TextAnalyticsSDKSchemasV4.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/TextAnalyticsSDKSchemasV4.scala new file mode 100644 index 0000000000..9a07d1f31c --- /dev/null +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/TextAnalyticsSDKSchemasV4.scala @@ -0,0 +1,351 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.cognitive + +import com.azure.ai.textanalytics.models._ +import com.microsoft.azure.synapse.ml.core.schema.SparkBindings + +import scala.collection.JavaConverters._ +import scala.language.implicitConversions + +object DetectLanguageResponseSDK extends SparkBindings[TAResponseSDK[DetectedLanguageSDK]] + +object KeyPhraseResponseSDK extends SparkBindings[TAResponseSDK[KeyphraseSDK]] + +object SentimentResponseSDK extends SparkBindings[TAResponseSDK[SentimentScoredDocumentSDK]] + +object PIIResponseSDK extends SparkBindings[TAResponseSDK[PIIEntityCollectionSDK]] + +object HealthcareResponseSDK extends SparkBindings[TAResponseSDK[HealthEntitiesResultSDK]] + +object LinkedEntityResponseSDK extends SparkBindings[TAResponseSDK[LinkedEntityCollectionSDK]] + +object NERResponseSDK extends SparkBindings[TAResponseSDK[NERCollectionSDK]] + +case class TAResponseSDK[T](result: Option[T], + error: Option[TAErrorSDK], + statistics: Option[DocumentStatistics]) + +case class DetectedLanguageSDK(name: String, + iso6391Name: String, + confidenceScore: Double, + warnings: Seq[TAWarningSDK]) + +case class TAErrorSDK(errorCode: String, errorMessage: String, target: String) + +case class TAWarningSDK(warningCode: String, message: String) + +case class TextDocumentInputs(id: String, text: String) + +case class KeyphraseSDK(keyPhrases: Seq[String], warnings: Seq[TAWarningSDK]) + +case class SentimentConfidenceScoreSDK(negative: Double, neutral: Double, positive: Double) + +case class SentimentScoredDocumentSDK(sentiment: String, + confidenceScores: SentimentConfidenceScoreSDK, + sentences: Seq[SentimentSentenceSDK], + warnings: Seq[TAWarningSDK]) + +case class SentimentSentenceSDK(text: String, + sentiment: String, + confidenceScores: SentimentConfidenceScoreSDK, + opinions: Option[Seq[OpinionSDK]], + offset: Int, + length: Int) + +case class OpinionSDK(target: TargetSDK, assessments: Seq[AssessmentSDK]) + +case class TargetSDK(text: String, + sentiment: String, + confidenceScores: SentimentConfidenceScoreSDK, + offset: Int, + length: Int) + +case class AssessmentSDK(text: String, + sentiment: String, + confidenceScores: SentimentConfidenceScoreSDK, + isNegated: Boolean, + offset: Int, + length: Int) + +case class PIIEntityCollectionSDK(entities: Seq[PIIEntitySDK], + redactedText: String, + warnings: Seq[TAWarningSDK]) + +case class PIIEntitySDK(text: String, + category: String, + subCategory: String, + confidenceScore: Double, + offset: Int, + length: Int) + +case class HealthEntitiesResultSDK(id: String, + warnings: Seq[TAWarningSDK], + entities: Seq[HealthcareEntitySDK], + entityRelation: Seq[HealthcareEntityRelationSDK]) + +case class HealthEntitiesOperationDetailSDK(createdAt: String, + expiresAt: String, + lastModifiedAt: String, + operationId: String) + +case class EntityDataSourceSDK(name: String, + entityId: String) + +case class HealthcareEntitySDK(assertion: Option[HealthcareEntityAssertionSDK], + category: String, + confidenceScore: Double, + dataSources: Seq[EntityDataSourceSDK], + length: Int, + normalizedText: String, + offset: Int, + subCategory: String, + text: String) + +case class HealthcareEntityAssertionSDK(association: Option[String], + certainty: Option[String], + conditionality: Option[String]) + +case class HealthcareEntityRelationSDK(relationType: String, + roles: Seq[HealthcareEntityRelationRoleSDK]) + +case class HealthcareEntityRelationRoleSDK(entity: HealthcareEntitySDK, name: String) + +case class LinkedEntityCollectionSDK(entities: Seq[LinkedEntitySDK], + warnings: Seq[TAWarningSDK]) + +case class LinkedEntitySDK(name: String, + matches: Seq[LinkedEntityMatchSDK], + language: String, + dataSourceEntityId: String, + url: String, + dataSource: String, + bingEntitySearchApiId: String) + +case class LinkedEntityMatchSDK(text: String, + confidenceScore: Double, + offset: Int, + length: Int) + +case class NERCollectionSDK(entities: Seq[NEREntitySDK], warnings: Seq[TAWarningSDK]) + +case class NEREntitySDK(text: String, + category: String, + subCategory: String, + confidenceScore: Double, + offset: Int, + length: Int) + +object SDKConverters { + implicit def fromSDK(score: SentimentConfidenceScores): SentimentConfidenceScoreSDK = { + SentimentConfidenceScoreSDK( + score.getNegative, + score.getNeutral, + score.getPositive) + } + + implicit def fromSDK(target: TargetSentiment): TargetSDK = { + TargetSDK( + target.getText, + target.getSentiment.toString, + target.getConfidenceScores, + target.getOffset, + target.getLength) + } + + implicit def fromSDK(assess: AssessmentSentiment): AssessmentSDK = { + AssessmentSDK( + assess.getText, + assess.getSentiment.toString, + assess.getConfidenceScores, + assess.isNegated, + assess.getOffset, + assess.getLength) + } + + implicit def fromSDK(op: SentenceOpinion): OpinionSDK = { + OpinionSDK( + op.getTarget, + op.getAssessments.asScala.toSeq.map(fromSDK) + ) + } + + implicit def fromSDK(ss: SentenceSentiment): SentimentSentenceSDK = { + SentimentSentenceSDK( + ss.getText, + ss.getSentiment.toString, + ss.getConfidenceScores, + Option(ss.getOpinions).map(sentenceOpinions => + sentenceOpinions.asScala.toSeq.map(fromSDK) + ), + ss.getOffset, + ss.getLength) + } + + implicit def fromSDK(warning: TextAnalyticsWarning): TAWarningSDK = { + TAWarningSDK(warning.getMessage, warning.getWarningCode.toString) + } + + implicit def fromSDK(error: TextAnalyticsError): TAErrorSDK = { + TAErrorSDK( + error.getErrorCode.toString, + error.getMessage, + error.getTarget) + } + + implicit def fromSDK(s: TextDocumentStatistics): DocumentStatistics = { + DocumentStatistics(s.getCharacterCount, s.getTransactionCount) + } + + implicit def fromSDK(doc: AnalyzeSentimentResult): SentimentScoredDocumentSDK = { + SentimentScoredDocumentSDK( + doc.getDocumentSentiment.getSentiment.toString, + doc.getDocumentSentiment.getConfidenceScores, + doc.getDocumentSentiment.getSentences.asScala.toSeq.map(fromSDK), + doc.getDocumentSentiment.getWarnings.asScala.toSeq.map(fromSDK)) + } + + implicit def fromSDK(phrases: ExtractKeyPhraseResult): KeyphraseSDK = { + KeyphraseSDK( + phrases.getKeyPhrases.asScala.toSeq, + phrases.getKeyPhrases.getWarnings.asScala.toSeq.map(fromSDK)) + } + + implicit def fromSDK(result: DetectLanguageResult): DetectedLanguageSDK = { + DetectedLanguageSDK( + result.getPrimaryLanguage.getName, + result.getPrimaryLanguage.getIso6391Name, + result.getPrimaryLanguage.getConfidenceScore, + result.getPrimaryLanguage.getWarnings.asScala.toSeq.map(fromSDK)) + } + + implicit def fromSDK(ent: PiiEntity): PIIEntitySDK = { + PIIEntitySDK( + ent.getText, + ent.getCategory.toString, + ent.getSubcategory, + ent.getConfidenceScore, + ent.getOffset, + ent.getLength) + } + + implicit def fromSDK(entity: RecognizePiiEntitiesResult): PIIEntityCollectionSDK = { + PIIEntityCollectionSDK( + entity.getEntities.asScala.toSeq.map(fromSDK), + entity.getEntities.getRedactedText, + entity.getEntities.getWarnings.asScala.toSeq.map(fromSDK)) + } + + implicit def fromSDK(ent: EntityDataSource): EntityDataSourceSDK = { + EntityDataSourceSDK( + ent.getName, + ent.getEntityId + ) + } + + implicit def fromSDK(entity: AnalyzeHealthcareEntitiesResult): HealthEntitiesResultSDK = { + HealthEntitiesResultSDK( + entity.getId, + entity.getWarnings.asScala.toSeq.map(fromSDK), + entity.getEntities.asScala.toSeq.map(fromSDK), + entity.getEntityRelations.asScala.toSeq.map(fromSDK) + ) + } + + implicit def fromSDK(ent: HealthcareEntity): HealthcareEntitySDK = { + HealthcareEntitySDK( + Option(ent.getAssertion).map(fromSDK), + ent.getCategory.toString, + ent.getConfidenceScore, + ent.getDataSources.asScala.toSeq.map(fromSDK), + ent.getLength, + ent.getNormalizedText, + ent.getOffset, + ent.getSubcategory, + ent.getText + ) + } + + implicit def fromSDK(entityAssertion: HealthcareEntityAssertion): HealthcareEntityAssertionSDK = { + HealthcareEntityAssertionSDK( + Option(entityAssertion.getAssociation).map(_.toString), + Option(entityAssertion.getCertainty).map(_.toString), + Option(entityAssertion.getConditionality).map(_.toString) + ) + } + + implicit def fromSDK(rel: HealthcareEntityRelation): HealthcareEntityRelationSDK = { + HealthcareEntityRelationSDK( + rel.getRelationType.toString, + rel.getRoles.asScala.toSeq.map(fromSDK) + ) + } + + implicit def fromSDK(role: HealthcareEntityRelationRole): HealthcareEntityRelationRoleSDK = { + HealthcareEntityRelationRoleSDK( + role.getEntity, + role.getName + ) + } + + implicit def fromSDK(entity: RecognizeLinkedEntitiesResult): LinkedEntityCollectionSDK = { + LinkedEntityCollectionSDK( + entity.getEntities.asScala.toSeq.map(fromSDK), + entity.getEntities.getWarnings.asScala.toSeq.map(fromSDK) + ) + } + + implicit def fromSDK(ent: LinkedEntity): LinkedEntitySDK = { + LinkedEntitySDK( + ent.getName, + ent.getMatches.asScala.toSeq.map(fromSDK), + ent.getLanguage, + ent.getDataSourceEntityId, + ent.getUrl, + ent.getDataSource, + ent.getBingEntitySearchApiId + ) + } + + implicit def fromSDK(ent: LinkedEntityMatch): LinkedEntityMatchSDK = { + LinkedEntityMatchSDK( + ent.getText, + ent.getConfidenceScore, + ent.getOffset, + ent.getLength + ) + } + + implicit def fromSDK(entity: CategorizedEntity): NEREntitySDK = { + NEREntitySDK( + entity.getText, + entity.getCategory.toString, + entity.getSubcategory, + entity.getConfidenceScore, + entity.getOffset, + entity.getLength) + } + + implicit def fromSDK(entity: RecognizeEntitiesResult): NERCollectionSDK = { + NERCollectionSDK( + entity.getEntities.asScala.toSeq.map(fromSDK), + entity.getEntities.getWarnings.asScala.toSeq.map(fromSDK)) + } + + def unpackResult[T <: TextAnalyticsResult, U](result: T)(implicit converter: T => U): + (Option[TAErrorSDK], Option[DocumentStatistics], Option[U]) = { + if (result.isError) { + (Some(fromSDK(result.getError)), None, None) + } else { + (None, Option(result.getStatistics).map(fromSDK), Some(converter(result))) + } + } + + def toResponse[T <: TextAnalyticsResult, U](rc: Iterable[T])(implicit converter: T => U) + : Seq[TAResponseSDK[U]] = { + rc.map(unpackResult(_)(converter)) + .toSeq.map(tup => new TAResponseSDK[U](tup._3, tup._1, tup._2)) + } +} + diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/split1/TextAnalyticsSDKSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/split1/TextAnalyticsSDKSuite.scala new file mode 100644 index 0000000000..0cd005393d --- /dev/null +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/split1/TextAnalyticsSDKSuite.scala @@ -0,0 +1,265 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.cognitive.split1 + +import com.microsoft.azure.synapse.ml.cognitive._ +import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing} +import com.microsoft.azure.synapse.ml.stages.FixedMiniBatchTransformer +import org.apache.spark.ml.util.MLReadable +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.functions.{col, explode} +import org.apache.spark.sql.{DataFrame, Row} + +class LanguageDetectorSDKSuite extends TransformerFuzzing[LanguageDetectorSDK] with TextEndpoint { + + import spark.implicits._ + + lazy val df: DataFrame = Seq( + "Hello World", + "Bonjour tout le monde", + "La carretera estaba atascada. Había mucho tráfico el día de ayer.", + ":) :( :D" + ).toDF("text") + + lazy val detector: LanguageDetectorSDK = new LanguageDetectorSDK() + .setSubscriptionKey(textKey) + .setLocation(textApiLocation) + .setOutputCol("replies") + + test("Basic Usage") { + val replies = detector.transform(df) + .select("replies.result.name") + .collect().toList + assert(replies.head.getString(0) == "English" && replies(2).getString(0) == "Spanish") + } + + test("Manual Batching") { + val batchedDF = new FixedMiniBatchTransformer().setBatchSize(10).transform(df.coalesce(1)) + val tdf = detector.transform(batchedDF) + val replies = tdf.collect().head.getAs[Seq[Row]]("replies") + assert(replies.length == 4) + val languages = replies.map(_.getAs[Row]("result").getAs[String]("name")).toSet + assert(languages("Spanish") && languages("English")) + } + + override def testObjects(): Seq[TestObject[LanguageDetectorSDK]] = + Seq(new TestObject[LanguageDetectorSDK](detector, df)) + + override def reader: MLReadable[_] = LanguageDetectorSDK +} + +class EntityDetectorSDKSuite extends TransformerFuzzing[EntityDetectorSDK] with TextEndpoint { + + import spark.implicits._ + + lazy val df: DataFrame = Seq( + ("1", "Microsoft released Windows 10"), + ("2", "In 1975, Bill Gates III and Paul Allen founded the company.") + ).toDF("id", "text") + + lazy val detector: EntityDetectorSDK = new EntityDetectorSDK() + .setSubscriptionKey(textKey) + .setLocation(textApiLocation) + .setLanguage("en") + .setOutputCol("replies") + + test("Basic Usage") { + val results = detector.transform(df) + .withColumn("entities", + col("replies.result.entities").getItem("name")) + .select("id", "entities").collect().toList + println(results) + assert(results.head.getSeq[String](1).toSet + .intersect(Set("Windows 10", "Windows 10 Mobile", "Microsoft")).size == 2) + } + + override def testObjects(): Seq[TestObject[EntityDetectorSDK]] = + Seq(new TestObject[EntityDetectorSDK](detector, df)) + + override def reader: MLReadable[_] = EntityDetectorSDK +} + +class TextSentimentSDKSuite extends TransformerFuzzing[TextSentimentSDK] with TextSentimentBaseSuite { + lazy val t: TextSentimentSDK = new TextSentimentSDK() + .setSubscriptionKey(textKey) + .setLocation("eastus") + .setLanguageCol("lang") + .setModelVersion("latest") + .setIncludeStatistics(true) + .setOutputCol("replies") + + test("Basic Usage") { + val results = t.transform(df).select( + col("replies.result.sentiment") + ).collect().map(_.getAs[String](0)) + + assert(List(4, 5).forall(results(_) == null)) + assert(results(0) == "positive" && results(2) == "negative") + } + + test("batch usage") { + val t = new TextSentimentSDK() + .setSubscriptionKey(textKey) + .setLocation(textApiLocation) + .setTextCol("text") + .setLanguage("en") + .setOutputCol("score") + val batchedDF = new FixedMiniBatchTransformer().setBatchSize(10).transform(df.coalesce(1)) + val tdf = t.transform(batchedDF) + val replies = tdf.collect().head.getAs[Seq[Row]]("score") + assert(replies.length == 6) + } + + override def testObjects(): Seq[TestObject[TextSentimentSDK]] = + Seq(new TestObject[TextSentimentSDK](t, df)) + + override def reader: MLReadable[_] = TextSentimentSDK +} + +class KeyPhraseExtractorSDKSuite extends TransformerFuzzing[KeyPhraseExtractorSDK] with TextEndpoint { + + import spark.implicits._ + + //noinspection ScalaStyle + lazy val df: DataFrame = Seq( + ("en", "Hello world. This is some input text that I love."), + ("fr", "Bonjour tout le monde"), + ("es", "La carretera estaba atascada. Había mucho tráfico el día de ayer."), + ("en", null) + ).toDF("lang", "text") + + lazy val t: KeyPhraseExtractorSDK = new KeyPhraseExtractorSDK() + .setSubscriptionKey(textKey) + .setLocation(textApiLocation) + .setLanguageCol("lang") + .setOutputCol("replies") + + test("Basic Usage") { + val results = t.transform(df).withColumn("phrases", + col("replies.result.keyPhrases")) + .select("phrases").collect().toList + + println(results) + + assert(results.head.getSeq[String](0).toSet === Set("Hello world", "input text")) + assert(results(2).getSeq[String](0).toSet === Set("mucho tráfico", "día", "carretera", "ayer")) + } + + override def testObjects(): Seq[TestObject[KeyPhraseExtractorSDK]] = + Seq(new TestObject[KeyPhraseExtractorSDK](t, df)) + + override def reader: MLReadable[_] = KeyPhraseExtractorSDK +} + +class NERSDKSuite extends TransformerFuzzing[NERSDK] with TextEndpoint { + + import spark.implicits._ + + lazy val df: DataFrame = Seq( + ("en", "I had a wonderful trip to Seattle last week."), + ("en", "I visited Space Needle 2 times.") + ).toDF("language", "text") + + lazy val n: NERSDK = new NERSDK() + .setSubscriptionKey(textKey) + .setLocation(textApiLocation) + .setLanguage("en") + .setOutputCol("response") + + test("Basic Usage") { + val results = n.transform(df) + val matches = results.withColumn("match", + col("response.result.entities") + .getItem(0)) + .select("match") + + val testRow = matches.collect().head(0).asInstanceOf[GenericRowWithSchema] + + assert(testRow.getAs[String]("text") === "trip") + assert(testRow.getAs[Int]("offset") === 18) + assert(testRow.getAs[Int]("length") === 4) + assert(testRow.getAs[Double]("confidenceScore") > 0.7) + assert(testRow.getAs[String]("category") === "Event") + + } + + override def testObjects(): Seq[TestObject[NERSDK]] = + Seq(new TestObject[NERSDK](n, df)) + + override def reader: MLReadable[_] = NERSDK +} + +class PIISDKSuite extends TransformerFuzzing[PIISDK] with TextEndpoint { + + import spark.implicits._ + + lazy val df: DataFrame = Seq( + ("en", "My SSN is 859-98-0987"), + ("en", "Your ABA number - 111000025 - is the first 9 digits in the lower left hand corner of check."), + ("en", "Is 998.214.865-68 your Brazilian CPF number?") + ).toDF("language", "text") + + lazy val n: PIISDK = new PIISDK() + .setSubscriptionKey(textKey) + .setLocation(textApiLocation) + .setLanguage("en") + .setOutputCol("response") + + test("Basic Usage") { + val results = n.transform(df) + + val redactedText = results + .select("response.result.redactedText") + .collect().head(0).toString + assert(redactedText === "My SSN is ***********") + + val matches = results.withColumn("match", + col("response.result.entities").getItem(0)) + .select("match") + + val testRow = matches.collect().head(0).asInstanceOf[GenericRowWithSchema] + + assert(testRow.getAs[String]("text") === "859-98-0987") + assert(testRow.getAs[Int]("offset") === 10) + assert(testRow.getAs[Int]("length") === 11) + assert(testRow.getAs[Double]("confidenceScore") > 0.6) + assert(testRow.getAs[String]("category") === "USSocialSecurityNumber") + } + + override def testObjects(): Seq[TestObject[PIISDK]] = + Seq(new TestObject[PIISDK](n, df)) + + override def reader: MLReadable[_] = PIISDK +} + +class HealthcareSDKSuite extends TransformerFuzzing[HealthcareSDK] with TextEndpoint { + + import spark.implicits._ + + lazy val df: DataFrame = Seq( + ("en", "20mg of ibuprofen twice a day"), + ("en", "1tsp of Tylenol every 4 hours"), + ("en", "6-drops of Vitamin B-12 every evening") + ).toDF("language", "text") + + lazy val extractor: HealthcareSDK = new HealthcareSDK() + .setSubscriptionKey(textKey) + .setLocation(textApiLocation) + .setLanguage("en") + .setOutputCol("response") + + test("Basic Usage") { + val results = extractor.transform(df) + val fromRow = HealthcareResponseSDK.makeFromRowConverter + val parsed = results.select("response").collect().map(r => fromRow(r.getStruct(0))) + assert(parsed.head.result.head.entities.head.category === "Dosage") + assert(parsed.head.result.head.entityRelation.head.relationType === "DosageOfMedication") + } + + override def testObjects(): Seq[TestObject[HealthcareSDK]] = Seq(new TestObject( + extractor, df + )) + + override def reader: MLReadable[_] = HealthcareSDK +} diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/split1/TextAnalyticsSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/split1/TextAnalyticsSuite.scala index 653476bad5..d231492df4 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/split1/TextAnalyticsSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/split1/TextAnalyticsSuite.scala @@ -134,7 +134,7 @@ class EntityDetectorSuite extends TransformerFuzzing[EntityDetectorV2] with Text override def reader: MLReadable[_] = EntityDetectorV2 } -class EntityDetectorSuiteV3 extends TransformerFuzzing[EntityDetector] with TextEndpoint { +class EntityDetectorV3Suite extends TransformerFuzzing[EntityDetector] with TextEndpoint { import spark.implicits._ @@ -202,7 +202,7 @@ class TextSentimentV3Suite extends TransformerFuzzing[TextSentiment] with TextSe test("batch usage") { val t = new TextSentiment() .setSubscriptionKey(textKey) - .setLocation("eastus") + .setLocation(textApiLocation) .setTextCol("text") .setLanguage("en") .setOutputCol("score") @@ -222,7 +222,7 @@ class TextSentimentSuite extends TransformerFuzzing[TextSentimentV2] with TextSe lazy val t: TextSentimentV2 = new TextSentimentV2() .setSubscriptionKey(textKey) - .setUrl(s"https://$textApiLocation.api.cognitive.microsoft.com/text/analytics/v2.0/sentiment") + .setLocation(textApiLocation) .setLanguageCol("lang") .setOutputCol("replies") @@ -267,7 +267,7 @@ class KeyPhraseExtractorSuite extends TransformerFuzzing[KeyPhraseExtractorV2] w lazy val t: KeyPhraseExtractorV2 = new KeyPhraseExtractorV2() .setSubscriptionKey(textKey) - .setUrl(s"https://$textApiLocation.api.cognitive.microsoft.com/text/analytics/v2.0/keyPhrases") + .setLocation(textApiLocation) .setLanguageCol("lang") .setOutputCol("replies") @@ -302,7 +302,7 @@ class KeyPhraseExtractorV3Suite extends TransformerFuzzing[KeyPhraseExtractor] w lazy val t: KeyPhraseExtractor = new KeyPhraseExtractor() .setSubscriptionKey(textKey) - .setUrl(s"https://$textApiLocation.api.cognitive.microsoft.com/text/analytics/v3.0/keyPhrases") + .setLocation(textApiLocation) .setLanguageCol("lang") .setOutputCol("replies") @@ -363,7 +363,7 @@ class NERSuite extends TransformerFuzzing[NERV2] with TextEndpoint { override def reader: MLReadable[_] = NERV2 } -class NERSuiteV3 extends TransformerFuzzing[NER] with TextEndpoint { +class NERV3Suite extends TransformerFuzzing[NER] with TextEndpoint { import spark.implicits._ @@ -403,7 +403,7 @@ class NERSuiteV3 extends TransformerFuzzing[NER] with TextEndpoint { override def reader: MLReadable[_] = NER } -class PIISuiteV3 extends TransformerFuzzing[PII] with TextEndpoint { +class PIIV3Suite extends TransformerFuzzing[PII] with TextEndpoint { import spark.implicits._ @@ -466,9 +466,9 @@ class TextAnalyzeSuite extends TransformerFuzzing[TextAnalyze] with TextEndpoint val batchSize = 25 lazy val dfBatched: DataFrame = Seq( ( - Seq("en", "invalid") ++ Seq.fill(batchSize-2)("en"), + Seq("en", "invalid") ++ Seq.fill(batchSize - 2)("en"), Seq("I had a wonderful trip to Seattle last week and visited Microsoft.", - "This is irrelevant as the language is invalid") ++ Seq.fill(batchSize-2)("Placeholder content") + "This is irrelevant as the language is invalid") ++ Seq.fill(batchSize - 2)("Placeholder content") ) ).toDF("language", "text") diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/core/contracts/Params.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/core/contracts/Params.scala index 18d32ce59f..bfade1d1aa 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/core/contracts/Params.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/core/contracts/Params.scala @@ -5,7 +5,6 @@ package com.microsoft.azure.synapse.ml.core.contracts import org.apache.spark.ml.param._ - trait HasInputCol extends Params { /** The name of the input column * diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala index cc117fc082..ed2ebe6553 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala @@ -37,7 +37,7 @@ trait HasHandler extends Params { getHandler(client, request) } -trait HTTPParams extends Wrappable { +trait ConcurrencyParams extends Wrappable { val concurrency: Param[Int] = new IntParam( this, "concurrency", "max number of concurrent calls") @@ -70,8 +70,7 @@ trait HTTPParams extends Wrappable { case None => clear(concurrentTimeout) } - setDefault(concurrency -> 1, - timeout -> 60.0) + setDefault(concurrency -> 1, timeout -> 60.0) } @@ -90,7 +89,7 @@ trait HasURL extends Params { object HTTPTransformer extends ComplexParamsReadable[HTTPTransformer] class HTTPTransformer(val uid: String) - extends Transformer with HTTPParams with HasInputCol + extends Transformer with ConcurrencyParams with HasInputCol with HasOutputCol with HasHandler with ComplexParamsWritable with BasicLogging { logClass() diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/SimpleHTTPTransformer.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/SimpleHTTPTransformer.scala index 89a64bf2bc..3e3ace8cf9 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/SimpleHTTPTransformer.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/SimpleHTTPTransformer.scala @@ -25,7 +25,7 @@ trait HasErrorCol extends Params { def setErrorCol(v: String): this.type = set(errorCol, v) def getErrorCol: String = $(errorCol) - + setDefault(errorCol -> "Error") } object ErrorUtils extends Serializable { @@ -62,7 +62,7 @@ object ErrorUtils extends Serializable { } class SimpleHTTPTransformer(val uid: String) - extends Transformer with HTTPParams with HasMiniBatcher with HasHandler + extends Transformer with ConcurrencyParams with HasMiniBatcher with HasHandler with HasInputCol with HasOutputCol with ComplexParamsWritable with HasErrorCol with BasicLogging { logClass() diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/stages/PartitionConsolidator.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/stages/PartitionConsolidator.scala index fa67c26a33..a099b85e0b 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/stages/PartitionConsolidator.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/stages/PartitionConsolidator.scala @@ -4,7 +4,7 @@ package com.microsoft.azure.synapse.ml.stages import com.microsoft.azure.synapse.ml.core.contracts.{HasInputCol, HasOutputCol} -import com.microsoft.azure.synapse.ml.io.http.{HTTPParams, SharedSingleton} +import com.microsoft.azure.synapse.ml.io.http.{ConcurrencyParams, SharedSingleton} import com.microsoft.azure.synapse.ml.logging.BasicLogging import org.apache.spark.ml.param._ import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable} @@ -19,7 +19,7 @@ import scala.concurrent.blocking object PartitionConsolidator extends DefaultParamsReadable[PartitionConsolidator] class PartitionConsolidator(val uid: String) - extends Transformer with HTTPParams with HasInputCol + extends Transformer with ConcurrencyParams with HasInputCol with HasOutputCol with ComplexParamsWritable with BasicLogging { logClass()