Skip to content
This repository has been archived by the owner on Oct 30, 2024. It is now read-only.

add: -w/--keyword flag for retrieve/askdir + allow selecting multiple datasets #68

Merged
merged 3 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ toolchain go1.22.4
replace (
github.com/hupe1980/golc => github.com/iwilltry42/golc v0.0.113-0.20240802113826-d065a3c5b0c7 // nbformat extension
github.com/ledongthuc/pdf => github.com/iwilltry42/pdf v0.0.0-20240517145113-99fbaebc5dd3 // fix for reading some PDFs: https://github.com/ledongthuc/pdf/pull/36 + https://github.com/iwilltry42/pdf/pull/2
github.com/philippgille/chromem-go => github.com/iwilltry42/chromem-go v0.0.0-20240813194839-d838df05b583 // OpenAI Compat Fixes
github.com/philippgille/chromem-go => github.com/iwilltry42/chromem-go v0.0.0-20240814135107-86b4f217a8e8 // OpenAI Compat Fixes
github.com/tmc/langchaingo => github.com/StrongMonkey/langchaingo v0.0.0-20240617180437-9af4bee04c8b // Context-Aware Markdown Splitting
)

Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ github.com/hupe1980/go-tiktoken v0.0.9 h1:qNs/XGTe7UHDUaFkU+jAPbhGzyi9BusOpxrNC8
github.com/hupe1980/go-tiktoken v0.0.9/go.mod h1:NME6d8hrE+Jo+kLUZHhXShYV8e40hYkm4BbSLQKtvAo=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/iwilltry42/chromem-go v0.0.0-20240813194839-d838df05b583 h1:xTsr6cysGZGpu9xYaLiYItFu47Lh54jC49OwYX7fE2M=
github.com/iwilltry42/chromem-go v0.0.0-20240813194839-d838df05b583/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxceYh86iIdoKMolPo=
github.com/iwilltry42/chromem-go v0.0.0-20240814135107-86b4f217a8e8 h1:Tob2qUvv7zEeVNDb4kNhAmboaj0zUYUlZ+fcJg/ru14=
github.com/iwilltry42/chromem-go v0.0.0-20240814135107-86b4f217a8e8/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxceYh86iIdoKMolPo=
github.com/iwilltry42/golc v0.0.113-0.20240802113826-d065a3c5b0c7 h1:2AzzbKVW1iP2F+ovqJKq801l6tgxYPt9m2zFKbs+i/Y=
github.com/iwilltry42/golc v0.0.113-0.20240802113826-d065a3c5b0c7/go.mod h1:w692KzkSTSvXROfyu+jYauNXB4YaL1s8zHPDMnNW88o=
github.com/iwilltry42/pdf v0.0.0-20240517145113-99fbaebc5dd3 h1:rCVwFT7Q+HxpijWfSzKTYX4pCDMS7oy/I/WzU30VXyI=
Expand Down
2 changes: 1 addition & 1 deletion pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type Client interface {
AskDirectory(ctx context.Context, path string, query string, opts *IngestPathsOpts, ropts *datastore.RetrieveOpts) (*dstypes.RetrievalResponse, error)
PrunePath(ctx context.Context, datasetID string, path string, keep []string) ([]index.File, error)
DeleteDocuments(ctx context.Context, datasetID string, documentIDs ...string) error
Retrieve(ctx context.Context, datasetID string, query string, opts datastore.RetrieveOpts) (*dstypes.RetrievalResponse, error)
Retrieve(ctx context.Context, datasetIDs []string, query string, opts datastore.RetrieveOpts) (*dstypes.RetrievalResponse, error)
ExportDatasets(ctx context.Context, path string, datasets ...string) error
ImportDatasets(ctx context.Context, path string, datasets ...string) error
UpdateDataset(ctx context.Context, dataset index.Dataset, opts *datastore.UpdateDatasetOpts) (*index.Dataset, error)
Expand Down
2 changes: 1 addition & 1 deletion pkg/client/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ func AskDir(ctx context.Context, c Client, path string, query string, opts *Inge
slog.Debug("Ingested files", "count", ingested, "path", abspath)

// retrieve documents
return c.Retrieve(ctx, datasetID, query, *ropts)
return c.Retrieve(ctx, []string{datasetID}, query, *ropts)
}

func getOrCreateDataset(ctx context.Context, c Client, datasetID string, create bool) (*index.Dataset, error) {
Expand Down
5 changes: 3 additions & 2 deletions pkg/client/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func (c *DefaultClient) DeleteDocuments(_ context.Context, datasetID string, doc
return nil
}

func (c *DefaultClient) Retrieve(_ context.Context, datasetID string, query string, opts datastore.RetrieveOpts) (*dstypes.RetrievalResponse, error) {
func (c *DefaultClient) Retrieve(_ context.Context, datasetIDs []string, query string, opts datastore.RetrieveOpts) (*dstypes.RetrievalResponse, error) {
q := types.Query{Prompt: query}

if opts.TopK != 0 {
Expand All @@ -185,7 +185,8 @@ func (c *DefaultClient) Retrieve(_ context.Context, datasetID string, query stri
return nil, err
}

resp, err := c.request(http.MethodPost, fmt.Sprintf("/datasets/%s/retrieve", datasetID), bytes.NewBuffer(data))
// TODO: change to allow for multiple datasets
resp, err := c.request(http.MethodPost, fmt.Sprintf("/datasets/%s/retrieve", datasetIDs), bytes.NewBuffer(data))
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/client/standalone.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ func (c *StandaloneClient) DeleteDocuments(ctx context.Context, datasetID string
return nil
}

func (c *StandaloneClient) Retrieve(ctx context.Context, datasetID string, query string, opts datastore.RetrieveOpts) (*dstypes.RetrievalResponse, error) {
return c.Datastore.Retrieve(ctx, datasetID, query, opts)
func (c *StandaloneClient) Retrieve(ctx context.Context, datasetIDs []string, query string, opts datastore.RetrieveOpts) (*dstypes.RetrievalResponse, error) {
return c.Datastore.Retrieve(ctx, datasetIDs, query, opts)
}

func (c *StandaloneClient) AskDirectory(ctx context.Context, path string, query string, opts *IngestPathsOpts, ropts *datastore.RetrieveOpts) (*dstypes.RetrievalResponse, error) {
Expand Down
3 changes: 2 additions & 1 deletion pkg/cmd/askdir.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ func (s *ClientAskDir) Run(cmd *cobra.Command, args []string) error {
}

retrieveOpts := &datastore.RetrieveOpts{
TopK: s.TopK,
TopK: s.TopK,
Keywords: s.Keywords,
}

if s.FlowsFile != "" {
Expand Down
38 changes: 25 additions & 13 deletions pkg/cmd/retrieve.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ import (

type ClientRetrieve struct {
Client
Dataset string `usage:"Target Dataset ID" short:"d" default:"default" env:"KNOW_TARGET_DATASET"`
Archive string `usage:"Path to the archive file"`
Datasets []string `usage:"Target Dataset IDs" short:"d" default:"default" env:"KNOW_TARGET_DATASETS" name:"dataset"`
Archive string `usage:"Path to the archive file"`
ClientRetrieveOpts
ClientFlowsConfig
}

type ClientRetrieveOpts struct {
TopK int `usage:"Number of sources to retrieve" short:"k" default:"10"`
TopK int `usage:"Number of sources to retrieve" short:"k" default:"10"`
Keywords []string `usage:"Keywords that retrieved documents must contain" short:"w" name:"keyword" env:"KNOW_RETRIEVE_KEYWORDS"`
}

func (s *ClientRetrieve) Customize(cmd *cobra.Command) {
Expand All @@ -35,15 +36,19 @@ func (s *ClientRetrieve) Run(cmd *cobra.Command, args []string) error {
return err
}

datasetID := s.Dataset
datasetIDs := s.Datasets
if len(s.Datasets) == 0 {
datasetIDs = []string{"default"}
}
query := args[0]

retrieveOpts := datastore.RetrieveOpts{
TopK: s.TopK,
TopK: s.TopK,
Keywords: s.Keywords,
}

if s.FlowsFile != "" {
slog.Debug("Loading retrieval flows from config", "flows_file", s.FlowsFile, "dataset", datasetID)
slog.Debug("Loading retrieval flows from config", "flows_file", s.FlowsFile, "dataset", datasetIDs)
flowCfg, err := flowconfig.FromFile(s.FlowsFile)
if err != nil {
return err
Expand All @@ -55,29 +60,36 @@ func (s *ClientRetrieve) Run(cmd *cobra.Command, args []string) error {
return err
}
} else {
flow, err = flowCfg.ForDataset(datasetID) // get flow for the dataset
if err != nil {
return err
if len(datasetIDs) == 1 {
flow, err = flowCfg.ForDataset(datasetIDs[0]) // get flow for the dataset
if err != nil {
return err
}
} else {
flow, err = flowCfg.GetDefaultFlowConfigEntry()
if err != nil {
return err
}
}
}

if flow.Retrieval == nil {
slog.Info("No retrieval config in assigned flow", "flows_file", s.FlowsFile, "dataset", datasetID)
slog.Info("No retrieval config in assigned flow", "flows_file", s.FlowsFile, "dataset", datasetIDs)
} else {
rf, err := flow.Retrieval.AsRetrievalFlow()
if err != nil {
return err
}
retrieveOpts.RetrievalFlow = rf
slog.Debug("Loaded retrieval flow from config", "flows_file", s.FlowsFile, "dataset", datasetID)
slog.Debug("Loaded retrieval flow from config", "flows_file", s.FlowsFile, "dataset", datasetIDs)
}
}

retrievalResp, err := c.Retrieve(cmd.Context(), datasetID, query, retrieveOpts)
retrievalResp, err := c.Retrieve(cmd.Context(), datasetIDs, query, retrieveOpts)
if err != nil {
// An empty collection is not a hard error - the LLM session can "recover" from it
if errors.Is(err, vserr.ErrCollectionEmpty) {
fmt.Printf("Dataset %q does not contain any documents\n", datasetID)
fmt.Printf("Dataset %q does not contain any documents\n", datasetIDs)
return nil
}
return err
Expand Down
44 changes: 39 additions & 5 deletions pkg/datastore/retrieve.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package datastore
import (
"context"
"github.com/gptscript-ai/knowledge/pkg/datastore/types"
"github.com/philippgille/chromem-go"
"log/slog"

"github.com/gptscript-ai/knowledge/pkg/datastore/defaults"
Expand All @@ -12,11 +13,12 @@ import (

type RetrieveOpts struct {
TopK int
Keywords []string
RetrievalFlow *flows.RetrievalFlow
}

func (s *Datastore) Retrieve(ctx context.Context, datasetID string, query string, opts RetrieveOpts) (*types.RetrievalResponse, error) {
slog.Debug("Retrieving content from dataset", "dataset", datasetID, "query", query)
func (s *Datastore) Retrieve(ctx context.Context, datasetIDs []string, query string, opts RetrieveOpts) (*types.RetrievalResponse, error) {
slog.Debug("Retrieving content from dataset", "dataset", datasetIDs, "query", query)

retrievalFlow := opts.RetrievalFlow
if retrievalFlow == nil {
Expand All @@ -28,9 +30,41 @@ func (s *Datastore) Retrieve(ctx context.Context, datasetID string, query string
}
retrievalFlow.FillDefaults(topK)

return retrievalFlow.Run(ctx, s, query, datasetID)
var whereDocs []chromem.WhereDocument
if len(opts.Keywords) > 0 {
whereDoc := chromem.WhereDocument{
Operator: chromem.WhereDocumentOperatorOr,
WhereDocuments: []chromem.WhereDocument{},
}
whereDocNot := chromem.WhereDocument{
Operator: chromem.WhereDocumentOperatorAnd,
WhereDocuments: []chromem.WhereDocument{},
}
for _, kw := range opts.Keywords {
if kw[0] == '-' {
whereDocNot.WhereDocuments = append(whereDocNot.WhereDocuments, chromem.WhereDocument{
Operator: chromem.WhereDocumentOperatorNotContains,
Value: kw[1:],
})
} else {
whereDoc.WhereDocuments = append(whereDoc.WhereDocuments, chromem.WhereDocument{
Operator: chromem.WhereDocumentOperatorContains,
Value: kw,
})
}
}
if len(whereDoc.WhereDocuments) > 0 {
whereDocs = append(whereDocs, whereDoc)
}
if len(whereDocNot.WhereDocuments) > 0 {
whereDocs = append(whereDocs, whereDocNot)
}

}

return retrievalFlow.Run(ctx, s, query, datasetIDs, &flows.RetrievalFlowOpts{Where: nil, WhereDocument: whereDocs})
}

func (s *Datastore) SimilaritySearch(ctx context.Context, query string, numDocuments int, datasetID string) ([]vectorstore.Document, error) {
return s.Vectorstore.SimilaritySearch(ctx, query, numDocuments, datasetID)
func (s *Datastore) SimilaritySearch(ctx context.Context, query string, numDocuments int, datasetID string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vectorstore.Document, error) {
return s.Vectorstore.SimilaritySearch(ctx, query, numDocuments, datasetID, where, whereDocument)
}
27 changes: 27 additions & 0 deletions pkg/datastore/retrievers/keyword.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package retrievers

import (
"regexp"
"strings"
)

// regex pattern to match double-quoted substrings
var doubleQuotePattern = regexp.MustCompile(`"([^"]*)"`)

// Extract double-quoted substrings from a string
func ExtractQuotedSubstrings(input string) []string {

matches := doubleQuotePattern.FindAllStringSubmatch(input, -1)

var substrings []string
for _, match := range matches {
if len(match) > 1 {
m := strings.TrimSpace(match[1])
if m != "" {
substrings = append(substrings, m)
}
}
}

return substrings
}
19 changes: 16 additions & 3 deletions pkg/datastore/retrievers/retrievers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ import (
"context"
"fmt"
"github.com/gptscript-ai/knowledge/pkg/datastore/store"
"github.com/philippgille/chromem-go"
"log/slog"

"github.com/gptscript-ai/knowledge/pkg/datastore/defaults"
vs "github.com/gptscript-ai/knowledge/pkg/vectorstore"
)

type Retriever interface {
Retrieve(ctx context.Context, store store.Store, query string, datasetID string) ([]vs.Document, error)
Retrieve(ctx context.Context, store store.Store, query string, datasetIDs []string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vs.Document, error)
Name() string
}

Expand Down Expand Up @@ -42,11 +43,23 @@ func (r *BasicRetriever) Name() string {
return BasicRetrieverName
}

func (r *BasicRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetID string) ([]vs.Document, error) {
func (r *BasicRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetIDs []string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vs.Document, error) {

if len(datasetIDs) > 1 {
return nil, fmt.Errorf("basic retriever does not support querying multiple datasets")
}

var datasetID string
if len(datasetIDs) == 0 {
datasetID = "default"
} else {
datasetID = datasetIDs[0]
}

log := slog.With("retriever", r.Name())
if r.TopK <= 0 {
log.Debug("[BasicRetriever] TopK not set, using default", "default", defaults.TopK)
r.TopK = defaults.TopK
}
return store.SimilaritySearch(ctx, query, r.TopK, datasetID)
return store.SimilaritySearch(ctx, query, r.TopK, datasetID, where, whereDocument)
}
8 changes: 5 additions & 3 deletions pkg/datastore/retrievers/routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/gptscript-ai/knowledge/pkg/datastore/store"
"github.com/gptscript-ai/knowledge/pkg/llm"
vs "github.com/gptscript-ai/knowledge/pkg/vectorstore"
"github.com/philippgille/chromem-go"
"log/slog"
)

Expand Down Expand Up @@ -35,10 +36,11 @@ type routingResp struct {
Result string `json:"result"`
}

func (r *RoutingRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetID string) ([]vs.Document, error) {
func (r *RoutingRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetIDs []string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vs.Document, error) {
log := slog.With("component", "RoutingRetriever")

log.Debug("Ignoring input datasetID in routing retriever, as it chooses one by itself", "query", query, "inputDataset", datasetID)
// TODO: properly handle the datasetIDs input
log.Debug("Ignoring input datasetIDs in routing retriever, as it chooses one by itself", "query", query, "inputDataset", datasetIDs)

if r.TopK <= 0 {
log.Debug("TopK not set, using default", "default", defaults.TopK)
Expand Down Expand Up @@ -91,5 +93,5 @@ func (r *RoutingRetriever) Retrieve(ctx context.Context, store store.Store, quer

slog.Debug("Routing query to dataset", "query", query, "dataset", resp.Result)

return store.SimilaritySearch(ctx, query, r.TopK, resp.Result)
return store.SimilaritySearch(ctx, query, r.TopK, resp.Result, where, whereDocument)
}
17 changes: 15 additions & 2 deletions pkg/datastore/retrievers/subquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/gptscript-ai/knowledge/pkg/datastore/store"
"github.com/gptscript-ai/knowledge/pkg/llm"
vs "github.com/gptscript-ai/knowledge/pkg/vectorstore"
"github.com/philippgille/chromem-go"
"log/slog"
"strings"
)
Expand Down Expand Up @@ -37,7 +38,19 @@ type subqueryResp struct {
Results []string `json:"results"`
}

func (s SubqueryRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetID string) ([]vs.Document, error) {
func (s SubqueryRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetIDs []string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vs.Document, error) {

if len(datasetIDs) > 1 {
return nil, fmt.Errorf("basic retriever does not support querying multiple datasets")
}

var datasetID string
if len(datasetIDs) == 0 {
datasetID = "default"
} else {
datasetID = datasetIDs[0]
}

m, err := llm.NewFromConfig(s.Model)
if err != nil {
return nil, err
Expand Down Expand Up @@ -72,7 +85,7 @@ func (s SubqueryRetriever) Retrieve(ctx context.Context, store store.Store, quer

var resultDocs []vs.Document
for _, q := range queries {
docs, err := store.SimilaritySearch(ctx, q, s.TopK, datasetID)
docs, err := store.SimilaritySearch(ctx, q, s.TopK, datasetID, where, whereDocument)
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/datastore/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ import (
"context"
"github.com/gptscript-ai/knowledge/pkg/index"
vs "github.com/gptscript-ai/knowledge/pkg/vectorstore"
"github.com/philippgille/chromem-go"
)

type Store interface {
ListDatasets(ctx context.Context) ([]index.Dataset, error)
GetDataset(ctx context.Context, datasetID string) (*index.Dataset, error)
SimilaritySearch(ctx context.Context, query string, numDocuments int, collection string) ([]vs.Document, error)
SimilaritySearch(ctx context.Context, query string, numDocuments int, collection string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vs.Document, error)
}
Loading