Skip to content

Commit

Permalink
Slight improvement to walk package (bufbuild#287)
Browse files Browse the repository at this point in the history
I added benchmarks, in the hopes of finding a more efficient way to
traverse descriptor protos, from `walk.DescriptorProtos`. The main cost
is the way the fully-qualified names are computed/allocated as it
traverses the descriptor hierarchy.

While I did not come up with any meaningful improvements there, I was
able to improve the other walk function (`walk.Descriptors`), by making
fewer interface method calls, memoizing the results of the various
accessors. This improves throughput, consistently taking about 15%
less time per operation.
  • Loading branch information
jhump authored Apr 22, 2024
1 parent 292379e commit 63736ac
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 54 deletions.
109 changes: 55 additions & 54 deletions walk/walk.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,36 +65,18 @@ func Descriptors(file protoreflect.FileDescriptor, fn func(protoreflect.Descript
// The exit function is called using a post-order traversal, where the function
// is called for a descriptor only after it is called for any descendants.
func DescriptorsEnterAndExit(file protoreflect.FileDescriptor, enter, exit func(protoreflect.Descriptor) error) error {
for i := 0; i < file.Messages().Len(); i++ {
msg := file.Messages().Get(i)
if err := messageDescriptor(msg, enter, exit); err != nil {
return err
}
}
for i := 0; i < file.Enums().Len(); i++ {
en := file.Enums().Get(i)
if err := enumDescriptor(en, enter, exit); err != nil {
return err
}
}
for i := 0; i < file.Extensions().Len(); i++ {
ext := file.Extensions().Get(i)
if err := enter(ext); err != nil {
return err
}
if exit != nil {
if err := exit(ext); err != nil {
return err
}
}
if err := walkContainer(file, enter, exit); err != nil {
return err
}
for i := 0; i < file.Services().Len(); i++ {
svc := file.Services().Get(i)
services := file.Services()
for i, length := 0, services.Len(); i < length; i++ {
svc := services.Get(i)
if err := enter(svc); err != nil {
return err
}
for i := 0; i < svc.Methods().Len(); i++ {
mtd := svc.Methods().Get(i)
methods := svc.Methods()
for i, length := 0, methods.Len(); i < length; i++ {
mtd := methods.Get(i)
if err := enter(mtd); err != nil {
return err
}
Expand All @@ -113,12 +95,49 @@ func DescriptorsEnterAndExit(file protoreflect.FileDescriptor, enter, exit func(
return nil
}

type container interface {
Messages() protoreflect.MessageDescriptors
Enums() protoreflect.EnumDescriptors
Extensions() protoreflect.ExtensionDescriptors
}

func walkContainer(container container, enter, exit func(protoreflect.Descriptor) error) error {
messages := container.Messages()
for i, length := 0, messages.Len(); i < length; i++ {
msg := messages.Get(i)
if err := messageDescriptor(msg, enter, exit); err != nil {
return err
}
}
enums := container.Enums()
for i, length := 0, enums.Len(); i < length; i++ {
en := enums.Get(i)
if err := enumDescriptor(en, enter, exit); err != nil {
return err
}
}
exts := container.Extensions()
for i, length := 0, exts.Len(); i < length; i++ {
ext := exts.Get(i)
if err := enter(ext); err != nil {
return err
}
if exit != nil {
if err := exit(ext); err != nil {
return err
}
}
}
return nil
}

func messageDescriptor(msg protoreflect.MessageDescriptor, enter, exit func(protoreflect.Descriptor) error) error {
if err := enter(msg); err != nil {
return err
}
for i := 0; i < msg.Fields().Len(); i++ {
fld := msg.Fields().Get(i)
fields := msg.Fields()
for i, length := 0, fields.Len(); i < length; i++ {
fld := fields.Get(i)
if err := enter(fld); err != nil {
return err
}
Expand All @@ -128,8 +147,9 @@ func messageDescriptor(msg protoreflect.MessageDescriptor, enter, exit func(prot
}
}
}
for i := 0; i < msg.Oneofs().Len(); i++ {
oo := msg.Oneofs().Get(i)
oneofs := msg.Oneofs()
for i, length := 0, oneofs.Len(); i < length; i++ {
oo := oneofs.Get(i)
if err := enter(oo); err != nil {
return err
}
Expand All @@ -139,28 +159,8 @@ func messageDescriptor(msg protoreflect.MessageDescriptor, enter, exit func(prot
}
}
}
for i := 0; i < msg.Messages().Len(); i++ {
nested := msg.Messages().Get(i)
if err := messageDescriptor(nested, enter, exit); err != nil {
return err
}
}
for i := 0; i < msg.Enums().Len(); i++ {
en := msg.Enums().Get(i)
if err := enumDescriptor(en, enter, exit); err != nil {
return err
}
}
for i := 0; i < msg.Extensions().Len(); i++ {
ext := msg.Extensions().Get(i)
if err := enter(ext); err != nil {
return err
}
if exit != nil {
if err := exit(ext); err != nil {
return err
}
}
if err := walkContainer(msg, enter, exit); err != nil {
return err
}
if exit != nil {
if err := exit(msg); err != nil {
Expand All @@ -174,8 +174,9 @@ func enumDescriptor(en protoreflect.EnumDescriptor, enter, exit func(protoreflec
if err := enter(en); err != nil {
return err
}
for i := 0; i < en.Values().Len(); i++ {
enVal := en.Values().Get(i)
vals := en.Values()
for i, length := 0, vals.Len(); i < length; i++ {
enVal := vals.Get(i)
if err := enter(enVal); err != nil {
return err
}
Expand Down
45 changes: 45 additions & 0 deletions walk/walk_benchmark_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright 2020-2024 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package walk

import (
"testing"

"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/descriptorpb"
)

func BenchmarkDescriptors(b *testing.B) {
file := (*descriptorpb.FileDescriptorProto)(nil).ProtoReflect().Descriptor().ParentFile()
for i := 0; i < b.N; i++ {
err := Descriptors(file, func(_ protoreflect.Descriptor) error {
return nil
})
require.NoError(b, err)
}
}

func BenchmarkDescriptorProtos(b *testing.B) {
file := protodesc.ToFileDescriptorProto((*descriptorpb.FileDescriptorProto)(nil).ProtoReflect().Descriptor().ParentFile())
for i := 0; i < b.N; i++ {
err := DescriptorProtos(file, func(_ protoreflect.FullName, _ proto.Message) error {
return nil
})
require.NoError(b, err)
}
}
78 changes: 78 additions & 0 deletions walk/walk_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright 2020-2024 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package walk

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/descriptorpb"
)

func TestDescriptorProtosEnterAndExit(t *testing.T) {
t.Parallel()
file := protodesc.ToFileDescriptorProto((*descriptorpb.FileDescriptorProto)(nil).ProtoReflect().Descriptor().ParentFile())
nameStack := []string{file.GetPackage()}
err := DescriptorProtosEnterAndExit(
file,
func(fullName protoreflect.FullName, message proto.Message) error {
switch d := message.(type) {
case *descriptorpb.DescriptorProto:
expected := joinNames(nameStack[len(nameStack)-1], d.GetName())
assert.Equal(t, expected, string(fullName))
case *descriptorpb.FieldDescriptorProto:
expected := joinNames(nameStack[len(nameStack)-1], d.GetName())
assert.Equal(t, expected, string(fullName))
case *descriptorpb.OneofDescriptorProto:
expected := joinNames(nameStack[len(nameStack)-1], d.GetName())
assert.Equal(t, expected, string(fullName))
case *descriptorpb.EnumDescriptorProto:
expected := joinNames(nameStack[len(nameStack)-1], d.GetName())
assert.Equal(t, expected, string(fullName))
case *descriptorpb.EnumValueDescriptorProto:
// we look at the NEXT to last item on stack because enums are
// defined not in the enum but in its enclosing scope
expected := joinNames(nameStack[len(nameStack)-2], d.GetName())
assert.Equal(t, expected, string(fullName))
case *descriptorpb.ServiceDescriptorProto:
expected := joinNames(nameStack[len(nameStack)-1], d.GetName())
assert.Equal(t, expected, string(fullName))
case *descriptorpb.MethodDescriptorProto:
expected := joinNames(nameStack[len(nameStack)-1], d.GetName())
assert.Equal(t, expected, string(fullName))
default:
t.Fatalf("unknown descriptor type: %T", d)
}
nameStack = append(nameStack, string(fullName))
return nil
},
func(name protoreflect.FullName, message proto.Message) error {
nameStack = nameStack[:len(nameStack)-1]
return nil
},
)
require.NoError(t, err)
}

func joinNames(prefix, name string) string {
if len(prefix) == 0 {
return name
}
return prefix + "." + name
}

0 comments on commit 63736ac

Please sign in to comment.