Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: integrated agent, live evaluations #50

Merged
merged 3 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ This system allows the usage of three related LLM models:

+ The out-of-the-box [Gemini 1.5 Flash model][gemini]
+ A tuned version of the Gemini 1.5 Flash model, trained on the [Guanaco dataset][guanaco].
+ A [Gemma 2][gemma2] open source model. This model currently cannot be
evaluated with the Evaluations API.
+ A [Gemma 2][gemma2] open source model.

These models have been evaluated against the following set of metrics.

Expand All @@ -55,10 +54,29 @@ The following table shows the evaluation scores for each of these models.

| Model | ROUGE | Closed domain | Open domain | Groundedness | Coherence | Date of eval |
| ---------------- | ------ | ------------- | ----------- | ------------ | --------- | ------------ |
| Gemini 1.5 Flash | 1.0[1] | 0.52 | 1.0 | 1.0[1] | 3.8 | 2024-11-07 |
| Tuned Gemini | 0.41 | 0.8 | 1.0 | 0.6 | 3.8 | 2024-11-07 |
| Gemini 1.5 Flash | 0.20[1]| 0.0 | 1.0 | 1.0[1] | 3.3 | 2024-11-25 |
| Tuned Gemini | 0.21 | 0.4 | 1.0 | 1.0 | 2.4 | 2024-11-25 |
| Gemma | 0.05 | 0.6 | 0.4 | 0.8 | 1.4 | 2024-11-25 |

[1]: Gemini 1.5 Flash responses were used as the ground truth for all other models.
[1]: Gemini 1.5 Flash responses from 2024-11-05 are used as the ground truth
for all other models.

## Adversarial evaluations

These models have been evaluated against the following set of adversarial
techniques.

+ [Prompt injection][injection]
+ [Prompt leaking][leaking]
+ [Jailbreaking][jailbreaking]

The following table shows the evaluation scores for adversarial prompting.

| Model | Prompt injection | Prompt leaking | Jailbreaking | Date of eval |
| ---------------- | ----------------- | -------------- | ------------ | ------------ |
| Gemini 1.5 Flash | 0.66 | 0.66 | 1.0 | 2024-11-25 |
| Tuned Gemini | 0.33 | 1.0 | 1.0 | 2024-11-25 |
| Gemma | 1.0 | 0.66 | 0.66 | 2024-11-25 |

[bigquery]: https://cloud.google.com/bigquery/docs
[bulma]: https://bulma.io/documentation/components/message/
Expand All @@ -75,6 +93,9 @@ The following table shows the evaluation scores for each of these models.
[groundedness]: https://cloud.google.com/vertex-ai/generative-ai/docs/models/metrics-templates#pointwise_groundedness
[guanaco]: https://huggingface.co/datasets/timdettmers/openassistant-guanaco
[herodotus]: https://en.wikipedia.org/wiki/Herodotus
[injection]: https://www.promptingguide.ai/prompts/adversarial-prompting/prompt-injection
[jailbreaking]: https://www.promptingguide.ai/prompts/adversarial-prompting/jailbreaking-llms
[leaking]: https://www.promptingguide.ai/prompts/adversarial-prompting/prompt-leaking
[pytorch]: https://pytorch.org/
[rouge]: https://cloud.google.com/vertex-ai/generative-ai/docs/models/determine-eval#rouge
[run]: https://cloud.google.com/run/docs/overview/what-is-cloud-run
Expand Down
55 changes: 49 additions & 6 deletions docs/services.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,57 @@ $ gcloud run jobs execute embeddings --region us-west1
## Reddit tool / agent

The [Reddit tool](../services/reddit-tool/) allows the LLM to read [r/travel][subreddit] posts based
upon a user query. The tool is packaged as a Vertex AI [Reasoning Engine agent][reasoning]. Internally,
the tool uses [LangChain][langchain] along with the Vertex AI Python SDK to perform its
magic.
upon a user query. The tool is packaged as a Vertex AI [Reasoning Engine agent][reasoning].
Internally, the tool uses [LangChain][langchain] along with the Vertex AI Python
SDK to perform its magic.

### Deploy the agent
**WARNING**: As of writing (2024-11-26), the Vertex AI Reasoning Engine agent
doesn't work as intended. Instead, the agent is published to Cloud Functions.

