Skip to content

Commit

Permalink
import hints and semantic token improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
kralicky committed Oct 27, 2023
1 parent 3cf5f70 commit fa105bb
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 25 deletions.
116 changes: 115 additions & 1 deletion pkg/lsp/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,7 @@ func (c *Cache) ComputeInlayHints(doc protocol.TextDocumentIdentifier, rng proto

hints := []protocol.InlayHint{}
hints = append(hints, c.computeMessageLiteralHints(doc, rng)...)
hints = append(hints, c.computeImportHints(doc, rng)...)
return hints, nil
}

Expand Down Expand Up @@ -1031,6 +1032,81 @@ func (c *Cache) computeMessageLiteralHints(doc protocol.TextDocumentIdentifier,
// return hints
}

func (c *Cache) computeImportHints(doc protocol.TextDocumentIdentifier, rng protocol.Range) []protocol.InlayHint {
// show inlay hints for imports that resolve to different paths
var hints []protocol.InlayHint
c.resultsMu.RLock()
defer c.resultsMu.RUnlock()

res, err := c.FindParseResultByURI(doc.URI.SpanURI())
if err != nil {
return nil
}
resAst := res.AST()
if resAst == nil {
return nil
}
var imports []*ast.ImportNode
// get the source positions of the import statements
for _, decl := range resAst.Decls {
if imp, ok := decl.(*ast.ImportNode); ok {
imports = append(imports, imp)
}
}

dependencyPaths := res.FileDescriptorProto().Dependency
// if the ast doesn't contain "google/protobuf/descriptor.proto" but the file descriptor does, filter it

found := false
for _, imp := range imports {
if imp.Name.AsString() == "google/protobuf/descriptor.proto" {
found = true
break
}
}
if !found {
for i, dep := range dependencyPaths {
if dep == "google/protobuf/descriptor.proto" {
dependencyPaths = append(dependencyPaths[:i], dependencyPaths[i+1:]...)
break
}
}
}

for i, imp := range imports {
importPath := imp.Name.AsString()
resolvedPath := dependencyPaths[i]
nameInfo := resAst.NodeInfo(imp.Name)
if resolvedPath != importPath {
hints = append(hints, protocol.InlayHint{
Kind: protocol.Type,
PaddingLeft: true,
PaddingRight: false,
Position: protocol.Position{
Line: uint32(nameInfo.Start().Line) - 1,
Character: uint32(nameInfo.End().Col) + 2,
},
TextEdits: []protocol.TextEdit{
{
Range: adjustColumns(toRange(nameInfo), +1, -1),
NewText: resolvedPath,
},
},
Label: []protocol.InlayHintLabelPart{
{
Tooltip: &protocol.OrPTooltipPLabel{
Value: fmt.Sprintf("Import resolves to %s", resolvedPath),
},
Value: resolvedPath,
},
},
})
}
}

return hints
}

