diff --git a/go.mod b/go.mod index 35975ba..4898932 100644 --- a/go.mod +++ b/go.mod @@ -7,10 +7,11 @@ require ( github.com/AlecAivazis/survey/v2 v2.3.7 github.com/bufbuild/protovalidate-go v0.5.2 github.com/google/cel-go v0.20.0 + github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.6.0 github.com/kralicky/gpkg v0.0.0-20240119195700-64f32830b14f - github.com/kralicky/protocompile v0.0.0-20240221032829-e40f4c19d142 - github.com/kralicky/tools-lite v0.0.0-20240209234032-93b7eedbea2e + github.com/kralicky/protocompile v0.0.0-20240221212304-b5a12d32e33d + github.com/kralicky/tools-lite v0.0.0-20240221184119-4cba2183fdda github.com/mattn/go-tty v0.0.5 github.com/spf13/cobra v1.8.0 golang.org/x/mod v0.15.0 diff --git a/go.sum b/go.sum index 166fe92..dd3903d 100644 --- a/go.sum +++ b/go.sum @@ -34,10 +34,10 @@ github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNU github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/kralicky/gpkg v0.0.0-20240119195700-64f32830b14f h1:MsNe8A51V+7Fu5OMXSl8SK02erPJ40vFs2zDHn89w1g= github.com/kralicky/gpkg v0.0.0-20240119195700-64f32830b14f/go.mod h1:vOkwMjs49XmP/7Xfo9ZL6eg2ei51lmtD/4U/Az5GTq8= -github.com/kralicky/protocompile v0.0.0-20240221032829-e40f4c19d142 h1:l30nktO2UguhdYw6ylVePYVHLFYnzgSBmoJYDywTosk= -github.com/kralicky/protocompile v0.0.0-20240221032829-e40f4c19d142/go.mod h1:eIoBteRQ90jYYcBBAL8RNOaVSMmyWFDqAH4t3i1elUc= -github.com/kralicky/tools-lite v0.0.0-20240209234032-93b7eedbea2e h1:Mic4oZbKrGxJ6l3FmVLax0+RNOxE/J6KnnlN4BH2d58= -github.com/kralicky/tools-lite v0.0.0-20240209234032-93b7eedbea2e/go.mod h1:NKsdxFI6awifvNvxDwtCU1YCaKRoSSPpbHXkKOMuq24= +github.com/kralicky/protocompile v0.0.0-20240221212304-b5a12d32e33d h1:gV7uD5ltHzixn1AzbfEs2rQqApmSz6o5yq/dmVZNrGM= +github.com/kralicky/protocompile v0.0.0-20240221212304-b5a12d32e33d/go.mod h1:eIoBteRQ90jYYcBBAL8RNOaVSMmyWFDqAH4t3i1elUc= +github.com/kralicky/tools-lite v0.0.0-20240221184119-4cba2183fdda h1:5zLw2UdV/QT50HmgsCBaJIVhrD2D+3AFr+EstHRtZFI= +github.com/kralicky/tools-lite v0.0.0-20240221184119-4cba2183fdda/go.mod h1:bDe7Unrh7kKhnkf5+csdbu0hSd8RRo1EnnnViKGUewo= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA= github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= diff --git a/pkg/lsp/cache.go b/pkg/lsp/cache.go index 2456e5b..2713d6e 100644 --- a/pkg/lsp/cache.go +++ b/pkg/lsp/cache.go @@ -3,7 +3,6 @@ package lsp import ( "context" "fmt" - "log/slog" "runtime" "strings" "sync" @@ -97,9 +96,6 @@ func NewCache(workspace protocol.WorkspaceFolder, opts ...CacheOption) *Cache { } func (c *Cache) LoadFiles(files []string) { - slog.Debug("initializing") - defer slog.Debug("done initializing") - created := make([]file.Modification, len(files)) for i, f := range files { created[i] = file.Modification{ diff --git a/pkg/lsp/semantic.go b/pkg/lsp/semantic.go index 2fc968c..1e69f22 100644 --- a/pkg/lsp/semantic.go +++ b/pkg/lsp/semantic.go @@ -187,7 +187,7 @@ func semanticTokensRange(cache *Cache, doc protocol.TextDocumentIdentifier, rng return ret, err } -const debugCheckOverlappingTokens = false +var DebugCheckOverlappingTokens = false func computeSemanticTokens(cache *Cache, e *semanticItems, walkOptions ...ast.WalkOption) { e.inspect(cache, e.AST(), walkOptions...) @@ -197,7 +197,7 @@ func computeSemanticTokens(cache *Cache, e *semanticItems, walkOptions ...ast.Wa } return e.items[i].start < e.items[j].start }) - if !debugCheckOverlappingTokens { + if !DebugCheckOverlappingTokens { return } @@ -469,6 +469,9 @@ func (s *semanticItems) inspect(cache *Cache, node ast.Node, walkOptions ...ast. s.mktokens(node, tracker.Path(), semanticTypeKeyword, 0) } case *ast.RuneNode: + if node.Virtual || node.Rune == 0 { + return true + } switch node.Rune { case '}', '{', '.', ',', '<', '>', '(', ')', '[', ']', ';', ':': s.mkcomments(node) diff --git a/pkg/lsp/server.go b/pkg/lsp/server.go index 4ae43fe..6b2bfd1 100644 --- a/pkg/lsp/server.go +++ b/pkg/lsp/server.go @@ -26,15 +26,14 @@ type Server struct { client protocol.Client - trackerMu sync.Mutex - tracker *progress.Tracker - - diagnosticStreamMu sync.RWMutex - diagnosticStreamCancel func() + trackerMu sync.Mutex + tracker *progress.Tracker + shutdownOnce sync.Once } type ServerOptions struct { unknownCommandHandlers map[string]UnknownCommandHandler + shutdownHooks []func(context.Context) } type ServerOption func(*ServerOptions) @@ -56,6 +55,12 @@ func WithUnknownCommandHandler(handler UnknownCommandHandler, cmds ...string) Se } } +func WithShutdownHook(hook func(context.Context)) ServerOption { + return func(o *ServerOptions) { + o.shutdownHooks = append(o.shutdownHooks, hook) + } +} + func NewServer(client protocol.Client, opts ...ServerOption) *Server { var options ServerOptions options.apply(opts...) @@ -136,6 +141,10 @@ func (s *Server) Initialize(ctx context.Context, params *protocol.ParamInitializ }, } slog.Debug("Initialize", "folders", folders) + defer s.client.LogMessage(ctx, &protocol.LogMessageParams{ + Type: protocol.Info, + Message: fmt.Sprintf("initialized workspace folders: %v", folders), + }) return &protocol.InitializeResult{ Capabilities: protocol.ServerCapabilities{ TextDocumentSync: protocol.TextDocumentSyncOptions{ @@ -736,10 +745,30 @@ func (s *Server) References(ctx context.Context, params *protocol.ReferenceParam } // Shutdown implements protocol.Server. -func (*Server) Shutdown(context.Context) error { +func (s *Server) Shutdown(ctx context.Context) error { + s.shutdownOnce.Do(func() { s.shutdown(ctx) }) return nil } +// Exit implements protocol.Server. +func (s *Server) Exit(ctx context.Context) error { + s.shutdownOnce.Do(func() { s.shutdown(ctx) }) + return nil +} + +func (s *Server) shutdown(ctx context.Context) { + slog.Info("server is shutting down") + for path := range s.caches { + s.cacheDestroyLocked(path, fmt.Errorf("server is shutting down")) + } + clear(s.caches) + for _, hook := range s.shutdownHooks { + // these must be run in a separate goroutine, since closing a jsonrpc conn + // from within the hook can trigger a deadlock. + go hook(ctx) + } +} + // DidChangeWorkspaceFolders implements protocol.Server. func (s *Server) DidChangeWorkspaceFolders(ctx context.Context, params *protocol.DidChangeWorkspaceFoldersParams) error { added := params.Event.Added @@ -1022,11 +1051,6 @@ func (*Server) DidSaveNotebookDocument(context.Context, *protocol.DidSaveNoteboo return notImplemented("DidSaveNotebookDocument") } -// Exit implements protocol.Server. -func (*Server) Exit(context.Context) error { - return notImplemented("Exit") -} - // FoldingRange implements protocol.Server. func (*Server) FoldingRange(context.Context, *protocol.FoldingRangeParams) ([]protocol.FoldingRange, error) { return nil, notImplemented("FoldingRange") diff --git a/pkg/lsprpc/handler.go b/pkg/lsprpc/handler.go new file mode 100644 index 0000000..d01bdbc --- /dev/null +++ b/pkg/lsprpc/handler.go @@ -0,0 +1,185 @@ +package lsprpc + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "path" + "strings" + + "github.com/kralicky/protocompile/linker" + "github.com/kralicky/protols/pkg/lsp" + "github.com/kralicky/protols/pkg/util" + "github.com/kralicky/protols/sdk/codegen" + "github.com/kralicky/protols/sdk/plugin" + "github.com/kralicky/tools-lite/gopls/pkg/lsp/protocol" + "github.com/kralicky/tools-lite/pkg/event" + "github.com/kralicky/tools-lite/pkg/jsonrpc2" + "google.golang.org/protobuf/types/descriptorpb" +) + +func NewStreamServer() jsonrpc2.StreamServer { + return &streamServer{} +} + +type streamServer struct{} + +func (s *streamServer) ServeStream(ctx context.Context, conn jsonrpc2.Conn) error { + client := protocol.ClientDispatcher(conn) + + server := lsp.NewServer(client, + lsp.WithUnknownCommandHandler( + &unknownHandler{Generators: codegen.DefaultGenerators()}, + "protols/generate", + "protols/generateWorkspace", + ), + lsp.WithShutdownHook(func(ctx context.Context) { + conn.Close() + }), + ) + handler := protocol.CancelHandler( + AsyncHandler( + jsonrpc2.MustReplyHandler( + protocol.ServerHandler(server, jsonrpc2.MethodNotFound)))) + conn.Go(ctx, handler) + <-conn.Done() + if err := conn.Err(); err != nil { + return fmt.Errorf("server exited with error: %w", err) + } + return nil +} + +// methods that are intended to be long-lived, and should not hold up the queue +var streamingRequestMethods = map[string]bool{ + "workspace/diagnostic": true, + "workspace/executeCommand": true, +} + +func AsyncHandler(handler jsonrpc2.Handler) jsonrpc2.Handler { + nextRequest := make(chan struct{}) + close(nextRequest) + return func(ctx context.Context, reply jsonrpc2.Replier, req jsonrpc2.Request) error { + waitForPrevious := nextRequest + nextRequest = make(chan struct{}) + unlockNext := nextRequest + if streamingRequestMethods[req.Method()] { + close(unlockNext) + } else { + innerReply := reply + reply = func(ctx context.Context, result interface{}, err error) error { + close(unlockNext) + return innerReply(ctx, result, err) + } + } + _, queueDone := event.Start(ctx, "queued") + go func() { + <-waitForPrevious + queueDone() + if err := handler(ctx, reply, req); err != nil { + event.Error(ctx, "jsonrpc2 async message delivery failed", err) + } + }() + return nil + } +} + +type unknownHandler struct { + Generators []codegen.Generator +} + +// Execute implements lsp.UnknownCommandHandler. +func (h *unknownHandler) Execute(ctx context.Context, uc lsp.UnknownCommand) (any, error) { + switch uc.Command { + case "protols/generate": + var req lsp.GenerateCodeRequest + if err := json.Unmarshal(uc.Arguments[0], &req); err != nil { + return nil, err + } + if uc.Cache == nil { + return nil, errors.New("no cache available") + } + return nil, h.doGenerate(ctx, uc.Cache, req.URIs) + case "protols/generateWorkspace": + if uc.Cache == nil { + return nil, errors.New("no cache available") + } + return nil, h.doGenerate(ctx, uc.Cache, uc.Cache.XListWorkspaceLocalURIs()) + default: + panic("unknown command: " + uc.Command) + } +} + +var _ lsp.UnknownCommandHandler = (*unknownHandler)(nil) + +func (h *unknownHandler) doGenerate(ctx context.Context, cache *lsp.Cache, uris []protocol.DocumentURI) error { + pathMappings := cache.XGetURIPathMappings() + roots := make(linker.Files, 0, len(uris)) + outputDirs := map[string]string{} + for _, uri := range uris { + res, err := cache.FindResultByURI(uri) + if err != nil { + return err + } + if res.Package() == "" { + continue + } + if _, ok := res.AST().Pragma(lsp.PragmaNoGenerate); ok { + continue + } + roots = append(roots, res) + p := pathMappings.FilePathsByURI[uri] + outputDirs[path.Dir(p)] = path.Dir(uri.Path()) + if opts := res.Options(); opts.ProtoReflect().IsValid() { + if goPkg := opts.(*descriptorpb.FileOptions).GoPackage; goPkg != nil { + // if the file has a different go_package than the implicit one, add + // it to the output dirs map as well + outputDirs[strings.Split(*goPkg, ";")[0]] = path.Dir(uri.Path()) + } + } + } + closure := linker.ComputeReflexiveTransitiveClosure(roots) + closureResults := make([]linker.Result, len(closure)) + for i, res := range closure { + closureResults[i] = res.(linker.Result) + } + + plugin, err := plugin.New(roots, closureResults, pathMappings) + if err != nil { + return err + } + for _, g := range h.Generators { + if err := g.Generate(plugin); err != nil { + return err + } + } + response := plugin.Response() + if response.Error != nil { + return errors.New(response.GetError()) + } + var errs error + for _, rf := range response.GetFile() { + dir, ok := outputDirs[path.Dir(rf.GetName())] + if !ok { + errs = errors.Join(errs, fmt.Errorf("cannot write outside of workspace module: %s", rf.GetName())) + continue + } + absPath := path.Join(dir, path.Base(rf.GetName())) + if info, err := os.Stat(absPath); err == nil { + original, err := os.ReadFile(absPath) + if err != nil { + return err + } + updated := rf.GetContent() + if err := util.OverwriteFile(absPath, original, []byte(updated), info.Mode().Perm(), info.Size()); err != nil { + return err + } + } else { + if err := os.WriteFile(absPath, []byte(rf.GetContent()), 0o644); err != nil { + return err + } + } + } + return errs +} diff --git a/pkg/protols/commands/serve.go b/pkg/protols/commands/serve.go index feccb10..be5474f 100644 --- a/pkg/protols/commands/serve.go +++ b/pkg/protols/commands/serve.go @@ -3,30 +3,17 @@ package commands import ( "bytes" "context" - "encoding/json" "errors" - "fmt" "log/slog" "net" - "os" - "path" - "strings" "sync" - "github.com/kralicky/protocompile/linker" - "github.com/kralicky/protols/pkg/lsp" - "github.com/kralicky/protols/pkg/util" - "github.com/kralicky/protols/sdk/codegen" - "github.com/kralicky/protols/sdk/plugin" - "github.com/kralicky/tools-lite/gopls/pkg/lsp/protocol" - "google.golang.org/protobuf/types/descriptorpb" - + "github.com/kralicky/protols/pkg/lsprpc" "github.com/kralicky/tools-lite/pkg/event" "github.com/kralicky/tools-lite/pkg/event/core" "github.com/kralicky/tools-lite/pkg/event/keys" "github.com/kralicky/tools-lite/pkg/event/label" "github.com/kralicky/tools-lite/pkg/jsonrpc2" - "github.com/spf13/cobra" ) @@ -37,19 +24,10 @@ func BuildServeCmd() *cobra.Command { Use: "serve", Short: "Start the language server", RunE: func(cmd *cobra.Command, args []string) error { - cc, err := net.Dial("unix", pipe) - if err != nil { - return err - } - stream := jsonrpc2.NewHeaderStream(cc) - conn := jsonrpc2.NewConn(stream) - client := protocol.ClientDispatcher(conn) - slog.SetDefault(slog.New(slog.NewTextHandler(cmd.OutOrStderr(), &slog.HandlerOptions{ AddSource: true, Level: slog.LevelDebug, }))) - var eventMu sync.Mutex event.SetExporter(func(ctx context.Context, e core.Event, lm label.Map) context.Context { eventMu.Lock() @@ -74,23 +52,14 @@ func BuildServeCmd() *cobra.Command { return ctx }) - server := lsp.NewServer(client, - lsp.WithUnknownCommandHandler( - &unknownHandler{Generators: codegen.DefaultGenerators()}, - "protols/generate", - "protols/generateWorkspace", - ), - ) - conn.Go(cmd.Context(), protocol.CancelHandler( - AsyncHandler( - jsonrpc2.MustReplyHandler( - protocol.ServerHandler(server, jsonrpc2.MethodNotFound))))) - - <-conn.Done() - if err := conn.Err(); err != nil { - return fmt.Errorf("server exited with error: %w", err) + cc, err := net.Dial("unix", pipe) + if err != nil { + return err } - return nil + stream := jsonrpc2.NewHeaderStream(cc) + conn := jsonrpc2.NewConn(stream) + ss := lsprpc.NewStreamServer() + return ss.ServeStream(cmd.Context(), conn) }, } @@ -99,136 +68,3 @@ func BuildServeCmd() *cobra.Command { return cmd } - -// methods that are intended to be long-lived, and should not hold up the queue -var streamingRequestMethods = map[string]bool{ - "workspace/diagnostic": true, - "workspace/executeCommand": true, -} - -func AsyncHandler(handler jsonrpc2.Handler) jsonrpc2.Handler { - nextRequest := make(chan struct{}) - close(nextRequest) - return func(ctx context.Context, reply jsonrpc2.Replier, req jsonrpc2.Request) error { - waitForPrevious := nextRequest - nextRequest = make(chan struct{}) - unlockNext := nextRequest - if streamingRequestMethods[req.Method()] { - close(unlockNext) - } else { - innerReply := reply - reply = func(ctx context.Context, result interface{}, err error) error { - close(unlockNext) - return innerReply(ctx, result, err) - } - } - _, queueDone := event.Start(ctx, "queued") - go func() { - <-waitForPrevious - queueDone() - if err := handler(ctx, reply, req); err != nil { - event.Error(ctx, "jsonrpc2 async message delivery failed", err) - } - }() - return nil - } -} - -type unknownHandler struct { - Generators []codegen.Generator -} - -// Execute implements lsp.UnknownCommandHandler. -func (h *unknownHandler) Execute(ctx context.Context, uc lsp.UnknownCommand) (any, error) { - switch uc.Command { - case "protols/generate": - var req lsp.GenerateCodeRequest - if err := json.Unmarshal(uc.Arguments[0], &req); err != nil { - return nil, err - } - if uc.Cache == nil { - return nil, errors.New("no cache available") - } - return nil, h.doGenerate(ctx, uc.Cache, req.URIs) - case "protols/generateWorkspace": - if uc.Cache == nil { - return nil, errors.New("no cache available") - } - return nil, h.doGenerate(ctx, uc.Cache, uc.Cache.XListWorkspaceLocalURIs()) - default: - panic("unknown command: " + uc.Command) - } -} - -var _ lsp.UnknownCommandHandler = (*unknownHandler)(nil) - -func (h *unknownHandler) doGenerate(ctx context.Context, cache *lsp.Cache, uris []protocol.DocumentURI) error { - pathMappings := cache.XGetURIPathMappings() - roots := make(linker.Files, 0, len(uris)) - outputDirs := map[string]string{} - for _, uri := range uris { - res, err := cache.FindResultByURI(uri) - if err != nil { - return err - } - if res.Package() == "" { - continue - } - if _, ok := res.AST().Pragma(lsp.PragmaNoGenerate); ok { - continue - } - roots = append(roots, res) - p := pathMappings.FilePathsByURI[uri] - outputDirs[path.Dir(p)] = path.Dir(uri.Path()) - if opts := res.Options(); opts.ProtoReflect().IsValid() { - if goPkg := opts.(*descriptorpb.FileOptions).GoPackage; goPkg != nil { - // if the file has a different go_package than the implicit one, add - // it to the output dirs map as well - outputDirs[strings.Split(*goPkg, ";")[0]] = path.Dir(uri.Path()) - } - } - } - closure := linker.ComputeReflexiveTransitiveClosure(roots) - closureResults := make([]linker.Result, len(closure)) - for i, res := range closure { - closureResults[i] = res.(linker.Result) - } - - plugin, err := plugin.New(roots, closureResults, pathMappings) - if err != nil { - return err - } - for _, g := range h.Generators { - if err := g.Generate(plugin); err != nil { - return err - } - } - response := plugin.Response() - if response.Error != nil { - return errors.New(response.GetError()) - } - var errs error - for _, rf := range response.GetFile() { - dir, ok := outputDirs[path.Dir(rf.GetName())] - if !ok { - errs = errors.Join(errs, fmt.Errorf("cannot write outside of workspace module: %s", rf.GetName())) - continue - } - absPath := path.Join(dir, path.Base(rf.GetName())) - if info, err := os.Stat(absPath); err == nil { - original, err := os.ReadFile(absPath) - if err != nil { - return err - } - updated := rf.GetContent() - if err := util.OverwriteFile(absPath, original, []byte(updated), info.Mode().Perm(), info.Size()); err != nil { - return err - } - } else { - if err := os.WriteFile(absPath, []byte(rf.GetContent()), 0o644); err != nil { - return err - } - } - } - return errs -} diff --git a/test/lsp_test.go b/test/lsp_test.go new file mode 100644 index 0000000..a1193cc --- /dev/null +++ b/test/lsp_test.go @@ -0,0 +1,58 @@ +package test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/kralicky/protols/pkg/lsp" + "github.com/kralicky/tools-lite/gopls/pkg/test/integration" + "github.com/kralicky/tools-lite/gopls/pkg/test/integration/fake" +) + +func TestMain(m *testing.M) { + lsp.DebugCheckOverlappingTokens = true + Main(m) +} + +func TestBasic(t *testing.T) { + const src = ` +-- go.mod -- +module example.com + +go 1.22 +-- main.go -- +package main + +func main() {} + +-- test.proto -- +syntax = "proto3"; + +package test; + +message Test { + string test = 1; +} + +` + Run(t, src, func(t *testing.T, env *integration.Env) { + env.OpenFile("test.proto") + tokens := env.SemanticTokensFull("test.proto") + want := []fake.SemanticToken{ + {Token: "syntax", TokenType: "keyword"}, + {Token: "=", TokenType: "operator"}, + {Token: `"proto3"`, TokenType: "string"}, + {Token: "package", TokenType: "keyword"}, + {Token: "test", TokenType: "namespace"}, + {Token: "message", TokenType: "keyword"}, + {Token: "Test", TokenType: "type", Mod: "definition"}, + {Token: "string", TokenType: "type", Mod: "defaultLibrary"}, + {Token: "test", TokenType: "variable", Mod: "definition"}, + {Token: "=", TokenType: "operator"}, + {Token: "1", TokenType: "number"}, + } + if x := cmp.Diff(want, tokens); x != "" { + t.Errorf("Semantic tokens do not match (-want +got):\n%s", x) + } + }) +} diff --git a/test/runner.go b/test/runner.go new file mode 100644 index 0000000..bcc4fb1 --- /dev/null +++ b/test/runner.go @@ -0,0 +1,231 @@ +package test + +import ( + "bytes" + "context" + "flag" + "fmt" + "io" + "net" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + "github.com/kralicky/protols/pkg/lsprpc" + "github.com/kralicky/tools-lite/gopls/pkg/lsp/protocol" + "github.com/kralicky/tools-lite/gopls/pkg/test/integration" + "github.com/kralicky/tools-lite/gopls/pkg/test/integration/fake" + "github.com/kralicky/tools-lite/pkg/jsonrpc2" + "github.com/kralicky/tools-lite/pkg/jsonrpc2/servertest" + "github.com/kralicky/tools-lite/pkg/testenv" +) + +var runner *Runner + +var ( + printLogs = flag.Bool("print-logs", false, "whether to print LSP logs") + printGoroutinesOnFailure = flag.Bool("print-goroutines", false, "whether to print goroutines info on failure") + skipCleanup = flag.Bool("skip-cleanup", false, "whether to skip cleaning up temp directories") +) + +func Main(m *testing.M) { + dir, err := os.MkdirTemp("", "protols-test-") + if err != nil { + panic(fmt.Errorf("creating temp directory: %v", err)) + } + flag.Parse() + + runner = &Runner{ + SkipCleanup: *skipCleanup, + tempDir: dir, + } + var code int + defer func() { + if err := runner.Close(); err != nil { + fmt.Fprintf(os.Stderr, "closing test runner: %v\n", err) + os.Exit(1) + } + os.Exit(code) + }() + code = m.Run() +} + +func Run(t *testing.T, files string, f TestFunc) { + runner.Run(t, files, f) +} + +type Runner struct { + SkipCleanup bool + + tempDir string + tsOnce sync.Once + ts *servertest.TCPServer +} + +type ( + TestFunc func(t *testing.T, env *integration.Env) + runConfig struct { + editor fake.EditorConfig + sandbox fake.SandboxConfig + } +) + +func defaultConfig() runConfig { + return runConfig{ + editor: fake.EditorConfig{ + ClientName: "gotest", + FileAssociations: map[string]string{ + "protobuf": `.*\.proto$`, + }, + }, + } +} + +// Run executes the test function in the default configured gopls execution +// modes. For each a test run, a new workspace is created containing the +// un-txtared files specified by filedata. +func (r *Runner) Run(t *testing.T, files string, test TestFunc, opts ...integration.RunOption) { + // TODO(rfindley): this function has gotten overly complicated, and warrants + // refactoring. + t.Helper() + + config := defaultConfig() + t.Run("in-process", func(t *testing.T) { + ctx := context.Background() + if d, ok := testenv.Deadline(t); ok { + timeout := time.Until(d) * 19 / 20 // Leave an arbitrary 5% for cleanup. + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + rootDir := filepath.Join(r.tempDir, filepath.FromSlash(t.Name())) + if err := os.MkdirAll(rootDir, 0o755); err != nil { + t.Fatal(err) + } + files := fake.UnpackTxt(files) + config.sandbox.Files = files + config.sandbox.RootDir = rootDir + sandbox, err := fake.NewSandbox(&config.sandbox) + if err != nil { + t.Fatal(err) + } + defer func() { + if !r.SkipCleanup { + if err := sandbox.Close(); err != nil { + t.Errorf("closing the sandbox: %v", err) + } + } + }() + ss := lsprpc.NewStreamServer() + r.ts = servertest.NewTCPServer(ctx, ss, nil) + + framer := jsonrpc2.NewRawStream + ls := &loggingFramer{} + framer = ls.framer(jsonrpc2.NewRawStream) + ts := servertest.NewPipeServer(ss, framer) + awaiter := integration.NewAwaiter(sandbox.Workdir) + const skipApplyEdits = false + editor, err := fake.NewEditor(sandbox, config.editor).Connect(ctx, ts, awaiter.Hooks(), skipApplyEdits) + if err != nil { + t.Fatal(err) + } + env := &integration.Env{ + T: t, + Ctx: ctx, + Sandbox: sandbox, + Editor: editor, + Server: ts, + Awaiter: awaiter, + } + defer func() { + if t.Failed() { + ls.printBuffers(t.Name(), os.Stderr) + } + // For tests that failed due to a timeout, don't fail to shutdown + // because ctx is done. + // + // There is little point to setting an arbitrary timeout for closing + // the editor: in general we want to clean up before proceeding to the + // next test, and if there is a deadlock preventing closing it will + // eventually be handled by the `go test` timeout. + if err := editor.Close(context.WithoutCancel(ctx)); err != nil { + t.Errorf("error closing editor: %v", err) + } + }() + // Always await the initial workspace load. + env.Await(integration.AllOf( + integration.LogMatching(protocol.Info, "initialized workspace folders", 1, true), + )) + test(t, env) + }) +} + +// Close cleans up resource that have been allocated to this workspace. +func (r *Runner) Close() error { + var errmsgs []string + if r.ts != nil { + if err := r.ts.Close(); err != nil { + errmsgs = append(errmsgs, err.Error()) + } + } + if !r.SkipCleanup { + if err := os.RemoveAll(r.tempDir); err != nil { + errmsgs = append(errmsgs, err.Error()) + } + } + if len(errmsgs) > 0 { + return fmt.Errorf("errors closing the test runner:\n\t%s", strings.Join(errmsgs, "\n\t")) + } + return nil +} + +type loggingFramer struct { + mu sync.Mutex + buf *safeBuffer +} + +// safeBuffer is a threadsafe buffer for logs. +type safeBuffer struct { + mu sync.Mutex + buf bytes.Buffer +} + +func (b *safeBuffer) Write(p []byte) (int, error) { + b.mu.Lock() + defer b.mu.Unlock() + return b.buf.Write(p) +} + +func (s *loggingFramer) framer(f jsonrpc2.Framer) jsonrpc2.Framer { + return func(nc net.Conn) jsonrpc2.Stream { + s.mu.Lock() + framed := false + if s.buf == nil { + s.buf = &safeBuffer{buf: bytes.Buffer{}} + framed = true + } + s.mu.Unlock() + stream := f(nc) + if framed { + return protocol.LoggingStream(stream, s.buf) + } + return stream + } +} + +func (s *loggingFramer) printBuffers(testname string, w io.Writer) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.buf == nil { + return + } + fmt.Fprintf(os.Stderr, "#### Start Gopls Test Logs for %q\n", testname) + s.buf.mu.Lock() + io.Copy(w, &s.buf.buf) + s.buf.mu.Unlock() + fmt.Fprintf(os.Stderr, "#### End Gopls Test Logs for %q\n", testname) +}