Skip to content
This repository has been archived by the owner on Sep 30, 2024. It is now read-only.

modifed idf index logic to remove tokenization and edited matching logic for term expansion #64505

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cmd/frontend/graphqlbackend/repository_reindex.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package graphqlbackend

import (
"context"
"fmt"

"github.com/graph-gophers/graphql-go"

"github.com/sourcegraph/sourcegraph/cmd/frontend/internal/search/idf"
"github.com/sourcegraph/sourcegraph/internal/auth"
"github.com/sourcegraph/sourcegraph/internal/search/zoekt"
)
Expand All @@ -13,6 +15,9 @@ import (
func (r *schemaResolver) ReindexRepository(ctx context.Context, args *struct {
Repository graphql.ID
}) (*EmptyResponse, error) {
// MARK(beyang): this is triggered by the "Reindex now" button on a page like https://sourcegraph.test:3443/github.com/hashicorp/errwrap/-/settings/index
fmt.Printf("# schemaResolver.ReindexRepository\n")

// 🚨 SECURITY: There is no reason why non-site-admins would need to run this operation.
if err := auth.CheckCurrentUserIsSiteAdmin(ctx, r.db); err != nil {
return nil, err
Expand All @@ -23,6 +28,10 @@ func (r *schemaResolver) ReindexRepository(ctx context.Context, args *struct {
return nil, err
}

if err := idf.Update(ctx, repo.RepoName()); err != nil {
return nil, err
}

err = zoekt.Reindex(ctx, repo.RepoName(), repo.IDInt32())
if err != nil {
return nil, err
Expand Down
65 changes: 62 additions & 3 deletions cmd/frontend/internal/codycontext/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@ package codycontext
import (
"context"
"fmt"
"sort"
"strings"
"sync"

lg "log"

"github.com/grafana/regexp"
"github.com/sourcegraph/conc/pool"
"github.com/sourcegraph/log"
"go.opentelemetry.io/otel/attribute"

"github.com/sourcegraph/sourcegraph/cmd/frontend/internal/cody"
"github.com/sourcegraph/sourcegraph/cmd/frontend/internal/search/idf"
"github.com/sourcegraph/sourcegraph/internal/api"
"github.com/sourcegraph/sourcegraph/internal/conf"
"github.com/sourcegraph/sourcegraph/internal/database"
Expand Down Expand Up @@ -82,6 +86,7 @@ type CodyContextClient struct {

type GetContextArgs struct {
Repos []types.RepoIDName
RepoStats map[api.RepoName]*idf.StatsProvider
Query string
CodeResultsCount int32
TextResultsCount int32
Expand Down Expand Up @@ -138,13 +143,15 @@ func (c *CodyContextClient) GetCodyContext(ctx context.Context, args GetContextA

embeddingsArgs := GetContextArgs{
Repos: embeddingRepos,
RepoStats: args.RepoStats,
Query: args.Query,
CodeResultsCount: int32(float32(args.CodeResultsCount) * embeddingsResultRatio),
TextResultsCount: int32(float32(args.TextResultsCount) * embeddingsResultRatio),
}
keywordArgs := GetContextArgs{
Repos: keywordRepos,
Query: args.Query,
Repos: keywordRepos,
RepoStats: args.RepoStats,
Query: args.Query,
// Assign the remaining result budget to keyword search
CodeResultsCount: args.CodeResultsCount - embeddingsArgs.CodeResultsCount,
TextResultsCount: args.TextResultsCount - embeddingsArgs.TextResultsCount,
Expand Down Expand Up @@ -277,7 +284,11 @@ func (c *CodyContextClient) getKeywordContext(ctx context.Context, args GetConte
// mini-HACK: pass in the scope using repo: filters. In an ideal world, we
// would not be using query text manipulation for this and would be using
// the job structs directly.
keywordQuery := fmt.Sprintf(`repo:%s %s %s`, reposAsRegexp(args.Repos), getKeywordContextExcludeFilePathsQuery(), args.Query)
var maxTermsPerWord = 5
transformedQuery := getTransformedQuery(args, maxTermsPerWord)
lg.Printf("# userQuery -> transformedQuery: %q -> %q", args.Query, transformedQuery)
fmt.Printf("# userQuery -> transformedQuery: %q -> %q", args.Query, transformedQuery)
keywordQuery := fmt.Sprintf(`repo:%s %s %s`, reposAsRegexp(args.Repos), getKeywordContextExcludeFilePathsQuery(), transformedQuery)

ctx, cancel := context.WithCancel(ctx)
defer cancel()
Expand Down Expand Up @@ -371,3 +382,51 @@ func fileMatchToContextMatch(fm *result.FileMatch) FileChunkContext {
StartLine: startLine,
}
}

func getTransformedQuery(args GetContextArgs, maxTermsPerWord int) string {
if args.RepoStats == nil {
lg.Printf("# no stats set")
return args.Query
}

for _, repo := range args.Repos {
if _, ok := args.RepoStats[repo.Name]; !ok {
// Don't transform query if one of the repositories lacks an IDF table
lg.Printf("# didn't find stats for repo %s", repo.Name)
return args.Query
}
}

// TODO(rishabh): currently we are just picking up top-k vocab terms based on idf scores, but we can do a better semantic ranking of terms
// current matching is fairly limited based on substring matching, but perhaps stemming/lemmatization might be considered?

var filteredToks []string
// var maxTermsPerWord = 5

type termScore struct {
term string
score float32
}

for _, word := range strings.Fields(args.Query) {
if len(word) < 4 {
continue
}
var matches []termScore
for _, stats := range args.RepoStats {
for term, score := range stats.GetTerms() {
if strings.Contains(term, word) && len(term) > 4 && score > 3 {
matches = append(matches, termScore{term: term, score: score})
}
}
}
sort.Slice(matches, func(i, j int) bool {
return matches[i].score > matches[j].score
})
for i := 0; i < min(maxTermsPerWord, len(matches)); i++ {
filteredToks = append(filteredToks, matches[i].term)
}
}

return strings.Join(filteredToks, " ")
}
12 changes: 12 additions & 0 deletions cmd/frontend/internal/context/resolvers/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"io"
lg "log"
"net/http"
"time"

Expand All @@ -13,9 +14,11 @@ import (
"github.com/sourcegraph/conc/iter"
"github.com/sourcegraph/conc/pool"
"github.com/sourcegraph/log"

"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
"github.com/sourcegraph/sourcegraph/cmd/frontend/internal/cody"
"github.com/sourcegraph/sourcegraph/cmd/frontend/internal/codycontext"
"github.com/sourcegraph/sourcegraph/cmd/frontend/internal/search/idf"
"github.com/sourcegraph/sourcegraph/internal/api"
"github.com/sourcegraph/sourcegraph/internal/conf"
"github.com/sourcegraph/sourcegraph/internal/database"
Expand Down Expand Up @@ -183,6 +186,7 @@ func (r *Resolver) GetCodyContext(ctx context.Context, args graphqlbackend.GetCo
}

repoNameIDs := make([]types.RepoIDName, len(repoIDs))
repoStats := make(map[api.RepoName]*idf.StatsProvider)
for i, repoID := range repoIDs {
repo, ok := repos[repoID]
if !ok {
Expand All @@ -191,10 +195,18 @@ func (r *Resolver) GetCodyContext(ctx context.Context, args graphqlbackend.GetCo
}

repoNameIDs[i] = types.RepoIDName{ID: repoID, Name: repo.Name}

stats, err := idf.Get(ctx, repo.Name)
if err != nil {
lg.Printf("Unexpected error getting idf index value for repo %v: %v", repoID, err)
continue
}
repoStats[repo.Name] = stats
}

fileChunks, err := r.contextClient.GetCodyContext(ctx, codycontext.GetContextArgs{
Repos: repoNameIDs,
RepoStats: repoStats,
Query: args.Query,
CodeResultsCount: args.CodeResultsCount,
TextResultsCount: args.TextResultsCount,
Expand Down
164 changes: 164 additions & 0 deletions cmd/frontend/internal/search/idf/idf.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// Package idf computes and stores the inverse document frequency (IDF) of a set of repositories.
//
// TODO(beyang): should probably move this elsewhere
package idf

import (
"archive/tar"
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"log"
"math"
"path"
"strings"
"unicode"

"github.com/sourcegraph/sourcegraph/internal/api"
"github.com/sourcegraph/sourcegraph/internal/gitserver"
"github.com/sourcegraph/sourcegraph/internal/rcache"
"github.com/sourcegraph/sourcegraph/internal/redispool"
"github.com/sourcegraph/sourcegraph/lib/errors"
)

var redisCache = rcache.NewWithTTL(redispool.Cache, "idf-index", 10*24*60*60)

func Update(ctx context.Context, repoName api.RepoName) error {
fmt.Printf("# idf.Update(%v)\n", repoName)

stats := NewStatsAggregator()

git := gitserver.NewClient("idf-indexer")
r, err := git.ArchiveReader(ctx, repoName, gitserver.ArchiveOptions{Treeish: "HEAD", Format: gitserver.ArchiveFormatTar})
if err != nil {
return nil
}

permissibleExtensions := map[string]bool{
".py": true, ".js": true, ".ts": true, ".java": true, ".cpp": true,
".c": true, ".cs": true, ".go": true, ".rb": true, ".rs": true,
".php": true, ".html": true, ".css": true, ".scss": true, ".md": true,
".sh": true, ".swift": true, ".kt": true, ".m": true,
}

tr := tar.NewReader(r)
for {
header, err := tr.Next()
if err == io.EOF {
break // End of archive
}
if err != nil {
log.Printf("Error reading next tar header: %v", err)
continue
}

// Skip directories
if header.Typeflag == tar.TypeDir {
continue
}

// Check if the file has a permissible extension
ext := strings.ToLower(path.Ext(header.Name))

if !permissibleExtensions[ext] {
continue
}

// Read the first line of the file
scanner := bufio.NewScanner(tr)
if scanner.Scan() {
stats.ProcessDoc(scanner.Text())
} else if err := scanner.Err(); err != nil {
log.Printf("Error reading file content: %v", err)
}
}

statsP := stats.EvalProvider()
statsBytes, err := json.Marshal(statsP)

log.Printf("# storing stats: %s", string(statsBytes))

if err != nil {
return errors.Wrap(err, "idf.Update: failed to marshal IDF table")
}

redisCache.Set(fmt.Sprintf("repo:%v", repoName), statsBytes)
return nil
}

func Get(ctx context.Context, repoName api.RepoName) (*StatsProvider, error) {
fmt.Printf("# idf.Get(%v)", repoName)
b, ok := redisCache.Get(fmt.Sprintf("repo:%v", repoName))
if !ok {
return nil, nil
}

var stats StatsProvider
if err := json.Unmarshal(b, &stats); err != nil {
return nil, errors.Wrap(err, "idf.Get: failed to unmarshal IDF table")
}

log.Printf("# fetching stats: %v", stats)

return &stats, nil
}

type StatsAggregator struct {
TermToDocCt map[string]int
DoctCt int
}

func NewStatsAggregator() *StatsAggregator {
return &StatsAggregator{
TermToDocCt: make(map[string]int),
}
}

func isValidWord(word string) bool {
if len(word) < 3 || len(word) > 20 {
return false
}
hasLetter := false
for _, char := range word {
if !unicode.IsLetter(char) && !unicode.IsNumber(char) {
return false
}
if unicode.IsLetter(char) {
hasLetter = true
}
}
return hasLetter
}

func (s *StatsAggregator) ProcessDoc(text string) {
words := strings.Fields(text)
for _, word := range words {
// word = strings.ToLower(word)
if isValidWord(word) {
s.TermToDocCt[word]++
}
}
s.DoctCt++
}

func (s *StatsAggregator) EvalProvider() StatsProvider {
idf := make(map[string]float32)
for term, docCt := range s.TermToDocCt {
idf[term] = float32(math.Log(float64(s.DoctCt) / (1.0 + float64(docCt))))
}
return StatsProvider{IDF: idf}
}

type StatsProvider struct {
IDF map[string]float32
}

func (s *StatsProvider) GetIDF(term string) float32 {
return s.IDF[strings.ToLower(term)]
}

func (s *StatsProvider) GetTerms() map[string]float32 {
return s.IDF
}
Loading
Loading