Skip to content

Commit

Permalink
internal/llmapp: allow policy checker in overview functions
Browse files Browse the repository at this point in the history
Add an optional policy checker to the overviews client. When a
policy checker is configured, all LLM inputs and outputs will be checked
for safety against the configured policy.

Not yet used by Gaby or anywhere else.

For #70

Change-Id: I8d48048eae9651499ec937a8804ab554baca2316
Reviewed-on: https://go-review.googlesource.com/c/oscar/+/637977
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Hyang-Ah Hana Kim <hyangah@gmail.com>
  • Loading branch information
tatianab committed Dec 20, 2024
1 parent 158f50b commit a315118
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 8 deletions.
84 changes: 84 additions & 0 deletions internal/llmapp/check.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package llmapp

import (
"context"
"log/slog"

"golang.org/x/oscar/internal/llm"
"golang.org/x/oscar/internal/storage"
)

// NewWithChecker is like [New], but it configures the Client to use
// the given checker to check the inputs to and outputs of the LLM against
// safety policies.
//
// When any of the Overview functions are called, the prompts and outputs of the LLM
// will be checked for safety violations.
func NewWithChecker(lg *slog.Logger, g llm.ContentGenerator, checker llm.PolicyChecker, db storage.DB) *Client {
return &Client{slog: lg, g: g, checker: checker, db: db}
}

// hasPolicyViolation invokes the policy checker on the given prompts and LLM output and
// logs its results. It reports whether any policy violations were found.
// TODO(tatianabradley): Cache calls to policy checker.
func (c *Client) hasPolicyViolation(ctx context.Context, prompts []llm.Part, output string) bool {
if c.checker == nil {
return false
}
foundViolation := false
for _, p := range prompts {
switch v := p.(type) {
case llm.Text:
if c.logCheck(ctx, string(v), nil) {
foundViolation = true
}
default:
// Other types are not supported for checks yet.
c.slog.Info("llmapp: can't check policy for prompt part (unsupported type)", "prompt part", v)
}
}
if c.logCheck(ctx, output, prompts) {
return true
}
return foundViolation
}

// logCheck invokes the policy checker on the give text (with optional prompts)
// and logs its results.
// It reports whether any policy violations were found.
func (c *Client) logCheck(ctx context.Context, text string, prompts []llm.Part) bool {
prs, err := c.checker.CheckText(ctx, text, prompts...)
if err != nil {
c.slog.Error("llmapp: error checking for policy violations", "err", err)
return false
}
c.slog.Info("llmapp: found policy results", "text", text, "prompts", prompts, "results", toStrings(prs))
if vs := violations(prs); len(vs) > 0 {
c.slog.Warn("llmapp: found policy violations for LLM output", "text", text, "prompts", prompts, "violations", toStrings(vs))
return true
}
return false
}

func toStrings(prs []*llm.PolicyResult) []string {
var ss []string
for _, pr := range prs {
ss = append(ss, pr.String())
}
return ss
}

// violations returns the policies in prs that are in violation.
func violations(prs []*llm.PolicyResult) []*llm.PolicyResult {
var vs []*llm.PolicyResult
for _, pr := range prs {
if pr.IsViolative() {
vs = append(vs, pr)
}
}
return vs
}
69 changes: 69 additions & 0 deletions internal/llmapp/check_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package llmapp

import (
"context"
"strings"
"testing"

"golang.org/x/oscar/internal/llm"
"golang.org/x/oscar/internal/storage"
"golang.org/x/oscar/internal/testutil"
)

func TestWithChecker(t *testing.T) {
lg := testutil.Slogger(t)
g := llm.EchoContentGenerator()
db := storage.MemDB()
checker := badChecker{}
c := NewWithChecker(lg, g, checker, db)

// With violation.
doc1 := &Doc{URL: "https://example.com", Author: "rsc", Title: "title", Text: "some bad text"}
doc2 := &Doc{Text: "some good text 2"}
r, err := c.Overview(context.Background(), doc1, doc2)
if err != nil {
t.Fatal(err)
}
if !r.HasPolicyViolation {
t.Errorf("c.Overview.HasPolicyViolation = false, want true")
}

// Without violation.
r, err = c.Overview(context.Background(), doc2)
if err != nil {
t.Fatal(err)
}
if r.HasPolicyViolation {
t.Errorf("c.Overview.HasPolicyViolation = true, want false")
}
}

// badChecker is a test implementation of [llm.PolicyChecker] that
// always returns a policy violation for text containing the string "bad",
// and no violations otherwise.
type badChecker struct{}

// no-op
func (badChecker) SetPolicies(_ []*llm.PolicyConfig) {}

// return violation for text containing "bad" and no violation for any other text.
func (badChecker) CheckText(_ context.Context, text string, prompts ...llm.Part) ([]*llm.PolicyResult, error) {
if strings.Contains(text, "bad") {
return []*llm.PolicyResult{
{
PolicyType: llm.PolicyTypeDangerousContent,
ViolationResult: llm.ViolationResultViolative,
},
}, nil
}
return []*llm.PolicyResult{
{
PolicyType: llm.PolicyTypeDangerousContent,
ViolationResult: llm.ViolationResultNonViolative,
},
}, nil
}
2 changes: 2 additions & 0 deletions internal/llmapp/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,6 @@ type Result struct {
Cached bool // whether the response was cached
Schema *llm.Schema // the JSON schema used to generate the result (nil if none)
Prompt []llm.Part // the prompt(s) used to generate the result
// TODO(tatianabradley): Store the specific policy results instead of just a boolean.
HasPolicyViolation bool // whether any policy violations were found for the inputs or outputs of the LLM
}
18 changes: 10 additions & 8 deletions internal/llmapp/overview.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,17 @@ import (

// Client is a client for accessing the LLM application functionality.
type Client struct {
slog *slog.Logger
g llm.ContentGenerator
db storage.DB // cache for LLM responses
slog *slog.Logger
g llm.ContentGenerator
checker llm.PolicyChecker
db storage.DB // cache for LLM responses
}

// New returns a new client.
// g is the underlying LLM content generator to use, and db is the database
// to use as a cache.
func New(lg *slog.Logger, g llm.ContentGenerator, db storage.DB) *Client {
return &Client{slog: lg, g: g, db: db}
return NewWithChecker(lg, g, nil, db)
}

// Overview returns an LLM-generated overview of the given documents,
Expand Down Expand Up @@ -101,10 +102,11 @@ func (c *Client) overview(ctx context.Context, kind docsKind, groups ...*docGrou
return nil, err
}
return &Result{
Response: overview,
Cached: cached,
Schema: schema,
Prompt: prompt,
Response: overview,
Cached: cached,
Schema: schema,
Prompt: prompt,
HasPolicyViolation: c.hasPolicyViolation(ctx, prompt, overview),
}, nil
}

Expand Down

0 comments on commit a315118

Please sign in to comment.