Skip to content

Commit

Permalink
Implement refactor code actions for inlining messages and renumbering…
Browse files Browse the repository at this point in the history
… fields
  • Loading branch information
kralicky committed Feb 15, 2024
1 parent 21f77c8 commit 16fceb3
Show file tree
Hide file tree
Showing 6 changed files with 254 additions and 9 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ Features in progress:
- [x] Simplify repeated message literal fields
- [ ] Simplify map literal fields
- [x] Extract fields to new message
- [ ] Inline fields from message
- [x] Inline fields from message
- [x] Renumber message fields
- [x] Code Lens
- [x] Generate file/package/workspace
- [x] Inlay hints
Expand Down
2 changes: 1 addition & 1 deletion pkg/lsp/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ func (c *Cache) FindTypeDescriptorAtLocation(params protocol.TextDocumentPositio
root := enc.AST()

token, comment := root.ItemAtOffset(offset)
if token == ast.TokenError && comment.IsValid() {
if comment.IsValid() {
return nil, protocol.Range{}, nil
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/lsp/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func (c *Cache) GetCompletions(params *protocol.CompletionParams) (result *proto
searchTarget = currentParseRes
}
tokenAtOffset, comment := searchTarget.AST().ItemAtOffset(posOffset)
if tokenAtOffset == ast.TokenError && comment.IsValid() {
if comment.IsValid() {
// don't complete within comments
return nil, nil
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/lsp/hover.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func (c *Cache) ComputeHover(params protocol.TextDocumentPositionParams) (*proto
}

tokenAtOffset, comment := parseRes.AST().ItemAtOffset(offset)
if tokenAtOffset == ast.TokenError && comment.IsValid() {
if tokenAtOffset == ast.TokenError || comment.IsValid() {
return nil, nil
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/lsp/packages.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (c *Cache) TryFindPackageReferences(params protocol.TextDocumentPositionPar
fileNode := parseRes.AST()

tokenAtOffset, comment := fileNode.ItemAtOffset(offset)
if tokenAtOffset == ast.TokenError && comment.IsValid() {
if tokenAtOffset == ast.TokenError || comment.IsValid() {
return nil
}

Expand Down Expand Up @@ -148,7 +148,7 @@ func (c *Cache) tryHoverPackageNode(params protocol.TextDocumentPositionParams)
fileNode := parseRes.AST()

tokenAtOffset, comment := fileNode.ItemAtOffset(offset)
if tokenAtOffset == ast.TokenError && comment.IsValid() {
if tokenAtOffset == ast.TokenError || comment.IsValid() {
return nil
}

Expand Down
250 changes: 247 additions & 3 deletions pkg/lsp/refactor.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@ var analyzers = map[protocol.CodeActionKind][]Analyzer{
protocol.RefactorRewrite: {
simplifyRepeatedOptions,
simplifyRepeatedFieldLiterals,
renumberFields,
},
protocol.RefactorExtract: {
extractFields,
},
protocol.RefactorInline: {
inlineMessageFields,
},
}

type pendingCodeAction struct {
Expand Down Expand Up @@ -429,7 +433,7 @@ func extractFields(ctx context.Context, request *protocol.CodeActionParams, link
newMsgFields = append(newMsgFields, newFld)
}
newMsgName := findNewUnusedMessageName(desc)
newFieldName := findNewUnusedFieldName(desc)
newFieldName := findNewUnusedFieldName(desc, "newField")
newMessage := ast.NewMessageNode(&ast.KeywordNode{Val: "message"}, &ast.IdentNode{Val: newMsgName}, &ast.RuneNode{Rune: '{'}, newMsgFields, &ast.RuneNode{Rune: '}'})

var label *ast.KeywordNode
Expand Down Expand Up @@ -481,6 +485,244 @@ func extractFields(ctx context.Context, request *protocol.CodeActionParams, link
}
}

func inlineMessageFields(ctx context.Context, request *protocol.CodeActionParams, linkRes linker.Result, mapper *protocol.Mapper, results chan<- protocol.CodeAction) {
if request.Range == (protocol.Range{}) || request.Range.Start != request.Range.End {
return
}
fileNode := linkRes.AST()
offset, err := mapper.PositionOffset(request.Range.Start)
if err != nil {
return
}
token, comment := linkRes.AST().ItemAtOffset(offset)
if token == ast.TokenError || comment.IsValid() {
return
}
path, ok := findPathIntersectingToken(linkRes, token, request.Range.Start)
if !ok {
return
}

if len(path) < 3 {
// the path must be at least 3 nodes long (file, message, field)
return
}

desc, _, err := deepPathSearch(path, linkRes, linkRes)
if err != nil {
return
}
fieldDesc, ok := desc.(protoreflect.FieldDescriptor)
if !ok {
return
}
if fieldDesc.Kind() != protoreflect.MessageKind || fieldDesc.IsMap() || fieldDesc.IsList() {
return
}

fieldToInline, ok := path[len(path)-1].(*ast.FieldNode)
if !ok {
return
}
containingMessage, ok := path[len(path)-2].(*ast.MessageNode)
if !ok {
return
}

results <- actionQueue.enqueue("Inline nested message", protocol.RefactorInline, func(ca *protocol.CodeAction) error {
containingMessageDesc := fieldDesc.ContainingMessage()
msgDesc := fieldDesc.Message()
if containingMessageDesc == msgDesc {
// technically nothing stopping us from doing this, but it's almost
// certainly not what the user intended and would just mess up their code
return fmt.Errorf("cannot inline a recursive message into itself")
}
existingFieldDescs := fieldDesc.ContainingMessage().Fields()
newFieldDescs := fieldDesc.Message().Fields()
updatedMsgFields := make([]ast.MessageElement, len(containingMessage.Decls))
var largestFieldNumber uint64
for i, decl := range containingMessage.Decls {
if fld, ok := decl.(*ast.FieldNode); ok {
if fld.Tag.Val > largestFieldNumber {
largestFieldNumber = fld.Tag.Val
}
if fld == fieldToInline {
continue
}
}
updatedMsgFields[i] = decl
}
messageInfo := fileNode.NodeInfo(containingMessage)
messageRange := positionsToRange(messageInfo.Start(), fileNode.NodeInfo(containingMessage.CloseBrace).End())

startNumber := largestFieldNumber + 1
mask := map[ast.Node]ast.NodeInfo{
containingMessage.Keyword: {},
}
for i, desc := range updatedMsgFields {
if desc != nil {
continue
}
var label *ast.KeywordNode
if isProto2(fileNode) {
label = &ast.KeywordNode{Val: "optional"}
}
toInsert := make([]ast.MessageElement, 0, newFieldDescs.Len())
for j := range newFieldDescs.Len() {
fld := newFieldDescs.Get(j)
var fieldType string
switch fld.Kind() {
case protoreflect.MessageKind:
fieldType = relativeFullName(fld.Message().FullName(), linkRes.Package())
case protoreflect.EnumKind:
fieldType = relativeFullName(fld.Enum().FullName(), linkRes.Package())
default:
fieldType = fld.Kind().String()
}
fldName := fld.Name()
if existingFieldDescs.ByName(fldName) != nil {
fldName = protoreflect.Name(findNewUnusedFieldName(containingMessageDesc, string(fldName)))
}
fldNumber := startNumber + uint64(j)
newField := ast.NewFieldNode(label, &ast.IdentNode{Val: fieldType}, &ast.IdentNode{Val: string(fldName)}, &ast.RuneNode{Rune: '='}, &ast.UintLiteralNode{Val: fldNumber}, nil)
newField.AddSemicolon(&ast.RuneNode{Rune: ';'})
toInsert = append(toInsert, newField)
mask[newField] = ast.NodeInfo{}
}

updatedMsgFields[i] = toInsert[0]
if len(toInsert) > 1 {
updatedMsgFields = slices.Insert(updatedMsgFields, i+1, toInsert[1:]...)
}
break
}

updatedMessage := ast.NewMessageNode(
containingMessage.Keyword,
containingMessage.Name,
containingMessage.OpenBrace,
updatedMsgFields,
containingMessage.CloseBrace,
)

updatedMessageText, err := format.PrintNode(format.NodeInfoOverlay(fileNode, mask), updatedMessage)
if err != nil {
return fmt.Errorf("error formatting updated message: %v", err)
}

ca.Edit = &protocol.WorkspaceEdit{
Changes: map[protocol.DocumentURI][]protocol.TextEdit{
request.TextDocument.URI: {
{
Range: messageRange,
NewText: indentTextHanging(updatedMessageText, int(messageRange.Start.Character)),
},
},
},
}
return nil
})
}

func renumberFields(ctx context.Context, request *protocol.CodeActionParams, linkRes linker.Result, mapper *protocol.Mapper, results chan<- protocol.CodeAction) {
if request.Range == (protocol.Range{}) || request.Range.Start != request.Range.End {
return
}
fileNode := linkRes.AST()
offset, err := mapper.PositionOffset(request.Range.Start)
if err != nil {
return
}
token, comment := linkRes.AST().ItemAtOffset(offset)
if token == ast.TokenError || comment.IsValid() {
return
}
path, ok := findPathIntersectingToken(linkRes, token, request.Range.Start)
if !ok {
return
}

desc, _, err := deepPathSearch(path, linkRes, linkRes)
if err != nil {
return
}
msgNode, ok := path[len(path)-1].(*ast.MessageNode)
if !ok {
return
}
if token < msgNode.Name.Start() || token > msgNode.Name.End() {
return
}

var canRenumber bool
msgDesc, ok := desc.(protoreflect.MessageDescriptor)
if ok {
fields := msgDesc.Fields()
for i := range fields.Len() {
if fields.Get(i).Number() != protowire.Number(i+1) {
canRenumber = true
break
}
}
}

if !canRenumber {
return
}

results <- actionQueue.enqueue("Renumber fields", protocol.RefactorRewrite, func(ca *protocol.CodeAction) error {
parentInfo := fileNode.NodeInfo(msgNode)
parentRange := positionsToRange(parentInfo.Start(), fileNode.NodeInfo(msgNode.CloseBrace).End())
updatedMsgFields := make([]ast.MessageElement, len(msgNode.Decls))
number := 1
mask := map[ast.Node]ast.NodeInfo{
msgNode.Keyword: {},
}
for i, decl := range msgNode.Decls {
if fld, ok := decl.(*ast.FieldNode); ok {
newFld := ast.NewFieldNode(
fld.Label.KeywordNode,
fld.FldType,
fld.Name,
fld.Equals,
&ast.UintLiteralNode{Val: uint64(number)},
fld.Options,
)
mask[fld.Tag] = fileNode.NodeInfo(fld.Tag)
newFld.AddSemicolon(fld.Semicolon)
updatedMsgFields[i] = newFld
number++
} else {
updatedMsgFields[i] = decl
}
}
updatedMessage := ast.NewMessageNode(
msgNode.Keyword,
msgNode.Name,
msgNode.OpenBrace,
updatedMsgFields,
msgNode.CloseBrace,
)

updatedMessageText, err := format.PrintNode(format.NodeInfoOverlay(fileNode, mask), updatedMessage)
if err != nil {
return fmt.Errorf("error formatting updated message: %v", err)
}

ca.Edit = &protocol.WorkspaceEdit{
Changes: map[protocol.DocumentURI][]protocol.TextEdit{
request.TextDocument.URI: {
{
Range: parentRange,
NewText: indentTextHanging(updatedMessageText, int(parentRange.Start.Character)),
},
},
},
}
return nil
})
}

func findNewUnusedMessageName(desc protoreflect.MessageDescriptor) string {
parent := desc.Parent()
prefix := "NewMessage"
Expand All @@ -498,9 +740,11 @@ func findNewUnusedMessageName(desc protoreflect.MessageDescriptor) string {
return name
}

func findNewUnusedFieldName(desc protoreflect.MessageDescriptor) string {
prefix := "newField"
func findNewUnusedFieldName(desc protoreflect.MessageDescriptor, prefix string) string {
name := prefix
if last := prefix[len(prefix)-1]; last >= '0' && last <= '9' {
prefix += "_"
}
for i := 1; i < 100; i++ {
if desc.Fields().ByName(protoreflect.Name(name)) == nil {
return name
Expand Down

0 comments on commit 16fceb3

Please sign in to comment.