diff --git a/editors/vscode/client/src/astviewer.ts b/editors/vscode/client/src/astviewer.ts index 66ca01c..719c245 100644 --- a/editors/vscode/client/src/astviewer.ts +++ b/editors/vscode/client/src/astviewer.ts @@ -1,22 +1,32 @@ import * as vscode from "vscode" +type ASTFetcher = ( + uri: vscode.Uri, + version: number, + token: vscode.CancellationToken, +) => Promise + export class ASTViewer implements vscode.TextDocumentContentProvider { private virtualUrisByFile = new Map() + private documentVersionsByUri = new Map() private closeListener: vscode.Disposable private changeListener: vscode.Disposable - private fetchAST: (uri: vscode.Uri) => Promise + private fetchAST: ASTFetcher - constructor(fetchAST: (uri: vscode.Uri) => Promise) { + constructor(fetchAST: ASTFetcher) { this.fetchAST = fetchAST this.closeListener = vscode.workspace.onDidCloseTextDocument((doc) => { switch (doc.uri.scheme) { case "file": { this.virtualUrisByFile.delete(doc.uri.toString()) + this.documentVersionsByUri.delete(doc.uri.toString()) break } case "protoast2": case "protoast": { - this.virtualUrisByFile.delete(fromProtoAstUri(doc.uri).toString()) + const uri = fromProtoAstUri(doc.uri) + this.virtualUrisByFile.delete(uri.toString()) + this.documentVersionsByUri.delete(uri.toString()) break } } @@ -24,6 +34,10 @@ export class ASTViewer implements vscode.TextDocumentContentProvider { this.changeListener = vscode.workspace.onDidChangeTextDocument((e) => { const virtualUri = this.virtualUrisByFile.get(e.document.uri.toString()) if (virtualUri) { + this.documentVersionsByUri.set( + virtualUri.toString(), + e.document.version, + ) this.refresh(virtualUri) } }) @@ -63,7 +77,8 @@ export class ASTViewer implements vscode.TextDocumentContentProvider { if (!this.virtualUrisByFile.has(fileUri.toString())) { return "" } - return this.fetchAST(fileUri) + const version = this.documentVersionsByUri.get(uri.toString()) ?? 0 + return this.fetchAST(fileUri, version, token) } public dispose() { diff --git a/editors/vscode/client/src/client.ts b/editors/vscode/client/src/client.ts index 712ff86..0b22411 100644 --- a/editors/vscode/client/src/client.ts +++ b/editors/vscode/client/src/client.ts @@ -28,10 +28,15 @@ export class ProtolsLanguageClient uri: vscode.Uri, token: vscode.CancellationToken, ): vscode.ProviderResult { - return this.sendRequest("workspace/executeCommand", { - command: "protols/synthetic-file-contents", - arguments: [{ uri: uri.toString() }], - }).then((result: string) => { + // look up the document version for this uri + return this.sendRequest( + "workspace/executeCommand", + { + command: "protols/synthetic-file-contents", + arguments: [{ uri: uri.toString() }], + }, + token, + ).then((result: string) => { return result }) } diff --git a/editors/vscode/client/src/extension.ts b/editors/vscode/client/src/extension.ts index bb23aaa..daf40a5 100644 --- a/editors/vscode/client/src/extension.ts +++ b/editors/vscode/client/src/extension.ts @@ -12,16 +12,27 @@ export async function activate(context: vscode.ExtensionContext) { // Start the client. This will also launch the server client.start() - const astViewer = new ASTViewer((uri) => { - return client - .sendRequest("workspace/executeCommand", { - command: "protols/ast", - arguments: [{ uri: fromProtoAstUri(uri).toString() }], - }) - .then((result: string) => { - return result - }) - }) + const astViewer = new ASTViewer( + (uri: vscode.Uri, version: number, token: vscode.CancellationToken) => { + return client + .sendRequest( + "workspace/executeCommand", + { + command: "protols/ast", + arguments: [ + { + uri: fromProtoAstUri(uri).toString(), + version, + }, + ], + }, + token, + ) + .then((result: string) => { + return result + }) + }, + ) context.subscriptions.push( astViewer, vscode.workspace.registerTextDocumentContentProvider("protoast", astViewer), diff --git a/go.mod b/go.mod index 8c362ed..409f460 100644 --- a/go.mod +++ b/go.mod @@ -3,21 +3,21 @@ module github.com/kralicky/protols go 1.21.4 require ( - buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.31.0-20230914171853-63dfe56cc2c4.1 + buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.32.0-20231115204500-e097f827e652.1 github.com/AlecAivazis/survey/v2 v2.3.7 - github.com/bufbuild/protovalidate-go v0.3.4 + github.com/bufbuild/protovalidate-go v0.4.3 github.com/google/cel-go v0.18.2 github.com/kralicky/gpkg v0.0.0-20231114180450-2f4bff8c5588 - github.com/kralicky/protocompile v0.0.0-20240113031314-24e69108897d + github.com/kralicky/protocompile v0.0.0-20240114032708-6ff5d8987df3 github.com/kralicky/tools-lite v0.0.0-20240104191314-c259ddd5a342 github.com/mattn/go-tty v0.0.5 github.com/spf13/cobra v1.8.0 golang.org/x/mod v0.14.0 - golang.org/x/sync v0.5.0 - google.golang.org/genproto v0.0.0-20231212172506-995d672761c0 - google.golang.org/genproto/googleapis/api v0.0.0-20231212172506-995d672761c0 - google.golang.org/genproto/googleapis/rpc v0.0.0-20231212172506-995d672761c0 - google.golang.org/protobuf v1.31.1-0.20231027082548-f4a6c1f6e5c1 + golang.org/x/sync v0.6.0 + google.golang.org/genproto v0.0.0-20240108191215-35c7eff3a6b1 + google.golang.org/genproto/googleapis/api v0.0.0-20240108191215-35c7eff3a6b1 + google.golang.org/genproto/googleapis/rpc v0.0.0-20240108191215-35c7eff3a6b1 + google.golang.org/protobuf v1.32.0 ) require ( diff --git a/go.sum b/go.sum index 411bb2b..061c0af 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.31.0-20230914171853-63dfe56cc2c4.1 h1:2gmp+PRca1fqQHf/WMKOgu9inVb0R0N07TucgY3QZCQ= -buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.31.0-20230914171853-63dfe56cc2c4.1/go.mod h1:xafc+XIsTxTy76GJQ1TKgvJWsSugFBqMaN27WhUblew= +buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.32.0-20231115204500-e097f827e652.1 h1:u0olL4yf2p7Tl5jfsAK5keaFi+JFJuv1CDHrbiXkxkk= +buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.32.0-20231115204500-e097f827e652.1/go.mod h1:tiTMKD8j6Pd/D2WzREoweufjzaJKHZg35f/VGcZ2v3I= cloud.google.com/go/dlp v1.11.1 h1:OFlXedmPP/5//X1hBEeq3D9kUVm9fb6ywYANlpv/EsQ= cloud.google.com/go/dlp v1.11.1/go.mod h1:/PA2EnioBeXTL/0hInwgj0rfsQb3lpE3R8XUJxqUNKI= github.com/AlecAivazis/survey/v2 v2.3.7 h1:6I/u8FvytdGsgonrYsVn2t8t4QiRnh6QSTqkkhIiSjQ= @@ -8,8 +8,8 @@ github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2 h1:+vx7roKuyA63n github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2/go.mod h1:HBCaDeC1lPdgDeDbhX8XFpy1jqjK0IBG8W5K+xYqA0w= github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= -github.com/bufbuild/protovalidate-go v0.3.4 h1:FrHcBBShspitvmC9Nkwu7BNs/EXWjkEQqrgFnWxYH60= -github.com/bufbuild/protovalidate-go v0.3.4/go.mod h1:Au57xmLypglbQAF0GzuDDYbYIct7SZ9QnwJlaPolyFw= +github.com/bufbuild/protovalidate-go v0.4.3 h1:1Xsm3qhkwioxLDEtxWgtn0Ch71xBP/sBauT/FZnn76A= +github.com/bufbuild/protovalidate-go v0.4.3/go.mod h1:RcgJ+onKVv4OkAVtzkRUxkocb8stcUAMK0EoqR4fuZE= github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.17 h1:QeVUsEDNrLBW4tMgZHvxy18sKtr6VI492kBhUfhDJNI= github.com/creack/pty v1.1.17/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= @@ -32,8 +32,8 @@ github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNU github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/kralicky/gpkg v0.0.0-20231114180450-2f4bff8c5588 h1:chw4znRXk7AA+AlKcrUZzH1Vupl54KcS4W6wkXCX3lU= github.com/kralicky/gpkg v0.0.0-20231114180450-2f4bff8c5588/go.mod h1:vOkwMjs49XmP/7Xfo9ZL6eg2ei51lmtD/4U/Az5GTq8= -github.com/kralicky/protocompile v0.0.0-20240113031314-24e69108897d h1:DdKNJNMQRs2MBIikDaqWm75w2Yud4xtjRCmmhs/d1qw= -github.com/kralicky/protocompile v0.0.0-20240113031314-24e69108897d/go.mod h1:QKlDXp/yojhlpqgJfUHWhqzvD9gCD/baEPFvq89cpgE= +github.com/kralicky/protocompile v0.0.0-20240114032708-6ff5d8987df3 h1:ZFXct43FfQYouVBjIaHUsDUpKm/o/RQItocyVKhOj+g= +github.com/kralicky/protocompile v0.0.0-20240114032708-6ff5d8987df3/go.mod h1:QKlDXp/yojhlpqgJfUHWhqzvD9gCD/baEPFvq89cpgE= github.com/kralicky/tools-lite v0.0.0-20240104191314-c259ddd5a342 h1:lZLWHXKHmOhTrs3oSZoCRtb8Y9a0mqUwCsaKut+Y1eU= github.com/kralicky/tools-lite v0.0.0-20240104191314-c259ddd5a342/go.mod h1:NKsdxFI6awifvNvxDwtCU1YCaKRoSSPpbHXkKOMuq24= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= @@ -82,8 +82,8 @@ golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= -golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -109,19 +109,18 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto v0.0.0-20231212172506-995d672761c0 h1:YJ5pD9rF8o9Qtta0Cmy9rdBwkSjrTCT6XTiUQVOtIos= -google.golang.org/genproto v0.0.0-20231212172506-995d672761c0/go.mod h1:l/k7rMz0vFTBPy+tFSGvXEd3z+BcoG1k7EHbqm+YBsY= -google.golang.org/genproto/googleapis/api v0.0.0-20231212172506-995d672761c0 h1:s1w3X6gQxwrLEpxnLd/qXTVLgQE2yXwaOaoa6IlY/+o= -google.golang.org/genproto/googleapis/api v0.0.0-20231212172506-995d672761c0/go.mod h1:CAny0tYF+0/9rmDB9fahA9YLzX3+AEVl1qXbv5hhj6c= -google.golang.org/genproto/googleapis/rpc v0.0.0-20231212172506-995d672761c0 h1:/jFB8jK5R3Sq3i/lmeZO0cATSzFfZaJq1J2Euan3XKU= -google.golang.org/genproto/googleapis/rpc v0.0.0-20231212172506-995d672761c0/go.mod h1:FUoWkonphQm3RhTS+kOEhF8h0iDpm4tdXolVCeZ9KKA= +google.golang.org/genproto v0.0.0-20240108191215-35c7eff3a6b1 h1:/IWabOtPziuXTEtI1KYCpM6Ss7vaAkeMxk+uXV/xvZs= +google.golang.org/genproto v0.0.0-20240108191215-35c7eff3a6b1/go.mod h1:+Rvu7ElI+aLzyDQhpHMFMMltsD6m7nqpuWDd2CwJw3k= +google.golang.org/genproto/googleapis/api v0.0.0-20240108191215-35c7eff3a6b1 h1:OPXtXn7fNMaXwO3JvOmF1QyTc00jsSFFz1vXXBOdCDo= +google.golang.org/genproto/googleapis/api v0.0.0-20240108191215-35c7eff3a6b1/go.mod h1:B5xPO//w8qmBDjGReYLpR6UJPnkldGkCSMoH/2vxJeg= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240108191215-35c7eff3a6b1 h1:gphdwh0npgs8elJ4T6J+DQJHPVF7RsuJHCfwztUb4J4= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240108191215-35c7eff3a6b1/go.mod h1:daQN87bsDqDoe316QbbvX60nMoJQa4r6Ds0ZuoAe5yA= google.golang.org/grpc v1.60.1 h1:26+wFr+cNqSGFcOXcabYC0lUVJVRa2Sb2ortSK7VrEU= google.golang.org/grpc v1.60.1/go.mod h1:OlCHIeLYqSSsLi6i49B5QGdzaMZK9+M7LXN2FKz4eGM= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.31.1-0.20231027082548-f4a6c1f6e5c1 h1:fk72uXZyuZiTtW5tgd63jyVK6582lF61nRC/kGv6vCA= -google.golang.org/protobuf v1.31.1-0.20231027082548-f4a6c1f6e5c1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I= +google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/pkg/format/ast.go b/pkg/format/ast.go index 2f49fe2..f6e2d84 100644 --- a/pkg/format/ast.go +++ b/pkg/format/ast.go @@ -94,6 +94,12 @@ func (v *dumpVisitor) VisitCompoundStringLiteralNode(node *ast.CompoundStringLit } func (v *dumpVisitor) VisitEmptyDeclNode(node *ast.EmptyDeclNode) error { + v.buf.WriteString("\n") + return nil +} + +func (v *dumpVisitor) VisitErrorNode(*ast.ErrorNode) error { + v.buf.WriteString("\n") return nil } @@ -132,7 +138,11 @@ func (v *dumpVisitor) VisitFieldReferenceNode(node *ast.FieldReferenceNode) erro } func (v *dumpVisitor) VisitFileNode(node *ast.FileNode) error { - v.buf.WriteString(fmt.Sprintf("syntax=%q #decls=%d\n", maybe(node.Syntax).Syntax.AsString(), len(node.Decls))) + if node.Syntax == nil { + v.buf.WriteString(fmt.Sprintf("!syntax #decls=%d\n", len(node.Decls))) + } else { + v.buf.WriteString(fmt.Sprintf("syntax=%q #decls=%d\n", maybe(node.Syntax).Syntax.AsString(), len(node.Decls))) + } return nil } diff --git a/pkg/lsp/cache.go b/pkg/lsp/cache.go index e76394a..0556f25 100644 --- a/pkg/lsp/cache.go +++ b/pkg/lsp/cache.go @@ -50,6 +50,8 @@ type Cache struct { inflightTasksInvalidate gsync.Map[protocompile.ResolvedPath, time.Time] inflightTasksCompile gsync.Map[protocompile.ResolvedPath, time.Time] pragmas gsync.Map[protocompile.ResolvedPath, *pragmaMap] + + documentVersions *documentVersionQueue } // FindDescriptorByName implements linker.Resolver. @@ -215,11 +217,12 @@ func NewCache(workspace protocol.WorkspaceFolder, opts ...CacheOption) *Cache { workdir: protocol.DocumentURI(workspace.URI).Path(), } cache := &Cache{ - workspace: workspace, - compiler: compiler, - resolver: resolver, - diagHandler: diagHandler, - unlinkedResults: make(map[protocompile.ResolvedPath]parser.Result), + workspace: workspace, + compiler: compiler, + resolver: resolver, + diagHandler: diagHandler, + unlinkedResults: make(map[protocompile.ResolvedPath]parser.Result), + documentVersions: newDocumentVersionQueue(), } compiler.Hooks = protocompile.CompilerHooks{ PreInvalidate: cache.preInvalidateHook, @@ -243,9 +246,7 @@ func (c *Cache) LoadFiles(files []string) { } } - if err := c.DidModifyFiles(context.TODO(), created); err != nil { - slog.Error("failed to index files", "error", err) - } + c.DidModifyFiles(context.TODO(), created) } func (r *Cache) GetMapper(uri protocol.DocumentURI) (*protocol.Mapper, error) { @@ -267,7 +268,7 @@ func (r *Cache) GetMapper(uri protocol.DocumentURI) (*protocol.Mapper, error) { return protocol.NewMapper(uri, content), nil } -func (s *Cache) ChangedText(ctx context.Context, uri protocol.DocumentURI, changes []protocol.TextDocumentContentChangeEvent) ([]byte, error) { +func (s *Cache) ChangedText(ctx context.Context, uri protocol.VersionedTextDocumentIdentifier, changes []protocol.TextDocumentContentChangeEvent) ([]byte, error) { if len(changes) == 0 { return nil, fmt.Errorf("%w: no content changes provided", jsonrpc2.ErrInternal) } @@ -278,7 +279,7 @@ func (s *Cache) ChangedText(ctx context.Context, uri protocol.DocumentURI, chang return []byte(changes[0].Text), nil } - m, err := s.GetMapper(uri) + m, err := s.GetMapper(uri.URI) if err != nil { return nil, err } @@ -400,8 +401,9 @@ func (c *Cache) compileLocked(protos ...string) { c.compileLocked(syntheticFiles...) } -func (c *Cache) DidModifyFiles(ctx context.Context, modifications []file.Modification) error { +func (c *Cache) DidModifyFiles(ctx context.Context, modifications []file.Modification) { c.resolver.UpdateURIPathMappings(modifications) + defer c.documentVersions.Update(modifications...) var toRecompile []string for _, m := range modifications { @@ -427,12 +429,11 @@ func (c *Cache) DidModifyFiles(ctx context.Context, modifications []file.Modific } } if err := c.compiler.fs.UpdateOverlays(ctx, modifications); err != nil { - return err + panic(fmt.Errorf("internal protocol error: %w", err)) } if len(toRecompile) > 0 { c.Compile(toRecompile...) } - return nil } func (c *Cache) ComputeSemanticTokens(doc protocol.TextDocumentIdentifier) ([]uint32, error) { @@ -1198,3 +1199,9 @@ func (c *Cache) FindPragmasByPath(path protocompile.ResolvedPath) (Pragmas, bool p, ok := c.pragmas.Load(path) return p, ok } + +func (c *Cache) WaitDocumentVersion(ctx context.Context, uri protocol.DocumentURI, version int32) error { + ctx, ca := context.WithTimeout(ctx, 2*time.Second) + defer ca() + return c.documentVersions.Wait(ctx, uri, version) +} diff --git a/pkg/lsp/commands.go b/pkg/lsp/commands.go index f30db0d..69ef843 100644 --- a/pkg/lsp/commands.go +++ b/pkg/lsp/commands.go @@ -34,7 +34,8 @@ type SyntheticFileContentsRequest struct { type DocumentASTRequest struct { // The URI of the file to retrieve the AST for. - URI string `json:"uri"` + URI string `json:"uri"` + Version int32 `json:"version"` } type ReindexWorkspacesRequest struct{} diff --git a/pkg/lsp/completion.go b/pkg/lsp/completion.go index b7c6512..9fc6284 100644 --- a/pkg/lsp/completion.go +++ b/pkg/lsp/completion.go @@ -93,265 +93,298 @@ func (c *Cache) GetCompletions(params *protocol.CompletionParams) (result *proto tokenAtOffset := searchTarget.AST().TokenAtOffset(posOffset) path, found := findNarrowestEnclosingScope(searchTarget, tokenAtOffset, params.Position) - if found && maybeCurrentLinkRes != nil { - completions := []protocol.CompletionItem{} - desc, _, _ := deepPathSearch(path, searchTarget, maybeCurrentLinkRes) - - scope, existingOpts := findCompletionScopeAndExistingOptions(path, maybeCurrentLinkRes) - - switch node := path[len(path)-1].(type) { - case *ast.MessageNode: - partialName := strings.TrimSpace(textPrecedingCursor) - if !isProto2(searchTarget.AST()) { - // complete message types - completions = append(completions, completeTypeNames(c, partialName, "", maybeCurrentLinkRes, desc.FullName(), params.Position)...) + if !found { + var completions []protocol.CompletionItem + if len(searchTarget.AST().Children()) == 1 { // only EOF + // empty file + completions = append(completions, syntaxSnippets()...) + } else { + // top-level keywords + completions = append(completions, fileKeywordCompletions(searchTarget.AST(), "", "", params.Position)...) + } + return &protocol.CompletionList{ + Items: completions, + }, nil + } + if maybeCurrentLinkRes == nil { + return nil, nil + } + + completions := []protocol.CompletionItem{} + desc, _, _ := deepPathSearch(path, searchTarget, maybeCurrentLinkRes) + + scope, existingOpts := findCompletionScopeAndExistingOptions(path, maybeCurrentLinkRes) + + switch node := path[len(path)-1].(type) { + case *ast.MessageNode: + var partialName, partialNameSuffix string + fileNode := searchTarget.AST() + if node.Name != nil && tokenAtOffset >= node.Name.Start() && tokenAtOffset <= node.Name.End() { + // complete message names + var err error + partialName, partialNameSuffix, err = findPartialNames(fileNode, node.Name, mapper, posOffset) + if err != nil { + return nil, err } - completions = append(completions, messageKeywordCompletions(searchTarget, partialName)...) - case *ast.MessageLiteralNode: - if desc == nil { + } + if !isProto2(fileNode) { + // complete message types + completions = append(completions, completeTypeNames(c, partialName, partialNameSuffix, maybeCurrentLinkRes, desc.FullName(), params.Position)...) + } + completions = append(completions, messageKeywordCompletions(fileNode, partialName, partialNameSuffix, params.Position)...) + case *ast.MessageLiteralNode: + if desc == nil { + break + } + switch desc := desc.(type) { + case protoreflect.MessageDescriptor: + // complete field names + // filter out fields that are already present + existingFieldNames := []string{} + + for _, elem := range node.Elements { + name := string(elem.Name.Name.AsIdentifier()) + if fd := desc.Fields().ByName(protoreflect.Name(name)); fd != nil { + existingFieldNames = append(existingFieldNames, name) + } + } + for i, l := 0, desc.Fields().Len(); i < l; i++ { + fld := desc.Fields().Get(i) + if slices.Contains(existingFieldNames, string(fld.Name())) { + continue + } + insertPos := protocol.Range{ + Start: params.Position, + End: params.Position, + } + completions = append(completions, fieldCompletion(fld, insertPos, messageLiteralStyle)) + } + + case protoreflect.ExtensionTypeDescriptor: + if _, ok := path[len(path)-2].(*ast.ExtendNode); ok { + // this is a field of an extend node, not a message literal break } - switch desc := desc.(type) { - case protoreflect.MessageDescriptor: + switch desc.Kind() { + case protoreflect.MessageKind: + msg := desc.Message() // complete field names - // filter out fields that are already present - existingFieldNames := []string{} - - for _, elem := range node.Elements { - name := string(elem.Name.Name.AsIdentifier()) - if fd := desc.Fields().ByName(protoreflect.Name(name)); fd != nil { - existingFieldNames = append(existingFieldNames, name) - } - } - for i, l := 0, desc.Fields().Len(); i < l; i++ { - fld := desc.Fields().Get(i) - if slices.Contains(existingFieldNames, string(fld.Name())) { - continue - } + for i, l := 0, msg.Fields().Len(); i < l; i++ { + fld := msg.Fields().Get(i) insertPos := protocol.Range{ Start: params.Position, End: params.Position, } - completions = append(completions, fieldCompletion(fld, insertPos, messageLiteralStyle)) + completions = append(completions, fieldCompletion(fld, insertPos, compactOptionsStyle)) } - - case protoreflect.ExtensionTypeDescriptor: - if _, ok := path[len(path)-2].(*ast.ExtendNode); ok { - // this is a field of an extend node, not a message literal - break + } + } + case *ast.CompactOptionsNode: + completions = append(completions, + c.completeOptionOrExtensionName(scope, path, searchTarget.AST(), nil, 0, maybeCurrentLinkRes, existingOpts, mapper, posOffset, params.Position)...) + case *ast.OptionNode: + completions = append(completions, + c.completeOptionOrExtensionName(scope, path, searchTarget.AST(), nil, 0, maybeCurrentLinkRes, existingOpts, mapper, posOffset, params.Position)...) + case *ast.FieldReferenceNode: + nodeIdx := -1 + scope := scope + switch prev := path[len(path)-2].(type) { + case *ast.OptionNameNode: + nodeIdx = slices.Index(prev.Parts, node) + case *ast.MessageFieldNode: + if desc == nil { + if desc, _, _ := deepPathSearch(path[:len(path)-2], searchTarget, maybeCurrentLinkRes); desc != nil { + nodeIdx = 0 + scope = desc } - switch desc.Kind() { - case protoreflect.MessageKind: - msg := desc.Message() - // complete field names - for i, l := 0, msg.Fields().Len(); i < l; i++ { - fld := msg.Fields().Get(i) - insertPos := protocol.Range{ - Start: params.Position, - End: params.Position, - } - completions = append(completions, fieldCompletion(fld, insertPos, compactOptionsStyle)) - } + } else { + nodeIdx = 0 + if fd, ok := desc.(protoreflect.FieldDescriptor); ok { + scope = fd.ContainingMessage() } } - case *ast.CompactOptionsNode: - completions = append(completions, - c.completeOptionOrExtensionName(scope, path, searchTarget.AST(), nil, 0, maybeCurrentLinkRes, existingOpts, mapper, posOffset, params.Position)...) - case *ast.OptionNode: - completions = append(completions, - c.completeOptionOrExtensionName(scope, path, searchTarget.AST(), nil, 0, maybeCurrentLinkRes, existingOpts, mapper, posOffset, params.Position)...) - case *ast.FieldReferenceNode: - nodeIdx := -1 - scope := scope - switch prev := path[len(path)-2].(type) { - case *ast.OptionNameNode: - nodeIdx = slices.Index(prev.Parts, node) - case *ast.MessageFieldNode: - if desc == nil { - if desc, _, _ := deepPathSearch(path[:len(path)-2], searchTarget, maybeCurrentLinkRes); desc != nil { - nodeIdx = 0 - scope = desc - } - } else { - nodeIdx = 0 - if fd, ok := desc.(protoreflect.FieldDescriptor); ok { - scope = fd.ContainingMessage() - } + } + if nodeIdx == -1 { + break + } + existingFields := map[string]struct{}{} + if messageLitNode, ok := path[len(path)-3].(*ast.MessageLiteralNode); ok { + for _, elem := range messageLitNode.Elements { + if elem.IsIncomplete() { + continue } + existingFields[string(scope.FullName().Append(protoreflect.Name(elem.Name.Name.AsIdentifier())))] = struct{}{} } - if nodeIdx == -1 { - break + } + completions = append(completions, + c.completeOptionOrExtensionName(scope, path, searchTarget.AST(), node, nodeIdx, maybeCurrentLinkRes, existingFields, mapper, posOffset, params.Position)...) + case *ast.OptionNameNode: + // this can be the closest node in a few cases, such as after a trailing dot + nodeChildren := node.Children() + var lastPart *ast.FieldReferenceNode + for i := 0; i < len(nodeChildren); i++ { + if nodeChildren[i].Start() != tokenAtOffset { + continue } - existingFields := map[string]struct{}{} - if messageLitNode, ok := path[len(path)-3].(*ast.MessageLiteralNode); ok { - for _, elem := range messageLitNode.Elements { - if elem.IsIncomplete() { - continue - } - existingFields[string(scope.FullName().Append(protoreflect.Name(elem.Name.Name.AsIdentifier())))] = struct{}{} + for j := i; j >= 0; j-- { + if frn, ok := nodeChildren[j].(*ast.FieldReferenceNode); ok { + lastPart = frn + break } } - completions = append(completions, - c.completeOptionOrExtensionName(scope, path, searchTarget.AST(), node, nodeIdx, maybeCurrentLinkRes, existingFields, mapper, posOffset, params.Position)...) - case *ast.OptionNameNode: - // this can be the closest node in a few cases, such as after a trailing dot - nodeChildren := node.Children() - var lastPart *ast.FieldReferenceNode - for i := 0; i < len(nodeChildren); i++ { - if nodeChildren[i].Start() != tokenAtOffset { - continue - } - for j := i; j >= 0; j-- { - if frn, ok := nodeChildren[j].(*ast.FieldReferenceNode); ok { - lastPart = frn - break + } + if lastPart == nil { + break + } + prevFd := maybeCurrentLinkRes.FindFieldDescriptorByFieldReferenceNode(lastPart) + if prevFd == nil { + break + } + items, err := c.deepCompleteOptionNames(prevFd, "", "", maybeCurrentLinkRes, nil, lastPart, params.Position) + if err != nil { + return nil, err + } + completions = append(completions, items...) + + case *ast.FieldNode: + // check if we are completing a type name + var shouldCompleteType bool + var shouldCompleteKeywords bool + var completeType string + var completeTypeSuffix string + + switch { + case tokenAtOffset == node.End(): + // figure out what the previous token is + switch tokenAtOffset - 1 { + case node.FldType.End(): + // complete the field name + switch fldType := node.FldType.(type) { + case *ast.IncompleteIdentNode: + if fldType.IncompleteVal != nil { + completeType = string(fldType.IncompleteVal.AsIdentifier()) + } else { + completeType = "" } + case *ast.CompoundIdentNode: + completeType = string(fldType.AsIdentifier()) + case *ast.IdentNode: + completeType = string(fldType.AsIdentifier()) } + shouldCompleteType = true + case node.Name.Token(): + case node.Equals.Token(): + case node.Tag.Token(): + case node.Options.End(): } - if lastPart == nil { - break - } - prevFd := maybeCurrentLinkRes.FindFieldDescriptorByFieldReferenceNode(lastPart) - if prevFd == nil { - break - } - items, err := c.deepCompleteOptionNames(prevFd, "", "", maybeCurrentLinkRes, nil, lastPart, params.Position) + case tokenAtOffset >= node.FldType.Start() && tokenAtOffset <= node.FldType.End(): + // complete within the field type + pos := searchTarget.AST().NodeInfo(node.FldType).Start() + startOffset, err := mapper.PositionOffset(protocol.Position{ + Line: uint32(pos.Line - 1), + Character: uint32(pos.Col - 1), + }) if err != nil { return nil, err } - completions = append(completions, items...) - - case *ast.FieldNode: - // check if we are completing a type name - var shouldCompleteType bool - var shouldCompleteKeywords bool - var completeType string - var completeTypeSuffix string - - switch { - case tokenAtOffset == node.End(): - // figure out what the previous token is - switch tokenAtOffset - 1 { - case node.FldType.End(): - // complete the field name - switch fldType := node.FldType.(type) { - case *ast.IncompleteIdentNode: - if fldType.IncompleteVal != nil { - completeType = string(fldType.IncompleteVal.AsIdentifier()) - } else { - completeType = "" - } - case *ast.CompoundIdentNode: - completeType = string(fldType.AsIdentifier()) - case *ast.IdentNode: - completeType = string(fldType.AsIdentifier()) - } - shouldCompleteType = true - case node.Name.Token(): - case node.Equals.Token(): - case node.Tag.Token(): - case node.Options.End(): - } - case tokenAtOffset >= node.FldType.Start() && tokenAtOffset <= node.FldType.End(): - if desc == nil { - break - } - // complete within the field type - pos := searchTarget.AST().NodeInfo(node.FldType).Start() - startOffset, err := mapper.PositionOffset(protocol.Position{ - Line: uint32(pos.Line - 1), - Character: uint32(pos.Col - 1), - }) - if err != nil { - return nil, err - } - cursorIndexIntoType := posOffset - startOffset - completeType = string(node.FldType.AsIdentifier()) - if len(completeType) >= cursorIndexIntoType { - completeTypeSuffix = completeType[cursorIndexIntoType:] - completeType = completeType[:cursorIndexIntoType] - shouldCompleteType = true - } - if node.Label.IsPresent() && node.Label.Start() == node.FldType.Start() { - // handle empty *ast.IncompleteIdentNodes, such as in 'optional ' - completeType = "" - shouldCompleteType = true - } else { - // complete keywords - shouldCompleteKeywords = true - } - + cursorIndexIntoType := posOffset - startOffset + completeType = string(node.FldType.AsIdentifier()) + if len(completeType) >= cursorIndexIntoType { + completeTypeSuffix = completeType[cursorIndexIntoType:] + completeType = completeType[:cursorIndexIntoType] + shouldCompleteType = true } - if shouldCompleteType { - fmt.Println("completing type", completeType) - var scope protoreflect.FullName - if len(path) > 1 { - if desc, _, err := deepPathSearch(path[:len(path)-1], searchTarget, maybeCurrentLinkRes); err == nil { - scope = desc.FullName() - } + if node.Label.IsPresent() && node.Label.Start() == node.FldType.Start() { + // handle empty *ast.IncompleteIdentNodes, such as in 'optional ' + completeType = "" + shouldCompleteType = true + } else { + // complete keywords + shouldCompleteKeywords = true + } + } + if shouldCompleteType { + fmt.Println("completing type", completeType) + var scope protoreflect.FullName + if len(path) > 1 { + if desc, _, err := deepPathSearch(path[:len(path)-1], searchTarget, maybeCurrentLinkRes); err == nil { + scope = desc.FullName() } - completions = append(completions, completeTypeNames(c, completeType, completeTypeSuffix, maybeCurrentLinkRes, scope, params.Position)...) } - if shouldCompleteKeywords { - completions = append(completions, messageKeywordCompletions(searchTarget, completeType)...) + completions = append(completions, completeTypeNames(c, completeType, completeTypeSuffix, maybeCurrentLinkRes, scope, params.Position)...) + } + if shouldCompleteKeywords { + completions = append(completions, messageKeywordCompletions(searchTarget.AST(), completeType, completeTypeSuffix, params.Position)...) + } + case *ast.ImportNode: + // complete import paths + quoteIdx := strings.IndexRune(textPrecedingCursor, '"') + if quoteIdx != -1 { + partialPath := strings.TrimSpace(textPrecedingCursor[quoteIdx+1:]) + endQuoteIdx := strings.IndexRune(textFollowingCursor, '"') + var partialPathSuffix string + if endQuoteIdx != -1 { + partialPathSuffix = strings.TrimSpace(textFollowingCursor[:endQuoteIdx]) } - case *ast.ImportNode: - // complete import paths - quoteIdx := strings.IndexRune(textPrecedingCursor, '"') - if quoteIdx != -1 { - partialPath := strings.TrimSpace(textPrecedingCursor[quoteIdx+1:]) - endQuoteIdx := strings.IndexRune(textFollowingCursor, '"') - var partialPathSuffix string - if endQuoteIdx != -1 { - partialPathSuffix = strings.TrimSpace(textFollowingCursor[:endQuoteIdx]) - } - existingImportPaths := []string{} - if desc != nil { - imports := maybeCurrentLinkRes.Imports() - for i, l := 0, imports.Len(); i < l; i++ { - // don't include the current import in the existing imports list - imp := imports.Get(i) - if imp == desc { - continue - } - existingImportPaths = append(existingImportPaths, imp.Path()) + existingImportPaths := []string{} + if desc != nil { + imports := maybeCurrentLinkRes.Imports() + for i, l := 0, imports.Len(); i < l; i++ { + // don't include the current import in the existing imports list + imp := imports.Get(i) + if imp == desc { + continue } + existingImportPaths = append(existingImportPaths, imp.Path()) } - completions = append(completions, completeImports(c, partialPath, partialPathSuffix, existingImportPaths, params.Position)...) - } else { - if strings.TrimSpace(textPrecedingCursor) == "import" && node.Public == nil { - completions = append(completions, protocol.CompletionItem{ - Label: "public", - Kind: protocol.KeywordCompletion, - }) - } } - case *ast.SyntaxNode: - // complete syntax versions - quoteIdx := strings.IndexRune(textPrecedingCursor, '"') - if quoteIdx != -1 { - partialVersion := strings.TrimSpace(textPrecedingCursor[quoteIdx+1:]) - endQuoteIdx := strings.IndexRune(textFollowingCursor, '"') - var partialVersionSuffix string - if endQuoteIdx != -1 { - partialVersionSuffix = strings.TrimSpace(textFollowingCursor[:endQuoteIdx]) + completions = append(completions, completeImports(c, partialPath, partialPathSuffix, existingImportPaths, params.Position)...) + } else { + if strings.TrimSpace(textPrecedingCursor) == "import" && node.Public == nil { + completions = append(completions, protocol.CompletionItem{ + Label: "public", + Kind: protocol.KeywordCompletion, + }) + } + } + case *ast.SyntaxNode: + // complete syntax versions + quoteIdx := strings.IndexRune(textPrecedingCursor, '"') + if quoteIdx != -1 { + partialVersion := strings.TrimSpace(textPrecedingCursor[quoteIdx+1:]) + endQuoteIdx := strings.IndexRune(textFollowingCursor, '"') + var partialVersionSuffix string + if endQuoteIdx != -1 { + partialVersionSuffix = strings.TrimSpace(textFollowingCursor[:endQuoteIdx]) + } + completions = append(completions, completeSyntaxVersions(partialVersion, partialVersionSuffix, params.Position)...) + } + case *ast.PackageNode: + // complete package names + completions = append(completions, + c.completePackageNames(node, path, searchTarget.AST(), maybeCurrentLinkRes, mapper, posOffset, params.Position)...) + case *ast.ErrorNode: + switch prev := path[len(path)-2].(type) { + case *ast.FileNode: + var partialName, partialNameSuffix string + if len(node.Children()) == 1 { + if ident, ok := node.Children()[0].(*ast.IdentNode); ok { + // complete partial top-level keywords + var err error + partialName, partialNameSuffix, err = findPartialNames(prev, ident, mapper, posOffset) + if err != nil { + return nil, err + } } - completions = append(completions, completeSyntaxVersions(partialVersion, partialVersionSuffix, params.Position)...) } - case *ast.PackageNode: - // complete package names - completions = append(completions, - c.completePackageNames(node, path, searchTarget.AST(), maybeCurrentLinkRes, mapper, posOffset, params.Position)...) - + completions = append(completions, fileKeywordCompletions(prev, partialName, partialNameSuffix, params.Position)...) } - - return &protocol.CompletionList{ - Items: completions, - }, nil } - return nil, nil + return &protocol.CompletionList{ + Items: completions, + }, nil } func (c *Cache) completeOptionOrExtensionName( @@ -729,13 +762,26 @@ func (c *Cache) deepCompleteOptionNames( return items, nil } -func completeKeywords(keywords ...string) []protocol.CompletionItem { - items := []protocol.CompletionItem{} +func completeKeywords(keywords []string, partialName, partialNameSuffix string, pos protocol.Position) []protocol.CompletionItem { + var items []protocol.CompletionItem + replaceRange := protocol.Range{ + Start: adjustColumn(pos, -len(partialName)), + End: adjustColumn(pos, len(partialNameSuffix)), + } for _, keyword := range keywords { + if !strings.HasPrefix(keyword, partialName) { + continue + } items = append(items, protocol.CompletionItem{ - Label: keyword, - Kind: protocol.KeywordCompletion, - InsertText: fmt.Sprintf("%s ", keyword), + Label: keyword, + Kind: protocol.KeywordCompletion, + TextEdit: &protocol.Or_CompletionItem_textEdit{ + Value: protocol.InsertReplaceEdit{ + NewText: keyword, + Insert: replaceRange, + Replace: replaceRange, + }, + }, }) } return items @@ -990,18 +1036,56 @@ func isProto2(f *ast.FileNode) bool { return f.Syntax.Syntax.AsString() == "proto2" } -func messageKeywordCompletions(searchTarget parser.Result, partialName string) []protocol.CompletionItem { +func messageKeywordCompletions(fileNode *ast.FileNode, partialName, partialNameSuffix string, pos protocol.Position) []protocol.CompletionItem { // add keyword completions for messages possibleKeywords := []string{"option", "optional", "repeated", "enum", "message", "reserved"} - if isProto2(searchTarget.AST()) { + if isProto2(fileNode) { possibleKeywords = append(possibleKeywords, "required", "extend", "group") } - if len(partialName) > 0 { - possibleKeywords = slices.DeleteFunc(possibleKeywords, func(s string) bool { - return !strings.HasPrefix(s, partialName) - }) + return completeKeywords(possibleKeywords, partialName, partialNameSuffix, pos) +} + +func fileKeywordCompletions(fileNode *ast.FileNode, partialName, partialNameSuffix string, pos protocol.Position) []protocol.CompletionItem { + possibleKeywords := make([]string, 0, 8) + // TODO(editions): add edition keyword + var completions []protocol.CompletionItem + if fileNode.Syntax == nil { + if strings.HasPrefix("syntax", partialName) { + completions = append(completions, syntaxSnippets()...) + } + possibleKeywords = append(possibleKeywords, "syntax") + } + hasPkgNode := false + for _, pkg := range fileNode.Decls { + if _, ok := pkg.(*ast.PackageNode); ok { + hasPkgNode = true + } + } + if !hasPkgNode { + possibleKeywords = append(possibleKeywords, "package") + } + possibleKeywords = append(possibleKeywords, "import", "option", "message", "enum", "service", "extend") + + return append(completions, completeKeywords(possibleKeywords, partialName, partialNameSuffix, pos)...) +} + +func syntaxSnippets() []protocol.CompletionItem { + return []protocol.CompletionItem{ + { + Label: "syntax: proto3", + Kind: protocol.SnippetCompletion, + InsertTextFormat: &snippetMode, + InsertText: "syntax = \"proto3\";\n", + Preselect: true, + }, + { + Label: "syntax: proto2", + Kind: protocol.SnippetCompletion, + InsertTextFormat: &snippetMode, + InsertText: "syntax = \"proto2\";\n", + }, + // TODO(editions): add edition keyword } - return completeKeywords(possibleKeywords...) } // sort by distance from local package diff --git a/pkg/lsp/search.go b/pkg/lsp/search.go index e4c3b17..fc2359f 100644 --- a/pkg/lsp/search.go +++ b/pkg/lsp/search.go @@ -588,37 +588,51 @@ func findNarrowestEnclosingScope(parseRes parser.Result, tokenAtOffset ast.Token info := fileNode.NodeInfo(node) return protocol.Intersect(toRange(info), protocol.Range{Start: location, End: location}) } + intersectsLocationExclusive := func(node, end ast.Node) bool { + if end == nil { + return intersectsLocation(node) + } + if rn, ok := end.(*ast.RuneNode); ok && rn.Virtual { + return intersectsLocation(node) + } + nodeInfo := fileNode.NodeInfo(node) + endSourcePos := fileNode.NodeInfo(end).End() + if protocol.Intersect(positionsToRange(nodeInfo.Start(), endSourcePos), protocol.Range{Start: location, End: location}) { + return int(location.Line) < endSourcePos.Line-1 || int(location.Character) < endSourcePos.Col-1 + } + return false + } opts := tracker.AsWalkOptions() if tokenAtOffset != ast.TokenError { opts = append(opts, ast.WithIntersection(tokenAtOffset)) } ast.Walk(parseRes.AST(), &ast.SimpleVisitor{ DoVisitImportNode: func(node *ast.ImportNode) error { - if intersectsLocation(node) { + if intersectsLocationExclusive(node, node.Semicolon) { paths = append(paths, slices.Clone(tracker.Path())) } return nil }, DoVisitSyntaxNode: func(node *ast.SyntaxNode) error { - if intersectsLocation(node) { + if intersectsLocationExclusive(node, node.Semicolon) { paths = append(paths, slices.Clone(tracker.Path())) } return nil }, DoVisitMessageNode: func(node *ast.MessageNode) error { - if intersectsLocation(node) { + if intersectsLocationExclusive(node, node.CloseBrace) { paths = append(paths, slices.Clone(tracker.Path())) } return nil }, DoVisitOptionNode: func(node *ast.OptionNode) error { - if intersectsLocation(node) { + if intersectsLocationExclusive(node, node.Semicolon) { paths = append(paths, slices.Clone(tracker.Path())) } return nil }, DoVisitMessageLiteralNode: func(node *ast.MessageLiteralNode) error { - if intersectsLocation(node) { + if intersectsLocationExclusive(node, node.Close) { paths = append(paths, slices.Clone(tracker.Path())) } return nil @@ -641,13 +655,13 @@ func findNarrowestEnclosingScope(parseRes parser.Result, tokenAtOffset ast.Token return nil }, DoVisitCompactOptionsNode: func(node *ast.CompactOptionsNode) error { - if intersectsLocation(node) { + if intersectsLocationExclusive(node, node.CloseBracket) { paths = append(paths, slices.Clone(tracker.Path())) } return nil }, DoVisitFieldNode: func(node *ast.FieldNode) error { - if intersectsLocation(node) { + if intersectsLocationExclusive(node, node.Semicolon) { paths = append(paths, slices.Clone(tracker.Path())) } return nil @@ -659,7 +673,13 @@ func findNarrowestEnclosingScope(parseRes parser.Result, tokenAtOffset ast.Token return nil }, DoVisitPackageNode: func(node *ast.PackageNode) error { - if intersectsLocation(node) { + if intersectsLocationExclusive(node, node.Semicolon) { + paths = append(paths, slices.Clone(tracker.Path())) + } + return nil + }, + DoVisitErrorNode: func(en *ast.ErrorNode) error { + if intersectsLocation(en) { paths = append(paths, slices.Clone(tracker.Path())) } return nil diff --git a/pkg/lsp/server.go b/pkg/lsp/server.go index d403e1e..e542543 100644 --- a/pkg/lsp/server.go +++ b/pkg/lsp/server.go @@ -233,7 +233,7 @@ func (s *Server) DidOpen(ctx context.Context, params *protocol.DidOpenTextDocume if !uri.IsFile() { return nil } - return c.DidModifyFiles(ctx, []file.Modification{ + c.DidModifyFiles(ctx, []file.Modification{ { URI: uri, Action: file.Open, @@ -242,6 +242,7 @@ func (s *Server) DidOpen(ctx context.Context, params *protocol.DidOpenTextDocume LanguageID: params.TextDocument.LanguageID, }, }) + return nil } // DidClose implements protocol.Server. @@ -255,7 +256,7 @@ func (s *Server) DidClose(ctx context.Context, params *protocol.DidCloseTextDocu if !uri.IsFile() { return nil } - return c.DidModifyFiles(ctx, []file.Modification{ + c.DidModifyFiles(ctx, []file.Modification{ { URI: uri, Action: file.Close, @@ -263,6 +264,7 @@ func (s *Server) DidClose(ctx context.Context, params *protocol.DidCloseTextDocu Text: nil, }, }) + return nil } // DidChange implements protocol.Server. @@ -276,21 +278,22 @@ func (s *Server) DidChange(ctx context.Context, params *protocol.DidChangeTextDo if !uri.IsFile() { return nil } - text, err := c.ChangedText(ctx, uri, params.ContentChanges) + text, err := c.ChangedText(ctx, params.TextDocument, params.ContentChanges) if err != nil { return err } - return c.DidModifyFiles(ctx, []file.Modification{{ + c.DidModifyFiles(ctx, []file.Modification{{ URI: uri, Action: file.Change, Version: params.TextDocument.Version, Text: text, }}) + return nil } // DidChangeWatchedFiles implements protocol.Server. func (s *Server) DidChangeWatchedFiles(ctx context.Context, params *protocol.DidChangeWatchedFilesParams) error { - mods := map[*Cache][]file.Modification{} + modsByCache := map[*Cache][]file.Modification{} for _, change := range params.Changes { uri := change.URI if !uri.IsFile() { @@ -300,16 +303,15 @@ func (s *Server) DidChangeWatchedFiles(ctx context.Context, params *protocol.Did if err != nil { continue } - mods[cache] = append(mods[cache], file.Modification{ - URI: uri, - Action: changeTypeToFileAction(change.Type), - OnDisk: true, + modsByCache[cache] = append(modsByCache[cache], file.Modification{ + URI: uri, + Action: changeTypeToFileAction(change.Type), + Version: -1, + OnDisk: true, }) } - for c, mods := range mods { - if err := c.DidModifyFiles(ctx, mods); err != nil { - slog.Error("failed to update files", "error", err) - } + for c, mods := range modsByCache { + c.DidModifyFiles(ctx, mods) } return nil } @@ -416,7 +418,8 @@ func (s *Server) DidSave(ctx context.Context, params *protocol.DidSaveTextDocume if params.Text != nil { mod.Text = []byte(*params.Text) } - return c.DidModifyFiles(ctx, []file.Modification{mod}) + c.DidModifyFiles(ctx, []file.Modification{mod}) + return nil } // SemanticTokensFull implements protocol.Server. @@ -656,7 +659,7 @@ func (s *Server) ExecuteCommand(ctx context.Context, params *protocol.ExecuteCom } return c.GetSyntheticFileContents(ctx, protocol.DocumentURI(req.URI)) case "protols/ast": - var req SyntheticFileContentsRequest + var req DocumentASTRequest if err := json.Unmarshal(params.Arguments[0], &req); err != nil { return nil, err } @@ -664,6 +667,9 @@ func (s *Server) ExecuteCommand(ctx context.Context, params *protocol.ExecuteCom if err != nil { return nil, err } + if err := c.WaitDocumentVersion(ctx, protocol.DocumentURI(req.URI), req.Version); err != nil { + return nil, err + } parseRes, err := c.FindParseResultByURI(protocol.DocumentURI(req.URI)) if err != nil { return nil, err diff --git a/pkg/lsp/versions.go b/pkg/lsp/versions.go new file mode 100644 index 0000000..e42ffa7 --- /dev/null +++ b/pkg/lsp/versions.go @@ -0,0 +1,60 @@ +package lsp + +import ( + "context" + "sync" + + "github.com/kralicky/tools-lite/gopls/pkg/file" + "github.com/kralicky/tools-lite/gopls/pkg/lsp/protocol" +) + +type documentVersionQueue struct { + mu sync.Mutex + versions map[protocol.DocumentURI]int32 + queue map[protocol.DocumentURI]map[int32]chan struct{} +} + +func newDocumentVersionQueue() *documentVersionQueue { + return &documentVersionQueue{ + versions: make(map[protocol.DocumentURI]int32), + queue: make(map[protocol.DocumentURI]map[int32]chan struct{}), + } +} + +func (t *documentVersionQueue) Wait(ctx context.Context, uri protocol.DocumentURI, version int32) error { + t.mu.Lock() + if currentVersion, ok := t.versions[uri]; ok && currentVersion >= version { + t.mu.Unlock() + return nil + } + if _, ok := t.queue[uri]; !ok { + t.queue[uri] = make(map[int32]chan struct{}) + } + qc, ok := t.queue[uri][version] + if !ok { + qc = make(chan struct{}) + t.queue[uri][version] = qc + } + t.mu.Unlock() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-qc: + return nil + } +} + +func (t *documentVersionQueue) Update(modifications ...file.Modification) { + t.mu.Lock() + defer t.mu.Unlock() + for _, mod := range modifications { + t.versions[mod.URI] = int32(mod.Version) + for v, qc := range t.queue[mod.URI] { + if mod.Version == -1 || v <= int32(mod.Version) { + delete(t.queue[mod.URI], v) + close(qc) + } + } + } +} diff --git a/pkg/protols/commands/serve.go b/pkg/protols/commands/serve.go index 9a8f051..c889687 100644 --- a/pkg/protols/commands/serve.go +++ b/pkg/protols/commands/serve.go @@ -89,7 +89,8 @@ func BuildServeCmd() *cobra.Command { // methods that are intended to be long-lived, and should not hold up the queue var streamingRequestMethods = map[string]bool{ - "workspace/diagnostic": true, + "workspace/diagnostic": true, + "workspace/executeCommand": true, } func AsyncHandler(handler jsonrpc2.Handler) jsonrpc2.Handler {