Skip to content

Commit

Permalink
refactor: improve edge case handling for recursion limits (#22988)
Browse files Browse the repository at this point in the history
Co-authored-by: Alex | Skip <alex@skip.money>
(cherry picked from commit 93282e1)

# Conflicts:
#	CHANGELOG.md
#	x/tx/decode/unknown.go
  • Loading branch information
haiyizxx authored and mergify[bot] committed Jan 6, 2025
1 parent 8e710b7 commit 0d5055a
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 3 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@ Ref: https://keepachangelog.com/en/1.0.0/

Every module contains its own CHANGELOG.md. Please refer to the module you are interested in.

<<<<<<< HEAD
=======
### Features

* (baseapp) [#20291](https://github.com/cosmos/cosmos-sdk/pull/20291) Simulate nested messages.
* (client/keys) [#21829](https://github.com/cosmos/cosmos-sdk/pull/21829) Add support for importing hex key using standard input.
* (x/auth/ante) [#23128](https://github.com/cosmos/cosmos-sdk/pull/23128) Allow custom verifyIsOnCurve when validate tx for public key like ethsecp256k1.

### Improvements

* (codec) [#22988](https://github.com/cosmos/cosmos-sdk/pull/22988) Improve edge case handling for recursion limits.

>>>>>>> 93282e101 (refactor: improve edge case handling for recursion limits (#22988))
### Bug Fixes

* (x/auth/tx) [#23148](https://github.com/cosmos/cosmos-sdk/pull/23148) Avoid panic from intoAnyV2 when v1.PublicKey is optional.
Expand Down
4 changes: 2 additions & 2 deletions codec/types/interface_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,10 @@ func (r statefulUnpacker) cloneForRecursion() *statefulUnpacker {
// UnpackAny deserializes a protobuf Any message into the provided interface, ensuring the interface is a pointer.
// It applies stateful constraints such as max depth and call limits, and unpacks interfaces if required.
func (r *statefulUnpacker) UnpackAny(any *Any, iface interface{}) error {
if r.maxDepth == 0 {
if r.maxDepth <= 0 {
return errors.New("max depth exceeded")
}
if r.maxCalls.count == 0 {
if r.maxCalls.count <= 0 {
return errors.New("call limit exceeded")
}
// here we gracefully handle the case in which `any` itself is `nil`, which may occur in message decoding
Expand Down
2 changes: 1 addition & 1 deletion codec/unknownproto/unknown_fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func doRejectUnknownFields(
if len(bz) == 0 {
return hasUnknownNonCriticals, nil
}
if recursionLimit == 0 {
if recursionLimit <= 0 {
return false, errors.New("recursion limit reached")
}

Expand Down
197 changes: 197 additions & 0 deletions x/tx/decode/unknown.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
package decode

import (
"errors"
"fmt"
"strings"

"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/anypb"
)

const bit11NonCritical = 1 << 10

var (
anyDesc = (&anypb.Any{}).ProtoReflect().Descriptor()
anyFullName = anyDesc.FullName()
)

// RejectUnknownFieldsStrict operates by the same rules as RejectUnknownFields, but returns an error if any unknown
// non-critical fields are encountered.
func RejectUnknownFieldsStrict(bz []byte, msg protoreflect.MessageDescriptor, resolver protodesc.Resolver) error {
_, err := RejectUnknownFields(bz, msg, false, resolver)
return err
}

// RejectUnknownFields rejects any bytes bz with an error that has unknown fields for the provided proto.Message type with an
// option to allow non-critical fields (specified as those fields with bit 11) to pass through. In either case, the
// hasUnknownNonCriticals will be set to true if non-critical fields were encountered during traversal. This flag can be
// used to treat a message with non-critical field different in different security contexts (such as transaction signing).
// This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message.
// An AnyResolver must be provided for traversing inside google.protobuf.Any's.
func RejectUnknownFields(bz []byte, desc protoreflect.MessageDescriptor, allowUnknownNonCriticals bool, resolver protodesc.Resolver) (hasUnknownNonCriticals bool, err error) {
// recursion limit with same default as https://github.com/protocolbuffers/protobuf-go/blob/v1.35.2/encoding/protowire/wire.go#L28
return doRejectUnknownFields(bz, desc, allowUnknownNonCriticals, resolver, 10_000)
}

func doRejectUnknownFields(
bz []byte,
desc protoreflect.MessageDescriptor,
allowUnknownNonCriticals bool,
resolver protodesc.Resolver,
recursionLimit int,
) (hasUnknownNonCriticals bool, err error) {
if len(bz) == 0 {
return hasUnknownNonCriticals, nil
}
if recursionLimit <= 0 {
return false, errors.New("recursion limit reached")
}

fields := desc.Fields()

for len(bz) > 0 {
tagNum, wireType, m := protowire.ConsumeTag(bz)
if m < 0 {
return hasUnknownNonCriticals, errors.New("invalid length")
}

fieldDesc := fields.ByNumber(tagNum)
if fieldDesc == nil {
isCriticalField := tagNum&bit11NonCritical == 0

if !isCriticalField {
hasUnknownNonCriticals = true
}

if isCriticalField || !allowUnknownNonCriticals {
// The tag is critical, so report it.
return hasUnknownNonCriticals, ErrUnknownField.Wrapf(

Check failure on line 72 in x/tx/decode/unknown.go

View workflow job for this annotation

GitHub Actions / dependency-review

undefined: ErrUnknownField

Check failure on line 72 in x/tx/decode/unknown.go

View workflow job for this annotation

GitHub Actions / golangci-lint

undefined: ErrUnknownField (typecheck)
"%s: {TagNum: %d, WireType:%q}",
desc.FullName(), tagNum, WireTypeToString(wireType))
}
}

// Skip over the bytes that store fieldNumber and wireType bytes.
bz = bz[m:]
n := protowire.ConsumeFieldValue(tagNum, wireType, bz)
if n < 0 {
err = fmt.Errorf("could not consume field value for tagNum: %d, wireType: %q; %w",
tagNum, WireTypeToString(wireType), protowire.ParseError(n))
return hasUnknownNonCriticals, err
}
fieldBytes := bz[:n]
bz = bz[n:]

// An unknown but non-critical field
if fieldDesc == nil {
continue
}

fieldMessage := fieldDesc.Message()
// not message or group kind
if fieldMessage == nil {
continue
}
// if a message descriptor is a placeholder resolve it using the injected resolver.
// this can happen when a descriptor has been registered in the
// "google.golang.org/protobuf" registry but not in "github.com/cosmos/gogoproto".
// fixes: https://github.com/cosmos/cosmos-sdk/issues/22574
if fieldMessage.IsPlaceholder() {
gogoDesc, err := resolver.FindDescriptorByName(fieldMessage.FullName())
if err != nil {
return hasUnknownNonCriticals, fmt.Errorf("could not resolve placeholder descriptor: %v: %w", fieldMessage, err)
}
fieldMessage = gogoDesc.(protoreflect.MessageDescriptor)
}

// consume length prefix of nested message
_, o := protowire.ConsumeVarint(fieldBytes)
if o < 0 {
err = fmt.Errorf("could not consume length prefix fieldBytes for nested message: %v: %w",
fieldMessage, protowire.ParseError(o))
return hasUnknownNonCriticals, err
} else if o > len(fieldBytes) {
err = fmt.Errorf("length prefix > len(fieldBytes) for nested message: %v", fieldMessage)
return hasUnknownNonCriticals, err
}

fieldBytes = fieldBytes[o:]

var err error

if fieldMessage.FullName() == anyFullName {
// Firstly typecheck types.Any to ensure nothing snuck in.
hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, anyDesc, allowUnknownNonCriticals, resolver, recursionLimit-1)
hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
if err != nil {
return hasUnknownNonCriticals, err
}
var a anypb.Any
if err = proto.Unmarshal(fieldBytes, &a); err != nil {
return hasUnknownNonCriticals, err
}

msgName := protoreflect.FullName(strings.TrimPrefix(a.TypeUrl, "/"))
msgDesc, err := resolver.FindDescriptorByName(msgName)
if err != nil {
return hasUnknownNonCriticals, err
}

fieldMessage = msgDesc.(protoreflect.MessageDescriptor)
fieldBytes = a.Value
}

hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, fieldMessage, allowUnknownNonCriticals, resolver, recursionLimit-1)
hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
if err != nil {
return hasUnknownNonCriticals, err
}
}

return hasUnknownNonCriticals, nil
}

// errUnknownField represents an error indicating that we encountered
// a field that isn't available in the target proto.Message.
type errUnknownField struct {
Desc protoreflect.MessageDescriptor
TagNum protowire.Number
WireType protowire.Type
}

// String implements fmt.Stringer.
func (twt *errUnknownField) String() string {
return fmt.Sprintf("errUnknownField %q: {TagNum: %d, WireType:%q}",
twt.Desc.FullName(), twt.TagNum, WireTypeToString(twt.WireType))
}

// Error implements the error interface.
func (twt *errUnknownField) Error() string {
return twt.String()
}

var _ error = (*errUnknownField)(nil)

// WireTypeToString returns a string representation of the given protowire.Type.
func WireTypeToString(wt protowire.Type) string {
switch wt {
case 0:
return "varint"
case 1:
return "fixed64"
case 2:
return "bytes"
case 3:
return "start_group"
case 4:
return "end_group"
case 5:
return "fixed32"
default:
return fmt.Sprintf("unknown type: %d", wt)
}
}

0 comments on commit 0d5055a

Please sign in to comment.