diff --git a/editors/vscode/client/src/extension.ts b/editors/vscode/client/src/extension.ts index 795e329..6301329 100644 --- a/editors/vscode/client/src/extension.ts +++ b/editors/vscode/client/src/extension.ts @@ -73,6 +73,10 @@ export function buildLanguageClient( synchronize: {fileEvents: workspace.createFileSystemWatcher('**/*.proto')}, revealOutputChannelOn: RevealOutputChannelOn.Never, outputChannel: vscode.window.createOutputChannel('Protobuf Language Server'), + markdown: { + isTrusted: true, + supportHtml: true, + } } as LanguageClientOptions, ); return c; diff --git a/pkg/lsp/cache.go b/pkg/lsp/cache.go index 18f2ec9..3248330 100644 --- a/pkg/lsp/cache.go +++ b/pkg/lsp/cache.go @@ -1075,53 +1075,468 @@ func (c *Cache) GetSyntheticFileContents(ctx context.Context, uri string) (strin } } +type list[T protoreflect.Descriptor] interface { + Len() int + Get(i int) T +} + +func findByName[T protoreflect.Descriptor](l list[T], name string) (entry T) { + isFullName := strings.Contains(name, ".") + for i := 0; i < l.Len(); i++ { + if isFullName { + if string(l.Get(i).FullName()) == name { + entry = l.Get(i) + break + } + } else { + if string(l.Get(i).Name()) == name { + entry = l.Get(i) + break + } + } + } + return +} + +type stackEntry struct { + node ast.Node + desc protoreflect.Descriptor + prev *stackEntry +} + +func (s *stackEntry) isResolved() bool { + return s.desc != nil +} + +func (s *stackEntry) nextResolved() *stackEntry { + res := s + for { + if res == nil { + panic("bug: stackEntry.nextResolved() called with no resolved entry") + } + if res.isResolved() { + return res + } + res = res.prev + } +} + +type stack []*stackEntry + +func (s *stack) push(node ast.Node, desc protoreflect.Descriptor) { + e := &stackEntry{ + node: node, + desc: desc, + } + if len(*s) > 0 { + (*s)[len(*s)-1].prev = e + } + *s = append(*s, e) +} + func (c *Cache) FindTypeDescriptorAtLocation(params protocol.TextDocumentPositionParams) (protoreflect.Descriptor, protocol.Range, error) { - enc, err := computeSemanticTokens(c, params.TextDocument, &protocol.Range{ - Start: params.Position, - End: params.Position, - }) + parseRes, err := c.FindParseResultByURI(params.TextDocument.URI.SpanURI()) + if err != nil { + return nil, protocol.Range{}, err + } + linkRes, err := c.FindResultByURI(params.TextDocument.URI.SpanURI()) if err != nil { return nil, protocol.Range{}, err } - item, found := findNarrowestSemanticToken(enc.items, params.Position) - if !found { - return nil, protocol.Range{}, nil + mapper, err := c.GetMapper(params.TextDocument.URI.SpanURI()) + if err != nil { + return nil, protocol.Range{}, err } - parseRes, err := c.FindParseResultByURI(params.TextDocument.URI.SpanURI()) + + enc := semanticItems{ + parseRes: parseRes, + linkRes: linkRes, + } + offset, err := mapper.PositionOffset(params.Position) if err != nil { return nil, protocol.Range{}, err } + root := parseRes.AST() + + token := root.TokenAtOffset(offset) + computeSemanticTokens(c, &enc, ast.WithIntersection(token)) - switch node := item.node.(type) { - case ast.IdentValueNode: - rng := toRange(parseRes.AST().NodeInfo(node)) - name := string(node.AsIdentifier()) - var desc protoreflect.Descriptor - var unqualifiedName protoreflect.Name - var pkg protoreflect.FullName - if strings.Contains(name, ".") { - // treat as fully qualified name - fn := protoreflect.FullName(name) - unqualifiedName = fn.Name() - pkg = fn.Parent() + item, found := findNarrowestSemanticToken(parseRes, enc.items, params.Position) + if !found { + return nil, protocol.Range{}, nil + } + + // traverse the path backwards to find the closest top-level mapped descriptor, + // then traverse forwards to find the deeply nested descriptor for the original + // ast node + stack := stack{} + // var haveDescriptor protoreflect.Descriptor + + for i := len(item.path) - 1; i >= 0; i-- { + currentNode := item.path[i] + nodeDescriptor := parseRes.Descriptor(currentNode) + if nodeDescriptor == nil { + // this node does not directly map to a descriptor. push it on the stack + // and go up one level + stack.push(currentNode, nil) } else { - unqualifiedName = protoreflect.Name(name) - pkg = protoreflect.FullName(parseRes.FileDescriptorProto().GetPackage()) + // this node does directly map to a descriptor. + var desc protoreflect.Descriptor + switch nodeDescriptor := nodeDescriptor.(type) { + case *descriptorpb.FileDescriptorProto: + desc = linkRes.ParentFile() + case *descriptorpb.DescriptorProto: + var typeName string + // check if it's a synthetic map field + if nodeDescriptor.GetOptions().GetMapEntry() { + // if it is, we're looking for the value message + typeName = strings.TrimPrefix(nodeDescriptor.Field[1].GetTypeName(), ".") + } else { + typeName = nodeDescriptor.GetName() + } + desc = findByName[protoreflect.MessageDescriptor](linkRes.Messages(), typeName) + case *descriptorpb.EnumDescriptorProto: + desc = findByName[protoreflect.EnumDescriptor](linkRes.Enums(), nodeDescriptor.GetName()) + case *descriptorpb.ServiceDescriptorProto: + desc = linkRes.Services().ByName(protoreflect.Name(nodeDescriptor.GetName())) + case *descriptorpb.UninterpretedOption_NamePart: + desc = linkRes.FindOptionNameFieldDescriptor(nodeDescriptor) + case *descriptorpb.UninterpretedOption: + desc = linkRes.FindOptionMessageDescriptor(nodeDescriptor) + // case *descriptorpb.FieldDescriptorProto: + + default: + // not a top-level descriptor. push it on the stack and go up one level + stack.push(currentNode, nil) + continue + } + if desc == nil { + return nil, protocol.Range{}, fmt.Errorf("could not find descriptor for %T", nodeDescriptor) + } + stack.push(currentNode, desc) + break } - if desc == nil { - descs := c.FindAllDescriptorsByPrefix(context.TODO(), string(unqualifiedName), pkg) - if len(descs) > 0 { - desc = descs[0] + } + + // fmt.Printf("descriptor: [%T] %v\n", haveDescriptor, haveDescriptor.FullName()) + + // fast path: the node is directly mapped to a resolved top-level descriptor + if len(stack) == 1 && stack[0].desc != nil { + return stack[0].desc, toRange(root.NodeInfo(stack[0].node)), nil + } + + for i := len(stack) - 1; i >= 0; i-- { + want := stack[i] + if want.isResolved() { + continue + } + have := want.nextResolved() + switch haveDesc := have.desc.(type) { + case protoreflect.FileDescriptor: + switch wantNode := want.node.(type) { + case ast.FileElement: + switch wantNode := wantNode.(type) { + case *ast.OptionNode: + want.desc = haveDesc.Options().(*descriptorpb.FileOptions).ProtoReflect().Descriptor() + case *ast.ImportNode: + want.desc = findByName[protoreflect.FileImport](haveDesc.Imports(), wantNode.Name.AsString()) + case *ast.MessageNode: + want.desc = findByName[protoreflect.MessageDescriptor](haveDesc.Messages(), string(wantNode.Name.AsIdentifier())) + case *ast.EnumNode: + want.desc = findByName[protoreflect.EnumDescriptor](haveDesc.Enums(), string(wantNode.Name.AsIdentifier())) + case *ast.ExtendNode: + want.desc = findByName[protoreflect.FieldDescriptor](haveDesc.Extensions(), string(wantNode.Extendee.AsIdentifier())) + case *ast.ServiceNode: + want.desc = findByName[protoreflect.ServiceDescriptor](haveDesc.Services(), string(wantNode.Name.AsIdentifier())) + } + } + case protoreflect.MessageDescriptor: + switch wantNode := want.node.(type) { + case ast.MessageElement: + switch wantNode := wantNode.(type) { + case *ast.OptionNode: + want.desc = haveDesc.Options().(*descriptorpb.MessageOptions).ProtoReflect().Descriptor() + case *ast.FieldNode: + want.desc = findByName[protoreflect.FieldDescriptor](haveDesc.Fields(), string(wantNode.Name.AsIdentifier())) + case *ast.MapFieldNode: + want.desc = findByName[protoreflect.FieldDescriptor](haveDesc.Fields(), string(wantNode.Name.AsIdentifier())) + case *ast.OneofNode: + want.desc = findByName[protoreflect.OneofDescriptor](haveDesc.Oneofs(), string(wantNode.Name.AsIdentifier())) + case *ast.GroupNode: + want.desc = findByName[protoreflect.FieldDescriptor](haveDesc.Fields(), string(wantNode.Name.AsIdentifier())) + case *ast.MessageNode: + want.desc = findByName[protoreflect.MessageDescriptor](haveDesc.Messages(), string(wantNode.Name.AsIdentifier())) + case *ast.EnumNode: + want.desc = findByName[protoreflect.EnumDescriptor](haveDesc.Enums(), string(wantNode.Name.AsIdentifier())) + case *ast.ExtendNode: + want.desc = findByName[protoreflect.FieldDescriptor](haveDesc.Extensions(), string(wantNode.Extendee.AsIdentifier())) + case *ast.ExtensionRangeNode: + case *ast.ReservedNode: + } + case *ast.MapTypeNode: + want.desc = haveDesc + case *ast.FieldReferenceNode: + if wantNode.IsAnyTypeReference() { + want.desc = linkRes.FindMessageDescriptorByTypeReferenceURLNode(wantNode) + } else { + want.desc = findByName[protoreflect.FieldDescriptor](haveDesc.Fields(), string(wantNode.Name.AsIdentifier())) + } + case *ast.MessageLiteralNode: + want.desc = haveDesc + case *ast.MessageFieldNode: + name := wantNode.Name + if name.IsAnyTypeReference() { + want.desc = linkRes.FindMessageDescriptorByTypeReferenceURLNode(name) + } else { + want.desc = findByName[protoreflect.FieldDescriptor](haveDesc.Fields(), string(wantNode.Name.Value())) + } + case ast.IdentValueNode: + want.desc = haveDesc + } + case protoreflect.ExtensionTypeDescriptor: + switch wantNode := want.node.(type) { + case ast.IdentValueNode: + id := wantNode.AsIdentifier() + exts := haveDesc.ParentFile().Extensions() + for i := 0; i < exts.Len(); i++ { + ext := exts.Get(i) + if ext.FullName() == protoreflect.FullName(id) { + want.desc = ext + break + } + } + } + case protoreflect.FieldDescriptor: + switch wantNode := want.node.(type) { + case ast.FieldDeclNode: + switch wantNode := wantNode.(type) { + case *ast.FieldNode: + want.desc = findByName[protoreflect.FieldDescriptor](haveDesc.Message().Fields(), string(wantNode.Name.AsIdentifier())) + case *ast.GroupNode: + want.desc = findByName[protoreflect.FieldDescriptor](haveDesc.Message().Fields(), string(wantNode.Name.AsIdentifier())) + case *ast.MapFieldNode: + want.desc = findByName[protoreflect.FieldDescriptor](haveDesc.Message().Fields(), string(wantNode.Name.AsIdentifier())) + case *ast.SyntheticMapField: + want.desc = findByName[protoreflect.FieldDescriptor](haveDesc.Message().Fields(), string(wantNode.Ident.AsIdentifier())) + } + case *ast.FieldReferenceNode: + want.desc = haveDesc + // if wantNode.IsAnyTypeReference() { + // want.desc = linkRes.FindMessageDescriptorByTypeReferenceURLNode(wantNode) + // } else { + // want.desc = findByName[protoreflect.FieldDescriptor](haveDesc.Message().Fields(), string(wantNode.Name.AsIdentifier())) + // } + case *ast.MessageLiteralNode: + want.desc = haveDesc + case *ast.MessageFieldNode: + name := wantNode.Name + if name.IsAnyTypeReference() { + want.desc = linkRes.FindMessageDescriptorByTypeReferenceURLNode(name) + } else { + want.desc = findByName[protoreflect.FieldDescriptor](haveDesc.Message().Fields(), string(wantNode.Name.Value())) + } + case *ast.CompactOptionsNode: + want.desc = haveDesc.Options().(*descriptorpb.FieldOptions).ProtoReflect().Descriptor() + case ast.IdentValueNode: + // need to disambiguate + switch haveNode := have.node.(type) { + case *ast.FieldReferenceNode: + want.desc = haveDesc + case ast.FieldDeclNode: + switch want.node { + case haveNode.FieldType(): + switch { + case haveDesc.IsExtension(): + // keep the field descriptor + case haveDesc.IsMap(): + want.desc = haveDesc.MapValue() + case haveDesc.Kind() == protoreflect.MessageKind: + want.desc = haveDesc.Message() + case haveDesc.Kind() == protoreflect.EnumKind: + want.desc = haveDesc.Enum() + } + case haveNode.FieldName(): + // keep the field descriptor + } + } + } + case protoreflect.EnumDescriptor: + switch wantNode := want.node.(type) { + case ast.EnumElement: + switch wantNode := wantNode.(type) { + case *ast.OptionNode: + want.desc = haveDesc.Options().(*descriptorpb.EnumOptions).ProtoReflect().Descriptor() + case *ast.EnumValueNode: + want.desc = findByName[protoreflect.EnumValueDescriptor](haveDesc.Values(), string(wantNode.Name.AsIdentifier())) + case *ast.ReservedNode: + } + // default: + // if enumNode, ok := parentNode.(*ast.EnumValueNode); ok { + // if want == enumNode.Name { + // want.resolved = have.Values().ByName(protoreflect.Name(enumNode.Name.AsIdentifier())) + // } + // } + case ast.IdentValueNode: + want.desc = haveDesc + } + case protoreflect.EnumValueDescriptor: + switch wantNode := want.node.(type) { + case ast.EnumValueDeclNode: + switch wantNode.(type) { + case *ast.EnumValueNode: + want.desc = haveDesc // ?? + case ast.NoSourceNode: + } + case *ast.CompactOptionsNode: + want.desc = haveDesc.Options().(*descriptorpb.EnumValueOptions).ProtoReflect().Descriptor() + case ast.IdentValueNode: + want.desc = haveDesc } + case protoreflect.ServiceDescriptor: + switch wantNode := want.node.(type) { + case ast.ServiceElement: + switch wantNode := wantNode.(type) { + case *ast.OptionNode: + want.desc = haveDesc.Options().(*descriptorpb.ServiceOptions).ProtoReflect().Descriptor() + case *ast.RPCNode: + want.desc = findByName[protoreflect.MethodDescriptor](haveDesc.Methods(), string(wantNode.Name.AsIdentifier())) + } + case ast.IdentValueNode: + want.desc = haveDesc + } + case protoreflect.MethodDescriptor: + switch wantNode := want.node.(type) { + case ast.RPCElement: + switch wantNode.(type) { + case *ast.OptionNode: + want.desc = haveDesc.Options().(*descriptorpb.MethodOptions).ProtoReflect().Descriptor() + default: + } + case *ast.RPCTypeNode: + if haveNode, ok := have.node.(*ast.RPCNode); ok { + switch want.node { + case haveNode.Input: + want.desc = haveDesc.Input() + case haveNode.Output: + want.desc = haveDesc.Output() + } + } + case *ast.CompactOptionsNode: + want.desc = haveDesc.Options().(*descriptorpb.MethodOptions).ProtoReflect().Descriptor() + case ast.IdentValueNode: + want.desc = haveDesc + } + default: + return nil, protocol.Range{}, fmt.Errorf("unknown descriptor type %T", want.desc) } - if desc == nil { - return nil, protocol.Range{}, nil + if want.desc == nil { + return nil, protocol.Range{}, fmt.Errorf("failed to find descriptor for %T/%T", want.desc, want.node) } - return desc, rng, nil - default: - return nil, protocol.Range{}, nil + } + + return stack[0].desc, toRange(parseRes.AST().NodeInfo(stack[0].node)), nil + + // rng := toRange(parseRes.AST().NodeInfo(item.node)) + // switch item.typ { + // case semanticTypeType, semanticTypeClass, semanticTypeEnum, semanticTypeEnumMember, + // semanticTypeInterface, semanticTypeStruct, semanticTypeTypeParameter: + // switch node := item.node.(type) { + // case ast.IdentValueNode: + // name := string(node.AsIdentifier()) + // var desc protoreflect.Descriptor + // var unqualifiedName protoreflect.Name + // var pkg protoreflect.FullName + // if strings.Contains(name, ".") { + // // treat as fully qualified name + // fn := protoreflect.FullName(name) + // unqualifiedName = fn.Name() + // pkg = fn.Parent() + // } else { + // unqualifiedName = protoreflect.Name(name) + // pkg = protoreflect.FullName(parseRes.FileDescriptorProto().GetPackage()) + // } + // if desc == nil { + // descs := c.FindAllDescriptorsByPrefix(context.TODO(), string(unqualifiedName), pkg) + // for _, d := range descs { + // // only show if the name matches exactly + // if strings.HasSuffix(string(d.FullName()), name) { + // desc = d + // break + // } + // } + // } + // if desc == nil { + // return nil, protocol.Range{}, nil + // } + // switch desc := desc.(type) { + // case protoreflect.FieldDescriptor: + // // For field descriptors that are part of options, show the underlying message type + // if desc.IsExtension() && desc.Kind() == protoreflect.MessageKind { + // m := desc.Message() + // if m != nil { + // return m, rng, nil + // } + // } + // } + // return desc, rng, nil + // default: + // return nil, protocol.Range{}, nil + // } + + // case semanticTypeProperty: + // d := parseRes.Descriptor(item.node) + // if d == nil { + // return nil, protocol.Range{}, nil + // } + // if item.path == nil { + // return nil, protocol.Range{}, nil + // } + + // switch d := d.(type) { + // case *descriptorpb.UninterpretedOption: + // optionNode, ok := parseRes.OptionNode(d).(*ast.OptionNode) + // if !ok { + // return nil, protocol.Range{}, nil + // } + // // we're hovering over an option. to continue, we need linker results for this file + // linkRes, err := c.FindResultByURI(params.TextDocument.URI.SpanURI()) + // if err != nil { + // return nil, protocol.Range{}, err + // } + // for i := len(item.path) - 2; i >= 0; i-- { + // parent := item.path[i] + // if fieldDescNode, ok := parent.(ast.FieldDeclNode); ok { + // fieldDescriptor := linkRes.FieldDescriptor(fieldDescNode) + // if fieldDescriptor != nil { + // opts := fieldDescriptor.GetOptions() + // opts = opts + // } + // } + // } + // srcInfo := linkRes.FindOptionSourceInfo(optionNode) + // if srcInfo == nil { + // return nil, protocol.Range{}, nil + // } + // // the field we are looking for is one of the child elements of this option + // switch info := srcInfo.Children.(type) { + // case *sourceinfo.ArrayLiteralSourceInfo: + // case *sourceinfo.MessageLiteralSourceInfo: + // for fieldNode, srcInfo := range info.Fields { + // if fieldNode.Name.Name == item.node { + // srcInfo = srcInfo + // } + // } + // } + + // } + + // return nil, protocol.Range{}, nil + // default: + // return nil, protocol.Range{}, nil + // } + } func (c *Cache) FindDefinitionForTypeDescriptor(desc protoreflect.Descriptor) ([]protocol.Location, error) { @@ -1144,7 +1559,17 @@ func (c *Cache) FindDefinitionForTypeDescriptor(desc protoreflect.Descriptor) ([ case protoreflect.MethodDescriptor: node = containingFileResolver.MethodNode(protoutil.ProtoFromMethodDescriptor(desc)) case protoreflect.FieldDescriptor: - node = containingFileResolver.FieldNode(protoutil.ProtoFromFieldDescriptor(desc)) + if !desc.IsExtension() { + node = containingFileResolver.FieldNode(protoutil.ProtoFromFieldDescriptor(desc)) + } else { + exts := desc.ParentFile().Extensions() + for i := 0; i < exts.Len(); i++ { + ext := exts.Get(i) + if ext.FullName() == desc.FullName() { + node = containingFileResolver.FieldNode(protoutil.ProtoFromFieldDescriptor(ext)) + } + } + } case protoreflect.EnumValueDescriptor: node = containingFileResolver.EnumValueNode(protoutil.ProtoFromEnumValueDescriptor(desc)) case protoreflect.OneofDescriptor: @@ -1155,6 +1580,9 @@ func (c *Cache) FindDefinitionForTypeDescriptor(desc protoreflect.Descriptor) ([ default: return nil, fmt.Errorf("unexpected descriptor type %T", desc) } + if node == nil { + return nil, fmt.Errorf("failed to find node for %q", desc.FullName()) + } info := containingFileResolver.AST().NodeInfo(node) uri, err := c.resolver.PathToURI(containingFileResolver.Path()) diff --git a/pkg/lsp/semantic.go b/pkg/lsp/semantic.go index 5ee4807..9bde21c 100644 --- a/pkg/lsp/semantic.go +++ b/pkg/lsp/semantic.go @@ -7,6 +7,7 @@ import ( "github.com/bufbuild/protocompile/ast" "github.com/bufbuild/protocompile/linker" "github.com/bufbuild/protocompile/parser" + "golang.org/x/exp/slices" "golang.org/x/tools/gopls/pkg/lsp/protocol" ) @@ -52,106 +53,77 @@ const ( semanticModifierDefaultLibrary ) -type semItem struct { - line, start uint32 +type semanticItem struct { + line, start uint32 // 0-indexed len uint32 typ tokenType mods tokenModifier // An AST node associated with this token. Used for hover, definitions, etc. node ast.Node + + path []ast.Node } -type encoded struct { +type semanticItems struct { // the generated data - items []semItem + items []semanticItem - parseRes parser.Result - res linker.Result - mapper *protocol.Mapper - rng *protocol.Range - start, end ast.Token + parseRes parser.Result // cannot be nil + linkRes linker.Result // can be nil if there are no linker results available } func semanticTokensFull(cache *Cache, doc protocol.TextDocumentIdentifier) (*protocol.SemanticTokens, error) { - enc, err := computeSemanticTokens(cache, doc, nil) + parseRes, err := cache.FindParseResultByURI(doc.URI.SpanURI()) if err != nil { return nil, err } - ret := &protocol.SemanticTokens{ - Data: enc.Data(), - } - return ret, err -} + maybeLinkRes, _ := cache.FindResultByURI(doc.URI.SpanURI()) -func semanticTokensRange(cache *Cache, doc protocol.TextDocumentIdentifier, rng protocol.Range) (*protocol.SemanticTokens, error) { - enc, err := computeSemanticTokens(cache, doc, &rng) - if err != nil { - return nil, err + enc := semanticItems{ + parseRes: parseRes, + linkRes: maybeLinkRes, } + computeSemanticTokens(cache, &enc) + ret := &protocol.SemanticTokens{ Data: enc.Data(), } return ret, err } -func computeSemanticTokens(cache *Cache, td protocol.TextDocumentIdentifier, rng *protocol.Range) (*encoded, error) { - parseRes, err := cache.FindParseResultByURI(td.URI.SpanURI()) +func semanticTokensRange(cache *Cache, doc protocol.TextDocumentIdentifier, rng protocol.Range) (*protocol.SemanticTokens, error) { + parseRes, err := cache.FindParseResultByURI(doc.URI.SpanURI()) if err != nil { return nil, err } + maybeLinkRes, _ := cache.FindResultByURI(doc.URI.SpanURI()) - mapper, err := cache.GetMapper(td.URI.SpanURI()) + mapper, err := cache.GetMapper(doc.URI.SpanURI()) if err != nil { return nil, err } a := parseRes.AST() - var startToken, endToken ast.Token - if rng == nil { - startToken = a.Start() - endToken = a.End() - } else { - startOff, endOff, _ := mapper.RangeOffsets(*rng) - startToken = a.TokenAtOffset(startOff) - endToken = a.TokenAtOffset(endOff) - } + startOff, endOff, _ := mapper.RangeOffsets(rng) + startToken := a.TokenAtOffset(startOff) + endToken := a.TokenAtOffset(endOff) - e := &encoded{ - rng: rng, + enc := semanticItems{ parseRes: parseRes, - mapper: mapper, - start: startToken, - end: endToken, - } - - if res, err := cache.FindResultByURI(td.URI.SpanURI()); err == nil { - e.res = res + linkRes: maybeLinkRes, } - if a.Syntax != nil { - start, end := a.Syntax.Start(), a.Syntax.End() - if end >= e.start && start <= e.end { - e.mkcomments(a.Syntax) - e.mktokens(a.Syntax.Keyword, semanticTypeKeyword, 0) - e.mktokens(a.Syntax.Equals, semanticTypeOperator, 0) - e.mktokens(a.Syntax.Syntax, semanticTypeString, 0) - } - } - for _, node := range a.Decls { - // only look at the decls that overlap the range - start, end := node.Start(), node.End() - if end < e.start || start > e.end { - continue - } - e.inspect(cache, node) - } - if endToken == a.End() { - e.mkcomments(a.EOF) + computeSemanticTokens(cache, &enc, ast.WithRange(startToken, endToken)) + ret := &protocol.SemanticTokens{ + Data: enc.Data(), } + return ret, err +} - return e, nil +func computeSemanticTokens(cache *Cache, e *semanticItems, walkOptions ...ast.WalkOption) { + e.inspect(cache, e.parseRes.AST(), walkOptions...) } -func findNarrowestSemanticToken(tokens []semItem, pos protocol.Position) (narrowest semItem, found bool) { +func findNarrowestSemanticToken(parseRes parser.Result, tokens []semanticItem, pos protocol.Position) (narrowest semanticItem, found bool) { // find the narrowest token that contains the position and also has a node // associated with it. The set of tokens will contain all the tokens that // contain the position, scoped to the narrowest top-level declaration (message, service, etc.) @@ -159,22 +131,30 @@ func findNarrowestSemanticToken(tokens []semItem, pos protocol.Position) (narrow for _, token := range tokens { if pos.Line != token.line { + if token.line > pos.Line { + // Stop searching once we've passed the line + break + } continue // Skip tokens not on the same line } if pos.Character < token.start || pos.Character > token.start+token.len { continue // Skip tokens that don't contain the position } if token.len < narrowestLen { - // Found a narrower token, update narrowest and narrowestLen + // Found a narrower token narrowest, narrowestLen = token, token.len found = true + + if _, isTerminal := token.node.(ast.TerminalNode); isTerminal { + break + } } } return } -func (s *encoded) mktokens(node ast.Node, tt tokenType, mods tokenModifier) { +func (s *semanticItems) mktokens(node ast.Node, path []ast.Node, tt tokenType, mods tokenModifier) { info := s.parseRes.AST().NodeInfo(node) if !info.IsValid() { return @@ -182,25 +162,26 @@ func (s *encoded) mktokens(node ast.Node, tt tokenType, mods tokenModifier) { length := (info.End().Col - 1) - (info.Start().Col - 1) - nodeTk := semItem{ + nodeTk := semanticItem{ line: uint32(info.Start().Line - 1), start: uint32(info.Start().Col - 1), len: uint32(length), typ: tt, mods: mods, node: node, + path: slices.Clone(path), } s.items = append(s.items, nodeTk) s.mkcomments(node) } -func (s *encoded) mkcomments(node ast.Node) { +func (s *semanticItems) mkcomments(node ast.Node) { info := s.parseRes.AST().NodeInfo(node) leadingComments := info.LeadingComments() for i := 0; i < leadingComments.Len(); i++ { comment := leadingComments.Index(i) - commentTk := semItem{ + commentTk := semanticItem{ line: uint32(comment.Start().Line - 1), start: uint32(comment.Start().Col - 1), len: uint32((comment.End().Col) - (comment.Start().Col - 1)), @@ -212,7 +193,7 @@ func (s *encoded) mkcomments(node ast.Node) { trailingComments := info.TrailingComments() for i := 0; i < trailingComments.Len(); i++ { comment := trailingComments.Index(i) - commentTk := semItem{ + commentTk := semanticItem{ line: uint32(comment.Start().Line - 1), start: uint32(comment.Start().Col - 1), len: uint32((comment.End().Col) - (comment.Start().Col - 1)), @@ -222,26 +203,42 @@ func (s *encoded) mkcomments(node ast.Node) { } } -func (s *encoded) inspect(cache *Cache, node ast.Node) { +func (s *semanticItems) inspect(cache *Cache, node ast.Node, walkOptions ...ast.WalkOption) { + tracker := &ast.AncestorTracker{} + walkOptions = append(walkOptions, tracker.AsWalkOptions()...) + // NB: when calling mktokens in composite node visitors: + // - ensure node paths are manually adjusted if creating tokens for a child node + // - ensure tokens for child nodes are created in the correct order ast.Walk(node, &ast.SimpleVisitor{ + DoVisitSyntaxNode: func(node *ast.SyntaxNode) error { + s.mkcomments(node.Syntax) + s.mktokens(node.Keyword, nil, semanticTypeKeyword, 0) + s.mktokens(node.Equals, nil, semanticTypeOperator, 0) + s.mktokens(node.Syntax, nil, semanticTypeString, 0) + return nil + }, + DoVisitFileNode: func(node *ast.FileNode) error { + s.mkcomments(node.EOF) + return nil + }, DoVisitStringLiteralNode: func(node *ast.StringLiteralNode) error { - s.mktokens(node, semanticTypeString, 0) + s.mktokens(node, tracker.Path(), semanticTypeString, 0) return nil }, DoVisitUintLiteralNode: func(node *ast.UintLiteralNode) error { - s.mktokens(node, semanticTypeNumber, 0) + s.mktokens(node, tracker.Path(), semanticTypeNumber, 0) return nil }, DoVisitFloatLiteralNode: func(node *ast.FloatLiteralNode) error { - s.mktokens(node, semanticTypeNumber, 0) + s.mktokens(node, tracker.Path(), semanticTypeNumber, 0) return nil }, DoVisitSpecialFloatLiteralNode: func(node *ast.SpecialFloatLiteralNode) error { - s.mktokens(node, semanticTypeNumber, 0) + s.mktokens(node, tracker.Path(), semanticTypeNumber, 0) return nil }, DoVisitKeywordNode: func(node *ast.KeywordNode) error { - s.mktokens(node, semanticTypeKeyword, 0) + s.mktokens(node, tracker.Path(), semanticTypeKeyword, 0) return nil }, DoVisitRuneNode: func(node *ast.RuneNode) error { @@ -249,75 +246,84 @@ func (s *encoded) inspect(cache *Cache, node ast.Node) { case '}', ';', '{', '.', ',', '<', '>', '(', ')': s.mkcomments(node) default: - s.mktokens(node, semanticTypeOperator, 0) + s.mktokens(node, tracker.Path(), semanticTypeOperator, 0) } return nil }, DoVisitOneofNode: func(node *ast.OneofNode) error { - s.mktokens(node.Name, semanticTypeClass, 0) + s.mktokens(node.Name, append(tracker.Path(), node.Name), semanticTypeClass, 0) return nil }, DoVisitMessageNode: func(node *ast.MessageNode) error { - s.mktokens(node.Name, semanticTypeClass, 0) + s.mktokens(node.Name, append(tracker.Path(), node.Name), semanticTypeClass, 0) return nil }, DoVisitFieldNode: func(node *ast.FieldNode) error { - s.mktokens(node.Name, semanticTypeProperty, 0) - s.mktokens(node.FldType, semanticTypeType, 0) + s.mktokens(node.FldType, append(tracker.Path(), node.FldType), semanticTypeType, 0) + s.mktokens(node.Name, append(tracker.Path(), node.Name), semanticTypeProperty, 0) return nil }, DoVisitFieldReferenceNode: func(node *ast.FieldReferenceNode) error { - if node.IsExtension() { - if node.IsAnyTypeReference() { - s.mktokens(node.URLPrefix, semanticTypeType, 0) - s.mktokens(node.Slash, semanticTypeType, 0) - s.mktokens(node.Name, semanticTypeType, 0) - } else { - s.mktokens(node.Name, semanticTypeType, 0) - } + if node.IsAnyTypeReference() { + s.mktokens(node.URLPrefix, append(tracker.Path(), node.URLPrefix), semanticTypeNamespace, 0) + s.mktokens(node.Name, append(tracker.Path(), node.Name), semanticTypeType, 0) + } else if node.IsExtension() { + s.mktokens(node.Name, append(tracker.Path(), node.Name), semanticTypeType, 0) } else { - s.mktokens(node.Name, semanticTypeProperty, 0) + s.mktokens(node.Name, append(tracker.Path(), node.Name), semanticTypeProperty, 0) } return nil }, + // DoVisitOptionNode: func(node *ast.OptionNode) error { + // return nil + // }, + // DoVisitMessageFieldNode: func(node *ast.MessageFieldNode) error { + // // : + // return nil + // }, + // DoVisitMessageLiteralNode: func(node *ast.MessageLiteralNode) error { + // s.mktokens(node.Open, semanticTypeOperator, 0) + // s.mktokens(node.Close, semanticTypeOperator, 0) + // return nil + // }, DoVisitMapFieldNode: func(node *ast.MapFieldNode) error { - s.mktokens(node.Name, semanticTypeProperty, 0) - s.mktokens(node.MapType.KeyType, semanticTypeType, 0) - s.mktokens(node.MapType.ValueType, semanticTypeType, 0) + s.mktokens(node.Name, append(tracker.Path(), node.Name), semanticTypeProperty, 0) + s.mktokens(node.MapType.KeyType, append(tracker.Path(), node.MapType, node.MapType.KeyType), semanticTypeType, 0) + s.mktokens(node.MapType.ValueType, append(tracker.Path(), node.MapType, node.MapType.ValueType), semanticTypeType, 0) return nil }, DoVisitRPCTypeNode: func(node *ast.RPCTypeNode) error { - s.mktokens(node.MessageType, semanticTypeType, 0) + s.mktokens(node.MessageType, append(tracker.Path(), node.MessageType), semanticTypeType, 0) return nil }, DoVisitRPCNode: func(node *ast.RPCNode) error { - s.mktokens(node.Name, semanticTypeFunction, 0) + s.mktokens(node.Name, append(tracker.Path(), node.Name), semanticTypeFunction, 0) return nil }, - DoVisitServiceNode: func(sn *ast.ServiceNode) error { - s.mktokens(sn.Name, semanticTypeClass, 0) + DoVisitServiceNode: func(node *ast.ServiceNode) error { + s.mktokens(node.Name, append(tracker.Path(), node.Name), semanticTypeClass, 0) return nil }, DoVisitPackageNode: func(node *ast.PackageNode) error { - s.mktokens(node.Name, semanticTypeNamespace, 0) + s.mktokens(node.Name, append(tracker.Path(), node.Name), semanticTypeNamespace, 0) return nil }, DoVisitEnumNode: func(node *ast.EnumNode) error { - s.mktokens(node.Name, semanticTypeClass, 0) + s.mktokens(node.Name, append(tracker.Path(), node.Name), semanticTypeClass, 0) return nil }, DoVisitEnumValueNode: func(node *ast.EnumValueNode) error { - s.mktokens(node.Name, semanticTypeEnumMember, 0) + s.mktokens(node.Name, append(tracker.Path(), node.Name), semanticTypeEnumMember, 0) return nil }, DoVisitTerminalNode: func(node ast.TerminalNode) error { s.mkcomments(node) return nil }, - }) + }, walkOptions...) } -func (e *encoded) Data() []uint32 { +func (e *semanticItems) Data() []uint32 { // binary operators, at least, will be out of order sort.Slice(e.items, func(i, j int) bool { if e.items[i].line != e.items[j].line { @@ -329,7 +335,7 @@ func (e *encoded) Data() []uint32 { // (see Integer Encoding for Tokens in the LSP spec) x := make([]uint32, 5*len(e.items)) var j int - var last semItem + var last semanticItem for i := 0; i < len(e.items); i++ { item := e.items[i] if j == 0 { diff --git a/pkg/lsp/wellknown.go b/pkg/lsp/wellknown.go index 406d4ac..c2f4d6d 100644 --- a/pkg/lsp/wellknown.go +++ b/pkg/lsp/wellknown.go @@ -1,8 +1,10 @@ package lsp -// import all well-known types +// import some extra well-known types import ( _ "google.golang.org/genproto/googleapis/api/annotations" + _ "google.golang.org/genproto/googleapis/api/httpbody" + _ "google.golang.org/genproto/googleapis/api/label" _ "google.golang.org/genproto/googleapis/rpc/code" _ "google.golang.org/genproto/googleapis/rpc/context" _ "google.golang.org/genproto/googleapis/rpc/context/attribute_context" diff --git a/protocompile b/protocompile index 8d7f8ad..948f048 160000 --- a/protocompile +++ b/protocompile @@ -1 +1 @@ -Subproject commit 8d7f8ad6b0fa248be7f0a04a31c82522ac4b244d +Subproject commit 948f048a67fde94db3bbc8955b8e8613a56451e9 diff --git a/protoreflect b/protoreflect index 6393c39..cd79ce6 160000 --- a/protoreflect +++ b/protoreflect @@ -1 +1 @@ -Subproject commit 6393c39ef464e6aeee865e203e8a416d4054d94b +Subproject commit cd79ce667f5e5bc5791f10fabcde05fb46474711