Skip to content
This repository has been archived by the owner on Oct 30, 2024. It is now read-only.

Commit

Permalink
add: expose config options for textsplitter (+resolve import cycle)
Browse files Browse the repository at this point in the history
  • Loading branch information
iwilltry42 committed May 10, 2024
1 parent aa4390f commit 8ae6359
Show file tree
Hide file tree
Showing 15 changed files with 127 additions and 58 deletions.
1 change: 1 addition & 0 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ type IngestPathsOpts struct {
IgnoreExtensions []string
Concurrency int
Recursive bool
TextSplitterOpts *datastore.TextSplitterOpts
}

type RetrieveOpts struct {
Expand Down
3 changes: 3 additions & 0 deletions pkg/client/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ func (c *DefaultClient) IngestPaths(ctx context.Context, datasetID string, opts
},
IsDuplicateFuncName: "file_metadata",
}
if opts != nil {
payload.TextSplitterOpts = opts.TextSplitterOpts
}
_, err = c.Ingest(ctx, datasetID, content, payload)
return err
}
Expand Down
38 changes: 30 additions & 8 deletions pkg/client/standalone.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,19 @@ func NewStandaloneClient(ds *datastore.Datastore) (*StandaloneClient, error) {
}

func (c *StandaloneClient) CreateDataset(ctx context.Context, datasetID string) (types.Dataset, error) {
ds := types.Dataset{
ds := index.Dataset{
ID: datasetID,
EmbedDimension: nil,
EmbedDimension: 0,
}
r := types.Dataset{
ID: ds.ID,
EmbedDimension: z.Pointer(ds.EmbedDimension),
}
err := c.Datastore.NewDataset(ctx, ds)
if err != nil {
return ds, err
return r, err
}
return ds, nil
return r, nil
}

func (c *StandaloneClient) DeleteDataset(ctx context.Context, datasetID string) error {
Expand All @@ -43,7 +47,18 @@ func (c *StandaloneClient) GetDataset(ctx context.Context, datasetID string) (*i
}

func (c *StandaloneClient) ListDatasets(ctx context.Context) ([]types.Dataset, error) {
return c.Datastore.ListDatasets(ctx)
ds, err := c.Datastore.ListDatasets(ctx)
if err != nil {
return nil, err
}
r := make([]types.Dataset, len(ds))
for i, d := range ds {
r[i] = types.Dataset{
ID: d.ID,
EmbedDimension: z.Pointer(d.EmbedDimension),
}
}
return r, nil
}

func (c *StandaloneClient) Ingest(ctx context.Context, datasetID string, data []byte, opts datastore.IngestOpts) ([]string, error) {
Expand All @@ -67,7 +82,8 @@ func (c *StandaloneClient) IngestPaths(ctx context.Context, datasetID string, op
if err != nil {
return fmt.Errorf("failed to open file %s: %w", path, err)
}
_, err = c.Datastore.Ingest(ctx, datasetID, file, datastore.IngestOpts{

iopts := datastore.IngestOpts{
Filename: z.Pointer(filepath.Base(path)),
FileMetadata: &index.FileMetadata{
Name: filepath.Base(path),
Expand All @@ -76,7 +92,13 @@ func (c *StandaloneClient) IngestPaths(ctx context.Context, datasetID string, op
ModifiedAt: finfo.ModTime(),
},
IsDuplicateFunc: datastore.DedupeByFileMetadata,
})
}

if opts != nil {
iopts.TextSplitterOpts = opts.TextSplitterOpts
}

_, err = c.Datastore.Ingest(ctx, datasetID, file, iopts)
return err
}

Expand All @@ -94,7 +116,7 @@ func (c *StandaloneClient) DeleteDocuments(ctx context.Context, datasetID string
}

func (c *StandaloneClient) Retrieve(ctx context.Context, datasetID string, query string, opts RetrieveOpts) ([]vectorstore.Document, error) {
return c.Datastore.Retrieve(ctx, datasetID, types.Query{Prompt: query, TopK: z.Pointer(opts.TopK)})
return c.Datastore.Retrieve(ctx, datasetID, query, opts.TopK)
}

func (c *StandaloneClient) AskDirectory(ctx context.Context, path string, query string, opts *IngestPathsOpts, ropts *RetrieveOpts) ([]vectorstore.Document, error) {
Expand Down
8 changes: 4 additions & 4 deletions pkg/cmd/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ package cmd

import (
"github.com/gptscript-ai/knowledge/pkg/client"
"github.com/gptscript-ai/knowledge/pkg/config"
"github.com/gptscript-ai/knowledge/pkg/datastore"
"github.com/gptscript-ai/knowledge/pkg/types"
)

type Client struct {
Server string `usage:"URL of the Knowledge API Server" default:"" env:"KNOW_SERVER_URL"`
types.OpenAIConfig
types.DatabaseConfig
types.VectorDBConfig
config.OpenAIConfig
config.DatabaseConfig
config.VectorDBConfig
}

func (s *Client) getClient() (client.Client, error) {
Expand Down
3 changes: 3 additions & 0 deletions pkg/cmd/ingest.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cmd
import (
"fmt"
"github.com/gptscript-ai/knowledge/pkg/client"
"github.com/gptscript-ai/knowledge/pkg/datastore"
"github.com/spf13/cobra"
"strings"
)
Expand All @@ -11,6 +12,7 @@ type ClientIngest struct {
Client
Dataset string `usage:"Target Dataset ID" short:"d" default:"default" env:"KNOW_TARGET_DATASET"`
ClientIngestOpts
datastore.TextSplitterOpts
}

type ClientIngestOpts struct {
Expand Down Expand Up @@ -38,6 +40,7 @@ func (s *ClientIngest) Run(cmd *cobra.Command, args []string) error {
IgnoreExtensions: strings.Split(s.IgnoreExtensions, ","),
Concurrency: s.Concurrency,
Recursive: s.Recursive,
TextSplitterOpts: &s.TextSplitterOpts,
}

filesIngested, err := c.IngestPaths(cmd.Context(), datasetID, ingestOpts, filePath)
Expand Down
8 changes: 4 additions & 4 deletions pkg/cmd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package cmd

import (
"fmt"
"github.com/gptscript-ai/knowledge/pkg/config"
"github.com/gptscript-ai/knowledge/pkg/datastore"
"github.com/gptscript-ai/knowledge/pkg/server"
"github.com/gptscript-ai/knowledge/pkg/types"
"github.com/spf13/cobra"
"os/signal"
"syscall"
Expand All @@ -16,9 +16,9 @@ type Server struct {
ServerPort string `usage:"Server port" default:"8000" env:"KNOW_SERVER_PORT"`
ServerAPIBase string `usage:"Server API base" default:"/v1" env:"KNOW_SERVER_API_BASE"`

types.OpenAIConfig
types.DatabaseConfig
types.VectorDBConfig
config.OpenAIConfig
config.DatabaseConfig
config.VectorDBConfig
}

func (s *Server) Run(cmd *cobra.Command, _ []string) error {
Expand Down
2 changes: 1 addition & 1 deletion pkg/types/config.go → pkg/config/config.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package types
package config

type OpenAIConfig struct {
APIBase string `usage:"OpenAI API base" default:"https://api.openai.com/v1" env:"OPENAI_BASE_URL"` // clicky-chats
Expand Down
16 changes: 7 additions & 9 deletions pkg/datastore/dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,16 @@ import (
"context"
"errors"
"fmt"
"github.com/acorn-io/z"
"github.com/gptscript-ai/knowledge/pkg/index"
"github.com/gptscript-ai/knowledge/pkg/types"
"github.com/gptscript-ai/knowledge/pkg/types/defaults"
"gorm.io/gorm"
"log/slog"
)

func (s *Datastore) NewDataset(ctx context.Context, dataset types.Dataset) error {
func (s *Datastore) NewDataset(ctx context.Context, dataset index.Dataset) error {
// Set defaults
if dataset.EmbedDimension == nil || *dataset.EmbedDimension <= 0 {
dataset.EmbedDimension = z.Pointer(defaults.EmbeddingDimension)
if dataset.EmbedDimension <= 0 {
dataset.EmbedDimension = defaults.EmbeddingDimension
}

// Create dataset
Expand All @@ -36,7 +34,7 @@ func (s *Datastore) NewDataset(ctx context.Context, dataset types.Dataset) error
func (s *Datastore) DeleteDataset(ctx context.Context, datasetID string) error {
// Delete dataset
slog.Info("Deleting dataset", "id", datasetID)
tx := s.Index.WithContext(ctx).Delete(&types.Dataset{}, "id = ?", datasetID)
tx := s.Index.WithContext(ctx).Delete(&index.Dataset{}, "id = ?", datasetID)
if tx.Error != nil {
return tx.Error
}
Expand All @@ -63,13 +61,13 @@ func (s *Datastore) GetDataset(ctx context.Context, datasetID string) (*index.Da
return dataset, nil
}

func (s *Datastore) ListDatasets(ctx context.Context) ([]types.Dataset, error) {
tx := s.Index.WithContext(ctx).Find(&[]types.Dataset{})
func (s *Datastore) ListDatasets(ctx context.Context) ([]index.Dataset, error) {
tx := s.Index.WithContext(ctx).Find(&[]index.Dataset{})
if tx.Error != nil {
return nil, tx.Error
}

var datasets []types.Dataset
var datasets []index.Dataset
if err := tx.Scan(&datasets).Error; err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/datastore/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import (
"fmt"
"github.com/acorn-io/z"
"github.com/adrg/xdg"
"github.com/gptscript-ai/knowledge/pkg/config"
"github.com/gptscript-ai/knowledge/pkg/index"
"github.com/gptscript-ai/knowledge/pkg/types"
"github.com/gptscript-ai/knowledge/pkg/vectorstore"
"github.com/gptscript-ai/knowledge/pkg/vectorstore/chromem"
cg "github.com/philippgille/chromem-go"
Expand Down Expand Up @@ -41,7 +41,7 @@ func GetDatastorePaths(dsn, vectordbPath string) (string, string, error) {
return dsn, vectordbPath, nil
}

func NewDatastore(dsn string, automigrate bool, vectorDBPath string, openAIConfig types.OpenAIConfig) (*Datastore, error) {
func NewDatastore(dsn string, automigrate bool, vectorDBPath string, openAIConfig config.OpenAIConfig) (*Datastore, error) {
dsn, vectorDBPath, err := GetDatastorePaths(dsn, vectorDBPath)
if err != nil {
return nil, fmt.Errorf("failed to determine datastore paths: %w", err)
Expand Down Expand Up @@ -80,7 +80,7 @@ func NewDatastore(dsn string, automigrate bool, vectorDBPath string, openAIConfi
}

if defaultDS == nil {
err = ds.NewDataset(context.Background(), types.Dataset{ID: "default", EmbedDimension: nil})
err = ds.NewDataset(context.Background(), index.Dataset{ID: "default"})
if err != nil {
return nil, fmt.Errorf("failed to create default dataset: %w", err)
}
Expand Down
22 changes: 14 additions & 8 deletions pkg/datastore/ingest.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type IngestOpts struct {
FileMetadata *index.FileMetadata
IsDuplicateFuncName string
IsDuplicateFunc IsDuplicateFunc
TextSplitterOpts *TextSplitterOpts
}

// Ingest loads a document from a reader and adds it to the dataset.
Expand Down Expand Up @@ -118,7 +119,7 @@ func (s *Datastore) Ingest(ctx context.Context, datasetID string, content []byte
return nil, nil
}

docs, err := GetDocuments(ctx, *opts.Filename, filetype, reader)
docs, err := GetDocuments(ctx, *opts.Filename, filetype, reader, opts.TextSplitterOpts)
if err != nil {
slog.Error("Failed to load documents", "error", err)
return nil, fmt.Errorf("failed to load documents: %w", err)
Expand Down Expand Up @@ -187,7 +188,12 @@ func mimetypeFromReader(reader io.Reader) (string, io.Reader, error) {
return mtype.String(), newReader, err
}

func GetDocuments(ctx context.Context, filename, filetype string, reader io.Reader) ([]vs.Document, error) {
func GetDocuments(ctx context.Context, filename, filetype string, reader io.Reader, textSplitterOpts *TextSplitterOpts) ([]vs.Document, error) {
if textSplitterOpts == nil {
textSplitterOpts = z.Pointer(NewTextSplitterOpts())
}
lcgoTextSplitter := NewLcgoTextSplitter(*textSplitterOpts)

/*
* Load documents from the content
* For now, we're using documentloaders from both langchaingo and golc
Expand Down Expand Up @@ -227,13 +233,13 @@ func GetDocuments(ctx context.Context, filename, filetype string, reader io.Read
Metadata: rdoc.Metadata,
}
}
lcgodocs, err = lcgosplitter.SplitDocuments(defaultLcgoSplitter, lcgodocs)
lcgodocs, err = lcgosplitter.SplitDocuments(lcgoTextSplitter, lcgodocs)
case ".html", "text/html":
lcgodocs, err = lcgodocloaders.NewHTML(reader).LoadAndSplit(ctx, defaultLcgoSplitter)
lcgodocs, err = lcgodocloaders.NewHTML(reader).LoadAndSplit(ctx, lcgoTextSplitter)
case ".md", "text/markdown":
lcgodocs, err = lcgodocloaders.NewText(reader).LoadAndSplit(ctx, defaultLcgoSplitter)
lcgodocs, err = lcgodocloaders.NewText(reader).LoadAndSplit(ctx, lcgoTextSplitter)
case ".txt", "text/plain":
lcgodocs, err = lcgodocloaders.NewText(reader).LoadAndSplit(ctx, defaultLcgoSplitter)
lcgodocs, err = lcgodocloaders.NewText(reader).LoadAndSplit(ctx, lcgoTextSplitter)
case ".csv", "text/csv":
golcdocs, err = golcdocloaders.NewCSV(reader).Load(ctx)
if err != nil && errors.Is(err, csv.ErrBareQuote) {
Expand All @@ -248,7 +254,7 @@ func GetDocuments(ctx context.Context, filename, filetype string, reader io.Read
}
}
case ".json", "application/json":
lcgodocs, err = lcgodocloaders.NewText(reader).LoadAndSplit(ctx, defaultLcgoSplitter)
lcgodocs, err = lcgodocloaders.NewText(reader).LoadAndSplit(ctx, lcgoTextSplitter)
case ".ipynb":
golcdocs, err = golcdocloaders.NewNotebook(reader).Load(ctx)
case ".docx", ".odt", ".rtf", "application/vnd.oasis.opendocument.text", "text/rtf", "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
Expand All @@ -260,7 +266,7 @@ func GetDocuments(ctx context.Context, filename, filetype string, reader io.Read
if nerr != nil {
return nil, fmt.Errorf("failed to extract text from %s: %w", filetype, nerr)
}
lcgodocs, err = lcgodocloaders.NewText(strings.NewReader(text)).LoadAndSplit(ctx, defaultLcgoSplitter)
lcgodocs, err = lcgodocloaders.NewText(strings.NewReader(text)).LoadAndSplit(ctx, lcgoTextSplitter)
default:
// TODO(@iwilltry42): Fallback to plaintext reader? Example: Makefile, Dockerfile, Source Files, etc.
slog.Error("Unsupported file type", "filename", filename, "type", filetype)
Expand Down
10 changes: 4 additions & 6 deletions pkg/datastore/retrieve.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,18 @@ package datastore

import (
"context"
"github.com/acorn-io/z"
"github.com/gptscript-ai/knowledge/pkg/types"
"github.com/gptscript-ai/knowledge/pkg/types/defaults"
"github.com/gptscript-ai/knowledge/pkg/vectorstore"
"log/slog"
)

func (s *Datastore) Retrieve(ctx context.Context, datasetID string, query types.Query) ([]vectorstore.Document, error) {
if query.TopK == nil {
query.TopK = z.Pointer(defaults.TopK)
func (s *Datastore) Retrieve(ctx context.Context, datasetID string, query string, topk int) ([]vectorstore.Document, error) {
if topk <= 0 {
topk = defaults.TopK
}
slog.Debug("Retrieving content from dataset", "dataset", datasetID, "query", query)

docs, err := s.Vectorstore.SimilaritySearch(ctx, query.Prompt, *query.TopK, datasetID)
docs, err := s.Vectorstore.SimilaritySearch(ctx, query, topk, datasetID)
if err != nil {
return nil, err
}
Expand Down
29 changes: 26 additions & 3 deletions pkg/datastore/textsplitter.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,29 @@ package datastore

import lcgosplitter "github.com/tmc/langchaingo/textsplitter"

var (
defaultLcgoSplitter = lcgosplitter.NewTokenSplitter(lcgosplitter.WithChunkSize(defaultChunkSize), lcgosplitter.WithChunkOverlap(defaultChunkOverlap), lcgosplitter.WithModelName(defaultTokenModel), lcgosplitter.WithEncodingName(defaultTokenEncoding))
)
type TextSplitterOpts struct {
ChunkSize int `usage:"Textsplitter Chunk Size" default:"1024" env:"KNOW_TEXTSPLITTER_CHUNK_SIZE"`
ChunkOverlap int `usage:"Textsplitter Chunk Overlap" default:"256" env:"KNOW_TEXTSPLITTER_CHUNK_OVERLAP"`
ModelName string `usage:"Textsplitter Model Name" default:"gpt-4" env:"KNOW_TEXTSPLITTER_MODEL_NAME"`
EncodingName string `usage:"Textsplitter Encoding Name" default:"cl100k_base" env:"KNOW_TEXTSPLITTER_ENCODING_NAME"`
}

// NewTextSplitterOpts returns the default options for a text splitter.
func NewTextSplitterOpts() TextSplitterOpts {
return TextSplitterOpts{
ChunkSize: defaultChunkSize,
ChunkOverlap: defaultChunkOverlap,
ModelName: defaultTokenModel,
EncodingName: defaultTokenEncoding,
}
}

// NewLcgoTextSplitter returns a new langchain-go text splitter.
func NewLcgoTextSplitter(opts TextSplitterOpts) lcgosplitter.TokenSplitter {
return lcgosplitter.NewTokenSplitter(
lcgosplitter.WithChunkSize(opts.ChunkSize),
lcgosplitter.WithChunkOverlap(opts.ChunkOverlap),
lcgosplitter.WithModelName(opts.ModelName),
lcgosplitter.WithEncodingName(opts.EncodingName),
)
}
Loading

0 comments on commit 8ae6359

Please sign in to comment.