Skip to content

Commit

Permalink
Merge pull request #28 from joyme123/fix-include-parse
Browse files Browse the repository at this point in the history
fix include parse
  • Loading branch information
joyme123 committed Nov 16, 2023
2 parents 28505c3 + 6f4e85b commit 7d46dff
Show file tree
Hide file tree
Showing 7 changed files with 304 additions and 31 deletions.
19 changes: 6 additions & 13 deletions lsp/codejump/definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package codejump
import (
"context"
"errors"
"strings"

"github.com/joyme123/thrift-ls/lsp/cache"
"github.com/joyme123/thrift-ls/lsp/lsputils"
Expand Down Expand Up @@ -61,11 +60,9 @@ func serviceDefinition(ctx context.Context, ss *cache.Snapshot, file uri.URI, as
func ServiceDefinitionIdentifier(ctx context.Context, ss *cache.Snapshot, file uri.URI, ast *parser.Document, targetNode parser.Node) (uri.URI, *parser.Identifier, string, error) {
identifierName := targetNode.(*parser.IdentifierName)

include, identifier, found := strings.Cut(identifierName.Text, ".")
include, identifier := lsputils.ParseIdent(file, ast.Includes, identifierName.Text)
var astFile uri.URI
if !found {
identifier = include
include = ""
if include == "" {
astFile = file
} else {
path := lsputils.GetIncludePath(ast, include)
Expand Down Expand Up @@ -113,11 +110,9 @@ func TypeNameDefinitionIdentifier(ctx context.Context, ss *cache.Snapshot, file
return "", nil, "", nil
}

include, identifier, found := strings.Cut(typeV, ".")
include, identifier := lsputils.ParseIdent(file, ast.Includes, typeV)
var astFile uri.URI
if !found {
identifier = include
include = ""
if include == "" {
astFile = file
} else {
path := lsputils.GetIncludePath(ast, include)
Expand Down Expand Up @@ -183,11 +178,9 @@ func ConstValueTypeDefinitionIdentifier(ctx context.Context, ss *cache.Snapshot,
return "", nil, nil
}

include, identifier, found := strings.Cut(constValue.Value.(string), ".")
include, identifier := lsputils.ParseIdent(file, ast.Includes, constValue.Value.(string))
var astFile uri.URI
if !found {
identifier = include
include = ""
if include == "" {
astFile = file
} else {
path := lsputils.GetIncludePath(ast, include)
Expand Down
79 changes: 79 additions & 0 deletions lsp/codejump/definition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ service Demo {
list<user.UserType> UserTypes(1:user.Test3 arg1=user.Test3.TWO, 2:string arg2=user.DefaultName)
}`

// user.extra.thrift
file3 := `struct Test {}`

file4 := `include "user.extra.thrift"
include "user.thrift"
struct Person {
1: required user.extra.Test field1,
2: required user.Test field2,
}`

ss := cache.BuildSnapshotForTest([]*cache.FileChange{
{
URI: "file:///tmp/user.thrift",
Expand All @@ -59,6 +70,18 @@ service Demo {
Content: []byte(file2),
From: cache.FileChangeTypeDidOpen,
},
{
URI: "file:///tmp/user.extra.thrift",
Version: 0,
Content: []byte(file3),
From: cache.FileChangeTypeDidOpen,
},
{
URI: "file:///tmp/app.thrift",
Version: 0,
Content: []byte(file4),
From: cache.FileChangeTypeDidOpen,
},
})

type args struct {
Expand Down Expand Up @@ -269,6 +292,62 @@ service Demo {
},
assertion: assert.NoError,
},
{
name: "case include 1",
args: args{
ctx: context.TODO(),
ss: ss,
file: "file:///tmp/app.thrift",
pos: protocol.Position{
Line: 4,
Character: 25,
},
},
want: []protocol.Location{
{
URI: "file:///tmp/user.extra.thrift",
Range: protocol.Range{
Start: protocol.Position{
Line: 0,
Character: 7,
},
End: protocol.Position{
Line: 0,
Character: 11,
},
},
},
},
assertion: assert.NoError,
},
{
name: "case include 2",
args: args{
ctx: context.TODO(),
ss: ss,
file: "file:///tmp/app.thrift",
pos: protocol.Position{
Line: 5,
Character: 19,
},
},
want: []protocol.Location{
{
URI: "file:///tmp/user.thrift",
Range: protocol.Range{
Start: protocol.Position{
Line: 0,
Character: 7,
},
End: protocol.Position{
Line: 0,
Character: 11,
},
},
},
},
assertion: assert.NoError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
19 changes: 6 additions & 13 deletions lsp/codejump/hover.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package codejump
import (
"context"
"errors"
"strings"

"github.com/joyme123/thrift-ls/format"
"github.com/joyme123/thrift-ls/lsp/cache"
Expand Down Expand Up @@ -50,11 +49,9 @@ func Hover(ctx context.Context, ss *cache.Snapshot, file uri.URI, pos protocol.P
func hoverService(ctx context.Context, ss *cache.Snapshot, file uri.URI, ast *parser.Document, targetNode parser.Node) (string, error) {
identifierName := targetNode.(*parser.IdentifierName)
name := identifierName.Text
include, identifier, found := strings.Cut(name, ".")
include, identifier := lsputils.ParseIdent(file, ast.Includes, name)
var astFile uri.URI
if !found {
identifier = include
include = ""
if include == "" {
astFile = file
} else {
path := lsputils.GetIncludePath(ast, include)
Expand Down Expand Up @@ -89,11 +86,9 @@ func hoverDefinition(ctx context.Context, ss *cache.Snapshot, file uri.URI, ast
return "", nil
}

include, identifier, found := strings.Cut(typeV, ".")
include, identifier := lsputils.ParseIdent(file, ast.Includes, typeV)
var astFile uri.URI
if !found {
identifier = include
include = ""
if include == "" {
astFile = file
} else {
path := lsputils.GetIncludePath(ast, include)
Expand Down Expand Up @@ -144,11 +139,9 @@ func hoverConstValue(ctx context.Context, ss *cache.Snapshot, file uri.URI, ast
return "", nil
}

include, identifier, found := strings.Cut(constValue.Value.(string), ".")
include, identifier := lsputils.ParseIdent(file, ast.Includes, constValue.Value.(string))
var astFile uri.URI
if !found {
identifier = include
include = ""
if include == "" {
astFile = file
} else {
path := lsputils.GetIncludePath(ast, include)
Expand Down
2 changes: 1 addition & 1 deletion lsp/codejump/reference.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func Reference(ctx context.Context, ss *cache.Snapshot, file uri.URI, pos protoc
if !strings.Contains(svcName, ".") {
svcName = fmt.Sprintf("%s.%s", lsputils.GetIncludeName(file), svcName)
} else {
include, _, _ := strings.Cut(svcName, ".")
include, _ := lsputils.ParseIdent(file, pf.AST().Includes, svcName)
path := lsputils.GetIncludePath(pf.AST(), include)
if path != "" { // doesn't match any include path
file = lsputils.IncludeURI(file, path)
Expand Down
2 changes: 1 addition & 1 deletion lsp/codejump/rename.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func Rename(ctx context.Context, ss *cache.Snapshot, file uri.URI, pos protocol.
if !strings.Contains(svcName, ".") {
svcName = fmt.Sprintf("%s.%s", lsputils.GetIncludeName(file), svcName)
} else {
include, _, _ := strings.Cut(svcName, ".")
include, _ := lsputils.ParseIdent(file, pf.AST().Includes, svcName)
path := lsputils.GetIncludePath(pf.AST(), include)
if path != "" { // doesn't match any include path
file = lsputils.IncludeURI(file, path)
Expand Down
48 changes: 45 additions & 3 deletions lsp/lsputils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package lsputils

import (
"path/filepath"
"sort"
"strings"

"github.com/joyme123/thrift-ls/parser"
Expand Down Expand Up @@ -43,15 +44,15 @@ func GetIncludeName(file uri.URI) string {
// if doesn't match, return empty string
func GetIncludePath(ast *parser.Document, includeName string) string {
for _, include := range ast.Includes {
if include.BadNode || include.Path == nil || include.Path.BadNode {
if include.BadNode || include.Path == nil || include.Path.BadNode || include.Path.Value == nil {
continue
}
items := strings.Split(include.Path.Value.Text, "/")
path := items[len(items)-1]
name, _, found := strings.Cut(path, ".")
if !found {
if !strings.HasSuffix(path, ".thrift") {
continue
}
name := strings.TrimSuffix(path, ".thrift")
if name == includeName {
return include.Path.Value.Text
}
Expand All @@ -71,3 +72,44 @@ func IncludeURI(cur uri.URI, includePath string) uri.URI {

return uri.File(path)
}

// ParseIdent parse an identifier. identifier format:
// 1. identifier
// 2. include.identifier
//
// it returns include, ident
func ParseIdent(cur uri.URI, includes []*parser.Include, identifier string) (include, ident string) {
includeNames := IncludeNames(cur, includes)
// parse include from includeNames

sort.SliceStable(includeNames, func(i, j int) bool {
// sort by string length, make sure longest include match early
// examples:
// user.extra
// user
return len(includeNames[i]) > len(includeNames[j])
})

for _, incName := range includeNames {
prefix := incName + "."
if strings.HasPrefix(identifier, prefix) {
return incName, strings.TrimPrefix(identifier, prefix)
}
}

return "", identifier
}

// IncludeNames returns include names from include ast nodes
func IncludeNames(cur uri.URI, includes []*parser.Include) (includeNames []string) {
for _, inc := range includes {
if inc.Path != nil && inc.Path.Value != nil {
path := inc.Path.Value.Text
u := IncludeURI(cur, path)
includeName := GetIncludeName(u)
includeNames = append(includeNames, includeName)
}
}

return includeNames
}
Loading

0 comments on commit 7d46dff

Please sign in to comment.