**NOTE**: You might need to install `pyenv` first before completing these instructions.
See [Troubleshooting](./troubleshooting.md) for more details.
### Test the agent locally (Cloud Functions)

1. Run the Cloud Function locally.

```sh
functions-framework-python --target get_agent_request
```

1. Send a request to the app with `curl`.

```sh
curl --header "Content-Type: application/json" \
--request POST \
--data '{"query":"I want to go to Crete. Where should I stay?"}' \
http://localhost:8080
```

Deployed location:
https://reddit-tool-1025771077852.us-west1.run.app

### Deploy the agent (Cloud Functions)

Run the following from the root of the reddit-tool directory.

```sh
gcloud functions deploy reddit-tool \
--gen2 \
--memory=512MB \
--timeout=120s \
--runtime=python312 \
--region=us-west1 \
--set-env-vars PROJECT_ID=${PROJECT_ID},BUCKET=${BUCKET} \
--source=. \
--entry-point=get_agent_request \
--trigger-http \
--allow-unauthenticated
```

### Deploy the agent (Reasoning Engine)

**NOTES**:
+ You might need to install `pyenv` first before completing these instructions.
See [Troubleshooting](./troubleshooting.md) for more details.
+

1. Create a virtual environment. The virtual environment needs to have Python v3.6 <= x <= v3.11.

Expand Down
49 changes: 49 additions & 0 deletions server/ai/reddit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package ai

import (
"context"
"fmt"

"github.com/vartanbeno/go-reddit/v2/reddit"
)

const subredditName = "travel"

func getRedditPosts(location string) (string, error) {
client, err := reddit.NewReadonlyClient()
if err != nil {
return "", err
}

ctx := context.Background()
posts, _, err := client.Subreddit.SearchPosts(ctx, location, subredditName, &reddit.ListPostSearchOptions{
ListPostOptions: reddit.ListPostOptions{
ListOptions: reddit.ListOptions{
Limit: 5,
},
Time: "all",
},
})
if err != nil {
return "", err
}

response := ""

for _, post := range posts {
if post.Body != "" {

postAndComments, _, err := client.Post.Get(ctx, post.ID)
if err != nil {
response += fmt.Sprintf("Title: %s, Post: %s",
post.Title, post.Body)
continue
}

response += fmt.Sprintf("Title: %s, Post: %s, Top Comment:\n",
post.Title, post.Body, postAndComments.Comments[0])
}
}

return response, nil
}
148 changes: 110 additions & 38 deletions server/ai/vertex.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ func Predict(query, modality, projectID string) (response string, templateName s
response, err = textPredictGemma(query, projectID)
case GeminiTuned:
response, err = textPredictGemini(query, projectID, GeminiTuned)
case AgentAssisted:
response, err = textPredictWithReddit(query, projectID)
default:
response, err = textPredictGemini(query, projectID, Gemini)
}
Expand Down Expand Up @@ -125,6 +127,51 @@ func SetConversationContext(convoHistory []generated.ConversationBit) error {
return nil
}

// storeConversationContext uploads past user conversations with the model into a Gen AI context.
// This context is used when the model is answering questions from the user.
func StoreConversationContext(conversationHistory []generated.ConversationBit, projectID string) (string, error) {
if len(conversationHistory) < MinimumConversationNum {
return "", &MinCacheNotReachedError{ConversationCount: len(conversationHistory)}
}

ctx := context.Background()
location := "us-west1"
client, err := genai.NewClient(ctx, projectID, location)
if err != nil {
return "", fmt.Errorf("unable to create client: %w", err)
}
defer client.Close()

var userParts []genai.Part
var modelParts []genai.Part
for _, p := range conversationHistory {
userParts = append(userParts, genai.Text(p.UserQuery))
modelParts = append(modelParts, genai.Text(p.BotResponse))
}

content := &genai.CachedContent{
Model: GeminiModel,
Expiration: genai.ExpireTimeOrTTL{TTL: 60 * time.Minute},
Contents: []*genai.Content{
{
Role: "user",
Parts: userParts,
},
{
Role: "model",
Parts: modelParts,
},
},
}
result, err := client.CreateCachedContent(ctx, content)
if err != nil {
return "", fmt.Errorf("CreateCachedContent: %w", err)
}
resourceName := result.Name

return resourceName, nil
}

