diff --git a/Makefile b/Makefile index 45a3ae02..70b0684d 100644 --- a/Makefile +++ b/Makefile @@ -168,6 +168,9 @@ internal/testdata/options/options2/test.protoset: $(PROTOC) internal/testdata/op internal/testdata/options/test_proto3.protoset: $(PROTOC) internal/testdata/options/test_proto3.proto cd $(@D) && $(PROTOC) --descriptor_set_out=$(@F) -I. $(filter-out protoc,$(^F)) +internal/testdata/options/test_editions.protoset: $(PROTOC) internal/testdata/options/test_editions.proto + cd $(@D) && $(PROTOC) --experimental_editions --descriptor_set_out=$(@F) -I. $(filter-out protoc,$(^F)) + .PHONY: test-descriptors test-descriptors: internal/testdata/all.protoset test-descriptors: internal/testdata/desc_test_complex.protoset @@ -177,5 +180,7 @@ test-descriptors: internal/testdata/descriptor_impl_tests.protoset test-descriptors: internal/testdata/descriptor_editions_impl_tests.protoset test-descriptors: internal/testdata/editions/all.protoset test-descriptors: internal/testdata/source_info.protoset +test-descriptors: internal/testdata/options/options.protoset test-descriptors: internal/testdata/options/test.protoset test-descriptors: internal/testdata/options/test_proto3.protoset +test-descriptors: internal/testdata/options/test_editions.protoset diff --git a/internal/benchmarks/benchmark_test.go b/internal/benchmarks/benchmark_test.go index 7710ef7a..53017de3 100644 --- a/internal/benchmarks/benchmark_test.go +++ b/internal/benchmarks/benchmark_test.go @@ -46,7 +46,6 @@ import ( "github.com/kralicky/protocompile" "github.com/kralicky/protocompile/ast" "github.com/kralicky/protocompile/internal/protoc" - "github.com/kralicky/protocompile/linker" "github.com/kralicky/protocompile/parser" "github.com/kralicky/protocompile/parser/fastscan" "github.com/kralicky/protocompile/protoutil" @@ -235,7 +234,7 @@ func downloadAndExpand(url, targetDir string) (e error) { } func BenchmarkGoogleapisProtocompile(b *testing.B) { - benchmarkGoogleapisProtocompile(b, false, func() *protocompile.Compiler { + benchmarkGoogleapisProtocompile(b, func() *protocompile.Compiler { return &protocompile.Compiler{ Resolver: protocompile.WithStandardImports(&protocompile.SourceResolver{ ImportPaths: []string{googleapisDir}, @@ -246,20 +245,8 @@ func BenchmarkGoogleapisProtocompile(b *testing.B) { }) } -func BenchmarkGoogleapisProtocompileCanonical(b *testing.B) { - benchmarkGoogleapisProtocompile(b, true, func() *protocompile.Compiler { - return &protocompile.Compiler{ - Resolver: protocompile.WithStandardImports(&protocompile.SourceResolver{ - ImportPaths: []string{googleapisDir}, - }), - SourceInfoMode: protocompile.SourceInfoStandard, - // leave MaxParallelism unset to let it use all cores available - } - }) -} - func BenchmarkGoogleapisProtocompileNoSourceInfo(b *testing.B) { - benchmarkGoogleapisProtocompile(b, false, func() *protocompile.Compiler { + benchmarkGoogleapisProtocompile(b, func() *protocompile.Compiler { return &protocompile.Compiler{ Resolver: protocompile.WithStandardImports(&protocompile.SourceResolver{ ImportPaths: []string{googleapisDir}, @@ -270,13 +257,13 @@ func BenchmarkGoogleapisProtocompileNoSourceInfo(b *testing.B) { }) } -func benchmarkGoogleapisProtocompile(b *testing.B, canonicalBytes bool, factory func() *protocompile.Compiler) { +func benchmarkGoogleapisProtocompile(b *testing.B, factory func() *protocompile.Compiler) { for i := 0; i < b.N; i++ { - benchmarkProtocompile(b, factory(), googleapisSources, canonicalBytes) + benchmarkProtocompile(b, factory(), googleapisSources) } } -func benchmarkProtocompile(b *testing.B, c *protocompile.Compiler, sources []string, canonicalBytes bool) { +func benchmarkProtocompile(b *testing.B, c *protocompile.Compiler, sources []string) { resolvedSources := make([]protocompile.ResolvedPath, 0, len(sources)) for _, source := range sources { resolvedSources = append(resolvedSources, protocompile.ResolvedPath(source)) @@ -286,11 +273,7 @@ func benchmarkProtocompile(b *testing.B, c *protocompile.Compiler, sources []str var fdSet descriptorpb.FileDescriptorSet fdSet.File = make([]*descriptorpb.FileDescriptorProto, len(fds.Files)) for i, fd := range fds.Files { - if canonicalBytes { - fdSet.File[i] = fd.(linker.Result).CanonicalProto() - } else { - fdSet.File[i] = protoutil.ProtoFromFileDescriptor(fd) - } + fdSet.File[i] = protoutil.ProtoFromFileDescriptor(fd) } // protoc is writing output to file descriptor set, so we should, too writeToNull(b, &fdSet) @@ -488,7 +471,7 @@ func BenchmarkGoogleapisProtocompileSingleThreaded(b *testing.B) { // need to run a single-threaded compile MaxParallelism: 1, } - benchmarkProtocompile(b, c, googleapisSources, false) + benchmarkProtocompile(b, c, googleapisSources) } }) } diff --git a/internal/testdata/all.protoset b/internal/testdata/all.protoset new file mode 100644 index 00000000..6a41ac30 Binary files /dev/null and b/internal/testdata/all.protoset differ diff --git a/internal/testdata/formatting_tests.proto b/internal/testdata/formatting_tests.proto index a350be3b..65e4e4aa 100644 --- a/internal/testdata/formatting_tests.proto +++ b/internal/testdata/formatting_tests.proto @@ -2,8 +2,8 @@ syntax = "proto2"; package testprotos; -import "github.com/kralicky/protocompile/internal/testdata/desc_test_complex.proto"; -import "github.com/kralicky/protocompile/internal/testdata/desc_test_options.proto"; +import "desc_test_complex.proto"; +import "desc_test_options.proto"; import "google/protobuf/descriptor.proto"; option go_package = "github.com/kralicky/protocompile/internal/testprotos"; @@ -66,7 +66,7 @@ extend google.protobuf.FieldOptions { message KeywordCollisionOptions { optional uint64 id = 1 [ - (foo.bar.float) = inf, + // (foo.bar.float) = inf, (foo.bar.syntax) = true, (foo.bar.import) = true, (foo.bar.public) = true, diff --git a/internal/testdata/options/options.proto b/internal/testdata/options/options.proto index 1bc19383..b3aa48f4 100644 --- a/internal/testdata/options/options.proto +++ b/internal/testdata/options/options.proto @@ -7,8 +7,32 @@ syntax = "proto2"; package bufbuild.protocompile.test3; +import "google/protobuf/any.proto"; import "google/protobuf/descriptor.proto"; +option (any) = { + [type.googleapis.com/bufbuild.protocompile.test3.AllTypes]: { + pr_i32: [0, 1, 2, 3], + str: "foo", + }, +}; + +option (filegroup) = { + array: [1, 2, 3, 4, 5, 6, 7, 8], +}; +option (filegroup).i32 = -123; +option (filegroup).s = "abc"; + +option (filegroups) = { + s: "abc", + i32: 123, +}; + +option (filegroups) = { + s: "xyz", + i32: 456, +}; + message Extendable { optional string foo = 1; optional int32 bar = 2; @@ -186,3 +210,21 @@ extend google.protobuf.MethodOptions { optional AllTypes rpc = 3001; repeated int32 rpc_i = 3002 [packed = true]; } + +// Also test encoding of options where option is defined in +// the same file as the usage. This makes sure we correctly +// defer computation of option bytes until we know enough +// to do so correctly (since option bytes encoding could +// depend on interpretation of other options in this file). +extend google.protobuf.FileOptions { + optional group FileGroup = 1003 { + optional string s = 1; + optional int32 i32 = 2; + repeated int32 array = 3 [packed = true]; + } + repeated group FileGroups = 1004 { + optional string s = 1; + optional int32 i32 = 2; + } + optional google.protobuf.Any any = 1005; +} diff --git a/internal/testdata/options/options.protoset b/internal/testdata/options/options.protoset new file mode 100644 index 00000000..0b4b52f8 Binary files /dev/null and b/internal/testdata/options/options.protoset differ diff --git a/internal/testdata/options/test_editions.proto b/internal/testdata/options/test_editions.proto new file mode 100644 index 00000000..d9c38138 --- /dev/null +++ b/internal/testdata/options/test_editions.proto @@ -0,0 +1,83 @@ +edition = "2023"; + +package bufbuild.protocompile.test3.editions; + +import "google/protobuf/any.proto"; +import "google/protobuf/descriptor.proto"; + +option (delimited).foo.children = { + name: "abc-1", +}; +option (delimited).foo.children = { + name: "abc-2", +}; +option (delimited).foo.name = "abc"; +option (delimited).foo.val = VAL1; + +option (delimited).name = "123"; +option (delimited).other.name = "xyz"; +option (delimited).other.val = VAL0; + +option (delimiteds) = { + name: "ABC", + val: 1, +}; +option (delimiteds) = { + name: "XYZ", + val: 1, +}; +option (delimiteds) = { + name: "1234", + val: 0, +}; + +option (other) = { + name: "123", + val: VAL0, + Foo: , + // NOTE: We can't currently refer to children in here + // because referring to delimited-encoded fields whose + // name != lower-case(type-name) inside a message + // literal is currently broken in protoc :( + // https://github.com/protocolbuffers/protobuf/issues/16239 +}; + +option (others) = { + name: "123", + val: 0, +}; + +message Foo { + string name = 1; + Foo foo = 2 [ + (any) = { + [type.googleapis.com/bufbuild.protocompile.test3.editions.Foo]: { + Foo: { + name: "abc", + Foo: {name: "xyz"}, + }, + }, + }, + features.message_encoding = DELIMITED + ]; + Foo other = 3; + Val val = 4; + repeated Foo children = 5 [features.message_encoding = DELIMITED]; +} + +enum Val { + option features.enum_type = CLOSED; + VAL0 = 0; + VAL1 = 1; +} + +extend google.protobuf.FileOptions { + Foo delimited = 10101 [features.message_encoding = DELIMITED]; + Foo other = 10102; + repeated Foo delimiteds = 10103 [features.message_encoding = DELIMITED]; + repeated Foo others = 10104; +} + +extend google.protobuf.FieldOptions { + google.protobuf.Any any = 10101; +} diff --git a/internal/testdata/options/test_editions.protoset b/internal/testdata/options/test_editions.protoset new file mode 100644 index 00000000..4c3d39ca Binary files /dev/null and b/internal/testdata/options/test_editions.protoset differ diff --git a/internal/testdata/options/test_proto3.proto b/internal/testdata/options/test_proto3.proto index a34aba0b..df8cda39 100644 --- a/internal/testdata/options/test_proto3.proto +++ b/internal/testdata/options/test_proto3.proto @@ -15,51 +15,51 @@ syntax = "proto3"; package bufbuild.protocompile.test3; -import "options.proto"; import "google/protobuf/descriptor.proto"; +import "options.proto"; option go_package = "foo"; option java_package = "bar"; option (file) = { - i64: 1 - u32: 2 - u64: 3 - f32: 4 - f64: 5 - sf32: 6 - sf64: 7 - fl32: 8.9 - fl64: 9.101 - s32: -10 - s64: -11 - str: "file" - - oo_u32: 9876 - oo_f32: 1234 - oo_fl32: 1.2345e100 - oo_b: "\x00\x01\x02\x03" + i64: 1, + u32: 2, + u64: 3, + f32: 4, + f64: 5, + sf32: 6, + sf64: 7, + fl32: 8.9, + fl64: 9.101, + s32: -10, + s64: -11, + str: "file", + + oo_u32: 9876, + oo_f32: 1234, + oo_fl32: 1.2345e100, + oo_b: "\x00\x01\x02\x03", }; option (file).b = "\x00\x01\x02\x03"; option (file).flag = true; option (file).grp = { - foo: "abc" - bar: 999 + foo: "abc", + bar: 999, }; option (file).m_i32 = { - key: 123 - value: 0 + key: 123, + value: 0, }; option (file).m_i32 = { - key: -234 - value: 1 + key: -234, + value: 1, }; option (file).m_u32 = { - key: 234 - value: 1 + key: 234, + value: 1, }; option (file).m_u32 = { - key: 123 - value: 0 + key: 123, + value: 0, }; option (file).msg.(t) = { r_i32: [ @@ -67,53 +67,53 @@ option (file).msg.(t) = { 1, 2, 3 - ] + ], pr_i32: [ 0, 1, 2, 3 - ] + ], m_i32: [ {key: 123, value: 1}, {key: 234, value: 2} - ] + ], r_u32: [ 0, 1, 2, 3 - ] + ], pr_u32: [ 0, 1, 2, 3 - ] + ], m_u32: [ {key: 123, value: 1}, {key: 234, value: 2} - ] + ], }; option (file).msg.(t).msg.(t) = { - r_i32: 1 - r_i32: 2 - pr_i32: 1 - pr_i32: 2 - m_i32: {value: 0} - m_i32: {key: 123, value: 1} - m_i32: {key: 234, value: 2} - m_i32: {key: -345} - r_u32: 1 - r_u32: 2 - pr_u32: 1 - pr_u32: 2 - m_u32: {} - m_u32: {key: 234, value: 2} + r_i32: 1, + r_i32: 2, + pr_i32: 1, + pr_i32: 2, + m_i32: {value: 0}, + m_i32: {key: 123, value: 1}, + m_i32: {key: 234, value: 2}, + m_i32: {key: -345}, + r_u32: 1, + r_u32: 2, + pr_u32: 1, + pr_u32: 2, + m_u32: {}, + m_u32: {key: 234, value: 2}, }; option (file).msg.(t).msg.(t).msg.(t) = { - pr_i32: 1 - pr_u32: 1 + pr_i32: 1, + pr_u32: 1, }; option (file).pr_i32 = 1; option (file).pr_i32 = 0; @@ -122,14 +122,14 @@ option (file).pr_u32 = 1; option (file).r_i32 = 1; option (file).r_i32 = 0; option (file).r_msg = { - foo: "filefoo" - bar: 99 - baz: false + foo: "filefoo", + bar: 99, + baz: false, }; option (file).r_msg = { - foo: "filefoo2" - bar: 98 - baz: true + foo: "filefoo2", + bar: 98, + baz: true, }; option (file).r_u32 = 1; option (file).r_u32 = 0; @@ -140,23 +140,23 @@ option (file3).msg = { 1, 2, 3 - ] + ], u32: [ 0, 1, 2, 3 - ] + ], }; option (file3).msg.msg = { - i32: 1 - i32: 2 - u32: 1 - u32: 2 + i32: 1, + i32: 2, + u32: 1, + u32: 2, }; option (file3).msg.msg.msg = { - i32: 1 - u32: 1 + i32: 1, + u32: 1, }; option (file3).u32 = 0; option (file_i) = 1; @@ -242,24 +242,24 @@ message TestMessage { (fld_i) = 2, (fld_i) = 3, (fld) = { - i64: 0 - i32: 1 - u32: 2 - u64: 3 - f32: 4 - f64: 5 - sf64: 6 - sf32: 7 - fl32: 8.9 - fl64: 9.101 - s32: -10 - s64: -11 - str: "file" - - oo_i32: -9876 - oo_f32: 1234 - oo_fl32: 1.2345e100 - oo_b: "\x00\x01\x02\x03" + i64: 0, + i32: 1, + u32: 2, + u64: 3, + f32: 4, + f64: 5, + sf64: 6, + sf32: 7, + fl32: 8.9, + fl64: 9.101, + s32: -10, + s64: -11, + str: "file", + + oo_i32: -9876, + oo_f32: 1234, + oo_fl32: 1.2345e100, + oo_b: "\x00\x01\x02\x03", }, (fld).r_s32 = 0, @@ -270,19 +270,19 @@ message TestMessage { (fld3).s32 = 1, (fld).m_s32 = { - key: 123 - value: 0 + key: 123, + value: 0, }, (fld).m_s32 = { - key: -234 - value: 1 + key: -234, + value: 1, }, (fld).flag = true, (fld).b = "\x00\x01\x02\x03", (fld).grp = { - foo: "abc" - bar: 999 + foo: "abc", + bar: 999, }, (fld).r_fl32 = 0, @@ -293,24 +293,24 @@ message TestMessage { (fld3).fl32 = 1, (fld).m_fl32 = { - key: "abc" - value: 0 + key: "abc", + value: 0, }, (fld).m_fl32 = { - key: "def" - value: 1 + key: "def", + value: 1, }, (fld).r_msg = { - foo: "filefoo" - bar: 99 - baz: false + foo: "filefoo", + bar: 99, + baz: false, }, (fld).r_msg = { - foo: "filefoo2" - bar: 98 - baz: true + foo: "filefoo2", + bar: 98, + baz: true, }, (fld).msg.(t) = { @@ -319,33 +319,33 @@ message TestMessage { 1, 2, 3 - ] + ], pr_s32: [ 0, 1, 2, 3 - ] + ], m_s32: [ {key: 123, value: 1}, {key: -234, value: 2} - ] + ], r_fl32: [ 0, 1, 2, 3 - ] + ], pr_fl32: [ 0, 1, 2, 3 - ] + ], m_fl32: [ {key: "foo", value: 1}, {key: "bar", value: 2} - ] + ], }, (fld3).msg = { s32: [ @@ -353,43 +353,43 @@ message TestMessage { 1, 2, 3 - ] + ], fl32: [ 0, 1, 2, 3 - ] + ], }, (fld).msg.(t).msg.(t) = { - r_s32: 1 - r_s32: 2 - pr_s32: 1 - pr_s32: 2 - m_s32: {value: 0} - m_s32: {key: 123, value: 1} - m_s32: {key: 234, value: 2} - m_s32: {key: -345} - r_fl32: 1 - r_fl32: 2 - pr_fl32: 1 - pr_fl32: 2 - m_fl32: {} - m_fl32: {key: "bar", value: -2.2222} + r_s32: 1, + r_s32: 2, + pr_s32: 1, + pr_s32: 2, + m_s32: {value: 0}, + m_s32: {key: 123, value: 1}, + m_s32: {key: 234, value: 2}, + m_s32: {key: -345}, + r_fl32: 1, + r_fl32: 2, + pr_fl32: 1, + pr_fl32: 2, + m_fl32: {}, + m_fl32: {key: "bar", value: -2.2222}, }, (fld3).msg.msg = { - s32: 1 - s32: 2 - fl32: 1 - fl32: 2 + s32: 1, + s32: 2, + fl32: 1, + fl32: 2, }, (fld).msg.(t).msg.(t).msg.(t) = { - pr_s32: 1 - pr_fl32: 2 + pr_s32: 1, + pr_fl32: 2, }, (fld3).msg.msg.msg = { - s32: 1 - fl32: 1 + s32: 1, + fl32: 1, } ]; @@ -403,19 +403,19 @@ message TestMessage { sint32 ss = 12; option (oo) = { - i32: 0 - i64: 1 - u32: 2 - u64: 3 - f32: 4 - f64: 5 - sf32: 6 - sf64: 7 - fl32: 8.9 - fl64: 9.101 - s32: -10 - s64: -11 - str: "file" + i32: 0, + i64: 1, + u32: 2, + u64: 3, + f32: 4, + f64: 5, + sf32: 6, + sf64: 7, + fl32: 8.9, + fl64: 9.101, + s32: -10, + s64: -11, + str: "file", }; option (oo).oo_i64 = -9876; @@ -431,19 +431,19 @@ message TestMessage { option (oo3).i64 = 0; option (oo).m_i64 = { - key: 123 - value: 0 + key: 123, + value: 0, }; option (oo).m_i64 = { - key: -234 - value: 1 + key: -234, + value: 1, }; option (oo).flag = true; option (oo).b = "\x00\x01\x02\x03"; option (oo).grp = { - foo: "abc" - bar: 999 + foo: "abc", + bar: 999, }; option (oo).r_u64 = 0; @@ -454,24 +454,24 @@ message TestMessage { option (oo3).u64 = 0; option (oo).m_u64 = { - key: 123 - value: 0 + key: 123, + value: 0, }; option (oo).m_u64 = { - key: 234 - value: 1 + key: 234, + value: 1, }; option (oo).r_msg = { - foo: "filefoo" - bar: 99 - baz: false + foo: "filefoo", + bar: 99, + baz: false, }; option (oo).r_msg = { - foo: "filefoo2" - bar: 98 - baz: true + foo: "filefoo2", + bar: 98, + baz: true, }; option (oo).msg.(t) = { @@ -480,33 +480,33 @@ message TestMessage { 1, 2, 3 - ] + ], pr_i64: [ 0, 1, 2, 3 - ] + ], m_i64: [ {key: 123, value: 1}, {key: 234, value: 2} - ] + ], r_u64: [ 0, 1, 2, 3 - ] + ], pr_u64: [ 0, 1, 2, 3 - ] + ], m_u64: [ {key: 123, value: 1}, {key: 234, value: 2} - ] + ], }; option (oo3).msg = { i64: [ @@ -514,60 +514,60 @@ message TestMessage { 1, 2, 3 - ] + ], u64: [ 0, 1, 2, 3 - ] + ], }; option (oo).msg.(t).msg.(t) = { - r_i64: 1 - r_i64: 2 - pr_i64: 1 - pr_i64: 2 - m_i64: {value: 0} - m_i64: {key: 123, value: 1} - m_i64: {key: 234, value: 2} - m_i64: {key: -345} - r_u64: 1 - r_u64: 2 - pr_u64: 1 - pr_u64: 2 - m_u64: {} - m_u64: {key: 234, value: 2} + r_i64: 1, + r_i64: 2, + pr_i64: 1, + pr_i64: 2, + m_i64: {value: 0}, + m_i64: {key: 123, value: 1}, + m_i64: {key: 234, value: 2}, + m_i64: {key: -345}, + r_u64: 1, + r_u64: 2, + pr_u64: 1, + pr_u64: 2, + m_u64: {}, + m_u64: {key: 234, value: 2}, }; option (oo3).msg.msg = { - i64: 1 - i64: 2 - u64: 1 - u64: 2 + i64: 1, + i64: 2, + u64: 1, + u64: 2, }; option (oo).msg.(t).msg.(t).msg.(t) = { - pr_i64: 1 - pr_u64: 1 + pr_i64: 1, + pr_u64: 1, }; option (oo3).msg.msg.msg = { - i64: 1 - u64: 2 + i64: 1, + u64: 2, }; } option (msg) = { - i32: 0 - i64: 1 - u32: 2 - u64: 3 - f32: 4 - f64: 5 - sf32: 6 - sf64: 7 - fl32: 8.9 - fl64: 9.101 - s32: -10 - s64: -11 - str: "file" + i32: 0, + i64: 1, + u32: 2, + u64: 3, + f32: 4, + f64: 5, + sf32: 6, + sf64: 7, + fl32: 8.9, + fl64: 9.101, + s32: -10, + s64: -11, + str: "file", }; option (msg).oo_s64 = -9876; @@ -583,19 +583,19 @@ message TestMessage { option (msg3).f32 = 0; option (msg).m_f32 = { - key: 123 - value: 0 + key: 123, + value: 0, }; option (msg).m_f32 = { - key: 234 - value: 1 + key: 234, + value: 1, }; option (msg).flag = true; option (msg).b = "\x00\x01\x02\x03"; option (msg).grp = { - foo: "abc" - bar: 999 + foo: "abc", + bar: 999, }; option (msg).r_sf32 = 0; @@ -606,24 +606,24 @@ message TestMessage { option (msg3).sf32 = 0; option (msg).m_sf32 = { - key: 123 - value: 0 + key: 123, + value: 0, }; option (msg).m_sf32 = { - key: -234 - value: 1 + key: -234, + value: 1, }; option (msg).r_msg = { - foo: "filefoo" - bar: 99 - baz: false + foo: "filefoo", + bar: 99, + baz: false, }; option (msg).r_msg = { - foo: "filefoo2" - bar: 98 - baz: true + foo: "filefoo2", + bar: 98, + baz: true, }; option (msg).msg.(t) = { @@ -632,33 +632,33 @@ message TestMessage { 1, 2, 3 - ] + ], pr_f32: [ 0, 1, 2, 3 - ] + ], m_f32: [ {key: 123, value: 1}, {key: 234, value: 2} - ] + ], r_sf32: [ 0, 1, 2, 3 - ] + ], pr_sf32: [ 0, 1, 2, 3 - ] + ], m_sf32: [ {key: 123, value: 1}, {key: 234, value: 2} - ] + ], }; option (msg3).msg = { f32: [ @@ -666,43 +666,43 @@ message TestMessage { 1, 2, 3 - ] + ], sf32: [ 0, 1, 2, 3 - ] + ], }; option (msg).msg.(t).msg.(t) = { - r_f32: 1 - r_f32: 2 - pr_f32: 1 - pr_f32: 2 - m_f32: {value: 0} - m_f32: {key: 123, value: 1} - m_f32: {key: 234, value: 2} - m_f32: {key: 345} - r_sf32: 1 - r_sf32: 2 - pr_sf32: 1 - pr_sf32: 2 - m_sf32: {} - m_sf32: {key: -234, value: -2} + r_f32: 1, + r_f32: 2, + pr_f32: 1, + pr_f32: 2, + m_f32: {value: 0}, + m_f32: {key: 123, value: 1}, + m_f32: {key: 234, value: 2}, + m_f32: {key: 345}, + r_sf32: 1, + r_sf32: 2, + pr_sf32: 1, + pr_sf32: 2, + m_sf32: {}, + m_sf32: {key: -234, value: -2}, }; option (msg3).msg.msg = { - f32: 1 - f32: 2 - sf32: 1 - sf32: 2 + f32: 1, + f32: 2, + sf32: 1, + sf32: 2, }; option (msg).msg.(t).msg.(t).msg.(t) = { - pr_f32: 1 - pr_sf32: 1 + pr_f32: 1, + pr_sf32: 1, }; option (msg3).msg.msg.msg = { - f32: 1 - sf32: 1 + f32: 1, + sf32: 1, }; option message_set_wire_format = false; @@ -726,19 +726,19 @@ enum TestEnum { (env_i) = 3, (env) = { - i32: 0 - i64: 1 - u32: 2 - u64: 3 - f32: 4 - f64: 5 - sf32: 6 - sf64: 7 - fl32: 8.9 - fl64: 9.101 - s32: -10 - s64: -11 - str: "file" + i32: 0, + i64: 1, + u32: 2, + u64: 3, + f32: 4, + f64: 5, + sf32: 6, + sf64: 7, + fl32: 8.9, + fl64: 9.101, + s32: -10, + s64: -11, + str: "file", }, (env).oo_u32 = 9876, @@ -754,19 +754,19 @@ enum TestEnum { (env3).s64 = 0, (env).m_s64 = { - key: 123 - value: 0 + key: 123, + value: 0, }, (env).m_s64 = { - key: -234 - value: 1 + key: -234, + value: 1, }, (env).flag = true, (env).b = "\x00\x01\x02\x03", (env).grp = { - foo: "abc" - bar: 999 + foo: "abc", + bar: 999, }, (env).r_fl64 = 0, @@ -777,24 +777,24 @@ enum TestEnum { (env3).fl64 = 0, (env).m_fl64 = { - key: "abc" - value: 0 + key: "abc", + value: 0, }, (env).m_fl64 = { - key: "def" - value: 1 + key: "def", + value: 1, }, (env).r_msg = { - foo: "filefoo" - bar: 99 - baz: false + foo: "filefoo", + bar: 99, + baz: false, }, (env).r_msg = { - foo: "filefoo2" - bar: 98 - baz: true + foo: "filefoo2", + bar: 98, + baz: true, }, (env).msg.(t) = { @@ -803,33 +803,33 @@ enum TestEnum { 1, 2, 3 - ] + ], pr_s64: [ 0, 1, 2, 3 - ] + ], m_s64: [ {key: 123, value: 1}, {key: 234, value: 2} - ] + ], r_fl64: [ 0, 1, 2, 3 - ] + ], pr_fl64: [ 0, 1, 2, 3 - ] + ], m_fl64: [ {key: "foo", value: 1}, {key: "bar", value: 2} - ] + ], }, (env3).msg = { s64: [ @@ -837,65 +837,65 @@ enum TestEnum { 1, 2, 3 - ] + ], fl64: [ 0, 1, 2, 3 - ] + ], }, (env).msg.(t).msg.(t) = { - r_s64: 1 - r_s64: 2 - pr_s64: 1 - pr_s64: 2 - m_s64: {value: 0} - m_s64: {key: 123, value: 1} - m_s64: {key: 234, value: 2} - m_s64: {key: -345} - r_fl64: 1 - r_fl64: 2 - pr_fl64: 1 - pr_fl64: 2 - m_fl64: {} - m_fl64: {key: "bar", value: 2} + r_s64: 1, + r_s64: 2, + pr_s64: 1, + pr_s64: 2, + m_s64: {value: 0}, + m_s64: {key: 123, value: 1}, + m_s64: {key: 234, value: 2}, + m_s64: {key: -345}, + r_fl64: 1, + r_fl64: 2, + pr_fl64: 1, + pr_fl64: 2, + m_fl64: {}, + m_fl64: {key: "bar", value: 2}, }, (env3).msg.msg = { - s64: 1 - s64: 2 - fl64: 1 - fl64: 2 + s64: 1, + s64: 2, + fl64: 1, + fl64: 2, }, (env).msg.(t).msg.(t).msg.(t) = { - pr_s64: 1 - pr_fl64: 1 + pr_s64: 1, + pr_fl64: 1, }, (env3).msg.msg.msg = { - s64: 1 - fl64: 1 + s64: 1, + fl64: 1, } ]; option (en) = { - i32: 0 - i64: 1 - u32: 2 - u64: 3 - f32: 4 - f64: 5 - sf32: 6 - sf64: 7 - fl32: 8.9 - fl64: 9.101 - s32: -10 - s64: -11 - str: "file" - - oo_u64: 9876 - oo_f32: 1234 - oo_fl32: 1.2345e100 - oo_b: "\x00\x01\x02\x03" + i32: 0, + i64: 1, + u32: 2, + u64: 3, + f32: 4, + f64: 5, + sf32: 6, + sf64: 7, + fl32: 8.9, + fl64: 9.101, + s32: -10, + s64: -11, + str: "file", + + oo_u64: 9876, + oo_f32: 1234, + oo_fl32: 1.2345e100, + oo_b: "\x00\x01\x02\x03", }; option (en).r_f64 = 0; @@ -906,19 +906,19 @@ enum TestEnum { option (en3).f64 = 0; option (en).m_f64 = { - key: 123 - value: 0 + key: 123, + value: 0, }; option (en).m_f64 = { - key: 234 - value: 1 + key: 234, + value: 1, }; option (en).flag = true; option (en).b = "\x00\x01\x02\x03"; option (en).grp = { - foo: "abc" - bar: 999 + foo: "abc", + bar: 999, }; option (en).r_sf64 = 0; @@ -929,24 +929,24 @@ enum TestEnum { option (en3).sf64 = 0; option (en).m_sf64 = { - key: 123 - value: 0 + key: 123, + value: 0, }; option (en).m_sf64 = { - key: -234 - value: 1 + key: -234, + value: 1, }; option (en).r_msg = { - foo: "filefoo" - bar: 99 - baz: false + foo: "filefoo", + bar: 99, + baz: false, }; option (en).r_msg = { - foo: "filefoo2" - bar: 98 - baz: true + foo: "filefoo2", + bar: 98, + baz: true, }; option (en).msg.(t) = { @@ -955,33 +955,33 @@ enum TestEnum { 1, 2, 3 - ] + ], pr_f64: [ 0, 1, 2, 3 - ] + ], m_f64: [ {key: 123, value: 1}, {key: 234, value: 2} - ] + ], r_sf64: [ 0, 1, 2, 3 - ] + ], pr_sf64: [ 0, 1, 2, 3 - ] + ], m_sf64: [ {key: 123, value: 1}, {key: 234, value: 2} - ] + ], }; option (en3).msg = { f64: [ @@ -989,43 +989,43 @@ enum TestEnum { 1, 2, 3 - ] + ], sf64: [ 0, 1, 2, 3 - ] + ], }; option (en).msg.(t).msg.(t) = { - r_f64: 1 - r_f64: 2 - pr_f64: 1 - pr_f64: 2 - m_f64: {value: 0} - m_f64: {key: 123, value: 1} - m_f64: {key: 234, value: 2} - m_f64: {key: 345} - r_sf64: 1 - r_sf64: 2 - pr_sf64: 1 - pr_sf64: 2 - m_sf64: {} - m_sf64: {key: -234, value: -2} + r_f64: 1, + r_f64: 2, + pr_f64: 1, + pr_f64: 2, + m_f64: {value: 0}, + m_f64: {key: 123, value: 1}, + m_f64: {key: 234, value: 2}, + m_f64: {key: 345}, + r_sf64: 1, + r_sf64: 2, + pr_sf64: 1, + pr_sf64: 2, + m_sf64: {}, + m_sf64: {key: -234, value: -2}, }; option (en3).msg.msg = { - f64: 1 - f64: 2 - sf64: 1 - sf64: 2 + f64: 1, + f64: 2, + sf64: 1, + sf64: 2, }; option (en).msg.(t).msg.(t).msg.(t) = { - pr_f64: 1 - pr_sf64: 1 + pr_f64: 1, + pr_sf64: 1, }; option (en3).msg.msg.msg = { - f64: 1 - sf64: 1 + f64: 1, + sf64: 1, }; } @@ -1040,24 +1040,24 @@ service TestService { option (rpc_i) = 3; option (rpc) = { - i32: 0 - i64: 1 - u32: 2 - u64: 3 - f32: 4 - f64: 5 - sf32: 6 - sf64: 7 - fl32: 8.9 - fl64: 9.101 - s32: -10 - s64: -11 - str: "file" - - oo_i32: -9876 - oo_f32: 1234 - oo_fl32: 1.2345e100 - OO_Grp + i32: 0, + i64: 1, + u32: 2, + u64: 3, + f32: 4, + f64: 5, + sf32: 6, + sf64: 7, + fl32: 8.9, + fl64: 9.101, + s32: -10, + s64: -11, + str: "file", + + oo_i32: -9876, + oo_f32: 1234, + oo_fl32: 1.2345e100, + OO_Grp: , }; option (rpc).r_en = ZED; @@ -1068,43 +1068,43 @@ service TestService { option (rpc3).en = BAR; option (rpc).m_en = { - key: "abc" - value: ZED + key: "abc", + value: ZED, }; option (rpc).m_en = { - key: "def" - value: UNO + key: "def", + value: UNO, }; option (rpc).flag = true; option (rpc).b = "\x00\x01\x02\x03"; option (rpc).grp = { - foo: "abc" - bar: 999 + foo: "abc", + bar: 999, }; option (rpc).r_str = "abc"; option (rpc).r_str = "def"; option (rpc).m_str = { - key: "abc" - value: "zero" + key: "abc", + value: "zero", }; option (rpc).m_str = { - key: "def" - value: "one" + key: "def", + value: "one", }; option (rpc).r_msg = { - foo: "filefoo" - bar: 99 - baz: false + foo: "filefoo", + bar: 99, + baz: false, }; option (rpc).r_msg = { - foo: "filefoo2" - bar: 98 - baz: true + foo: "filefoo2", + bar: 98, + baz: true, }; option (rpc).msg.(t) = { @@ -1112,71 +1112,71 @@ service TestService { ZED, UNO, DOS - ] + ], pr_en: [ ZED, UNO, DOS - ] + ], m_en: [ {key: "foo", value: UNO}, {key: "bar", value: DOS} - ] + ], r_str: [ "abc", "def", "mno", "xyz" - ] + ], m_str: [ {key: "foo", value: "one"}, {key: "bar", value: "two"} - ] + ], }; option (rpc3).msg = { en: [ BAR, BAZ - ] + ], }; option (rpc).msg.(t).msg.(t) = { - r_en: UNO - r_en: DOS - pr_en: UNO - pr_en: DOS - m_en: {key: "foo", value: UNO} - m_en: {key: "bar", value: DOS} - r_str: "abc" - r_str: "def" - m_str: {key: "foo", value: "one"} - m_str: {key: "bar", value: "two"} + r_en: UNO, + r_en: DOS, + pr_en: UNO, + pr_en: DOS, + m_en: {key: "foo", value: UNO}, + m_en: {key: "bar", value: DOS}, + r_str: "abc", + r_str: "def", + m_str: {key: "foo", value: "one"}, + m_str: {key: "bar", value: "two"}, }; option (rpc3).msg.msg = { - en: BAR - en: BAZ + en: BAR, + en: BAZ, }; option (rpc).msg.(t).msg.(t).msg.(t) = { - pr_en: UNO + pr_en: UNO, }; option (rpc3).msg.msg.msg = { - en: BAR + en: BAR, }; } option (svc) = { - i32: 0 - i64: 1 - u32: 2 - u64: 3 - f32: 4 - f64: 5 - sf32: 6 - sf64: 7 - fl32: 8.9 - fl64: 9.101 - s32: -10 - s64: -11 - str: "file" + i32: 0, + i64: 1, + u32: 2, + u64: 3, + f32: 4, + f64: 5, + sf32: 6, + sf64: 7, + fl32: 8.9, + fl64: 9.101, + s32: -10, + s64: -11, + str: "file", }; option (svc).oo_i32 = -9876; @@ -1192,51 +1192,51 @@ service TestService { option (svc3).flag = true; option (svc).m_flag = { - key: "abc" - value: true + key: "abc", + value: true, }; option (svc).m_flag = { - key: "def" - value: false + key: "def", + value: false, }; option (svc).flag = true; option (svc).b = "\x00\x01\x02\x03"; option (svc).grp = { - foo: "abc" - bar: 999 + foo: "abc", + bar: 999, }; option (svc).r_b = "\x00\x01"; option (svc).r_b = "\x02\x03"; option (svc).r_grp = { - foo: "foo" - bar: 1 + foo: "foo", + bar: 1, }; option (svc).r_grp = { - foo: "bar" - bar: 2 + foo: "bar", + bar: 2, }; option (svc).m_b = { - key: "abc" - value: "\x00\x01" + key: "abc", + value: "\x00\x01", }; option (svc).m_b = { - key: "def" - value: "\x02\x03" + key: "def", + value: "\x02\x03", }; option (svc).r_msg = { - foo: "filefoo" - bar: 99 - baz: false + foo: "filefoo", + bar: 99, + baz: false, }; option (svc).r_msg = { - foo: "filefoo2" - bar: 98 - baz: true + foo: "filefoo2", + bar: 98, + baz: true, }; option (svc).msg.(t) = { @@ -1245,27 +1245,27 @@ service TestService { true, false, false - ] + ], pr_flag: [ false, false, true, true - ] + ], m_flag: [ {key: "foo", value: true}, {key: "bar", value: false} - ] + ], r_b: [ "abc", "def", "mno", "xyz" - ] + ], m_b: [ {key: "foo", value: "abc"}, {key: "bar", value: "def"} - ] + ], }; option (svc3).msg = { flag: [ @@ -1273,29 +1273,29 @@ service TestService { false, true, true - ] + ], }; option (svc).msg.(t).msg.(t) = { - r_flag: true - r_flag: false - pr_flag: true - pr_flag: false - m_flag: {key: "foo", value: true} - m_flag: {key: "bar", value: false} - r_b: "abc" - r_b: "def" - m_b: {key: "foo", value: "abc"} - m_b: {key: "bar", value: "def"} + r_flag: true, + r_flag: false, + pr_flag: true, + pr_flag: false, + m_flag: {key: "foo", value: true}, + m_flag: {key: "bar", value: false}, + r_b: "abc", + r_b: "def", + m_b: {key: "foo", value: "abc"}, + m_b: {key: "bar", value: "def"}, }; option (svc3).msg.msg = { - flag: true - flag: false + flag: true, + flag: false, }; option (svc).msg.(t).msg.(t).msg.(t) = { - pr_flag: true + pr_flag: true, }; option (svc3).msg.msg.msg = { - flag: true + flag: true, }; option deprecated = true; diff --git a/internal/testdata/options/test_proto3.protoset b/internal/testdata/options/test_proto3.protoset deleted file mode 100644 index fb1925e8..00000000 Binary files a/internal/testdata/options/test_proto3.protoset and /dev/null differ diff --git a/linker/descriptors.go b/linker/descriptors.go index 298c82a7..9de30154 100644 --- a/linker/descriptors.go +++ b/linker/descriptors.go @@ -154,11 +154,6 @@ type result struct { // interpreting options. usedImports map[string]struct{} - // A map of descriptor options messages to their pre-serialized bytes (using - // a canonical serialization format based on how protoc renders options to - // bytes). - optionBytes map[proto.Message][]byte - // A map of AST nodes that represent identifiers in ast.FieldReferenceNodes // to their fully-qualified name. The identifiers are for field names in // message literals (in option values) that are extension fields. These names @@ -434,141 +429,6 @@ func asSourceLocations(srcInfoProtos []*descriptorpb.SourceCodeInfo_Location) [] return locs } -// AddOptionBytes associates the given opts (an options message encoded in the -// binary format) with the given options protobuf message. The protobuf message -// should exist in the hierarchy of this result's FileDescriptorProto. This -// allows the FileDescriptorProto to be marshaled to bytes in a way that -// preserves the way options are defined in source (just as is done by protoc, -// but not possible when only using the generated Go types and standard -// marshaling APIs in the protobuf runtime). -func (r *result) AddOptionBytes(pm proto.Message, opts []byte) { - if r.optionBytes == nil { - r.optionBytes = map[proto.Message][]byte{} - } - r.optionBytes[pm] = append(r.optionBytes[pm], opts...) -} - -func (r *result) CanonicalProto() *descriptorpb.FileDescriptorProto { - origFd := r.FileDescriptorProto() - // make a copy that we can mutate - fd := proto.Clone(origFd).(*descriptorpb.FileDescriptorProto) //nolint:errcheck - - r.storeOptionBytesInFile(fd, origFd) - - return fd -} - -func (r *result) storeOptionBytes(opts, origOpts proto.Message) { - optionBytes := r.optionBytes[origOpts] - if len(optionBytes) == 0 { - // If we don't know about this options message, leave it alone. - return - } - proto.Reset(opts) - opts.ProtoReflect().SetUnknown(optionBytes) -} - -func (r *result) storeOptionBytesInFile(fd, origFd *descriptorpb.FileDescriptorProto) { - if fd.Options != nil { - r.storeOptionBytes(fd.Options, origFd.Options) - } - - for i, md := range fd.MessageType { - origMd := origFd.MessageType[i] - r.storeOptionBytesInMessage(md, origMd) - } - - for i, ed := range fd.EnumType { - origEd := origFd.EnumType[i] - r.storeOptionBytesInEnum(ed, origEd) - } - - for i, exd := range fd.Extension { - origExd := origFd.Extension[i] - r.storeOptionBytesInField(exd, origExd) - } - - for i, sd := range fd.Service { - origSd := origFd.Service[i] - if sd.Options != nil { - r.storeOptionBytes(sd.Options, origSd.Options) - } - - for j, mtd := range sd.Method { - origMtd := origSd.Method[j] - if mtd.Options != nil { - r.storeOptionBytes(mtd.Options, origMtd.Options) - } - } - } -} - -func (r *result) storeOptionBytesInMessage(md, origMd *descriptorpb.DescriptorProto) { - if md.GetOptions().GetMapEntry() { - // Map entry messages are synthesized. They won't have any option bytes - // since they don't actually appear in the source and thus have any option - // declarations in the source. - return - } - - if md.Options != nil { - r.storeOptionBytes(md.Options, origMd.Options) - } - - for i, fld := range md.Field { - origFld := origMd.Field[i] - r.storeOptionBytesInField(fld, origFld) - } - - for i, ood := range md.OneofDecl { - origOod := origMd.OneofDecl[i] - if ood.Options != nil { - r.storeOptionBytes(ood.Options, origOod.Options) - } - } - - for i, exr := range md.ExtensionRange { - origExr := origMd.ExtensionRange[i] - if exr.Options != nil { - r.storeOptionBytes(exr.Options, origExr.Options) - } - } - - for i, nmd := range md.NestedType { - origNmd := origMd.NestedType[i] - r.storeOptionBytesInMessage(nmd, origNmd) - } - - for i, ed := range md.EnumType { - origEd := origMd.EnumType[i] - r.storeOptionBytesInEnum(ed, origEd) - } - - for i, exd := range md.Extension { - origExd := origMd.Extension[i] - r.storeOptionBytesInField(exd, origExd) - } -} - -func (r *result) storeOptionBytesInEnum(ed, origEd *descriptorpb.EnumDescriptorProto) { - if ed.Options != nil { - r.storeOptionBytes(ed.Options, origEd.Options) - } - - for i, evd := range ed.Value { - origEvd := origEd.Value[i] - if evd.Options != nil { - r.storeOptionBytes(evd.Options, origEvd.Options) - } - } -} - -func (r *result) storeOptionBytesInField(fld, origFld *descriptorpb.FieldDescriptorProto) { - if fld.Options != nil { - r.storeOptionBytes(fld.Options, origFld.Options) - } -} - type fileImports struct { protoreflect.FileImports files []protoreflect.FileImport diff --git a/linker/linker.go b/linker/linker.go index 73aa8cc9..a61fb80a 100644 --- a/linker/linker.go +++ b/linker/linker.go @@ -213,30 +213,6 @@ type Result interface { FindExtendeeDescriptorByName(fqn protoreflect.FullName) protoreflect.MessageDescriptor FindExtensionsByMessage(fqn protoreflect.FullName) []protoreflect.ExtensionDescriptor - // CanonicalProto returns the file descriptor proto in a form that - // will be serialized in a canonical way. The "canonical" way matches - // the way that "protoc" emits option values, which is a way that - // mostly matches the way options are defined in source, including - // ordering and de-structuring. Unlike the FileDescriptorProto() method, - // this method is more expensive and results in a new descriptor proto - // being constructed with each call. - // - // The returned value will have all options (fields of the various - // descriptorpb.*Options message types) represented via unrecognized - // fields. So the returned value will serialize as desired, but it - // is otherwise not useful since all option values are treated as - // unknown. - // - // Note that CanonicalProto is a no-op if the options in this file - // were not interpreted by this module (e.g. the underlying descriptor - // proto was provided, with options already interpreted, instead of - // parsed from source). If the underlying descriptor proto was provided, - // but with a mix of interpreted and uninterpreted options, this method - // will effectively clear the already-interpreted fields and only the - // options actually interpreted by the compile operation will be - // retained. - CanonicalProto() *descriptorpb.FileDescriptorProto - // RemoveAST drops the AST information from this result. RemoveAST() } diff --git a/linker/linker_test.go b/linker/linker_test.go index ebcaaf4d..2f74aa46 100644 --- a/linker/linker_test.go +++ b/linker/linker_test.go @@ -461,35 +461,19 @@ func TestLinkerValidation(t *testing.T) { "foo.proto": "message Foo { option message_set_wire_format = true; extensions 1 to max; } extend Foo { optional Foo bar = 536870912; }", }, }, - "failure_message_set_wire_format_scalar2": { - input: map[string]string{ - "foo.proto": "message Foo { option message_set_wire_format = true; extensions 1 to 100; } extend Foo { optional int32 bar = 1; }", - }, - expectedErr: "foo.proto:1:99: messages with message-set wire format cannot contain scalar extensions, only messages", - }, - "success_message_set_wire_format2": { - input: map[string]string{ - "foo.proto": "message Foo { option message_set_wire_format = true; extensions 1 to 100; } extend Foo { optional Foo bar = 1; }", - }, - }, "failure_message_set_wire_format_repeated": { input: map[string]string{ "foo.proto": "message Foo { option message_set_wire_format = true; extensions 1 to 100; } extend Foo { repeated Foo bar = 1; }", }, expectedErr: "foo.proto:1:90: messages with message-set wire format cannot contain repeated extensions, only optional", }, - "success_large_extension_message_set_wire_format": { - input: map[string]string{ - "foo.proto": "message Foo { option message_set_wire_format = true; extensions 1 to max; } extend Foo { optional Foo bar = 536870912; }", - }, - }, - "failure_string_value_leading_dot": { + "failure_resolve_first_part_of_name": { input: map[string]string{ "foo.proto": `syntax = "proto3"; package com.google; import "google/protobuf/wrappers.proto"; message Foo { google.protobuf.StringValue str = 1; }`, }, expectedErr: "foo.proto:1:95: field com.google.Foo.str: unknown type google.protobuf.StringValue; resolved to com.google.protobuf.StringValue which is not defined; consider using a leading dot", }, - "success_group_message_extension": { + "success_group_in_custom_option": { input: map[string]string{ "foo.proto": ` syntax = "proto2"; @@ -501,7 +485,7 @@ func TestLinkerValidation(t *testing.T) { message Baz { option (foo).bar.name = "abc"; }`, }, }, - "failure_group_extension_not_exist": { + "failure_group_in_custom_option_referred_by_type_name": { input: map[string]string{ "foo.proto": ` syntax = "proto2"; @@ -514,7 +498,7 @@ func TestLinkerValidation(t *testing.T) { }, expectedErr: "foo.proto:7:28: message Baz: option (foo).Bar.name: field Bar of Foo does not exist", }, - "success_group_extension": { + "success_group_custom_option": { input: map[string]string{ "foo.proto": ` syntax = "proto2"; @@ -525,7 +509,7 @@ func TestLinkerValidation(t *testing.T) { message Bar { option (foo).name = "abc"; }`, }, }, - "failure_group_not_extension": { + "failure_group_custom_option_referred_by_type_name": { input: map[string]string{ "foo.proto": ` syntax = "proto2"; @@ -537,7 +521,7 @@ func TestLinkerValidation(t *testing.T) { }, expectedErr: "foo.proto:6:22: message Bar: invalid extension: Foo is a message, not an extension", }, - "success_group_custom_option": { + "success_group_in_custom_option_msg_literal": { input: map[string]string{ "foo.proto": ` syntax = "proto2"; @@ -549,7 +533,7 @@ func TestLinkerValidation(t *testing.T) { message Baz { option (foo) = { Bar< name: "abc" > }; }`, }, }, - "failure_group_custom_option": { + "failure_group_in_custom_option_msg_literal_referred_by_field_name": { input: map[string]string{ "foo.proto": ` syntax = "proto2"; @@ -562,7 +546,7 @@ func TestLinkerValidation(t *testing.T) { }, expectedErr: "foo.proto:7:32: message Baz: option (foo): field bar not found (did you mean the group named Bar?)", }, - "success_group_custom_option2": { + "success_group_extension_in_custom_option_msg_literal": { input: map[string]string{ "foo.proto": ` syntax = "proto2"; @@ -573,7 +557,7 @@ func TestLinkerValidation(t *testing.T) { message Baz { option (foo) = { [bar]< name: "abc" > }; }`, }, }, - "failure_group_extension_field_not_found": { + "failure_group_extension_in_custom_option_msg_literal_referred_by_type_name": { input: map[string]string{ "foo.proto": ` syntax = "proto2"; @@ -585,7 +569,7 @@ func TestLinkerValidation(t *testing.T) { }, expectedErr: "foo.proto:6:33: message Baz: option (foo): invalid extension: Bar is a message, not an extension", }, - "failure_oneof_extension_already_set": { + "failure_oneof_extension_already_set_msg_literal": { input: map[string]string{ "foo.proto": ` syntax = "proto3"; @@ -596,7 +580,7 @@ func TestLinkerValidation(t *testing.T) { }, expectedErr: `foo.proto:5:43: message Baz: option (foo): oneof "bar" already has field "baz" set`, }, - "failure_oneof_extension_already_set2": { + "failure_oneof_extension_already_set": { input: map[string]string{ "foo.proto": ` syntax = "proto3"; @@ -613,7 +597,7 @@ func TestLinkerValidation(t *testing.T) { // TODO: This is a bug of protoc (https://github.com/protocolbuffers/protobuf/issues/9125). // Difference is expected in the test before it is fixed. }, - "failure_oneof_extension_already_set3": { + "failure_oneof_extension_already_set_implied_by_destructured_option": { input: map[string]string{ "foo.proto": ` syntax = "proto3"; @@ -630,7 +614,7 @@ func TestLinkerValidation(t *testing.T) { // TODO: This is a bug of protoc (https://github.com/protocolbuffers/protobuf/issues/9125). // Difference is expected in the test before it is fixed. }, - "failure_oneof_extension_already_set4": { + "failure_oneof_extension_already_set_implied_by_deeply_nested_destructured_option": { input: map[string]string{ "foo.proto": ` syntax = "proto3"; @@ -647,7 +631,7 @@ func TestLinkerValidation(t *testing.T) { // TODO: This is a bug of protoc (https://github.com/protocolbuffers/protobuf/issues/9125). // Difference is expected in the test before it is fixed. }, - "success_repeated_extensions": { + "success_empty_array_literal_no_leading_colon_if_msg": { input: map[string]string{ "foo.proto": ` syntax = "proto3"; @@ -660,7 +644,7 @@ func TestLinkerValidation(t *testing.T) { };`, }, }, - "failure_repeated_primitive_no_leading_colon": { + "failure_empty_array_literal_require_leading_colon_if_scalar": { input: map[string]string{ "foo.proto": ` syntax = "proto3"; @@ -674,7 +658,7 @@ func TestLinkerValidation(t *testing.T) { }, expectedErr: `foo.proto:6:8: syntax error: unexpected value, expecting ':'`, }, - "success_extension_repeated_field_values": { + "success_array_literal": { input: map[string]string{ "foo.proto": ` syntax = "proto3"; @@ -687,7 +671,7 @@ func TestLinkerValidation(t *testing.T) { };`, }, }, - "failure_extension_unexpected_string_literal": { + "failure_array_literal_require_leading_colon_if_scalar": { input: map[string]string{ "foo.proto": ` syntax = "proto3"; @@ -701,7 +685,7 @@ func TestLinkerValidation(t *testing.T) { }, expectedErr: `foo.proto:6:9: syntax error: unexpected string literal, expecting '{' or '<' or ']'`, }, - "failure_extension_enum_value_not_message": { + "failure_scoping_resolves_to_sibling_not_parent": { input: map[string]string{ "foo.proto": ` package foo.bar; @@ -714,7 +698,7 @@ func TestLinkerValidation(t *testing.T) { }, expectedErr: `foo.proto:6:10: extendee is invalid: foo.bar.M.M is an enum value, not a message`, }, - "failure_json_name_extension": { + "failure_json_name_on_extension": { input: map[string]string{ "foo.proto": ` syntax = "proto3"; @@ -725,7 +709,8 @@ func TestLinkerValidation(t *testing.T) { }, expectedErr: "foo.proto:4:26: field foobar: option json_name is not allowed on extensions", }, - "success_json_name_extension_default": { + "success_json_name_on_extension_ok_if_default": { + // Unclear if this should really be valid... But it's what protoc does. input: map[string]string{ "foo.proto": ` syntax = "proto3"; @@ -745,7 +730,7 @@ func TestLinkerValidation(t *testing.T) { }, expectedErr: "foo.proto:3:36: field Foo.foobar: option json_name value cannot start with '[' and end with ']'; that is reserved for representing extensions", }, - "success_json_name_not_quite_extension": { + "success_json_name_not_quite_extension_okay": { input: map[string]string{ "foo.proto": ` syntax = "proto3"; @@ -778,7 +763,7 @@ func TestLinkerValidation(t *testing.T) { }, expectedErr: "foo.proto:4:3: field Foo.e: google.protobuf.Struct.FieldsEntry is a synthetic map entry and may not be referenced explicitly", }, - "failure_proto3_extend_add_field": { + "failure_proto3_can_only_extend_options": { input: map[string]string{ "foo.proto": ` syntax = "proto2"; @@ -821,7 +806,7 @@ func TestLinkerValidation(t *testing.T) { }, expectedErr: "foo.proto:6:3: syntax error: unexpected ';'", }, - "failure_oneof_field_conflict": { + "failure_oneof_conflicts_with_contained_field": { input: map[string]string{ "a.proto": ` syntax = "proto3"; @@ -833,7 +818,7 @@ func TestLinkerValidation(t *testing.T) { }, expectedErr: `a.proto:4:15: symbol "m.z" already defined at a.proto:3:9`, }, - "failure_oneof_field_conflict2": { + "failure_oneof_conflicts_with_adjacent_field": { input: map[string]string{ "a.proto": ` syntax="proto3"; @@ -844,7 +829,7 @@ func TestLinkerValidation(t *testing.T) { }, expectedErr: `a.proto:4:9: symbol "m.z" already defined at a.proto:3:10`, }, - "failure_oneof_conflicts": { + "failure_oneof_conflicts_with_other_oneof": { input: map[string]string{ "a.proto": ` syntax="proto3"; @@ -855,7 +840,7 @@ func TestLinkerValidation(t *testing.T) { }, expectedErr: `a.proto:4:9: symbol "m.z" already defined at a.proto:3:9`, }, - "success_message_literals": { + "success_custom_option_enums_look_like_msg_literal_keywords": { input: map[string]string{ "foo.proto": ` syntax = "proto3"; @@ -968,7 +953,11 @@ func TestLinkerValidation(t *testing.T) { }, expectedErr: "test.proto:9:10: extendee is invalid: foo.bar.c.b is an extension, not a message", }, - "failure_extension_resolution_unknown": { + "failure_msg_literal_scoping_rules_limited": { + // This is due to an unfortunate way of how message literals are actually implemented + // in protoc. It just uses the text format, so parsing the text format has different + // (and much more limited) resolution/scoping rules for relative references than other + // references in protobuf language. input: map[string]string{ "test.proto": ` syntax="proto2"; @@ -987,7 +976,7 @@ func TestLinkerValidation(t *testing.T) { }, expectedErr: "test.proto:11:6: message foo.bar.b: option (foo.bar.msga): unknown extension c.i", }, - "failure_extension_resolution_unknown2": { + "failure_msg_literal_scoping_rules_limited2": { input: map[string]string{ "test.proto": ` syntax="proto2"; @@ -1006,7 +995,11 @@ func TestLinkerValidation(t *testing.T) { }, expectedErr: "test.proto:11:6: message foo.bar.b: option (foo.bar.msga): unknown extension i", }, - "failure_extension_resolution_unknown3": { + "failure_option_scoping_rules_limited": { + // This is an unfortunate side effect of having no language spec and so accidental + // quirks in the implementation end up as part of the language :( + // In this case, names in the option can't resolve to siblings, but must resolve + // to a scope at least one level higher. input: map[string]string{ "test.proto": ` syntax="proto2"; @@ -1023,7 +1016,7 @@ func TestLinkerValidation(t *testing.T) { }, expectedErr: "test.proto:10:17: message foo.bar.b: unknown extension c.f", }, - "failure_extension_resolution_unknown4": { + "failure_option_scoping_rules_limited2": { input: map[string]string{ "test.proto": ` syntax="proto2"; @@ -1040,7 +1033,9 @@ func TestLinkerValidation(t *testing.T) { }, expectedErr: "test.proto:10:17: message foo.bar.b: unknown extension f", }, - "success_nested_extension_resolution_custom_options": { + "success_option_and_msg_literal_scoping_rules": { + // This demonstrates all the ways one can successfully refer to extensions + // in option names and in message literals. input: map[string]string{ "test.proto": ` syntax="proto2"; @@ -1067,7 +1062,7 @@ func TestLinkerValidation(t *testing.T) { }`, }, }, - "failure_extension_resolution_unknown_nested": { + "failure_msg_literal_scoping_rules_limited3": { input: map[string]string{ "test.proto": ` syntax="proto2"; @@ -1205,7 +1200,7 @@ func TestLinkerValidation(t *testing.T) { }; }`, }, - expectedErr: "foo.proto:13:6: message foo.bar.Baz: option (foo.bar.any): multiple any type references are not allowed", + expectedErr: "foo.proto:9:6: message foo.bar.Baz: option (foo.bar.any): any type references cannot be repeated or mixed with other fields", }, "failure_scope_type_name": { input: map[string]string{ @@ -2220,7 +2215,7 @@ func TestLinkerValidation(t *testing.T) { } `, }, - expectedErr: `test.proto:3:18: feature "enum_type" is allowed on [enum,file], not on field`, + expectedErr: `test.proto:3:27: feature "enum_type" is allowed on [enum,file], not on field`, }, "failure_editions_feature_on_wrong_target_type_msg_literal": { input: map[string]string{ diff --git a/linker/resolve.go b/linker/resolve.go index cb0333ed..a4404dbe 100644 --- a/linker/resolve.go +++ b/linker/resolve.go @@ -557,7 +557,7 @@ opts: // also resolve any extension names found inside message literals in option values mc.Option = opt optNode := r.OptionNode(opt) - if optNode.IsIncomplete() { + if optNode == nil || optNode.IsIncomplete() { continue } if err := r.resolveOptionValue(handler, mc, optNode.Val, scopes); err != nil { diff --git a/options/options.go b/options/options.go index 77e67633..0a3092c7 100644 --- a/options/options.go +++ b/options/options.go @@ -30,11 +30,9 @@ import ( "errors" "fmt" "math" - "sort" "strings" "google.golang.org/protobuf/encoding/prototext" - "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" @@ -62,19 +60,12 @@ var ( type interpreter struct { file file resolver linker.Resolver - container optionsContainer overrideDescriptorProto linker.File lenient bool handler *reporter.Handler index sourceinfo.OptionIndex + pathBuffer []int32 descriptorIndex sourceinfo.OptionDescriptorIndex - neededOpenEnums map[protoreflect.EnumDescriptor][]enumValRef -} - -type enumValRef struct { - number protoreflect.EnumNumber - value ast.Node - mc *protointernal.MessageContext } type file interface { @@ -157,69 +148,70 @@ func interpretOptions(file file, res linker.Resolver, handler *reporter.Handler, handler: handler, index: sourceinfo.OptionIndex{}, descriptorIndex: sourceinfo.NewOptionDescriptorIndex(), + pathBuffer: make([]int32, 0, 16), } - interp.container, _ = file.(optionsContainer) for _, opt := range interpOpts { opt(&interp) } + // We have to do this in two phases. First we interpret non-custom options. + // This allows us to handle standard options and features that may needed to + // correctly reference the custom options in the second phase. + if err := interp.interpretFileOptions(file, false); err != nil { + return nil, sourceinfo.OptionDescriptorIndex{}, err + } + // Now we can do custom options. + if err := interp.interpretFileOptions(file, true); err != nil { + return nil, sourceinfo.OptionDescriptorIndex{}, err + } + return interp.index, interp.descriptorIndex, nil +} +func (interp *interpreter) interpretFileOptions(file file, customOpts bool) error { fd := file.FileDescriptorProto() prefix := fd.GetPackage() if prefix != "" { prefix += "." } - err := interpretElementOptions(&interp, fd.GetName(), targetTypeFile, fd) + err := interpretElementOptions(interp, fd.GetName(), targetTypeFile, fd, customOpts) if err != nil { - return nil, sourceinfo.OptionDescriptorIndex{}, err + return err } for _, md := range fd.GetMessageType() { fqn := prefix + md.GetName() - if err := interp.interpretMessageOptions(fqn, md); err != nil { - return nil, sourceinfo.OptionDescriptorIndex{}, err + if err := interp.interpretMessageOptions(fqn, md, customOpts); err != nil { + return err } } for _, fld := range fd.GetExtension() { fqn := prefix + fld.GetName() - if err := interp.interpretFieldOptions(fqn, fld); err != nil { - return nil, sourceinfo.OptionDescriptorIndex{}, err + if err := interp.interpretFieldOptions(fqn, fld, customOpts); err != nil { + return err } } for _, ed := range fd.GetEnumType() { fqn := prefix + ed.GetName() - if err := interp.interpretEnumOptions(fqn, ed); err != nil { - return nil, sourceinfo.OptionDescriptorIndex{}, err + if err := interp.interpretEnumOptions(fqn, ed, customOpts); err != nil { + return err } } for _, sd := range fd.GetService() { fqn := prefix + sd.GetName() - err := interpretElementOptions(&interp, fqn, targetTypeService, sd) + err := interpretElementOptions(interp, fqn, targetTypeService, sd, customOpts) if err != nil { - return nil, sourceinfo.OptionDescriptorIndex{}, err + return err } for _, mtd := range sd.GetMethod() { mtdFqn := fqn + "." + mtd.GetName() - err := interpretElementOptions(&interp, mtdFqn, targetTypeMethod, mtd) + err := interpretElementOptions(interp, mtdFqn, targetTypeMethod, mtd, customOpts) if err != nil { - return nil, sourceinfo.OptionDescriptorIndex{}, err - } - } - } - // Now that we're done, we need to go back and check any enum value references - // that couldn't be validated. - for ed, refs := range interp.neededOpenEnums { - if ed.IsClosed() { - for _, ref := range refs { - err := handler.HandleErrorf(interp.nodeInfo(ref.value), "%vclosed enum %s has no value with number %d", ref.mc, ed.FullName(), ref.number) - if err != nil { - return nil, sourceinfo.OptionDescriptorIndex{}, err - } + return err } } } - return interp.index, interp.descriptorIndex, nil + return nil } -func resolveDescriptor[T protoreflect.Descriptor](res linker.Resolver, name string) T { +func resolveDescriptor[T protoreflect.Descriptor](res linker.Resolver, name protoreflect.FullName) T { var zero T if res == nil { return zero @@ -227,7 +219,7 @@ func resolveDescriptor[T protoreflect.Descriptor](res linker.Resolver, name stri if len(name) > 0 && name[0] == '.' { name = name[1:] } - desc, _ := res.FindDescriptorByName(protoreflect.FullName(name)) + desc, _ := res.FindDescriptorByName(name) typedDesc, ok := desc.(T) if ok { return typedDesc @@ -249,7 +241,7 @@ func (interp *interpreter) resolveExtensionType(name string) (protoreflect.Exten return ext.TypeDescriptor(), nil } -func (interp *interpreter) resolveOptionsType(name string) protoreflect.MessageDescriptor { +func (interp *interpreter) resolveOptionsType(name protoreflect.FullName) protoreflect.MessageDescriptor { md := resolveDescriptor[protoreflect.MessageDescriptor](interp.resolver, name) if md != nil { return md @@ -260,7 +252,7 @@ func (interp *interpreter) resolveOptionsType(name string) protoreflect.MessageD if len(name) > 0 && name[0] == '.' { name = name[1:] } - desc := interp.overrideDescriptorProto.FindDescriptorByName(protoreflect.FullName(name)) + desc := interp.overrideDescriptorProto.FindDescriptorByName(name) if md, ok := desc.(protoreflect.MessageDescriptor); ok { return md } @@ -271,46 +263,46 @@ func (interp *interpreter) nodeInfo(n ast.Node) ast.NodeInfo { return interp.file.FileNode().NodeInfo(n) } -func (interp *interpreter) interpretMessageOptions(fqn string, md *descriptorpb.DescriptorProto) error { - err := interpretElementOptions(interp, fqn, targetTypeMessage, md) +func (interp *interpreter) interpretMessageOptions(fqn string, md *descriptorpb.DescriptorProto, customOpts bool) error { + err := interpretElementOptions(interp, fqn, targetTypeMessage, md, customOpts) if err != nil { return err } for _, fld := range md.GetField() { fldFqn := fqn + "." + fld.GetName() - if err := interp.interpretFieldOptions(fldFqn, fld); err != nil { + if err := interp.interpretFieldOptions(fldFqn, fld, customOpts); err != nil { return err } } for _, ood := range md.GetOneofDecl() { oodFqn := fqn + "." + ood.GetName() - err := interpretElementOptions(interp, oodFqn, targetTypeOneof, ood) + err := interpretElementOptions(interp, oodFqn, targetTypeOneof, ood, customOpts) if err != nil { return err } } for _, fld := range md.GetExtension() { fldFqn := fqn + "." + fld.GetName() - if err := interp.interpretFieldOptions(fldFqn, fld); err != nil { + if err := interp.interpretFieldOptions(fldFqn, fld, customOpts); err != nil { return err } } for _, er := range md.GetExtensionRange() { erFqn := fmt.Sprintf("%s.%d-%d", fqn, er.GetStart(), er.GetEnd()) - err := interpretElementOptions(interp, erFqn, targetTypeExtensionRange, er) + err := interpretElementOptions(interp, erFqn, targetTypeExtensionRange, er, customOpts) if err != nil { return err } } for _, nmd := range md.GetNestedType() { nmdFqn := fqn + "." + nmd.GetName() - if err := interp.interpretMessageOptions(nmdFqn, nmd); err != nil { + if err := interp.interpretMessageOptions(nmdFqn, nmd, customOpts); err != nil { return err } } for _, ed := range md.GetEnumType() { edFqn := fqn + "." + ed.GetName() - if err := interp.interpretEnumOptions(edFqn, ed); err != nil { + if err := interp.interpretEnumOptions(edFqn, ed, customOpts); err != nil { return err } } @@ -354,12 +346,12 @@ func (interp *interpreter) interpretMessageOptions(fqn string, md *descriptorpb. var emptyFieldOptions = &descriptorpb.FieldOptions{} -func (interp *interpreter) interpretFieldOptions(fqn string, fld *descriptorpb.FieldDescriptorProto) error { +func (interp *interpreter) interpretFieldOptions(fqn string, fld *descriptorpb.FieldDescriptorProto, customOpts bool) error { opts := fld.GetOptions() emptyOptionsAlreadyPresent := opts != nil && len(opts.GetUninterpretedOption()) == 0 - // First process pseudo-options - if len(opts.GetUninterpretedOption()) > 0 { + // For non-custom phase, first process pseudo-options + if len(opts.GetUninterpretedOption()) > 0 && !customOpts { if err := interp.interpretFieldPseudoOptions(fqn, fld, opts); err != nil { return err } @@ -377,7 +369,7 @@ func (interp *interpreter) interpretFieldOptions(fqn string, fld *descriptorpb.F } // Then process actual options. - return interpretElementOptions(interp, fqn, targetTypeField, fld) + return interpretElementOptions(interp, fqn, targetTypeField, fld, customOpts) } func (interp *interpreter) interpretFieldPseudoOptions(fqn string, fld *descriptorpb.FieldDescriptorProto, opts *descriptorpb.FieldOptions) error { @@ -412,7 +404,7 @@ func (interp *interpreter) interpretFieldPseudoOptions(fqn string, fld *descript if strings.HasPrefix(name, "[") && strings.HasSuffix(name, "]") { return interp.HandleOptionValueErrorf(nil, optNode.GetVal(), "%s: option json_name value cannot start with '[' and end with ']'; that is reserved for representing extensions", scope) } - interp.descriptorIndex.OptionsToFieldDescriptors[opt] = resolveDescriptor[protoreflect.FieldDescriptor](interp.resolver, fqn) + interp.descriptorIndex.OptionsToFieldDescriptors[opt] = resolveDescriptor[protoreflect.FieldDescriptor](interp.resolver, protoreflect.FullName(fqn)) fld.JsonName = proto.String(jsonName) } @@ -420,7 +412,7 @@ func (interp *interpreter) interpretFieldPseudoOptions(fqn string, fld *descript if index, err := interp.processDefaultOption(scope, fqn, fld, uo); err != nil && !interp.lenient { return err } else if index >= 0 { - fldDesc := resolveDescriptor[protoreflect.FieldDescriptor](interp.resolver, fqn) + fldDesc := resolveDescriptor[protoreflect.FieldDescriptor](interp.resolver, protoreflect.FullName(fqn)) interp.descriptorIndex.OptionsToFieldDescriptors[uo[index]] = fldDesc nm := interp.file.OptionNamePartNode(uo[index].Name[0]) interp.descriptorIndex.FieldReferenceNodesToFieldDescriptors[nm] = fldDesc @@ -428,7 +420,7 @@ func (interp *interpreter) interpretFieldPseudoOptions(fqn string, fld *descript optNode := interp.file.OptionNode(uo[index]) interp.index[optNode] = &sourceinfo.OptionSourceInfo{Path: []int32{-1, protointernal.FieldDefaultTag}} - if fldDesc.Kind() == protoreflect.EnumKind { + if optNode != nil && fldDesc != nil && fldDesc.Kind() == protoreflect.EnumKind { interp.indexEnumValueRef(fldDesc, optNode.Val) } uo = protointernal.RemoveOption(uo, index) @@ -507,7 +499,7 @@ func (interp *interpreter) defaultValue(mc *protointernal.MessageContext, fld *d return -1, interp.HandleOptionForbiddenErrorf(mc, val, "default value cannot be a message") } if fld.GetType() == descriptorpb.FieldDescriptorProto_TYPE_ENUM { - ed := resolveDescriptor[protoreflect.EnumDescriptor](interp.resolver, fld.GetTypeName()) + ed := resolveDescriptor[protoreflect.EnumDescriptor](interp.resolver, protoreflect.FullName(fld.GetTypeName())) if ed == nil { return -1, interp.HandleOptionValueErrorf(mc, val, "unable to resolve enum type %q for field %q", fld.GetTypeName(), fld.GetName()) } @@ -525,7 +517,7 @@ func (interp *interpreter) defaultValueFromProto(mc *protointernal.MessageContex return -1, interp.HandleOptionValueErrorf(mc, node, "default value cannot be a message") } if fld.GetType() == descriptorpb.FieldDescriptorProto_TYPE_ENUM { - ed := resolveDescriptor[protoreflect.EnumDescriptor](interp.resolver, fld.GetTypeName()) + ed := resolveDescriptor[protoreflect.EnumDescriptor](interp.resolver, protoreflect.FullName(fld.GetTypeName())) if ed == nil { return -1, interp.HandleOptionValueErrorf(mc, node, "unable to resolve enum type %q for field %q", fld.GetTypeName(), fld.GetName()) } @@ -544,14 +536,14 @@ func encodeDefaultBytes(b []byte) string { return buf.String() } -func (interp *interpreter) interpretEnumOptions(fqn string, ed *descriptorpb.EnumDescriptorProto) error { - err := interpretElementOptions(interp, fqn, targetTypeEnum, ed) +func (interp *interpreter) interpretEnumOptions(fqn string, ed *descriptorpb.EnumDescriptorProto, customOpts bool) error { + err := interpretElementOptions(interp, fqn, targetTypeEnum, ed, customOpts) if err != nil { return err } for _, evd := range ed.GetValue() { evdFqn := fqn + "." + evd.GetName() - err := interpretElementOptions(interp, evdFqn, targetTypeEnumValue, evd) + err := interpretElementOptions(interp, evdFqn, targetTypeEnumValue, evd, customOpts) if err != nil { return err } @@ -559,319 +551,17 @@ func (interp *interpreter) interpretEnumOptions(fqn string, ed *descriptorpb.Enu return nil } -// interpretedOption represents the result of interpreting an option. -// This includes metadata that allows the option to be serialized to -// bytes in a way that is deterministic and can preserve the structure -// of the source (the way the options are de-structured and the order in -// which options appear). -type interpretedOption struct { - unknown bool - pathPrefix []int32 - interpretedField -} - -func (o *interpretedOption) toSourceInfo() *sourceinfo.OptionSourceInfo { - return o.interpretedField.toSourceInfo(o.pathPrefix) -} - -func (o *interpretedOption) appendOptionBytes(b []byte) ([]byte, error) { - return o.appendOptionBytesWithPath(b, o.pathPrefix) -} - -func (o *interpretedOption) appendOptionBytesWithPath(b []byte, path []int32) ([]byte, error) { - if len(path) == 0 { - return appendOptionBytesSingle(b, &o.interpretedField) - } - // NB: if we add functions to compute sizes of the options first, we could - // allocate precisely sized slice up front, which would be more efficient than - // repeated creation/growing/concatenation. - enclosed, err := o.appendOptionBytesWithPath(nil, path[1:]) - if err != nil { - return nil, err - } - b = protowire.AppendTag(b, protowire.Number(path[0]), protowire.BytesType) - return protowire.AppendBytes(b, enclosed), nil -} - -// interpretedField represents a field in an options message that is the -// result of interpreting an option. This is used for the option value -// itself as well as for subfields when an option value is a message -// literal. -type interpretedField struct { - // the AST node for this field -- an [*ast.OptionNode] for top-level options, - // an [*ast.MessageFieldNode] for fields in a message literal, or nil for - // synthetic field values (for keys or values in map entries that were - // omitted from source). - node ast.Node - // field number - number int32 - // index of this element inside a repeated field; only set if repeated == true - index int32 - // true if this is a repeated field - repeated bool - // true if this is a repeated field that stores scalar values in packed form - packed bool - // the field's kind - kind protoreflect.Kind - - value interpretedFieldValue -} - -func (f *interpretedField) path(prefix []int32) []int32 { - path := make([]int32, 0, len(prefix)+2) - path = append(path, prefix...) - path = append(path, f.number) - if f.repeated { - path = append(path, f.index) - } - return path -} - -func (f *interpretedField) toSourceInfo(prefix []int32) *sourceinfo.OptionSourceInfo { - path := f.path(prefix) - var children sourceinfo.OptionChildrenSourceInfo - if len(f.value.msgListVal) > 0 { - elements := make([]sourceinfo.OptionSourceInfo, len(f.value.msgListVal)) - for i, msgVal := range f.value.msgListVal { - // With an array literal, the index in path is that of the first element. - elementPath := append(([]int32)(nil), path...) - elementPath[len(elementPath)-1] += int32(i) - elements[i].Path = elementPath - elements[i].Children = msgSourceInfo(elementPath, msgVal) - } - children = &sourceinfo.ArrayLiteralSourceInfo{Elements: elements} - } else if len(f.value.msgVal) > 0 { - children = msgSourceInfo(path, f.value.msgVal) - } - return &sourceinfo.OptionSourceInfo{ - Path: path, - Children: children, - } -} - -func msgSourceInfo(prefix []int32, fields []*interpretedField) *sourceinfo.MessageLiteralSourceInfo { - fieldInfo := map[*ast.MessageFieldNode]*sourceinfo.OptionSourceInfo{} - for _, field := range fields { - msgFieldNode, ok := field.node.(*ast.MessageFieldNode) - if !ok { - continue - } - fieldInfo[msgFieldNode] = field.toSourceInfo(prefix) - } - return &sourceinfo.MessageLiteralSourceInfo{Fields: fieldInfo} -} - -// interpretedFieldValue is a wrapper around protoreflect.Value that -// includes extra metadata. -type interpretedFieldValue struct { - // the bytes for this field value if already pre-serialized - // (when this is set, the other fields are ignored) - preserialized []byte - - // the field value - val protoreflect.Value - // if true, this value is a list of values, not a singular value - isList bool - // non-nil for singular message values - msgVal []*interpretedField - // non-nil for non-empty lists of message values - msgListVal [][]*interpretedField -} - -func appendOptionBytes(b []byte, flds []*interpretedField) ([]byte, error) { - // protoc emits messages sorted by field number - if len(flds) > 1 { - sort.SliceStable(flds, func(i, j int) bool { - return flds[i].number < flds[j].number - }) - } - - for i := 0; i < len(flds); i++ { - f := flds[i] - if f.value.preserialized != nil { - b = append(b, f.value.preserialized...) - continue - } - switch { - case f.packed && protointernal.CanPack(f.kind): - // for packed repeated numeric fields, all runs of values are merged into one packed list - num := f.number - j := i - for j < len(flds) && flds[j].number == num { - j++ - } - // now flds[i:j] is the range of contiguous fields for the same field number - enclosed, err := appendOptionBytesPacked(nil, f.kind, flds[i:j]) - if err != nil { - return nil, err - } - b = protowire.AppendTag(b, protowire.Number(f.number), protowire.BytesType) - b = protowire.AppendBytes(b, enclosed) - // skip over the other subsequent fields we just serialized - i = j - 1 - case f.value.isList: - // if not packed, then emit one value at a time - single := *f - single.value.isList = false - single.value.msgListVal = nil - l := f.value.val.List() - for i := 0; i < l.Len(); i++ { - single.value.val = l.Get(i) - if f.kind == protoreflect.MessageKind || f.kind == protoreflect.GroupKind { - single.value.msgVal = f.value.msgListVal[i] - } - var err error - b, err = appendOptionBytesSingle(b, &single) - if err != nil { - return nil, err - } - } - default: - // simple singular value - var err error - b, err = appendOptionBytesSingle(b, f) - if err != nil { - return nil, err - } - } - } - - return b, nil -} - -func appendOptionBytesPacked(b []byte, k protoreflect.Kind, flds []*interpretedField) ([]byte, error) { - for i := range flds { - val := flds[i].value - if val.isList { - l := val.val.List() - var err error - b, err = appendNumericValueBytesPacked(b, k, l) - if err != nil { - return nil, err - } - } else { - var err error - b, err = appendNumericValueBytes(b, k, val.val) - if err != nil { - return nil, err - } - } - } - return b, nil -} - -func appendOptionBytesSingle(b []byte, f *interpretedField) ([]byte, error) { - if f.value.preserialized != nil { - return append(b, f.value.preserialized...), nil - } - num := protowire.Number(f.number) - switch f.kind { - case protoreflect.MessageKind: - enclosed, err := appendOptionBytes(nil, f.value.msgVal) - if err != nil { - return nil, err - } - b = protowire.AppendTag(b, num, protowire.BytesType) - return protowire.AppendBytes(b, enclosed), nil - - case protoreflect.GroupKind: - b = protowire.AppendTag(b, num, protowire.StartGroupType) - var err error - b, err = appendOptionBytes(b, f.value.msgVal) - if err != nil { - return nil, err - } - return protowire.AppendTag(b, num, protowire.EndGroupType), nil - - case protoreflect.StringKind: - b = protowire.AppendTag(b, num, protowire.BytesType) - return protowire.AppendString(b, f.value.val.String()), nil - - case protoreflect.BytesKind: - b = protowire.AppendTag(b, num, protowire.BytesType) - return protowire.AppendBytes(b, f.value.val.Bytes()), nil - - case protoreflect.Int32Kind, protoreflect.Int64Kind, protoreflect.Uint32Kind, protoreflect.Uint64Kind, - protoreflect.Sint32Kind, protoreflect.Sint64Kind, protoreflect.EnumKind, protoreflect.BoolKind: - b = protowire.AppendTag(b, num, protowire.VarintType) - return appendNumericValueBytes(b, f.kind, f.value.val) - - case protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind, protoreflect.FloatKind: - b = protowire.AppendTag(b, num, protowire.Fixed32Type) - return appendNumericValueBytes(b, f.kind, f.value.val) - - case protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind, protoreflect.DoubleKind: - b = protowire.AppendTag(b, num, protowire.Fixed64Type) - return appendNumericValueBytes(b, f.kind, f.value.val) - - default: - return nil, fmt.Errorf("unknown field kind: %v", f.kind) - } -} - -func appendNumericValueBytesPacked(b []byte, k protoreflect.Kind, l protoreflect.List) ([]byte, error) { - for i := 0; i < l.Len(); i++ { - var err error - b, err = appendNumericValueBytes(b, k, l.Get(i)) - if err != nil { - return nil, err - } - } - return b, nil -} - -func appendNumericValueBytes(b []byte, k protoreflect.Kind, v protoreflect.Value) ([]byte, error) { - switch k { - case protoreflect.Int32Kind, protoreflect.Int64Kind: - return protowire.AppendVarint(b, uint64(v.Int())), nil - case protoreflect.Uint32Kind, protoreflect.Uint64Kind: - return protowire.AppendVarint(b, v.Uint()), nil - case protoreflect.Sint32Kind, protoreflect.Sint64Kind: - return protowire.AppendVarint(b, protowire.EncodeZigZag(v.Int())), nil - case protoreflect.Fixed32Kind: - return protowire.AppendFixed32(b, uint32(v.Uint())), nil - case protoreflect.Fixed64Kind: - return protowire.AppendFixed64(b, v.Uint()), nil - case protoreflect.Sfixed32Kind: - return protowire.AppendFixed32(b, uint32(v.Int())), nil - case protoreflect.Sfixed64Kind: - return protowire.AppendFixed64(b, uint64(v.Int())), nil - case protoreflect.FloatKind: - return protowire.AppendFixed32(b, math.Float32bits(float32(v.Float()))), nil - case protoreflect.DoubleKind: - return protowire.AppendFixed64(b, math.Float64bits(v.Float())), nil - case protoreflect.BoolKind: - return protowire.AppendVarint(b, protowire.EncodeBool(v.Bool())), nil - case protoreflect.EnumKind: - return protowire.AppendVarint(b, uint64(v.Enum())), nil - default: - return nil, fmt.Errorf("unknown field kind: %v", k) - } -} - -// optionsContainer may be optionally implemented by a linker.Result. It is -// not part of the linker.Result interface as it is meant only for internal use. -// This allows the option interpreter step to store extra metadata about the -// serialized structure of options. -type optionsContainer interface { - // AddOptionBytes adds the given pre-serialized option bytes to a file, - // associated with the given options message. The type of the given message - // should be an options message, for example *descriptorpb.MessageOptions. - // This value should be part of the message hierarchy whose root is the - // *descriptorpb.FileDescriptorProto that corresponds to this result. - AddOptionBytes(pm proto.Message, opts []byte) -} - func interpretElementOptions[Elem elementType[OptsStruct, Opts], OptsStruct any, Opts optionsType[OptsStruct]]( interp *interpreter, fqn string, target *targetType[Elem, OptsStruct, Opts], elem Elem, + customOpts bool, ) error { opts := elem.GetOptions() uo := opts.GetUninterpretedOption() if len(uo) > 0 { - remain, err := interp.interpretOptions(fqn, target.t, elem, opts, uo) + remain, err := interp.interpretOptions(fqn, target.t, elem, opts, uo, customOpts) if err != nil { return err } @@ -890,9 +580,10 @@ func (interp *interpreter) interpretOptions( targetType descriptorpb.FieldOptions_OptionTargetType, element, opts proto.Message, uninterpreted []*descriptorpb.UninterpretedOption, + customOpts bool, ) ([]*descriptorpb.UninterpretedOption, error) { optsDesc := opts.ProtoReflect().Descriptor() - optsFqn := string(optsDesc.FullName()) + optsFqn := optsDesc.FullName() var msg protoreflect.Message // see if the parse included an override copy for these options if md := interp.resolveOptionsType(optsFqn); md != nil { @@ -912,12 +603,16 @@ func (interp *interpreter) interpretOptions( ElementType: descriptorType(element), } var remain []*descriptorpb.UninterpretedOption - results := make([]*interpretedOption, 0, len(uninterpreted)) - var featuresInfo []*interpretedOption + var features []*ast.OptionNode for _, uo := range uninterpreted { if len(uo.Name) == 0 { continue } + if uo.Name[0].GetIsExtension() != customOpts { + // We're not looking at these this phase. + remain = append(remain, uo) + continue + } node := interp.file.OptionNode(uo) if !uo.Name[0].GetIsExtension() && uo.Name[0].GetNamePart() == "uninterpreted_option" { if interp.lenient { @@ -931,7 +626,7 @@ func (interp *interpreter) interpretOptions( continue } mc.Option = uo - res, err := interp.interpretField(mc, msg, uo, 0, nil) + srcInfo, err := interp.interpretField(mc, msg, uo, 0, interp.pathBuffer) if err != nil { if interp.lenient { remain = append(remain, uo) @@ -939,27 +634,15 @@ func (interp *interpreter) interpretOptions( } return nil, err } - if res == nil { - if interp.lenient { - remain = append(remain, uo) - continue - } - return nil, interp.handler.Error() - // if err := interp.reporter.HandleErrorf(interp.nodeInfo(node.GetName()), "%vunknown option", mc); err != nil { - // return nil, err - // } - // continue - } - res.unknown = !isKnownField(optsDesc, res) - results = append(results, res) if !uo.Name[0].GetIsExtension() && uo.Name[0].GetNamePart() == featuresFieldName { - featuresInfo = append(featuresInfo, res) + features = append(features, node) + } + if srcInfo != nil { + interp.index[node] = srcInfo } - si := res.toSourceInfo() - interp.index[node] = si } - if err := interp.validateFeatures(targetType, msg, featuresInfo); err != nil && !interp.lenient { + if err := interp.validateFeatures(targetType, msg, features); err != nil && !interp.lenient { return nil, err } @@ -978,14 +661,6 @@ func (interp *interpreter) interpretOptions( proto.Reset(opts) proto.Merge(opts, optsClone) - if interp.container != nil { - b, err := interp.toOptionBytes(mc, results) - if err != nil { - return nil, err - } - interp.container.AddOptionBytes(opts, b) - } - return remain, nil } @@ -1000,21 +675,13 @@ func (interp *interpreter) interpretOptions( return nil, interp.HandleOptionValueErrorf(nil, node, "error in %s options: %w", descriptorType(element), err) } - if interp.container != nil { - b, err := interp.toOptionBytes(mc, results) - if err != nil { - return nil, err - } - interp.container.AddOptionBytes(opts, b) - } - - return nil, nil + return remain, nil } func (interp *interpreter) validateFeatures( targetType descriptorpb.FieldOptions_OptionTargetType, opts protoreflect.Message, - featuresInfo []*interpretedOption, + features []*ast.OptionNode, ) error { fld := opts.Descriptor().Fields().ByName(featuresFieldName) if fld == nil { @@ -1026,9 +693,9 @@ func (interp *interpreter) validateFeatures( // TODO: should this return an error? return nil } - features := opts.Get(fld).Message() + featureSet := opts.Get(fld).Message() var err error - features.Range(func(featureField protoreflect.FieldDescriptor, _ protoreflect.Value) bool { + featureSet.Range(func(featureField protoreflect.FieldDescriptor, _ protoreflect.Value) bool { opts, ok := featureField.Options().(*descriptorpb.FieldOptions) if !ok { return true @@ -1046,7 +713,7 @@ func (interp *interpreter) validateFeatures( for i, t := range opts.Targets { allowedTypes[i] = targetTypeString(t) } - node := interp.positionOfFeature(featuresInfo, fld.Number(), featureField.Number()) + node := interp.positionOfFeature(features, featuresFieldName, featureField.Name()) if len(opts.Targets) == 1 && opts.Targets[0] == descriptorpb.FieldOptions_TARGET_TYPE_UNKNOWN { err = interp.HandleOptionForbiddenErrorf(nil, node, "feature field %q may not be used explicitly", featureField.Name()) } else { @@ -1058,52 +725,52 @@ func (interp *interpreter) validateFeatures( return err } -func (interp *interpreter) positionOfFeature(featuresInfo []*interpretedOption, fieldNumbers ...protoreflect.FieldNumber) ast.Node { +func (interp *interpreter) positionOfFeature(features []*ast.OptionNode, fieldNames ...protoreflect.Name) ast.Node { if interp.file.AST() == nil { return &ast.NoSourceNode{Filename: interp.file.FileDescriptorProto().GetName()} } - for _, info := range featuresInfo { - matched, remainingNumbers, node := matchInterpretedOption(info, fieldNumbers) + for _, feature := range features { + matched, remainingNames, nodePos, nodeValue := matchInterpretedOption(feature, fieldNames) if !matched { continue } - if len(remainingNumbers) > 0 { - node = findInterpretedFieldForFeature(&(info.interpretedField), remainingNumbers) + if len(remainingNames) > 0 { + nodePos = findInterpretedFieldForFeature(nodePos, nodeValue, remainingNames) } - if node != nil { - return node + if nodePos != nil { + return nodePos } } return &ast.NoSourceNode{Filename: interp.file.FileDescriptorProto().GetName()} } -func matchInterpretedOption(info *interpretedOption, path []protoreflect.FieldNumber) (bool, []protoreflect.FieldNumber, ast.Node) { - for i := 0; i < len(path) && i < len(info.pathPrefix); i++ { - if info.pathPrefix[i] != int32(path[i]) { - return false, nil, nil +func matchInterpretedOption(node *ast.OptionNode, path []protoreflect.Name) (bool, []protoreflect.Name, ast.Node, *ast.ValueNode) { + parts := node.Name.FilterFieldReferences() + for i := 0; i < len(path) && i < len(parts); i++ { + part := parts[i] + if !part.IsExtension() && protoreflect.Name(part.Name.AsIdentifier()) != path[i] { + return false, nil, nil, nil } } - if len(path) <= len(info.pathPrefix) { - // no more path elements to match - node := info.node - if optsNode, ok := node.(*ast.OptionNode); ok { - node = optsNode.Name.Parts[len(path)-1].Unwrap() - } - return true, nil, node - } - if info.number != int32(path[len(info.pathPrefix)]) { - return false, nil, nil + if len(path) <= len(node.Name.Parts) { + // No more path elements to match. Report location + // of the final element of path inside option name. + return true, nil, node.Name.Parts[len(path)-1], node.Val } - return true, path[len(info.pathPrefix)+1:], info.node + return true, path[len(node.Name.Parts):], node.Name.Parts[len(node.Name.Parts)-1], node.Val } -func findInterpretedFieldForFeature(opt *interpretedField, path []protoreflect.FieldNumber) ast.Node { +func findInterpretedFieldForFeature(nodePos ast.Node, nodeValue *ast.ValueNode, path []protoreflect.Name) ast.Node { if len(path) == 0 { - return opt.node + return nodePos + } + msgNode := nodeValue.GetMessageLiteral() + if msgNode == nil { + return nil } - for _, fld := range opt.value.msgVal { - if fld.number == int32(path[0]) { - if res := findInterpretedFieldForFeature(fld, path[1:]); res != nil { + for _, fldNode := range msgNode.Elements { + if fldNode.Name.Open == nil && protoreflect.Name(fldNode.Name.Name.AsIdentifier()) == path[0] { + if res := findInterpretedFieldForFeature(fldNode.Name, fldNode.Val, path[1:]); res != nil { return res } } @@ -1111,68 +778,6 @@ func findInterpretedFieldForFeature(opt *interpretedField, path []protoreflect.F return nil } -// isKnownField returns true if the given option is for a known field of the -// given options message descriptor and will be serialized using the expected -// wire type for that known field. -func isKnownField(desc protoreflect.MessageDescriptor, opt *interpretedOption) bool { - var num int32 - if len(opt.pathPrefix) > 0 { - num = opt.pathPrefix[0] - } else { - num = opt.number - } - fd := desc.Fields().ByNumber(protoreflect.FieldNumber(num)) - if fd == nil { - return false - } - - // Before the full wire type check, we do a quick check that will usually pass - // and allow us to short-circuit the logic below. - if fd.IsList() == opt.repeated && fd.Kind() == opt.kind { - return true - } - - // We figure out the wire type this interpreted field will use when serialized. - var wireType protowire.Type - switch { - case len(opt.pathPrefix) > 0: - // If path prefix exists, this field is nested inside a message. - // And messages use bytes wire type. - wireType = protowire.BytesType - case opt.repeated && opt.packed && protointernal.CanPack(opt.kind): - // Packed repeated numeric scalars use bytes wire type. - wireType = protowire.BytesType - default: - wireType = wireTypeForKind(opt.kind) - } - - // And then we see if the wire type we just determined is compatible with - // the field descriptor we found. - if fd.IsList() && protointernal.CanPack(fd.Kind()) && wireType == protowire.BytesType { - // Even if fd.IsPacked() is false, bytes type is still accepted for - // repeated scalar numerics, so that changing a repeated field from - // packed to not-packed (or vice versa) is a compatible change. - return true - } - return wireType == wireTypeForKind(fd.Kind()) -} - -func wireTypeForKind(kind protoreflect.Kind) protowire.Type { - switch kind { - case protoreflect.StringKind, protoreflect.BytesKind, protoreflect.MessageKind: - return protowire.BytesType - case protoreflect.GroupKind: - return protowire.StartGroupType - case protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind, protoreflect.FloatKind: - return protowire.Fixed32Type - case protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind, protoreflect.DoubleKind: - return protowire.Fixed64Type - default: - // everything else uses varint - return protowire.VarintType - } -} - func targetTypeString(t descriptorpb.FieldOptions_OptionTargetType) string { return strings.ToLower(strings.ReplaceAll(strings.TrimPrefix(t.String(), "TARGET_TYPE_"), "_", " ")) } @@ -1194,36 +799,6 @@ func cloneInto(dest proto.Message, src proto.Message, res linker.Resolver) error return proto.UnmarshalOptions{Resolver: res}.Unmarshal(data, dest) } -func (interp *interpreter) toOptionBytes(mc *protointernal.MessageContext, results []*interpretedOption) ([]byte, error) { - // protoc emits non-custom options in tag order and then - // the rest are emitted in the order they are defined in source - sort.SliceStable(results, func(i, j int) bool { - if !results[i].unknown && results[j].unknown { - return true - } - if !results[i].unknown && !results[j].unknown { - return results[i].number < results[j].number - } - return false - }) - var b []byte - for _, res := range results { - var err error - b, err = res.appendOptionBytes(b) - if err != nil { - if _, ok := err.(reporter.ErrorWithPos); !ok { - span := ast.UnknownSpan(interp.file.AST().Name()) - // TODO: this error is unusual - err = reporter.Errorf(span, "%sfailed to encode options: %w", mc, err) - } - if err := interp.handler.HandleError(err); err != nil { - return nil, err - } - } - } - return b, nil -} - func validateRecursive(msg protoreflect.Message, prefix string) error { flds := msg.Descriptor().Fields() var missingFields []string @@ -1283,7 +858,13 @@ func validateRecursive(msg protoreflect.Message, prefix string) error { // msg must be an options message. For nameIndex > 0, msg is a nested message inside of the // options message. The given pathPrefix is the path (sequence of field numbers and indices // with a FileDescriptorProto as the start) up to but not including the given nameIndex. -func (interp *interpreter) interpretField(mc *protointernal.MessageContext, msg protoreflect.Message, opt *descriptorpb.UninterpretedOption, nameIndex int, pathPrefix []int32) (*interpretedOption, error) { +func (interp *interpreter) interpretField( + mc *protointernal.MessageContext, + msg protoreflect.Message, + opt *descriptorpb.UninterpretedOption, + nameIndex int, + pathPrefix []int32, +) (*sourceinfo.OptionSourceInfo, error) { var fld protoreflect.FieldDescriptor nm := opt.GetName()[nameIndex] node := interp.file.OptionNamePartNode(nm) @@ -1309,6 +890,7 @@ func (interp *interpreter) interpretField(mc *protointernal.MessageContext, msg interp.descriptorIndex.UninterpretedNameDescriptorsToFieldDescriptors[nm] = fld interp.descriptorIndex.FieldReferenceNodesToFieldDescriptors[node] = fld interp.descriptorIndex.OptionsToFieldDescriptors[opt] = fld + pathPrefix = append(pathPrefix, int32(fld.Number())) if len(opt.GetName()) > nameIndex+1 { nextnm := opt.GetName()[nameIndex+1] @@ -1336,277 +918,194 @@ func (interp *interpreter) interpretField(mc *protointernal.MessageContext, msg msg.Set(fld, fldVal) } // recurse to set next part of name - return interp.interpretField(mc, fdm, opt, nameIndex+1, append(pathPrefix, int32(fld.Number()))) + return interp.interpretField(mc, fdm, opt, nameIndex+1, pathPrefix) } optNode := interp.file.OptionNode(opt) optValNode := optNode.GetVal() - var val interpretedFieldValue - var index int + + var srcInfo *sourceinfo.OptionSourceInfo var err error - if optValNode == nil { - // We don't have an AST, so we get the value from the uninterpreted option proto. - // It's okay that we don't populate index as it is used to populate source code info, - // which can't be done without an AST. - val, err = interp.setOptionFieldFromProto(mc, msg, fld, node, opt, optNode) + if optValNode.Value() == nil { + err = interp.setOptionFieldFromProto(mc, msg, fld, node, opt, optValNode) + srcInfoVal := newSrcInfo(pathPrefix, nil) + srcInfo = &srcInfoVal } else { - val, index, err = interp.setOptionField(mc, msg, fld, node, optValNode, false) - - // index enum value references used as extension values: - // option (foo) = SomeEnumValue; - if fld.Kind() == protoreflect.EnumKind { - interp.indexEnumValueRef(fld, optValNode) - } + srcInfo, err = interp.setOptionField(mc, msg, fld, node, optValNode, false, pathPrefix) } - // err is returned along with the interpretedOption, the caller decides - // what to do with it depending on its leniency setting - - for _, interpretedField := range val.msgVal { - if interpretedField.node == nil { - continue - } - fields := fld.Message().Fields() - fieldDesc := fields.ByNumber(protowire.Number(interpretedField.number)) - if fieldDesc == nil { - continue - } - // recursively index field references; this only happens once per top level - // message literal (if the field is a message literal) - interp.indexInterpretedFieldsRecursive(interpretedField, fieldDesc) - - // index enum value references used as field values: - // option (foo) = { - // bar: SomeEnumValue - // baz: SomeOtherEnumValue - // }; - if fieldDesc.Kind() == protoreflect.EnumKind { - switch v := interpretedField.node.(type) { - case *ast.MessageFieldNode: - interp.indexEnumValueRef(fieldDesc, v.Val) - } - } - } - - return &interpretedOption{ - pathPrefix: pathPrefix, - interpretedField: interpretedField{ - node: optNode, - number: int32(fld.Number()), - index: int32(index), - kind: fld.Kind(), - repeated: fld.Cardinality() == protoreflect.Repeated, - value: val, - // NB: don't set packed here in a top-level option - // (only values in message literals will be serialized - // in packed format) - }, - }, err -} - -func (interp *interpreter) indexEnumValueRef(fld protoreflect.FieldDescriptor, optValNode *ast.ValueNode) { - enumDesc := fld.Enum() - switch v := optValNode.Unwrap().(type) { - case *ast.IdentNode: - interp.descriptorIndex.EnumValueIdentNodesToEnumValueDescriptors[v] = enumDesc.Values().ByName(protoreflect.Name(v.AsIdentifier())) - } -} - -func (interp *interpreter) indexInterpretedFieldsRecursive(interpretedField *interpretedField, fieldDesc protoreflect.FieldDescriptor) { - if interpretedField.node == nil || fieldDesc == nil { - return - } - interp.descriptorIndex.FieldReferenceNodesToFieldDescriptors[interpretedField.node] = fieldDesc - for _, f := range interpretedField.value.msgVal { - interp.indexInterpretedFieldsRecursive(f, fieldDesc.Message().Fields().ByNumber(protowire.Number(f.number))) + if err != nil { + return nil, interp.handler.HandleError(err) } + return srcInfo, nil } // setOptionField sets the value for field fld in the given message msg to the value represented // by AST node val. The given name is the AST node that corresponds to the name of fld. On success, // it returns additional metadata about the field that was set. -func (interp *interpreter) setOptionField(mc *protointernal.MessageContext, msg protoreflect.Message, fld protoreflect.FieldDescriptor, name ast.Node, val *ast.ValueNode, insideMsgLiteral bool) (interpretedFieldValue, int, error) { +func (interp *interpreter) setOptionField( + mc *protointernal.MessageContext, + msg protoreflect.Message, + fld protoreflect.FieldDescriptor, + name ast.Node, + val *ast.ValueNode, + insideMsgLiteral bool, + pathPrefix []int32, +) (*sourceinfo.OptionSourceInfo, error) { v := val.Value() if sl, ok := v.([]*ast.ValueNode); ok { // handle slices a little differently than the others if fld.Cardinality() != protoreflect.Repeated { - return interpretedFieldValue{}, 0, interp.HandleOptionForbiddenErrorf(mc, val, "value is an array but field is not repeated") + return nil, reporter.Errorf(interp.nodeInfo(val), "%vvalue is an array but field is not repeated", mc) } origPath := mc.OptAggPath defer func() { mc.OptAggPath = origPath }() - var resVal listValue - var resMsgVals [][]*interpretedField + childVals := make([]sourceinfo.OptionSourceInfo, len(sl)) var firstIndex int + if fld.IsMap() { + firstIndex = msg.Get(fld).Map().Len() + } else { + firstIndex = msg.Get(fld).List().Len() + } for index, item := range sl { mc.OptAggPath = fmt.Sprintf("%s[%d]", origPath, index) - value, err := interp.fieldValue(mc, msg, fld, item, insideMsgLiteral) + value, srcInfo, err := interp.fieldValue(mc, msg, fld, item, insideMsgLiteral, append(pathPrefix, int32(firstIndex+index))) if err != nil { - return interpretedFieldValue{}, 0, err + return nil, err } if fld.IsMap() { mv := msg.Mutable(fld).Map() - if index == 0 { - firstIndex = mv.Len() - } - setMapEntry(fld, msg, mv, &value) + setMapEntry(fld, msg, mv, value.Message()) } else { lv := msg.Mutable(fld).List() - if index == 0 { - firstIndex = lv.Len() - } - lv.Append(value.val) - } - resVal = append(resVal, value.val) - if value.msgVal != nil { - resMsgVals = append(resMsgVals, value.msgVal) + lv.Append(value) } + childVals[index] = srcInfo } - return interpretedFieldValue{ - isList: true, - val: protoreflect.ValueOfList(&resVal), - msgListVal: resMsgVals, - }, firstIndex, nil + srcInfo := newSrcInfo(append(pathPrefix, int32(firstIndex)), &sourceinfo.ArrayLiteralSourceInfo{Elements: childVals}) + return &srcInfo, nil } - value, err := interp.fieldValue(mc, msg, fld, val, insideMsgLiteral) + if fld.IsMap() { + pathPrefix = append(pathPrefix, int32(msg.Get(fld).Map().Len())) + } else if fld.IsList() { + pathPrefix = append(pathPrefix, int32(msg.Get(fld).List().Len())) + } + + value, srcInfo, err := interp.fieldValue(mc, msg, fld, val, insideMsgLiteral, pathPrefix) if err != nil { - return interpretedFieldValue{}, 0, err + return nil, err } - if !value.val.IsValid() { - return interpretedFieldValue{}, 0, interp.HandleOptionValueErrorf(mc, val, "invalid value") + if !value.IsValid() { + return nil, interp.HandleOptionValueErrorf(mc, val, "invalid value") } if ood := fld.ContainingOneof(); ood != nil { existingFld := msg.WhichOneof(ood) if existingFld != nil && existingFld.Number() != fld.Number() { - return interpretedFieldValue{}, 0, interp.HandleOptionForbiddenErrorf(mc, name, "oneof %q already has field %q set", ood.Name(), fieldName(existingFld)) + return nil, reporter.Errorf(interp.nodeInfo(name), "%voneof %q already has field %q set", mc, ood.Name(), fieldName(existingFld)) } } - var index int switch { case fld.IsMap(): mv := msg.Mutable(fld).Map() - index = mv.Len() - setMapEntry(fld, msg, mv, &value) + setMapEntry(fld, msg, mv, value.Message()) case fld.IsList(): lv := msg.Mutable(fld).List() - index = lv.Len() - lv.Append(value.val) + lv.Append(value) default: if msg.Has(fld) { - return interpretedFieldValue{}, 0, interp.HandleOptionForbiddenErrorf(mc, name, "non-repeated option field %s already set", fieldName(fld)) + return nil, reporter.Errorf(interp.nodeInfo(name), "%vnon-repeated option field %s already set", mc, fieldName(fld)) } - msg.Set(fld, value.val) + msg.Set(fld, value) } - return value, index, nil + return &srcInfo, nil } // setOptionFieldFromProto sets the value for field fld in the given message msg to the value // represented by the given uninterpreted option. The given ast.Node, if non-nil, will be used // to report source positions in error messages. On success, it returns additional metadata // about the field that was set. -func (interp *interpreter) setOptionFieldFromProto(mc *protointernal.MessageContext, msg protoreflect.Message, fld protoreflect.FieldDescriptor, name ast.Node, opt *descriptorpb.UninterpretedOption, node ast.Node) (interpretedFieldValue, error) { +func (interp *interpreter) setOptionFieldFromProto( + mc *protointernal.MessageContext, + msg protoreflect.Message, + fld protoreflect.FieldDescriptor, + name ast.Node, + opt *descriptorpb.UninterpretedOption, + node ast.Node, +) error { k := fld.Kind() - var value interpretedFieldValue + var value protoreflect.Value switch k { case protoreflect.EnumKind: num, _, err := interp.enumFieldValueFromProto(mc, fld.Enum(), opt, node) if err != nil { - return interpretedFieldValue{}, err + return err } - value = interpretedFieldValue{val: protoreflect.ValueOfEnum(num)} + value = protoreflect.ValueOfEnum(num) case protoreflect.MessageKind, protoreflect.GroupKind: if opt.AggregateValue == nil { - return interpretedFieldValue{}, interp.HandleOptionValueErrorf(mc, node, "expecting message, got %s", optionValueKind(opt)) + return reporter.Errorf(interp.nodeInfo(node), "%vexpecting message, got %s", mc, optionValueKind(opt)) } // We must parse the text format from the aggregate value string - fmd := fld.Message() - tmpMsg := dynamicpb.NewMessage(fmd) + var elem protoreflect.Message + switch { + case fld.IsMap(): + elem = dynamicpb.NewMessage(fld.Message()) + case fld.IsList(): + elem = msg.Get(fld).List().NewElement().Message() + default: + elem = msg.NewField(fld).Message() + } err := prototext.UnmarshalOptions{ Resolver: &msgLiteralResolver{interp: interp, pkg: fld.ParentFile().Package()}, AllowPartial: true, - }.Unmarshal([]byte(opt.GetAggregateValue()), tmpMsg) + }.Unmarshal([]byte(opt.GetAggregateValue()), elem.Interface()) if err != nil { - return interpretedFieldValue{}, interp.HandleOptionValueErrorf(mc, node, "failed to parse message literal: %w", err) + return reporter.Errorf(interp.nodeInfo(node), "%vfailed to parse message literal %w", mc, err) } - msgData, err := proto.MarshalOptions{ - AllowPartial: true, - }.Marshal(tmpMsg) - if err != nil { - return interpretedFieldValue{}, interp.HandleOptionValueErrorf(mc, node, "failed to serialize data from message literal: %w", err) - } - var data []byte - if k == protoreflect.GroupKind { - data = protowire.AppendTag(data, fld.Number(), protowire.StartGroupType) - data = append(data, msgData...) - data = protowire.AppendTag(data, fld.Number(), protowire.EndGroupType) - } else { - data = protowire.AppendTag(nil, fld.Number(), protowire.BytesType) - data = protowire.AppendBytes(data, msgData) - } - // NB: At this point, the serialized fields may no longer be in the same - // order as in the text literal. So for this case, the linker result's - // CanonicalProto won't be in *quite* the right order. ¯\_(ツ)_/¯ - value = interpretedFieldValue{preserialized: data} - + value = protoreflect.ValueOfMessage(elem) default: v, err := interp.scalarFieldValueFromProto(mc, descriptorpb.FieldDescriptorProto_Type(k), opt, node) if err != nil { - return interpretedFieldValue{}, err + return err } - value = interpretedFieldValue{val: protoreflect.ValueOf(v)} + value = protoreflect.ValueOf(v) } if ood := fld.ContainingOneof(); ood != nil { existingFld := msg.WhichOneof(ood) if existingFld != nil && existingFld.Number() != fld.Number() { - return interpretedFieldValue{}, interp.HandleOptionForbiddenErrorf(mc, name, "oneof %q already has field %q set", ood.Name(), fieldName(existingFld)) + return reporter.Errorf(interp.nodeInfo(name), "%voneof %q already has field %q set", mc, ood.Name(), fieldName(existingFld)) } } switch { - case value.preserialized != nil: - if !fld.IsList() && !fld.IsMap() && msg.Has(fld) { - return interpretedFieldValue{}, interp.HandleOptionForbiddenErrorf(mc, name, "non-repeated option field %s already set", fieldName(fld)) - } - // We have to merge the bytes for this field into the message. - // TODO: if a map field, error if key for this entry already set? - err := proto.UnmarshalOptions{ - Resolver: &msgLiteralResolver{interp: interp, pkg: fld.ParentFile().Package()}, - AllowPartial: true, - Merge: true, - }.Unmarshal(value.preserialized, msg.Interface()) - if err != nil { - return interpretedFieldValue{}, interp.HandleOptionValueErrorf(mc, name, "failed to set value for field %v: %w", fieldName(fld), err) - } + case fld.IsMap(): + mv := msg.Mutable(fld).Map() + setMapEntry(fld, msg, mv, value.Message()) case fld.IsList(): - msg.Mutable(fld).List().Append(value.val) + msg.Mutable(fld).List().Append(value) default: if msg.Has(fld) { - return interpretedFieldValue{}, interp.HandleOptionForbiddenErrorf(mc, name, "non-repeated option field %s already set", fieldName(fld)) + return reporter.Errorf(interp.nodeInfo(name), "%vnon-repeated option field %s already set", mc, fieldName(fld)) } - msg.Set(fld, value.val) + msg.Set(fld, value) } - return value, nil + return nil } -func setMapEntry(fld protoreflect.FieldDescriptor, msg protoreflect.Message, mapVal protoreflect.Map, value *interpretedFieldValue) { - entry := value.val.Message() +func setMapEntry( + fld protoreflect.FieldDescriptor, + msg protoreflect.Message, + mapVal protoreflect.Map, + entry protoreflect.Message, +) { keyFld, valFld := fld.MapKey(), fld.MapValue() - // if an entry is missing a key or value, we add in an explicit - // zero value to msgVals to match protoc (which also odds these - // in even if not present in source) - if !entry.Has(keyFld) { - // put key before value - value.msgVal = append(append(([]*interpretedField)(nil), zeroValue(keyFld)), value.msgVal...) - } - if !entry.Has(valFld) { - value.msgVal = append(value.msgVal, zeroValue(valFld)) - } key := entry.Get(keyFld) val := entry.Get(valFld) if fld.MapValue().Kind() == protoreflect.MessageKind { @@ -1635,84 +1134,6 @@ func setMapEntry(fld protoreflect.FieldDescriptor, msg protoreflect.Message, map mapVal.Set(key.MapKey(), val) } -// zeroValue returns the zero value for the field types as a *interpretedField. -// The given fld must NOT be a repeated field. -func zeroValue(fld protoreflect.FieldDescriptor) *interpretedField { - var val protoreflect.Value - var msgVal []*interpretedField - switch fld.Kind() { - case protoreflect.MessageKind, protoreflect.GroupKind: - // needs to be non-nil, but empty - msgVal = []*interpretedField{} - msg := dynamicpb.NewMessage(fld.Message()) - val = protoreflect.ValueOfMessage(msg) - case protoreflect.EnumKind: - val = protoreflect.ValueOfEnum(0) - case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: - val = protoreflect.ValueOfInt32(0) - case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: - val = protoreflect.ValueOfUint32(0) - case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: - val = protoreflect.ValueOfInt64(0) - case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: - val = protoreflect.ValueOfUint64(0) - case protoreflect.BoolKind: - val = protoreflect.ValueOfBool(false) - case protoreflect.FloatKind: - val = protoreflect.ValueOfFloat32(0) - case protoreflect.DoubleKind: - val = protoreflect.ValueOfFloat64(0) - case protoreflect.BytesKind: - val = protoreflect.ValueOfBytes(nil) - case protoreflect.StringKind: - val = protoreflect.ValueOfString("") - } - return &interpretedField{ - number: int32(fld.Number()), - kind: fld.Kind(), - value: interpretedFieldValue{ - val: val, - msgVal: msgVal, - }, - } -} - -type listValue []protoreflect.Value - -var _ protoreflect.List = (*listValue)(nil) - -func (l listValue) Len() int { - return len(l) -} - -func (l listValue) Get(i int) protoreflect.Value { - return l[i] -} - -func (l listValue) Set(i int, value protoreflect.Value) { - l[i] = value -} - -func (l *listValue) Append(value protoreflect.Value) { - *l = append(*l, value) -} - -func (l listValue) AppendMutable() protoreflect.Value { - panic("AppendMutable not supported") -} - -func (l *listValue) Truncate(i int) { - *l = (*l)[:i] -} - -func (l listValue) NewElement() protoreflect.Value { - panic("NewElement not supported") -} - -func (l listValue) IsValid() bool { - return true -} - type msgLiteralResolver struct { interp *interpreter pkg protoreflect.FullName @@ -1815,15 +1236,22 @@ func optionValueKind(opt *descriptorpb.UninterpretedOption) string { // fieldValue computes a compile-time value (constant or list or message literal) for the given // AST node val. The value in val must be assignable to the field fld. -func (interp *interpreter) fieldValue(mc *protointernal.MessageContext, msg protoreflect.Message, fld protoreflect.FieldDescriptor, val *ast.ValueNode, insideMsgLiteral bool) (interpretedFieldValue, error) { +func (interp *interpreter) fieldValue( + mc *protointernal.MessageContext, + msg protoreflect.Message, + fld protoreflect.FieldDescriptor, + val *ast.ValueNode, + insideMsgLiteral bool, + pathPrefix []int32, +) (protoreflect.Value, sourceinfo.OptionSourceInfo, error) { k := fld.Kind() switch k { case protoreflect.EnumKind: num, _, err := interp.enumFieldValue(mc, fld.Enum(), val, insideMsgLiteral) if err != nil { - return interpretedFieldValue{}, err + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, err } - return interpretedFieldValue{val: protoreflect.ValueOfEnum(num)}, nil + return protoreflect.ValueOfEnum(num), newSrcInfo(pathPrefix, nil), nil case protoreflect.MessageKind, protoreflect.GroupKind: v := val.Value() @@ -1841,22 +1269,28 @@ func (interp *interpreter) fieldValue(mc *protointernal.MessageContext, msg prot // Normal message field childMsg = msg.NewField(fld).Message() } - return interp.messageLiteralValue(mc, aggs, childMsg) + return interp.messageLiteralValue(mc, aggs, childMsg, pathPrefix) } - return interpretedFieldValue{}, interp.HandleOptionValueErrorf(mc, val, "expecting message, got %s", valueKind(v)) + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Errorf(interp.nodeInfo(val), "%vexpecting message, got %s", mc, valueKind(v)) default: v, err := interp.scalarFieldValue(mc, descriptorpb.FieldDescriptorProto_Type(k), val, insideMsgLiteral) if err != nil { - return interpretedFieldValue{}, err + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, err } - return interpretedFieldValue{val: protoreflect.ValueOf(v)}, nil + return protoreflect.ValueOf(v), newSrcInfo(pathPrefix, nil), nil } } // enumFieldValue resolves the given AST node val as an enum value descriptor. If the given // value is not a valid identifier (or number if allowed), an error is returned instead. -func (interp *interpreter) enumFieldValue(mc *protointernal.MessageContext, ed protoreflect.EnumDescriptor, val *ast.ValueNode, allowNumber bool) (protoreflect.EnumNumber, protoreflect.Name, error) { +func (interp *interpreter) enumFieldValue( + mc *protointernal.MessageContext, + ed protoreflect.EnumDescriptor, + val *ast.ValueNode, + allowNumber bool, +) (protoreflect.EnumNumber, protoreflect.Name, error) { v := val.Value() var num protoreflect.EnumNumber switch v := v.(type) { @@ -1890,40 +1324,21 @@ func (interp *interpreter) enumFieldValue(mc *protointernal.MessageContext, ed p if ev != nil { return num, ev.Name(), nil } - // NB: We have to look at the syntax instead of directly using ed.IsClosed because we - // may still be interpreting options that would be used by the implementation of IsClosed. - if ed.Syntax() != protoreflect.Proto3 { - if ed.Syntax() == protoreflect.Editions && interp.file == any(ed.ParentFile()) { - // Oof. We are still interpreting options for this file. Yet we need - // options in order to decide if this enum value is allowed (only if - // the enum is open). - // - // So, for now, we will assume it's valid and then report an error - // later, after all options in the file are interpreted, to check if - // it was allowed. - interp.needOpenEnum(mc, ed, val, num) - return num, "", nil - } - return num, "", interp.HandleOptionValueErrorf(mc, val, "closed enum %s has no value with number %d", ed.FullName(), num) + if ed.IsClosed() { + return num, "", reporter.Errorf(interp.nodeInfo(val), "%vclosed enum %s has no value with number %d", mc, ed.FullName(), num) } // unknown value, but enum is open, so we allow it and return blank name return num, "", nil } -func (interp *interpreter) needOpenEnum(mc *protointernal.MessageContext, ed protoreflect.EnumDescriptor, val ast.Node, number protoreflect.EnumNumber) { - if interp.neededOpenEnums == nil { - interp.neededOpenEnums = map[protoreflect.EnumDescriptor][]enumValRef{} - } - interp.neededOpenEnums[ed] = append(interp.neededOpenEnums[ed], enumValRef{ - number: number, - value: val, - mc: mc, - }) -} - // enumFieldValueFromProto resolves the given uninterpreted option value as an enum value descriptor. // If the given value is not a valid identifier, an error is returned instead. -func (interp *interpreter) enumFieldValueFromProto(mc *protointernal.MessageContext, ed protoreflect.EnumDescriptor, opt *descriptorpb.UninterpretedOption, node ast.Node) (protoreflect.EnumNumber, protoreflect.Name, error) { +func (interp *interpreter) enumFieldValueFromProto( + mc *protointernal.MessageContext, + ed protoreflect.EnumDescriptor, + opt *descriptorpb.UninterpretedOption, + node ast.Node, +) (protoreflect.EnumNumber, protoreflect.Name, error) { // We don't have to worry about allowing numbers because numbers are never allowed // in uninterpreted values; they are only allowed inside aggregate values (i.e. // message literals). @@ -1940,9 +1355,22 @@ func (interp *interpreter) enumFieldValueFromProto(mc *protointernal.MessageCont } } +func (interp *interpreter) indexEnumValueRef(fld protoreflect.FieldDescriptor, optValNode *ast.ValueNode) { + enumDesc := fld.Enum() + switch v := optValNode.Unwrap().(type) { + case *ast.IdentNode: + interp.descriptorIndex.EnumValueIdentNodesToEnumValueDescriptors[v] = enumDesc.Values().ByName(protoreflect.Name(v.AsIdentifier())) + } +} + // scalarFieldValue resolves the given AST node val as a value whose type is assignable to a // field with the given fldType. -func (interp *interpreter) scalarFieldValue(mc *protointernal.MessageContext, fldType descriptorpb.FieldDescriptorProto_Type, val *ast.ValueNode, insideMsgLiteral bool) (interface{}, error) { +func (interp *interpreter) scalarFieldValue( + mc *protointernal.MessageContext, + fldType descriptorpb.FieldDescriptorProto_Type, + val *ast.ValueNode, + insideMsgLiteral bool, +) (interface{}, error) { v := val.Value() switch fldType { case descriptorpb.FieldDescriptorProto_TYPE_BOOL: @@ -2075,7 +1503,12 @@ func (interp *interpreter) scalarFieldValue(mc *protointernal.MessageContext, fl // scalarFieldValue resolves the given uninterpreted option value as a value whose type is // assignable to a field with the given fldType. -func (interp *interpreter) scalarFieldValueFromProto(mc *protointernal.MessageContext, fldType descriptorpb.FieldDescriptorProto_Type, opt *descriptorpb.UninterpretedOption, node ast.Node) (interface{}, error) { +func (interp *interpreter) scalarFieldValueFromProto( + mc *protointernal.MessageContext, + fldType descriptorpb.FieldDescriptorProto_Type, + opt *descriptorpb.UninterpretedOption, + node ast.Node, +) (interface{}, error) { switch fldType { case descriptorpb.FieldDescriptorProto_TYPE_BOOL: if opt.IdentifierValue != nil { @@ -2224,31 +1657,34 @@ func descriptorType(m proto.Message) string { } } -func (interp *interpreter) messageLiteralValue(mc *protointernal.MessageContext, fieldNodes []*ast.MessageFieldNode, msg protoreflect.Message) (interpretedFieldValue, error) { +func (interp *interpreter) messageLiteralValue( + mc *protointernal.MessageContext, + fieldNodes []*ast.MessageFieldNode, + msg protoreflect.Message, + pathPrefix []int32, +) (protoreflect.Value, sourceinfo.OptionSourceInfo, error) { fmd := msg.Descriptor() origPath := mc.OptAggPath defer func() { mc.OptAggPath = origPath }() - // NB: we don't want to leave this nil, even if the - // message is empty, because that indicates to - // caller that the result is not a message - flds := make([]*interpretedField, 0, len(fieldNodes)) - var foundAnyNode bool + flds := make(map[*ast.MessageFieldNode]*sourceinfo.OptionSourceInfo, len(fieldNodes)) for _, fieldNode := range fieldNodes { if origPath == "" { mc.OptAggPath = fieldNode.Name.Value() } else { mc.OptAggPath = origPath + "." + fieldNode.Name.Value() } - if fieldNode.Name.IsAnyTypeReference() && !fieldNode.IsIncomplete() { - if fmd.FullName() != "google.protobuf.Any" { - return interpretedFieldValue{}, interp.HandleOptionForbiddenErrorf(mc, fieldNode.Name.UrlPrefix, "type references are only allowed for google.protobuf.Any, but this type is %s", fmd.FullName()) + if fieldNode.Name.IsAnyTypeReference() { + if len(fieldNodes) > 1 { + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Errorf(interp.nodeInfo(fieldNode.Name.UrlPrefix), "%vany type references cannot be repeated or mixed with other fields", mc) } - if foundAnyNode { - return interpretedFieldValue{}, interp.HandleOptionForbiddenErrorf(mc, fieldNode.Name.UrlPrefix, "multiple any type references are not allowed") + if fmd.FullName() != "google.protobuf.Any" { + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Errorf(interp.nodeInfo(fieldNode.Name.UrlPrefix), "%vtype references are only allowed for google.protobuf.Any, but this type is %s", mc, fmd.FullName()) } - foundAnyNode = true + urlPrefix := fieldNode.Name.UrlPrefix.AsIdentifier() msgName := fieldNode.Name.Name.AsIdentifier() fullURL := fmt.Sprintf("%s/%s", urlPrefix, msgName) @@ -2257,43 +1693,51 @@ func (interp *interpreter) messageLiteralValue(mc *protointernal.MessageContext, // URLs into message descriptors. The default resolver would be // implemented as below, only accepting "type.googleapis.com" and // "type.googleprod.com" as hosts/prefixes and using the compiled - // file's transitive closure to find the named message. + // file's transitive closure to find the named message, since that + // is what protoc does. if urlPrefix != "type.googleapis.com" && urlPrefix != "type.googleprod.com" { - return interpretedFieldValue{}, interp.HandleOptionNotFoundErrorf(mc, fieldNode.Name.UrlPrefix, "could not resolve type reference %s", fullURL) + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Errorf(interp.nodeInfo(fieldNode.Name.UrlPrefix), "%vcould not resolve type reference %s", mc, fullURL) } anyFields, ok := fieldNode.Val.Value().([]*ast.MessageFieldNode) if !ok { - return interpretedFieldValue{}, interp.HandleOptionForbiddenErrorf(mc, fieldNode.Val, "type references for google.protobuf.Any must have message literal value") + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Errorf(interp.nodeInfo(fieldNode.Val), "%vtype references for google.protobuf.Any must have message literal value", mc) } - anyMd := resolveDescriptor[protoreflect.MessageDescriptor](interp.resolver, string(msgName)) + + anyMd := resolveDescriptor[protoreflect.MessageDescriptor](interp.resolver, protoreflect.FullName(msgName)) if anyMd == nil { - return interpretedFieldValue{}, interp.HandleOptionNotFoundErrorf(mc, fieldNode.Name.UrlPrefix, "could not resolve type reference %s", fullURL) + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Errorf(interp.nodeInfo(fieldNode.Name), "%vcould not resolve type reference %s", mc, fullURL) } // parse the message value - msgVal, err := interp.messageLiteralValue(mc, anyFields, dynamicpb.NewMessage(anyMd)) + msgVal, valueSrcInfo, err := interp.messageLiteralValue(mc, anyFields, dynamicpb.NewMessage(anyMd), append(pathPrefix, protointernal.AnyValueTag)) if err != nil { - return interpretedFieldValue{}, err + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, err } - // Any is defined with two fields: - // string type_url = 1 - // bytes value = 2 - typeURLDescriptor := fmd.Fields().ByNumber(1) + typeURLDescriptor := fmd.Fields().ByNumber(protointernal.AnyTypeURLTag) if typeURLDescriptor == nil || typeURLDescriptor.Kind() != protoreflect.StringKind { - return interpretedFieldValue{}, interp.HandleOptionValueErrorf(mc, fieldNode.Name, "failed to set type_url string field on Any: %w", err) + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Errorf(interp.nodeInfo(fieldNode.Name), "%vfailed to set type_url string field on Any: %w", mc, err) } - msg.Set(typeURLDescriptor, protoreflect.ValueOfString(fullURL)) - valueDescriptor := fmd.Fields().ByNumber(2) + typeURLVal := protoreflect.ValueOfString(fullURL) + msg.Set(typeURLDescriptor, typeURLVal) + valueDescriptor := fmd.Fields().ByNumber(protointernal.AnyValueTag) if valueDescriptor == nil || valueDescriptor.Kind() != protoreflect.BytesKind { - return interpretedFieldValue{}, interp.HandleOptionValueErrorf(mc, fieldNode.Name, "failed to set value bytes field on Any: %w", err) + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Errorf(interp.nodeInfo(fieldNode.Name), "%vfailed to set value bytes field on Any: %w", mc, err) } - b, err := proto.MarshalOptions{Deterministic: true}.Marshal(msgVal.val.Message().Interface()) + + b, err := (proto.MarshalOptions{Deterministic: true}).Marshal(msgVal.Message().Interface()) if err != nil { - return interpretedFieldValue{}, interp.HandleOptionValueErrorf(mc, fieldNode.Val, "failed to serialize message value: %w", err) + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Errorf(interp.nodeInfo(fieldNode.Val), "%vfailed to serialize message value: %w", mc, err) } interp.descriptorIndex.TypeReferenceURLsToMessageDescriptors[fieldNode.Name] = anyMd msg.Set(valueDescriptor, protoreflect.ValueOfBytes(b)) + flds[fieldNode] = &valueSrcInfo } else { var ffld protoreflect.FieldDescriptor var err error @@ -2317,9 +1761,15 @@ func (interp *interpreter) messageLiteralValue(mc *protointernal.MessageContext, // Groups are indicated in the text format by the group name (which is // camel-case), NOT the field name (which is lower-case). // ...but only regular fields, not extensions that are groups... - if ffld != nil && ffld.Kind() == protoreflect.GroupKind && ffld.Message().Name() != protoreflect.Name(fieldNode.Name.Value()) { - // this is kind of silly to fail here, but this mimics protoc behavior - return interpretedFieldValue{}, interp.HandleOptionNotFoundErrorf(mc, fieldNode.Name, "field %s not found (did you mean the group named %s?)", fieldNode.Name.Value(), ffld.Message().Name()) + if ffld != nil && ffld.Kind() == protoreflect.GroupKind && + string(ffld.Name()) == strings.ToLower(string(ffld.Message().Name())) && + ffld.Message().Name() != protoreflect.Name(fieldNode.Name.Value()) { + // This is kind of silly to fail here, but this mimics protoc behavior. + // We only fail when this really looks like a group since we need to be + // able to use the field name for fields in editions files that use the + // delimited message encoding and don't use proto2 group naming. + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Errorf(interp.nodeInfo(fieldNode.Name), "%vfield %s not found (did you mean the group named %s?)", mc, fieldNode.Name.Value(), ffld.Message().Name()) } if ffld == nil { err = protoregistry.NotFound @@ -2335,45 +1785,38 @@ func (interp *interpreter) messageLiteralValue(mc *protointernal.MessageContext, } } } - if err != nil { - return interpretedFieldValue{}, interp.HandleOptionNotFoundErrorf(mc, fieldNode.Name, "field %s not found", string(fieldNode.Name.Name.AsIdentifier())) - } - if fieldNode.IsIncomplete() { - // we can't save the incomplete field, but if we get this far, index the - // field descriptor so it can be queried from the incomplete field node. - interp.descriptorIndex.FieldReferenceNodesToFieldDescriptors[fieldNode] = ffld - continue + if errors.Is(err, protoregistry.NotFound) { + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Errorf(interp.nodeInfo(fieldNode.Name), "%vfield %s not found", mc, string(fieldNode.Name.Name.AsIdentifier())) + } else if err != nil { + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Error(interp.nodeInfo(fieldNode.Name), err) } if fieldNode.Sep == nil && ffld.Message() == nil { // If there is no separator, the field type should be a message. - // Otherwise it is an error in the text format. - return interpretedFieldValue{}, interp.HandleOptionValueErrorf(mc, fieldNode, "unexpected non-message value (did you forget a ':'?)") + // Otherwise, it is an error in the text format. + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, + reporter.Errorf(interp.nodeInfo(fieldNode.Val), "syntax error: unexpected value, expecting ':'") } - res, index, err := interp.setOptionField(mc, msg, ffld, fieldNode.Name, fieldNode.Val, true) - + srcInfo, err := interp.setOptionField(mc, msg, ffld, fieldNode.Name, fieldNode.Val, true, append(pathPrefix, int32(ffld.Number()))) + if err != nil { + return protoreflect.Value{}, sourceinfo.OptionSourceInfo{}, err + } + interp.descriptorIndex.FieldReferenceNodesToFieldDescriptors[fieldNode] = ffld if ffld.Kind() == protoreflect.EnumKind { interp.indexEnumValueRef(ffld, fieldNode.Val) } - if err != nil { - return interpretedFieldValue{}, err - } - flds = append(flds, &interpretedField{ - node: fieldNode, - number: int32(ffld.Number()), - index: int32(index), - kind: ffld.Kind(), - repeated: ffld.Cardinality() == protoreflect.Repeated, - packed: ffld.IsPacked(), - value: res, - }) - } - } - val := protoreflect.ValueOfMessage(msg) - if !val.IsValid() { - return interpretedFieldValue{}, interp.HandleOptionValueErrorf(mc, fieldNodes[0], "could not resolve message literal") - } - return interpretedFieldValue{ - val: val, - msgVal: flds, - }, nil + flds[fieldNode] = srcInfo + } + } + return protoreflect.ValueOfMessage(msg), + newSrcInfo(pathPrefix, &sourceinfo.MessageLiteralSourceInfo{Fields: flds}), + nil +} + +func newSrcInfo(path []int32, children sourceinfo.OptionChildrenSourceInfo) sourceinfo.OptionSourceInfo { + return sourceinfo.OptionSourceInfo{ + Path: protointernal.ClonePath(path), + Children: children, + } } diff --git a/options/options_test.go b/options/options_test.go index 5e25f761..90d2886e 100644 --- a/options/options_test.go +++ b/options/options_test.go @@ -25,24 +25,31 @@ import ( "strings" "testing" - "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protodesc" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" - "google.golang.org/protobuf/testing/protocmp" "google.golang.org/protobuf/types/descriptorpb" "github.com/kralicky/protocompile" "github.com/kralicky/protocompile/linker" "github.com/kralicky/protocompile/options" "github.com/kralicky/protocompile/parser" + "github.com/kralicky/protocompile/protointernal" "github.com/kralicky/protocompile/protointernal/prototest" + "github.com/kralicky/protocompile/protoutil" "github.com/kralicky/protocompile/reporter" ) +func TestMain(m *testing.M) { + // Enable just for tests. + protointernal.AllowEditions = true + status := m.Run() + os.Exit(status) +} + type ( ident string aggregate string @@ -347,13 +354,15 @@ func qualify(qualifier, name string) string { func TestOptionsEncoding(t *testing.T) { t.Parallel() testCases := map[string]string{ - "proto2": "options/options2/test.proto", - "proto3": "options/test_proto3.proto", - "defaults": "desc_test_defaults.proto", + "proto2-opts-same-file": "options/options.proto", + "proto2": "options/options2/test.proto", + "proto3": "options/test_proto3.proto", + "editions": "options/test_editions.proto", + "defaults": "desc_test_defaults.proto", } - for syntax, file := range testCases { + for testCaseName, file := range testCases { file := file // must not capture loop variable below, for thread safety - t.Run(syntax, func(t *testing.T) { + t.Run(testCaseName, func(t *testing.T) { t.Parallel() fileToCompile := strings.TrimPrefix(file, "options/") compiler := protocompile.Compiler{ @@ -378,42 +387,33 @@ func TestOptionsEncoding(t *testing.T) { fdset := prototest.LoadDescriptorSet(t, descriptorSetFile, linker.ResolverFromFile(fds.Files[0])) prototest.CheckFiles(t, res, fdset, false) - canonicalProto := res.CanonicalProto() actualFdset := &descriptorpb.FileDescriptorSet{ - File: []*descriptorpb.FileDescriptorProto{canonicalProto}, + File: []*descriptorpb.FileDescriptorProto{protoutil.ProtoFromFileDescriptor(res)}, } - actualData, err := proto.Marshal(actualFdset) - require.NoError(t, err) - // semantic check that unmarshalling the "canonical bytes" results - // in the same proto as when not using "canonical bytes" - protoData, err := proto.Marshal(canonicalProto) + // drum roll... make sure the descriptors we produce are semantically equivalent + // to those produced by protoc + expectedData, err := os.ReadFile(descriptorSetFile) require.NoError(t, err) - proto.Reset(canonicalProto) + expectedFdset := &descriptorpb.FileDescriptorSet{} uOpts := proto.UnmarshalOptions{Resolver: linker.ResolverFromFile(fds.Files[0])} - err = uOpts.Unmarshal(protoData, canonicalProto) - require.NoError(t, err) - if !proto.Equal(res.FileDescriptorProto(), canonicalProto) { - t.Fatal("canonical proto != proto") - } - - // drum roll... make sure the bytes match the protoc output - expectedData, err := os.ReadFile(descriptorSetFile) + err = uOpts.Unmarshal(expectedData, expectedFdset) require.NoError(t, err) - if !bytes.Equal(actualData, expectedData) { + if !prototest.AssertMessagesEqual(t, expectedFdset, actualFdset, file) { outputDescriptorSetFile := strings.ReplaceAll(descriptorSetFile, ".proto", ".actual.proto") + actualData, err := proto.Marshal(actualFdset) + require.NoError(t, err) err = os.WriteFile(outputDescriptorSetFile, actualData, 0o644) if err != nil { - t.Log("failed to write actual to file") + t.Logf("failed to write actual to file: %v", err) + } else { + t.Logf("wrote actual contents to %s", outputDescriptorSetFile) } - - t.Fatalf("descriptor set bytes not equal (created file %q with actual bytes)", outputDescriptorSetFile) } }) } } -//nolint:errcheck func TestInterpretOptionsWithoutAST(t *testing.T) { t.Parallel() @@ -445,6 +445,7 @@ func TestInterpretOptionsWithoutAST(t *testing.T) { return res, err } res.Proto = parseResult.FileDescriptorProto() + res.ResolvedPath = protocompile.ResolvedPath(name) return res, nil }, )), @@ -458,22 +459,7 @@ func TestInterpretOptionsWithoutAST(t *testing.T) { fd := file.(linker.Result).FileDescriptorProto() fdFromNoAST := fromNoAST.(linker.Result).FileDescriptorProto() // final protos, with options interpreted, match - diff := cmp.Diff(fd, fdFromNoAST, protocmp.Transform()) - require.Empty(t, diff) - } - - // Also make sure the canonical bytes are correct - for _, file := range filesFromNoAST.Files { - res := file.(linker.Result) - canonicalFd := res.CanonicalProto() - data, err := proto.Marshal(canonicalFd) - require.NoError(t, err) - fromCanonical := &descriptorpb.FileDescriptorProto{} - err = proto.UnmarshalOptions{Resolver: linker.ResolverFromFile(file)}.Unmarshal(data, fromCanonical) - require.NoError(t, err) - origFd := res.FileDescriptorProto() - diff := cmp.Diff(origFd, fromCanonical, protocmp.Transform()) - require.Empty(t, diff) + prototest.AssertMessagesEqual(t, fd, fdFromNoAST, file.Path()) } } @@ -523,21 +509,6 @@ func TestInterpretOptionsWithoutASTNoOp(t *testing.T) { fd := file.(linker.Result).FileDescriptorProto() fdFromNoAST := fromNoAST.(linker.Result).FileDescriptorProto() // final protos, with options interpreted, match - diff := cmp.Diff(fd, fdFromNoAST, protocmp.Transform()) - require.Empty(t, diff) - } - - // Also make sure the canonical bytes are correct - for _, file := range resultsFromNoAST.Files { - res := file.(linker.Result) - canonicalFd := res.CanonicalProto() - data, err := proto.Marshal(canonicalFd) - require.NoError(t, err) - fromCanonical := &descriptorpb.FileDescriptorProto{} - err = proto.UnmarshalOptions{Resolver: linker.ResolverFromFile(file)}.Unmarshal(data, fromCanonical) - require.NoError(t, err) - origFd := res.FileDescriptorProto() - diff := cmp.Diff(origFd, fromCanonical, protocmp.Transform()) - require.Empty(t, diff) + prototest.AssertMessagesEqual(t, fd, fdFromNoAST, file.Path()) } } diff --git a/protointernal/prototest/util.go b/protointernal/prototest/util.go index b7117b75..1d2ce47c 100644 --- a/protointernal/prototest/util.go +++ b/protointernal/prototest/util.go @@ -15,11 +15,11 @@ package prototest import ( - "fmt" "os" "testing" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" @@ -73,26 +73,11 @@ func findFileInSet(fps *descriptorpb.FileDescriptorSet, name string) *descriptor return nil } -func AssertMessagesEqual(t *testing.T, exp, act proto.Message, msgAndArgs ...interface{}) { +func AssertMessagesEqual(t *testing.T, exp, act proto.Message, description string) bool { t.Helper() - AssertMessagesEqualWithOptions(t, exp, act, nil, msgAndArgs...) -} - -func AssertMessagesEqualWithOptions(t *testing.T, exp, act proto.Message, opts []cmp.Option, msgAndArgs ...interface{}) { - t.Helper() - cmpOpts := []cmp.Option{protocmp.Transform()} - cmpOpts = append(cmpOpts, opts...) - if diff := cmp.Diff(exp, act, cmpOpts...); diff != "" { - var prefix string - if len(msgAndArgs) == 1 { - if msg, ok := msgAndArgs[0].(string); ok { - prefix = msg + ": " - } else { - prefix = fmt.Sprintf("%+v: ", msgAndArgs[0]) - } - } else if len(msgAndArgs) > 1 { - prefix = fmt.Sprintf(msgAndArgs[0].(string)+": ", msgAndArgs[1:]...) - } - t.Errorf("%smessage mismatch (-want +got):\n%v", prefix, diff) + if diff := cmp.Diff(exp, act, protocmp.Transform(), cmpopts.EquateNaNs()); diff != "" { + t.Errorf("%s: message mismatch (-want, +got):\n%s", description, diff) + return false } + return true } diff --git a/protointernal/tags.go b/protointernal/tags.go index 2c7a5f6d..74002953 100644 --- a/protointernal/tags.go +++ b/protointernal/tags.go @@ -243,4 +243,9 @@ const ( // UninterpretedNameNameTag is the tag number of the name element in an // uninterpreted option name proto. UninterpretedNameNameTag = 1 + + // AnyTypeURLTag is the tag number of the type_url field of the Any proto. + AnyTypeURLTag = 1 + // AnyValueTag is the tag number of the value field of the Any proto. + AnyValueTag = 2 ) diff --git a/protointernal/util.go b/protointernal/util.go index 4a677a4a..7902a60f 100644 --- a/protointernal/util.go +++ b/protointernal/util.go @@ -231,3 +231,9 @@ func CanPack(k protoreflect.Kind) bool { return true } } + +func ClonePath(path protoreflect.SourcePath) protoreflect.SourcePath { + clone := make(protoreflect.SourcePath, len(path)) + copy(clone, path) + return clone +} diff --git a/sourceinfo/source_code_info.go b/sourceinfo/source_code_info.go index 07c8bfea..2e2be1b6 100644 --- a/sourceinfo/source_code_info.go +++ b/sourceinfo/source_code_info.go @@ -65,6 +65,11 @@ type OptionSourceInfo struct { // The source info path to this element. If this element represents a // declaration with an array-literal value, the last element of the // path is the index of the first item in the array. + // If the first element is negative, it indicates the number of path + // components to remove from the path to the relevant options. This is + // used for field pseudo-options, so that the path indicates a field on + // the descriptor, which is a parent of the options message (since that + // is how the pseudo-options are actually stored). Path []int32 // Children can be an *ArrayLiteralSourceInfo, a *MessageLiteralSourceInfo, // or nil, depending on whether the option's value is an @@ -179,7 +184,7 @@ func generateSourceInfoForFile(opts OptionIndex, sci *sourceCodeInfo) { if sci.protocCompatMode { proto.GetExtension(sci.file, ast.E_FileInfo).(*ast.FileInfo).PositionEncoding = ast.FileInfo_PositionEncodingProtocCompatible } - path := make([]int32, 0, 10) + path := make([]int32, 0, 16) sci.newLocWithoutComments(sci.file, nil) if sci.file.Syntax != nil { @@ -214,7 +219,10 @@ func generateSourceInfoForFile(opts OptionIndex, sci *sourceCodeInfo) { generateSourceCodeInfoForEnum(opts, sci, child, append(path, protointernal.FileEnumsTag, enumIndex)) enumIndex++ case *ast.ExtendNode: - generateSourceCodeInfoForExtensions(opts, sci, child, &extendIndex, &msgIndex, append(path, protointernal.FileExtensionsTag), append(dup(path), protointernal.FileMessagesTag)) + extsPath := append(path, protointernal.FileExtensionsTag) //nolint:gocritic // intentionally creating new slice var + // we clone the path here so that append can't mutate extsPath, since they may share storage + msgsPath := append(protointernal.ClonePath(path), protointernal.FileMessagesTag) + generateSourceCodeInfoForExtensions(opts, sci, child, &extendIndex, &msgIndex, extsPath, msgsPath) case *ast.ServiceNode: generateSourceCodeInfoForService(opts, sci, child, append(path, protointernal.FileServicesTag, svcIndex)) svcIndex++ @@ -305,6 +313,18 @@ func generateSourceInfoForOptionChildren(sci *sourceCodeInfo, n *ast.ValueNode, continue } fullPath := combinePathsForOption(pathPrefix, fieldInfo.Path) + locationNode := ast.Node(fieldNode) + if fieldNode.Name.IsAnyTypeReference() && fullPath[len(fullPath)-1] == protointernal.AnyValueTag { + // This is a special expanded Any. So also insert a location + // for the type URL field. + typeURLPath := make([]int32, len(fullPath)) + copy(typeURLPath, fullPath) + typeURLPath[len(typeURLPath)-1] = protointernal.AnyTypeURLTag + sci.newLoc(fieldNode.Name, fullPath) + // And create the next location so it's just the value, + // not the full field definition. + locationNode = fieldNode.Val + } arrayLiteralVal := fieldNode.GetVal().GetArrayLiteral() if arrayLiteralVal != nil { // We don't include this with an array literal since the path @@ -312,7 +332,7 @@ func generateSourceInfoForOptionChildren(sci *sourceCodeInfo, n *ast.ValueNode, // it would be redundant with the child info we add next, and // it wouldn't be entirely correct since it only indicates the // index of the first element in the array (and not the others). - sci.newLoc(fieldNode, fullPath) + sci.newLoc(locationNode, fullPath) } generateSourceInfoForOptionChildren(sci, fieldNode.Val, pathPrefix, fullPath, fieldInfo.Children) } @@ -365,18 +385,24 @@ func generateSourceCodeInfoForMessage(opts OptionIndex, sci *sourceCodeInfo, n a generateSourceCodeInfoForField(opts, sci, child, append(path, protointernal.MessageFieldsTag, fieldIndex)) fieldIndex++ case *ast.GroupNode: - fldPath := path - fldPath = append(fldPath, protointernal.MessageFieldsTag, fieldIndex) + fldPath := append(path, protointernal.MessageFieldsTag, fieldIndex) //nolint:gocritic // intentionally creating new slice var generateSourceCodeInfoForField(opts, sci, child, fldPath) fieldIndex++ - generateSourceCodeInfoForMessage(opts, sci, child, fldPath, append(dup(path), protointernal.MessageNestedMessagesTag, nestedMsgIndex)) + // we clone the path here so that append can't mutate fldPath, since they may share storage + msgPath := append(protointernal.ClonePath(path), protointernal.MessageNestedMessagesTag, nestedMsgIndex) + generateSourceCodeInfoForMessage(opts, sci, child, fldPath, msgPath) nestedMsgIndex++ case *ast.MapFieldNode: generateSourceCodeInfoForField(opts, sci, child, append(path, protointernal.MessageFieldsTag, fieldIndex)) fieldIndex++ nestedMsgIndex++ case *ast.OneofNode: - generateSourceCodeInfoForOneof(opts, sci, child, &fieldIndex, &nestedMsgIndex, append(path, protointernal.MessageFieldsTag), append(dup(path), protointernal.MessageNestedMessagesTag), append(dup(path), protointernal.MessageOneofsTag, oneofIndex)) + fldsPath := append(path, protointernal.MessageFieldsTag) //nolint:gocritic // intentionally creating new slice var + // we clone the path here and below so that append ops can't mutate + // fldPath or msgsPath, since they may otherwise share storage + msgsPath := append(protointernal.ClonePath(path), protointernal.MessageNestedMessagesTag) + ooPath := append(protointernal.ClonePath(path), protointernal.MessageOneofsTag, oneofIndex) + generateSourceCodeInfoForOneof(opts, sci, child, &fieldIndex, &nestedMsgIndex, fldsPath, msgsPath, ooPath) oneofIndex++ case *ast.MessageNode: generateSourceCodeInfoForMessage(opts, sci, child, nil, append(path, protointernal.MessageNestedMessagesTag, nestedMsgIndex)) @@ -385,7 +411,10 @@ func generateSourceCodeInfoForMessage(opts OptionIndex, sci *sourceCodeInfo, n a generateSourceCodeInfoForEnum(opts, sci, child, append(path, protointernal.MessageEnumsTag, nestedEnumIndex)) nestedEnumIndex++ case *ast.ExtendNode: - generateSourceCodeInfoForExtensions(opts, sci, child, &extendIndex, &nestedMsgIndex, append(path, protointernal.MessageExtensionsTag), append(dup(path), protointernal.MessageNestedMessagesTag)) + extsPath := append(path, protointernal.MessageExtensionsTag) //nolint:gocritic // intentionally creating new slice var + // we clone the path here so that append can't mutate extsPath, since they may share storage + msgsPath := append(protointernal.ClonePath(path), protointernal.MessageNestedMessagesTag) + generateSourceCodeInfoForExtensions(opts, sci, child, &extendIndex, &nestedMsgIndex, extsPath, msgsPath) case *ast.ExtensionRangeNode: generateSourceCodeInfoForExtensionRanges(opts, sci, child, &extRangeIndex, append(path, protointernal.MessageExtensionRangesTag)) case *ast.ReservedNode: @@ -665,8 +694,6 @@ type sourceCodeInfo struct { } func (sci *sourceCodeInfo) newLocWithoutComments(n ast.Node, path []int32) { - dup := make([]int32, len(path)) - copy(dup, path) var start, end ast.SourcePos if n == sci.file { // For files, we don't want to consider trailing EOF token @@ -686,7 +713,7 @@ func (sci *sourceCodeInfo) newLocWithoutComments(n ast.Node, path []int32) { start, end = info.Start(), info.End() } sci.locs = append(sci.locs, &descriptorpb.SourceCodeInfo_Location{ - Path: dup, + Path: protointernal.ClonePath(path), Span: makeSpan(start, end), }) } @@ -697,11 +724,9 @@ func (sci *sourceCodeInfo) newLoc(n ast.Node, path []int32) { } info := sci.file.NodeInfo(n) if !sci.extraComments { - dup := make([]int32, len(path)) - copy(dup, path) start, end := info.Start(), info.End() sci.locs = append(sci.locs, &descriptorpb.SourceCodeInfo_Location{ - Path: dup, + Path: protointernal.ClonePath(path), Span: makeSpan(start, end), }) } else { @@ -762,13 +787,11 @@ func (sci *sourceCodeInfo) newLocWithGivenComments(nodeInfo ast.NodeInfo, detach detached[i] = sci.combineComments(cmts) } - dup := make([]int32, len(path)) - copy(dup, path) sci.locs = append(sci.locs, &descriptorpb.SourceCodeInfo_Location{ LeadingDetachedComments: detached, LeadingComments: lead, TrailingComments: trail, - Path: dup, + Path: protointernal.ClonePath(path), Span: makeSpan(nodeInfo.Start(), nodeInfo.End()), }) } @@ -999,7 +1022,3 @@ func (sci *sourceCodeInfo) combineComments(comments comments) string { } return buf.String() } - -func dup(p []int32) []int32 { - return append(([]int32)(nil), p...) -} diff --git a/sourceinfo/source_code_info_test.go b/sourceinfo/source_code_info_test.go index 23013c66..76b87504 100644 --- a/sourceinfo/source_code_info_test.go +++ b/sourceinfo/source_code_info_test.go @@ -187,7 +187,8 @@ func TestSourceCodeInfoOptions(t *testing.T) { // set to true to re-generate golden output file const regenerateGoldenOutputFile = true - generateSourceInfoText := func(filename string, mode protocompile.SourceInfoMode) string { + generateSourceInfoText := func(t *testing.T, filename string, mode protocompile.SourceInfoMode) string { + t.Helper() compiler := protocompile.Compiler{ Resolver: protocompile.WithStandardImports(&protocompile.SourceResolver{ ImportPaths: []string{"../internal/testdata"}, @@ -227,14 +228,14 @@ func TestSourceCodeInfoOptions(t *testing.T) { testCase := testCase t.Run(testCase.name, func(t *testing.T) { t.Parallel() - output := generateSourceInfoText(testCase.filename, testCase.mode) + output := generateSourceInfoText(t, testCase.filename, testCase.mode) baseName := strings.TrimSuffix(testCase.filename, ".proto") if regenerateGoldenOutputFile { err := os.WriteFile(fmt.Sprintf("testdata/%s.%s.txt", baseName, testCase.name), []byte(output), 0o644) require.NoError(t, err) // also create a file with standard comments, as a useful demonstration of the differences - output := generateSourceInfoText(testCase.filename, protocompile.SourceInfoStandard|protocompile.SourceInfoProtocCompatible) + output := generateSourceInfoText(t, testCase.filename, protocompile.SourceInfoStandard|protocompile.SourceInfoProtocCompatible) err = os.WriteFile(fmt.Sprintf("testdata/%s.standard.txt", baseName), []byte(output), 0o644) require.NoError(t, err) return diff --git a/sourceinfo/testdata/desc_test_complex.extra_option_locations.txt b/sourceinfo/testdata/desc_test_complex.extra_option_locations.txt index 10a08290..63d25878 100644 --- a/sourceinfo/testdata/desc_test_complex.extra_option_locations.txt +++ b/sourceinfo/testdata/desc_test_complex.extra_option_locations.txt @@ -795,20 +795,11 @@ desc_test_complex.proto > message_type[4] > options > (foo.bar.rept)[0] > array[ Span: 97:13 -> 97:14 desc_test_complex.proto > message_type[4] > options > (foo.bar.rept)[0] > array[1]: - Span: 97:14 -> 97:15 - -desc_test_complex.proto > message_type[4] > options > (foo.bar.rept)[0] > array[2]: Span: 97:16 -> 97:17 -desc_test_complex.proto > message_type[4] > options > (foo.bar.rept)[0] > array[3]: - Span: 97:17 -> 97:18 - -desc_test_complex.proto > message_type[4] > options > (foo.bar.rept)[0] > array[4]: +desc_test_complex.proto > message_type[4] > options > (foo.bar.rept)[0] > array[2]: Span: 97:19 -> 97:20 -desc_test_complex.proto > message_type[4] > options > (foo.bar.rept)[0] > array[5]: - Span: 97:20 -> 97:20 - desc_test_complex.proto > message_type[4] > options > (foo.bar.rept)[0] > r[0]: Span: 98:5 -> 98:49 @@ -834,20 +825,11 @@ desc_test_complex.proto > message_type[4] > options > (foo.bar.rept)[1] > array[ Span: 103:13 -> 103:14 desc_test_complex.proto > message_type[4] > options > (foo.bar.rept)[1] > array[1]: - Span: 103:14 -> 103:15 - -desc_test_complex.proto > message_type[4] > options > (foo.bar.rept)[1] > array[2]: Span: 103:16 -> 103:17 -desc_test_complex.proto > message_type[4] > options > (foo.bar.rept)[1] > array[3]: - Span: 103:17 -> 103:18 - -desc_test_complex.proto > message_type[4] > options > (foo.bar.rept)[1] > array[4]: +desc_test_complex.proto > message_type[4] > options > (foo.bar.rept)[1] > array[2]: Span: 103:19 -> 103:20 -desc_test_complex.proto > message_type[4] > options > (foo.bar.rept)[1] > array[5]: - Span: 103:20 -> 103:20 - desc_test_complex.proto > message_type[4] > options: Span: 107:3 -> 107:34 @@ -879,29 +861,17 @@ desc_test_complex.proto > message_type[4] > options > (foo.bar.a) > test > enums Span: 114:7 -> 114:9 desc_test_complex.proto > message_type[4] > options > (foo.bar.a) > test > enums[1]: - Span: 114:9 -> 114:10 - -desc_test_complex.proto > message_type[4] > options > (foo.bar.a) > test > enums[2]: Span: 115:7 -> 115:9 -desc_test_complex.proto > message_type[4] > options > (foo.bar.a) > test > enums[3]: - Span: 115:9 -> 115:10 - -desc_test_complex.proto > message_type[4] > options > (foo.bar.a) > test > enums[4]: +desc_test_complex.proto > message_type[4] > options > (foo.bar.a) > test > enums[2]: Span: 116:7 -> 116:9 -desc_test_complex.proto > message_type[4] > options > (foo.bar.a) > test > enums[5]: - Span: 117:5 -> 117:5 - desc_test_complex.proto > message_type[4] > options > (foo.bar.a) > test > bools[0]: Span: 118:5 -> 118:19 desc_test_complex.proto > message_type[4] > options > (foo.bar.a) > test > bools[0]: Span: 118:13 -> 118:18 -desc_test_complex.proto > message_type[4] > options > (foo.bar.a) > test > bools[1]: - Span: 118:18 -> 118:18 - desc_test_complex.proto > message_type[4] > options: Span: 120:3 -> 120:74