Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: created predict endpoint #47

Merged
merged 1 commit into from
Nov 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
feat: created predict endpoint
telpirion committed Nov 25, 2024

Verified

This commit was signed with the committer’s verified signature.
prabhu prabhu
commit ecd3c3e2e9329535f28a530a3fdaa62ef720916d
4 changes: 3 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -16,7 +16,9 @@ COPY site/css ./site/css
COPY site/html ./site/html
COPY prompts ./server/templates
COPY server/favicon.ico ./server/favicon.ico
COPY server/* ./server
COPY server/generated ./server/generated
COPY server/ai ./server/ai
COPY server/*.go ./server

COPY server/go.mod server/go.sum ./server/
WORKDIR /server
25 changes: 25 additions & 0 deletions docs/server.md
Original file line number Diff line number Diff line change
@@ -46,3 +46,28 @@ $ docker tag myherodotus us-west1-docker.pkg.dev/${PROJECT_ID}/my-herodotus/base
$ docker push us-west1-docker.pkg.dev/${PROJECT_ID}/my-herodotus/base-image:${SEMVER}
```

## Get predictions directly from API

The MyHerodotus app exposes an API endpoint, `/predict`, that allows callers to send
raw prediction requests to the AI system.

The following code sample demonstrates how to get a simple prediction from the `predict`
endpoint using `curl`. This assumes that the MyHerodotus app is running locally and
listeningon port `:8080`.

```sh
curl --header "Content-Type: application/json" \
--request POST \
--data '{"message":"I want to go to Greece","model":"gemini"}' \
http://localhost:8080/predict
```

The following code sample demonstrates how to get a simple prediction from the `predict`
endpoint of the deployed Herodotus app using `curl`.

```sh
curl --header "Content-Type: application/json" \
--request POST \
--data '{"message":"I want to go to Greece","model":"gemini"}' \
https://myherodotus-1025771077852.us-west1.run.app/predict
```
101 changes: 80 additions & 21 deletions server/vertex.go → server/ai/vertex.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package main
package ai

import (
"bytes"
@@ -17,6 +17,8 @@ import (
"cloud.google.com/go/vertexai/genai"
"google.golang.org/api/option"
"google.golang.org/protobuf/types/known/structpb"

"github.com/telpirion/MyHerodotus/generated"
)

const (
@@ -29,7 +31,30 @@ const (
MaxGemmaTokens int32 = 2048
)

var cachedContext string = ""
var (
cachedContext = ""
convoContext = ""
)

type Modality int

const (
Gemini Modality = iota
GeminiTuned
Gemma
AgentAssisted
EmbeddingsAssisted
)

var (
modalitiesMap = map[string]Modality{
"gemini": Gemini,
"gemini-tuned": GeminiTuned,
"gemma": Gemma,
"agent-assisted": AgentAssisted,
"embeddings-assisted": EmbeddingsAssisted,
}
)

type MinCacheNotReachedError struct {
ConversationCount int
@@ -45,8 +70,29 @@ type promptInput struct {
History string
}

// getTokenCount uses the Gemini tokenizer to count the tokens in some text.
func getTokenCount(text string) (int32, error) {
func Predict(query, modality, projectID string) (response string, templateName string, err error) {

switch modalitiesMap[strings.ToLower(modality)] {
case Gemini:
response, err = textPredictGemini(query, projectID, Gemini)
case Gemma:
response, err = textPredictGemma(query, projectID)
case GeminiTuned:
response, err = textPredictGemini(query, projectID, GeminiTuned)
default:
response, err = textPredictGemini(query, projectID, Gemini)
}

if err != nil {
return "", "", nil
}

cachedContext += fmt.Sprintf("### Human: %s\n### Assistant: %s\n", query, response)
return response, templateName, nil
}

// GetTokenCount uses the Gemini tokenizer to count the tokens in some text.
func GetTokenCount(text, projectID string) (int32, error) {
location := "us-west1"
ctx := context.Background()
client, err := genai.NewClient(ctx, projectID, location)
@@ -65,9 +111,9 @@ func getTokenCount(text string) (int32, error) {
return resp.TotalTokens, nil
}

// setConversationContext creates string out of past conversation between user and model.
// SetConversationContext creates string out of past conversation between user and model.
// This conversation history is used as grounding for the prompt template.
func setConversationContext(convoHistory []ConversationBit) error {
func SetConversationContext(convoHistory []generated.ConversationBit) error {
tmp, err := template.ParseFiles(HistoryTemplate)
if err != nil {
return err
@@ -93,15 +139,15 @@ func extractAnswer(response string) string {
}

// createPrompt generates a new prompt based upon the stored prompt template.
func createPrompt(message, templateName string) (string, error) {
func createPrompt(message, templateName, history string) (string, error) {
tmp, err := template.ParseFiles(templateName)
if err != nil {
return "", nil
}

promptInputs := promptInput{
Query: message,
History: cachedContext,
History: history,
}

var buf bytes.Buffer
@@ -122,25 +168,33 @@ func textPredictGemma(message, projectID string) (string, error) {
apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
if err != nil {
LogError(fmt.Sprintf("unable to create prediction client: %v\n", err))
return "", err
}
defer client.Close()

parameters := map[string]interface{}{}

prompt, err := createPrompt(message, GemmaTemplate)
prompt, err := createPrompt(message, GemmaTemplate, cachedContext)
if err != nil {
LogError(fmt.Sprintf("unable to create Gemma prompt: %v\n", err))
return "", err
}

tokenCount, err := GetTokenCount(prompt, projectID)
if err != nil {
return "", fmt.Errorf("error counting input tokens: %w", err)
}
if tokenCount > MaxGemmaTokens {
prompt, err = createPrompt(message, GemmaTemplate, trimContext())
}
if err != nil {
prompt = message
}

promptValue, err := structpb.NewValue(map[string]interface{}{
"inputs": prompt,
"parameters": parameters,
})
if err != nil {
LogError(fmt.Sprintf("unable to create prompt value: %v\n", err))
return "", err
}

@@ -151,7 +205,6 @@ func textPredictGemma(message, projectID string) (string, error) {

resp, err := client.Predict(ctx, req)
if err != nil {
LogError(fmt.Sprintf("unable to make prediction: %v\n", err))
return "", err
}

@@ -163,19 +216,18 @@ func textPredictGemma(message, projectID string) (string, error) {
}

// textPredictGemini generates text using a Gemini 1.5 Flash model
func textPredictGemini(message, projectID, modelVersion string) (string, error) {
func textPredictGemini(message, projectID string, modality Modality) (string, error) {
ctx := context.Background()
location := "us-west1"

client, err := genai.NewClient(ctx, projectID, location)
if err != nil {
LogError(fmt.Sprintf("unable to create genai client: %v\n", err))
return "", err
}
defer client.Close()

modelName := GeminiModel
if modelVersion == "gemini-tuned" {
if modality == GeminiTuned {
endpointID := os.Getenv("TUNED_MODEL_ENDPOINT_ID")
modelName = fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID)
}
@@ -185,21 +237,18 @@ func textPredictGemini(message, projectID, modelVersion string) (string, error)
llm.CachedContentName = convoContext
}

prompt, err := createPrompt(message, GeminiTemplate)
prompt, err := createPrompt(message, GeminiTemplate, cachedContext)
if err != nil {
LogError(fmt.Sprintf("unable to create Gemini prompt: %v\n", err))
return "", err
}

resp, err := llm.GenerateContent(ctx, genai.Text(prompt))
if err != nil {
LogError(fmt.Sprintf("unable to generate content: %v\n", err))
return "", err
}

candidate, err := getCandidate(resp)
if err != nil {
LogError(err.Error())
return "I'm not sure how to answer that. Would you please repeat the question?", nil
}
return extractAnswer(candidate), nil
@@ -224,7 +273,7 @@ func getCandidate(resp *genai.GenerateContentResponse) (string, error) {

// storeConversationContext uploads past user conversations with the model into a Gen AI context.
// This context is used when the model is answering questions from the user.
func storeConversationContext(conversationHistory []ConversationBit, projectID string) (string, error) {
func StoreConversationContext(conversationHistory []generated.ConversationBit, projectID string) (string, error) {
if len(conversationHistory) < MinimumConversationNum {
return "", &MinCacheNotReachedError{ConversationCount: len(conversationHistory)}
}
@@ -266,3 +315,13 @@ func storeConversationContext(conversationHistory []ConversationBit, projectID s

return resourceName, nil
}

func trimContext() (last string) {
sep := "###"
convos := strings.Split(cachedContext, sep)
length := len(convos)
if len(convos) > 3 {
last = strings.Join(convos[length-3:length-1], sep)
}
return last
}
4 changes: 2 additions & 2 deletions server/vertex_test.go → server/ai/vertex_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package main
package ai

import (
"strings"
@@ -28,7 +28,7 @@ func TestSetConversationContext(t *testing.T) {
BotResponse: "test bot response 2",
},
}
err := setConversationContext(convoHistory)
err := SetConversationContext(convoHistory)
if err != nil {
t.Fatal(err)
}
12 changes: 7 additions & 5 deletions server/db.go
Original file line number Diff line number Diff line change
@@ -34,6 +34,8 @@ import (

"cloud.google.com/go/firestore"
"google.golang.org/api/iterator"

"github.com/telpirion/MyHerodotus/generated"
)

const (
@@ -55,10 +57,10 @@ var CollectionName string = "HerodotusDev"
*/
type ConversationHistory struct {
UserEmail string
Conversations []ConversationBit
Conversations []generated.ConversationBit
}

