Skip to content

Commit

Permalink
completion work in progress: relative package names in identifiers
Browse files Browse the repository at this point in the history
  • Loading branch information
kralicky committed Dec 27, 2023
1 parent cf92e61 commit c8c732f
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 57 deletions.
8 changes: 4 additions & 4 deletions pkg/lsp/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ func (c *Cache) Compile(protos ...string) {
}

func (c *Cache) compileLocked(protos ...string) {
slog.Info("compiling", "protos", len(protos))
slog.Debug("compiling", "protos", len(protos))

resolved := make([]protocompile.ResolvedPath, 0, len(protos))
for _, proto := range protos {
Expand All @@ -347,7 +347,7 @@ func (c *Cache) compileLocked(protos ...string) {
}
// important to lock resultsMu here so that it can be modified in compile hooks
// c.resultsMu.Lock()
slog.Info("done compiling", "protos", len(protos))
slog.Debug("done compiling", "protos", len(protos))
for _, r := range res.Files {
path := r.Path()
found := false
Expand Down Expand Up @@ -1219,7 +1219,7 @@ func (c *Cache) FormatDocument(doc protocol.TextDocumentIdentifier, options prot
return protocol.EditsFromDiffEdits(mapper, edits)
}

func (c *Cache) FindAllDescriptorsByPrefix(ctx context.Context, prefix string, localPackage protoreflect.FullName) []protoreflect.Descriptor {
func (c *Cache) FindAllDescriptorsByPrefix(ctx context.Context, prefix string, localPackage protoreflect.FullName, filter ...func(protoreflect.Descriptor) bool) []protoreflect.Descriptor {
c.resultsMu.RLock()
defer c.resultsMu.RUnlock()
eg, ctx := errgroup.WithContext(ctx)
Expand All @@ -1231,7 +1231,7 @@ func (c *Cache) FindAllDescriptorsByPrefix(ctx context.Context, prefix string, l
if res.Package() == localPackage {
p = string(localPackage) + "." + p
}
resultsByPackage[i], err = res.(linker.Result).FindDescriptorsByPrefix(ctx, p)
resultsByPackage[i], err = res.(linker.Result).FindDescriptorsByPrefix(ctx, p, filter...)
return
})
}
Expand Down
169 changes: 116 additions & 53 deletions pkg/lsp/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ func (c *Cache) GetCompletions(params *protocol.CompletionParams) (result *proto
// if we have a previous link result, use its option index
path, found := findNarrowestEnclosingScope(currentParseRes, tokenAtOffset, adjustColumn(params.Position, columnAdjust))
if found {
c, err := completeOptionNames(path, partialName, maybeCurrentLinkRes, params.Position)
comps, err := c.completeOptionNames(path, partialName, maybeCurrentLinkRes, params.Position)
if err != nil {
return nil, err
}
completions = append(completions, c...)
completions = append(completions, comps...)
}
}
return &protocol.CompletionList{
Expand Down Expand Up @@ -174,7 +174,7 @@ func (c *Cache) GetCompletions(params *protocol.CompletionParams) (result *proto
}
}

items, err := completeOptionNamesByScope(fieldScope, path, partialName, maybeCurrentLinkRes, params.Position)
items, err := c.completeOptionNamesByScope(fieldScope, path, partialName, maybeCurrentLinkRes, params.Position)
if err != nil {
return nil, err
}
Expand All @@ -188,21 +188,21 @@ func (c *Cache) GetCompletions(params *protocol.CompletionParams) (result *proto
inc := node.Name.(*ast.IncompleteIdentNode)
// complete the partial name
if inc.IncompleteVal != nil {
items, err = completeOptionNamesByScope(fieldScope, path, string(inc.IncompleteVal.AsIdentifier()), maybeCurrentLinkRes, params.Position)
items, err = c.completeOptionNamesByScope(fieldScope, path, string(inc.IncompleteVal.AsIdentifier()), maybeCurrentLinkRes, params.Position)
} else {
// the field is empty
if node.IsExtension() { // note: this only checks for the open paren
items, err = completeExtensionNamesByScope(fieldScope, string(node.Open.Rune), maybeCurrentLinkRes, params.Position)
items, err = c.completeExtensionNamesByScope(fieldScope, string(node.Open.Rune), maybeCurrentLinkRes, params.Position)
} else {
items, err = completeOptionNamesByScope(fieldScope, path, "", maybeCurrentLinkRes, params.Position)
items, err = c.completeOptionNamesByScope(fieldScope, path, "", maybeCurrentLinkRes, params.Position)
}
}
} else {
// there is a non-empty name and both parens (if it is an extension)
if node.IsExtension() {
items, err = completeExtensionNamesByScope(fieldScope, string(node.Name.AsIdentifier()), maybeCurrentLinkRes, params.Position)
items, err = c.completeExtensionNamesByScope(fieldScope, string(node.Name.AsIdentifier()), maybeCurrentLinkRes, params.Position)
} else {
items, err = completeOptionNamesByScope(fieldScope, path, string(node.Name.AsIdentifier()), maybeCurrentLinkRes, params.Position)
items, err = c.completeOptionNamesByScope(fieldScope, path, string(node.Name.AsIdentifier()), maybeCurrentLinkRes, params.Position)
}
}
if err != nil {
Expand All @@ -215,7 +215,7 @@ func (c *Cache) GetCompletions(params *protocol.CompletionParams) (result *proto
var completeType string

switch {
case tokenAtOffset == node.Semicolon.Token():
case tokenAtOffset == node.End():
// figure out what the previous token is
switch tokenAtOffset - 1 {
case node.FldType.End():
Expand Down Expand Up @@ -394,30 +394,40 @@ func fieldTypeDetail(fld protoreflect.FieldDescriptor) string {
}
}

var fieldDescType = reflect.TypeOf((*protoreflect.FieldDescriptor)(nil)).Elem()
var adjustIndentationMode = protocol.AdjustIndentation
var snippetMode = protocol.SnippetTextFormat
var (
fieldDescType = reflect.TypeOf((*protoreflect.FieldDescriptor)(nil)).Elem()
adjustIndentationMode = protocol.AdjustIndentation
snippetMode = protocol.SnippetTextFormat
)

func completeOptionNames(path []ast.Node, maybePartialName string, linkRes linker.Result, pos protocol.Position) ([]protocol.CompletionItem, error) {
func (c *Cache) completeOptionNames(path []ast.Node, maybePartialName string, linkRes linker.Result, pos protocol.Position) ([]protocol.CompletionItem, error) {
var scope completionScope
switch path[len(path)-1].(type) {
case *ast.MessageNode:
scope = messageScope
case *ast.FieldNode:
scope = fieldScope
default:
LOOP:
for i := len(path) - 1; i >= 0; i-- {
switch path[i].(type) {
case *ast.MessageNode:
scope = messageScope
break LOOP
case *ast.FieldNode:
scope = fieldScope
break LOOP
}
}
if scope == 0 {
return nil, nil
}
return completeOptionNamesByScope(scope, path, maybePartialName, linkRes, pos)
return c.completeOptionNamesByScope(scope, path, maybePartialName, linkRes, pos)
}

var msgDescriptorFields = (*descriptorpb.MessageOptions)(nil).ProtoReflect().Descriptor().Fields()
var fieldDescriptorFields = (*descriptorpb.FieldOptions)(nil).ProtoReflect().Descriptor().Fields()
var (
msgDescriptorFields = (*descriptorpb.MessageOptions)(nil).ProtoReflect().Descriptor().Fields()
fieldDescriptorFields = (*descriptorpb.FieldOptions)(nil).ProtoReflect().Descriptor().Fields()
)

type completionScope int

const (
messageScope completionScope = iota
messageScope completionScope = iota + 1
fieldScope
)

Expand All @@ -437,7 +447,7 @@ var defaultFieldCompletions = []protocol.CompletionItem{
newBuiltinScalarOptionCompletionItem(fieldDescriptorFields.ByName("unverified_lazy")),
}

func completeExtensionNamesByScope(scope completionScope, maybePartialName string, linkRes linker.Result, pos protocol.Position) ([]protocol.CompletionItem, error) {
func (c *Cache) completeExtensionNamesByScope(scope completionScope, maybePartialName string, linkRes linker.Result, pos protocol.Position) ([]protocol.CompletionItem, error) {
wantExtension := strings.HasPrefix(maybePartialName, "(")
if wantExtension {
maybePartialName = strings.TrimPrefix(maybePartialName, "(")
Expand Down Expand Up @@ -468,7 +478,7 @@ func completeExtensionNamesByScope(scope completionScope, maybePartialName strin
case fieldScope:
extName = "google.protobuf.FieldOptions"
}
candidates, err := linkRes.FindDescriptorsByPrefix(context.TODO(), maybePartialName, func(d protoreflect.Descriptor) bool {
candidates := c.FindAllDescriptorsByPrefix(context.TODO(), maybePartialName, linkRes.Package(), func(d protoreflect.Descriptor) bool {
if fd, ok := d.(protoreflect.ExtensionDescriptor); ok {
isExt := fd.IsExtension()
if wantExtension {
Expand All @@ -479,37 +489,59 @@ func completeExtensionNamesByScope(scope completionScope, maybePartialName strin
}
return false
})
if err != nil {
return nil, err
}

items := []protocol.CompletionItem{}
for _, candidate := range candidates {
fd := candidate.(protoreflect.FieldDescriptor)
switch fd.Kind() {
case protoreflect.MessageKind:
if fd.IsExtension() && wantExtension {
items = append(items, newExtensionFieldCompletionItem(fd, false))
items = append(items, newExtensionFieldCompletionItem(fd, linkRes.Package(), false))
} else if !fd.IsExtension() && !wantExtension {
items = append(items, newMessageFieldCompletionItem(fd))
items = append(items, newMessageFieldCompletionItem(fd, linkRes.Package()))
}
default:
if fd.Cardinality() == protoreflect.Repeated {
items = append(items, newNonMessageRepeatedOptionCompletionItem(fd, maybePartialName, pos))
items = append(items, newNonMessageRepeatedOptionCompletionItem(fd, linkRes.Package(), maybePartialName, pos))
} else if fd.IsExtension() && wantExtension {
items = append(items, newExtensionNonMessageFieldCompletionItem(fd, false))
items = append(items, newExtensionNonMessageFieldCompletionItem(fd, linkRes.Package(), false))
} else if !fd.IsExtension() && !wantExtension {
items = append(items, newNonMessageFieldCompletionItem(fd))
items = append(items, newNonMessageFieldCompletionItem(fd, linkRes.Package()))
}
}
}
return items, nil
}

func completeOptionNamesByScope(scope completionScope, path []ast.Node, maybePartialName string, linkRes linker.Result, pos protocol.Position) ([]protocol.CompletionItem, error) {
parts := strings.Split(maybePartialName, ".")
// splits an option name into parts, respecting extensions grouped by parens
// ex: "(foo.bar).baz" -> ["(foo.bar)", "baz"]
func splitOptionName(name string) []string {
var parts []string
var currentPart []rune
var inParens bool
for _, rn := range name {
switch rn {
case '(':
inParens = true
case ')':
inParens = false
case '.':
if !inParens {
parts = append(parts, string(currentPart))
currentPart = nil
continue
}
}
currentPart = append(currentPart, rn)
}
parts = append(parts, string(currentPart))
return parts
}

func (c *Cache) completeOptionNamesByScope(scope completionScope, path []ast.Node, maybePartialName string, linkRes linker.Result, pos protocol.Position) ([]protocol.CompletionItem, error) {
parts := splitOptionName(maybePartialName)
if len(parts) == 1 && !strings.HasSuffix(maybePartialName, ")") {
return completeExtensionNamesByScope(scope, maybePartialName, linkRes, pos)
return c.completeExtensionNamesByScope(scope, maybePartialName, linkRes, pos)
} else if len(parts) > 1 {
// walk the options path
var currentContext protoreflect.MessageDescriptor
Expand Down Expand Up @@ -557,6 +589,7 @@ func completeOptionNamesByScope(scope completionScope, path []ast.Node, maybePar
items := []protocol.CompletionItem{}
if isExtension {
exts := currentContext.Extensions()
localPkg := currentContext.ParentFile().Package()
for i, l := 0, exts.Len(); i < l; i++ {
ext := exts.Get(i)
if !strings.Contains(string(ext.Name()), lastPart) {
Expand All @@ -568,17 +601,18 @@ func completeOptionNamesByScope(scope completionScope, path []ast.Node, maybePar
// if the field is actually a map, the completion should insert map syntax
items = append(items, newMapFieldCompletionItem(ext))
} else {
items = append(items, newExtensionFieldCompletionItem(ext, true))
items = append(items, newExtensionFieldCompletionItem(ext, localPkg, true))
}
default:
if strings.Contains(string(ext.Name()), lastPart) {
items = append(items, newNonMessageFieldCompletionItem(ext))
items = append(items, newNonMessageFieldCompletionItem(ext, localPkg))
}
}
}
} else {
// match field names
fields := currentContext.Fields()
localPkg := currentContext.ParentFile().Package()
for i, l := 0, fields.Len(); i < l; i++ {
fld := fields.Get(i)
if !strings.Contains(string(fld.Name()), lastPart) {
Expand All @@ -590,16 +624,16 @@ func completeOptionNamesByScope(scope completionScope, path []ast.Node, maybePar
// if the field is actually a map, the completion should insert map syntax
items = append(items, newMapFieldCompletionItem(fld))
} else if fld.IsExtension() {
items = append(items, newExtensionFieldCompletionItem(fld, false))
items = append(items, newExtensionFieldCompletionItem(fld, localPkg, false))
} else {
items = append(items, newMessageFieldCompletionItem(fld))
items = append(items, newMessageFieldCompletionItem(fld, localPkg))
}
default:
if strings.Contains(string(fld.Name()), lastPart) {
if fld.Cardinality() == protoreflect.Repeated {
items = append(items, newNonMessageRepeatedOptionCompletionItem(fld, maybePartialName, pos))
items = append(items, newNonMessageRepeatedOptionCompletionItem(fld, localPkg, maybePartialName, pos))
} else {
items = append(items, newNonMessageFieldCompletionItem(fld))
items = append(items, newNonMessageFieldCompletionItem(fld, localPkg))
}
}
}
Expand All @@ -620,28 +654,30 @@ func newMapFieldCompletionItem(fld protoreflect.FieldDescriptor) protocol.Comple
}
}

func newMessageFieldCompletionItem(fld protoreflect.FieldDescriptor) protocol.CompletionItem {
func newMessageFieldCompletionItem(fld protoreflect.FieldDescriptor, localPkg protoreflect.FullName) protocol.CompletionItem {
name := relativeFullName(fld.FullName(), localPkg)
return protocol.CompletionItem{
Label: string(fld.Name()),
Label: string(fld.FullName()),
Kind: protocol.StructCompletion,
Detail: fieldTypeDetail(fld),
InsertText: string(fld.Name()),
InsertText: string(name),
CommitCharacters: []string{"."},
}
}

func newExtensionFieldCompletionItem(fld protoreflect.FieldDescriptor, needsLeadingOpenParen bool) protocol.CompletionItem {
func newExtensionFieldCompletionItem(fld protoreflect.FieldDescriptor, localPkg protoreflect.FullName, needsLeadingOpenParen bool) protocol.CompletionItem {
var fmtStr string
if needsLeadingOpenParen {
fmtStr = "(%s)"
fmtStr = "(%s"
} else {
fmtStr = "%s)"
fmtStr = "%s"
}
name := relativeFullName(fld.FullName(), localPkg)
return protocol.CompletionItem{
Label: string(fld.Name()),
Label: string(name),
Kind: protocol.InterfaceCompletion,
Detail: fieldTypeDetail(fld),
InsertText: fmt.Sprintf(fmtStr, fld.Name()),
InsertText: fmt.Sprintf(fmtStr, name),
CommitCharacters: []string{"."},
}
}
Expand Down Expand Up @@ -674,7 +710,7 @@ func newBuiltinScalarOptionCompletionItem(fld protoreflect.FieldDescriptor) prot
}
}

func newNonMessageFieldCompletionItem(fld protoreflect.FieldDescriptor) protocol.CompletionItem {
func newNonMessageFieldCompletionItem(fld protoreflect.FieldDescriptor, localPkg protoreflect.FullName) protocol.CompletionItem {
return protocol.CompletionItem{
Label: string(fld.Name()),
Kind: protocol.ValueCompletion,
Expand All @@ -684,7 +720,7 @@ func newNonMessageFieldCompletionItem(fld protoreflect.FieldDescriptor) protocol
}
}

func newNonMessageRepeatedOptionCompletionItem(fld protoreflect.FieldDescriptor, partialName string, pos protocol.Position) protocol.CompletionItem {
func newNonMessageRepeatedOptionCompletionItem(fld protoreflect.FieldDescriptor, localPkg protoreflect.FullName, partialName string, pos protocol.Position) protocol.CompletionItem {
// If we're completing from a top-level option, e.g. 'option (foo).bar'
// where bar is a repeated field, we need to rewrite the expression to
// 'option (foo) = {bar: []}'. The syntax 'option (foo).bar = []' is
Expand Down Expand Up @@ -718,7 +754,7 @@ func newNonMessageRepeatedOptionCompletionItem(fld protoreflect.FieldDescriptor,
}
}

func newExtensionNonMessageFieldCompletionItem(fld protoreflect.FieldDescriptor, needsLeadingOpenParen bool) protocol.CompletionItem {
func newExtensionNonMessageFieldCompletionItem(fld protoreflect.FieldDescriptor, localPkg protoreflect.FullName, needsLeadingOpenParen bool) protocol.CompletionItem {
var fmtStr string
if needsLeadingOpenParen {
fmtStr = "(%s) = ${0};"
Expand Down Expand Up @@ -790,3 +826,30 @@ func completeTypeNames(cache *Cache, partialName string, linkRes linker.Result)
}
return items
}

// Returns the least qualified name that would be required to refer to 'target'
// relative to the package 'fromPkg'.
// ex:
//
// relativeFullName("foo.bar.baz.A", "foo.bar") => "baz.A"
// relativeFullName("foo.bar.baz.A", "foo") => "bar.baz.A"
// relativeFullName("foo.bar.baz.A", "foo.bar.baz") => "A"
// relativeFullName("foo.bar.baz.A", "x.y") => "foo.bar.baz.A"
func relativeFullName(target, fromPkg protoreflect.FullName) protoreflect.FullName {
targetPkg := target.Parent()
// walk targetPkg up until it matches fromPkg, or if empty, it must be fully qualified
stack := []protoreflect.Name{}
for {
if targetPkg == fromPkg {
for i := len(stack) - 1; i >= 0; i-- {
targetPkg = targetPkg.Append(stack[i])
}
return targetPkg.Append(target.Name())
}
if targetPkg == "" {
return target
}
stack = append(stack, targetPkg.Name())
targetPkg = targetPkg.Parent()
}
}

0 comments on commit c8c732f

Please sign in to comment.