Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enable GPT-4 in OpenAIPrompt #2248

Merged
merged 10 commits into from
Jul 16, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.services._
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.param.{Param, Params}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import spray.json.DefaultJsonProtocol._
Expand Down Expand Up @@ -40,6 +41,16 @@ trait HasPromptInputs extends HasServiceParams {

}

trait HasMessagesInput extends Params {
val messagesCol: Param[String] = new Param[String](
this, "messagesCol", "The column messages to generate chat completions for," +
" in the chat format. This column should have type Array(Struct(role: String, content: String)).")

def getMessagesCol: String = $(messagesCol)

def setMessagesCol(v: String): this.type = set(messagesCol, v)
}

trait HasOpenAISharedParams extends HasServiceParams with HasAPIVersion {

val deploymentName = new ServiceParam[String](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,10 @@ import scala.language.existentials
object OpenAIChatCompletion extends ComplexParamsReadable[OpenAIChatCompletion]

class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAITextParams with HasOpenAICognitiveServiceInput
with HasOpenAITextParams with HasMessagesInput with HasOpenAICognitiveServiceInput
with HasInternalJsonOutputParser with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

val messagesCol: Param[String] = new Param[String](
this, "messagesCol", "The column messages to generate chat completions for," +
" in the chat format. This column should have type Array(Struct(role: String, content: String)).")

def getMessagesCol: String = $(messagesCol)

def setMessagesCol(v: String): this.type = set(messagesCol, v)

def this() = this(Identifiable.randomUID("OpenAIChatCompletion"))

def urlPath: String = ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import com.microsoft.azure.synapse.ml.param.StringStringMapParam
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.{Column, DataFrame, Dataset, functions => F, types => T}

Expand All @@ -20,7 +21,7 @@ import scala.collection.JavaConverters._
object OpenAIPrompt extends ComplexParamsReadable[OpenAIPrompt]

class OpenAIPrompt(override val uid: String) extends Transformer
with HasOpenAITextParams
with HasOpenAITextParams with HasMessagesInput
with HasErrorCol with HasOutputCol
with HasURL with HasCustomCogServiceDomain with ConcurrencyParams
with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader
Expand Down Expand Up @@ -62,18 +63,30 @@ class OpenAIPrompt(override val uid: String) extends Transformer
set(postProcessingOptions, v.asScala.toMap)

val dropPrompt = new BooleanParam(
this, "dropPrompt", "whether to drop the column of prompts after templating")
this, "dropPrompt", "whether to drop the column of prompts after templating (when using legacy models)")

def getDropPrompt: Boolean = $(dropPrompt)

def setDropPrompt(value: Boolean): this.type = set(dropPrompt, value)

val systemPrompt = new Param[String](
this, "systemPrompt", "The initial system prompt to be used.")

def getSystemPrompt: String = $(systemPrompt)

def setSystemPrompt(value: String): this.type = set(systemPrompt, value)

private val defaultSystemPrompt = "You are an AI chatbot who wants to answer user's questions and complete tasks. " +
"Follow their instructions carefully and be brief if they don't say otherwise."

setDefault(
postProcessing -> "",
postProcessingOptions -> Map.empty,
outputCol -> (this.uid + "_output"),
errorCol -> (this.uid + "_error"),
messagesCol -> (this.uid + "_messages"),
dropPrompt -> true,
systemPrompt -> defaultSystemPrompt,
timeout -> 360.0
)

Expand All @@ -82,40 +95,77 @@ class OpenAIPrompt(override val uid: String) extends Transformer
}

private val localParamNames = Seq(
"promptTemplate", "outputCol", "postProcessing", "postProcessingOptions", "dropPrompt")
"promptTemplate", "outputCol", "postProcessing", "postProcessingOptions", "dropPrompt", "dropMessages",
"systemPrompt")

