From a31511835b17c559601780230331aa92875bd275 Mon Sep 17 00:00:00 2001 From: Tatiana Bradley Date: Thu, 19 Dec 2024 14:12:12 -1000 Subject: [PATCH] internal/llmapp: allow policy checker in overview functions Add an optional policy checker to the overviews client. When a policy checker is configured, all LLM inputs and outputs will be checked for safety against the configured policy. Not yet used by Gaby or anywhere else. For golang/oscar#70 Change-Id: I8d48048eae9651499ec937a8804ab554baca2316 Reviewed-on: https://go-review.googlesource.com/c/oscar/+/637977 LUCI-TryBot-Result: Go LUCI Reviewed-by: Hyang-Ah Hana Kim --- internal/llmapp/check.go | 84 +++++++++++++++++++++++++++++++++++ internal/llmapp/check_test.go | 69 ++++++++++++++++++++++++++++ internal/llmapp/data.go | 2 + internal/llmapp/overview.go | 18 ++++---- 4 files changed, 165 insertions(+), 8 deletions(-) create mode 100644 internal/llmapp/check.go create mode 100644 internal/llmapp/check_test.go diff --git a/internal/llmapp/check.go b/internal/llmapp/check.go new file mode 100644 index 0000000..03d1fa1 --- /dev/null +++ b/internal/llmapp/check.go @@ -0,0 +1,84 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package llmapp + +import ( + "context" + "log/slog" + + "golang.org/x/oscar/internal/llm" + "golang.org/x/oscar/internal/storage" +) + +// NewWithChecker is like [New], but it configures the Client to use +// the given checker to check the inputs to and outputs of the LLM against +// safety policies. +// +// When any of the Overview functions are called, the prompts and outputs of the LLM +// will be checked for safety violations. +func NewWithChecker(lg *slog.Logger, g llm.ContentGenerator, checker llm.PolicyChecker, db storage.DB) *Client { + return &Client{slog: lg, g: g, checker: checker, db: db} +} + +// hasPolicyViolation invokes the policy checker on the given prompts and LLM output and +// logs its results. It reports whether any policy violations were found. +// TODO(tatianabradley): Cache calls to policy checker. +func (c *Client) hasPolicyViolation(ctx context.Context, prompts []llm.Part, output string) bool { + if c.checker == nil { + return false + } + foundViolation := false + for _, p := range prompts { + switch v := p.(type) { + case llm.Text: + if c.logCheck(ctx, string(v), nil) { + foundViolation = true + } + default: + // Other types are not supported for checks yet. + c.slog.Info("llmapp: can't check policy for prompt part (unsupported type)", "prompt part", v) + } + } + if c.logCheck(ctx, output, prompts) { + return true + } + return foundViolation +} + +// logCheck invokes the policy checker on the give text (with optional prompts) +// and logs its results. +// It reports whether any policy violations were found. +func (c *Client) logCheck(ctx context.Context, text string, prompts []llm.Part) bool { + prs, err := c.checker.CheckText(ctx, text, prompts...) + if err != nil { + c.slog.Error("llmapp: error checking for policy violations", "err", err) + return false + } + c.slog.Info("llmapp: found policy results", "text", text, "prompts", prompts, "results", toStrings(prs)) + if vs := violations(prs); len(vs) > 0 { + c.slog.Warn("llmapp: found policy violations for LLM output", "text", text, "prompts", prompts, "violations", toStrings(vs)) + return true + } + return false +} + +func toStrings(prs []*llm.PolicyResult) []string { + var ss []string + for _, pr := range prs { + ss = append(ss, pr.String()) + } + return ss +} + +// violations returns the policies in prs that are in violation. +func violations(prs []*llm.PolicyResult) []*llm.PolicyResult { + var vs []*llm.PolicyResult + for _, pr := range prs { + if pr.IsViolative() { + vs = append(vs, pr) + } + } + return vs +} diff --git a/internal/llmapp/check_test.go b/internal/llmapp/check_test.go new file mode 100644 index 0000000..e78c937 --- /dev/null +++ b/internal/llmapp/check_test.go @@ -0,0 +1,69 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package llmapp + +import ( + "context" + "strings" + "testing" + + "golang.org/x/oscar/internal/llm" + "golang.org/x/oscar/internal/storage" + "golang.org/x/oscar/internal/testutil" +) + +func TestWithChecker(t *testing.T) { + lg := testutil.Slogger(t) + g := llm.EchoContentGenerator() + db := storage.MemDB() + checker := badChecker{} + c := NewWithChecker(lg, g, checker, db) + + // With violation. + doc1 := &Doc{URL: "https://example.com", Author: "rsc", Title: "title", Text: "some bad text"} + doc2 := &Doc{Text: "some good text 2"} + r, err := c.Overview(context.Background(), doc1, doc2) + if err != nil { + t.Fatal(err) + } + if !r.HasPolicyViolation { + t.Errorf("c.Overview.HasPolicyViolation = false, want true") + } + + // Without violation. + r, err = c.Overview(context.Background(), doc2) + if err != nil { + t.Fatal(err) + } + if r.HasPolicyViolation { + t.Errorf("c.Overview.HasPolicyViolation = true, want false") + } +} + +// badChecker is a test implementation of [llm.PolicyChecker] that +// always returns a policy violation for text containing the string "bad", +// and no violations otherwise. +type badChecker struct{} + +// no-op +func (badChecker) SetPolicies(_ []*llm.PolicyConfig) {} + +// return violation for text containing "bad" and no violation for any other text. +func (badChecker) CheckText(_ context.Context, text string, prompts ...llm.Part) ([]*llm.PolicyResult, error) { + if strings.Contains(text, "bad") { + return []*llm.PolicyResult{ + { + PolicyType: llm.PolicyTypeDangerousContent, + ViolationResult: llm.ViolationResultViolative, + }, + }, nil + } + return []*llm.PolicyResult{ + { + PolicyType: llm.PolicyTypeDangerousContent, + ViolationResult: llm.ViolationResultNonViolative, + }, + }, nil +} diff --git a/internal/llmapp/data.go b/internal/llmapp/data.go index f228536..30c877d 100644 --- a/internal/llmapp/data.go +++ b/internal/llmapp/data.go @@ -27,4 +27,6 @@ type Result struct { Cached bool // whether the response was cached Schema *llm.Schema // the JSON schema used to generate the result (nil if none) Prompt []llm.Part // the prompt(s) used to generate the result + // TODO(tatianabradley): Store the specific policy results instead of just a boolean. + HasPolicyViolation bool // whether any policy violations were found for the inputs or outputs of the LLM } diff --git a/internal/llmapp/overview.go b/internal/llmapp/overview.go index 7c2f4ec..8586c79 100644 --- a/internal/llmapp/overview.go +++ b/internal/llmapp/overview.go @@ -33,16 +33,17 @@ import ( // Client is a client for accessing the LLM application functionality. type Client struct { - slog *slog.Logger - g llm.ContentGenerator - db storage.DB // cache for LLM responses + slog *slog.Logger + g llm.ContentGenerator + checker llm.PolicyChecker + db storage.DB // cache for LLM responses } // New returns a new client. // g is the underlying LLM content generator to use, and db is the database // to use as a cache. func New(lg *slog.Logger, g llm.ContentGenerator, db storage.DB) *Client { - return &Client{slog: lg, g: g, db: db} + return NewWithChecker(lg, g, nil, db) } // Overview returns an LLM-generated overview of the given documents, @@ -101,10 +102,11 @@ func (c *Client) overview(ctx context.Context, kind docsKind, groups ...*docGrou return nil, err } return &Result{ - Response: overview, - Cached: cached, - Schema: schema, - Prompt: prompt, + Response: overview, + Cached: cached, + Schema: schema, + Prompt: prompt, + HasPolicyViolation: c.hasPolicyViolation(ctx, prompt, overview), }, nil }