diff --git a/go.mod b/go.mod index 573a90041..3e375d519 100644 --- a/go.mod +++ b/go.mod @@ -61,6 +61,7 @@ require ( github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/mitchellh/reflectwalk v1.0.0 // indirect github.com/oklog/ulid v1.3.1 // indirect + github.com/pgvector/pgvector-go v0.1.1 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.11.0 // indirect diff --git a/go.sum b/go.sum index a542393ac..80650e5d6 100644 --- a/go.sum +++ b/go.sum @@ -420,6 +420,8 @@ github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7J github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pelletier/go-toml v1.7.0/go.mod h1:vwGMzjaWMwyfHwgIBhI2YUM4fB6nL6lVAvS1LBMMhTE= +github.com/pgvector/pgvector-go v0.1.1 h1:kqJigGctFnlWvskUiYIvJRNwUtQl/aMSUZVs0YWQe+g= +github.com/pgvector/pgvector-go v0.1.1/go.mod h1:wLJgD/ODkdtd2LJK4l6evHXTuG+8PxymYAVomKHOWac= github.com/pinecone-io/go-pinecone v0.3.0 h1:+t0CiYaaA+JN6YM9QRNlvfLEr2kkGzcVEj/xNmSAON4= github.com/pinecone-io/go-pinecone v0.3.0/go.mod h1:VdSieE1r4jT3XydjFi+iL5w9qsGRz/x8LxWach2Hnv8= github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= diff --git a/vectorstores/chroma/doc.go b/vectorstores/chroma/doc.go index 9833449bf..46f969440 100644 --- a/vectorstores/chroma/doc.go +++ b/vectorstores/chroma/doc.go @@ -1,2 +1,2 @@ -// Package chroma contains an implementation of the vectorStore interface that connects to an external Chroma database. +// Package chroma contains an implementation of the VectorStore interface that connects to an external Chroma database. package chroma diff --git a/vectorstores/pgvector/doc.go b/vectorstores/pgvector/doc.go new file mode 100644 index 000000000..2f0c57a80 --- /dev/null +++ b/vectorstores/pgvector/doc.go @@ -0,0 +1,3 @@ +// Package pgvector contains an implementation of the VectorStore +// interface using pgvector. +package pgvector diff --git a/vectorstores/pgvector/options.go b/vectorstores/pgvector/options.go new file mode 100644 index 000000000..702a82ccd --- /dev/null +++ b/vectorstores/pgvector/options.go @@ -0,0 +1,92 @@ +package pgvector + +import ( + "errors" + "fmt" + "os" + + "github.com/jackc/pgx/v5" + "github.com/tmc/langchaingo/embeddings" +) + +const ( + DefaultCollectionName = "langchain" + DefaultPreDeleteCollection = false + DefaultEmbeddingStoreTableName = "langchain_pg_embedding" + DefaultCollectionStoreTableName = "langchain_pg_collection" +) + +// ErrInvalidOptions is returned when the options given are invalid. +var ErrInvalidOptions = errors.New("invalid options") + +// Option is a function type that can be used to modify the client. +type Option func(p *Store) + +// WithEmbedder is an option for setting the embedder to use. Must be set. +func WithEmbedder(e embeddings.Embedder) Option { + return func(p *Store) { + p.embedder = e + } +} + +// WithConnectionURL is an option for specifying the Postgres connection URL. Must be set. +func WithConnectionURL(connectionURL string) Option { + return func(p *Store) { + p.postgresConnectionURL = connectionURL + } +} + +// WithPreDeleteCollection is an option for setting if the collection should be deleted before creating. +func WithPreDeleteCollection(preDelete bool) Option { + return func(p *Store) { + p.preDeleteCollection = preDelete + } +} + +// WithCollectionName is an option for specifying the collection name. +func WithCollectionName(name string) Option { + return func(p *Store) { + p.collectionName = name + } +} + +// WithEmbeddingTableName is an option for specifying the embedding table name. +func WithEmbeddingTableName(name string) Option { + return func(p *Store) { + p.embeddingTableName = pgx.Identifier{name}.Sanitize() + } +} + +// WithCollectionTableName is an option for specifying the collection table name. +func WithCollectionTableName(name string) Option { + return func(p *Store) { + p.collectionTableName = pgx.Identifier{name}.Sanitize() + } +} + +func applyClientOptions(opts ...Option) (Store, error) { + o := &Store{ + collectionName: DefaultCollectionName, + preDeleteCollection: DefaultPreDeleteCollection, + embeddingTableName: DefaultEmbeddingStoreTableName, + collectionTableName: DefaultCollectionStoreTableName, + } + + for _, opt := range opts { + opt(o) + } + + if o.postgresConnectionURL == "" { + o.postgresConnectionURL = os.Getenv("PGVECTOR_CONNECTION_STRING") + } + + if o.postgresConnectionURL == "" { + return Store{}, fmt.Errorf("%w: missing postgresConnectionURL", ErrInvalidOptions) + } + + if o.embedder == nil { + return Store{}, fmt.Errorf("%w: missing embedder", ErrInvalidOptions) + } + + return *o, nil +} diff --git a/vectorstores/pgvector/pgvector.go b/vectorstores/pgvector/pgvector.go new file mode 100644 index 000000000..2c5d8811c --- /dev/null +++ b/vectorstores/pgvector/pgvector.go @@ -0,0 +1,335 @@ +package pgvector + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/pgvector/pgvector-go" + "github.com/tmc/langchaingo/embeddings" + "github.com/tmc/langchaingo/schema" + "github.com/tmc/langchaingo/vectorstores" +) + +const ( + // pgLockIDEmbeddingTable is used for advisor lock to fix issue arising from concurrent + // creation of the embedding table.The same value represents the same lock. + pgLockIDEmbeddingTable = 1573678846307946494 + // pgLockIDCollectionTable is used for advisor lock to fix issue arising from concurrent + // creation of the collection table.The same value represents the same lock. + pgLockIDCollectionTable = 1573678846307946495 + // pgLockIDExtension is used for advisor lock to fix issue arising from concurrent creation + // of the vector extension. The value is deliberately set to the same as python langchain + // https://github.com/langchain-ai/langchain/blob/v0.0.340/libs/langchain/langchain/vectorstores/pgvector.py#L167 + pgLockIDExtension = 1573678846307946496 +) + +var ( + ErrEmbedderWrongNumberVectors = errors.New("number of vectors from embedder does not match number of documents") + ErrInvalidScoreThreshold = errors.New("score threshold must be between 0 and 1") + ErrInvalidFilters = errors.New("invalid filters") + ErrUnsupportedOptions = errors.New("unsupported options") +) + +// Store is a wrapper around the pgvector client. +type Store struct { + embedder embeddings.Embedder + conn *pgx.Conn + postgresConnectionURL string + embeddingTableName string + collectionTableName string + collectionName string + collectionUUID string + collectionMetadata map[string]any + preDeleteCollection bool +} + +var _ vectorstores.VectorStore = Store{} + +// New creates a new Store with options. +func New(ctx context.Context, opts ...Option) (Store, error) { + store, err := applyClientOptions(opts...) + if err != nil { + return Store{}, err + } + store.conn, err = pgx.Connect(ctx, store.postgresConnectionURL) + if err != nil { + return Store{}, err + } + + if err = store.conn.Ping(ctx); err != nil { + return Store{}, err + } + + if err = store.createVectorExtensionIfNotExists(ctx); err != nil { + return Store{}, err + } + if err = store.createCollectionTableIfNotExists(ctx); err != nil { + return Store{}, err + } + if err = store.createEmbeddingTableIfNotExists(ctx); err != nil { + return Store{}, err + } + if store.preDeleteCollection { + if err = store.RemoveCollection(ctx); err != nil { + return Store{}, err + } + } + if err = store.createOrGetCollection(ctx); err != nil { + return Store{}, err + } + return store, nil +} + +func (s Store) createVectorExtensionIfNotExists(ctx context.Context) error { + tx, err := s.conn.Begin(ctx) + if err != nil { + return err + } + // inspired by + // https://github.com/langchain-ai/langchain/blob/v0.0.340/libs/langchain/langchain/vectorstores/pgvector.py#L167 + // The advisor lock fixes issue arising from concurrent + // creation of the vector extension. + // https://github.com/langchain-ai/langchain/issues/12933 + // For more information see: + // https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS + if _, err := tx.Exec(ctx, "SELECT pg_advisory_xact_lock($1)", pgLockIDExtension); err != nil { + return err + } + if _, err := tx.Exec(ctx, "CREATE EXTENSION IF NOT EXISTS vector"); err != nil { + return err + } + return tx.Commit(ctx) +} + +func (s Store) createCollectionTableIfNotExists(ctx context.Context) error { + tx, err := s.conn.Begin(ctx) + if err != nil { + return err + } + // inspired by + // https://github.com/langchain-ai/langchain/blob/v0.0.340/libs/langchain/langchain/vectorstores/pgvector.py#L167 + // The advisor lock fixes issue arising from concurrent + // creation of the vector extension. + // https://github.com/langchain-ai/langchain/issues/12933 + // For more information see: + // https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS + if _, err = tx.Exec(ctx, "SELECT pg_advisory_xact_lock($1)", pgLockIDCollectionTable); err != nil { + return err + } + sql := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s ( + name varchar, + cmetadata json, + "uuid" uuid NOT NULL, + PRIMARY KEY (uuid))`, s.collectionTableName) + if _, err = tx.Exec(ctx, sql); err != nil { + return err + } + return tx.Commit(ctx) +} + +func (s Store) createEmbeddingTableIfNotExists(ctx context.Context) error { + tx, err := s.conn.Begin(ctx) + if err != nil { + return err + } + // inspired by + // https://github.com/langchain-ai/langchain/blob/v0.0.340/libs/langchain/langchain/vectorstores/pgvector.py#L167 + // The advisor lock fixes issue arising from concurrent + // creation of the vector extension. + // https://github.com/langchain-ai/langchain/issues/12933 + // For more information see: + // https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS + if _, err := tx.Exec(ctx, "SELECT pg_advisory_xact_lock($1)", pgLockIDEmbeddingTable); err != nil { + return err + } + sql := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s ( + collection_id uuid, + embedding vector, + document varchar, + cmetadata json, + custom_id varchar, + "uuid" uuid NOT NULL, + CONSTRAINT langchain_pg_embedding_collection_id_fkey + FOREIGN KEY (collection_id) REFERENCES %s (uuid) ON DELETE CASCADE, +PRIMARY KEY (uuid))`, s.embeddingTableName, s.collectionTableName) + if _, err = tx.Exec(ctx, sql); err != nil { + return err + } + return tx.Commit(ctx) +} + +func (s Store) AddDocuments(ctx context.Context, docs []schema.Document, options ...vectorstores.Option) error { + opts := s.getOptions(options...) + if opts.ScoreThreshold != 0 || opts.Filters != nil || opts.NameSpace != "" { + return ErrUnsupportedOptions + } + + texts := make([]string, 0, len(docs)) + for _, doc := range docs { + texts = append(texts, doc.PageContent) + } + + embedder := s.embedder + if opts.Embedder != nil { + embedder = opts.Embedder + } + vectors, err := embedder.EmbedDocuments(ctx, texts) + if err != nil { + return err + } + + if len(vectors) != len(docs) { + return ErrEmbedderWrongNumberVectors + } + customID := uuid.New().String() + b := &pgx.Batch{} + sql := fmt.Sprintf(`INSERT INTO %s (uuid, document, embedding, cmetadata, custom_id, collection_id) + VALUES($1, $2, $3, $4, $5, $6)`, s.embeddingTableName) + for docIdx, doc := range docs { + id := uuid.New().String() + b.Queue(sql, id, doc.PageContent, pgvector.NewVector(vectors[docIdx]), doc.Metadata, customID, s.collectionUUID) + } + return s.conn.SendBatch(ctx, b).Close() +} + +//nolint:cyclop +func (s Store) SimilaritySearch( + ctx context.Context, + query string, + numDocuments int, + options ...vectorstores.Option, +) ([]schema.Document, error) { + opts := s.getOptions(options...) + collectionName := s.getNameSpace(opts) + scoreThreshold, err := s.getScoreThreshold(opts) + if err != nil { + return nil, err + } + filter, err := s.getFilters(opts) + if err != nil { + return nil, err + } + embedder := s.embedder + if opts.Embedder != nil { + embedder = opts.Embedder + } + embedderData, err := embedder.EmbedQuery(ctx, query) + if err != nil { + return nil, err + } + whereQuerys := make([]string, 0) + if scoreThreshold != 0 { + whereQuerys = append(whereQuerys, fmt.Sprintf("data.distance < %f", 1-scoreThreshold)) + } + for k, v := range filter { + whereQuerys = append(whereQuerys, fmt.Sprintf("(data.cmetadata ->> '%s') = '%s'", k, v)) + } + whereQuery := strings.Join(whereQuerys, " AND ") + if len(whereQuery) == 0 { + whereQuery = "TRUE" + } + sql := fmt.Sprintf(`SELECT + data.document, + data.cmetadata, + data.distance +FROM ( + SELECT + %s.*, + embedding <=> $1 AS distance + FROM + %s + JOIN %s ON %s.collection_id=%s.uuid WHERE %s.name='%s') AS data +WHERE %s +ORDER BY + data.distance +LIMIT $2`, s.embeddingTableName, + s.embeddingTableName, + s.collectionTableName, s.embeddingTableName, s.collectionTableName, s.collectionTableName, collectionName, + whereQuery) + rows, err := s.conn.Query(ctx, sql, pgvector.NewVector(embedderData), numDocuments) + if err != nil { + return nil, err + } + docs := make([]schema.Document, 0) + for rows.Next() { + doc := schema.Document{} + if err := rows.Scan(&doc.PageContent, &doc.Metadata, &doc.Score); err != nil { + return nil, err + } + docs = append(docs, doc) + } + return docs, nil +} + +// Close closes the connection. +func (s Store) Close(ctx context.Context) error { + return s.conn.Close(ctx) +} + +func (s Store) DropTables(ctx context.Context) error { + if _, err := s.conn.Exec(ctx, fmt.Sprintf(`DROP TABLE IF EXISTS %s`, s.collectionTableName)); err != nil { + return err + } + if _, err := s.conn.Exec(ctx, fmt.Sprintf(`DROP TABLE IF EXISTS %s`, s.embeddingTableName)); err != nil { + return err + } + return nil +} + +func (s Store) RemoveCollection(ctx context.Context) error { + _, err := s.conn.Exec(ctx, fmt.Sprintf(`DELETE FROM %s WHERE name = $1`, s.collectionTableName), s.collectionName) + return err +} + +func (s *Store) createOrGetCollection(ctx context.Context) error { + sql := fmt.Sprintf(`INSERT INTO %s (uuid, name, cmetadata) + VALUES($1, $2, $3) ON CONFLICT DO NOTHING`, s.collectionTableName) + if _, err := s.conn.Exec(ctx, sql, uuid.New().String(), s.collectionName, s.collectionMetadata); err != nil { + return err + } + sql = fmt.Sprintf(`SELECT uuid FROM %s WHERE name = $1 ORDER BY name limit 1`, s.collectionTableName) + if err := s.conn.QueryRow(ctx, sql, s.collectionName).Scan(&s.collectionUUID); err != nil { + return err + } + return nil +} + +// getOptions applies given options to default Options and returns it +// This uses options pattern so clients can easily pass options without changing function signature. +func (s Store) getOptions(options ...vectorstores.Option) vectorstores.Options { + opts := vectorstores.Options{} + for _, opt := range options { + opt(&opts) + } + return opts +} + +func (s Store) getNameSpace(opts vectorstores.Options) string { + if opts.NameSpace != "" { + return opts.NameSpace + } + return s.collectionName +} + +func (s Store) getScoreThreshold(opts vectorstores.Options) (float32, error) { + if opts.ScoreThreshold < 0 || opts.ScoreThreshold > 1 { + return 0, ErrInvalidScoreThreshold + } + return opts.ScoreThreshold, nil +} + +// getFilters return metadata filters, now only support map[key]value pattern +// TODO: should support more types like {"key1": {"key2":"values2"}} or {"key": ["value1", "values2"]}. +func (s Store) getFilters(opts vectorstores.Options) (map[string]any, error) { + if opts.Filters != nil { + if filters, ok := opts.Filters.(map[string]any); ok { + return filters, nil + } + return nil, ErrInvalidFilters + } + return map[string]any{}, nil +} diff --git a/vectorstores/pgvector/pgvector_test.go b/vectorstores/pgvector/pgvector_test.go new file mode 100644 index 000000000..5349a7bc4 --- /dev/null +++ b/vectorstores/pgvector/pgvector_test.go @@ -0,0 +1,409 @@ +package pgvector_test + +import ( + "context" + "fmt" + "os" + "strings" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "github.com/tmc/langchaingo/chains" + "github.com/tmc/langchaingo/embeddings" + "github.com/tmc/langchaingo/llms/openai" + "github.com/tmc/langchaingo/schema" + "github.com/tmc/langchaingo/vectorstores" + "github.com/tmc/langchaingo/vectorstores/pgvector" +) + +func preCheckEnvSetting(t *testing.T) { + t.Helper() + + pgvectorURL := os.Getenv("PGVECTOR_CONNECTION_STRING") + if pgvectorURL == "" { + t.Skip("Must set PGVECTOR_CONNECTION_STRING to run test") + } + + if openaiKey := os.Getenv("OPENAI_API_KEY"); openaiKey == "" { + t.Skip("OPENAI_API_KEY not set") + } +} + +func makeNewCollectionName() string { + return fmt.Sprintf("test-collection-%s", uuid.New().String()) +} + +func cleanupTestArtifacts(ctx context.Context, t *testing.T, s pgvector.Store) { + t.Helper() + require.NoError(t, s.RemoveCollection(ctx)) + require.NoError(t, s.Close(ctx)) +} + +func TestPgvectorStoreRest(t *testing.T) { + t.Parallel() + preCheckEnvSetting(t) + ctx := context.Background() + + llm, err := openai.New() + require.NoError(t, err) + e, err := embeddings.NewEmbedder(llm) + require.NoError(t, err) + + store, err := pgvector.New( + ctx, + pgvector.WithEmbedder(e), + pgvector.WithPreDeleteCollection(true), + pgvector.WithCollectionName(makeNewCollectionName()), + ) + require.NoError(t, err) + + defer cleanupTestArtifacts(ctx, t, store) + + err = store.AddDocuments(ctx, []schema.Document{ + {PageContent: "tokyo", Metadata: map[string]any{ + "country": "japan", + }}, + {PageContent: "potato"}, + }) + require.NoError(t, err) + + docs, err := store.SimilaritySearch(ctx, "japan", 1) + require.NoError(t, err) + require.Len(t, docs, 1) + require.Equal(t, "tokyo", docs[0].PageContent) + require.Equal(t, "japan", docs[0].Metadata["country"]) +} + +func TestPgvectorStoreRestWithScoreThreshold(t *testing.T) { + t.Parallel() + preCheckEnvSetting(t) + ctx := context.Background() + + llm, err := openai.New() + require.NoError(t, err) + e, err := embeddings.NewEmbedder(llm) + require.NoError(t, err) + + store, err := pgvector.New( + ctx, + pgvector.WithEmbedder(e), + pgvector.WithPreDeleteCollection(true), + pgvector.WithCollectionName(makeNewCollectionName()), + ) + require.NoError(t, err) + + defer cleanupTestArtifacts(ctx, t, store) + + err = store.AddDocuments(context.Background(), []schema.Document{ + {PageContent: "Tokyo"}, + {PageContent: "Yokohama"}, + {PageContent: "Osaka"}, + {PageContent: "Nagoya"}, + {PageContent: "Sapporo"}, + {PageContent: "Fukuoka"}, + {PageContent: "Dublin"}, + {PageContent: "Paris"}, + {PageContent: "London "}, + {PageContent: "New York"}, + }) + require.NoError(t, err) + + // test with a score threshold of 0.8, expected 6 documents + docs, err := store.SimilaritySearch( + ctx, + "Which of these are cities in Japan", + 10, + vectorstores.WithScoreThreshold(0.8), + ) + require.NoError(t, err) + require.Len(t, docs, 6) + + // test with a score threshold of 0, expected all 10 documents + docs, err = store.SimilaritySearch( + ctx, + "Which of these are cities in Japan", + 10, + vectorstores.WithScoreThreshold(0)) + require.NoError(t, err) + require.Len(t, docs, 10) +} + +func TestSimilaritySearchWithInvalidScoreThreshold(t *testing.T) { + t.Parallel() + preCheckEnvSetting(t) + ctx := context.Background() + + llm, err := openai.New() + require.NoError(t, err) + e, err := embeddings.NewEmbedder(llm) + require.NoError(t, err) + + store, err := pgvector.New( + ctx, + pgvector.WithEmbedder(e), + pgvector.WithPreDeleteCollection(true), + pgvector.WithCollectionName(makeNewCollectionName()), + ) + require.NoError(t, err) + + defer cleanupTestArtifacts(ctx, t, store) + + err = store.AddDocuments(ctx, []schema.Document{ + {PageContent: "Tokyo"}, + {PageContent: "Yokohama"}, + {PageContent: "Osaka"}, + {PageContent: "Nagoya"}, + {PageContent: "Sapporo"}, + {PageContent: "Fukuoka"}, + {PageContent: "Dublin"}, + {PageContent: "Paris"}, + {PageContent: "London "}, + {PageContent: "New York"}, + }) + require.NoError(t, err) + + _, err = store.SimilaritySearch( + ctx, + "Which of these are cities in Japan", + 10, + vectorstores.WithScoreThreshold(-0.8), + ) + require.Error(t, err) + + _, err = store.SimilaritySearch( + ctx, + "Which of these are cities in Japan", + 10, + vectorstores.WithScoreThreshold(1.8), + ) + require.Error(t, err) +} + +func TestPgvectorAsRetriever(t *testing.T) { + t.Parallel() + preCheckEnvSetting(t) + ctx := context.Background() + + llm, err := openai.New() + require.NoError(t, err) + e, err := embeddings.NewEmbedder(llm) + require.NoError(t, err) + + store, err := pgvector.New( + ctx, + pgvector.WithEmbedder(e), + pgvector.WithPreDeleteCollection(true), + pgvector.WithCollectionName(makeNewCollectionName()), + ) + require.NoError(t, err) + + defer cleanupTestArtifacts(ctx, t, store) + + err = store.AddDocuments( + ctx, + []schema.Document{ + {PageContent: "The color of the house is blue."}, + {PageContent: "The color of the car is red."}, + {PageContent: "The color of the desk is orange."}, + }, + ) + require.NoError(t, err) + + result, err := chains.Run( + ctx, + chains.NewRetrievalQAFromLLM( + llm, + vectorstores.ToRetriever(store, 1), + ), + "What color is the desk?", + ) + require.NoError(t, err) + require.True(t, strings.Contains(result, "orange"), "expected orange in result") +} + +func TestPgvectorAsRetrieverWithScoreThreshold(t *testing.T) { + t.Parallel() + preCheckEnvSetting(t) + ctx := context.Background() + + llm, err := openai.New() + require.NoError(t, err) + e, err := embeddings.NewEmbedder(llm) + require.NoError(t, err) + + store, err := pgvector.New( + ctx, + pgvector.WithEmbedder(e), + pgvector.WithPreDeleteCollection(true), + pgvector.WithCollectionName(makeNewCollectionName()), + ) + require.NoError(t, err) + + defer cleanupTestArtifacts(ctx, t, store) + + err = store.AddDocuments( + context.Background(), + []schema.Document{ + {PageContent: "The color of the house is blue."}, + {PageContent: "The color of the car is red."}, + {PageContent: "The color of the desk is orange."}, + {PageContent: "The color of the lamp beside the desk is black."}, + {PageContent: "The color of the chair beside the desk is beige."}, + }, + ) + require.NoError(t, err) + + result, err := chains.Run( + ctx, + chains.NewRetrievalQAFromLLM( + llm, + vectorstores.ToRetriever(store, 5, vectorstores.WithScoreThreshold(0.8)), + ), + "What colors is each piece of furniture next to the desk?", + ) + require.NoError(t, err) + + require.Contains(t, result, "orange", "expected orange in result") + require.Contains(t, result, "black", "expected black in result") + require.Contains(t, result, "beige", "expected beige in result") +} + +func TestPgvectorAsRetrieverWithMetadataFilterNotSelected(t *testing.T) { + t.Parallel() + preCheckEnvSetting(t) + ctx := context.Background() + + llm, err := openai.New() + require.NoError(t, err) + e, err := embeddings.NewEmbedder(llm) + require.NoError(t, err) + + store, err := pgvector.New( + ctx, + pgvector.WithEmbedder(e), + pgvector.WithPreDeleteCollection(true), + pgvector.WithCollectionName(makeNewCollectionName()), + ) + require.NoError(t, err) + + defer cleanupTestArtifacts(ctx, t, store) + + err = store.AddDocuments( + ctx, + []schema.Document{ + { + PageContent: "The color of the lamp beside the desk is black.", + Metadata: map[string]any{ + "location": "kitchen", + }, + }, + { + PageContent: "The color of the lamp beside the desk is blue.", + Metadata: map[string]any{ + "location": "bedroom", + }, + }, + { + PageContent: "The color of the lamp beside the desk is orange.", + Metadata: map[string]any{ + "location": "office", + }, + }, + { + PageContent: "The color of the lamp beside the desk is purple.", + Metadata: map[string]any{ + "location": "sitting room", + }, + }, + { + PageContent: "The color of the lamp beside the desk is yellow.", + Metadata: map[string]any{ + "location": "patio", + }, + }, + }, + ) + require.NoError(t, err) + + result, err := chains.Run( + ctx, + chains.NewRetrievalQAFromLLM( + llm, + vectorstores.ToRetriever(store, 5), + ), + "What color is the lamp in each room?", + ) + require.NoError(t, err) + + require.Contains(t, result, "black", "expected black in result") + require.Contains(t, result, "blue", "expected blue in result") + require.Contains(t, result, "orange", "expected orange in result") + require.Contains(t, result, "purple", "expected purple in result") + require.Contains(t, result, "yellow", "expected yellow in result") +} + +func TestPgvectorAsRetrieverWithMetadataFilters(t *testing.T) { + t.Parallel() + preCheckEnvSetting(t) + ctx := context.Background() + + llm, err := openai.New() + require.NoError(t, err) + e, err := embeddings.NewEmbedder(llm) + require.NoError(t, err) + + store, err := pgvector.New( + ctx, + pgvector.WithEmbedder(e), + pgvector.WithPreDeleteCollection(true), + pgvector.WithCollectionName(makeNewCollectionName()), + ) + require.NoError(t, err) + + defer cleanupTestArtifacts(ctx, t, store) + + err = store.AddDocuments( + context.Background(), + []schema.Document{ + { + PageContent: "The color of the lamp beside the desk is orange.", + Metadata: map[string]any{ + "location": "office", + "square_feet": 100, + }, + }, + { + PageContent: "The color of the lamp beside the desk is purple.", + Metadata: map[string]any{ + "location": "sitting room", + "square_feet": 400, + }, + }, + { + PageContent: "The color of the lamp beside the desk is yellow.", + Metadata: map[string]any{ + "location": "patio", + "square_feet": 800, + }, + }, + }, + ) + require.NoError(t, err) + + filter := map[string]any{"location": "sitting room"} + + result, err := chains.Run( + ctx, + chains.NewRetrievalQAFromLLM( + llm, + vectorstores.ToRetriever(store, + 5, + vectorstores.WithFilters(filter))), + "What color is the lamp in each room?", + ) + require.NoError(t, err) + require.Contains(t, result, "purple", "expected purple in result") + require.NotContains(t, result, "orange", "expected not orange in result") + require.NotContains(t, result, "yellow", "expected not yellow in result") +} diff --git a/vectorstores/pinecone/doc.go b/vectorstores/pinecone/doc.go index de6aeb60c..123235cc0 100644 --- a/vectorstores/pinecone/doc.go +++ b/vectorstores/pinecone/doc.go @@ -1,3 +1,3 @@ -// Package pinecone contains an implementation of the vectorStore +// Package pinecone contains an implementation of the VectorStore // interface using pinecone. package pinecone diff --git a/vectorstores/weaviate/doc.go b/vectorstores/weaviate/doc.go index edf6c70ae..82b826862 100644 --- a/vectorstores/weaviate/doc.go +++ b/vectorstores/weaviate/doc.go @@ -1,3 +1,3 @@ -// Package weaviate contains an implementation of the vectorStore +// Package weaviate contains an implementation of the VectorStore // interface using weaviate. package weaviate