override def transform(dataset: Dataset[_]): DataFrame = {
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions._

logTransform[DataFrame]({
val df = dataset.toDF

val promptColName = df.withDerivativeCol("prompt")

val dfTemplated = df.withColumn(promptColName, Functions.template(getPromptTemplate))

val completion = openAICompletion.setPromptCol(promptColName)

// run completion
val results = completion
.transform(dfTemplated)
.withColumn(getOutputCol,
getParser.parse(F.element_at(F.col(completion.getOutputCol).getField("choices"), 1)
.getField("text")))
.drop(completion.getOutputCol)

if (getDropPrompt) {
results.drop(promptColName)
} else {
results
val completion = openAICompletion
val promptCol = Functions.template(getPromptTemplate)
val createMessagesUDF = udf((userMessage: String) => {
Seq(
OpenAIMessage("system", getSystemPrompt),
OpenAIMessage("user", userMessage)
)
})
completion match {
case chatCompletion: OpenAIChatCompletion =>
val messageColName = getMessagesCol
val dfTemplated = df.withColumn(messageColName, createMessagesUDF(promptCol))
val completionNamed = chatCompletion.setMessagesCol(messageColName)

val results = completionNamed
.transform(dfTemplated)
.withColumn(getOutputCol,
getParser.parse(F.element_at(F.col(completionNamed.getOutputCol).getField("choices"), 1)
.getField("message").getField("content")))
.drop(completionNamed.getOutputCol)

if (getDropPrompt) {
results.drop(messageColName)
} else {
results
}

case completion: OpenAICompletion =>
val promptColName = df.withDerivativeCol("prompt")
val dfTemplated = df.withColumn(promptColName, promptCol)
val completionNamed = completion.setPromptCol(promptColName)

// run completion
val results = completionNamed
.transform(dfTemplated)
.withColumn(getOutputCol,
getParser.parse(F.element_at(F.col(completionNamed.getOutputCol).getField("choices"), 1)
.getField("text")))
.drop(completionNamed.getOutputCol)

if (getDropPrompt) {
results.drop(promptColName)
} else {
results
}
}
}, dataset.columns.length)
}

private def openAICompletion: OpenAICompletion = {
// apply template
val completion = new OpenAICompletion()
private val legacyModels = Set("ada","babbage", "curie", "davinci",
"text-ada-001", "text-babbage-001", "text-curie-001", "text-davinci-002", "text-davinci-003",
"code-cushman-001", "code-davinci-002")

private def openAICompletion: OpenAIServicesBase = {

val completion: OpenAIServicesBase =
if (legacyModels.contains(getDeploymentName)) {
new OpenAICompletion()
}
else {
new OpenAIChatCompletion()
}
// apply all parameters
extractParamMap().toSeq
.filter(p => !localParamNames.contains(p.param.name))
Expand All @@ -136,10 +186,18 @@ class OpenAIPrompt(override val uid: String) extends Transformer
}
}

override def transformSchema(schema: StructType): StructType =
openAICompletion
.transformSchema(schema)
.add(getPostProcessing, getParser.outputSchema)
override def transformSchema(schema: StructType): StructType = {
openAICompletion match {
case chatCompletion: OpenAIChatCompletion =>
chatCompletion
.transformSchema(schema.add(getMessagesCol, StructType(Seq())))
.add(getPostProcessing, getParser.outputSchema)
case completion: OpenAICompletion =>
completion
.transformSchema(schema)
.add(getPostProcessing, getParser.outputSchema)
}
}
}

trait OutputParser {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK

test("Basic Usage JSON") {
prompt.setPromptTemplate(
"""Split a word into prefix and postfix a respond in JSON
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
|{text}:
|""".stripMargin)
"""Split a word into prefix and postfix a respond in JSON
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
|{text}:
|""".stripMargin)
.setPostProcessing("json")
.setPostProcessingOptions(Map("jsonSchema" -> "prefix STRING, suffix STRING"))
.transform(df)
Expand All @@ -62,6 +62,56 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
.foreach(r => assert(r.getStruct(0).getString(0).nonEmpty))
}

lazy val promptGpt4: OpenAIPrompt = new OpenAIPrompt()
.setSubscriptionKey(openAIAPIKey)
.setDeploymentName(deploymentNameGpt4)
.setCustomServiceName(openAIServiceName)
.setOutputCol("outParsed")
.setTemperature(0)

test("Basic Usage - Gpt 4") {
sss04 marked this conversation as resolved.
Show resolved Hide resolved
val nonNullCount = promptGpt4
.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ")
.setPostProcessing("csv")
.transform(df)
.select("outParsed")
.collect()
.count(r => Option(r.getSeq[String](0)).isDefined)

assert(nonNullCount == 3)
}

test("Basic Usage JSON - Gpt 4") {
promptGpt4.setPromptTemplate(
"""Split a word into prefix and postfix a respond in JSON
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
|{text}:
|""".stripMargin)
.setPostProcessing("json")
.setPostProcessingOptions(Map("jsonSchema" -> "prefix STRING, suffix STRING"))
.transform(df)
.select("outParsed")
.where(col("outParsed").isNotNull)
.collect()
.foreach(r => assert(r.getStruct(0).getString(0).nonEmpty))
}

test("Setting and Keeping Messages Col - Gpt 4") {
promptGpt4.setMessagesCol("messages")
.setDropPrompt(false)
.setPromptTemplate(
"""Classify each word as to whether they are an F1 team or not
|ferrari: TRUE
|tomato: FALSE
|{text}:
|""".stripMargin)
.transform(df)
.select("messages")
.where(col("messages").isNotNull)
.collect()
.foreach(r => assert(r.get(0) != null))
}

override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = {
super.assertDFEq(df1.drop("out", "outParsed"), df2.drop("out", "outParsed"))(eq)
}
Expand Down
Loading