-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
internal/llmapp: allow policy checker in overview functions
Add an optional policy checker to the overviews client. When a policy checker is configured, all LLM inputs and outputs will be checked for safety against the configured policy. Not yet used by Gaby or anywhere else. For #70 Change-Id: I8d48048eae9651499ec937a8804ab554baca2316 Reviewed-on: https://go-review.googlesource.com/c/oscar/+/637977 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Hyang-Ah Hana Kim <hyangah@gmail.com>
- Loading branch information
Showing
4 changed files
with
165 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
// Copyright 2024 The Go Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package llmapp | ||
|
||
import ( | ||
"context" | ||
"log/slog" | ||
|
||
"golang.org/x/oscar/internal/llm" | ||
"golang.org/x/oscar/internal/storage" | ||
) | ||
|
||
// NewWithChecker is like [New], but it configures the Client to use | ||
// the given checker to check the inputs to and outputs of the LLM against | ||
// safety policies. | ||
// | ||
// When any of the Overview functions are called, the prompts and outputs of the LLM | ||
// will be checked for safety violations. | ||
func NewWithChecker(lg *slog.Logger, g llm.ContentGenerator, checker llm.PolicyChecker, db storage.DB) *Client { | ||
return &Client{slog: lg, g: g, checker: checker, db: db} | ||
} | ||
|
||
// hasPolicyViolation invokes the policy checker on the given prompts and LLM output and | ||
// logs its results. It reports whether any policy violations were found. | ||
// TODO(tatianabradley): Cache calls to policy checker. | ||
func (c *Client) hasPolicyViolation(ctx context.Context, prompts []llm.Part, output string) bool { | ||
if c.checker == nil { | ||
return false | ||
} | ||
foundViolation := false | ||
for _, p := range prompts { | ||
switch v := p.(type) { | ||
case llm.Text: | ||
if c.logCheck(ctx, string(v), nil) { | ||
foundViolation = true | ||
} | ||
default: | ||
// Other types are not supported for checks yet. | ||
c.slog.Info("llmapp: can't check policy for prompt part (unsupported type)", "prompt part", v) | ||
} | ||
} | ||
if c.logCheck(ctx, output, prompts) { | ||
return true | ||
} | ||
return foundViolation | ||
} | ||
|
||
// logCheck invokes the policy checker on the give text (with optional prompts) | ||
// and logs its results. | ||
// It reports whether any policy violations were found. | ||
func (c *Client) logCheck(ctx context.Context, text string, prompts []llm.Part) bool { | ||
prs, err := c.checker.CheckText(ctx, text, prompts...) | ||
if err != nil { | ||
c.slog.Error("llmapp: error checking for policy violations", "err", err) | ||
return false | ||
} | ||
c.slog.Info("llmapp: found policy results", "text", text, "prompts", prompts, "results", toStrings(prs)) | ||
if vs := violations(prs); len(vs) > 0 { | ||
c.slog.Warn("llmapp: found policy violations for LLM output", "text", text, "prompts", prompts, "violations", toStrings(vs)) | ||
return true | ||
} | ||
return false | ||
} | ||
|
||
func toStrings(prs []*llm.PolicyResult) []string { | ||
var ss []string | ||
for _, pr := range prs { | ||
ss = append(ss, pr.String()) | ||
} | ||
return ss | ||
} | ||
|
||
// violations returns the policies in prs that are in violation. | ||
func violations(prs []*llm.PolicyResult) []*llm.PolicyResult { | ||
var vs []*llm.PolicyResult | ||
for _, pr := range prs { | ||
if pr.IsViolative() { | ||
vs = append(vs, pr) | ||
} | ||
} | ||
return vs | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
// Copyright 2024 The Go Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package llmapp | ||
|
||
import ( | ||
"context" | ||
"strings" | ||
"testing" | ||
|
||
"golang.org/x/oscar/internal/llm" | ||
"golang.org/x/oscar/internal/storage" | ||
"golang.org/x/oscar/internal/testutil" | ||
) | ||
|
||
func TestWithChecker(t *testing.T) { | ||
lg := testutil.Slogger(t) | ||
g := llm.EchoContentGenerator() | ||
db := storage.MemDB() | ||
checker := badChecker{} | ||
c := NewWithChecker(lg, g, checker, db) | ||
|
||
// With violation. | ||
doc1 := &Doc{URL: "https://example.com", Author: "rsc", Title: "title", Text: "some bad text"} | ||
doc2 := &Doc{Text: "some good text 2"} | ||
r, err := c.Overview(context.Background(), doc1, doc2) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
if !r.HasPolicyViolation { | ||
t.Errorf("c.Overview.HasPolicyViolation = false, want true") | ||
} | ||
|
||
// Without violation. | ||
r, err = c.Overview(context.Background(), doc2) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
if r.HasPolicyViolation { | ||
t.Errorf("c.Overview.HasPolicyViolation = true, want false") | ||
} | ||
} | ||
|
||
// badChecker is a test implementation of [llm.PolicyChecker] that | ||
// always returns a policy violation for text containing the string "bad", | ||
// and no violations otherwise. | ||
type badChecker struct{} | ||
|
||
// no-op | ||
func (badChecker) SetPolicies(_ []*llm.PolicyConfig) {} | ||
|
||
// return violation for text containing "bad" and no violation for any other text. | ||
func (badChecker) CheckText(_ context.Context, text string, prompts ...llm.Part) ([]*llm.PolicyResult, error) { | ||
if strings.Contains(text, "bad") { | ||
return []*llm.PolicyResult{ | ||
{ | ||
PolicyType: llm.PolicyTypeDangerousContent, | ||
ViolationResult: llm.ViolationResultViolative, | ||
}, | ||
}, nil | ||
} | ||
return []*llm.PolicyResult{ | ||
{ | ||
PolicyType: llm.PolicyTypeDangerousContent, | ||
ViolationResult: llm.ViolationResultNonViolative, | ||
}, | ||
}, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters