diff --git a/protols/cache.go b/protols/cache.go index d0679a8..581f9fc 100644 --- a/protols/cache.go +++ b/protols/cache.go @@ -9,19 +9,26 @@ import ( "os" "path" "path/filepath" + "runtime" + "strings" "sync" + "time" "github.com/bmatcuk/doublestar" "github.com/bufbuild/protocompile" "github.com/bufbuild/protocompile/ast" "github.com/bufbuild/protocompile/linker" "github.com/bufbuild/protocompile/parser" + "github.com/bufbuild/protocompile/protoutil" "github.com/bufbuild/protocompile/reporter" + "github.com/bufbuild/protocompile/walk" "github.com/jhump/protoreflect/desc" "github.com/jhump/protoreflect/desc/protoprint" + gsync "github.com/kralicky/gpkg/sync" "github.com/kralicky/ragu" "go.uber.org/zap" "golang.org/x/exp/maps" + "golang.org/x/tools/gopls/pkg/lsp/cache" "golang.org/x/tools/gopls/pkg/lsp/protocol" "golang.org/x/tools/gopls/pkg/lsp/source" "golang.org/x/tools/gopls/pkg/span" @@ -49,6 +56,9 @@ type Cache struct { fileURIsByPath map[string]span.URI // canonical file path (go package + file name) -> URI todoModLock sync.Mutex + + inflightTasksInvalidate gsync.Map[string, time.Time] + inflightTasksCompile gsync.Map[string, time.Time] } // FindDescriptorByName implements linker.Resolver. @@ -273,11 +283,62 @@ func (o *Overlay) Get(path string) (*protocol.Mapper, error) { return nil, os.ErrNotExist } -// NewCache creates a new cache. +var requiredGoEnvVars = []string{"GO111MODULE", "GOFLAGS", "GOINSECURE", "GOMOD", "GOMODCACHE", "GONOPROXY", "GONOSUMDB", "GOPATH", "GOPROXY", "GOROOT", "GOSUMDB", "GOWORK"} + func NewCache(workdir string, lg *zap.Logger) *Cache { + synthesizer := NewProtoSourceSynthesizer(workdir) + fmt.Println("SourceAccessor", workdir) + memoizedFs := cache.NewMemoizedFS() + + tryReadFromFs := func(importName string) (_ io.ReadCloser, _err error) { + fmt.Println("tryReadFromOverlay", importName) + defer func() { fmt.Println("tryReadFromOverlay done", importName, _err) }() + fh, err := memoizedFs.ReadFile(context.TODO(), span.URIFromPath(importName)) + if err == nil { + content, err := fh.Content() + if err != nil { + return nil, err + } + if content != nil { + return io.NopCloser(bytes.NewReader(content)), nil + } + } + return nil, err + } + resolverFn := protocompile.ResolverFunc(func(importName string) (result protocompile.SearchResult, _ error) { + if strings.HasPrefix(importName, "google/") { + return protocompile.SearchResult{}, os.ErrNotExist + } + if strings.HasPrefix(importName, "gogoproto/") { + importName = "github.com/gogo/protobuf/" + importName + } + + rc, err := tryReadFromFs(importName) + if err == nil { + result.Source = rc + return + } + f, dir, err := synthesizer.ImportFromGoModule(importName) + if err == nil { + result.Source, _ = os.Open(f) + return + } + if dir != "" { + if synthesized, err := synthesizer.SynthesizeFromGoSource(importName, dir); err == nil { + result.Proto = synthesized + return + } else { + return protocompile.SearchResult{}, fmt.Errorf("failed to synthesize %s: %w", importName, err) + } + } + + return protocompile.SearchResult{}, fmt.Errorf("failed to resolve %s: %w", importName, err) + }) + + // NewCache creates a new cache. diagHandler := NewDiagnosticHandler() reporter := reporter.NewReporter(diagHandler.HandleError, diagHandler.HandleWarning) - accessor := ragu.SourceAccessor(nil) + accessor := tryReadFromFs overlay := &Overlay{ baseAccessor: accessor, sources: map[string]*protocol.Mapper{}, @@ -286,9 +347,7 @@ func NewCache(workdir string, lg *zap.Logger) *Cache { &protocompile.SourceResolver{ Accessor: overlay.Accessor, }, - &protocompile.SourceResolver{ - Accessor: accessor, - }, + resolverFn, protocompile.ResolverFunc(func(path string) (protocompile.SearchResult, error) { fd, err := desc.LoadFileDescriptor(path) if err != nil { @@ -297,10 +356,11 @@ func NewCache(workdir string, lg *zap.Logger) *Cache { return protocompile.SearchResult{Desc: fd.UnwrapFile()}, nil }), } + compiler := &Compiler{ Compiler: &protocompile.Compiler{ Resolver: resolver, - MaxParallelism: -1, + MaxParallelism: runtime.NumCPU() * 4, Reporter: reporter, SourceInfoMode: protocompile.SourceInfoExtraComments | protocompile.SourceInfoExtraOptionLocations, RetainResults: true, @@ -331,21 +391,32 @@ func NewCache(workdir string, lg *zap.Logger) *Cache { func (c *Cache) preInvalidateHook(path string, reason string) { fmt.Printf("invalidating %s (%s)\n", path, reason) - + c.inflightTasksInvalidate.Store(path, time.Now()) c.diagHandler.ClearDiagnosticsForPath(path) } func (c *Cache) postInvalidateHook(path string) { - fmt.Printf("done invalidating %s\n", path) + startTime, ok := c.inflightTasksInvalidate.LoadAndDelete(path) + if ok { + fmt.Printf("invalidated %s (took %s)\n", path, time.Since(startTime)) + } else { + fmt.Printf("invalidated %s\n", path) + } } func (c *Cache) preCompile(path string) { fmt.Printf("compiling %s\n", path) + c.inflightTasksCompile.Store(path, time.Now()) delete(c.partialResults, path) } func (c *Cache) postCompile(path string) { - fmt.Printf("done compiling %s\n", path) + startTime, ok := c.inflightTasksCompile.LoadAndDelete(path) + if ok { + fmt.Printf("compiled %s (took %s)\n", path, time.Since(startTime)) + } else { + fmt.Printf("compiled %s\n", path) + } } func (c *Cache) Reindex() { @@ -570,6 +641,8 @@ func (c *Cache) ComputeDiagnosticReports(uri span.URI) ([]*protocol.Diagnostic, Range: rng, Severity: rawReport.Severity, Message: rawReport.Error.Error(), + Tags: rawReport.Tags, + Source: "protols", }) } @@ -671,8 +744,6 @@ func (c *Cache) computeMessageLiteralHints(doc protocol.TextDocumentIdentifier, a := res.AST() startOff, endOff, _ := mapper.RangeOffsets(rng) - startToken := a.TokenAtOffset(startOff) - endToken := a.TokenAtOffset(endOff) optionsByNode := make(map[*ast.OptionNode][]protoreflect.ExtensionType) @@ -681,6 +752,11 @@ func (c *Cache) computeMessageLiteralHints(doc protocol.TextDocumentIdentifier, opt := opt if len(opt.Name.Parts) == 1 { info := a.NodeInfo(opt.Name) + if info.End().Offset <= startOff { + continue + } else if info.Start().Offset >= endOff { + break + } part := opt.Name.Parts[0] if wellKnownType, ok := wellKnownFileOptions[part.Value()]; ok { hints = append(hints, protocol.InlayHint{ @@ -706,12 +782,24 @@ func (c *Cache) computeMessageLiteralHints(doc protocol.TextDocumentIdentifier, // collect all options for _, svc := range fdp.GetService() { for _, decl := range res.ServiceNode(svc).(*ast.ServiceNode).Decls { + info := a.NodeInfo(decl) + if info.End().Offset <= startOff { + continue + } else if info.Start().Offset >= endOff { + break + } if opt, ok := decl.(*ast.OptionNode); ok { collectOptions[*descriptorpb.ServiceOptions](opt, svc, optionsByNode) } } for _, method := range svc.GetMethod() { for _, decl := range res.MethodNode(method).(*ast.RPCNode).Decls { + info := a.NodeInfo(decl) + if info.End().Offset <= startOff { + continue + } else if info.Start().Offset >= endOff { + break + } if opt, ok := decl.(*ast.OptionNode); ok { collectOptions[*descriptorpb.MethodOptions](opt, method, optionsByNode) } @@ -720,26 +808,44 @@ func (c *Cache) computeMessageLiteralHints(doc protocol.TextDocumentIdentifier, } for _, msg := range fdp.GetMessageType() { for _, decl := range res.MessageNode(msg).(*ast.MessageNode).Decls { + info := a.NodeInfo(decl) + if info.End().Offset <= startOff { + continue + } else if info.Start().Offset >= endOff { + break + } if opt, ok := decl.(*ast.OptionNode); ok { collectOptions[*descriptorpb.MessageOptions](opt, msg, optionsByNode) } - for _, field := range msg.GetField() { - fieldNode := res.FieldNode(field) - switch fieldNode := fieldNode.(type) { - case *ast.FieldNode: - for _, opt := range fieldNode.GetOptions().GetElements() { - collectOptions[*descriptorpb.FieldOptions](opt, field, optionsByNode) - } - case *ast.MapFieldNode: - for _, opt := range fieldNode.GetOptions().GetElements() { - collectOptions[*descriptorpb.FieldOptions](opt, field, optionsByNode) - } + } + for _, field := range msg.GetField() { + fieldNode := res.FieldNode(field) + info := a.NodeInfo(fieldNode) + if info.End().Offset <= startOff { + continue + } else if info.Start().Offset >= endOff { + break + } + switch fieldNode := fieldNode.(type) { + case *ast.FieldNode: + for _, opt := range fieldNode.GetOptions().GetElements() { + collectOptions[*descriptorpb.FieldOptions](opt, field, optionsByNode) + } + case *ast.MapFieldNode: + for _, opt := range fieldNode.GetOptions().GetElements() { + collectOptions[*descriptorpb.FieldOptions](opt, field, optionsByNode) } } } } for _, enum := range fdp.GetEnumType() { for _, decl := range res.EnumNode(enum).(*ast.EnumNode).Decls { + info := a.NodeInfo(decl) + if info.End().Offset <= startOff { + continue + } else if info.Start().Offset >= endOff { + break + } if opt, ok := decl.(*ast.OptionNode); ok { collectOptions[*descriptorpb.EnumOptions](opt, enum, optionsByNode) } @@ -750,18 +856,20 @@ func (c *Cache) computeMessageLiteralHints(doc protocol.TextDocumentIdentifier, } } } - for _, ext := range fdp.GetExtension() { - for _, opt := range res.FieldNode(ext).(*ast.FieldNode).GetOptions().GetElements() { - collectOptions[*descriptorpb.FieldOptions](opt, ext, optionsByNode) - } - } + // for _, ext := range fdp.GetExtension() { + // for _, opt := range res.FieldNode(ext).(*ast.FieldNode).GetOptions().GetElements() { + // collectOptions[*descriptorpb.FieldOptions](opt, ext, optionsByNode) + // } + // } allNodes := a.Children() for _, node := range allNodes { // only look at the decls that overlap the range - start, end := node.Start(), node.End() - if end <= startToken || start >= endToken { + info := a.NodeInfo(node) + if info.End().Offset <= startOff { continue + } else if info.Start().Offset >= endOff { + break } ast.Walk(node, &ast.SimpleVisitor{ DoVisitOptionNode: func(n *ast.OptionNode) error { @@ -987,15 +1095,15 @@ func buildMessageLiteralHints(lit *ast.MessageLiteralNode, msg protoreflect.Mess } fieldHint.PaddingLeft = false } else { - info := a.NodeInfo(field.Sep) - fieldHint.Position = protocol.Position{ - Line: uint32(info.Start().Line) - 1, - Character: uint32(info.Start().Col) - 1, - } - fieldHint.Label = append(fieldHint.Label, protocol.InlayHintLabelPart{ - Value: kind.String(), - }) - fieldHint.PaddingRight = false + // info := a.NodeInfo(field.Sep) + // fieldHint.Position = protocol.Position{ + // Line: uint32(info.Start().Line) - 1, + // Character: uint32(info.Start().Col) - 1, + // } + // fieldHint.Label = append(fieldHint.Label, protocol.InlayHintLabelPart{ + // Value: kind.String(), + // }) + // fieldHint.PaddingRight = false } hints = append(hints, fieldHint) } @@ -1008,9 +1116,10 @@ func makeTooltip(d protoreflect.Descriptor) *protocol.OrPTooltipPLabel { return nil } printer := protoprint.Printer{ - SortElements: true, - Indent: " ", - Compact: true, + SortElements: true, + CustomSortFunction: SortElements, + Indent: " ", + Compact: protoprint.CompactDefault, } str, err := printer.PrintProtoToString(wrap) if err != nil { @@ -1023,3 +1132,91 @@ func makeTooltip(d protoreflect.Descriptor) *protocol.OrPTooltipPLabel { }, } } + +func (c *Cache) FormatDocument(doc protocol.TextDocumentIdentifier, options protocol.FormattingOptions, maybeRange ...protocol.Range) ([]protocol.TextEdit, error) { + printer := protoprint.Printer{ + SortElements: true, + CustomSortFunction: SortElements, + Indent: " ", // todo: tabs break semantic tokens + Compact: protoprint.CompactDefault, + } + path, err := c.URIToPath(doc.URI.SpanURI()) + if err != nil { + return nil, err + } + mapper, err := c.compiler.overlay.Get(path) + if err != nil { + return nil, err + } + res, err := c.FindResultByURI(doc.URI.SpanURI()) + if err != nil { + return nil, err + } + + if len(maybeRange) == 1 { + rng := maybeRange[0] + // format range + start, end, err := mapper.RangeOffsets(rng) + if err != nil { + return nil, err + } + + // Try to map the range to a single top-level element. If the range overlaps + // multiple top level elements, we'll just format the whole file. + + targetDesc, err := findDescriptorWithinRangeOffsets(res, start, end) + if err != nil { + return nil, err + } + splicedBuffer := bytes.NewBuffer(bytes.Clone(mapper.Content[:start])) + + wrap, err := desc.WrapDescriptor(targetDesc) + if err != nil { + return nil, err + } + + err = printer.PrintProto(wrap, splicedBuffer) + if err != nil { + return nil, err + } + splicedBuffer.Write(mapper.Content[end:]) + spliced := splicedBuffer.Bytes() + fmt.Printf("old:\n%s\nnew:\n%s\n", string(mapper.Content), string(spliced)) + + edits := diff.Bytes(mapper.Content, spliced) + return source.ToProtocolEdits(mapper, edits) + } + + wrap, err := desc.WrapFile(res) + if err != nil { + return nil, err + } + // format whole file + buf := bytes.NewBuffer(make([]byte, 0, len(mapper.Content))) + err = printer.PrintProtoFile(wrap, buf) + if err != nil { + return nil, err + } + + edits := diff.Bytes(mapper.Content, buf.Bytes()) + return source.ToProtocolEdits(mapper, edits) +} + +func findDescriptorWithinRangeOffsets(res linker.Result, start, end int) (output protoreflect.Descriptor, err error) { + ast := res.AST() + + err = walk.Descriptors(res, func(d protoreflect.Descriptor) error { + node := res.Node(protoutil.ProtoFromDescriptor(d)) + tokenStart := ast.TokenInfo(node.Start()) + tokenEnd := ast.TokenInfo(node.End()) + if tokenStart.Start().Offset >= start && tokenEnd.End().Offset <= end { + output = d + return sentinel + } + return nil + }) + if err == sentinel { + err = nil + } + return +} diff --git a/protols/diagnostics.go b/protols/diagnostics.go index 605fbd9..deddc4f 100644 --- a/protols/diagnostics.go +++ b/protols/diagnostics.go @@ -1,11 +1,13 @@ package main import ( + "errors" "fmt" "os" "sync" "github.com/bufbuild/protocompile/ast" + "github.com/bufbuild/protocompile/linker" "github.com/bufbuild/protocompile/reporter" "golang.org/x/tools/gopls/pkg/lsp/protocol" ) @@ -14,6 +16,7 @@ type ProtoDiagnostic struct { Pos ast.SourcePosInfo Severity protocol.DiagnosticSeverity Error error + Tags []protocol.DiagnosticTag } func NewDiagnosticHandler() *DiagnosticHandler { @@ -27,6 +30,15 @@ type DiagnosticHandler struct { diagnostics map[string][]*ProtoDiagnostic } +func tagsForError(err error) []protocol.DiagnosticTag { + switch errors.Unwrap(err).(type) { + case linker.ErrorUnusedImport: + return []protocol.DiagnosticTag{protocol.Unnecessary} + default: + return nil + } +} + func (dr *DiagnosticHandler) HandleError(err reporter.ErrorWithPos) error { if err == nil { return nil @@ -44,6 +56,7 @@ func (dr *DiagnosticHandler) HandleError(err reporter.ErrorWithPos) error { Pos: pos, Severity: protocol.SeverityError, Error: err.Unwrap(), + Tags: tagsForError(err), }) return nil // allow the compiler to continue @@ -66,6 +79,7 @@ func (dr *DiagnosticHandler) HandleWarning(err reporter.ErrorWithPos) { Pos: pos, Severity: protocol.SeverityWarning, Error: err.Unwrap(), + Tags: tagsForError(err), }) } diff --git a/protols/gen/gen.go b/protols/gen/gen.go deleted file mode 100644 index d409e88..0000000 --- a/protols/gen/gen.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright (c) HashiCorp, Inc. -// SPDX-License-Identifier: MPL-2.0 - -//go:build generate -// +build generate - -package main - -import ( - "fmt" - "io/ioutil" - "log" - "net/http" - "os" - "path/filepath" -) - -const ( - goplsRef = "gopls/v0.10.0" - urlFmt = "https://raw.githubusercontent.com/golang/tools" + - "/%s/gopls/internal/lsp/protocol/%s" -) - -func main() { - args := os.Args[1:] - if len(args) > 1 && args[0] == "--" { - args = args[1:] - } - - if len(args) != 2 { - log.Fatalf("expected exactly 2 arguments (source filename & target path), given: %q", args) - } - - sourceFilename := args[0] - - targetFilename, err := filepath.Abs(args[1]) - if err != nil { - log.Fatal(err) - } - - url := fmt.Sprintf(urlFmt, goplsRef, sourceFilename) - - resp, err := http.Get(url) - if err != nil { - log.Fatal(err) - } - - if resp.StatusCode != 200 { - log.Fatalf("status code: %d (%s)", resp.StatusCode, url) - } - - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - log.Fatalf("failed reading body: %s", err) - } - - f, err := os.Create(targetFilename) - if err != nil { - log.Fatalf("failed to create file: %s", err) - } - - n, err := f.Write(b) - - fmt.Printf("%d bytes written to %s\n", n, targetFilename) -} diff --git a/protols/server.go b/protols/server.go index dd60087..3508f0e 100644 --- a/protols/server.go +++ b/protols/server.go @@ -77,6 +77,12 @@ func (s *Server) Initialize(ctx context.Context, params *protocol.ParamInitializ }, InlayHintProvider: true, DocumentLinkProvider: &protocol.DocumentLinkOptions{}, + DocumentFormattingProvider: &protocol.Or_ServerCapabilities_documentFormattingProvider{ + Value: protocol.DocumentFormattingOptions{}, + }, + DocumentRangeFormattingProvider: &protocol.Or_ServerCapabilities_documentRangeFormattingProvider{ + Value: protocol.DocumentRangeFormattingOptions{}, + }, // DeclarationProvider: &protocol.Or_ServerCapabilities_declarationProvider{Value: true}, // TypeDefinitionProvider: true, // ReferencesProvider: true, @@ -248,9 +254,10 @@ func (s *Server) Hover(ctx context.Context, params *protocol.HoverParams) (resul return nil, err } printer := protoprint.Printer{ - SortElements: true, - Indent: " ", - Compact: true, + SortElements: true, + CustomSortFunction: SortElements, + Indent: " ", + Compact: protoprint.CompactDefault, } str, err := printer.PrintProtoToString(wrap) if err != nil { @@ -475,8 +482,8 @@ func (*Server) FoldingRange(context.Context, *protocol.FoldingRangeParams) ([]pr } // Formatting implements protocol.Server. -func (*Server) Formatting(context.Context, *protocol.DocumentFormattingParams) ([]protocol.TextEdit, error) { - return nil, jsonrpc2.ErrMethodNotFound +func (s *Server) Formatting(ctx context.Context, params *protocol.DocumentFormattingParams) ([]protocol.TextEdit, error) { + return s.c.FormatDocument(params.TextDocument, params.Options) } // Implementation implements protocol.Server. @@ -555,8 +562,8 @@ func (*Server) Progress(context.Context, *protocol.ProgressParams) error { } // RangeFormatting implements protocol.Server. -func (*Server) RangeFormatting(context.Context, *protocol.DocumentRangeFormattingParams) ([]protocol.TextEdit, error) { - return nil, jsonrpc2.ErrMethodNotFound +func (s *Server) RangeFormatting(ctx context.Context, params *protocol.DocumentRangeFormattingParams) ([]protocol.TextEdit, error) { + return s.c.FormatDocument(params.TextDocument, params.Options, params.Range) } // References implements protocol.Server. diff --git a/protols/sort.go b/protols/sort.go new file mode 100644 index 0000000..b4a13ce --- /dev/null +++ b/protols/sort.go @@ -0,0 +1,38 @@ +package main + +import "github.com/jhump/protoreflect/desc/protoprint" + +// Sort logic w.r.t. the file: +// 1. Package is always first, and imports are always second +// 2. Builtin options are always third. Custom options are not sorted. +// 3. Messages, enums, services, and extensions are not sorted. + +// Sort logic w.r.t. individual elements: +// 1. If the elements have a number, sort by number w.r.t. other numbered elements. +// 2. Otherwise, they are not sorted (i.e. they are left in the order they appear in the file). + +func SortElements(a, b protoprint.Element) (less bool) { + // First, sort by kind of element. The "less" function will return true if a should come before b. + if a.Kind() != b.Kind() && a.Kind() <= protoprint.KindOption && b.Kind() <= protoprint.KindOption { + return a.Kind() < b.Kind() + } + // At this point, a and b are of the same kind. We apply different sorting rules based on the kind. + switch a.Kind() { + case protoprint.KindOption: + // Builtin options come before custom options. + if a.IsCustomOption() != b.IsCustomOption() { + return !a.IsCustomOption() && b.IsCustomOption() + } + // If both are builtin or both are custom, do not sort. + return false + case protoprint.KindField, protoprint.KindExtension, protoprint.KindEnumValue: + // Sort by number. + return a.Number() < b.Number() + case protoprint.KindImport: + // Sort by path. + return a.Name() < b.Name() + default: + // For all other kinds, do not sort. + return false + } +} diff --git a/protols/source.go b/protols/source.go index 8945bb1..8a128f8 100644 --- a/protols/source.go +++ b/protols/source.go @@ -32,6 +32,7 @@ func findNodeAtSourcePos(file *ast.FileNode, pos ast.SourcePos) []ast.Node { if err != nil && err != sentinel { return nil } + return path } diff --git a/protols/synthesis.go b/protols/synthesis.go new file mode 100644 index 0000000..3006e2a --- /dev/null +++ b/protols/synthesis.go @@ -0,0 +1,225 @@ +package main + +import ( + "bytes" + "compress/gzip" + "fmt" + goast "go/ast" + goparser "go/parser" + "go/token" + "io" + "io/fs" + "os" + "path/filepath" + "strconv" + "strings" + + "go.uber.org/zap" + "golang.org/x/mod/module" + "golang.org/x/tools/pkg/diff" + "golang.org/x/tools/pkg/gocommand" + "golang.org/x/tools/pkg/imports" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/types/descriptorpb" +) + +// creates proto files out of thin air +type ProtoSourceSynthesizer struct { + processEnv *imports.ProcessEnv + moduleResolver *imports.ModuleResolver + resolver protodesc.Resolver + knownAlternativePackages [][]diff.Edit +} + +func NewProtoSourceSynthesizer(workdir string) *ProtoSourceSynthesizer { + env := map[string]string{} + for _, key := range requiredGoEnvVars { + if v, ok := os.LookupEnv(key); ok { + env[key] = v + } + } + procEnv := &imports.ProcessEnv{ + GocmdRunner: &gocommand.Runner{}, + Env: env, + ModFile: filepath.Join(workdir, "go.mod"), + ModFlag: "readonly", + WorkingDir: workdir, + Logf: zap.S().Debugf, + } + res, err := procEnv.GetResolver() + if err != nil { + panic(err) + } + resolver := res.(*imports.ModuleResolver) + + resolver.ClearForNewMod() + + return &ProtoSourceSynthesizer{ + processEnv: procEnv, + moduleResolver: resolver, + } +} + +func (s *ProtoSourceSynthesizer) SetResolver(resolver protodesc.Resolver) { + // needs to be called afterwards, since this *is* part of the resolver + s.resolver = resolver +} + +func (s *ProtoSourceSynthesizer) ImportFromGoModule(importName string) (_str string, _dir string, _err error) { + fmt.Println("tryGoImport", importName) + defer func() { fmt.Println("tryGoImport done", _str, _err) }() + + last := strings.LastIndex(importName, "/") + if last == -1 { + return "", "", fmt.Errorf("%w: %s", os.ErrNotExist, "not a go import") + } + filename := importName[last+1:] + if !strings.HasSuffix(filename, ".proto") { + return "", "", fmt.Errorf("%w: %s", os.ErrNotExist, "not a .proto file") + } + + // check if the path (excluding the filename) is a well-formed go module + importPath := importName[:last] + if err := module.CheckImportPath(importPath); err != nil { + return "", "", fmt.Errorf("%w: %s", os.ErrNotExist, err) + } + + pkgData, dir := s.moduleResolver.FindPackage(importPath) + if pkgData == nil || dir == "" { + for _, edits := range s.knownAlternativePackages { + edited, err := diff.Apply(importPath, edits) + fmt.Printf("tryGoImport > %q not found, trying %q instead based on previously detected patterns\n", importPath, edited) + if err == nil { + pkgData, dir = s.moduleResolver.FindPackage(edited) + if pkgData != nil && dir != "" { + fmt.Println("tryGoImport > successfully found", edited) + goto edit_success + } + } + } + return "", "", fmt.Errorf("%w: %s", os.ErrNotExist, "no packages found") + } +edit_success: + fmt.Println("tryGoImport > pkgData", pkgData) + + // We now have a valid go package. First check if there's a .proto file in the package. + // If there is, we're done. + if _, err := os.Stat(filepath.Join(dir, filename)); err == nil { + // thank god + return filepath.Join(dir, filename), dir, nil + } + return "", dir, fmt.Errorf("%w: %s", os.ErrNotExist, "no .proto file found") +} + +func (s *ProtoSourceSynthesizer) SynthesizeFromGoSource(importName string, dir string) (desc *descriptorpb.FileDescriptorProto, _err error) { + // buckle up + fset := token.NewFileSet() + packages, err := goparser.ParseDir(fset, dir, func(fi fs.FileInfo) bool { + if strings.HasSuffix(fi.Name(), "_test.go") { + return false + } + return strings.HasSuffix(fi.Name(), ".pb.go") && !strings.HasSuffix(fi.Name(), "_grpc.pb.go") + }, goparser.ParseComments) + if err != nil { + return nil, fmt.Errorf("%w: %s", os.ErrNotExist, err) + } + if len(packages) != 1 { + return nil, fmt.Errorf("wrong number of packages found: %d", len(packages)) + } + var rawDescByteArray *goast.Object + fmt.Println(">> [OK] found packages:", packages) +PACKAGES: + for _, pkg := range packages { + // we're looking for the byte array that contains the raw file descriptor + // it's named "file__rawDesc" where is the import path + // used when compiling the generated code, with slashes replaced by underscores. + // e.g. file_example_com_foo_bar_baz_proto_rawDesc => "example.com/foo/bar/baz.proto" + // only one catch: the go package path is not necessarily the same as the import path. + // luckily, there's a comment at the top of the file that tells us what the import path is. + // it looks like "// source: example.com/foo/bar/baz.proto" + for _, f := range pkg.Files { + for _, comment := range f.Comments { + text := comment.Text() + _, path, ok := strings.Cut(text, "source: ") + path = strings.TrimSpace(path) + if !ok || !strings.HasSuffix(path, ".proto") { + continue + } + + // found a possible match, check if there's a symbol with the right name + symbolName := fmt.Sprintf("file_%s_rawDesc", strings.ReplaceAll(strings.ReplaceAll(path, "/", "_"), ".", "_")) + object := f.Scope.Lookup(symbolName) + if object != nil && object.Kind == goast.Var { + // found it! + rawDescByteArray = object + break PACKAGES + } + } + } + } + if rawDescByteArray == nil { + return nil, fmt.Errorf("%w: %s", os.ErrNotExist, "could not find file descriptor in package") + } + fmt.Println(">> [OK] found ast object") + // we have the raw descriptor byte array, which is just a bunch of hex numbers in a slice + // which we can decode from the ast. + // The ast for the byte array will look like: + // *ast.Object { + // Kind: var + // Name: "file__rawDesc" + // Decl: *ast.ValueSpec { + // Values: []ast.Expr (len = 1) { + // 0: *ast.CompositeLit { + // Elts: []ast.Expr (len = {len}) { + // 0: *ast.BasicLit { + // Value: "0x0a" + // } + // 1: *ast.BasicLit { + // Value: "0x2c" + // } + // ... + elements := rawDescByteArray.Decl.(*goast.ValueSpec).Values[0].(*goast.CompositeLit).Elts + buf := bytes.NewBuffer(make([]byte, 0, 4096)) + for _, b := range elements { + str := b.(*goast.BasicLit).Value + i, err := strconv.ParseUint(str, 0, 8) + if err != nil { + return nil, fmt.Errorf("%w: %s", os.ErrNotExist, err) + } + buf.WriteByte(byte(i)) + } + fmt.Println(">> [OK] decoded byte array") + + // now we have a byte array containing the raw file descriptor, which we can unmarshal + // into a FileDescriptorProto. + // the buffer may or may not be gzipped, so we need to check that first. + var reader io.Reader = buf + if bytes.HasPrefix(buf.Bytes(), []byte{0x1f, 0x8b}) { + reader, err = gzip.NewReader(buf) + if err != nil { + return nil, fmt.Errorf("%w: %s", os.ErrNotExist, err) + } + } + decompressedBytes, err := io.ReadAll(reader) + if err != nil { + return nil, fmt.Errorf("%w: %s", os.ErrNotExist, err) + } + + fd := &descriptorpb.FileDescriptorProto{} + if err := proto.Unmarshal(decompressedBytes, fd); err != nil { + return nil, fmt.Errorf("%w: %s", os.ErrNotExist, err) + } + fmt.Println(">> [OK] decoded raw file descriptor") + if fd.GetName() != importName { + // this package uses an alternate import path. we need to keep track of this + // in case any of its dependencies use a similar path structure. + alternateImportPath := fd.GetName() + resolvedImportPath := importName + edits := diff.Strings(alternateImportPath, resolvedImportPath) + s.knownAlternativePackages = append(s.knownAlternativePackages, edits) + + *fd.Name = importName + } + return fd, nil +}