Skip to content

Commit

Permalink
lsp: Workspace eval, return rule head locations (#985)
Browse files Browse the repository at this point in the history
* lsp: Workspace eval, return rule head locations

This allows for more reliable highlighting in the UI.

* Address PR comments

Correct rule_heads name, tidy rego test

Signed-off-by: Charlie Egan <charlie@styra.com>

* lsp: Fix linter

Signed-off-by: Charlie Egan <charlie@styra.com>

---------

Signed-off-by: Charlie Egan <charlie@styra.com>
  • Loading branch information
charlieegan3 authored Aug 13, 2024
1 parent fc0dc04 commit c5aa188
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 11 deletions.
22 changes: 22 additions & 0 deletions bundle/regal/ast/rule_heads.rego
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package regal.ast

import rego.v1

# METADATA
# description: |
# For a given rule head name, this rule contains a list of locations where
# there is a rule head with that name.
rule_head_locations[name] contains info if {
some rule in input.rules

name := concat(".", [
"data",
package_name,
ref_static_to_string(rule.head.ref),
])

info := {
"row": rule.head.location.row,
"col": rule.head.location.col,
}
}
35 changes: 35 additions & 0 deletions bundle/regal/ast/rule_heads_test.rego
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package regal.ast_test

import rego.v1

import data.regal.ast

test_rule_head_locations if {
policy := `package policy
import rego.v1
default allow := false
allow if true
reasons contains "foo"
reasons contains "bar"
default my_func(_) := false
my_func(1) := true
ref_rule[foo] := true if {
some foo in [1,2,3]
}
`

result := ast.rule_head_locations with input as regal.parse_module("p.rego", policy)

result == {
"data.policy.allow": {{"col": 9, "row": 5}, {"col": 1, "row": 7}},
"data.policy.reasons": {{"col": 1, "row": 9}, {"col": 1, "row": 10}},
"data.policy.my_func": {{"col": 9, "row": 12}, {"col": 1, "row": 13}},
"data.policy.ref_rule": {{"col": 1, "row": 15}},
}
}
75 changes: 64 additions & 11 deletions internal/lsp/rego/rego.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ type KeywordUse struct {
Location KeywordUseLocation `json:"location"`
}

type RuleHeads map[string][]*ast.Location

type KeywordUseLocation struct {
Row uint `json:"row"`
Col uint `json:"col"`
Expand Down Expand Up @@ -94,39 +96,54 @@ func AllBuiltinCalls(module *ast.Module) []BuiltInCall {
}

//nolint:gochecknoglobals
var keywordPreparedQuery *rego.PreparedEvalQuery
var keywordsPreparedQuery *rego.PreparedEvalQuery

//nolint:gochecknoglobals
var ruleHeadLocationsPreparedQuery *rego.PreparedEvalQuery

//nolint:gochecknoglobals
var keywordPreparedQueryInitOnce sync.Once
var preparedQueriesInitOnce sync.Once

func initialize() {
regalRules := rio.MustLoadRegalBundleFS(rbundle.Bundle)

regoArgs := []func(*rego.Rego){
rego.ParsedBundle("regal", &regalRules),
rego.Query("data.regal.ast.keywords"),
rego.Function2(builtins.RegalParseModuleMeta, builtins.RegalParseModule),
rego.Function1(builtins.RegalLastMeta, builtins.RegalLast),
createArgs := func(args ...func(*rego.Rego)) []func(*rego.Rego) {
return append([]func(*rego.Rego){
rego.ParsedBundle("regal", &regalRules),
rego.Function2(builtins.RegalParseModuleMeta, builtins.RegalParseModule),
rego.Function1(builtins.RegalLastMeta, builtins.RegalLast),
}, args...)
}

keywordRegoArgs := createArgs(rego.Query("data.regal.ast.keywords"))

kwpq, err := rego.New(keywordRegoArgs...).PrepareForEval(context.Background())
if err != nil {
panic(err)
}

preparedQuery, err := rego.New(regoArgs...).PrepareForEval(context.Background())
keywordsPreparedQuery = &kwpq

ruleHeadLocationsRegoArgs := createArgs(rego.Query("data.regal.ast.rule_head_locations"))

rhlpq, err := rego.New(ruleHeadLocationsRegoArgs...).PrepareForEval(context.Background())
if err != nil {
panic(err)
}

keywordPreparedQuery = &preparedQuery
ruleHeadLocationsPreparedQuery = &rhlpq
}

// AllKeywords returns all keywords in the module.
func AllKeywords(ctx context.Context, fileName, contents string, module *ast.Module) (map[string][]KeywordUse, error) {
keywordPreparedQueryInitOnce.Do(initialize)
preparedQueriesInitOnce.Do(initialize)

enhancedInput, err := parse.PrepareAST(fileName, contents, module)
if err != nil {
return nil, fmt.Errorf("failed enhancing input: %w", err)
}

rs, err := keywordPreparedQuery.Eval(ctx, rego.EvalInput(enhancedInput))
rs, err := keywordsPreparedQuery.Eval(ctx, rego.EvalInput(enhancedInput))
if err != nil {
return nil, fmt.Errorf("failed evaluating keywords: %w", err)
}
Expand All @@ -149,6 +166,42 @@ func AllKeywords(ctx context.Context, fileName, contents string, module *ast.Mod
return result, nil
}

// AllRuleHeadLocations returns mapping of rules names to the head locations.
func AllRuleHeadLocations(ctx context.Context, fileName, contents string, module *ast.Module) (RuleHeads, error) {
preparedQueriesInitOnce.Do(initialize)

enhancedInput, err := parse.PrepareAST(fileName, contents, module)
if err != nil {
return nil, fmt.Errorf("failed enhancing input: %w", err)
}

rs, err := ruleHeadLocationsPreparedQuery.Eval(ctx, rego.EvalInput(enhancedInput))
if err != nil {
return nil, fmt.Errorf("failed evaluating keywords: %w", err)
}

if len(rs) == 0 {
return nil, errors.New("no results returned from evaluation")
}

if len(rs) != 1 {
return nil, errors.New("expected exactly one result from evaluation")
}

if len(rs[0].Expressions) != 1 {
return nil, errors.New("expected exactly one expression in result")
}

var result RuleHeads

err = rio.JSONRoundTrip(rs[0].Expressions[0].Value, &result)
if err != nil {
return nil, fmt.Errorf("failed unmarshaling keywords: %w", err)
}

return result, nil
}

// ToInput prepares a module with Regal additions to be used as input for evaluation.
func ToInput(
fileURI string,
Expand Down
35 changes: 35 additions & 0 deletions internal/lsp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/styrainc/regal/internal/lsp/examples"
"github.com/styrainc/regal/internal/lsp/hover"
"github.com/styrainc/regal/internal/lsp/opa/oracle"
"github.com/styrainc/regal/internal/lsp/rego"
"github.com/styrainc/regal/internal/lsp/types"
"github.com/styrainc/regal/internal/lsp/uri"
rparse "github.com/styrainc/regal/internal/parse"
Expand Down Expand Up @@ -478,6 +479,30 @@ func (l *LanguageServer) StartCommandWorker(ctx context.Context) {
break
}

currentModule, ok := l.cache.GetModule(file)
if !ok {
l.logError(fmt.Errorf("failed to get module for file %q", file))

break
}

currentContents, ok := l.cache.GetFileContents(file)
if !ok {
l.logError(fmt.Errorf("failed to get contents for file %q", file))

break
}

allRuleHeadLocations, err := rego.AllRuleHeadLocations(ctx, filepath.Base(file), currentContents, currentModule)
if err != nil {
l.logError(fmt.Errorf("failed to get rule head locations: %w", err))

break
}

// if there are none, then it's a package evaluation
ruleHeadLocations := allRuleHeadLocations[path]

workspacePath := uri.ToPath(l.clientIdentifier, l.workspaceRootURI)
input := FindInput(uri.ToPath(l.clientIdentifier, file), workspacePath)

Expand All @@ -498,10 +523,20 @@ func (l *LanguageServer) StartCommandWorker(ctx context.Context) {
break
}

target := "package"
if len(ruleHeadLocations) > 0 {
target = strings.TrimPrefix(path, currentModule.Package.Path.String()+".")
}

if l.clientIdentifier == clients.IdentifierVSCode {
responseParams := map[string]any{
"result": result,
"line": line,
"target": target,
// only used when the target is 'package'
"package": strings.TrimPrefix(currentModule.Package.Path.String(), "data."),
// only used when the target is a rule
"rule_head_locations": ruleHeadLocations,
}

responseResult := map[string]any{}
Expand Down

0 comments on commit c5aa188

Please sign in to comment.