-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: improve edge case handling for recursion limits (#22988)
Co-authored-by: Alex | Skip <alex@skip.money> (cherry picked from commit 93282e1) # Conflicts: # CHANGELOG.md # x/tx/decode/unknown.go
- Loading branch information
1 parent
8e710b7
commit 0d5055a
Showing
4 changed files
with
213 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
package decode | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
"strings" | ||
|
||
"google.golang.org/protobuf/encoding/protowire" | ||
"google.golang.org/protobuf/proto" | ||
"google.golang.org/protobuf/reflect/protodesc" | ||
"google.golang.org/protobuf/reflect/protoreflect" | ||
"google.golang.org/protobuf/types/known/anypb" | ||
) | ||
|
||
const bit11NonCritical = 1 << 10 | ||
|
||
var ( | ||
anyDesc = (&anypb.Any{}).ProtoReflect().Descriptor() | ||
anyFullName = anyDesc.FullName() | ||
) | ||
|
||
// RejectUnknownFieldsStrict operates by the same rules as RejectUnknownFields, but returns an error if any unknown | ||
// non-critical fields are encountered. | ||
func RejectUnknownFieldsStrict(bz []byte, msg protoreflect.MessageDescriptor, resolver protodesc.Resolver) error { | ||
_, err := RejectUnknownFields(bz, msg, false, resolver) | ||
return err | ||
} | ||
|
||
// RejectUnknownFields rejects any bytes bz with an error that has unknown fields for the provided proto.Message type with an | ||
// option to allow non-critical fields (specified as those fields with bit 11) to pass through. In either case, the | ||
// hasUnknownNonCriticals will be set to true if non-critical fields were encountered during traversal. This flag can be | ||
// used to treat a message with non-critical field different in different security contexts (such as transaction signing). | ||
// This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message. | ||
// An AnyResolver must be provided for traversing inside google.protobuf.Any's. | ||
func RejectUnknownFields(bz []byte, desc protoreflect.MessageDescriptor, allowUnknownNonCriticals bool, resolver protodesc.Resolver) (hasUnknownNonCriticals bool, err error) { | ||
// recursion limit with same default as https://github.com/protocolbuffers/protobuf-go/blob/v1.35.2/encoding/protowire/wire.go#L28 | ||
return doRejectUnknownFields(bz, desc, allowUnknownNonCriticals, resolver, 10_000) | ||
} | ||
|
||
func doRejectUnknownFields( | ||
bz []byte, | ||
desc protoreflect.MessageDescriptor, | ||
allowUnknownNonCriticals bool, | ||
resolver protodesc.Resolver, | ||
recursionLimit int, | ||
) (hasUnknownNonCriticals bool, err error) { | ||
if len(bz) == 0 { | ||
return hasUnknownNonCriticals, nil | ||
} | ||
if recursionLimit <= 0 { | ||
return false, errors.New("recursion limit reached") | ||
} | ||
|
||
fields := desc.Fields() | ||
|
||
for len(bz) > 0 { | ||
tagNum, wireType, m := protowire.ConsumeTag(bz) | ||
if m < 0 { | ||
return hasUnknownNonCriticals, errors.New("invalid length") | ||
} | ||
|
||
fieldDesc := fields.ByNumber(tagNum) | ||
if fieldDesc == nil { | ||
isCriticalField := tagNum&bit11NonCritical == 0 | ||
|
||
if !isCriticalField { | ||
hasUnknownNonCriticals = true | ||
} | ||
|
||
if isCriticalField || !allowUnknownNonCriticals { | ||
// The tag is critical, so report it. | ||
return hasUnknownNonCriticals, ErrUnknownField.Wrapf( | ||
Check failure on line 72 in x/tx/decode/unknown.go GitHub Actions / dependency-review
|
||
"%s: {TagNum: %d, WireType:%q}", | ||
desc.FullName(), tagNum, WireTypeToString(wireType)) | ||
} | ||
} | ||
|
||
// Skip over the bytes that store fieldNumber and wireType bytes. | ||
bz = bz[m:] | ||
n := protowire.ConsumeFieldValue(tagNum, wireType, bz) | ||
if n < 0 { | ||
err = fmt.Errorf("could not consume field value for tagNum: %d, wireType: %q; %w", | ||
tagNum, WireTypeToString(wireType), protowire.ParseError(n)) | ||
return hasUnknownNonCriticals, err | ||
} | ||
fieldBytes := bz[:n] | ||
bz = bz[n:] | ||
|
||
// An unknown but non-critical field | ||
if fieldDesc == nil { | ||
continue | ||
} | ||
|
||
fieldMessage := fieldDesc.Message() | ||
// not message or group kind | ||
if fieldMessage == nil { | ||
continue | ||
} | ||
// if a message descriptor is a placeholder resolve it using the injected resolver. | ||
// this can happen when a descriptor has been registered in the | ||
// "google.golang.org/protobuf" registry but not in "github.com/cosmos/gogoproto". | ||
// fixes: https://github.com/cosmos/cosmos-sdk/issues/22574 | ||
if fieldMessage.IsPlaceholder() { | ||
gogoDesc, err := resolver.FindDescriptorByName(fieldMessage.FullName()) | ||
if err != nil { | ||
return hasUnknownNonCriticals, fmt.Errorf("could not resolve placeholder descriptor: %v: %w", fieldMessage, err) | ||
} | ||
fieldMessage = gogoDesc.(protoreflect.MessageDescriptor) | ||
} | ||
|
||
// consume length prefix of nested message | ||
_, o := protowire.ConsumeVarint(fieldBytes) | ||
if o < 0 { | ||
err = fmt.Errorf("could not consume length prefix fieldBytes for nested message: %v: %w", | ||
fieldMessage, protowire.ParseError(o)) | ||
return hasUnknownNonCriticals, err | ||
} else if o > len(fieldBytes) { | ||
err = fmt.Errorf("length prefix > len(fieldBytes) for nested message: %v", fieldMessage) | ||
return hasUnknownNonCriticals, err | ||
} | ||
|
||
fieldBytes = fieldBytes[o:] | ||
|
||
var err error | ||
|
||
if fieldMessage.FullName() == anyFullName { | ||
// Firstly typecheck types.Any to ensure nothing snuck in. | ||
hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, anyDesc, allowUnknownNonCriticals, resolver, recursionLimit-1) | ||
hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild | ||
if err != nil { | ||
return hasUnknownNonCriticals, err | ||
} | ||
var a anypb.Any | ||
if err = proto.Unmarshal(fieldBytes, &a); err != nil { | ||
return hasUnknownNonCriticals, err | ||
} | ||
|
||
msgName := protoreflect.FullName(strings.TrimPrefix(a.TypeUrl, "/")) | ||
msgDesc, err := resolver.FindDescriptorByName(msgName) | ||
if err != nil { | ||
return hasUnknownNonCriticals, err | ||
} | ||
|
||
fieldMessage = msgDesc.(protoreflect.MessageDescriptor) | ||
fieldBytes = a.Value | ||
} | ||
|
||
hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, fieldMessage, allowUnknownNonCriticals, resolver, recursionLimit-1) | ||
hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild | ||
if err != nil { | ||
return hasUnknownNonCriticals, err | ||
} | ||
} | ||
|
||
return hasUnknownNonCriticals, nil | ||
} | ||
|
||
// errUnknownField represents an error indicating that we encountered | ||
// a field that isn't available in the target proto.Message. | ||
type errUnknownField struct { | ||
Desc protoreflect.MessageDescriptor | ||
TagNum protowire.Number | ||
WireType protowire.Type | ||
} | ||
|
||
// String implements fmt.Stringer. | ||
func (twt *errUnknownField) String() string { | ||
return fmt.Sprintf("errUnknownField %q: {TagNum: %d, WireType:%q}", | ||
twt.Desc.FullName(), twt.TagNum, WireTypeToString(twt.WireType)) | ||
} | ||
|
||
// Error implements the error interface. | ||
func (twt *errUnknownField) Error() string { | ||
return twt.String() | ||
} | ||
|
||
var _ error = (*errUnknownField)(nil) | ||
|
||
// WireTypeToString returns a string representation of the given protowire.Type. | ||
func WireTypeToString(wt protowire.Type) string { | ||
switch wt { | ||
case 0: | ||
return "varint" | ||
case 1: | ||
return "fixed64" | ||
case 2: | ||
return "bytes" | ||
case 3: | ||
return "start_group" | ||
case 4: | ||
return "end_group" | ||
case 5: | ||
return "fixed32" | ||
default: | ||
return fmt.Sprintf("unknown type: %d", wt) | ||
} | ||
} |