Skip to content

Commit

Permalink
fix(ai): add context possibiloty
Browse files Browse the repository at this point in the history
  • Loading branch information
RISCH Francois committed Jun 12, 2024
1 parent 530caf1 commit aeb95b2
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 25 deletions.
15 changes: 12 additions & 3 deletions src/main/java/com/datagen/model/type/BedrockField.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@ public class BedrockField extends Field<String> {
private final String modelId;
private final BedrockModelType bedrockmodeltype;
private JSONObject preparedRequest = null;
private final String context;

public BedrockField(String name, String url, String user, String password,
String request, String modelType, Float temperature, String region, Integer maxTokens) {
String request, String modelType, Float temperature, String region, Integer maxTokens, String context) {
this.name = name;
this.url = url;
this.user = user;
Expand All @@ -76,6 +77,10 @@ public BedrockField(String name, String url, String user, String password,
.region(this.region)
.build();

var contextAsMessage = context!=null?"Use the following information to answer the question:"+System.lineSeparator()+context:"";
this.context = "Generate only the answer."+System.lineSeparator()+contextAsMessage;
log.debug("Will provide following System information to the model: {}", context);

// See model Ids available at: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
this.modelId = modelType == null ? "amazon.titan-text-lite-v1" : modelType;
/*
Expand Down Expand Up @@ -107,6 +112,7 @@ public BedrockField(String name, String url, String user, String password,
yield new JSONObject()
.put("temperature", this.temperature)
.put("stop_sequences", List.of("\n\nHuman:"))
.put("system", context)
.put("max_tokens_to_sample", this.maxTokens);
case MISTRAL:
yield new JSONObject()
Expand Down Expand Up @@ -134,11 +140,14 @@ public String generateComputedValue(Row row) {
case ANTHROPIC -> preparedRequest.put("prompt",
"Human: " + stringToEvaluate + "\\n\\nAssistant:");
case MISTRAL -> preparedRequest.put("prompt",
"<s>[INST] " + stringToEvaluate + "[/INST]");
"<s>[INST] " + stringToEvaluate + "[/INST]" + context);
case TITAN -> preparedRequest.put("inputText", stringToEvaluate);
case LLAMA -> preparedRequest.put("prompt", stringToEvaluate);
default -> preparedRequest.put("prompt", stringToEvaluate);
}

log.debug("Request to Bedrock is: {}", preparedRequest.toString());

// Encode and send the request.
var response = bedrockRuntimeClient.invokeModel(req -> req
.accept("application/json")
Expand Down Expand Up @@ -169,7 +178,7 @@ public String generateComputedValue(Row row) {
e);
}

return responseText;
return responseText.trim().replaceAll("\\n[ \\t]*\\n","");
}

@Override
Expand Down
12 changes: 8 additions & 4 deletions src/main/java/com/datagen/model/type/Field.java
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ public static Field instantiateField(
Float frequencyPenalty,
Float presencePenalty,
Integer maxTokens,
Float topP
Float topP,
String context
) {
if (name == null || name.isEmpty()) {
throw new IllegalStateException(
Expand Down Expand Up @@ -296,7 +297,8 @@ yield new OllamaField(name,
presencePenalty == null ? Float.valueOf(properties.get(
ApplicationConfigs.OLLAMA_PRESENCE_PENALTY_DEFAULT)) : presencePenalty,
topP == null ? Float.valueOf(properties.get(
ApplicationConfigs.OLLAMA_TOP_P_DEFAULT)) : topP
ApplicationConfigs.OLLAMA_TOP_P_DEFAULT)) : topP,
context
);

case "BEDROCK":
Expand All @@ -314,7 +316,8 @@ yield new BedrockField(name, url,
ApplicationConfigs.BEDROCK_TEMPERATURE_DEFAULT)) : temperature,
properties.get(ApplicationConfigs.BEDROCK_REGION),
maxTokens == null ? Integer.valueOf(properties.get(
ApplicationConfigs.BEDROCK_MAX_TOKENS_DEFAULT)) : maxTokens
ApplicationConfigs.BEDROCK_MAX_TOKENS_DEFAULT)) : maxTokens,
context
);

case "OPENAI":
Expand All @@ -336,7 +339,8 @@ yield new OpenAIField(name, url,
maxTokens == null ? Integer.valueOf(properties.get(
ApplicationConfigs.OPENAI_MAX_TOKENS_DEFAULT)) : maxTokens,
topP == null ? Float.valueOf(properties.get(
ApplicationConfigs.OPENAI_TOP_P_DEFAULT)) : topP
ApplicationConfigs.OPENAI_TOP_P_DEFAULT)) : topP,
context
);

default:
Expand Down
16 changes: 13 additions & 3 deletions src/main/java/com/datagen/model/type/OllamaField.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@
import org.apache.kudu.Type;
import org.apache.kudu.client.PartialRow;
import org.apache.orc.TypeDescription;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;

import java.sql.SQLException;
import java.util.LinkedList;
import java.util.List;

@Slf4j
public class OllamaField extends Field<String> {
Expand All @@ -46,11 +49,12 @@ public class OllamaField extends Field<String> {
private final OllamaApi ollamaApi;
private final OllamaChatClient ollamaChatClient;
private final OllamaOptions ollamaOptions;
private final SystemMessage systemMessage;


public OllamaField(String name, String url, String user, String password, String request,
String modelType, Float temperature, Float frequencyPenalty,
Float presencePenalty, Float topP) {
Float presencePenalty, Float topP, String context) {
this.name = name;
this.url = url;
this.user = user;
Expand All @@ -64,18 +68,24 @@ public OllamaField(String name, String url, String user, String password, String
.withFrequencyPenalty(frequencyPenalty == null ? 1.0f : frequencyPenalty)
.withPresencePenalty(presencePenalty == null ? 1.0f : presencePenalty)
.withTopP(topP == null ? 1.0f : topP);

var contextAsMessage = context!=null?"Use the following information to answer the question:"+System.lineSeparator()+context:"";
this.systemMessage = new SystemMessage("Generate only the answer."+System.lineSeparator()+contextAsMessage);
log.debug("Will provide following System information to the model: {}", systemMessage.getContent());
}

@Override
public String generateComputedValue(Row row) {
String stringToEvaluate = ParsingUtils.injectRowValuesToAString(row, requestToInject);
log.debug("Asking to Ollama: {}", stringToEvaluate);
UserMessage userMessage = new UserMessage(stringToEvaluate);

return this.ollamaChatClient.call(
new Prompt(
stringToEvaluate,
List.of(userMessage, this.systemMessage),
this.ollamaOptions
)).getResult().getOutput().getContent();
)).getResult().getOutput().getContent()
.trim().replaceAll("\\n[ \\t]*\\n","");
}

@Override
Expand Down
19 changes: 17 additions & 2 deletions src/main/java/com/datagen/model/type/OpenAIField.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,16 @@
import org.apache.kudu.Type;
import org.apache.kudu.client.PartialRow;
import org.apache.orc.TypeDescription;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;

import java.sql.SQLException;
import java.util.LinkedList;
import java.util.List;

@Slf4j
public class OpenAIField extends Field<String> {
Expand All @@ -47,10 +51,11 @@ public class OpenAIField extends Field<String> {
private final OpenAiChatClient openAiChatClient;
private final OpenAiChatOptions openAiChatOptions;
private final String modelId;
private final SystemMessage systemMessage;

public OpenAIField(String name, String url, String user, String password,
String request, String modelType, Float temperature, Float frequencyPenalty,
Float presencePenalty, Integer maxTokens, Float topP) {
Float presencePenalty, Integer maxTokens, Float topP, String context) {
this.name = name;
this.url = url;
this.user = user;
Expand All @@ -71,14 +76,24 @@ public OpenAIField(String name, String url, String user, String password,
.build();
this.openAiChatClient = new OpenAiChatClient(openAiApi, openAiChatOptions);

var contextAsMessage = context!=null?"Use the following information to answer the question:"+System.lineSeparator()+context:"";
this.systemMessage = new SystemMessage("Generate only the answer."+System.lineSeparator()+contextAsMessage);
log.debug("Will provide following System information to the model: {}", systemMessage.getContent());

}

@Override
public String generateComputedValue(Row row) {
String stringToEvaluate =
ParsingUtils.injectRowValuesToAString(row, requestToInject);
log.debug("Asking to OpenAI: {}", stringToEvaluate);
return openAiChatClient.call(stringToEvaluate);
UserMessage userMessage = new UserMessage(stringToEvaluate);

return openAiChatClient.call(new Prompt(
List.of(userMessage, this.systemMessage),
this.openAiChatOptions
)).getResult().getOutput().getContent()
.trim().replaceAll("\\n[ \\t]*\\n","");
}

@Override
Expand Down
10 changes: 9 additions & 1 deletion src/main/java/com/datagen/parsers/JsonParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,13 @@ private T getOneField(JsonNode jsonField, Map<ApplicationConfigs, String> proper
topP = null;
}

String context;
try {
context = jsonField.get("context").asText();
} catch (NullPointerException e) {
context = null;
}

JsonNode filtersArray = jsonField.get("filters");
List<JsonNode> filters = new ArrayList<>();
try {
Expand Down Expand Up @@ -391,7 +398,8 @@ private T getOneField(JsonNode jsonField, Map<ApplicationConfigs, String> proper
frequencyPenalty,
presencePenalty,
maxTokens,
topP);
topP,
context);
}

private Map<String, String> mapColNameToColQual(String mapping) {
Expand Down
9 changes: 6 additions & 3 deletions src/main/resources/models/example-model-ai.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
"user": "<YOUR_ACCESS_KEY>",
"password": "<YOUR_SECRET_KEY>",
"temperature": 1.0,
"max_tokens": 256
"max_tokens": 256,
"context": "For their birthday, kids wants a toy, young wants a new car, adults wants a prosper family and friends, elder wants a strong health"
},
{
"name": "birthday_wish_ollama",
Expand All @@ -28,7 +29,8 @@
"temperature": 1.0,
"frequency_penalty": 1.5,
"presence_penalty": 1.3,
"top_p": 1.0
"top_p": 1.0,
"context": "For their birthday, kids wants a toy, young wants a new car, adults wants a prosper family and friends, elder wants a strong health"
},
{
"name": "birthday_wish_openai",
Expand All @@ -40,7 +42,8 @@
"frequency_penalty": 1.5,
"presence_penalty": 1.3,
"max_tokens": 256,
"top_p": 1.0
"top_p": 1.0,
"context": "For their birthday, kids wants a toy, young wants a new car, adults wants a prosper family and friends, elder wants a strong health"
}
],
"Table_Names": {
Expand Down
15 changes: 6 additions & 9 deletions src/main/resources/models/example-model.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,17 @@
{
"name": "age",
"type": "LONG",
"min": 18,
"max": 99
"min": 5,
"max": 85
},
{
"name": "birthday_wish_openai",
"type": "OPENAI",
"name": "birthday_wish_bedrock",
"type": "BEDROCK",
"request": "generate a one line birthday wish to ${name} who is ${age} years old today",
"model_type": "gpt-4o",
"password": "",
"model_type": "mistral.mistral-small-2402-v1:0",
"temperature": 1.0,
"frequency_penalty": 1.5,
"presence_penalty": 1.3,
"max_tokens": 256,
"top_p": 1.0
"context": "For their birthday, kids wants a toy, young wants a new car, adults wants a prosper family and friends, elder wants a strong health"
}
],
"Table_Names": {
Expand Down

0 comments on commit aeb95b2

Please sign in to comment.