Skip to content

Commit

Permalink
refactor protoparse, fixes multiple issues (#316)
Browse files Browse the repository at this point in the history
- defers a lot of validation to after the AST is fully constructed
- in particular, defers tag validation to after message options are parsed
  in order to know if it has message_set_wire_format option, which impacts
  allowed tag range
- fixes issues with very large constant numbers (that overflow uint64 or
  underflow int64)
- refactors grammar around option names and fixes issue where extensions
  on message options aren't parsed correctly
- fixes issues related to groups in oneofs
- fixes issues with reserved name validation
- adds some util methods to shrink the boiler-plate for error creation
  and error handling
- breaks up monolithic parser.go into three files: parser.go, validation.go
  and descriptor_protos.go
- adds numerous new validation test cases to catch various issues that
  were fixed
- includes updates to protoprint and builder packages related to
  messageset wire format
  • Loading branch information
jhump authored Apr 27, 2020
1 parent b97137b commit b6666e6
Show file tree
Hide file tree
Showing 34 changed files with 3,720 additions and 3,609 deletions.
7 changes: 6 additions & 1 deletion desc/builder/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ func (flb *FieldBuilder) GetExtendeeTypeName() string {
}
}

func (flb *FieldBuilder) buildProto(path []int32, sourceInfo *dpb.SourceCodeInfo) (*dpb.FieldDescriptorProto, error) {
func (flb *FieldBuilder) buildProto(path []int32, sourceInfo *dpb.SourceCodeInfo, isMessageSet bool) (*dpb.FieldDescriptorProto, error) {
addCommentsTo(sourceInfo, path, &flb.comments)

var lbl *dpb.FieldDescriptorProto_Label
Expand All @@ -508,6 +508,11 @@ func (flb *FieldBuilder) buildProto(path []int32, sourceInfo *dpb.SourceCodeInfo
def = proto.String(flb.Default)
}

maxTag := internal.GetMaxTag(isMessageSet)
if flb.number > maxTag {
return nil, fmt.Errorf("tag for field %s cannot be above max %d", GetFullyQualifiedName(flb), maxTag)
}

return &dpb.FieldDescriptorProto{
Name: proto.String(flb.name),
Number: proto.Int32(flb.number),
Expand Down
9 changes: 8 additions & 1 deletion desc/builder/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ func (fb *FileBuilder) buildProto(deps []*desc.FileDescriptor) (*dpb.FileDescrip
extensions := make([]*dpb.FieldDescriptorProto, 0, len(fb.extensions))
for _, exb := range fb.extensions {
path := append(path, internal.File_extensionsTag, int32(len(extensions)))
if exd, err := exb.buildProto(path, &sourceInfo); err != nil {
if exd, err := exb.buildProto(path, &sourceInfo, isExtendeeMessageSet(exb)); err != nil {
return nil, err
} else {
extensions = append(extensions, exd)
Expand Down Expand Up @@ -699,6 +699,13 @@ func (fb *FileBuilder) buildProto(deps []*desc.FileDescriptor) (*dpb.FileDescrip
}, nil
}

func isExtendeeMessageSet(flb *FieldBuilder) bool {
if flb.localExtendee != nil {
return flb.localExtendee.Options.GetMessageSetWireFormat()
}
return flb.foreignExtendee.GetMessageOptions().GetMessageSetWireFormat()
}

// Build constructs a file descriptor based on the contents of this file
// builder. If there are any problems constructing the descriptor, including
// resolving symbols referenced by the builder or failing to meet certain
Expand Down
6 changes: 3 additions & 3 deletions desc/builder/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ func (mb *MessageBuilder) buildProto(path []int32, sourceInfo *dpb.SourceCodeInf
for _, b := range mb.fieldsAndOneOfs {
if flb, ok := b.(*FieldBuilder); ok {
fldpath := append(path, internal.Message_fieldsTag, int32(len(fields)))
fld, err := flb.buildProto(fldpath, sourceInfo)
fld, err := flb.buildProto(fldpath, sourceInfo, mb.Options.GetMessageSetWireFormat())
if err != nil {
return nil, err
}
Expand All @@ -729,7 +729,7 @@ func (mb *MessageBuilder) buildProto(path []int32, sourceInfo *dpb.SourceCodeInf
oneOfs = append(oneOfs, ood)
for _, flb := range oob.choices {
path := append(path, internal.Message_fieldsTag, int32(len(fields)))
fld, err := flb.buildProto(path, sourceInfo)
fld, err := flb.buildProto(path, sourceInfo, mb.Options.GetMessageSetWireFormat())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -775,7 +775,7 @@ func (mb *MessageBuilder) buildProto(path []int32, sourceInfo *dpb.SourceCodeInf
nestedExtensions := make([]*dpb.FieldDescriptorProto, 0, len(mb.nestedExtensions))
for _, exb := range mb.nestedExtensions {
path := append(path, internal.Message_extensionsTag, int32(len(nestedExtensions)))
if exd, err := exb.buildProto(path, sourceInfo); err != nil {
if exd, err := exb.buildProto(path, sourceInfo, isExtendeeMessageSet(exb)); err != nil {
return nil, err
} else {
nestedExtensions = append(nestedExtensions, exd)
Expand Down
22 changes: 20 additions & 2 deletions desc/internal/util.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
package internal

import (
"math"
"unicode"
"unicode/utf8"
)

const (
// MaxTag is the maximum allowed tag number for a field.
MaxTag = 536870911 // 2^29 - 1
// MaxNormalTag is the maximum allowed tag number for a field in a normal message.
MaxNormalTag = 536870911 // 2^29 - 1

// MaxMessageSetTag is the maximum allowed tag number of a field in a message that
// uses the message set wire format.
MaxMessageSetTag = math.MaxInt32 - 1

// MaxTag is the maximum allowed tag number. (It is the same as MaxMessageSetTag
// since that is the absolute highest allowed.)
MaxTag = MaxMessageSetTag

// SpecialReservedStart is the first tag in a range that is reserved and not
// allowed for use in message definitions.
Expand Down Expand Up @@ -268,3 +277,12 @@ func CreatePrefixList(pkg string) []string {

return prefixes
}

// GetMaxTag returns the max tag number allowed, based on whether a message uses
// message set wire format or not.
func GetMaxTag(isMessageSet bool) int32 {
if isMessageSet {
return MaxMessageSetTag
}
return MaxNormalTag
}
69 changes: 65 additions & 4 deletions desc/protoparse/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,33 +373,65 @@ func (n *compoundStringNode) value() interface{} {
return n.val
}

type intLiteral interface {
asInt32(min, max int32) (int32, bool)
value() interface{}
}

type intLiteralNode struct {
basicNode
val uint64
}

var _ intLiteral = (*intLiteralNode)(nil)

func (n *intLiteralNode) value() interface{} {
return n.val
}

func (n *intLiteralNode) asInt32(min, max int32) (int32, bool) {
if (min >= 0 && n.val < uint64(min)) || n.val > uint64(max) {
return 0, false
}
return int32(n.val), true
}

type compoundUintNode struct {
basicCompositeNode
val uint64
}

var _ intLiteral = (*compoundUintNode)(nil)

func (n *compoundUintNode) value() interface{} {
return n.val
}

func (n *compoundUintNode) asInt32(min, max int32) (int32, bool) {
if (min >= 0 && n.val < uint64(min)) || n.val > uint64(max) {
return 0, false
}
return int32(n.val), true
}

type compoundIntNode struct {
basicCompositeNode
val int64
}

var _ intLiteral = (*compoundIntNode)(nil)

func (n *compoundIntNode) value() interface{} {
return n.val
}

func (n *compoundIntNode) asInt32(min, max int32) (int32, bool) {
if n.val < int64(min) || n.val > int64(max) {
return 0, false
}
return int32(n.val), true
}

type floatLiteralNode struct {
basicNode
val float64
Expand Down Expand Up @@ -728,16 +760,45 @@ type extensionRangeNode struct {

type rangeNode struct {
basicCompositeNode
stNode, enNode node
st, en int32
startNode, endNode node
endMax bool
}

func (n *rangeNode) rangeStart() node {
return n.stNode
return n.startNode
}

func (n *rangeNode) rangeEnd() node {
return n.enNode
if n.endNode == nil {
return n.startNode
}
return n.endNode
}

func (n *rangeNode) startValue() interface{} {
return n.startNode.(intLiteral).value()
}

func (n *rangeNode) startValueAsInt32(min, max int32) (int32, bool) {
return n.startNode.(intLiteral).asInt32(min, max)
}

func (n *rangeNode) endValue() interface{} {
l, ok := n.endNode.(intLiteral)
if !ok {
return nil
}
return l.value()
}

func (n *rangeNode) endValueAsInt32(min, max int32) (int32, bool) {
if n.endMax {
return max, true
}
if n.endNode == nil {
return n.startValueAsInt32(min, max)
}
return n.endNode.(intLiteral).asInt32(min, max)
}

type reservedNode struct {
Expand Down
Loading

0 comments on commit b6666e6

Please sign in to comment.