func (c *Cache) DocumentSymbolsForFile(doc protocol.TextDocumentIdentifier) ([]protocol.DocumentSymbol, error) {
c.resultsMu.RLock()
defer c.resultsMu.RUnlock()
Expand Down Expand Up @@ -1287,6 +1363,23 @@ func (c *Cache) FindTypeDescriptorAtLocation(params protocol.TextDocumentPositio

for i := len(item.path) - 1; i >= 0; i-- {
currentNode := item.path[i]
switch currentNode.(type) {
// short-circuit for some nodes that we know don't map to descriptors -
// keywords and numbers
case *ast.KeywordNode,
*ast.SyntaxNode,
*ast.PackageNode,
*ast.EmptyDeclNode,
*ast.RuneNode,
*ast.UintLiteralNode,
*ast.PositiveUintLiteralNode,
*ast.NegativeIntLiteralNode,
*ast.FloatLiteralNode,
*ast.SpecialFloatLiteralNode,
*ast.SignedFloatLiteralNode,
*ast.StringLiteralNode, *ast.CompoundStringLiteralNode: // TODO: this could change in the future
return nil, protocol.Range{}, nil
}
nodeDescriptor := parseRes.Descriptor(currentNode)
if nodeDescriptor == nil {
// this node does not directly map to a descriptor. push it on the stack
Expand Down Expand Up @@ -1395,6 +1488,14 @@ func (c *Cache) FindTypeDescriptorAtLocation(params protocol.TextDocumentPositio
want.desc = linkRes.FindExtendeeDescriptorByName(protoreflect.FullName(wantNode.AsIdentifier()))
}
}
case *ast.StringLiteralNode:
if fd, ok := have.desc.(protoreflect.FileImport); ok {
if fd.FileDescriptor == nil {
// nothing to do
return nil, protocol.Range{}, nil
}
want.desc = fd.FileDescriptor
}
}
case protoreflect.MessageDescriptor:
switch wantNode := want.node.(type) {
Expand Down Expand Up @@ -1629,7 +1730,20 @@ func (c *Cache) FindDefinitionForTypeDescriptor(desc protoreflect.Descriptor) ([
node = containingFileResolver.MethodNode(desc.(protoutil.DescriptorProtoWrapper).AsProto().(*descriptorpb.MethodDescriptorProto)).GetName()
case protoreflect.FieldDescriptor:
if !desc.IsExtension() {
node = containingFileResolver.FieldNode(desc.(protoutil.DescriptorProtoWrapper).AsProto().(*descriptorpb.FieldDescriptorProto))
switch desc.(type) {
case protoutil.DescriptorProtoWrapper:
node = containingFileResolver.FieldNode(desc.(protoutil.DescriptorProtoWrapper).AsProto().(*descriptorpb.FieldDescriptorProto))
default:
// these can be internal filedesc.Field descriptors for e.g. builtin file options
containingFileResolver.RangeFieldReferenceNodesWithDescriptors(func(n ast.Node, fd protoreflect.FieldDescriptor) bool {
// TODO: this is a workaround, figure out why the linker wrapper types aren't being used here
if desc.FullName() == fd.FullName() {
node = n
return false
}
return true
})
}
} else {
exts := desc.ParentFile().Extensions()
for i := 0; i < exts.Len(); i++ {
Expand Down
13 changes: 13 additions & 0 deletions pkg/lsp/range.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,19 @@ func toRange[T ranger](t T) protocol.Range {
return positionsToRange(t.Start(), t.End())
}

func adjustColumns(r protocol.Range, leftAdjust int, rightAdjust int) protocol.Range {
return protocol.Range{
Start: protocol.Position{
Line: r.Start.Line,
Character: r.Start.Character + uint32(leftAdjust),
},
End: protocol.Position{
Line: r.End.Line,
Character: r.End.Character + uint32(rightAdjust),
},
}
}

func positionsToRange(start, end ast.SourcePos) protocol.Range {
return protocol.Range{
Start: protocol.Position{
Expand Down
5 changes: 2 additions & 3 deletions pkg/lsp/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"fmt"
"io"
"log/slog"
"maps"
"net/url"
"os"
"path/filepath"
Expand Down Expand Up @@ -320,15 +319,15 @@ func (r *Resolver) checkWellKnownImportPath(path string) (protocompile.SearchRes
return protocompile.SearchResult{}, os.ErrNotExist
}

const largeFileThreshold = 100 * 1024 // 100KB
const largeFileThreshold = 1024 * 1024 // 1MB

func (r *Resolver) checkFS(path string, whence protocompile.ImportContext) (protocompile.SearchResult, error) {
uri, ok := r.fileURIsByPath[path]
if ok {
if fh, err := r.ReadFile(context.TODO(), uri); err == nil {
content, err := fh.Content()
if len(content) > largeFileThreshold {
return protocompile.SearchResult{}, fmt.Errorf("refusing to load file %q larger than 100KB", path)
return protocompile.SearchResult{}, fmt.Errorf("refusing to load file %q larger than 1MB", path)
}
if err == nil && content != nil {
return protocompile.SearchResult{
Expand Down
35 changes: 14 additions & 21 deletions pkg/lsp/semantic.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,24 +211,6 @@ func (s *semanticItems) mktokens(node ast.Node, path []ast.Node, tt tokenType, m
func (s *semanticItems) mktokens_cel(str *ast.StringLiteralNode, start, end int32, tt tokenType, mods tokenModifier) {
lineInfo := s.parseRes.AST().NodeInfo(str)
lineStart := lineInfo.Start()
// lineEnd := lineInfo.End()

// quoteStartTk := semanticItem{
// lang: tokenLanguageCel,
// line: uint32(lineStart.Line - 1),
// start: uint32(lineStart.Col - 1),
// len: 1,
// typ: semanticTypeOperator,
// mods: 0,
// }
// quoteEndTk := semanticItem{
// lang: tokenLanguageCel,
// line: uint32(lineStart.Line - 1),
// start: uint32(lineEnd.Col - 1),
// len: 1,
// typ: semanticTypeOperator,
// mods: 0,
// }

nodeTk := semanticItem{
lang: tokenLanguageCel,
Expand All @@ -238,7 +220,6 @@ func (s *semanticItems) mktokens_cel(str *ast.StringLiteralNode, start, end int3
typ: tt,
mods: mods,
}
// s.items = append(s.items, quoteStartTk, nodeTk, quoteEndTk)
s.items = append(s.items, nodeTk)
}

Expand Down Expand Up @@ -345,6 +326,7 @@ func (s *semanticItems) inspect(cache *Cache, node ast.Node, walkOptions ...ast.
},
DoVisitStringLiteralNode: func(node *ast.StringLiteralNode) error {
if _, ok := embeddedStringLiterals[node]; ok {
s.mkcomments(node)
return nil
}
s.mktokens(node, tracker.Path(), semanticTypeString, 0)
Expand All @@ -368,7 +350,7 @@ func (s *semanticItems) inspect(cache *Cache, node ast.Node, walkOptions ...ast.
},
DoVisitRuneNode: func(node *ast.RuneNode) error {
switch node.Rune {
case '}', '{', '.', ',', '<', '>', '(', ')', '[', ']', ';':
case '}', '{', '.', ',', '<', '>', '(', ')', '[', ']', ';', ':':
s.mkcomments(node)
default:
s.mktokens(node, tracker.Path(), semanticTypeOperator, 0)
Expand Down Expand Up @@ -454,7 +436,8 @@ func (s *semanticItems) inspect(cache *Cache, node ast.Node, walkOptions ...ast.
}
}
if hasExpressionField && hasIdField {
for _, lit := range s.inspectCelExpr(node) {
tokens := s.inspectCelExpr(node)
for _, lit := range tokens {
embeddedStringLiterals[lit] = struct{}{}
}
}
Expand Down Expand Up @@ -491,6 +474,14 @@ func (s *semanticItems) inspect(cache *Cache, node ast.Node, walkOptions ...ast.
return nil
},
DoVisitTerminalNode: func(node ast.TerminalNode) error {
// handle bool ident nodes here, since this is the lowest precedence visitor.
// DoVisitIdentNode matches too many things we have specific visitors for,
// and bool values aren't their own node type
if ident, ok := node.(*ast.IdentNode); ok {
if ident.Val == "true" || ident.Val == "false" {
s.mktokens(ident, append(tracker.Path(), ident), semanticTypeKeyword, 0)
}
}
s.mkcomments(node)
return nil
},
Expand All @@ -517,6 +508,8 @@ func (s *semanticItems) inspectCelExpr(messageLit *ast.MessageLiteralNode) []*as
}
celExpr = strings.Join(lines, "\n")
}
// escape backslashes that would have been un-escaped by the parser
celExpr = strings.ReplaceAll(celExpr, `\`, `\\`)
parsed, issues := celEnv.Parse(celExpr)
if issues != nil && issues.Err() != nil {
return nil
Expand Down

0 comments on commit fa105bb

Please sign in to comment.