diff --git a/protols/cache.go b/protols/cache.go index 3c02c2f..9731b42 100644 --- a/protols/cache.go +++ b/protols/cache.go @@ -62,7 +62,7 @@ func (c *Cache) Reindex(ctx context.Context) error { c.sourceFilenames[source] = path.Join(goPkg, path.Base(source)) } accessor := ragu.SourceAccessor(c.sourcePackages) - res := ragu.NewResolver(accessor) + res := protocompile.WithStandardImports(ragu.NewResolver(accessor)) compiler := protocompile.Compiler{ Resolver: res, MaxParallelism: -1, diff --git a/protols/server.go b/protols/server.go index 542cbf0..1989f5d 100644 --- a/protols/server.go +++ b/protols/server.go @@ -221,8 +221,8 @@ func (s *Server) Hover(ctx context.Context, params *protocol.HoverParams) (resul } return &protocol.Hover{ Contents: protocol.MarkupContent{ - Kind: protocol.PlainText, - Value: str, + Kind: protocol.Markdown, + Value: fmt.Sprintf("```proto\n%s\n```", str), }, }, nil } diff --git a/protols/source.go b/protols/source.go index e00a829..3f26a1a 100644 --- a/protols/source.go +++ b/protols/source.go @@ -10,6 +10,7 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protodesc" "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/descriptorpb" ) @@ -106,24 +107,56 @@ func findRelevantDescriptorAtLocation(params *protocol.TextDocumentPositionParam // special cases: + switch descriptor := descriptor.(type) { // 1. Imports, which resolve to ambiguous file descriptors - switch descriptor.(type) { case *descriptorpb.FileDescriptorProto: fileNode := path[closestIdentifiableIndex+1] var filename string switch fileNode := fileNode.(type) { case *ast.ImportNode: filename = fileNode.Name.AsString() - case *ast.StringLiteralNode: - filename = fileNode.AsString() + case *ast.FileNode: + filename = fileNode.Name() + case *ast.PackageNode: + return nil, fmt.Errorf("no descriptor available for package node") default: return nil, fmt.Errorf("unexpected node type %T", fileNode) } f, err := cache.files.AsResolver().FindFileByPath(filename) if err != nil { - return nil, fmt.Errorf("could not find file %q: %w", filename, err) + f, err = protoregistry.GlobalFiles.FindFileByPath(filename) + if err != nil { + return nil, fmt.Errorf("could not find file %q: %w", filename, err) + } } descriptor = protodesc.ToFileDescriptorProto(f) + // 2. Self references + case *descriptorpb.DescriptorProto: + return fd.Messages().ByName(protoreflect.Name(descriptor.GetName())), nil + case *descriptorpb.EnumDescriptorProto: + return fd.Enums().ByName(protoreflect.Name(descriptor.GetName())), nil + case *descriptorpb.ServiceDescriptorProto: + return fd.Services().ByName(protoreflect.Name(descriptor.GetName())), nil + case *descriptorpb.MethodDescriptorProto: + // go up one more level + rpcNode := path[closestIdentifiableIndex-1] + if svcNode, ok := rpcNode.(*ast.ServiceNode); ok { + return fd.Services().ByName(protoreflect.Name(svcNode.Name.AsIdentifier())).Methods().ByName(protoreflect.Name(descriptor.GetName())), nil + } + // 3. Fields (cursor is over the field name, not the type) + case *descriptorpb.FieldDescriptorProto: + if ident, ok := path[len(path)-1].(*ast.IdentNode); ok { + if ident.Val == descriptor.GetName() { + // go up one more level + msgNode := path[closestIdentifiableIndex-1] + switch msgNode := msgNode.(type) { + case *ast.MessageNode: + if field := fd.Messages().ByName(protoreflect.Name(msgNode.Name.AsIdentifier())).Fields().ByName(protoreflect.Name(descriptor.GetName())); field != nil { + return field, nil + } + } + } + } } } @@ -195,14 +228,21 @@ func findRelevantDescriptorAtLocation(params *protocol.TextDocumentPositionParam break } } - default: return nil, fmt.Errorf("unimplemented descriptor type %T", desc) } desc, err := cache.files.AsResolver().FindDescriptorByName(definitionFullName) if err != nil { - return nil, fmt.Errorf("failed to find descriptor for %q: %w", definitionFullName, err) + if msg, err := protoregistry.GlobalTypes.FindMessageByName(definitionFullName); err == nil { + desc = msg.Descriptor() + } else if msg, err := protoregistry.GlobalTypes.FindEnumByName(definitionFullName); err == nil { + desc = msg.Descriptor() + } else if msg, err := protoregistry.GlobalTypes.FindExtensionByName(definitionFullName); err == nil { + desc = msg.TypeDescriptor() + } else { + return nil, fmt.Errorf("failed to find descriptor for %q: %w", definitionFullName, err) + } } return desc, nil