diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/CognitiveServiceBase.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/CognitiveServiceBase.scala index b81a5a71f8..3ffcb09719 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/CognitiveServiceBase.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/CognitiveServiceBase.scala @@ -233,6 +233,22 @@ trait HasCustomCogServiceDomain extends Wrappable with HasURL with HasUrlPath { } } +trait HasAPIVersion extends HasServiceParams { + val apiVersion: ServiceParam[String] = new ServiceParam[String]( + this, "apiVersion", "version of the api", isRequired = true, isURLParam = true) { + override val payloadName: String = "api-version" + } + + def getApiVersion: String = getScalarParam(apiVersion) + + def setApiVersion(v: String): this.type = setScalarParam(apiVersion, v) + + def getApiVersionCol: String = getVectorParam(apiVersion) + + def setApiVersionCol(v: String): this.type = setVectorParam(apiVersion, v) + +} + object URLEncodingUtils { private case class NameValuePairInternal(t: (String, String)) extends NameValuePair { diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/form/FormRecognizerV3.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/form/FormRecognizerV3.scala index bfd7b92bdf..3b9dbd36ac 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/form/FormRecognizerV3.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/form/FormRecognizerV3.scala @@ -4,7 +4,6 @@ package com.microsoft.azure.synapse.ml.cognitive.form import com.microsoft.azure.synapse.ml.cognitive._ -import com.microsoft.azure.synapse.ml.cognitive.openai.HasAPIVersion import com.microsoft.azure.synapse.ml.cognitive.vision.{BasicAsyncReply, HasImageInput} import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging import com.microsoft.azure.synapse.ml.param.ServiceParam diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/language/AnalyzeText.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/language/AnalyzeText.scala index 782f4a5fbb..d4847b8bb0 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/language/AnalyzeText.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/language/AnalyzeText.scala @@ -4,7 +4,6 @@ package com.microsoft.azure.synapse.ml.cognitive.language import com.microsoft.azure.synapse.ml.cognitive._ -import com.microsoft.azure.synapse.ml.cognitive.openai.HasAPIVersion import com.microsoft.azure.synapse.ml.cognitive.text.{TADocument, TextAnalyticsAutoBatch} import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging import com.microsoft.azure.synapse.ml.param.ServiceParam diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAI.scala index 96d000cbe0..1d1a655ea3 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAI.scala @@ -4,24 +4,14 @@ package com.microsoft.azure.synapse.ml.cognitive.openai import com.microsoft.azure.synapse.ml.codegen.GenerationUtils -import com.microsoft.azure.synapse.ml.cognitive.{ - CognitiveServicesBase, HasCognitiveServiceInput, - HasInternalJsonOutputParser, HasServiceParams -} -import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging -import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat +import com.microsoft.azure.synapse.ml.cognitive.{HasAPIVersion, HasServiceParams} import com.microsoft.azure.synapse.ml.param.ServiceParam -import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity} -import org.apache.spark.ml.ComplexParamsReadable -import org.apache.spark.ml.util._ import org.apache.spark.sql.Row -import org.apache.spark.sql.types._ import spray.json.DefaultJsonProtocol._ -import spray.json._ import scala.language.existentials -trait HasPrompt extends HasServiceParams { +trait HasPromptInputs extends HasServiceParams { val prompt: ServiceParam[String] = new ServiceParam[String]( this, "prompt", "The text to complete", isRequired = false) @@ -32,9 +22,7 @@ trait HasPrompt extends HasServiceParams { def getPromptCol: String = getVectorParam(prompt) def setPromptCol(v: String): this.type = setVectorParam(prompt, v) -} -trait HasBatchPrompt extends HasServiceParams { val batchPrompt: ServiceParam[Seq[String]] = new ServiceParam[Seq[String]]( this, "batchPrompt", "Sequence of prompts to complete", isRequired = false) @@ -45,65 +33,40 @@ trait HasBatchPrompt extends HasServiceParams { def getBatchPromptCol: String = getVectorParam(batchPrompt) def setBatchPromptCol(v: String): this.type = setVectorParam(batchPrompt, v) -} - -trait HasIndexPrompt extends HasServiceParams { - val indexPrompt: ServiceParam[Seq[Int]] = new ServiceParam[Seq[Int]]( - this, "indexPrompt", "Sequence of indexes to complete", isRequired = false) - def getIndexPrompt: Seq[Int] = getScalarParam(indexPrompt) - - def setIndexPrompt(v: Seq[Int]): this.type = setScalarParam(indexPrompt, v) - - def getIndexPromptCol: String = getVectorParam(indexPrompt) - - def setIndexPromptCol(v: String): this.type = setVectorParam(indexPrompt, v) } -trait HasBatchIndexPrompt extends HasServiceParams { - val batchIndexPrompt: ServiceParam[Seq[Seq[Int]]] = new ServiceParam[Seq[Seq[Int]]]( - this, "batchIndexPrompt", "Sequence of index sequences to complete", isRequired = false) - - def getBatchIndexPrompt: Seq[Seq[Int]] = getScalarParam(batchIndexPrompt) +trait HasOpenAISharedParams extends HasServiceParams with HasAPIVersion { - def setBatchIndexPrompt(v: Seq[Seq[Int]]): this.type = setScalarParam(batchIndexPrompt, v) - - def getBatchIndexPromptCol: String = getVectorParam(batchIndexPrompt) - - def setBatchIndexPromptCol(v: String): this.type = setVectorParam(batchIndexPrompt, v) -} + val deploymentName = new ServiceParam[String]( + this, "deploymentName", "The name of the deployment", isRequired = true) -trait HasAPIVersion extends HasServiceParams { - val apiVersion: ServiceParam[String] = new ServiceParam[String]( - this, "apiVersion", "version of the api", isRequired = true, isURLParam = true) { - override val payloadName: String = "api-version" - } + def getDeploymentName: String = getScalarParam(deploymentName) - def getApiVersion: String = getScalarParam(apiVersion) + def setDeploymentName(v: String): this.type = setScalarParam(deploymentName, v) - def setApiVersion(v: String): this.type = setScalarParam(apiVersion, v) + def getDeploymentNameCol: String = getVectorParam(deploymentName) - def getApiVersionCol: String = getVectorParam(apiVersion) + def setDeploymentNameCol(v: String): this.type = setVectorParam(deploymentName, v) - def setApiVersionCol(v: String): this.type = setVectorParam(apiVersion, v) + val user: ServiceParam[String] = new ServiceParam[String]( + this, "user", + "The ID of the end-user, for use in tracking and rate-limiting.", + isRequired = false) - setDefault(apiVersion -> Left("2022-03-01-preview")) -} + def getUser: String = getScalarParam(user) -trait HasDeploymentName extends HasServiceParams { - val deploymentName = new ServiceParam[String]( - this, "deploymentName", "The name of the deployment", isRequired = true) + def setUser(v: String): this.type = setScalarParam(user, v) - def getDeploymentName: String = getScalarParam(deploymentName) + def getUserCol: String = getVectorParam(user) - def setDeploymentName(v: String): this.type = setScalarParam(deploymentName, v) + def setUserCol(v: String): this.type = setVectorParam(user, v) - def getDeploymentNameCol: String = getVectorParam(deploymentName) + setDefault(apiVersion -> Left("2023-03-15-preview")) - def setDeploymentNameCol(v: String): this.type = setVectorParam(deploymentName, v) } -trait HasMaxTokens extends HasServiceParams { +trait HasOpenAITextParams extends HasOpenAISharedParams { val maxTokens: ServiceParam[Int] = new ServiceParam[Int]( this, "maxTokens", @@ -118,9 +81,6 @@ trait HasMaxTokens extends HasServiceParams { def setMaxTokensCol(v: String): this.type = setVectorParam(maxTokens, v) -} - -trait HasTemperature extends HasServiceParams { val temperature: ServiceParam[Double] = new ServiceParam[Double]( this, "temperature", "What sampling temperature to use. Higher values means the model will take more risks." + @@ -135,24 +95,7 @@ trait HasTemperature extends HasServiceParams { def getTemperatureCol: String = getVectorParam(temperature) def setTemperatureCol(v: String): this.type = setVectorParam(temperature, v) -} - -trait HasModel extends HasServiceParams { - val model: ServiceParam[String] = new ServiceParam[String]( - this, "model", - "The name of the model to use", - isRequired = false) - def getModel: String = getScalarParam(model) - - def setModel(v: String): this.type = setScalarParam(model, v) - - def getModelCol: String = getVectorParam(model) - - def setModelCol(v: String): this.type = setVectorParam(model, v) -} - -trait HasStop extends HasServiceParams { val stop: ServiceParam[String] = new ServiceParam[String]( this, "stop", "A sequence which indicates the end of the current document.", @@ -165,12 +108,6 @@ trait HasStop extends HasServiceParams { def getStopCol: String = getVectorParam(stop) def setStopCol(v: String): this.type = setVectorParam(stop, v) -} - -trait HasOpenAIParams extends HasServiceParams - with HasPrompt with HasBatchPrompt with HasIndexPrompt with HasBatchIndexPrompt - with HasTemperature with HasModel with HasStop - with HasAPIVersion with HasDeploymentName with HasMaxTokens { val topP: ServiceParam[Double] = new ServiceParam[Double]( this, "topP", @@ -189,19 +126,6 @@ trait HasOpenAIParams extends HasServiceParams def setTopPCol(v: String): this.type = setVectorParam(topP, v) - val user: ServiceParam[String] = new ServiceParam[String]( - this, "user", - "The ID of the end-user, for use in tracking and rate-limiting.", - isRequired = false) - - def getUser: String = getScalarParam(user) - - def setUser(v: String): this.type = setScalarParam(user, v) - - def getUserCol: String = getVectorParam(user) - - def setUserCol(v: String): this.type = setVectorParam(user, v) - val n: ServiceParam[Int] = new ServiceParam[Int]( this, "n", "How many snippets to generate for each prompt. Minimum of 1 and maximum of 128 allowed.", @@ -299,5 +223,24 @@ trait HasOpenAIParams extends HasServiceParams def setBestOfCol(v: String): this.type = setVectorParam(bestOf, v) + private[ml] def getOptionalParams(r: Row): Map[String, Any] = { + Seq( + maxTokens, + temperature, + topP, + user, + n, + echo, + stop, + cacheLevel, + presencePenalty, + frequencyPenalty, + bestOf + ).flatMap(param => + getValueOpt(r, param).map(v => (GenerationUtils.camelToSnake(param.name), v)) + ).++(Seq( + getValueOpt(r, logProbs).map(v => ("logprobs", v)) + ).flatten).toMap + } } diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAIChatCompletion.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAIChatCompletion.scala new file mode 100644 index 0000000000..1af328474e --- /dev/null +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAIChatCompletion.scala @@ -0,0 +1,81 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.cognitive.openai + +import com.microsoft.azure.synapse.ml.cognitive.{ + CognitiveServicesBase, HasCognitiveServiceInput, HasInternalJsonOutputParser +} +import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging +import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat +import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity} +import org.apache.spark.ml.ComplexParamsReadable +import org.apache.spark.ml.param.Param +import org.apache.spark.ml.util._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ +import spray.json._ +import spray.json.DefaultJsonProtocol._ + +import scala.language.existentials + +object OpenAIChatCompletion extends ComplexParamsReadable[OpenAIChatCompletion] + +class OpenAIChatCompletion(override val uid: String) extends CognitiveServicesBase(uid) + with HasOpenAITextParams with HasCognitiveServiceInput + with HasInternalJsonOutputParser with SynapseMLLogging { + logClass() + + val messagesCol: Param[String] = new Param[String]( + this, "messagesCol", "The column messages to generate chat completions for," + + " in the chat format. This column should have type Array(Struct(role: String, content: String)).") + + def getMessagesCol: String = $(messagesCol) + + def setMessagesCol(v: String): this.type = set(messagesCol, v) + + def this() = this(Identifiable.randomUID("OpenAIChatCompletion")) + + def urlPath: String = "" + + override private[ml] def internalServiceType: String = "openai" + + override def setCustomServiceName(v: String): this.type = { + setUrl(s"https://$v.openai.azure.com/" + urlPath.stripPrefix("/")) + } + + override protected def prepareUrlRoot: Row => String = { row => + s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/chat/completions" + } + + override protected def prepareEntity: Row => Option[AbstractHttpEntity] = { + r => + lazy val optionalParams: Map[String, Any] = getOptionalParams(r) + val messages = r.getAs[Seq[Row]](getMessagesCol) + Some(getStringEntity(messages, optionalParams)) + } + + override val subscriptionKeyHeaderName: String = "api-key" + + override def shouldSkip(row: Row): Boolean = + super.shouldSkip(row) || Option(row.getAs[Row](getMessagesCol)).isEmpty + + override protected def getVectorParamMap: Map[String, String] = super.getVectorParamMap + .updated("messages", getMessagesCol) + + override def responseDataType: DataType = ChatCompletionResponse.schema + + private[this] def getStringEntity(messages: Seq[Row], optionalParams: Map[String, Any]): StringEntity = { + val mappedMessages: Seq[Map[String, String]] = messages.map { m => + Seq("role", "content", "name").map(n => + n -> Option(m.getAs[String](n)) + ).toMap.filter(_._2.isDefined).mapValues(_.get) + } + val fullPayload = optionalParams.updated("messages", mappedMessages) + new StringEntity(fullPayload.toJson.compactPrint, ContentType.APPLICATION_JSON) + } + +} + + + diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAICompletion.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAICompletion.scala index 5cebee8d4c..7c7d7869ec 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAICompletion.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAICompletion.scala @@ -3,10 +3,8 @@ package com.microsoft.azure.synapse.ml.cognitive.openai -import com.microsoft.azure.synapse.ml.codegen.GenerationUtils import com.microsoft.azure.synapse.ml.cognitive.{ - CognitiveServicesBase, HasCognitiveServiceInput, - HasInternalJsonOutputParser + CognitiveServicesBase, HasCognitiveServiceInput, HasInternalJsonOutputParser } import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat @@ -23,11 +21,11 @@ import scala.language.existentials object OpenAICompletion extends ComplexParamsReadable[OpenAICompletion] class OpenAICompletion(override val uid: String) extends CognitiveServicesBase(uid) - with HasOpenAIParams with HasCognitiveServiceInput + with HasOpenAITextParams with HasPromptInputs with HasCognitiveServiceInput with HasInternalJsonOutputParser with SynapseMLLogging { logClass() - def this() = this(Identifiable.randomUID("OpenAPICompletion")) + def this() = this(Identifiable.randomUID("OpenAICompletion")) def urlPath: String = "" @@ -43,33 +41,11 @@ class OpenAICompletion(override val uid: String) extends CognitiveServicesBase(u override protected def prepareEntity: Row => Option[AbstractHttpEntity] = { r => - lazy val optionalParams: Map[String, Any] = Seq( - maxTokens, - temperature, - topP, - user, - n, - model, - echo, - stop, - cacheLevel, - presencePenalty, - frequencyPenalty, - bestOf - ).flatMap(param => - getValueOpt(r, param).map(v => (GenerationUtils.camelToSnake(param.name), v)) - ).++(Seq( - getValueOpt(r, logProbs).map(v => ("logprobs", v)) - ).flatten).toMap - + lazy val optionalParams: Map[String, Any] = getOptionalParams(r) getValueOpt(r, prompt) .map(prompt => getStringEntity(prompt, optionalParams)) .orElse(getValueOpt(r, batchPrompt) .map(batchPrompt => getStringEntity(batchPrompt, optionalParams))) - .orElse(getValueOpt(r, indexPrompt) - .map(indexPrompt => getStringEntity(indexPrompt, optionalParams))) - .orElse(getValueOpt(r, batchIndexPrompt) - .map(batchIndexPrompt => getStringEntity(batchIndexPrompt, optionalParams))) .orElse(throw new IllegalArgumentException( "Please set one of prompt, batchPrompt, indexPrompt or batchIndexPrompt.")) } @@ -77,10 +53,8 @@ class OpenAICompletion(override val uid: String) extends CognitiveServicesBase(u override val subscriptionKeyHeaderName: String = "api-key" override def shouldSkip(row: Row): Boolean = - emptyParamData(row, prompt) && - emptyParamData(row, batchPrompt) && - emptyParamData(row, indexPrompt) && - emptyParamData(row, batchIndexPrompt) + super.shouldSkip(row) || + (emptyParamData(row, prompt) && emptyParamData(row, batchPrompt)) override def responseDataType: DataType = CompletionResponse.schema @@ -90,3 +64,4 @@ class OpenAICompletion(override val uid: String) extends CognitiveServicesBase(u } } + diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAIEmbedding.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAIEmbedding.scala index 8e46ef8f49..c241d7e33f 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAIEmbedding.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAIEmbedding.scala @@ -23,8 +23,7 @@ import scala.language.existentials object OpenAIEmbedding extends ComplexParamsReadable[OpenAIEmbedding] class OpenAIEmbedding (override val uid: String) extends CognitiveServicesBase(uid) - with HasServiceParams with HasAPIVersion with HasDeploymentName - with HasCognitiveServiceInput with SynapseMLLogging { + with HasOpenAISharedParams with HasCognitiveServiceInput with SynapseMLLogging { logClass() def this() = this(Identifiable.randomUID("OpenAIEmbedding")) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAIPrompt.scala index d5855323b0..894ef1c93a 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAIPrompt.scala @@ -9,7 +9,7 @@ import com.microsoft.azure.synapse.ml.core.spark.Functions import com.microsoft.azure.synapse.ml.io.http.{ConcurrencyParams, HasErrorCol, HasURL} import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging import com.microsoft.azure.synapse.ml.param.StringStringMapParam -import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} +import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer} import org.apache.spark.sql.types.{DataType, StructType} @@ -20,11 +20,12 @@ import scala.collection.JavaConverters._ object OpenAIPrompt extends ComplexParamsReadable[OpenAIPrompt] class OpenAIPrompt(override val uid: String) extends Transformer - with HasAPIVersion with HasDeploymentName with HasMaxTokens - with HasTemperature with HasModel with HasStop - with HasURL with HasSubscriptionKey with HasAADToken - with ComplexParamsWritable with HasErrorCol with ConcurrencyParams - with HasOutputCol with HasCustomCogServiceDomain with SynapseMLLogging { + with HasOpenAITextParams + with HasErrorCol with HasOutputCol + with HasURL with HasCustomCogServiceDomain with ConcurrencyParams + with HasSubscriptionKey with HasAADToken + with ComplexParamsWritable with SynapseMLLogging { + logClass() def this() = this(Identifiable.randomUID("OpenAIPrompt")) @@ -60,14 +61,27 @@ class OpenAIPrompt(override val uid: String) extends Transformer def setPostProcessingOptions(v: java.util.HashMap[String, String]): this.type = set(postProcessingOptions, v.asScala.toMap) - setDefault(outputCol -> "out", - postProcessing -> "", postProcessingOptions -> Map.empty) + val dropPrompt = new BooleanParam( + this, "dropPrompt", "whether to drop the column of prompts after templating") + + def getDropPrompt: Boolean = $(dropPrompt) + + def setDropPrompt(value: Boolean): this.type = set(dropPrompt, value) + + setDefault( + postProcessing -> "", + postProcessingOptions -> Map.empty, + outputCol -> (this.uid + "_output"), + errorCol -> (this.uid + "_error"), + dropPrompt -> true + ) override def setCustomServiceName(v: String): this.type = { setUrl(s"https://$v.openai.azure.com/" + urlPath.stripPrefix("/")) } - private val localParamNames = Seq("promptTemplate", "outputCol", "postProcessing", "postProcessingOptions") + private val localParamNames = Seq( + "promptTemplate", "outputCol", "postProcessing", "postProcessingOptions", "dropPrompt") override def transform(dataset: Dataset[_]): DataFrame = { import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions._ @@ -82,12 +96,18 @@ class OpenAIPrompt(override val uid: String) extends Transformer val completion = openAICompletion.setPromptCol(promptColName) // run completion - completion + val results = completion .transform(dfTemplated) .withColumn(getOutputCol, getParser.parse(F.element_at(F.col(completion.getOutputCol).getField("choices"), 1) .getField("text"))) .drop(completion.getOutputCol) + + if (getDropPrompt) { + results.drop(promptColName) + } else { + results + } }) } @@ -114,6 +134,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer case _ => throw new IllegalArgumentException(s"Unsupported postProcessing type: '$getPostProcessing'") } } + override def transformSchema(schema: StructType): StructType = openAICompletion .transformSchema(schema) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAISchemas.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAISchemas.scala index c44bf2073a..6644eaaa11 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAISchemas.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAISchemas.scala @@ -4,6 +4,8 @@ package com.microsoft.azure.synapse.ml.cognitive.openai import com.microsoft.azure.synapse.ml.core.schema.SparkBindings +import org.apache.spark.sql.Row +import spray.json.{DefaultJsonProtocol, RootJsonFormat} object CompletionResponse extends SparkBindings[CompletionResponse] @@ -32,3 +34,21 @@ case class EmbeddingResponse(`object`: String, case class EmbeddingObject(`object`: String, embedding: Array[Double], index: Int) + +case class OpenAIMessage(role: String, content: String, name: Option[String] = None) + +case class OpenAIChatChoice(message: OpenAIMessage, + index: Long, + finish_reason: String) + +case class ChatCompletionResponse(id: String, + `object`: String, + created: String, + model: String, + choices: Seq[OpenAIChatChoice]) + +object ChatCompletionResponse extends SparkBindings[ChatCompletionResponse] + +object OpenAIJsonProtocol extends DefaultJsonProtocol { + implicit val MessageEnc: RootJsonFormat[OpenAIMessage] = jsonFormat3(OpenAIMessage.apply) +} diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAIChatCompletionSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAIChatCompletionSuite.scala new file mode 100644 index 0000000000..b138a488fc --- /dev/null +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAIChatCompletionSuite.scala @@ -0,0 +1,79 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.cognitive.openai + +import com.microsoft.azure.synapse.ml.core.test.base.Flaky +import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing} +import org.apache.spark.ml.util.MLReadable +import org.apache.spark.sql.{DataFrame, Row} +import org.scalactic.Equality + +class OpenAIChatCompletionSuite extends TransformerFuzzing[OpenAIChatCompletion] with OpenAIAPIKey with Flaky { + + import spark.implicits._ + + lazy val completion: OpenAIChatCompletion = new OpenAIChatCompletion() + .setDeploymentName(deploymentName) + .setCustomServiceName(openAIServiceName) + .setMaxTokens(200) + .setOutputCol("out") + .setMessagesCol("messages") + .setSubscriptionKey(openAIAPIKey) + + + lazy val goodDf: DataFrame = Seq( + Seq( + OpenAIMessage("system", "You are an AI chatbot with red as your favorite color"), + OpenAIMessage("user", "Whats your favorite color") + ), + Seq( + OpenAIMessage("system", "You are very excited"), + OpenAIMessage("user", "How are you today") + ), + Seq( + OpenAIMessage("system", "You are very excited"), + OpenAIMessage("user", "How are you today"), + OpenAIMessage("system", "Better than ever"), + OpenAIMessage("user", "Why?") + ) + ).toDF("messages") + + lazy val badDf: DataFrame = Seq( + Seq(), + Seq( + OpenAIMessage("system", "You are very excited"), + OpenAIMessage("user", null) //scalastyle:ignore null + ), + null //scalastyle:ignore null + ).toDF("messages") + + test("Basic Usage") { + testCompletion(completion, goodDf) + } + + test("Robustness to bad inputs") { + val results = completion.transform(badDf).collect() + assert(Option(results.head.getAs[Row](completion.getErrorCol)).isDefined) + assert(Option(results.apply(1).getAs[Row](completion.getErrorCol)).isDefined) + assert(Option(results.apply(2).getAs[Row](completion.getErrorCol)).isEmpty) + assert(Option(results.apply(2).getAs[Row]("out")).isEmpty) + } + + def testCompletion(completion: OpenAIChatCompletion, df: DataFrame, requiredLength: Int = 10): Unit = { + val fromRow = ChatCompletionResponse.makeFromRowConverter + completion.transform(df).collect().foreach(r => + fromRow(r.getAs[Row]("out")).choices.foreach(c => + assert(c.message.content.length > requiredLength))) + } + + override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = { + super.assertDFEq(df1.drop("out"), df2.drop("out"))(eq) + } + + override def testObjects(): Seq[TestObject[OpenAIChatCompletion]] = + Seq(new TestObject(completion, goodDf)) + + override def reader: MLReadable[_] = OpenAIChatCompletion + +} diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAICompletionSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAICompletionSuite.scala index b6cf1520b7..ac2329a868 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAICompletionSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAICompletionSuite.scala @@ -14,27 +14,23 @@ import org.scalactic.Equality trait OpenAIAPIKey { lazy val openAIAPIKey: String = sys.env.getOrElse("OPENAI_API_KEY", Secrets.OpenAIApiKey) lazy val openAIServiceName: String = "synapseml-openai" + lazy val deploymentName: String = "gpt-35-turbo" + lazy val modelName: String = "gpt-35-turbo" } class OpenAICompletionSuite extends TransformerFuzzing[OpenAICompletion] with OpenAIAPIKey with Flaky { import spark.implicits._ - lazy val completion: OpenAICompletion = new OpenAICompletion() - .setSubscriptionKey(openAIAPIKey) - .setDeploymentName("text-davinci-001") - .setModel("text-davinci-003") + def newCompletion: OpenAICompletion = new OpenAICompletion() + .setDeploymentName(deploymentName) .setCustomServiceName(openAIServiceName) - .setMaxTokens(20) - .setLogProbs(5) - .setPromptCol("prompt") + .setMaxTokens(200) .setOutputCol("out") + .setSubscriptionKey(openAIAPIKey) lazy val promptCompletion: OpenAICompletion = newCompletion.setPromptCol("prompt") lazy val batchPromptCompletion: OpenAICompletion = newCompletion.setBatchPromptCol("batchPrompt") - lazy val indexPromptCompletion: OpenAICompletion = newCompletion.setIndexPromptCol("indexPrompt") - lazy val batchIndexPromptCompletion: OpenAICompletion = newCompletion.setBatchIndexPromptCol("batchIndexPrompt") - lazy val df: DataFrame = Seq( "Once upon a time", @@ -55,51 +51,30 @@ class OpenAICompletionSuite extends TransformerFuzzing[OpenAICompletion] with Op "Knock, knock") ).toDF("batchPrompt") - lazy val indexPromptDF: DataFrame = Seq( - Seq(3, 1, 5, 4) - ).toDF("indexPrompt") - - lazy val batchIndexPromptDF: DataFrame = Seq( - Seq( - Seq(1, 8, 4, 2), - Seq(7, 3, 8, 5, 9), - Seq(8, 0, 11, 3, 14, 1)) - ).toDF("batchIndexPrompt") - test("Basic Usage") { testCompletion(promptCompletion, promptDF) } test("Basic usage with AAD auth") { - val aadToken = getAccessToken(Secrets.ServicePrincipalClientId, + val aadToken = getAccessToken( + Secrets.ServicePrincipalClientId, Secrets.ServiceConnectionSecret, "https://cognitiveservices.azure.com/") + val completion = new OpenAICompletion() .setAADToken(aadToken) - .setDeploymentName("text-davinci-001") - .setModel("text-davinci-003") + .setDeploymentName(deploymentName) .setCustomServiceName(openAIServiceName) - .setMaxTokens(20) - .setLogProbs(5) .setPromptCol("prompt") .setOutputCol("out") testCompletion(completion, promptDF) } - ignore("Batch Prompt") { + test("Batch Prompt") { testCompletion(batchPromptCompletion, batchPromptDF) } - // TODO: see if data type failure here is due to a change on the OpenAI side of things - ignore("Index Prompt") { - testCompletion(indexPromptCompletion, indexPromptDF) - } - - ignore("Batch Index Prompt") { - testCompletion(batchIndexPromptCompletion, batchIndexPromptDF) - } - def testCompletion(completion: OpenAICompletion, df: DataFrame, requiredLength: Int = 10): Unit = { val fromRow = CompletionResponse.makeFromRowConverter completion.transform(df).collect().foreach(r => @@ -107,22 +82,13 @@ class OpenAICompletionSuite extends TransformerFuzzing[OpenAICompletion] with Op assert(c.text.length > requiredLength))) } - def newCompletion: OpenAICompletion = { - new OpenAICompletion() - .setSubscriptionKey(openAIAPIKey) - .setDeploymentName("text-davinci-001") - .setCustomServiceName(openAIServiceName) - .setMaxTokens(20) - .setLogProbs(5) - .setOutputCol("out") - } override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = { super.assertDFEq(df1.drop("out"), df2.drop("out"))(eq) } override def testObjects(): Seq[TestObject[OpenAICompletion]] = - Seq(new TestObject(completion, df)) + Seq(new TestObject(newCompletion, df)) override def reader: MLReadable[_] = OpenAICompletion diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAIEmbeddingsSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAIEmbeddingsSuite.scala index 4b192b7fee..ed01308e2d 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAIEmbeddingsSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAIEmbeddingsSuite.scala @@ -29,7 +29,6 @@ class OpenAIEmbeddingsSuite extends TransformerFuzzing[OpenAIEmbedding] with Ope test("Basic Usage") { embedding.transform(df).collect().foreach(r => { val v = r.getAs[Vector]("out") - assert(v.size > 0) }) } diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAIPromptSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAIPromptSuite.scala index cb023a82cd..84283ebfef 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAIPromptSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/openai/OpenAIPromptSuite.scala @@ -3,12 +3,11 @@ package com.microsoft.azure.synapse.ml.cognitive.openai -import com.microsoft.azure.synapse.ml.core.spark.Functions.template import com.microsoft.azure.synapse.ml.core.test.base.Flaky import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing} import org.apache.spark.ml.util.MLReadable -import org.apache.spark.sql.functions.{col, lit, when} -import org.apache.spark.sql.{DataFrame, functions => F} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.col import org.scalactic.Equality class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIKey with Flaky { @@ -17,8 +16,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK lazy val prompt: OpenAIPrompt = new OpenAIPrompt() .setSubscriptionKey(openAIAPIKey) - .setDeploymentName("text-davinci-001") - .setModel("text-davinci-003") + .setDeploymentName(deploymentName) .setCustomServiceName(openAIServiceName) .setOutputCol("outParsed") .setTemperature(0) @@ -37,24 +35,17 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK .transform(df) .select("outParsed") .collect() - .map(r => - if (r.isNullAt(0)) 0 - else { - assert(r.getSeq[String](0).length > 0) - 1 - }) - .toSeq.sum + .count(r => Option(r.getSeq[String](0)).isDefined) assert(nonNullCount == 3) } test("Basic Usage JSON") { - val result = prompt - .setPromptTemplate( - """Split a word into prefix and postfix a respond in JSON - |Cherry: {{"prefix": "Che", "suffix": "rry"}} - |{text}: - |""".stripMargin) + prompt.setPromptTemplate( + """Split a word into prefix and postfix a respond in JSON + |Cherry: {{"prefix": "Che", "suffix": "rry"}} + |{text}: + |""".stripMargin) .setPostProcessing("json") .setPostProcessingOptions(Map("jsonSchema" -> "prefix STRING, suffix STRING")) .transform(df) diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/translate/TranslatorSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/translate/TranslatorSuite.scala index 9fd51b6dc4..8062f0bdc8 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/translate/TranslatorSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/translate/TranslatorSuite.scala @@ -105,7 +105,7 @@ class TranslateSuite extends TransformerFuzzing[Translate] .withColumn("translation", col("translation.text")) .select("translation", "transliteration").collect() assert(results.head.getSeq(0).mkString("\n") === "再见") - assert(results.head.getSeq(1).mkString("\n") === "zài jiàn") + assert(results.head.getSeq(1).mkString("\n").replaceAllLiterally(" ", "") === "zàijiàn") } test("Translate to multiple languages") { diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/causal/VerifyDoubleMLEstimator.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/causal/VerifyDoubleMLEstimator.scala index 25c78d622f..b3810d1374 100644 --- a/core/src/test/scala/com/microsoft/azure/synapse/ml/causal/VerifyDoubleMLEstimator.scala +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/causal/VerifyDoubleMLEstimator.scala @@ -3,6 +3,7 @@ package com.microsoft.azure.synapse.ml.causal +import com.microsoft.azure.synapse.ml.core.test.base.Flaky import com.microsoft.azure.synapse.ml.core.test.fuzzing.{EstimatorFuzzing, TestObject} import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.ml.classification.LogisticRegression @@ -10,7 +11,7 @@ import org.apache.spark.ml.util.MLReadable import org.apache.spark.sql.Row import org.apache.spark.sql.types.{BooleanType, DoubleType, IntegerType, StringType, StructField, StructType} -class VerifyDoubleMLEstimator extends EstimatorFuzzing[DoubleMLEstimator] { +class VerifyDoubleMLEstimator extends EstimatorFuzzing[DoubleMLEstimator] with Flaky { val cat = "Cat" val dog = "Dog" val bird = "Bird" diff --git a/notebooks/features/cognitive_services/CognitiveServices - OpenAI.ipynb b/notebooks/features/cognitive_services/CognitiveServices - OpenAI.ipynb index 1c889a6e34..e0db58dd90 100644 --- a/notebooks/features/cognitive_services/CognitiveServices - OpenAI.ipynb +++ b/notebooks/features/cognitive_services/CognitiveServices - OpenAI.ipynb @@ -63,7 +63,7 @@ "\n", "# Fill in the following lines with your service information\n", "service_name = \"synapseml-openai\"\n", - "deployment_name = \"text-davinci-001\"\n", + "deployment_name = \"gpt-35-turbo\"\n", "deployment_name_embeddings = \"text-search-ada-doc-001\"\n", "deployment_name_embeddings_query = \"text-search-ada-query-001\"\n", "\n", @@ -858,4 +858,4 @@ }, "nbformat": 4, "nbformat_minor": 1 -} +} \ No newline at end of file