Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unmarshal unknown extensions into XXX_unrecognized instead of into extenson map #386

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions proto/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,8 @@ func (p *Buffer) Unmarshal(pb Message) error {
func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group bool, base structPointer) error {
var state errorState
required, reqFields := prop.reqCount, uint64(0)
var regExt map[int32]*ExtensionDesc
regExtInit := false

var err error
for err == nil && o.index < len(o.buf) {
Expand Down Expand Up @@ -492,11 +494,22 @@ func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group
// Maybe it's an extension?
if prop.extendable {
if e, _ := extendable(structPointer_Interface(base, st)); isExtensionField(e, int32(tag)) {
if err = o.skip(st, tag, wire); err == nil {
extmap := e.extensionsWrite()
ext := extmap[int32(tag)] // may be missing
ext.enc = append(ext.enc, o.buf[oi:o.index]...)
extmap[int32(tag)] = ext
if !regExtInit {
msgType := reflect.Zero(reflect.PtrTo(st)).Interface().(Message)
regExt = RegisteredExtensions(msgType)
Copy link
Contributor Author

@jhump jhump Jun 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's more clear to just do regExt := RegisteredExtensions(msgType) here and skip memoizing from one iteration of the loop to the next. But I'm leary of doing any more work than necessary in something like proto unmarshaling since that is likely to be a hot path in many servers. (Though, admittedly, extensions are likely used infrequently enough that this is not necessary, especially since the work we are saving is just a single map lookup for string key. Should I omit this? Thoughts, suggestions?)

}
extdesc := regExt[int32(tag)]
if extdesc == nil {
// unknown extension
err = o.skipAndSave(st, tag, wire, base, prop.unrecField)
} else {
if err = o.skip(st, tag, wire); err == nil {
extmap := e.extensionsWrite()
ext := extmap[int32(tag)] // may be missing
ext.enc = append(ext.enc, o.buf[oi:o.index]...)
ext.desc = extdesc
extmap[int32(tag)] = ext
}
}
continue
}
Expand Down
8 changes: 5 additions & 3 deletions proto/extensions.go
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,10 @@ func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, e
}

// ExtensionDescs returns a new slice containing pb's extension descriptors, in undefined order.
// For non-registered extensions, ExtensionDescs returns an incomplete descriptor containing
// just the Field field, which defines the extension's field number.
// If the message was de-serialized from a stream that referenced unknown extensions (e.g. fields
// with a tag number in an extension range, but not registered), they will not be returned by
// this function. Instead, they can only be found by examining the message's XXX_unrecognized
// data.
func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) {
epb, ok := extendable(pb)
if !ok {
Expand All @@ -512,7 +514,7 @@ func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) {
if desc == nil {
desc = registeredExtensions[extid]
if desc == nil {
desc = &ExtensionDesc{Field: extid}
continue
}
}

Expand Down
71 changes: 62 additions & 9 deletions proto/extensions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ package proto_test
import (
"bytes"
"fmt"
"math"
"reflect"
"sort"
"testing"

"github.com/golang/protobuf/proto"
pb "github.com/golang/protobuf/proto/testdata"
"golang.org/x/sync/errgroup"
"io"
)

func TestGetExtensionsWithMissingExtensions(t *testing.T) {
Expand All @@ -64,27 +66,40 @@ func TestGetExtensionsWithMissingExtensions(t *testing.T) {
}
}

func TestExtensionDescsWithMissingExtensions(t *testing.T) {
func TestExtensionDescsWithUnrecognizedExtensions(t *testing.T) {
msg := &pb.MyMessage{Count: proto.Int32(0)}
extdesc1 := pb.E_Ext_More
if descs, err := proto.ExtensionDescs(msg); len(descs) != 0 || err != nil {
t.Errorf("proto.ExtensionDescs: got %d descs, error %v; want 0, nil", len(descs), err)
}

ext1 := &pb.Ext{}
if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
extdesc1 := pb.E_Ext_More
if err := proto.SetExtension(msg, extdesc1, &pb.Ext{}); err != nil {
t.Fatalf("Could not set ext1: %s", err)
}
extdesc2 := &proto.ExtensionDesc{
extdesc2 := pb.E_Ext_Number
if err := proto.SetExtension(msg, extdesc2, proto.Int32(int32(101))); err != nil {
t.Fatalf("Could not set ext2: %s", err)
}

unknownExtdesc1 := &proto.ExtensionDesc{
ExtendedType: (*pb.MyMessage)(nil),
ExtensionType: (*bool)(nil),
Field: 123456789,
Name: "a.b",
Tag: "varint,123456789,opt",
}
ext2 := proto.Bool(false)
if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
t.Fatalf("Could not set ext2: %s", err)
if err := proto.SetExtension(msg, unknownExtdesc1, proto.Bool(true)); err != nil {
t.Fatalf("Could not set unknownExtdesc1: %s", err)
}
unknownExtdesc2 := &proto.ExtensionDesc{
ExtendedType: (*pb.MyMessage)(nil),
ExtensionType: (*float64)(nil),
Field: 123456790,
Name: "a.c",
Tag: "fixed64,123456790,opt",
}
if err := proto.SetExtension(msg, unknownExtdesc2, proto.Float64(12.34)); err != nil {
t.Fatalf("Could not set unknownExtdesc2: %s", err)
}

b, err := proto.Marshal(msg)
Expand All @@ -100,10 +115,48 @@ func TestExtensionDescsWithMissingExtensions(t *testing.T) {
t.Fatalf("proto.ExtensionDescs: got error %v", err)
}
sortExtDescs(descs)
wantDescs := []*proto.ExtensionDesc{extdesc1, &proto.ExtensionDesc{Field: extdesc2.Field}}
wantDescs := []*proto.ExtensionDesc{extdesc1, extdesc2}
if !reflect.DeepEqual(descs, wantDescs) {
t.Errorf("proto.ExtensionDescs(msg) sorted extension ids: got %+v, want %+v", descs, wantDescs)
}

// make sure the unrecognized fields are serialized correctly
bb := proto.NewBuffer(msg.XXX_unrecognized)

// unrecognized extension #1
expectedTagAndWire := uint64((unknownExtdesc1.Field << 3) | proto.WireVarint)
if tagAndWire, err := bb.DecodeVarint(); err != nil {
t.Fatalf("Could not read unrecognized field tag and wire type: %v", err)
} else if tagAndWire != expectedTagAndWire {
t.Fatalf("Wrong tag and wire type: %d != %d", tagAndWire, expectedTagAndWire)
}
if val, err := bb.DecodeVarint(); err != nil {
t.Fatalf("Could not read unrecognized field value: %v", err)
} else if val != 1 /* varint value of bool "true" */ {
t.Fatalf("Wrong value for unrecognized extension 1: %d != 1", val)
}

// unrecognized extension #2
expectedTagAndWire = uint64((unknownExtdesc2.Field << 3) | proto.WireFixed64)
if tagAndWire, err := bb.DecodeVarint(); err != nil {
t.Fatalf("Could not read unrecognized field tag and wire type: %v", err)
} else if tagAndWire != expectedTagAndWire {
t.Fatalf("Wrong tag and wire type: %d != %d", tagAndWire, expectedTagAndWire)
}
if val, err := bb.DecodeFixed64(); err != nil {
t.Fatalf("Could not read unrecognized field value: %v", err)
} else if math.Float64frombits(val) != 12.34 {
t.Fatalf("Wrong value for unrecognized extension 1: %f != 12.34", math.Float64frombits(val))
}

// we should have reached EOF of the unknown fields
if _, err := bb.DecodeFixed32(); err != io.ErrUnexpectedEOF {
if err == nil {
t.Fatalf("Unexpected unrecognized data after expected extensions")
} else {
t.Fatalf("Unexpected error checking for buffer EOF: %v", err)
}
}
}

type ExtensionDescSlice []*proto.ExtensionDesc
Expand Down