From fde5ec990c67eba548f5878b59480aebf117a420 Mon Sep 17 00:00:00 2001 From: Charlie Egan Date: Wed, 29 May 2024 14:04:02 +0100 Subject: [PATCH] lsp: Provide data ref completions in rules Fixes: https://github.com/StyraInc/regal/issues/753 Signed-off-by: Charlie Egan --- internal/lsp/completions/manager.go | 2 +- .../lsp/completions/providers/packagerefs.go | 54 +---- .../lsp/completions/providers/rulerefs.go | 188 +++++++++++------- .../completions/providers/rulerefs_test.go | 28 ++- internal/lsp/completions/providers/utils.go | 36 +++- 5 files changed, 183 insertions(+), 125 deletions(-) diff --git a/internal/lsp/completions/manager.go b/internal/lsp/completions/manager.go index 2cb6d80f..2747f2eb 100644 --- a/internal/lsp/completions/manager.go +++ b/internal/lsp/completions/manager.go @@ -34,7 +34,7 @@ func NewDefaultManager(c *cache.Cache) *Manager { m.RegisterProvider(&providers.BuiltIns{}) m.RegisterProvider(&providers.RegoV1{}) m.RegisterProvider(&providers.PackageRefs{}) - m.RegisterProvider(&providers.RuleFromImportedPackageRefs{}) + m.RegisterProvider(&providers.RuleRefs{}) return m } diff --git a/internal/lsp/completions/providers/packagerefs.go b/internal/lsp/completions/providers/packagerefs.go index ea5c91d7..f61c322b 100644 --- a/internal/lsp/completions/providers/packagerefs.go +++ b/internal/lsp/completions/providers/packagerefs.go @@ -1,7 +1,6 @@ package providers import ( - "slices" "strings" "github.com/styrainc/regal/internal/lsp/cache" @@ -29,7 +28,7 @@ func (*PackageRefs) Run(c *cache.Cache, params types.CompletionParams, _ *Option lastWord := words[len(words)-1] thisFileReferences := c.GetFileRefs(fileURI) - otherFilePackages := make(map[string]types.Ref) + refsForContext := make(map[string]types.Ref) // filter out the packages that have the last word as a prefix. for file, refs := range c.GetAllFileRefs() { @@ -54,24 +53,22 @@ func (*PackageRefs) Run(c *cache.Cache, params types.CompletionParams, _ *Option continue } - otherFilePackages[key] = ref + refsForContext[key] = ref } } - // partialRefs is a generated list of package names generated from - // longer names. For example, if the packages data.foo.bar and data.foo.baz + // refsForContext is now supplemented with a generated list of package names + // from longer packages. For example, if the packages data.foo.bar and data.foo.baz // are defined, an author should still be able to import data.foo. - partialRefs := make(map[string]types.Ref) - - for key := range otherFilePackages { + for key := range refsForContext { parts := strings.Split(key, ".") // starting at 1 to skip 'data' for i := 1; i < len(parts)-1; i++ { partialKey := strings.Join(parts[:i+1], ".") // only insert the new partial key if there is no full package // ref that matches it - if _, ok := partialRefs[partialKey]; !ok { - partialRefs[partialKey] = types.Ref{ + if _, ok := refsForContext[partialKey]; !ok { + refsForContext[partialKey] = types.Ref{ Label: partialKey, Description: "See sub packages for more information", } @@ -82,42 +79,7 @@ func (*PackageRefs) Run(c *cache.Cache, params types.CompletionParams, _ *Option // refs are grouped by 'depth', where depth is the number of dots in the // ref string. This is a simplification, but allows shorted, higher level // refs to be suggested first. - byDepth := make(map[int]map[string]types.Ref) - - for _, item := range otherFilePackages { - depth := strings.Count(item.Label, ".") - - if _, ok := byDepth[depth]; !ok { - byDepth[depth] = make(map[string]types.Ref) - } - - byDepth[depth][item.Label] = item - } - - // add partial refs to the byDepth map in case they are not defined - // as full refs in files. - for _, item := range partialRefs { - depth := strings.Count(item.Label, ".") - - if _, ok := byDepth[depth]; !ok { - byDepth[depth] = make(map[string]types.Ref) - } - - // only add partial refs where no top level ref exists. - if _, ok := byDepth[depth][item.Label]; ok { - continue - } - - byDepth[depth][item.Label] = item - } - - // items will be shown in order from shallowest to deepest - depths := make([]int, 0) - for k := range byDepth { - depths = append(depths, k) - } - - slices.Sort(depths) + depths, byDepth := groupKeyedRefsByDepth(refsForContext) items := make([]types.CompletionItem, 0) for _, depth := range depths { diff --git a/internal/lsp/completions/providers/rulerefs.go b/internal/lsp/completions/providers/rulerefs.go index b255dff0..d7475745 100644 --- a/internal/lsp/completions/providers/rulerefs.go +++ b/internal/lsp/completions/providers/rulerefs.go @@ -8,11 +8,14 @@ import ( "github.com/styrainc/regal/internal/lsp/types/completion" ) -// RuleFromImportedPackageRefs is a completion provider that returns completions for -// rules found in already imported packages. -type RuleFromImportedPackageRefs struct{} - -func (*RuleFromImportedPackageRefs) Run( +// RuleRefs is a completion provider that returns completions for +// rules found in: +// - the current file +// - imported packages +// - any other files in the workspace under data. +type RuleRefs struct{} + +func (*RuleRefs) Run( c *cache.Cache, params types.CompletionParams, _ *Options, @@ -33,6 +36,9 @@ func (*RuleFromImportedPackageRefs) Run( return nil, nil } + words := patternWhiteSpace.Split(currentLine, -1) + lastWord := words[len(words)-1] + // some version of a parsed mod is needed here to filter refs to suggest // based on import statements mod, ok := c.GetModule(fileURI) @@ -40,101 +46,149 @@ func (*RuleFromImportedPackageRefs) Run( return nil, nil } - refsFromImports := make(map[string]types.Ref) + refsForContext := make(map[string]types.Ref) - for file, refs := range c.GetAllFileRefs() { - if file == fileURI { - continue - } - - for key, ref := range refs { - // we are not interested in packages here, only the rules + for uri, refs := range c.GetAllFileRefs() { + for _, ref := range refs { + // we are not interested in packages here, only rules if ref.Kind == types.Package { continue } - // don't suggest refs of "private" rules, even - // if only just by naming convention + // don't suggest refs of "private" rules, even if this + // is only just a naming convention if strings.Contains(ref.Label, "._") { continue } - isFromImportedPackage := false - + // for refs from imported packages, we need to strip the start of the + // package string, e.g. data.foo.bar -> bar + key := ref.Label for _, i := range mod.Imports { - if strings.HasPrefix(key, i.Path.String()) { - isFromImportedPackage = true + if k := attemptToStripImportPrefix(key, i.Path.String()); k != "" { + key = k break } } - if !isFromImportedPackage { + // suggest rules from the current file without any package prefix + if uri == fileURI { + parts := strings.Split(key, ".") + key = parts[len(parts)-1] + } + + // only suggest refs that match the last word the user has typed. + if !strings.HasPrefix(key, lastWord) { continue } - refsFromImports[key] = ref + refsForContext[key] = ref + } + } + + // Generate a list of package names from longer rule names. + // For example, if the rules data.foo.bar and data.foo.baz + // are defined, an author should still see data.foo suggested + // as a partial path leading to the rules bar and baz. + for key := range refsForContext { + parts := strings.Split(key, ".") + // starting at 1 to skip 'data' + for i := 1; i < len(parts)-1; i++ { + partialKey := strings.Join(parts[:i+1], ".") + // only insert the new partial key if there is no full package + // ref that matches it + if _, ok := refsForContext[partialKey]; !ok { + refsForContext[partialKey] = types.Ref{ + Label: partialKey, + Description: "Partial", + } + } } } + // refs are grouped by 'depth', where depth is the number of dots in the + // ref string. This is a simplification, but allows shorter, higher level + // refs to be suggested first. + depths, byDepth := groupKeyedRefsByDepth(refsForContext) + items := make([]types.CompletionItem, 0) + for _, depth := range depths { + // items are added in groups of depth until there more then 10 items. + if len(items) > 10 { + continue + } - for _, ref := range refsFromImports { - symbol := completion.Variable - detail := "Rule" - - switch { - case ref.Kind == types.ConstantRule: - symbol = completion.Constant - detail = "Constant Rule" - case ref.Kind == types.Function: - symbol = completion.Function - detail = "Function" + itemsForDepth, ok := byDepth[depth] + if !ok { + continue } - packageAndRule := labelToPackageAndRule(ref.Label) - - startChar := params.Position.Character - - uint(len(strings.Split(currentLine, " ")[len(strings.Split(currentLine, " "))-1])) - - items = append(items, types.CompletionItem{ - Label: packageAndRule, - Kind: symbol, - Detail: detail, - Documentation: &types.MarkupContent{ - Kind: "markdown", - Value: ref.Description, - }, - TextEdit: &types.TextEdit{ - Range: types.Range{ - Start: types.Position{ - Line: params.Position.Line, - Character: startChar, - }, - End: types.Position{ - Line: params.Position.Line, - Character: uint(len(currentLine)), + for key, ref := range itemsForDepth { + symbol := completion.Variable + detail := "Rule" + + switch { + case ref.Kind == types.ConstantRule: + symbol = completion.Constant + detail = "Constant Rule" + case ref.Kind == types.Function: + symbol = completion.Function + detail = "Function" + case ref.Description == "Partial": + detail = "Partial path suggestion, continue typing for more suggestions." + symbol = completion.Module + ref.Description = "" + } + + items = append(items, types.CompletionItem{ + Label: key, + Kind: symbol, + Detail: detail, + Documentation: &types.MarkupContent{ + Kind: "markdown", + Value: ref.Description, + }, + TextEdit: &types.TextEdit{ + Range: types.Range{ + Start: types.Position{ + Line: params.Position.Line, + Character: params.Position.Character - uint(len(lastWord)), + }, + End: types.Position{ + Line: params.Position.Line, + Character: uint(len(currentLine)), + }, }, + NewText: key, }, - NewText: packageAndRule, - }, - }) + }) + } } return items, nil } -func labelToPackageAndRule(label string) string { - parts := strings.Split(label, ".") - partCount := len(parts) +func attemptToStripImportPrefix(key, importKey string) string { + // we can only strip the import prefix if the rule key starts with the + // import key + if !strings.HasPrefix(key, importKey) { + return "" + } + + importKeyParts := strings.Split(importKey, ".") - // a ref should be at least three parts, data.package.rule_from_package - // if it is not, then we can't provide a valid new text and return the - // full label as a fallback - if partCount < 3 { - return label + // if for some reason we have a key shorter than 'data.foo', we don't know + // how to handle this + if len(importKeyParts) < 2 { + return "" } - // take the last two parts of the ref, package and rule - return parts[partCount-2] + "." + parts[partCount-1] + // strippablePrefix is all but the last part of the module path + // for example 'data.foo.bar.baz' -> 'data.foo.bar.'. + // This is what the author would need to type to access rules + // from the imported package. + strippablePrefix := strings.Join(importKeyParts[0:len(importKeyParts)-1], ".") + "." + + return strings.TrimPrefix(key, strippablePrefix) } diff --git a/internal/lsp/completions/providers/rulerefs_test.go b/internal/lsp/completions/providers/rulerefs_test.go index 1b58fee0..3a187453 100644 --- a/internal/lsp/completions/providers/rulerefs_test.go +++ b/internal/lsp/completions/providers/rulerefs_test.go @@ -2,6 +2,7 @@ package providers import ( "slices" + "strings" "testing" "github.com/styrainc/regal/internal/lsp/cache" @@ -17,9 +18,11 @@ func TestRuleFromImportedPackageRefs(t *testing.T) { currentlyEditingFileContents := `package example - import data.foo - import data.bar - import data.baz` +import data.foo +import data.bar +import data.baz + +local_rule := true` regoFiles := map[string]string{ "file:///foo/foo.rego": `package foo @@ -60,14 +63,14 @@ deny := false c.SetFileContents("file:///example.rego", currentlyEditingFileContents+"\n\nallow if ") - p := &RuleFromImportedPackageRefs{} + p := &RuleRefs{} completionParams := types.CompletionParams{ TextDocument: types.TextDocumentIdentifier{ URI: "file:///example.rego", }, Position: types.Position{ - Line: 6, + Line: 8, Character: 8, }, } @@ -77,7 +80,14 @@ deny := false t.Fatalf("Unexpected error: %v", err) } - expectedRefs := []string{"foo.bar", "bar.allow", "baz.funkyfunc"} + expectedRefs := []string{ + "data.notimported", // 'partial', based on data.notimported.deny + "data.notimported.deny", + "foo.bar", + "bar.allow", + "baz.funkyfunc", + "local_rule", + } slices.Sort(expectedRefs) foundRefs := make([]string, len(completions)) @@ -89,6 +99,10 @@ deny := false slices.Sort(foundRefs) if !slices.Equal(expectedRefs, foundRefs) { - t.Fatalf("Expected completions to be %v, got: %v", expectedRefs, foundRefs) + t.Fatalf( + "Expected completions to be\n%s\ngot:\n%s", + strings.Join(expectedRefs, "\n"), + strings.Join(foundRefs, "\n"), + ) } } diff --git a/internal/lsp/completions/providers/utils.go b/internal/lsp/completions/providers/utils.go index bc54a60f..944e94a8 100644 --- a/internal/lsp/completions/providers/utils.go +++ b/internal/lsp/completions/providers/utils.go @@ -2,11 +2,19 @@ package providers import ( "regexp" + "slices" "strings" "github.com/styrainc/regal/internal/lsp/cache" + "github.com/styrainc/regal/internal/lsp/types" ) +//nolint:gochecknoglobals +var patternRuleBody = regexp.MustCompile(`^\s+`) + +//nolint:gochecknoglobals +var patternWhiteSpace = regexp.MustCompile(`\s+`) + // completionLineHelper returns the lines of a file and the current line for a given index. This // function is used by multiple completion providers. func completionLineHelper(c *cache.Cache, fileURI string, currentLineNumber uint) ([]string, string) { @@ -25,8 +33,28 @@ func completionLineHelper(c *cache.Cache, fileURI string, currentLineNumber uint return strings.Split(fileContents, "\n"), currentLine } -//nolint:gochecknoglobals -var patternRuleBody = regexp.MustCompile(`^\s+`) +// groupKeyedRefsByDepth groups refs by their 'depth', where depth is the number of dots in the key. +// This is helpful when attempting to show shorter, higher level keys before longer, lower level keys. +func groupKeyedRefsByDepth(refs map[string]types.Ref) ([]int, map[int]map[string]types.Ref) { + byDepth := make(map[int]map[string]types.Ref) -//nolint:gochecknoglobals -var patternWhiteSpace = regexp.MustCompile(`\s+`) + for key, item := range refs { + depth := strings.Count(key, ".") + + if _, ok := byDepth[depth]; !ok { + byDepth[depth] = make(map[string]types.Ref) + } + + byDepth[depth][key] = item + } + + depths := make([]int, 0) + for k := range byDepth { + depths = append(depths, k) + } + + // items from higher depths should be shown first + slices.Sort(depths) + + return depths, byDepth +}