// extractAnswer cleans up the response returned from the models
func extractAnswer(response string) string {
// I am not a regex expert :/
Expand Down Expand Up @@ -271,57 +318,82 @@ func getCandidate(resp *genai.GenerateContentResponse) (string, error) {
return string(candidate), nil
}

// storeConversationContext uploads past user conversations with the model into a Gen AI context.
// This context is used when the model is answering questions from the user.
func StoreConversationContext(conversationHistory []generated.ConversationBit, projectID string) (string, error) {
if len(conversationHistory) < MinimumConversationNum {
return "", &MinCacheNotReachedError{ConversationCount: len(conversationHistory)}
func trimContext() (last string) {
sep := "###"
convos := strings.Split(cachedContext, sep)
length := len(convos)
if len(convos) > 3 {
last = strings.Join(convos[length-3:length-1], sep)
}
return last
}

func textPredictWithReddit(query, projectID string) (string, error) {
funcName := "GetRedditPosts"
ctx := context.Background()
location := "us-west1"
client, err := genai.NewClient(ctx, projectID, location)
client, err := genai.NewClient(ctx, projectID, "us-west1")
if err != nil {
return "", fmt.Errorf("unable to create client: %w", err)
return "", err
}
defer client.Close()

var userParts []genai.Part
var modelParts []genai.Part
for _, p := range conversationHistory {
userParts = append(userParts, genai.Text(p.UserQuery))
modelParts = append(modelParts, genai.Text(p.BotResponse))
}

content := &genai.CachedContent{
Model: GeminiModel,
Expiration: genai.ExpireTimeOrTTL{TTL: 60 * time.Minute},
Contents: []*genai.Content{
{
Role: "user",
Parts: userParts,
},
{
Role: "model",
Parts: modelParts,
schema := &genai.Schema{
Type: genai.TypeObject,
Properties: map[string]*genai.Schema{
"location": {
Type: genai.TypeString,
Description: "the place the user wants to go, e.g. Crete, Greece",
},
},
Required: []string{"location"},
}
result, err := client.CreateCachedContent(ctx, content)

redditTool := &genai.Tool{
FunctionDeclarations: []*genai.FunctionDeclaration{{
Name: funcName,
Description: "Get Reddit posts about a location from the Travel subreddit",
Parameters: schema,
}},
}

model := client.GenerativeModel(GeminiModel)
model.Tools = []*genai.Tool{redditTool}

session := model.StartChat()

res, err := session.SendMessage(ctx, genai.Text(query))
if err != nil {
return "", fmt.Errorf("CreateCachedContent: %w", err)
return "", nil
}
resourceName := result.Name

return resourceName, nil
}
part := res.Candidates[0].Content.Parts[0]
funcCall, ok := part.(genai.FunctionCall)
if !ok {
return "", fmt.Errorf("expected function call: %v", part)
}
if funcCall.Name != funcName {
return "", fmt.Errorf("expected %s, got: %v", funcName, funcCall.Name)
}
locArg, ok := funcCall.Args["location"].(string)
if !ok {
return "", fmt.Errorf("expected string, got: %v", funcCall.Args["location"])
}

func trimContext() (last string) {
sep := "###"
convos := strings.Split(cachedContext, sep)
length := len(convos)
if len(convos) > 3 {
last = strings.Join(convos[length-3:length-1], sep)
redditData, err := getRedditPosts(locArg)
if err != nil {
return "", err
}
return last

res, err = session.SendMessage(ctx, genai.FunctionResponse{
Name: redditTool.FunctionDeclarations[0].Name,
Response: map[string]any{
"output": redditData,
},
})
if err != nil {
return "", err
}

output := string(res.Candidates[0].Content.Parts[0].(genai.Text))
return output, nil
}
6 changes: 5 additions & 1 deletion server/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ require (
github.com/hashicorp/go-retryablehttp v0.7.4
)

require cloud.google.com/go/longrunning v0.6.1 // indirect
require (
cloud.google.com/go/longrunning v0.6.1 // indirect
github.com/google/go-querystring v1.0.0 // indirect
github.com/vartanbeno/go-reddit/v2 v2.0.1 // indirect
)

require (
cloud.google.com/go v0.116.0 // indirect
Expand Down
Loading
Loading