func saveConversation(convo ConversationBit, userEmail, projectID string) (string, error) {
func saveConversation(convo generated.ConversationBit, userEmail, projectID string) (string, error) {
ctx := context.Background()

// Get CollectionName for running in staging or prod
@@ -104,9 +106,9 @@ func updateConversation(documentId, userEmail, rating, projectID string) error {
return nil
}

func getConversation(userEmail, projectID string) ([]ConversationBit, error) {
func getConversation(userEmail, projectID string) ([]generated.ConversationBit, error) {
ctx := context.Background()
conversations := []ConversationBit{}
conversations := []generated.ConversationBit{}
client, err := firestore.NewClientWithDatabase(ctx, projectID, DBName)
if err != nil {
LogError(fmt.Sprintf("firestore.Client: %v\n", err))
@@ -126,7 +128,7 @@ func getConversation(userEmail, projectID string) ([]ConversationBit, error) {
LogError(fmt.Sprintf("Firestore Iterator: %v\n", err))
return conversations, err
}
var convo ConversationBit
var convo generated.ConversationBit
err = doc.DataTo(&convo)
if err != nil {
LogError(fmt.Sprintf("Firestore document unmarshaling: %v\n", err))

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion server/go.mod
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module my-herodotus
module github.com/telpirion/MyHerodotus

go 1.23.0

Binary file added server/main
Binary file not shown.
Loading