Skip to content

Commit 6ab8141

Browse files
committed
[Go] align Embedder API with JS
- An EmbedRequest takes a slice of Documents instead of a single Document. - An EmbedResponse contains embeddings for each document. The []float32 containing the embedding is inside a struct, to accommodate future additions (and to match the JS). - The googleai embedder works on multiple documents sequentially. It should be changed to use the BatchEmbed RPC. - The vertexai embedder always handled multiple "instances". Now an instance is the concatenated text parts of a document; before it was one text part of the sole document. (This is the only behavioral change.) There is one unrelated change: the prompt of a generation test was changed because the previous prompt is now blocked for the "recitation" reason.
1 parent 19b88d5 commit 6ab8141

File tree

10 files changed

+134
-97
lines changed

10 files changed

+134
-97
lines changed

go/ai/embedder.go

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,33 +23,44 @@ import (
2323

2424
// An Embedder is used to convert a document to a
2525
// multidimensional vector.
26-
type Embedder core.Action[*EmbedRequest, []float32, struct{}]
26+
type Embedder core.Action[*EmbedRequest, *EmbedResponse, struct{}]
2727

28-
// EmbedRequest is the data we pass to convert a document
28+
// EmbedRequest is the data we pass to convert one or more documents
2929
// to a multidimensional vector.
3030
type EmbedRequest struct {
31-
Document *Document `json:"input"`
32-
Options any `json:"options,omitempty"`
31+
Documents []*Document `json:"input"`
32+
Options any `json:"options,omitempty"`
33+
}
34+
35+
type EmbedResponse struct {
36+
// One embedding for each Document in the request, in the same order.
37+
Embeddings []*DocumentEmbedding `json:"embeddings"`
38+
}
39+
40+
// DocumentEmbedding holds emdedding information about a single document.
41+
type DocumentEmbedding struct {
42+
// The vector for the embedding.
43+
Embedding []float32 `json:"embedding"`
3344
}
3445

3546
// DefineEmbedder registers the given embed function as an action, and returns an
3647
// [EmbedderAction] that runs it.
37-
func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedRequest) ([]float32, error)) *Embedder {
48+
func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedRequest) (*EmbedResponse, error)) *Embedder {
3849
return (*Embedder)(core.DefineAction(provider, name, atype.Embedder, nil, embed))
3950
}
4051

4152
// LookupEmbedder looks up an [EmbedderAction] registered by [DefineEmbedder].
4253
// It returns nil if the embedder was not defined.
4354
func LookupEmbedder(provider, name string) *Embedder {
44-
action := core.LookupActionFor[*EmbedRequest, []float32, struct{}](atype.Embedder, provider, name)
55+
action := core.LookupActionFor[*EmbedRequest, *EmbedResponse, struct{}](atype.Embedder, provider, name)
4556
if action == nil {
4657
return nil
4758
}
4859
return (*Embedder)(action)
4960
}
5061

5162
// Embed runs the given [Embedder].
52-
func (e *Embedder) Embed(ctx context.Context, req *EmbedRequest) ([]float32, error) {
53-
a := (*core.Action[*EmbedRequest, []float32, struct{}])(e)
63+
func (e *Embedder) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
64+
a := (*core.Action[*EmbedRequest, *EmbedResponse, struct{}])(e)
5465
return a.Run(ctx, req, nil)
5566
}

go/internal/fakeembedder/fakeembedder.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,14 @@ func (e *Embedder) Register(d *ai.Document, vals []float32) {
4343
e.registry[d] = vals
4444
}
4545

