From e1afd44944ddc75cc7b3ad62824121c7a6646e93 Mon Sep 17 00:00:00 2001 From: Joshua Humphries <2035234+jhump@users.noreply.github.com> Date: Wed, 27 Mar 2024 10:12:00 -0400 Subject: [PATCH] Remove canonical bytes feature (#262) This was never actually used in the buf CLI. And it's rather complicated and kind of a lot to maintain for an effectively unused feature. It also had several bugs, so instead of merging fixes for those bugs, we're just removing the whole feature. --- internal/benchmarks/benchmark_test.go | 31 +- internal/prototest/util.go | 27 +- internal/tags.go | 5 + internal/testdata/options/options.proto | 2 +- internal/util.go | 6 + linker/descriptors.go | 140 ---- linker/linker.go | 25 - linker/linker_test.go | 2 +- options/options.go | 993 ++++++------------------ options/options_test.go | 68 +- sourceinfo/source_code_info.go | 61 +- sourceinfo/source_code_info_test.go | 7 +- 12 files changed, 312 insertions(+), 1055 deletions(-) diff --git a/internal/benchmarks/benchmark_test.go b/internal/benchmarks/benchmark_test.go index a953b961..3092f8f2 100644 --- a/internal/benchmarks/benchmark_test.go +++ b/internal/benchmarks/benchmark_test.go @@ -46,7 +46,6 @@ import ( "github.com/bufbuild/protocompile" "github.com/bufbuild/protocompile/ast" "github.com/bufbuild/protocompile/internal/protoc" - "github.com/bufbuild/protocompile/linker" "github.com/bufbuild/protocompile/parser" "github.com/bufbuild/protocompile/parser/fastscan" "github.com/bufbuild/protocompile/protoutil" @@ -235,7 +234,7 @@ func downloadAndExpand(url, targetDir string) (e error) { } func BenchmarkGoogleapisProtocompile(b *testing.B) { - benchmarkGoogleapisProtocompile(b, false, func() *protocompile.Compiler { + benchmarkGoogleapisProtocompile(b, func() *protocompile.Compiler { return &protocompile.Compiler{ Resolver: protocompile.WithStandardImports(&protocompile.SourceResolver{ ImportPaths: []string{googleapisDir}, @@ -246,20 +245,8 @@ func BenchmarkGoogleapisProtocompile(b *testing.B) { }) } -func BenchmarkGoogleapisProtocompileCanonical(b *testing.B) { - benchmarkGoogleapisProtocompile(b, true, func() *protocompile.Compiler { - return &protocompile.Compiler{ - Resolver: protocompile.WithStandardImports(&protocompile.SourceResolver{ - ImportPaths: []string{googleapisDir}, - }), - SourceInfoMode: protocompile.SourceInfoStandard, - // leave MaxParallelism unset to let it use all cores available - } - }) -} - func BenchmarkGoogleapisProtocompileNoSourceInfo(b *testing.B) { - benchmarkGoogleapisProtocompile(b, false, func() *protocompile.Compiler { + benchmarkGoogleapisProtocompile(b, func() *protocompile.Compiler { return &protocompile.Compiler{ Resolver: protocompile.WithStandardImports(&protocompile.SourceResolver{ ImportPaths: []string{googleapisDir}, @@ -270,23 +257,19 @@ func BenchmarkGoogleapisProtocompileNoSourceInfo(b *testing.B) { }) } -func benchmarkGoogleapisProtocompile(b *testing.B, canonicalBytes bool, factory func() *protocompile.Compiler) { +func benchmarkGoogleapisProtocompile(b *testing.B, factory func() *protocompile.Compiler) { for i := 0; i < b.N; i++ { - benchmarkProtocompile(b, factory(), googleapisSources, canonicalBytes) + benchmarkProtocompile(b, factory(), googleapisSources) } } -func benchmarkProtocompile(b *testing.B, c *protocompile.Compiler, sources []string, canonicalBytes bool) { +func benchmarkProtocompile(b *testing.B, c *protocompile.Compiler, sources []string) { fds, err := c.Compile(context.Background(), sources...) require.NoError(b, err) var fdSet descriptorpb.FileDescriptorSet fdSet.File = make([]*descriptorpb.FileDescriptorProto, len(fds)) for i, fd := range fds { - if canonicalBytes { - fdSet.File[i] = fd.(linker.Result).CanonicalProto() - } else { - fdSet.File[i] = protoutil.ProtoFromFileDescriptor(fd) - } + fdSet.File[i] = protoutil.ProtoFromFileDescriptor(fd) } // protoc is writing output to file descriptor set, so we should, too writeToNull(b, &fdSet) @@ -484,7 +467,7 @@ func BenchmarkGoogleapisProtocompileSingleThreaded(b *testing.B) { // need to run a single-threaded compile MaxParallelism: 1, } - benchmarkProtocompile(b, c, googleapisSources, false) + benchmarkProtocompile(b, c, googleapisSources) } }) } diff --git a/internal/prototest/util.go b/internal/prototest/util.go index a21fca7c..675b226b 100644 --- a/internal/prototest/util.go +++ b/internal/prototest/util.go @@ -15,7 +15,6 @@ package prototest import ( - "fmt" "os" "testing" @@ -73,27 +72,11 @@ func findFileInSet(fps *descriptorpb.FileDescriptorSet, name string) *descriptor return nil } -func AssertMessagesEqual(t *testing.T, exp, act proto.Message, msgAndArgs ...interface{}) { +func AssertMessagesEqual(t *testing.T, exp, act proto.Message, description string) bool { t.Helper() - AssertMessagesEqualWithOptions(t, exp, act, nil, msgAndArgs...) -} - -func AssertMessagesEqualWithOptions(t *testing.T, exp, act proto.Message, opts []cmp.Option, msgAndArgs ...interface{}) { - t.Helper() - cmpOpts := []cmp.Option{protocmp.Transform()} - cmpOpts = append(cmpOpts, opts...) - if diff := cmp.Diff(exp, act, cmpOpts...); diff != "" { - var prefix string - if len(msgAndArgs) == 1 { - if msg, ok := msgAndArgs[0].(string); ok { - prefix = msg + ": " - } else { - prefix = fmt.Sprintf("%+v: ", msgAndArgs[0]) - } - } else if len(msgAndArgs) > 1 { - prefix = fmt.Sprintf(msgAndArgs[0].(string)+": ", msgAndArgs[1:]...) - } - - t.Errorf("%smessage mismatch (-want +got):\n%v", prefix, diff) + if diff := cmp.Diff(exp, act, protocmp.Transform()); diff != "" { + t.Errorf("%s: message mismatch (-want, +got):\n%s", description, diff) + return false } + return true } diff --git a/internal/tags.go b/internal/tags.go index 869f9bdb..0f3960f9 100644 --- a/internal/tags.go +++ b/internal/tags.go @@ -243,4 +243,9 @@ const ( // UninterpretedNameNameTag is the tag number of the name element in an // uninterpreted option name proto. UninterpretedNameNameTag = 1 + + // AnyTypeURLTag is the tag number of the type_url field of the Any proto. + AnyTypeURLTag = 1 + // AnyValueTag is the tag number of the value field of the Any proto. + AnyValueTag = 2 ) diff --git a/internal/testdata/options/options.proto b/internal/testdata/options/options.proto index f5a58c90..7be198f2 100644 --- a/internal/testdata/options/options.proto +++ b/internal/testdata/options/options.proto @@ -227,4 +227,4 @@ option (any) = { pr_i32: [0,1,2,3] str: "foo" } -}; \ No newline at end of file +}; diff --git a/internal/util.go b/internal/util.go index a590e9b7..569cb3f1 100644 --- a/internal/util.go +++ b/internal/util.go @@ -230,6 +230,12 @@ func CanPack(k protoreflect.Kind) bool { } } +func ClonePath(path protoreflect.SourcePath) protoreflect.SourcePath { + clone := make(protoreflect.SourcePath, len(path)) + copy(clone, path) + return clone +} + func reverse(p protoreflect.SourcePath) protoreflect.SourcePath { for i, j := 0, len(p)-1; i < j; i, j = i+1, j-1 { p[i], p[j] = p[j], p[i] diff --git a/linker/descriptors.go b/linker/descriptors.go index d1343e33..8616c171 100644 --- a/linker/descriptors.go +++ b/linker/descriptors.go @@ -151,11 +151,6 @@ type result struct { // interpreting options. usedImports map[string]struct{} - // A map of descriptor options messages to their pre-serialized bytes (using - // a canonical serialization format based on how protoc renders options to - // bytes). - optionBytes map[proto.Message][]byte - // A map of AST nodes that represent identifiers in ast.FieldReferenceNodes // to their fully-qualified name. The identifiers are for field names in // message literals (in option values) that are extension fields. These names @@ -316,141 +311,6 @@ func asSourceLocations(srcInfoProtos []*descriptorpb.SourceCodeInfo_Location) [] return locs } -// AddOptionBytes associates the given opts (an options message encoded in the -// binary format) with the given options protobuf message. The protobuf message -// should exist in the hierarchy of this result's FileDescriptorProto. This -// allows the FileDescriptorProto to be marshaled to bytes in a way that -// preserves the way options are defined in source (just as is done by protoc, -// but not possible when only using the generated Go types and standard -// marshaling APIs in the protobuf runtime). -func (r *result) AddOptionBytes(pm proto.Message, opts []byte) { - if r.optionBytes == nil { - r.optionBytes = map[proto.Message][]byte{} - } - r.optionBytes[pm] = append(r.optionBytes[pm], opts...) -} - -func (r *result) CanonicalProto() *descriptorpb.FileDescriptorProto { - origFd := r.FileDescriptorProto() - // make a copy that we can mutate - fd := proto.Clone(origFd).(*descriptorpb.FileDescriptorProto) //nolint:errcheck - - r.storeOptionBytesInFile(fd, origFd) - - return fd -} - -func (r *result) storeOptionBytes(opts, origOpts proto.Message) { - optionBytes := r.optionBytes[origOpts] - if len(optionBytes) == 0 { - // If we don't know about this options message, leave it alone. - return - } - proto.Reset(opts) - opts.ProtoReflect().SetUnknown(optionBytes) -} - -func (r *result) storeOptionBytesInFile(fd, origFd *descriptorpb.FileDescriptorProto) { - if fd.Options != nil { - r.storeOptionBytes(fd.Options, origFd.Options) - } - - for i, md := range fd.MessageType { - origMd := origFd.MessageType[i] - r.storeOptionBytesInMessage(md, origMd) - } - - for i, ed := range fd.EnumType { - origEd := origFd.EnumType[i] - r.storeOptionBytesInEnum(ed, origEd) - } - - for i, exd := range fd.Extension { - origExd := origFd.Extension[i] - r.storeOptionBytesInField(exd, origExd) - } - - for i, sd := range fd.Service { - origSd := origFd.Service[i] - if sd.Options != nil { - r.storeOptionBytes(sd.Options, origSd.Options) - } - - for j, mtd := range sd.Method { - origMtd := origSd.Method[j] - if mtd.Options != nil { - r.storeOptionBytes(mtd.Options, origMtd.Options) - } - } - } -} - -func (r *result) storeOptionBytesInMessage(md, origMd *descriptorpb.DescriptorProto) { - if md.GetOptions().GetMapEntry() { - // Map entry messages are synthesized. They won't have any option bytes - // since they don't actually appear in the source and thus have any option - // declarations in the source. - return - } - - if md.Options != nil { - r.storeOptionBytes(md.Options, origMd.Options) - } - - for i, fld := range md.Field { - origFld := origMd.Field[i] - r.storeOptionBytesInField(fld, origFld) - } - - for i, ood := range md.OneofDecl { - origOod := origMd.OneofDecl[i] - if ood.Options != nil { - r.storeOptionBytes(ood.Options, origOod.Options) - } - } - - for i, exr := range md.ExtensionRange { - origExr := origMd.ExtensionRange[i] - if exr.Options != nil { - r.storeOptionBytes(exr.Options, origExr.Options) - } - } - - for i, nmd := range md.NestedType { - origNmd := origMd.NestedType[i] - r.storeOptionBytesInMessage(nmd, origNmd) - } - - for i, ed := range md.EnumType { - origEd := origMd.EnumType[i] - r.storeOptionBytesInEnum(ed, origEd) - } - - for i, exd := range md.Extension { - origExd := origMd.Extension[i] - r.storeOptionBytesInField(exd, origExd) - } -} - -func (r *result) storeOptionBytesInEnum(ed, origEd *descriptorpb.EnumDescriptorProto) { - if ed.Options != nil { - r.storeOptionBytes(ed.Options, origEd.Options) - } - - for i, evd := range ed.Value { - origEvd := origEd.Value[i] - if evd.Options != nil { - r.storeOptionBytes(evd.Options, origEvd.Options) - } - } -} - -func (r *result) storeOptionBytesInField(fld, origFld *descriptorpb.FieldDescriptorProto) { - if fld.Options != nil { - r.storeOptionBytes(fld.Options, origFld.Options) - } -} - type fileImports struct { protoreflect.FileImports files []protoreflect.FileImport diff --git a/linker/linker.go b/linker/linker.go index 8304cde1..0097339d 100644 --- a/linker/linker.go +++ b/linker/linker.go @@ -18,7 +18,6 @@ import ( "fmt" "google.golang.org/protobuf/reflect/protoreflect" - "google.golang.org/protobuf/types/descriptorpb" "github.com/bufbuild/protocompile/ast" "github.com/bufbuild/protocompile/parser" @@ -128,30 +127,6 @@ type Result interface { // interpreting options (which is done after linking). PopulateSourceCodeInfo() - // CanonicalProto returns the file descriptor proto in a form that - // will be serialized in a canonical way. The "canonical" way matches - // the way that "protoc" emits option values, which is a way that - // mostly matches the way options are defined in source, including - // ordering and de-structuring. Unlike the FileDescriptorProto() method, - // this method is more expensive and results in a new descriptor proto - // being constructed with each call. - // - // The returned value will have all options (fields of the various - // descriptorpb.*Options message types) represented via unrecognized - // fields. So the returned value will serialize as desired, but it - // is otherwise not useful since all option values are treated as - // unknown. - // - // Note that CanonicalProto is a no-op if the options in this file - // were not interpreted by this module (e.g. the underlying descriptor - // proto was provided, with options already interpreted, instead of - // parsed from source). If the underlying descriptor proto was provided, - // but with a mix of interpreted and uninterpreted options, this method - // will effectively clear the already-interpreted fields and only the - // options actually interpreted by the compile operation will be - // retained. - CanonicalProto() *descriptorpb.FileDescriptorProto - // RemoveAST drops the AST information from this result. RemoveAST() } diff --git a/linker/linker_test.go b/linker/linker_test.go index 24790078..a7a20020 100644 --- a/linker/linker_test.go +++ b/linker/linker_test.go @@ -2213,7 +2213,7 @@ func TestLinkerValidation(t *testing.T) { } `, }, - expectedErr: `test.proto:3:18: feature "enum_type" is allowed on [enum,file], not on field`, + expectedErr: `test.proto:3:27: feature "enum_type" is allowed on [enum,file], not on field`, }, "failure_editions_feature_on_wrong_target_type_msg_literal": { input: map[string]string{ diff --git a/options/options.go b/options/options.go index feb6c5cc..103531af 100644 --- a/options/options.go +++ b/options/options.go @@ -30,11 +30,9 @@ import ( "errors" "fmt" "math" - "sort" "strings" "google.golang.org/protobuf/encoding/prototext" - "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" @@ -62,11 +60,11 @@ var ( type interpreter struct { file file resolver linker.Resolver - container optionsContainer overrideDescriptorProto linker.File lenient bool reporter *reporter.Handler index sourceinfo.OptionIndex + pathBuffer []int32 } type file interface { @@ -138,13 +136,13 @@ func InterpretUnlinkedOptions(parsed parser.Result, opts ...InterpreterOption) ( func interpretOptions(lenient bool, file file, res linker.Resolver, handler *reporter.Handler, interpOpts []InterpreterOption) (sourceinfo.OptionIndex, error) { interp := interpreter{ - file: file, - resolver: res, - lenient: lenient, - reporter: handler, - index: sourceinfo.OptionIndex{}, + file: file, + resolver: res, + lenient: lenient, + reporter: handler, + index: sourceinfo.OptionIndex{}, + pathBuffer: make([]int32, 0, 16), } - interp.container, _ = file.(optionsContainer) for _, opt := range interpOpts { opt(&interp) } @@ -536,326 +534,6 @@ func (interp *interpreter) interpretEnumOptions(fqn string, ed *descriptorpb.Enu return nil } -// interpretedOption represents the result of interpreting an option. -// This includes metadata that allows the option to be serialized to -// bytes in a way that is deterministic and can preserve the structure -// of the source (the way the options are de-structured and the order in -// which options appear). -type interpretedOption struct { - unknown bool - pathPrefix []interpretedEnclosingTag - interpretedField -} - -func (o *interpretedOption) toSourceInfo() *sourceinfo.OptionSourceInfo { - var path []int32 - if len(o.pathPrefix) > 0 { - path = make([]int32, len(o.pathPrefix)) - for i := range o.pathPrefix { - path[i] = o.pathPrefix[i].tag - } - } - return o.interpretedField.toSourceInfo(path) -} - -func (o *interpretedOption) appendOptionBytes(b []byte) ([]byte, error) { - return o.appendOptionBytesWithPath(b, o.pathPrefix) -} - -func (o *interpretedOption) appendOptionBytesWithPath(b []byte, path []interpretedEnclosingTag) ([]byte, error) { - if len(path) == 0 { - return appendOptionBytesSingle(b, &o.interpretedField) - } - // NB: if we add functions to compute sizes of the options first, we could - // allocate precisely sized slice up front, which would be more efficient than - // repeated creation/growing/concatenation. - enclosed, err := o.appendOptionBytesWithPath(nil, path[1:]) - if err != nil { - return nil, err - } - if path[0].kind == protoreflect.GroupKind { - b = protowire.AppendTag(b, protowire.Number(path[0].tag), protowire.StartGroupType) - b = append(b, enclosed...) - b = protowire.AppendTag(b, protowire.Number(path[0].tag), protowire.EndGroupType) - } else { - b = protowire.AppendTag(b, protowire.Number(path[0].tag), protowire.BytesType) - b = protowire.AppendBytes(b, enclosed) - } - return b, nil -} - -type interpretedEnclosingTag struct { - tag int32 - kind protoreflect.Kind // indicates whether we need to use group encoding or not -} - -// interpretedField represents a field in an options message that is the -// result of interpreting an option. This is used for the option value -// itself as well as for subfields when an option value is a message -// literal. -type interpretedField struct { - // the AST node for this field -- an [*ast.OptionNode] for top-level options, - // an [*ast.MessageFieldNode] for fields in a message literal, or nil for - // synthetic field values (for keys or values in map entries that were - // omitted from source). - node ast.Node - // field number - number int32 - // index of this element inside a repeated field; only set if repeated == true - index int32 - // true if this is a repeated field - repeated bool - packed bool - kind protoreflect.Kind - - value interpretedFieldValue -} - -func (f *interpretedField) path(prefix []int32) []int32 { - path := make([]int32, 0, len(prefix)+2) - path = append(path, prefix...) - path = append(path, f.number) - if f.repeated { - path = append(path, f.index) - } - return path -} - -func (f *interpretedField) toSourceInfo(prefix []int32) *sourceinfo.OptionSourceInfo { - path := f.path(prefix) - var children sourceinfo.OptionChildrenSourceInfo - if len(f.value.msgListVal) > 0 { - elements := make([]sourceinfo.OptionSourceInfo, len(f.value.msgListVal)) - for i, msgVal := range f.value.msgListVal { - // With an array literal, the index in path is that of the first element. - elementPath := append(([]int32)(nil), path...) - elementPath[len(elementPath)-1] += int32(i) - elements[i].Path = elementPath - elements[i].Children = msgSourceInfo(elementPath, msgVal) - } - children = &sourceinfo.ArrayLiteralSourceInfo{Elements: elements} - } else if len(f.value.msgVal) > 0 { - children = msgSourceInfo(path, f.value.msgVal) - } - return &sourceinfo.OptionSourceInfo{ - Path: path, - Children: children, - } -} - -func msgSourceInfo(prefix []int32, fields []*interpretedField) *sourceinfo.MessageLiteralSourceInfo { - fieldInfo := map[*ast.MessageFieldNode]*sourceinfo.OptionSourceInfo{} - for _, field := range fields { - msgFieldNode, ok := field.node.(*ast.MessageFieldNode) - if !ok { - continue - } - fieldInfo[msgFieldNode] = field.toSourceInfo(prefix) - } - return &sourceinfo.MessageLiteralSourceInfo{Fields: fieldInfo} -} - -// interpretedFieldValue is a wrapper around protoreflect.Value that -// includes extra metadata. -type interpretedFieldValue struct { - // the bytes for this field value if already pre-serialized - // (when this is set, the other fields are ignored) - preserialized []byte - - // the field value - val protoreflect.Value - // if true, this value is a list of values, not a singular value - isList bool - // non-nil for singular message values - msgVal []*interpretedField - // non-nil for non-empty lists of message values - msgListVal [][]*interpretedField -} - -func appendOptionBytes(b []byte, flds []*interpretedField) ([]byte, error) { - // protoc emits messages sorted by field number - if len(flds) > 1 { - sort.SliceStable(flds, func(i, j int) bool { - return flds[i].number < flds[j].number - }) - } - - for i := 0; i < len(flds); i++ { - f := flds[i] - if f.value.preserialized != nil { - b = append(b, f.value.preserialized...) - continue - } - switch { - case f.packed && internal.CanPack(f.kind): - // for packed repeated numeric fields, all runs of values are merged into one packed list - num := f.number - j := i - for j < len(flds) && flds[j].number == num { - j++ - } - // now flds[i:j] is the range of contiguous fields for the same field number - enclosed, err := appendOptionBytesPacked(nil, f.kind, flds[i:j]) - if err != nil { - return nil, err - } - b = protowire.AppendTag(b, protowire.Number(f.number), protowire.BytesType) - b = protowire.AppendBytes(b, enclosed) - // skip over the other subsequent fields we just serialized - i = j - 1 - case f.value.isList: - // if not packed, then emit one value at a time - single := *f - single.value.isList = false - single.value.msgListVal = nil - l := f.value.val.List() - for i := 0; i < l.Len(); i++ { - single.value.val = l.Get(i) - if f.kind == protoreflect.MessageKind || f.kind == protoreflect.GroupKind { - single.value.msgVal = f.value.msgListVal[i] - } - var err error - b, err = appendOptionBytesSingle(b, &single) - if err != nil { - return nil, err - } - } - default: - // simple singular value - var err error - b, err = appendOptionBytesSingle(b, f) - if err != nil { - return nil, err - } - } - } - - return b, nil -} - -func appendOptionBytesPacked(b []byte, k protoreflect.Kind, flds []*interpretedField) ([]byte, error) { - for i := range flds { - val := flds[i].value - if val.isList { - l := val.val.List() - var err error - b, err = appendNumericValueBytesPacked(b, k, l) - if err != nil { - return nil, err - } - } else { - var err error - b, err = appendNumericValueBytes(b, k, val.val) - if err != nil { - return nil, err - } - } - } - return b, nil -} - -func appendOptionBytesSingle(b []byte, f *interpretedField) ([]byte, error) { - if f.value.preserialized != nil { - return append(b, f.value.preserialized...), nil - } - num := protowire.Number(f.number) - switch f.kind { - case protoreflect.MessageKind: - enclosed, err := appendOptionBytes(nil, f.value.msgVal) - if err != nil { - return nil, err - } - b = protowire.AppendTag(b, num, protowire.BytesType) - return protowire.AppendBytes(b, enclosed), nil - - case protoreflect.GroupKind: - b = protowire.AppendTag(b, num, protowire.StartGroupType) - var err error - b, err = appendOptionBytes(b, f.value.msgVal) - if err != nil { - return nil, err - } - return protowire.AppendTag(b, num, protowire.EndGroupType), nil - - case protoreflect.StringKind: - b = protowire.AppendTag(b, num, protowire.BytesType) - return protowire.AppendString(b, f.value.val.String()), nil - - case protoreflect.BytesKind: - b = protowire.AppendTag(b, num, protowire.BytesType) - return protowire.AppendBytes(b, f.value.val.Bytes()), nil - - case protoreflect.Int32Kind, protoreflect.Int64Kind, protoreflect.Uint32Kind, protoreflect.Uint64Kind, - protoreflect.Sint32Kind, protoreflect.Sint64Kind, protoreflect.EnumKind, protoreflect.BoolKind: - b = protowire.AppendTag(b, num, protowire.VarintType) - return appendNumericValueBytes(b, f.kind, f.value.val) - - case protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind, protoreflect.FloatKind: - b = protowire.AppendTag(b, num, protowire.Fixed32Type) - return appendNumericValueBytes(b, f.kind, f.value.val) - - case protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind, protoreflect.DoubleKind: - b = protowire.AppendTag(b, num, protowire.Fixed64Type) - return appendNumericValueBytes(b, f.kind, f.value.val) - - default: - return nil, fmt.Errorf("unknown field kind: %v", f.kind) - } -} - -func appendNumericValueBytesPacked(b []byte, k protoreflect.Kind, l protoreflect.List) ([]byte, error) { - for i := 0; i < l.Len(); i++ { - var err error - b, err = appendNumericValueBytes(b, k, l.Get(i)) - if err != nil { - return nil, err - } - } - return b, nil -} - -func appendNumericValueBytes(b []byte, k protoreflect.Kind, v protoreflect.Value) ([]byte, error) { - switch k { - case protoreflect.Int32Kind, protoreflect.Int64Kind: - return protowire.AppendVarint(b, uint64(v.Int())), nil - case protoreflect.Uint32Kind, protoreflect.Uint64Kind: - return protowire.AppendVarint(b, v.Uint()), nil - case protoreflect.Sint32Kind, protoreflect.Sint64Kind: - return protowire.AppendVarint(b, protowire.EncodeZigZag(v.Int())), nil - case protoreflect.Fixed32Kind: - return protowire.AppendFixed32(b, uint32(v.Uint())), nil - case protoreflect.Fixed64Kind: - return protowire.AppendFixed64(b, v.Uint()), nil - case protoreflect.Sfixed32Kind: - return protowire.AppendFixed32(b, uint32(v.Int())), nil - case protoreflect.Sfixed64Kind: - return protowire.AppendFixed64(b, uint64(v.Int())), nil - case protoreflect.FloatKind: - return protowire.AppendFixed32(b, math.Float32bits(float32(v.Float()))), nil - case protoreflect.DoubleKind: - return protowire.AppendFixed64(b, math.Float64bits(v.Float())), nil - case protoreflect.BoolKind: - return protowire.AppendVarint(b, protowire.EncodeBool(v.Bool())), nil - case protoreflect.EnumKind: - return protowire.AppendVarint(b, uint64(v.Enum())), nil - default: - return nil, fmt.Errorf("unknown field kind: %v", k) - } -} - -// optionsContainer may be optionally implemented by a linker.Result. It is -// not part of the linker.Result interface as it is meant only for internal use. -// This allows the option interpreter step to store extra metadata about the -// serialized structure of options. -type optionsContainer interface { - // AddOptionBytes adds the given pre-serialized option bytes to a file, - // associated with the given options message. The type of the given message - // should be an options message, for example *descriptorpb.MessageOptions. - // This value should be part of the message hierarchy whose root is the - // *descriptorpb.FileDescriptorProto that corresponds to this result. - AddOptionBytes(pm proto.Message, opts []byte) -} - func interpretElementOptions[Elem elementType[OptsStruct, Opts], OptsStruct any, Opts optionsType[OptsStruct]]( interp *interpreter, fqn string, @@ -908,8 +586,7 @@ func (interp *interpreter) interpretOptions( ElementType: descriptorType(element), } var remain []*descriptorpb.UninterpretedOption - results := make([]*interpretedOption, 0, len(uninterpreted)) - var featuresInfo []*interpretedOption + var features []*ast.OptionNode for _, uo := range uninterpreted { if uo.Name[0].GetIsExtension() != customOpts { // We're not looking at these this phase. @@ -928,7 +605,7 @@ func (interp *interpreter) interpretOptions( } } mc.Option = uo - res, err := interp.interpretField(mc, msg, uo, 0, nil) + srcInfo, err := interp.interpretField(mc, msg, uo, 0, interp.pathBuffer) if err != nil { if interp.lenient { remain = append(remain, uo) @@ -936,18 +613,17 @@ func (interp *interpreter) interpretOptions( } return nil, err } - res.unknown = !isKnownField(optsDesc, res) - results = append(results, res) - if !uo.Name[0].GetIsExtension() && uo.Name[0].GetNamePart() == featuresFieldName { - featuresInfo = append(featuresInfo, res) - } if optn, ok := node.(*ast.OptionNode); ok { - si := res.toSourceInfo() - interp.index[optn] = si + if !uo.Name[0].GetIsExtension() && uo.Name[0].GetNamePart() == featuresFieldName { + features = append(features, optn) + } + if srcInfo != nil { + interp.index[optn] = srcInfo + } } } - if err := interp.validateFeatures(targetType, msg, featuresInfo); err != nil && !interp.lenient { + if err := interp.validateFeatures(targetType, msg, features); err != nil && !interp.lenient { return nil, err } @@ -966,12 +642,6 @@ func (interp *interpreter) interpretOptions( proto.Reset(opts) proto.Merge(opts, optsClone) - if interp.container != nil { - if err := interp.setOptionBytes(mc, opts, results); err != nil { - return nil, err - } - } - return remain, nil } @@ -988,28 +658,13 @@ func (interp *interpreter) interpretOptions( return nil, interp.reporter.HandleError(reporter.Error(interp.nodeInfo(node), err)) } - if interp.container != nil { - if err := interp.setOptionBytes(mc, opts, results); err != nil { - return nil, err - } - } - return remain, nil } -func (interp *interpreter) setOptionBytes(mc *internal.MessageContext, opts proto.Message, values []*interpretedOption) error { - b, err := interp.toOptionBytes(mc, values) - if err != nil { - return err - } - interp.container.AddOptionBytes(opts, b) - return nil -} - func (interp *interpreter) validateFeatures( targetType descriptorpb.FieldOptions_OptionTargetType, opts protoreflect.Message, - featuresInfo []*interpretedOption, + features []*ast.OptionNode, ) error { fld := opts.Descriptor().Fields().ByName(featuresFieldName) if fld == nil { @@ -1021,9 +676,9 @@ func (interp *interpreter) validateFeatures( // TODO: should this return an error? return nil } - features := opts.Get(fld).Message() + featureSet := opts.Get(fld).Message() var err error - features.Range(func(featureField protoreflect.FieldDescriptor, _ protoreflect.Value) bool { + featureSet.Range(func(featureField protoreflect.FieldDescriptor, _ protoreflect.Value) bool { opts, ok := featureField.Options().(*descriptorpb.FieldOptions) if !ok { return true @@ -1041,7 +696,7 @@ func (interp *interpreter) validateFeatures( for i, t := range opts.Targets { allowedTypes[i] = targetTypeString(t) } - pos := interp.positionOfFeature(featuresInfo, fld.Number(), featureField.Number()) + pos := interp.positionOfFeature(features, featuresFieldName, featureField.Name()) if len(opts.Targets) == 1 && opts.Targets[0] == descriptorpb.FieldOptions_TARGET_TYPE_UNKNOWN { err = interp.reporter.HandleErrorf(pos, "feature field %q may not be used explicitly", featureField.Name()) } else { @@ -1053,55 +708,51 @@ func (interp *interpreter) validateFeatures( return err } -func (interp *interpreter) positionOfFeature(featuresInfo []*interpretedOption, fieldNumbers ...protoreflect.FieldNumber) ast.SourceSpan { +func (interp *interpreter) positionOfFeature(features []*ast.OptionNode, fieldNames ...protoreflect.Name) ast.SourceSpan { if interp.file.AST() == nil { return ast.UnknownSpan(interp.file.FileDescriptorProto().GetName()) } - for _, info := range featuresInfo { - matched, remainingNumbers, node := matchInterpretedOption(info, fieldNumbers) + for _, feature := range features { + matched, remainingNames, nodePos, nodeValue := matchInterpretedOption(feature, fieldNames) if !matched { continue } - if len(remainingNumbers) > 0 { - node = findInterpretedFieldForFeature(&(info.interpretedField), remainingNumbers) + if len(remainingNames) > 0 { + nodePos = findInterpretedFieldForFeature(nodePos, nodeValue, remainingNames) } - if node != nil { - return interp.file.FileNode().NodeInfo(node) + if nodePos != nil { + return interp.file.FileNode().NodeInfo(nodePos) } } return ast.UnknownSpan(interp.file.FileDescriptorProto().GetName()) } -func matchInterpretedOption(info *interpretedOption, path []protoreflect.FieldNumber) (bool, []protoreflect.FieldNumber, ast.Node) { - for i := 0; i < len(path) && i < len(info.pathPrefix); i++ { - if info.pathPrefix[i].tag != int32(path[i]) { - return false, nil, nil - } - } - if len(path) <= len(info.pathPrefix) { - // no more path elements to match - node := info.node - if optsNode, ok := node.(*ast.OptionNode); ok { - // Do we need to check this? It should always be true... - if len(optsNode.Name.Parts) == len(info.pathPrefix)+1 { - node = optsNode.Name.Parts[len(path)-1] - } +func matchInterpretedOption(node *ast.OptionNode, path []protoreflect.Name) (bool, []protoreflect.Name, ast.Node, ast.ValueNode) { + for i := 0; i < len(path) && i < len(node.Name.Parts); i++ { + part := node.Name.Parts[i] + if !part.IsExtension() && protoreflect.Name(part.Name.AsIdentifier()) != path[i] { + return false, nil, nil, nil } - return true, nil, node } - if info.number != int32(path[len(info.pathPrefix)]) { - return false, nil, nil + if len(path) <= len(node.Name.Parts) { + // No more path elements to match. Report location + // of the final element of path inside option name. + return true, nil, node.Name.Parts[len(path)-1], node.Val } - return true, path[len(info.pathPrefix)+1:], info.node + return true, path[len(node.Name.Parts):], node.Name.Parts[len(node.Name.Parts)-1], node.Val } -func findInterpretedFieldForFeature(opt *interpretedField, path []protoreflect.FieldNumber) ast.Node { +func findInterpretedFieldForFeature(nodePos ast.Node, nodeValue ast.ValueNode, path []protoreflect.Name) ast.Node { if len(path) == 0 { - return opt.node + return nodePos + } + msgNode, ok := nodeValue.(*ast.MessageLiteralNode) + if !ok { + return nil } - for _, fld := range opt.value.msgVal { - if fld.number == int32(path[0]) { - if res := findInterpretedFieldForFeature(fld, path[1:]); res != nil { + for _, fldNode := range msgNode.Elements { + if fldNode.Name.Open == nil && protoreflect.Name(fldNode.Name.Name.AsIdentifier()) == path[0] { + if res := findInterpretedFieldForFeature(fldNode.Name, fldNode.Val, path[1:]); res != nil { return res } } @@ -1109,68 +760,6 @@ func findInterpretedFieldForFeature(opt *interpretedField, path []protoreflect.F return nil } -// isKnownField returns true if the given option is for a known field of the -// given options message descriptor and will be serialized using the expected -// wire type for that known field. -func isKnownField(desc protoreflect.MessageDescriptor, opt *interpretedOption) bool { - var num int32 - if len(opt.pathPrefix) > 0 { - num = opt.pathPrefix[0].tag - } else { - num = opt.number - } - fd := desc.Fields().ByNumber(protoreflect.FieldNumber(num)) - if fd == nil { - return false - } - - // Before the full wire type check, we do a quick check that will usually pass - // and allow us to short-circuit the logic below. - if fd.IsList() == opt.repeated && fd.Kind() == opt.kind { - return true - } - - // We figure out the wire type this interpreted field will use when serialized. - var wireType protowire.Type - switch { - case len(opt.pathPrefix) > 0: - // If path prefix exists, this field is nested inside a message. - // And messages use bytes wire type. - wireType = protowire.BytesType - case opt.repeated && opt.packed && internal.CanPack(opt.kind): - // Packed repeated numeric scalars use bytes wire type. - wireType = protowire.BytesType - default: - wireType = wireTypeForKind(opt.kind) - } - - // And then we see if the wire type we just determined is compatible with - // the field descriptor we found. - if fd.IsList() && internal.CanPack(fd.Kind()) && wireType == protowire.BytesType { - // Even if fd.IsPacked() is false, bytes type is still accepted for - // repeated scalar numerics, so that changing a repeated field from - // packed to not-packed (or vice versa) is a compatible change. - return true - } - return wireType == wireTypeForKind(fd.Kind()) -} - -func wireTypeForKind(kind protoreflect.Kind) protowire.Type { - switch kind { - case protoreflect.StringKind, protoreflect.BytesKind, protoreflect.MessageKind: - return protowire.BytesType - case protoreflect.GroupKind: - return protowire.StartGroupType - case protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind, protoreflect.FloatKind: - return protowire.Fixed32Type - case protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind, protoreflect.DoubleKind: - return protowire.Fixed64Type - default: - // everything else uses varint - return protowire.VarintType - } -} - func targetTypeString(t descriptorpb.FieldOptions_OptionTargetType) string { return strings.ToLower(strings.ReplaceAll(strings.TrimPrefix(t.String(), "TARGET_TYPE_"), "_", " ")) } @@ -1192,35 +781,6 @@ func cloneInto(dest proto.Message, src proto.Message, res linker.Resolver) error return proto.UnmarshalOptions{Resolver: res}.Unmarshal(data, dest) } -func (interp *interpreter) toOptionBytes(mc *internal.MessageContext, results []*interpretedOption) ([]byte, error) { - // protoc emits non-custom options in tag order and then - // the rest are emitted in the order they are defined in source - sort.SliceStable(results, func(i, j int) bool { - if !results[i].unknown && results[j].unknown { - return true - } - if !results[i].unknown && !results[j].unknown { - return results[i].number < results[j].number - } - return false - }) - var b []byte - for _, res := range results { - var err error - b, err = res.appendOptionBytes(b) - if err != nil { - if _, ok := err.(reporter.ErrorWithPos); !ok { - span := ast.UnknownSpan(interp.file.AST().Name()) - err = reporter.Errorf(span, "%sfailed to encode options: %w", mc, err) - } - if err := interp.reporter.HandleError(err); err != nil { - return nil, err - } - } - } - return b, nil -} - func validateRecursive(msg protoreflect.Message, prefix string) error { flds := msg.Descriptor().Fields() var missingFields []string @@ -1280,7 +840,13 @@ func validateRecursive(msg protoreflect.Message, prefix string) error { // msg must be an options message. For nameIndex > 0, msg is a nested message inside of the // options message. The given pathPrefix is the path (sequence of field numbers and indices // with a FileDescriptorProto as the start) up to but not including the given nameIndex. -func (interp *interpreter) interpretField(mc *internal.MessageContext, msg protoreflect.Message, opt *descriptorpb.UninterpretedOption, nameIndex int, pathPrefix []interpretedEnclosingTag) (*interpretedOption, error) { +func (interp *interpreter) interpretField( + mc *internal.MessageContext, + msg protoreflect.Message, + opt *descriptorpb.UninterpretedOption, + nameIndex int, + pathPrefix []int32, +) (*sourceinfo.OptionSourceInfo, error) { var fld protoreflect.FieldDescriptor nm := opt.GetName()[nameIndex] node := interp.file.OptionNamePartNode(nm) @@ -1311,6 +877,7 @@ func (interp *interpreter) interpretField(mc *internal.MessageContext, msg proto mc, nm.GetNamePart(), msg.Descriptor().FullName()) } } + pathPrefix = append(pathPrefix, int32(fld.Number())) if len(opt.GetName()) > nameIndex+1 { nextnm := opt.GetName()[nameIndex+1] @@ -1344,228 +911,191 @@ func (interp *interpreter) interpretField(mc *internal.MessageContext, msg proto msg.Set(fld, fldVal) } // recurse to set next part of name - enclosing := interpretedEnclosingTag{ - tag: int32(fld.Number()), - kind: fld.Kind(), - } - return interp.interpretField(mc, fdm, opt, nameIndex+1, append(pathPrefix, enclosing)) + return interp.interpretField(mc, fdm, opt, nameIndex+1, pathPrefix) } optNode := interp.file.OptionNode(opt) optValNode := optNode.GetValue() - var val interpretedFieldValue - var index int + var srcInfo *sourceinfo.OptionSourceInfo var err error if optValNode.Value() == nil { - // We don't have an AST, so we get the value from the uninterpreted option proto. - // It's okay that we don't populate index as it is used to populate source code info, - // which can't be done without an AST. - val, err = interp.setOptionFieldFromProto(mc, msg, fld, node, opt, optValNode) + err = interp.setOptionFieldFromProto(mc, msg, fld, node, opt, optValNode) + srcInfoVal := newSrcInfo(pathPrefix, nil) + srcInfo = &srcInfoVal } else { - val, index, err = interp.setOptionField(mc, msg, fld, node, optValNode, false) + srcInfo, err = interp.setOptionField(mc, msg, fld, node, optValNode, false, pathPrefix) } if err != nil { return nil, interp.reporter.HandleError(err) } - return &interpretedOption{ - pathPrefix: pathPrefix, - interpretedField: interpretedField{ - node: optNode, - number: int32(fld.Number()), - index: int32(index), - kind: fld.Kind(), - repeated: fld.Cardinality() == protoreflect.Repeated, - value: val, - // NB: don't set packed here in a top-level option - // (only values in message literals will be serialized - // in packed format) - }, - }, nil + return srcInfo, nil } // setOptionField sets the value for field fld in the given message msg to the value represented // by AST node val. The given name is the AST node that corresponds to the name of fld. On success, // it returns additional metadata about the field that was set. -func (interp *interpreter) setOptionField(mc *internal.MessageContext, msg protoreflect.Message, fld protoreflect.FieldDescriptor, name ast.Node, val ast.ValueNode, insideMsgLiteral bool) (interpretedFieldValue, int, error) { +func (interp *interpreter) setOptionField( + mc *internal.MessageContext, + msg protoreflect.Message, + fld protoreflect.FieldDescriptor, + name ast.Node, + val ast.ValueNode, + insideMsgLiteral bool, + pathPrefix []int32, +) (*sourceinfo.OptionSourceInfo, error) { v := val.Value() if sl, ok := v.([]ast.ValueNode); ok { // handle slices a little differently than the others if fld.Cardinality() != protoreflect.Repeated { - return interpretedFieldValue{}, 0, reporter.Errorf(interp.nodeInfo(val), "%vvalue is an array but field is not repeated", mc) + return nil, reporter.Errorf(interp.nodeInfo(val), "%vvalue is an array but field is not repeated", mc) } origPath := mc.OptAggPath defer func() { mc.OptAggPath = origPath }() - var resVal listValue - var resMsgVals [][]*interpretedField + childVals := make([]sourceinfo.OptionSourceInfo, len(sl)) var firstIndex int + if fld.IsMap() { + firstIndex = msg.Get(fld).Map().Len() + } else { + firstIndex = msg.Get(fld).List().Len() + } for index, item := range sl { mc.OptAggPath = fmt.Sprintf("%s[%d]", origPath, index) - value, err := interp.fieldValue(mc, msg, fld, item, insideMsgLiteral) + value, srcInfo, err := interp.fieldValue(mc, msg, fld, item, insideMsgLiteral, append(pathPrefix, int32(firstIndex+index))) if err != nil { - return interpretedFieldValue{}, 0, err + return nil, err } if fld.IsMap() { mv := msg.Mutable(fld).Map() - if index == 0 { - firstIndex = mv.Len() - } - setMapEntry(fld, msg, mv, &value) + setMapEntry(fld, msg, mv, value.Message()) } else { lv := msg.Mutable(fld).List() - if index == 0 { - firstIndex = lv.Len() - } - lv.Append(value.val) - } - resVal = append(resVal, value.val) - if value.msgVal != nil { - resMsgVals = append(resMsgVals, value.msgVal) + lv.Append(value) } + childVals[index] = srcInfo } - return interpretedFieldValue{ - isList: true, - val: protoreflect.ValueOfList(&resVal), - msgListVal: resMsgVals, - }, firstIndex, nil + srcInfo := newSrcInfo(append(pathPrefix, int32(firstIndex)), &sourceinfo.ArrayLiteralSourceInfo{Elements: childVals}) + return &srcInfo, nil } - value, err := interp.fieldValue(mc, msg, fld, val, insideMsgLiteral) + if fld.IsMap() { + pathPrefix = append(pathPrefix, int32(msg.Get(fld).Map().Len())) + } else if fld.IsList() { + pathPrefix = append(pathPrefix, int32(msg.Get(fld).List().Len())) + } + + value, srcInfo, err := interp.fieldValue(mc, msg, fld, val, insideMsgLiteral, pathPrefix) if err != nil { - return interpretedFieldValue{}, 0, err + return nil, err } if ood := fld.ContainingOneof(); ood != nil { existingFld := msg.WhichOneof(ood) if existingFld != nil && existingFld.Number() != fld.Number() { - return interpretedFieldValue{}, 0, reporter.Errorf(interp.nodeInfo(name), "%voneof %q already has field %q set", mc, ood.Name(), fieldName(existingFld)) + return nil, reporter.Errorf(interp.nodeInfo(name), "%voneof %q already has field %q set", mc, ood.Name(), fieldName(existingFld)) } } - var index int switch { case fld.IsMap(): mv := msg.Mutable(fld).Map() - index = mv.Len() - setMapEntry(fld, msg, mv, &value) + setMapEntry(fld, msg, mv, value.Message()) case fld.IsList(): lv := msg.Mutable(fld).List() - index = lv.Len() - lv.Append(value.val) + lv.Append(value) default: if msg.Has(fld) { - return interpretedFieldValue{}, 0, reporter.Errorf(interp.nodeInfo(name), "%vnon-repeated option field %s already set", mc, fieldName(fld)) + return nil, reporter.Errorf(interp.nodeInfo(name), "%vnon-repeated option field %s already set", mc, fieldName(fld)) } - msg.Set(fld, value.val) + msg.Set(fld, value) } - return value, index, nil + return &srcInfo, nil } // setOptionFieldFromProto sets the value for field fld in the given message msg to the value // represented by the given uninterpreted option. The given ast.Node, if non-nil, will be used // to report source positions in error messages. On success, it returns additional metadata // about the field that was set. -func (interp *interpreter) setOptionFieldFromProto(mc *internal.MessageContext, msg protoreflect.Message, fld protoreflect.FieldDescriptor, name ast.Node, opt *descriptorpb.UninterpretedOption, node ast.Node) (interpretedFieldValue, error) { +func (interp *interpreter) setOptionFieldFromProto( + mc *internal.MessageContext, + msg protoreflect.Message, + fld protoreflect.FieldDescriptor, + name ast.Node, + opt *descriptorpb.UninterpretedOption, + node ast.Node, +) error { k := fld.Kind() - var value interpretedFieldValue + var value protoreflect.Value switch k { case protoreflect.EnumKind: num, _, err := interp.enumFieldValueFromProto(mc, fld.Enum(), opt, node) if err != nil { - return interpretedFieldValue{}, err + return err } - value = interpretedFieldValue{val: protoreflect.ValueOfEnum(num)} + value = protoreflect.ValueOfEnum(num) case protoreflect.MessageKind, protoreflect.GroupKind: if opt.AggregateValue == nil { - return interpretedFieldValue{}, reporter.Errorf(interp.nodeInfo(node), "%vexpecting message, got %s", mc, optionValueKind(opt)) + return reporter.Errorf(interp.nodeInfo(node), "%vexpecting message, got %s", mc, optionValueKind(opt)) } // We must parse the text format from the aggregate value string - fmd := fld.Message() - tmpMsg := dynamicpb.NewMessage(fmd) + var elem protoreflect.Message + switch { + case fld.IsMap(): + elem = dynamicpb.NewMessage(fld.Message()) + case fld.IsList(): + elem = msg.Get(fld).List().NewElement().Message() + default: + elem = msg.NewField(fld).Message() + } err := prototext.UnmarshalOptions{ Resolver: &msgLiteralResolver{interp: interp, pkg: fld.ParentFile().Package()}, AllowPartial: true, - }.Unmarshal([]byte(opt.GetAggregateValue()), tmpMsg) - if err != nil { - return interpretedFieldValue{}, reporter.Errorf(interp.nodeInfo(node), "%vfailed to parse message literal %w", mc, err) - } - msgData, err := proto.MarshalOptions{ - AllowPartial: true, - }.Marshal(tmpMsg) + }.Unmarshal([]byte(opt.GetAggregateValue()), elem.Interface()) if err != nil { - return interpretedFieldValue{}, reporter.Errorf(interp.nodeInfo(node), "%vfailed to serialize data from message literal %w", mc, err) + return reporter.Errorf(interp.nodeInfo(node), "%vfailed to parse message literal %w", mc, err) } - var data []byte - if k == protoreflect.GroupKind { - data = protowire.AppendTag(data, fld.Number(), protowire.StartGroupType) - data = append(data, msgData...) - data = protowire.AppendTag(data, fld.Number(), protowire.EndGroupType) - } else { - data = protowire.AppendTag(nil, fld.Number(), protowire.BytesType) - data = protowire.AppendBytes(data, msgData) - } - // NB: At this point, the serialized fields may no longer be in the same - // order as in the text literal. So for this case, the linker result's - // CanonicalProto won't be in *quite* the right order. ¯\_(ツ)_/¯ - value = interpretedFieldValue{preserialized: data} + value = protoreflect.ValueOfMessage(elem) default: v, err := interp.scalarFieldValueFromProto(mc, descriptorpb.FieldDescriptorProto_Type(k), opt, node) if err != nil { - return interpretedFieldValue{}, err + return err } - value = interpretedFieldValue{val: protoreflect.ValueOf(v)} + value = protoreflect.ValueOf(v) } if ood := fld.ContainingOneof(); ood != nil { existingFld := msg.WhichOneof(ood) if existingFld != nil && existingFld.Number() != fld.Number() { - return interpretedFieldValue{}, reporter.Errorf(interp.nodeInfo(name), "%voneof %q already has field %q set", mc, ood.Name(), fieldName(existingFld)) + return reporter.Errorf(interp.nodeInfo(name), "%voneof %q already has field %q set", mc, ood.Name(), fieldName(existingFld)) } } switch { - case value.preserialized != nil: - if !fld.IsList() && !fld.IsMap() && msg.Has(fld) { - return interpretedFieldValue{}, reporter.Errorf(interp.nodeInfo(name), "%vnon-repeated option field %s already set", mc, fieldName(fld)) - } - // We have to merge the bytes for this field into the message. - // TODO: if a map field, error if key for this entry already set? - err := proto.UnmarshalOptions{ - Resolver: &msgLiteralResolver{interp: interp, pkg: fld.ParentFile().Package()}, - AllowPartial: true, - Merge: true, - }.Unmarshal(value.preserialized, msg.Interface()) - if err != nil { - return interpretedFieldValue{}, reporter.Errorf(interp.nodeInfo(name), "%v failed to set value for field %v: %w", mc, fieldName(fld), err) - } + case fld.IsMap(): + mv := msg.Mutable(fld).Map() + setMapEntry(fld, msg, mv, value.Message()) case fld.IsList(): - msg.Mutable(fld).List().Append(value.val) + msg.Mutable(fld).List().Append(value) default: if msg.Has(fld) { - return interpretedFieldValue{}, reporter.Errorf(interp.nodeInfo(name), "%vnon-repeated option field %s already set", mc, fieldName(fld)) + return reporter.Errorf(interp.nodeInfo(name), "%vnon-repeated option field %s already set", mc, fieldName(fld)) } - msg.Set(fld, value.val) + msg.Set(fld, value) } - return value, nil + return nil } -func setMapEntry(fld protoreflect.FieldDescriptor, msg protoreflect.Message, mapVal protoreflect.Map, value *interpretedFieldValue) { - entry := value.val.Message() +func setMapEntry( + fld protoreflect.FieldDescriptor, + msg protoreflect.Message, + mapVal protoreflect.Map, + entry protoreflect.Message, +) { keyFld, valFld := fld.MapKey(), fld.MapValue() - // if an entry is missing a key or value, we add in an explicit - // zero value to msgVals to match protoc (which also odds these - // in even if not present in source) - if !entry.Has(keyFld) { - // put key before value - value.msgVal = append(append(([]*interpretedField)(nil), zeroValue(keyFld)), value.msgVal...) - } - if !entry.Has(valFld) { - value.msgVal = append(value.msgVal, zeroValue(valFld)) - } key := entry.Get(keyFld) val := entry.Get(valFld) if fld.MapValue().Kind() == protoreflect.MessageKind { @@ -1594,84 +1124,6 @@ func setMapEntry(fld protoreflect.FieldDescriptor, msg protoreflect.Message, map mapVal.Set(key.MapKey(), val) } -// zeroValue returns the zero value for the field types as a *interpretedField. -// The given fld must NOT be a repeated field. -func zeroValue(fld protoreflect.FieldDescriptor) *interpretedField { - var val protoreflect.Value - var msgVal []*interpretedField - switch fld.Kind() { - case protoreflect.MessageKind, protoreflect.GroupKind: - // needs to be non-nil, but empty - msgVal = []*interpretedField{} - msg := dynamicpb.NewMessage(fld.Message()) - val = protoreflect.ValueOfMessage(msg) - case protoreflect.EnumKind: - val = protoreflect.ValueOfEnum(0) - case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: - val = protoreflect.ValueOfInt32(0) - case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: - val = protoreflect.ValueOfUint32(0) - case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: - val = protoreflect.ValueOfInt64(0) - case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: - val = protoreflect.ValueOfUint64(0) - case protoreflect.BoolKind: - val = protoreflect.ValueOfBool(false) - case protoreflect.FloatKind: - val = protoreflect.ValueOfFloat32(0) - case protoreflect.DoubleKind: - val = protoreflect.ValueOfFloat64(0) - case protoreflect.BytesKind: - val = protoreflect.ValueOfBytes(nil) - case protoreflect.StringKind: - val = protoreflect.ValueOfString("") - } - return &interpretedField{ - number: int32(fld.Number()), - kind: fld.Kind(), - value: interpretedFieldValue{ - val: val, - msgVal: msgVal, - }, - } -} - -type listValue []protoreflect.Value - -var _ protoreflect.List = (*listValue)(nil) - -func (l *listValue) Len() int { - return len(*l) -} - -func (l *listValue) Get(i int) protoreflect.Value { - return (*l)[i] -} - -func (l *listValue) Set(i int, value protoreflect.Value) { - (*l)[i] = value -} - -func (l *listValue) Append(value protoreflect.Value) { - *l = append(*l, value) -} - -func (l *listValue) AppendMutable() protoreflect.Value { - panic("AppendMutable not supported") -} - -func (l *listValue) Truncate(i int) { - *l = (*l)[:i] -} - -func (l *listValue) NewElement() protoreflect.Value { - panic("NewElement not supported") -} - -func (l *listValue) IsValid() bool { - return true -} - type msgLiteralResolver struct { interp *interpreter pkg protoreflect.FullName @@ -1774,15 +1226,22 @@ func optionValueKind(opt *descriptorpb.UninterpretedOption) string { // fieldValue computes a compile-time value (constant or list or message literal) for the given // AST node val. The value in val must be assignable to the field fld. -func (interp *interpreter) fieldValue(mc *internal.MessageContext, msg protoreflect.Message, fld protoreflect.FieldDescriptor, val ast.ValueNode, insideMsgLiteral bool) (interpretedFieldValue, error) { +func (interp *interpreter) fieldValue( + mc *internal.MessageContext, + msg protoreflect.Message, + fld protoreflect.FieldDescriptor, + val ast.ValueNode, + insideMsgLiteral bool, + pathPrefix []int32, +) (protoreflect.Value, sourceinfo.OptionSourceInfo, error) { k := fld.Kind() switch k { case protoreflect.EnumKind: num, _, err := interp.enumFieldValue(mc, fld.Enum(), val, insideMsgLiteral) if err != nil { - return interpretedFieldValue{}, err + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, err } - return interpretedFieldValue{val: protoreflect.ValueOfEnum(num)}, nil + return protoreflect.ValueOfEnum(num), newSrcInfo(pathPrefix, nil), nil case protoreflect.MessageKind, protoreflect.GroupKind: v := val.Value() @@ -1800,22 +1259,28 @@ func (interp *interpreter) fieldValue(mc *internal.MessageContext, msg protorefl // Normal message field childMsg = msg.NewField(fld).Message() } - return interp.messageLiteralValue(mc, aggs, childMsg) + return interp.messageLiteralValue(mc, aggs, childMsg, pathPrefix) } - return interpretedFieldValue{}, reporter.Errorf(interp.nodeInfo(val), "%vexpecting message, got %s", mc, valueKind(v)) + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Errorf(interp.nodeInfo(val), "%vexpecting message, got %s", mc, valueKind(v)) default: v, err := interp.scalarFieldValue(mc, descriptorpb.FieldDescriptorProto_Type(k), val, insideMsgLiteral) if err != nil { - return interpretedFieldValue{}, err + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, err } - return interpretedFieldValue{val: protoreflect.ValueOf(v)}, nil + return protoreflect.ValueOf(v), newSrcInfo(pathPrefix, nil), nil } } // enumFieldValue resolves the given AST node val as an enum value descriptor. If the given // value is not a valid identifier (or number if allowed), an error is returned instead. -func (interp *interpreter) enumFieldValue(mc *internal.MessageContext, ed protoreflect.EnumDescriptor, val ast.ValueNode, allowNumber bool) (protoreflect.EnumNumber, protoreflect.Name, error) { +func (interp *interpreter) enumFieldValue( + mc *internal.MessageContext, + ed protoreflect.EnumDescriptor, + val ast.ValueNode, + allowNumber bool, +) (protoreflect.EnumNumber, protoreflect.Name, error) { v := val.Value() var num protoreflect.EnumNumber switch v := v.(type) { @@ -1858,7 +1323,12 @@ func (interp *interpreter) enumFieldValue(mc *internal.MessageContext, ed protor // enumFieldValueFromProto resolves the given uninterpreted option value as an enum value descriptor. // If the given value is not a valid identifier, an error is returned instead. -func (interp *interpreter) enumFieldValueFromProto(mc *internal.MessageContext, ed protoreflect.EnumDescriptor, opt *descriptorpb.UninterpretedOption, node ast.Node) (protoreflect.EnumNumber, protoreflect.Name, error) { +func (interp *interpreter) enumFieldValueFromProto( + mc *internal.MessageContext, + ed protoreflect.EnumDescriptor, + opt *descriptorpb.UninterpretedOption, + node ast.Node, +) (protoreflect.EnumNumber, protoreflect.Name, error) { // We don't have to worry about allowing numbers because numbers are never allowed // in uninterpreted values; they are only allowed inside aggregate values (i.e. // message literals). @@ -1877,7 +1347,12 @@ func (interp *interpreter) enumFieldValueFromProto(mc *internal.MessageContext, // scalarFieldValue resolves the given AST node val as a value whose type is assignable to a // field with the given fldType. -func (interp *interpreter) scalarFieldValue(mc *internal.MessageContext, fldType descriptorpb.FieldDescriptorProto_Type, val ast.ValueNode, insideMsgLiteral bool) (interface{}, error) { +func (interp *interpreter) scalarFieldValue( + mc *internal.MessageContext, + fldType descriptorpb.FieldDescriptorProto_Type, + val ast.ValueNode, + insideMsgLiteral bool, +) (interface{}, error) { v := val.Value() switch fldType { case descriptorpb.FieldDescriptorProto_TYPE_BOOL: @@ -2010,7 +1485,12 @@ func (interp *interpreter) scalarFieldValue(mc *internal.MessageContext, fldType // scalarFieldValue resolves the given uninterpreted option value as a value whose type is // assignable to a field with the given fldType. -func (interp *interpreter) scalarFieldValueFromProto(mc *internal.MessageContext, fldType descriptorpb.FieldDescriptorProto_Type, opt *descriptorpb.UninterpretedOption, node ast.Node) (interface{}, error) { +func (interp *interpreter) scalarFieldValueFromProto( + mc *internal.MessageContext, + fldType descriptorpb.FieldDescriptorProto_Type, + opt *descriptorpb.UninterpretedOption, + node ast.Node, +) (interface{}, error) { switch fldType { case descriptorpb.FieldDescriptorProto_TYPE_BOOL: if opt.IdentifierValue != nil { @@ -2159,16 +1639,18 @@ func descriptorType(m proto.Message) string { } } -func (interp *interpreter) messageLiteralValue(mc *internal.MessageContext, fieldNodes []*ast.MessageFieldNode, msg protoreflect.Message) (interpretedFieldValue, error) { +func (interp *interpreter) messageLiteralValue( + mc *internal.MessageContext, + fieldNodes []*ast.MessageFieldNode, + msg protoreflect.Message, + pathPrefix []int32, +) (protoreflect.Value, sourceinfo.OptionSourceInfo, error) { fmd := msg.Descriptor() origPath := mc.OptAggPath defer func() { mc.OptAggPath = origPath }() - // NB: we don't want to leave this nil, even if the - // message is empty, because that indicates to - // caller that the result is not a message - flds := make([]*interpretedField, 0, len(fieldNodes)) + flds := make(map[*ast.MessageFieldNode]*sourceinfo.OptionSourceInfo, len(fieldNodes)) for _, fieldNode := range fieldNodes { if origPath == "" { mc.OptAggPath = fieldNode.Name.Value() @@ -2177,10 +1659,12 @@ func (interp *interpreter) messageLiteralValue(mc *internal.MessageContext, fiel } if fieldNode.Name.IsAnyTypeReference() { if len(fieldNodes) > 1 { - return interpretedFieldValue{}, reporter.Errorf(interp.nodeInfo(fieldNode.Name.URLPrefix), "%vany type references cannot be repeated or mixed with other fields", mc) + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Errorf(interp.nodeInfo(fieldNode.Name.URLPrefix), "%vany type references cannot be repeated or mixed with other fields", mc) } if fmd.FullName() != "google.protobuf.Any" { - return interpretedFieldValue{}, reporter.Errorf(interp.nodeInfo(fieldNode.Name.URLPrefix), "%vtype references are only allowed for google.protobuf.Any, but this type is %s", mc, fmd.FullName()) + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Errorf(interp.nodeInfo(fieldNode.Name.URLPrefix), "%vtype references are only allowed for google.protobuf.Any, but this type is %s", mc, fmd.FullName()) } urlPrefix := fieldNode.Name.URLPrefix.AsIdentifier() msgName := fieldNode.Name.Name.AsIdentifier() @@ -2193,61 +1677,45 @@ func (interp *interpreter) messageLiteralValue(mc *internal.MessageContext, fiel // file's transitive closure to find the named message, since that // is what protoc does. if urlPrefix != "type.googleapis.com" && urlPrefix != "type.googleprod.com" { - return interpretedFieldValue{}, reporter.Errorf(interp.nodeInfo(fieldNode.Name.URLPrefix), "%vcould not resolve type reference %s", mc, fullURL) + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Errorf(interp.nodeInfo(fieldNode.Name.URLPrefix), "%vcould not resolve type reference %s", mc, fullURL) } anyFields, ok := fieldNode.Val.Value().([]*ast.MessageFieldNode) if !ok { - return interpretedFieldValue{}, reporter.Errorf(interp.nodeInfo(fieldNode.Val), "%vtype references for google.protobuf.Any must have message literal value", mc) + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Errorf(interp.nodeInfo(fieldNode.Val), "%vtype references for google.protobuf.Any must have message literal value", mc) } anyMd := resolveDescriptor[protoreflect.MessageDescriptor](interp.resolver, string(msgName)) if anyMd == nil { - return interpretedFieldValue{}, reporter.Errorf(interp.nodeInfo(fieldNode.Name.URLPrefix), "%vcould not resolve type reference %s", mc, fullURL) + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Errorf(interp.nodeInfo(fieldNode.Name.URLPrefix), "%vcould not resolve type reference %s", mc, fullURL) } // parse the message value - msgVal, err := interp.messageLiteralValue(mc, anyFields, dynamicpb.NewMessage(anyMd)) + msgVal, valueSrcInfo, err := interp.messageLiteralValue(mc, anyFields, dynamicpb.NewMessage(anyMd), append(pathPrefix, internal.AnyValueTag)) if err != nil { - return interpretedFieldValue{}, err + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, err } - // Any is defined with two fields: - // string type_url = 1 - // bytes value = 2 - typeURLDescriptor := fmd.Fields().ByNumber(1) + typeURLDescriptor := fmd.Fields().ByNumber(internal.AnyTypeURLTag) if typeURLDescriptor == nil || typeURLDescriptor.Kind() != protoreflect.StringKind { - return interpretedFieldValue{}, reporter.Errorf(interp.nodeInfo(fieldNode.Name), "%vfailed to set type_url string field on Any: %w", mc, err) + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Errorf(interp.nodeInfo(fieldNode.Name), "%vfailed to set type_url string field on Any: %w", mc, err) } typeURLVal := protoreflect.ValueOfString(fullURL) msg.Set(typeURLDescriptor, typeURLVal) - valueDescriptor := fmd.Fields().ByNumber(2) + valueDescriptor := fmd.Fields().ByNumber(internal.AnyValueTag) if valueDescriptor == nil || valueDescriptor.Kind() != protoreflect.BytesKind { - return interpretedFieldValue{}, reporter.Errorf(interp.nodeInfo(fieldNode.Name), "%vfailed to set value bytes field on Any: %w", mc, err) + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Errorf(interp.nodeInfo(fieldNode.Name), "%vfailed to set value bytes field on Any: %w", mc, err) } - b, err := appendOptionBytes(nil, msgVal.msgVal) + b, err := (proto.MarshalOptions{Deterministic: true}).Marshal(msgVal.Message().Interface()) if err != nil { - return interpretedFieldValue{}, reporter.Errorf(interp.nodeInfo(fieldNode.Val), "%vfailed to serialize message value: %w", mc, err) + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Errorf(interp.nodeInfo(fieldNode.Val), "%vfailed to serialize message value: %w", mc, err) } msg.Set(valueDescriptor, protoreflect.ValueOfBytes(b)) - msgVal.preserialized = asMessageBytes(2, b) // no need to recompute this for canonical option bytes - flds = []*interpretedField{ - { - node: fieldNode.Name, - number: 1, - kind: protoreflect.StringKind, - value: interpretedFieldValue{ - val: typeURLVal, - }, - }, - { - node: fieldNode.Val, - number: 2, - // Technically this field is a "bytes" kind. But the actual - // byte output is the same, and this way we can defer the - // computation of those bytes until later if needed. - kind: protoreflect.MessageKind, - value: msgVal, - }, - } + flds[fieldNode] = &valueSrcInfo } else { var ffld protoreflect.FieldDescriptor var err error @@ -2278,7 +1746,8 @@ func (interp *interpreter) messageLiteralValue(mc *internal.MessageContext, fiel // We only fail when this really looks like a group since we need to be // able to use the field name for fields in editions files that use the // delimited message encoding and don't use proto2 group naming. - return interpretedFieldValue{}, reporter.Errorf(interp.nodeInfo(fieldNode.Name), "%vfield %s not found (did you mean the group named %s?)", mc, fieldNode.Name.Value(), ffld.Message().Name()) + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Errorf(interp.nodeInfo(fieldNode.Name), "%vfield %s not found (did you mean the group named %s?)", mc, fieldNode.Name.Value(), ffld.Message().Name()) } if ffld == nil { err = protoregistry.NotFound @@ -2295,39 +1764,33 @@ func (interp *interpreter) messageLiteralValue(mc *internal.MessageContext, fiel } } if errors.Is(err, protoregistry.NotFound) { - return interpretedFieldValue{}, reporter.Errorf(interp.nodeInfo(fieldNode.Name), - "%vfield %s not found", mc, string(fieldNode.Name.Name.AsIdentifier())) + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Errorf(interp.nodeInfo(fieldNode.Name), "%vfield %s not found", mc, string(fieldNode.Name.Name.AsIdentifier())) } else if err != nil { - return interpretedFieldValue{}, reporter.Error(interp.nodeInfo(fieldNode.Name), err) + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Error(interp.nodeInfo(fieldNode.Name), err) } if fieldNode.Sep == nil && ffld.Message() == nil { // If there is no separator, the field type should be a message. - // Otherwise it is an error in the text format. - return interpretedFieldValue{}, reporter.Errorf(interp.nodeInfo(fieldNode.Val), "syntax error: unexpected value, expecting ':'") + // Otherwise, it is an error in the text format. + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Errorf(interp.nodeInfo(fieldNode.Val), "syntax error: unexpected value, expecting ':'") } - res, index, err := interp.setOptionField(mc, msg, ffld, fieldNode.Name, fieldNode.Val, true) + srcInfo, err := interp.setOptionField(mc, msg, ffld, fieldNode.Name, fieldNode.Val, true, append(pathPrefix, int32(ffld.Number()))) if err != nil { - return interpretedFieldValue{}, err + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, err } - flds = append(flds, &interpretedField{ - node: fieldNode, - number: int32(ffld.Number()), - index: int32(index), - kind: ffld.Kind(), - repeated: ffld.Cardinality() == protoreflect.Repeated, - packed: ffld.IsPacked(), - value: res, - }) - } - } - return interpretedFieldValue{ - val: protoreflect.ValueOfMessage(msg), - msgVal: flds, - }, nil + flds[fieldNode] = srcInfo + } + } + return protoreflect.ValueOfMessage(msg), + newSrcInfo(pathPrefix, &sourceinfo.MessageLiteralSourceInfo{Fields: flds}), + nil } -func asMessageBytes(tag protowire.Number, data []byte) []byte { - result := make([]byte, 0, protowire.SizeTag(tag)+protowire.SizeBytes(len(data))) - result = protowire.AppendTag(result, tag, protowire.BytesType) - return protowire.AppendBytes(result, data) +func newSrcInfo(path []int32, children sourceinfo.OptionChildrenSourceInfo) sourceinfo.OptionSourceInfo { + return sourceinfo.OptionSourceInfo{ + Path: internal.ClonePath(path), + Children: children, + } } diff --git a/options/options_test.go b/options/options_test.go index 60ae1677..e297feee 100644 --- a/options/options_test.go +++ b/options/options_test.go @@ -25,14 +25,12 @@ import ( "strings" "testing" - "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protodesc" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" - "google.golang.org/protobuf/testing/protocmp" "google.golang.org/protobuf/types/descriptorpb" "github.com/bufbuild/protocompile" @@ -41,6 +39,7 @@ import ( "github.com/bufbuild/protocompile/linker" "github.com/bufbuild/protocompile/options" "github.com/bufbuild/protocompile/parser" + "github.com/bufbuild/protocompile/protoutil" "github.com/bufbuild/protocompile/reporter" ) @@ -386,40 +385,33 @@ func TestOptionsEncoding(t *testing.T) { fdset := prototest.LoadDescriptorSet(t, descriptorSetFile, linker.ResolverFromFile(fds[0])) prototest.CheckFiles(t, res, fdset, false) - canonicalProto := res.CanonicalProto() actualFdset := &descriptorpb.FileDescriptorSet{ - File: []*descriptorpb.FileDescriptorProto{canonicalProto}, + File: []*descriptorpb.FileDescriptorProto{protoutil.ProtoFromFileDescriptor(res)}, } - actualData, err := proto.Marshal(actualFdset) - require.NoError(t, err) - // semantic check that unmarshalling the "canonical bytes" results - // in the same proto as when not using "canonical bytes" - protoData, err := proto.Marshal(canonicalProto) + // drum roll... make sure the descriptors we produce are semantically equivalent + // to those produced by protoc + expectedData, err := os.ReadFile(descriptorSetFile) require.NoError(t, err) - proto.Reset(canonicalProto) + expectedFdset := &descriptorpb.FileDescriptorSet{} uOpts := proto.UnmarshalOptions{Resolver: linker.ResolverFromFile(fds[0])} - err = uOpts.Unmarshal(protoData, canonicalProto) - require.NoError(t, err) - require.Empty(t, cmp.Diff(res.FileDescriptorProto(), canonicalProto, protocmp.Transform()), "canonical proto != proto") - - // drum roll... make sure the bytes match the protoc output - expectedData, err := os.ReadFile(descriptorSetFile) + err = uOpts.Unmarshal(expectedData, expectedFdset) require.NoError(t, err) - if !bytes.Equal(actualData, expectedData) { + if !prototest.AssertMessagesEqual(t, expectedFdset, actualFdset, file) { outputDescriptorSetFile := strings.ReplaceAll(descriptorSetFile, ".proto", ".actual.proto") + actualData, err := proto.Marshal(actualFdset) + require.NoError(t, err) err = os.WriteFile(outputDescriptorSetFile, actualData, 0644) if err != nil { - t.Log("failed to write actual to file") + t.Logf("failed to write actual to file: %v", err) + } else { + t.Logf("wrote actual contents to %s", outputDescriptorSetFile) } - - t.Fatalf("descriptor set bytes not equal (created file %q with actual bytes)", outputDescriptorSetFile) } }) } } -//nolint:errcheck func TestInterpretOptionsWithoutAST(t *testing.T) { t.Parallel() @@ -464,22 +456,7 @@ func TestInterpretOptionsWithoutAST(t *testing.T) { fd := file.(linker.Result).FileDescriptorProto() fdFromNoAST := fromNoAST.(linker.Result).FileDescriptorProto() // final protos, with options interpreted, match - diff := cmp.Diff(fd, fdFromNoAST, protocmp.Transform()) - require.Empty(t, diff) - } - - // Also make sure the canonical bytes are correct - for _, file := range filesFromNoAST { - res := file.(linker.Result) - canonicalFd := res.CanonicalProto() - data, err := proto.Marshal(canonicalFd) - require.NoError(t, err) - fromCanonical := &descriptorpb.FileDescriptorProto{} - err = proto.UnmarshalOptions{Resolver: linker.ResolverFromFile(file)}.Unmarshal(data, fromCanonical) - require.NoError(t, err) - origFd := res.FileDescriptorProto() - diff := cmp.Diff(origFd, fromCanonical, protocmp.Transform()) - require.Empty(t, diff) + prototest.AssertMessagesEqual(t, fd, fdFromNoAST, file.Path()) } } @@ -529,21 +506,6 @@ func TestInterpretOptionsWithoutASTNoOp(t *testing.T) { fd := file.(linker.Result).FileDescriptorProto() fdFromNoAST := fromNoAST.(linker.Result).FileDescriptorProto() // final protos, with options interpreted, match - diff := cmp.Diff(fd, fdFromNoAST, protocmp.Transform()) - require.Empty(t, diff) - } - - // Also make sure the canonical bytes are correct - for _, file := range filesFromNoAST { - res := file.(linker.Result) - canonicalFd := res.CanonicalProto() - data, err := proto.Marshal(canonicalFd) - require.NoError(t, err) - fromCanonical := &descriptorpb.FileDescriptorProto{} - err = proto.UnmarshalOptions{Resolver: linker.ResolverFromFile(file)}.Unmarshal(data, fromCanonical) - require.NoError(t, err) - origFd := res.FileDescriptorProto() - diff := cmp.Diff(origFd, fromCanonical, protocmp.Transform()) - require.Empty(t, diff) + prototest.AssertMessagesEqual(t, fd, fdFromNoAST, file.Path()) } } diff --git a/sourceinfo/source_code_info.go b/sourceinfo/source_code_info.go index 65bfb5ad..d3620ce4 100644 --- a/sourceinfo/source_code_info.go +++ b/sourceinfo/source_code_info.go @@ -45,6 +45,11 @@ type OptionSourceInfo struct { // The source info path to this element. If this element represents a // declaration with an array-literal value, the last element of the // path is the index of the first item in the array. + // If the first element is negative, it indicates the number of path + // components to remove from the path to the relevant options. This is + // used for field pseudo-options, so that the path indicates a field on + // the descriptor, which is a parent of the options message (since that + // is how the pseudo-options are actually stored). Path []int32 // Children can be an *ArrayLiteralSourceInfo, a *MessageLiteralSourceInfo, // or nil, depending on whether the option's value is an @@ -132,7 +137,7 @@ func (e extraOptionLocationsOption) apply(info *sourceCodeInfo) { } func generateSourceInfoForFile(opts OptionIndex, sci *sourceCodeInfo, file *ast.FileNode) { - path := make([]int32, 0, 10) + path := make([]int32, 0, 16) sci.newLocWithoutComments(file, nil) @@ -168,7 +173,10 @@ func generateSourceInfoForFile(opts OptionIndex, sci *sourceCodeInfo, file *ast. generateSourceCodeInfoForEnum(opts, sci, child, append(path, internal.FileEnumsTag, enumIndex)) enumIndex++ case *ast.ExtendNode: - generateSourceCodeInfoForExtensions(opts, sci, child, &extendIndex, &msgIndex, append(path, internal.FileExtensionsTag), append(dup(path), internal.FileMessagesTag)) + extsPath := append(path, internal.FileExtensionsTag) //nolint:gocritic // intentionally creating new slice var + // we clone the path here so that append can't mutate extsPath, since they may share storage + msgsPath := append(internal.ClonePath(path), internal.FileMessagesTag) + generateSourceCodeInfoForExtensions(opts, sci, child, &extendIndex, &msgIndex, extsPath, msgsPath) case *ast.ServiceNode: generateSourceCodeInfoForService(opts, sci, child, append(path, internal.FileServicesTag, svcIndex)) svcIndex++ @@ -257,6 +265,18 @@ func generateSourceInfoForOptionChildren(sci *sourceCodeInfo, n ast.ValueNode, p continue } fullPath := combinePathsForOption(pathPrefix, fieldInfo.Path) + locationNode := ast.Node(fieldNode) + if fieldNode.Name.IsAnyTypeReference() && fullPath[len(fullPath)-1] == internal.AnyValueTag { + // This is a special expanded Any. So also insert a location + // for the type URL field. + typeURLPath := make([]int32, len(fullPath)) + copy(typeURLPath, fullPath) + typeURLPath[len(typeURLPath)-1] = internal.AnyTypeURLTag + sci.newLoc(fieldNode.Name, fullPath) + // And create the next location so it's just the value, + // not the full field definition. + locationNode = fieldNode.Val + } _, isArrayLiteral := fieldNode.Val.(*ast.ArrayLiteralNode) if !isArrayLiteral { // We don't include this with an array literal since the path @@ -264,7 +284,7 @@ func generateSourceInfoForOptionChildren(sci *sourceCodeInfo, n ast.ValueNode, p // it would be redundant with the child info we add next, and // it wouldn't be entirely correct since it only indicates the // index of the first element in the array (and not the others). - sci.newLoc(fieldNode, fullPath) + sci.newLoc(locationNode, fullPath) } generateSourceInfoForOptionChildren(sci, fieldNode.Val, pathPrefix, fullPath, fieldInfo.Children) } @@ -317,18 +337,24 @@ func generateSourceCodeInfoForMessage(opts OptionIndex, sci *sourceCodeInfo, n a generateSourceCodeInfoForField(opts, sci, child, append(path, internal.MessageFieldsTag, fieldIndex)) fieldIndex++ case *ast.GroupNode: - fldPath := path - fldPath = append(fldPath, internal.MessageFieldsTag, fieldIndex) + fldPath := append(path, internal.MessageFieldsTag, fieldIndex) //nolint:gocritic // intentionally creating new slice var generateSourceCodeInfoForField(opts, sci, child, fldPath) fieldIndex++ - generateSourceCodeInfoForMessage(opts, sci, child, fldPath, append(dup(path), internal.MessageNestedMessagesTag, nestedMsgIndex)) + // we clone the path here so that append can't mutate fldPath, since they may share storage + msgPath := append(internal.ClonePath(path), internal.MessageNestedMessagesTag, nestedMsgIndex) + generateSourceCodeInfoForMessage(opts, sci, child, fldPath, msgPath) nestedMsgIndex++ case *ast.MapFieldNode: generateSourceCodeInfoForField(opts, sci, child, append(path, internal.MessageFieldsTag, fieldIndex)) fieldIndex++ nestedMsgIndex++ case *ast.OneofNode: - generateSourceCodeInfoForOneof(opts, sci, child, &fieldIndex, &nestedMsgIndex, append(path, internal.MessageFieldsTag), append(dup(path), internal.MessageNestedMessagesTag), append(dup(path), internal.MessageOneofsTag, oneofIndex)) + fldsPath := append(path, internal.MessageFieldsTag) //nolint:gocritic // intentionally creating new slice var + // we clone the path here and below so that append ops can't mutate + // fldPath or msgsPath, since they may otherwise share storage + msgsPath := append(internal.ClonePath(path), internal.MessageNestedMessagesTag) + ooPath := append(internal.ClonePath(path), internal.MessageOneofsTag, oneofIndex) + generateSourceCodeInfoForOneof(opts, sci, child, &fieldIndex, &nestedMsgIndex, fldsPath, msgsPath, ooPath) oneofIndex++ case *ast.MessageNode: generateSourceCodeInfoForMessage(opts, sci, child, nil, append(path, internal.MessageNestedMessagesTag, nestedMsgIndex)) @@ -337,7 +363,10 @@ func generateSourceCodeInfoForMessage(opts OptionIndex, sci *sourceCodeInfo, n a generateSourceCodeInfoForEnum(opts, sci, child, append(path, internal.MessageEnumsTag, nestedEnumIndex)) nestedEnumIndex++ case *ast.ExtendNode: - generateSourceCodeInfoForExtensions(opts, sci, child, &extendIndex, &nestedMsgIndex, append(path, internal.MessageExtensionsTag), append(dup(path), internal.MessageNestedMessagesTag)) + extsPath := append(path, internal.MessageExtensionsTag) //nolint:gocritic // intentionally creating new slice var + // we clone the path here so that append can't mutate extsPath, since they may share storage + msgsPath := append(internal.ClonePath(path), internal.MessageNestedMessagesTag) + generateSourceCodeInfoForExtensions(opts, sci, child, &extendIndex, &nestedMsgIndex, extsPath, msgsPath) case *ast.ExtensionRangeNode: generateSourceCodeInfoForExtensionRanges(opts, sci, child, &extRangeIndex, append(path, internal.MessageExtensionRangesTag)) case *ast.ReservedNode: @@ -604,8 +633,6 @@ type sourceCodeInfo struct { } func (sci *sourceCodeInfo) newLocWithoutComments(n ast.Node, path []int32) { - dup := make([]int32, len(path)) - copy(dup, path) var start, end ast.SourcePos if n == sci.file { // For files, we don't want to consider trailing EOF token @@ -628,7 +655,7 @@ func (sci *sourceCodeInfo) newLocWithoutComments(n ast.Node, path []int32) { start, end = info.Start(), info.End() } sci.locs = append(sci.locs, &descriptorpb.SourceCodeInfo_Location{ - Path: dup, + Path: internal.ClonePath(path), Span: makeSpan(start, end), }) } @@ -636,11 +663,9 @@ func (sci *sourceCodeInfo) newLocWithoutComments(n ast.Node, path []int32) { func (sci *sourceCodeInfo) newLoc(n ast.Node, path []int32) { info := sci.file.NodeInfo(n) if !sci.extraComments { - dup := make([]int32, len(path)) - copy(dup, path) start, end := info.Start(), info.End() sci.locs = append(sci.locs, &descriptorpb.SourceCodeInfo_Location{ - Path: dup, + Path: internal.ClonePath(path), Span: makeSpan(start, end), }) } else { @@ -701,13 +726,11 @@ func (sci *sourceCodeInfo) newLocWithGivenComments(nodeInfo ast.NodeInfo, detach detached[i] = sci.combineComments(cmts) } - dup := make([]int32, len(path)) - copy(dup, path) sci.locs = append(sci.locs, &descriptorpb.SourceCodeInfo_Location{ LeadingDetachedComments: detached, LeadingComments: lead, TrailingComments: trail, - Path: dup, + Path: internal.ClonePath(path), Span: makeSpan(nodeInfo.Start(), nodeInfo.End()), }) } @@ -933,7 +956,3 @@ func (sci *sourceCodeInfo) combineComments(comments comments) string { } return buf.String() } - -func dup(p []int32) []int32 { - return append(([]int32)(nil), p...) -} diff --git a/sourceinfo/source_code_info_test.go b/sourceinfo/source_code_info_test.go index 431a5372..293bd937 100644 --- a/sourceinfo/source_code_info_test.go +++ b/sourceinfo/source_code_info_test.go @@ -166,7 +166,8 @@ func TestSourceCodeInfoOptions(t *testing.T) { // set to true to re-generate golden output file const regenerateGoldenOutputFile = false - generateSourceInfoText := func(filename string, mode protocompile.SourceInfoMode) string { + generateSourceInfoText := func(t *testing.T, filename string, mode protocompile.SourceInfoMode) string { + t.Helper() compiler := protocompile.Compiler{ Resolver: protocompile.WithStandardImports(&protocompile.SourceResolver{ ImportPaths: []string{"../internal/testdata"}, @@ -206,14 +207,14 @@ func TestSourceCodeInfoOptions(t *testing.T) { testCase := testCase t.Run(testCase.name, func(t *testing.T) { t.Parallel() - output := generateSourceInfoText(testCase.filename, testCase.mode) + output := generateSourceInfoText(t, testCase.filename, testCase.mode) baseName := strings.TrimSuffix(testCase.filename, ".proto") if regenerateGoldenOutputFile { err := os.WriteFile(fmt.Sprintf("testdata/%s.%s.txt", baseName, testCase.name), []byte(output), 0644) require.NoError(t, err) // also create a file with standard comments, as a useful demonstration of the differences - output := generateSourceInfoText(testCase.filename, protocompile.SourceInfoStandard) + output := generateSourceInfoText(t, testCase.filename, protocompile.SourceInfoStandard) err = os.WriteFile(fmt.Sprintf("testdata/%s.standard.txt", baseName), []byte(output), 0644) require.NoError(t, err) return