diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index b1b3d21499..b57f4d65da 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -9,6 +9,7 @@ import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails import com.microsoft.azure.synapse.ml.param.ServiceParam import com.microsoft.azure.synapse.ml.services._ import org.apache.spark.ml.PipelineModel +import org.apache.spark.ml.param.{Param, Params} import org.apache.spark.sql.Row import org.apache.spark.sql.types._ import spray.json.DefaultJsonProtocol._ @@ -40,6 +41,16 @@ trait HasPromptInputs extends HasServiceParams { } +trait HasMessagesInput extends Params { + val messagesCol: Param[String] = new Param[String]( + this, "messagesCol", "The column messages to generate chat completions for," + + " in the chat format. This column should have type Array(Struct(role: String, content: String)).") + + def getMessagesCol: String = $(messagesCol) + + def setMessagesCol(v: String): this.type = set(messagesCol, v) +} + trait HasOpenAISharedParams extends HasServiceParams with HasAPIVersion { val deploymentName = new ServiceParam[String]( diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala index aeace84127..57837ad276 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala @@ -20,18 +20,10 @@ import scala.language.existentials object OpenAIChatCompletion extends ComplexParamsReadable[OpenAIChatCompletion] class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(uid) - with HasOpenAITextParams with HasOpenAICognitiveServiceInput + with HasOpenAITextParams with HasMessagesInput with HasOpenAICognitiveServiceInput with HasInternalJsonOutputParser with SynapseMLLogging { logClass(FeatureNames.AiServices.OpenAI) - val messagesCol: Param[String] = new Param[String]( - this, "messagesCol", "The column messages to generate chat completions for," + - " in the chat format. This column should have type Array(Struct(role: String, content: String)).") - - def getMessagesCol: String = $(messagesCol) - - def setMessagesCol(v: String): this.type = set(messagesCol, v) - def this() = this(Identifiable.randomUID("OpenAIChatCompletion")) def urlPath: String = "" diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index 52661a4e70..b17b5c59c1 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -12,6 +12,7 @@ import com.microsoft.azure.synapse.ml.param.StringStringMapParam import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.{Column, DataFrame, Dataset, functions => F, types => T} @@ -20,7 +21,7 @@ import scala.collection.JavaConverters._ object OpenAIPrompt extends ComplexParamsReadable[OpenAIPrompt] class OpenAIPrompt(override val uid: String) extends Transformer - with HasOpenAITextParams + with HasOpenAITextParams with HasMessagesInput with HasErrorCol with HasOutputCol with HasURL with HasCustomCogServiceDomain with ConcurrencyParams with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader @@ -62,18 +63,30 @@ class OpenAIPrompt(override val uid: String) extends Transformer set(postProcessingOptions, v.asScala.toMap) val dropPrompt = new BooleanParam( - this, "dropPrompt", "whether to drop the column of prompts after templating") + this, "dropPrompt", "whether to drop the column of prompts after templating (when using legacy models)") def getDropPrompt: Boolean = $(dropPrompt) def setDropPrompt(value: Boolean): this.type = set(dropPrompt, value) + val systemPrompt = new Param[String]( + this, "systemPrompt", "The initial system prompt to be used.") + + def getSystemPrompt: String = $(systemPrompt) + + def setSystemPrompt(value: String): this.type = set(systemPrompt, value) + + private val defaultSystemPrompt = "You are an AI chatbot who wants to answer user's questions and complete tasks. " + + "Follow their instructions carefully and be brief if they don't say otherwise." + setDefault( postProcessing -> "", postProcessingOptions -> Map.empty, outputCol -> (this.uid + "_output"), errorCol -> (this.uid + "_error"), + messagesCol -> (this.uid + "_messages"), dropPrompt -> true, + systemPrompt -> defaultSystemPrompt, timeout -> 360.0 ) @@ -82,7 +95,8 @@ class OpenAIPrompt(override val uid: String) extends Transformer } private val localParamNames = Seq( - "promptTemplate", "outputCol", "postProcessing", "postProcessingOptions", "dropPrompt") + "promptTemplate", "outputCol", "postProcessing", "postProcessingOptions", "dropPrompt", "dropMessages", + "systemPrompt") override def transform(dataset: Dataset[_]): DataFrame = { import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions._ @@ -90,32 +104,68 @@ class OpenAIPrompt(override val uid: String) extends Transformer logTransform[DataFrame]({ val df = dataset.toDF - val promptColName = df.withDerivativeCol("prompt") - - val dfTemplated = df.withColumn(promptColName, Functions.template(getPromptTemplate)) - - val completion = openAICompletion.setPromptCol(promptColName) - - // run completion - val results = completion - .transform(dfTemplated) - .withColumn(getOutputCol, - getParser.parse(F.element_at(F.col(completion.getOutputCol).getField("choices"), 1) - .getField("text"))) - .drop(completion.getOutputCol) - - if (getDropPrompt) { - results.drop(promptColName) - } else { - results + val completion = openAICompletion + val promptCol = Functions.template(getPromptTemplate) + val createMessagesUDF = udf((userMessage: String) => { + Seq( + OpenAIMessage("system", getSystemPrompt), + OpenAIMessage("user", userMessage) + ) + }) + completion match { + case chatCompletion: OpenAIChatCompletion => + val messageColName = getMessagesCol + val dfTemplated = df.withColumn(messageColName, createMessagesUDF(promptCol)) + val completionNamed = chatCompletion.setMessagesCol(messageColName) + + val results = completionNamed + .transform(dfTemplated) + .withColumn(getOutputCol, + getParser.parse(F.element_at(F.col(completionNamed.getOutputCol).getField("choices"), 1) + .getField("message").getField("content"))) + .drop(completionNamed.getOutputCol) + + if (getDropPrompt) { + results.drop(messageColName) + } else { + results + } + + case completion: OpenAICompletion => + val promptColName = df.withDerivativeCol("prompt") + val dfTemplated = df.withColumn(promptColName, promptCol) + val completionNamed = completion.setPromptCol(promptColName) + + // run completion + val results = completionNamed + .transform(dfTemplated) + .withColumn(getOutputCol, + getParser.parse(F.element_at(F.col(completionNamed.getOutputCol).getField("choices"), 1) + .getField("text"))) + .drop(completionNamed.getOutputCol) + + if (getDropPrompt) { + results.drop(promptColName) + } else { + results + } } }, dataset.columns.length) } - private def openAICompletion: OpenAICompletion = { - // apply template - val completion = new OpenAICompletion() + private val legacyModels = Set("ada","babbage", "curie", "davinci", + "text-ada-001", "text-babbage-001", "text-curie-001", "text-davinci-002", "text-davinci-003", + "code-cushman-001", "code-davinci-002") + + private def openAICompletion: OpenAIServicesBase = { + val completion: OpenAIServicesBase = + if (legacyModels.contains(getDeploymentName)) { + new OpenAICompletion() + } + else { + new OpenAIChatCompletion() + } // apply all parameters extractParamMap().toSeq .filter(p => !localParamNames.contains(p.param.name)) @@ -136,10 +186,18 @@ class OpenAIPrompt(override val uid: String) extends Transformer } } - override def transformSchema(schema: StructType): StructType = - openAICompletion - .transformSchema(schema) - .add(getPostProcessing, getParser.outputSchema) + override def transformSchema(schema: StructType): StructType = { + openAICompletion match { + case chatCompletion: OpenAIChatCompletion => + chatCompletion + .transformSchema(schema.add(getMessagesCol, StructType(Seq()))) + .add(getPostProcessing, getParser.outputSchema) + case completion: OpenAICompletion => + completion + .transformSchema(schema) + .add(getPostProcessing, getParser.outputSchema) + } + } } trait OutputParser { diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala index 68910407bc..6282067b0d 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala @@ -49,10 +49,10 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK test("Basic Usage JSON") { prompt.setPromptTemplate( - """Split a word into prefix and postfix a respond in JSON - |Cherry: {{"prefix": "Che", "suffix": "rry"}} - |{text}: - |""".stripMargin) + """Split a word into prefix and postfix a respond in JSON + |Cherry: {{"prefix": "Che", "suffix": "rry"}} + |{text}: + |""".stripMargin) .setPostProcessing("json") .setPostProcessingOptions(Map("jsonSchema" -> "prefix STRING, suffix STRING")) .transform(df) @@ -62,6 +62,56 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK .foreach(r => assert(r.getStruct(0).getString(0).nonEmpty)) } + lazy val promptGpt4: OpenAIPrompt = new OpenAIPrompt() + .setSubscriptionKey(openAIAPIKey) + .setDeploymentName(deploymentNameGpt4) + .setCustomServiceName(openAIServiceName) + .setOutputCol("outParsed") + .setTemperature(0) + + test("Basic Usage - Gpt 4") { + val nonNullCount = promptGpt4 + .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") + .setPostProcessing("csv") + .transform(df) + .select("outParsed") + .collect() + .count(r => Option(r.getSeq[String](0)).isDefined) + + assert(nonNullCount == 3) + } + + test("Basic Usage JSON - Gpt 4") { + promptGpt4.setPromptTemplate( + """Split a word into prefix and postfix a respond in JSON + |Cherry: {{"prefix": "Che", "suffix": "rry"}} + |{text}: + |""".stripMargin) + .setPostProcessing("json") + .setPostProcessingOptions(Map("jsonSchema" -> "prefix STRING, suffix STRING")) + .transform(df) + .select("outParsed") + .where(col("outParsed").isNotNull) + .collect() + .foreach(r => assert(r.getStruct(0).getString(0).nonEmpty)) + } + + test("Setting and Keeping Messages Col - Gpt 4") { + promptGpt4.setMessagesCol("messages") + .setDropPrompt(false) + .setPromptTemplate( + """Classify each word as to whether they are an F1 team or not + |ferrari: TRUE + |tomato: FALSE + |{text}: + |""".stripMargin) + .transform(df) + .select("messages") + .where(col("messages").isNotNull) + .collect() + .foreach(r => assert(r.get(0) != null)) + } + override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = { super.assertDFEq(df1.drop("out", "outParsed"), df2.drop("out", "outParsed"))(eq) }