46-
func (e *Embedder) Embed(ctx context.Context, req *ai.EmbedRequest) ([]float32, error) {
47-
vals, ok := e.registry[req.Document]
48-
if !ok {
49-
return nil, errors.New("fake embedder called with unregistered document")
46+
func (e *Embedder) Embed(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
47+
res := &ai.EmbedResponse{}
48+
for _, doc := range req.Documents {
49+
vals, ok := e.registry[doc]
50+
if !ok {
51+
return nil, errors.New("fake embedder called with unregistered document")
52+
}
53+
res.Embeddings = append(res.Embeddings, &ai.DocumentEmbedding{Embedding: vals})
5054
}
51-
return vals, nil
55+
return res, nil
5256
}

go/internal/fakeembedder/fakeembedder_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,19 @@ func TestFakeEmbedder(t *testing.T) {
3131
embed.Register(d, vals)
3232

3333
req := &ai.EmbedRequest{
34-
Document: d,
34+
Documents: []*ai.Document{d},
3535
}
3636
ctx := context.Background()
37-
got, err := emb.Embed(ctx, req)
37+
res, err := emb.Embed(ctx, req)
3838
if err != nil {
3939
t.Fatal(err)
4040
}
41+
got := res.Embeddings[0].Embedding
4142
if !slices.Equal(got, vals) {
4243
t.Errorf("lookup returned %v, want %v", got, vals)
4344
}
4445

45-
req.Document = ai.DocumentFromText("missing document", nil)
46+
req.Documents[0] = ai.DocumentFromText("missing document", nil)
4647
if _, err = emb.Embed(ctx, req); err == nil {
4748
t.Error("embedding unknown document succeeded unexpectedly")
4849
}

go/plugins/googleai/googleai.go

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -162,17 +162,22 @@ func DefineEmbedder(name string) *ai.Embedder {
162162

163163
// requires state.mu
164164
func defineEmbedder(name string) *ai.Embedder {
165-
return ai.DefineEmbedder(provider, name, func(ctx context.Context, input *ai.EmbedRequest) ([]float32, error) {
165+
return ai.DefineEmbedder(provider, name, func(ctx context.Context, input *ai.EmbedRequest) (*ai.EmbedResponse, error) {
166+
// TODO: use the batch embedding API.
166167
em := state.client.EmbeddingModel(name)
167-
parts, err := convertParts(input.Document.Content)
168-
if err != nil {
169-
return nil, err
170-
}
171-
res, err := em.EmbedContent(ctx, parts...)
172-
if err != nil {
173-
return nil, err
168+
var res ai.EmbedResponse
169+
for _, doc := range input.Documents {
170+
parts, err := convertParts(doc.Content)
171+
if err != nil {
172+
return nil, err
173+
}
174+
eres, err := em.EmbedContent(ctx, parts...)
175+
if err != nil {
176+
return nil, err
177+
}
178+
res.Embeddings = append(res.Embeddings, &ai.DocumentEmbedding{Embedding: eres.Embedding.Values})
174179
}
175-
return res.Embedding.Values, nil
180+
return &res, nil
176181
})
177182
}
178183

go/plugins/googleai/googleai_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ func TestLive(t *testing.T) {
4646
if err != nil {
4747
t.Fatal(err)
4848
}
49-
embedder := googleai.DefineEmbedder("embedding-001")
50-
model, err := googleai.DefineModel("gemini-1.0-pro", nil)
49+
embedder := googleai.Embedder("embedding-001")
50+
model := googleai.Model("gemini-1.0-pro")
5151
if err != nil {
5252
t.Fatal(err)
5353
}
@@ -85,13 +85,13 @@ func TestLive(t *testing.T) {
8585
},
8686
)
8787
t.Run("embedder", func(t *testing.T) {
88-
out, err := embedder.Embed(ctx, &ai.EmbedRequest{
89-
Document: ai.DocumentFromText("yellow banana", nil),
88+
res, err := embedder.Embed(ctx, &ai.EmbedRequest{
89+
Documents: []*ai.Document{ai.DocumentFromText("yellow banana", nil)},
9090
})
9191
if err != nil {
9292
t.Fatal(err)
9393
}
94-
94+
out := res.Embeddings[0].Embedding
9595
// There's not a whole lot we can test about the result.
9696
// Just do a few sanity checks.
9797
if len(out) < 100 {
@@ -137,7 +137,7 @@ func TestLive(t *testing.T) {
137137
Candidates: 1,
138138
Messages: []*ai.Message{
139139
{
140-
Content: []*ai.Part{ai.NewTextPart("Write one paragraph about the Golden State Warriors.")},
140+
Content: []*ai.Part{ai.NewTextPart("Write one paragraph about the North Pole.")},
141141
Role: ai.RoleUser,
142142
},
143143
},
@@ -160,7 +160,7 @@ func TestLive(t *testing.T) {
160160
if out != out2 {
161161
t.Errorf("streaming and final should contain the same text.\nstreaming:%s\nfinal:%s", out, out2)
162162
}
163-
const want = "Golden"
163+
const want = "North"
164164
if !strings.Contains(out, want) {
165165
t.Errorf("got %q, expecting it to contain %q", out, want)
166166
}

go/plugins/localvec/localvec.go

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -119,21 +119,19 @@ func newDocStore(dir, name string, embedder *ai.Embedder, embedderOptions any) (
119119

120120
// index indexes a document.
121121
func (ds *docStore) index(ctx context.Context, req *ai.IndexerRequest) error {
122-
for _, doc := range req.Documents {
123-
ereq := &ai.EmbedRequest{
124-
Document: doc,
125-
Options: ds.embedderOptions,
126-
}
127-
vals, err := ds.embedder.Embed(ctx, ereq)
128-
if err != nil {
129-
return fmt.Errorf("localvec index embedding failed: %v", err)
130-
}
131-
132-
id, err := docID(doc)
122+
ereq := &ai.EmbedRequest{
123+
Documents: req.Documents,
124+
Options: ds.embedderOptions,
125+
}
126+
eres, err := ds.embedder.Embed(ctx, ereq)
127+
if err != nil {
128+
return fmt.Errorf("localvec index embedding failed: %v", err)
129+
}
130+
for i, de := range eres.Embeddings {
131+
id, err := docID(req.Documents[i])
133132
if err != nil {
134133
return err
135134
}
136-
137135
if _, ok := ds.data[id]; ok {
138136
logger.FromContext(ctx).Debug("localvec skipping document because already present", "id", id)
139137
continue
@@ -144,8 +142,8 @@ func (ds *docStore) index(ctx context.Context, req *ai.IndexerRequest) error {
144142
}
145143

146144
ds.data[id] = dbValue{
147-
Doc: doc,
148-
Embedding: vals,
145+
Doc: req.Documents[i],
146+
Embedding: de.Embedding,
149147
}
150148
}
151149

@@ -183,13 +181,14 @@ func (ds *docStore) retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
183181
// Use the embedder to convert the document we want to
184182
// retrieve into a vector.
185183
ereq := &ai.EmbedRequest{
186-
Document: req.Document,
187-
Options: ds.embedderOptions,
184+
Documents: []*ai.Document{req.Document},
185+
Options: ds.embedderOptions,
188186
}
189-
vals, err := ds.embedder.Embed(ctx, ereq)
187+
eres, err := ds.embedder.Embed(ctx, ereq)
190188
if err != nil {
191189
return nil, fmt.Errorf("localvec retrieve embedding failed: %v", err)
192190
}
191+
vals := eres.Embeddings[0].Embedding
193192

194193
type scoredDoc struct {
195194
score float64

go/plugins/pinecone/genkit.go

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -185,16 +185,16 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
185185

186186
// Use the embedder to convert each Document into a vector.
187187
vecs := make([]vector, 0, len(req.Documents))
188-
for _, doc := range req.Documents {
189-
ereq := &ai.EmbedRequest{
190-
Document: doc,
191-
Options: ds.embedderOptions,
192-
}
193-
vals, err := ds.embedder.Embed(ctx, ereq)
194-
if err != nil {
195-
return fmt.Errorf("pinecone index embedding failed: %v", err)
196-
}
197-
188+
ereq := &ai.EmbedRequest{
189+
Documents: req.Documents,
190+
Options: ds.embedderOptions,
191+
}
192+
eres, err := ds.embedder.Embed(ctx, ereq)
193+
if err != nil {
194+
return fmt.Errorf("pinecone index embedding failed: %v", err)
195+
}
196+
for i, de := range eres.Embeddings {
197+
doc := req.Documents[i]
198198
id, err := docID(doc)
199199
if err != nil {
200200
return err
@@ -216,7 +216,7 @@ func (ds *docStore) Index(ctx context.Context, req *ai.IndexerRequest) error {
216216

217217
v := vector{
218218
ID: id,
219-
Values: vals,
219+
Values: de.Embedding,
220220
Metadata: metadata,
221221
}
222222
vecs = append(vecs, v)
@@ -282,15 +282,15 @@ func (ds *docStore) Retrieve(ctx context.Context, req *ai.RetrieverRequest) (*ai
282282
// Use the embedder to convert the document we want to
283283
// retrieve into a vector.
284284
ereq := &ai.EmbedRequest{
285-
Document: req.Document,
286-
Options: ds.embedderOptions,
285+
Documents: []*ai.Document{req.Document},
286+
Options: ds.embedderOptions,
287287
}
288-
vals, err := ds.embedder.Embed(ctx, ereq)
288+
eres, err := ds.embedder.Embed(ctx, ereq)
289289
if err != nil {
290290
return nil, fmt.Errorf("pinecone retrieve embedding failed: %v", err)
291291
}
292292

293-
results, err := ds.index.query(ctx, vals, count, wantMetadata, namespace)
293+
results, err := ds.index.query(ctx, eres.Embeddings[0].Embedding, count, wantMetadata, namespace)
294294
if err != nil {
295295
return nil, err
296296
}

go/plugins/vertexai/embed.go

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ package vertexai
1616

1717
import (
1818
"context"
19-
"errors"
19+
"fmt"
20+
"strings"
2021

2122
aiplatform "cloud.google.com/go/aiplatform/apiv1"
2223
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
@@ -34,7 +35,7 @@ type EmbedOptions struct {
3435
TaskType string `json:"task_type,omitempty"`
3536
}
3637

37-
func embed(ctx context.Context, reqEndpoint string, client *aiplatform.PredictionClient, req *ai.EmbedRequest) ([]float32, error) {
38+
func embed(ctx context.Context, reqEndpoint string, client *aiplatform.PredictionClient, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
3839
preq, err := newPredictRequest(reqEndpoint, req)
3940
if err != nil {
4041
return nil, err
@@ -44,32 +45,34 @@ func embed(ctx context.Context, reqEndpoint string, client *aiplatform.Predictio
4445
return nil, err
4546
}
4647

47-
// TODO(ianlancetaylor): This can return multiple vectors.
48-
// We just use the first one for now.
49-
50-
if len(resp.Predictions) < 1 {
51-
return nil, errors.New("vertexai: embed request returned no values")
48+
if g, w := len(resp.Predictions), len(req.Documents); g != w {
49+
return nil, fmt.Errorf("vertexai: got %d embeddings, expected %d", g, w)
5250
}
5351

54-
values := resp.Predictions[0].GetStructValue().Fields["embeddings"].GetStructValue().Fields["values"].GetListValue().Values
55-
ret := make([]float32, len(values))
56-
for i, value := range values {
57-
ret[i] = float32(value.GetNumberValue())
52+
ret := &ai.EmbedResponse{}
53+
for _, pred := range resp.Predictions {
54+
values := pred.GetStructValue().Fields["embeddings"].GetStructValue().Fields["values"].GetListValue().Values
55+
vals := make([]float32, len(values))
56+
for i, value := range values {
57+
vals[i] = float32(value.GetNumberValue())
58+
}
59+
ret.Embeddings = append(ret.Embeddings, &ai.DocumentEmbedding{Embedding: vals})
5860
}
59-
6061
return ret, nil
6162
}
6263

64+
// newPredictRequest creates a PredictRequest from an EmbedRequest.
65+
// Each Document in the EmbedRequest becomes a separate instance in the PredictRequest.
6366
func newPredictRequest(endpoint string, req *ai.EmbedRequest) (*aiplatformpb.PredictRequest, error) {
6467
var title, taskType string
6568
if options, _ := req.Options.(*EmbedOptions); options != nil {
6669
title = options.Title
6770
taskType = options.TaskType
6871
}
69-
instances := make([]*structpb.Value, 0, len(req.Document.Content))
70-
for _, part := range req.Document.Content {
72+
instances := make([]*structpb.Value, 0, len(req.Documents))
73+
for _, doc := range req.Documents {
7174
fields := map[string]any{
72-
"content": part.Text,
75+
"content": text(doc),
7376
}
7477
if title != "" {
7578
fields["title"] = title
@@ -90,3 +93,15 @@ func newPredictRequest(endpoint string, req *ai.EmbedRequest) (*aiplatformpb.Pre
9093
Instances: instances,
9194
}, nil
9295
}
96+
97+
// text concatenates all the text parts of the document together,
98+
// with no delimiter.
99+
func text(d *ai.Document) string {
100+
var b strings.Builder
101+
for _, p := range d.Content {
102+
if p.IsText() {
103+
b.WriteString(p.Text)
104+
}
105+
}
106+
return b.String()
107+
}

go/plugins/vertexai/vertexai.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ func DefineEmbedder(name string) *ai.Embedder {
177177
panic("vertexai.Init not called")
178178
}
179179
fullName := fmt.Sprintf("projects/%s/locations/%s/publishers/google/models/%s", state.projectID, state.location, name)
180-
return ai.DefineEmbedder(provider, name, func(ctx context.Context, req *ai.EmbedRequest) ([]float32, error) {
180+
return ai.DefineEmbedder(provider, name, func(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
181181
return embed(ctx, fullName, state.pclient, req)
182182
})
183183
}

0 commit comments

Comments
 (0)