From 16fceb3d06b0c99353346c173a6dccfb11e41376 Mon Sep 17 00:00:00 2001 From: Joe Kralicky Date: Thu, 15 Feb 2024 14:14:02 -0500 Subject: [PATCH] Implement refactor code actions for inlining messages and renumbering fields --- README.md | 3 +- pkg/lsp/cache.go | 2 +- pkg/lsp/completion.go | 2 +- pkg/lsp/hover.go | 2 +- pkg/lsp/packages.go | 4 +- pkg/lsp/refactor.go | 250 +++++++++++++++++++++++++++++++++++++++++- 6 files changed, 254 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 8957821..188fc6a 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,8 @@ Features in progress: - [x] Simplify repeated message literal fields - [ ] Simplify map literal fields - [x] Extract fields to new message - - [ ] Inline fields from message + - [x] Inline fields from message + - [x] Renumber message fields - [x] Code Lens - [x] Generate file/package/workspace - [x] Inlay hints diff --git a/pkg/lsp/cache.go b/pkg/lsp/cache.go index 4768de2..8e9a921 100644 --- a/pkg/lsp/cache.go +++ b/pkg/lsp/cache.go @@ -279,7 +279,7 @@ func (c *Cache) FindTypeDescriptorAtLocation(params protocol.TextDocumentPositio root := enc.AST() token, comment := root.ItemAtOffset(offset) - if token == ast.TokenError && comment.IsValid() { + if comment.IsValid() { return nil, protocol.Range{}, nil } diff --git a/pkg/lsp/completion.go b/pkg/lsp/completion.go index df16b34..52293d3 100644 --- a/pkg/lsp/completion.go +++ b/pkg/lsp/completion.go @@ -93,7 +93,7 @@ func (c *Cache) GetCompletions(params *protocol.CompletionParams) (result *proto searchTarget = currentParseRes } tokenAtOffset, comment := searchTarget.AST().ItemAtOffset(posOffset) - if tokenAtOffset == ast.TokenError && comment.IsValid() { + if comment.IsValid() { // don't complete within comments return nil, nil } diff --git a/pkg/lsp/hover.go b/pkg/lsp/hover.go index c5d0333..4db2775 100644 --- a/pkg/lsp/hover.go +++ b/pkg/lsp/hover.go @@ -38,7 +38,7 @@ func (c *Cache) ComputeHover(params protocol.TextDocumentPositionParams) (*proto } tokenAtOffset, comment := parseRes.AST().ItemAtOffset(offset) - if tokenAtOffset == ast.TokenError && comment.IsValid() { + if tokenAtOffset == ast.TokenError || comment.IsValid() { return nil, nil } diff --git a/pkg/lsp/packages.go b/pkg/lsp/packages.go index 2d7d27d..dae61c6 100644 --- a/pkg/lsp/packages.go +++ b/pkg/lsp/packages.go @@ -30,7 +30,7 @@ func (c *Cache) TryFindPackageReferences(params protocol.TextDocumentPositionPar fileNode := parseRes.AST() tokenAtOffset, comment := fileNode.ItemAtOffset(offset) - if tokenAtOffset == ast.TokenError && comment.IsValid() { + if tokenAtOffset == ast.TokenError || comment.IsValid() { return nil } @@ -148,7 +148,7 @@ func (c *Cache) tryHoverPackageNode(params protocol.TextDocumentPositionParams) fileNode := parseRes.AST() tokenAtOffset, comment := fileNode.ItemAtOffset(offset) - if tokenAtOffset == ast.TokenError && comment.IsValid() { + if tokenAtOffset == ast.TokenError || comment.IsValid() { return nil } diff --git a/pkg/lsp/refactor.go b/pkg/lsp/refactor.go index 0536305..e7d1657 100644 --- a/pkg/lsp/refactor.go +++ b/pkg/lsp/refactor.go @@ -24,10 +24,14 @@ var analyzers = map[protocol.CodeActionKind][]Analyzer{ protocol.RefactorRewrite: { simplifyRepeatedOptions, simplifyRepeatedFieldLiterals, + renumberFields, }, protocol.RefactorExtract: { extractFields, }, + protocol.RefactorInline: { + inlineMessageFields, + }, } type pendingCodeAction struct { @@ -429,7 +433,7 @@ func extractFields(ctx context.Context, request *protocol.CodeActionParams, link newMsgFields = append(newMsgFields, newFld) } newMsgName := findNewUnusedMessageName(desc) - newFieldName := findNewUnusedFieldName(desc) + newFieldName := findNewUnusedFieldName(desc, "newField") newMessage := ast.NewMessageNode(&ast.KeywordNode{Val: "message"}, &ast.IdentNode{Val: newMsgName}, &ast.RuneNode{Rune: '{'}, newMsgFields, &ast.RuneNode{Rune: '}'}) var label *ast.KeywordNode @@ -481,6 +485,244 @@ func extractFields(ctx context.Context, request *protocol.CodeActionParams, link } } +func inlineMessageFields(ctx context.Context, request *protocol.CodeActionParams, linkRes linker.Result, mapper *protocol.Mapper, results chan<- protocol.CodeAction) { + if request.Range == (protocol.Range{}) || request.Range.Start != request.Range.End { + return + } + fileNode := linkRes.AST() + offset, err := mapper.PositionOffset(request.Range.Start) + if err != nil { + return + } + token, comment := linkRes.AST().ItemAtOffset(offset) + if token == ast.TokenError || comment.IsValid() { + return + } + path, ok := findPathIntersectingToken(linkRes, token, request.Range.Start) + if !ok { + return + } + + if len(path) < 3 { + // the path must be at least 3 nodes long (file, message, field) + return + } + + desc, _, err := deepPathSearch(path, linkRes, linkRes) + if err != nil { + return + } + fieldDesc, ok := desc.(protoreflect.FieldDescriptor) + if !ok { + return + } + if fieldDesc.Kind() != protoreflect.MessageKind || fieldDesc.IsMap() || fieldDesc.IsList() { + return + } + + fieldToInline, ok := path[len(path)-1].(*ast.FieldNode) + if !ok { + return + } + containingMessage, ok := path[len(path)-2].(*ast.MessageNode) + if !ok { + return + } + + results <- actionQueue.enqueue("Inline nested message", protocol.RefactorInline, func(ca *protocol.CodeAction) error { + containingMessageDesc := fieldDesc.ContainingMessage() + msgDesc := fieldDesc.Message() + if containingMessageDesc == msgDesc { + // technically nothing stopping us from doing this, but it's almost + // certainly not what the user intended and would just mess up their code + return fmt.Errorf("cannot inline a recursive message into itself") + } + existingFieldDescs := fieldDesc.ContainingMessage().Fields() + newFieldDescs := fieldDesc.Message().Fields() + updatedMsgFields := make([]ast.MessageElement, len(containingMessage.Decls)) + var largestFieldNumber uint64 + for i, decl := range containingMessage.Decls { + if fld, ok := decl.(*ast.FieldNode); ok { + if fld.Tag.Val > largestFieldNumber { + largestFieldNumber = fld.Tag.Val + } + if fld == fieldToInline { + continue + } + } + updatedMsgFields[i] = decl + } + messageInfo := fileNode.NodeInfo(containingMessage) + messageRange := positionsToRange(messageInfo.Start(), fileNode.NodeInfo(containingMessage.CloseBrace).End()) + + startNumber := largestFieldNumber + 1 + mask := map[ast.Node]ast.NodeInfo{ + containingMessage.Keyword: {}, + } + for i, desc := range updatedMsgFields { + if desc != nil { + continue + } + var label *ast.KeywordNode + if isProto2(fileNode) { + label = &ast.KeywordNode{Val: "optional"} + } + toInsert := make([]ast.MessageElement, 0, newFieldDescs.Len()) + for j := range newFieldDescs.Len() { + fld := newFieldDescs.Get(j) + var fieldType string + switch fld.Kind() { + case protoreflect.MessageKind: + fieldType = relativeFullName(fld.Message().FullName(), linkRes.Package()) + case protoreflect.EnumKind: + fieldType = relativeFullName(fld.Enum().FullName(), linkRes.Package()) + default: + fieldType = fld.Kind().String() + } + fldName := fld.Name() + if existingFieldDescs.ByName(fldName) != nil { + fldName = protoreflect.Name(findNewUnusedFieldName(containingMessageDesc, string(fldName))) + } + fldNumber := startNumber + uint64(j) + newField := ast.NewFieldNode(label, &ast.IdentNode{Val: fieldType}, &ast.IdentNode{Val: string(fldName)}, &ast.RuneNode{Rune: '='}, &ast.UintLiteralNode{Val: fldNumber}, nil) + newField.AddSemicolon(&ast.RuneNode{Rune: ';'}) + toInsert = append(toInsert, newField) + mask[newField] = ast.NodeInfo{} + } + + updatedMsgFields[i] = toInsert[0] + if len(toInsert) > 1 { + updatedMsgFields = slices.Insert(updatedMsgFields, i+1, toInsert[1:]...) + } + break + } + + updatedMessage := ast.NewMessageNode( + containingMessage.Keyword, + containingMessage.Name, + containingMessage.OpenBrace, + updatedMsgFields, + containingMessage.CloseBrace, + ) + + updatedMessageText, err := format.PrintNode(format.NodeInfoOverlay(fileNode, mask), updatedMessage) + if err != nil { + return fmt.Errorf("error formatting updated message: %v", err) + } + + ca.Edit = &protocol.WorkspaceEdit{ + Changes: map[protocol.DocumentURI][]protocol.TextEdit{ + request.TextDocument.URI: { + { + Range: messageRange, + NewText: indentTextHanging(updatedMessageText, int(messageRange.Start.Character)), + }, + }, + }, + } + return nil + }) +} + +func renumberFields(ctx context.Context, request *protocol.CodeActionParams, linkRes linker.Result, mapper *protocol.Mapper, results chan<- protocol.CodeAction) { + if request.Range == (protocol.Range{}) || request.Range.Start != request.Range.End { + return + } + fileNode := linkRes.AST() + offset, err := mapper.PositionOffset(request.Range.Start) + if err != nil { + return + } + token, comment := linkRes.AST().ItemAtOffset(offset) + if token == ast.TokenError || comment.IsValid() { + return + } + path, ok := findPathIntersectingToken(linkRes, token, request.Range.Start) + if !ok { + return + } + + desc, _, err := deepPathSearch(path, linkRes, linkRes) + if err != nil { + return + } + msgNode, ok := path[len(path)-1].(*ast.MessageNode) + if !ok { + return + } + if token < msgNode.Name.Start() || token > msgNode.Name.End() { + return + } + + var canRenumber bool + msgDesc, ok := desc.(protoreflect.MessageDescriptor) + if ok { + fields := msgDesc.Fields() + for i := range fields.Len() { + if fields.Get(i).Number() != protowire.Number(i+1) { + canRenumber = true + break + } + } + } + + if !canRenumber { + return + } + + results <- actionQueue.enqueue("Renumber fields", protocol.RefactorRewrite, func(ca *protocol.CodeAction) error { + parentInfo := fileNode.NodeInfo(msgNode) + parentRange := positionsToRange(parentInfo.Start(), fileNode.NodeInfo(msgNode.CloseBrace).End()) + updatedMsgFields := make([]ast.MessageElement, len(msgNode.Decls)) + number := 1 + mask := map[ast.Node]ast.NodeInfo{ + msgNode.Keyword: {}, + } + for i, decl := range msgNode.Decls { + if fld, ok := decl.(*ast.FieldNode); ok { + newFld := ast.NewFieldNode( + fld.Label.KeywordNode, + fld.FldType, + fld.Name, + fld.Equals, + &ast.UintLiteralNode{Val: uint64(number)}, + fld.Options, + ) + mask[fld.Tag] = fileNode.NodeInfo(fld.Tag) + newFld.AddSemicolon(fld.Semicolon) + updatedMsgFields[i] = newFld + number++ + } else { + updatedMsgFields[i] = decl + } + } + updatedMessage := ast.NewMessageNode( + msgNode.Keyword, + msgNode.Name, + msgNode.OpenBrace, + updatedMsgFields, + msgNode.CloseBrace, + ) + + updatedMessageText, err := format.PrintNode(format.NodeInfoOverlay(fileNode, mask), updatedMessage) + if err != nil { + return fmt.Errorf("error formatting updated message: %v", err) + } + + ca.Edit = &protocol.WorkspaceEdit{ + Changes: map[protocol.DocumentURI][]protocol.TextEdit{ + request.TextDocument.URI: { + { + Range: parentRange, + NewText: indentTextHanging(updatedMessageText, int(parentRange.Start.Character)), + }, + }, + }, + } + return nil + }) +} + func findNewUnusedMessageName(desc protoreflect.MessageDescriptor) string { parent := desc.Parent() prefix := "NewMessage" @@ -498,9 +740,11 @@ func findNewUnusedMessageName(desc protoreflect.MessageDescriptor) string { return name } -func findNewUnusedFieldName(desc protoreflect.MessageDescriptor) string { - prefix := "newField" +func findNewUnusedFieldName(desc protoreflect.MessageDescriptor, prefix string) string { name := prefix + if last := prefix[len(prefix)-1]; last >= '0' && last <= '9' { + prefix += "_" + } for i := 1; i < 100; i++ { if desc.Fields().ByName(protoreflect.Name(name)) == nil { return name