From 0d5055a52a611df4a6291517457d0d0144cdae20 Mon Sep 17 00:00:00 2001 From: haiyizxx Date: Mon, 6 Jan 2025 03:26:46 -0500 Subject: [PATCH 1/2] refactor: improve edge case handling for recursion limits (#22988) Co-authored-by: Alex | Skip (cherry picked from commit 93282e101d3804c59716bc3b30f7d43221ee8c43) # Conflicts: # CHANGELOG.md # x/tx/decode/unknown.go --- CHANGELOG.md | 13 ++ codec/types/interface_registry.go | 4 +- codec/unknownproto/unknown_fields.go | 2 +- x/tx/decode/unknown.go | 197 +++++++++++++++++++++++++++ 4 files changed, 213 insertions(+), 3 deletions(-) create mode 100644 x/tx/decode/unknown.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 42ca411e2f97..cef82d21ef78 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,19 @@ Ref: https://keepachangelog.com/en/1.0.0/ Every module contains its own CHANGELOG.md. Please refer to the module you are interested in. +<<<<<<< HEAD +======= +### Features + +* (baseapp) [#20291](https://github.com/cosmos/cosmos-sdk/pull/20291) Simulate nested messages. +* (client/keys) [#21829](https://github.com/cosmos/cosmos-sdk/pull/21829) Add support for importing hex key using standard input. +* (x/auth/ante) [#23128](https://github.com/cosmos/cosmos-sdk/pull/23128) Allow custom verifyIsOnCurve when validate tx for public key like ethsecp256k1. + +### Improvements + +* (codec) [#22988](https://github.com/cosmos/cosmos-sdk/pull/22988) Improve edge case handling for recursion limits. + +>>>>>>> 93282e101 (refactor: improve edge case handling for recursion limits (#22988)) ### Bug Fixes * (x/auth/tx) [#23148](https://github.com/cosmos/cosmos-sdk/pull/23148) Avoid panic from intoAnyV2 when v1.PublicKey is optional. diff --git a/codec/types/interface_registry.go b/codec/types/interface_registry.go index 34d59bd33a46..68ed8c885d9f 100644 --- a/codec/types/interface_registry.go +++ b/codec/types/interface_registry.go @@ -274,10 +274,10 @@ func (r statefulUnpacker) cloneForRecursion() *statefulUnpacker { // UnpackAny deserializes a protobuf Any message into the provided interface, ensuring the interface is a pointer. // It applies stateful constraints such as max depth and call limits, and unpacks interfaces if required. func (r *statefulUnpacker) UnpackAny(any *Any, iface interface{}) error { - if r.maxDepth == 0 { + if r.maxDepth <= 0 { return errors.New("max depth exceeded") } - if r.maxCalls.count == 0 { + if r.maxCalls.count <= 0 { return errors.New("call limit exceeded") } // here we gracefully handle the case in which `any` itself is `nil`, which may occur in message decoding diff --git a/codec/unknownproto/unknown_fields.go b/codec/unknownproto/unknown_fields.go index 17b8f7e424ee..a60f2f9caac8 100644 --- a/codec/unknownproto/unknown_fields.go +++ b/codec/unknownproto/unknown_fields.go @@ -54,7 +54,7 @@ func doRejectUnknownFields( if len(bz) == 0 { return hasUnknownNonCriticals, nil } - if recursionLimit == 0 { + if recursionLimit <= 0 { return false, errors.New("recursion limit reached") } diff --git a/x/tx/decode/unknown.go b/x/tx/decode/unknown.go new file mode 100644 index 000000000000..ce608b32a4ba --- /dev/null +++ b/x/tx/decode/unknown.go @@ -0,0 +1,197 @@ +package decode + +import ( + "errors" + "fmt" + "strings" + + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/known/anypb" +) + +const bit11NonCritical = 1 << 10 + +var ( + anyDesc = (&anypb.Any{}).ProtoReflect().Descriptor() + anyFullName = anyDesc.FullName() +) + +// RejectUnknownFieldsStrict operates by the same rules as RejectUnknownFields, but returns an error if any unknown +// non-critical fields are encountered. +func RejectUnknownFieldsStrict(bz []byte, msg protoreflect.MessageDescriptor, resolver protodesc.Resolver) error { + _, err := RejectUnknownFields(bz, msg, false, resolver) + return err +} + +// RejectUnknownFields rejects any bytes bz with an error that has unknown fields for the provided proto.Message type with an +// option to allow non-critical fields (specified as those fields with bit 11) to pass through. In either case, the +// hasUnknownNonCriticals will be set to true if non-critical fields were encountered during traversal. This flag can be +// used to treat a message with non-critical field different in different security contexts (such as transaction signing). +// This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message. +// An AnyResolver must be provided for traversing inside google.protobuf.Any's. +func RejectUnknownFields(bz []byte, desc protoreflect.MessageDescriptor, allowUnknownNonCriticals bool, resolver protodesc.Resolver) (hasUnknownNonCriticals bool, err error) { + // recursion limit with same default as https://github.com/protocolbuffers/protobuf-go/blob/v1.35.2/encoding/protowire/wire.go#L28 + return doRejectUnknownFields(bz, desc, allowUnknownNonCriticals, resolver, 10_000) +} + +func doRejectUnknownFields( + bz []byte, + desc protoreflect.MessageDescriptor, + allowUnknownNonCriticals bool, + resolver protodesc.Resolver, + recursionLimit int, +) (hasUnknownNonCriticals bool, err error) { + if len(bz) == 0 { + return hasUnknownNonCriticals, nil + } + if recursionLimit <= 0 { + return false, errors.New("recursion limit reached") + } + + fields := desc.Fields() + + for len(bz) > 0 { + tagNum, wireType, m := protowire.ConsumeTag(bz) + if m < 0 { + return hasUnknownNonCriticals, errors.New("invalid length") + } + + fieldDesc := fields.ByNumber(tagNum) + if fieldDesc == nil { + isCriticalField := tagNum&bit11NonCritical == 0 + + if !isCriticalField { + hasUnknownNonCriticals = true + } + + if isCriticalField || !allowUnknownNonCriticals { + // The tag is critical, so report it. + return hasUnknownNonCriticals, ErrUnknownField.Wrapf( + "%s: {TagNum: %d, WireType:%q}", + desc.FullName(), tagNum, WireTypeToString(wireType)) + } + } + + // Skip over the bytes that store fieldNumber and wireType bytes. + bz = bz[m:] + n := protowire.ConsumeFieldValue(tagNum, wireType, bz) + if n < 0 { + err = fmt.Errorf("could not consume field value for tagNum: %d, wireType: %q; %w", + tagNum, WireTypeToString(wireType), protowire.ParseError(n)) + return hasUnknownNonCriticals, err + } + fieldBytes := bz[:n] + bz = bz[n:] + + // An unknown but non-critical field + if fieldDesc == nil { + continue + } + + fieldMessage := fieldDesc.Message() + // not message or group kind + if fieldMessage == nil { + continue + } + // if a message descriptor is a placeholder resolve it using the injected resolver. + // this can happen when a descriptor has been registered in the + // "google.golang.org/protobuf" registry but not in "github.com/cosmos/gogoproto". + // fixes: https://github.com/cosmos/cosmos-sdk/issues/22574 + if fieldMessage.IsPlaceholder() { + gogoDesc, err := resolver.FindDescriptorByName(fieldMessage.FullName()) + if err != nil { + return hasUnknownNonCriticals, fmt.Errorf("could not resolve placeholder descriptor: %v: %w", fieldMessage, err) + } + fieldMessage = gogoDesc.(protoreflect.MessageDescriptor) + } + + // consume length prefix of nested message + _, o := protowire.ConsumeVarint(fieldBytes) + if o < 0 { + err = fmt.Errorf("could not consume length prefix fieldBytes for nested message: %v: %w", + fieldMessage, protowire.ParseError(o)) + return hasUnknownNonCriticals, err + } else if o > len(fieldBytes) { + err = fmt.Errorf("length prefix > len(fieldBytes) for nested message: %v", fieldMessage) + return hasUnknownNonCriticals, err + } + + fieldBytes = fieldBytes[o:] + + var err error + + if fieldMessage.FullName() == anyFullName { + // Firstly typecheck types.Any to ensure nothing snuck in. + hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, anyDesc, allowUnknownNonCriticals, resolver, recursionLimit-1) + hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild + if err != nil { + return hasUnknownNonCriticals, err + } + var a anypb.Any + if err = proto.Unmarshal(fieldBytes, &a); err != nil { + return hasUnknownNonCriticals, err + } + + msgName := protoreflect.FullName(strings.TrimPrefix(a.TypeUrl, "/")) + msgDesc, err := resolver.FindDescriptorByName(msgName) + if err != nil { + return hasUnknownNonCriticals, err + } + + fieldMessage = msgDesc.(protoreflect.MessageDescriptor) + fieldBytes = a.Value + } + + hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, fieldMessage, allowUnknownNonCriticals, resolver, recursionLimit-1) + hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild + if err != nil { + return hasUnknownNonCriticals, err + } + } + + return hasUnknownNonCriticals, nil +} + +// errUnknownField represents an error indicating that we encountered +// a field that isn't available in the target proto.Message. +type errUnknownField struct { + Desc protoreflect.MessageDescriptor + TagNum protowire.Number + WireType protowire.Type +} + +// String implements fmt.Stringer. +func (twt *errUnknownField) String() string { + return fmt.Sprintf("errUnknownField %q: {TagNum: %d, WireType:%q}", + twt.Desc.FullName(), twt.TagNum, WireTypeToString(twt.WireType)) +} + +// Error implements the error interface. +func (twt *errUnknownField) Error() string { + return twt.String() +} + +var _ error = (*errUnknownField)(nil) + +// WireTypeToString returns a string representation of the given protowire.Type. +func WireTypeToString(wt protowire.Type) string { + switch wt { + case 0: + return "varint" + case 1: + return "fixed64" + case 2: + return "bytes" + case 3: + return "start_group" + case 4: + return "end_group" + case 5: + return "fixed32" + default: + return fmt.Sprintf("unknown type: %d", wt) + } +} From 3afb5643edd53c77520bbe5d2bf7bfba4a34b582 Mon Sep 17 00:00:00 2001 From: Julien Robert Date: Mon, 6 Jan 2025 09:37:57 +0100 Subject: [PATCH 2/2] imp --- CHANGELOG.md | 9 -- x/tx/decode/unknown.go | 197 ----------------------------------------- 2 files changed, 206 deletions(-) delete mode 100644 x/tx/decode/unknown.go diff --git a/CHANGELOG.md b/CHANGELOG.md index cef82d21ef78..d5a73d9aa3e6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,19 +40,10 @@ Ref: https://keepachangelog.com/en/1.0.0/ Every module contains its own CHANGELOG.md. Please refer to the module you are interested in. -<<<<<<< HEAD -======= -### Features - -* (baseapp) [#20291](https://github.com/cosmos/cosmos-sdk/pull/20291) Simulate nested messages. -* (client/keys) [#21829](https://github.com/cosmos/cosmos-sdk/pull/21829) Add support for importing hex key using standard input. -* (x/auth/ante) [#23128](https://github.com/cosmos/cosmos-sdk/pull/23128) Allow custom verifyIsOnCurve when validate tx for public key like ethsecp256k1. - ### Improvements * (codec) [#22988](https://github.com/cosmos/cosmos-sdk/pull/22988) Improve edge case handling for recursion limits. ->>>>>>> 93282e101 (refactor: improve edge case handling for recursion limits (#22988)) ### Bug Fixes * (x/auth/tx) [#23148](https://github.com/cosmos/cosmos-sdk/pull/23148) Avoid panic from intoAnyV2 when v1.PublicKey is optional. diff --git a/x/tx/decode/unknown.go b/x/tx/decode/unknown.go deleted file mode 100644 index ce608b32a4ba..000000000000 --- a/x/tx/decode/unknown.go +++ /dev/null @@ -1,197 +0,0 @@ -package decode - -import ( - "errors" - "fmt" - "strings" - - "google.golang.org/protobuf/encoding/protowire" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/reflect/protodesc" - "google.golang.org/protobuf/reflect/protoreflect" - "google.golang.org/protobuf/types/known/anypb" -) - -const bit11NonCritical = 1 << 10 - -var ( - anyDesc = (&anypb.Any{}).ProtoReflect().Descriptor() - anyFullName = anyDesc.FullName() -) - -// RejectUnknownFieldsStrict operates by the same rules as RejectUnknownFields, but returns an error if any unknown -// non-critical fields are encountered. -func RejectUnknownFieldsStrict(bz []byte, msg protoreflect.MessageDescriptor, resolver protodesc.Resolver) error { - _, err := RejectUnknownFields(bz, msg, false, resolver) - return err -} - -// RejectUnknownFields rejects any bytes bz with an error that has unknown fields for the provided proto.Message type with an -// option to allow non-critical fields (specified as those fields with bit 11) to pass through. In either case, the -// hasUnknownNonCriticals will be set to true if non-critical fields were encountered during traversal. This flag can be -// used to treat a message with non-critical field different in different security contexts (such as transaction signing). -// This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message. -// An AnyResolver must be provided for traversing inside google.protobuf.Any's. -func RejectUnknownFields(bz []byte, desc protoreflect.MessageDescriptor, allowUnknownNonCriticals bool, resolver protodesc.Resolver) (hasUnknownNonCriticals bool, err error) { - // recursion limit with same default as https://github.com/protocolbuffers/protobuf-go/blob/v1.35.2/encoding/protowire/wire.go#L28 - return doRejectUnknownFields(bz, desc, allowUnknownNonCriticals, resolver, 10_000) -} - -func doRejectUnknownFields( - bz []byte, - desc protoreflect.MessageDescriptor, - allowUnknownNonCriticals bool, - resolver protodesc.Resolver, - recursionLimit int, -) (hasUnknownNonCriticals bool, err error) { - if len(bz) == 0 { - return hasUnknownNonCriticals, nil - } - if recursionLimit <= 0 { - return false, errors.New("recursion limit reached") - } - - fields := desc.Fields() - - for len(bz) > 0 { - tagNum, wireType, m := protowire.ConsumeTag(bz) - if m < 0 { - return hasUnknownNonCriticals, errors.New("invalid length") - } - - fieldDesc := fields.ByNumber(tagNum) - if fieldDesc == nil { - isCriticalField := tagNum&bit11NonCritical == 0 - - if !isCriticalField { - hasUnknownNonCriticals = true - } - - if isCriticalField || !allowUnknownNonCriticals { - // The tag is critical, so report it. - return hasUnknownNonCriticals, ErrUnknownField.Wrapf( - "%s: {TagNum: %d, WireType:%q}", - desc.FullName(), tagNum, WireTypeToString(wireType)) - } - } - - // Skip over the bytes that store fieldNumber and wireType bytes. - bz = bz[m:] - n := protowire.ConsumeFieldValue(tagNum, wireType, bz) - if n < 0 { - err = fmt.Errorf("could not consume field value for tagNum: %d, wireType: %q; %w", - tagNum, WireTypeToString(wireType), protowire.ParseError(n)) - return hasUnknownNonCriticals, err - } - fieldBytes := bz[:n] - bz = bz[n:] - - // An unknown but non-critical field - if fieldDesc == nil { - continue - } - - fieldMessage := fieldDesc.Message() - // not message or group kind - if fieldMessage == nil { - continue - } - // if a message descriptor is a placeholder resolve it using the injected resolver. - // this can happen when a descriptor has been registered in the - // "google.golang.org/protobuf" registry but not in "github.com/cosmos/gogoproto". - // fixes: https://github.com/cosmos/cosmos-sdk/issues/22574 - if fieldMessage.IsPlaceholder() { - gogoDesc, err := resolver.FindDescriptorByName(fieldMessage.FullName()) - if err != nil { - return hasUnknownNonCriticals, fmt.Errorf("could not resolve placeholder descriptor: %v: %w", fieldMessage, err) - } - fieldMessage = gogoDesc.(protoreflect.MessageDescriptor) - } - - // consume length prefix of nested message - _, o := protowire.ConsumeVarint(fieldBytes) - if o < 0 { - err = fmt.Errorf("could not consume length prefix fieldBytes for nested message: %v: %w", - fieldMessage, protowire.ParseError(o)) - return hasUnknownNonCriticals, err - } else if o > len(fieldBytes) { - err = fmt.Errorf("length prefix > len(fieldBytes) for nested message: %v", fieldMessage) - return hasUnknownNonCriticals, err - } - - fieldBytes = fieldBytes[o:] - - var err error - - if fieldMessage.FullName() == anyFullName { - // Firstly typecheck types.Any to ensure nothing snuck in. - hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, anyDesc, allowUnknownNonCriticals, resolver, recursionLimit-1) - hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild - if err != nil { - return hasUnknownNonCriticals, err - } - var a anypb.Any - if err = proto.Unmarshal(fieldBytes, &a); err != nil { - return hasUnknownNonCriticals, err - } - - msgName := protoreflect.FullName(strings.TrimPrefix(a.TypeUrl, "/")) - msgDesc, err := resolver.FindDescriptorByName(msgName) - if err != nil { - return hasUnknownNonCriticals, err - } - - fieldMessage = msgDesc.(protoreflect.MessageDescriptor) - fieldBytes = a.Value - } - - hasUnknownNonCriticalsChild, err := doRejectUnknownFields(fieldBytes, fieldMessage, allowUnknownNonCriticals, resolver, recursionLimit-1) - hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild - if err != nil { - return hasUnknownNonCriticals, err - } - } - - return hasUnknownNonCriticals, nil -} - -// errUnknownField represents an error indicating that we encountered -// a field that isn't available in the target proto.Message. -type errUnknownField struct { - Desc protoreflect.MessageDescriptor - TagNum protowire.Number - WireType protowire.Type -} - -// String implements fmt.Stringer. -func (twt *errUnknownField) String() string { - return fmt.Sprintf("errUnknownField %q: {TagNum: %d, WireType:%q}", - twt.Desc.FullName(), twt.TagNum, WireTypeToString(twt.WireType)) -} - -// Error implements the error interface. -func (twt *errUnknownField) Error() string { - return twt.String() -} - -var _ error = (*errUnknownField)(nil) - -// WireTypeToString returns a string representation of the given protowire.Type. -func WireTypeToString(wt protowire.Type) string { - switch wt { - case 0: - return "varint" - case 1: - return "fixed64" - case 2: - return "bytes" - case 3: - return "start_group" - case 4: - return "end_group" - case 5: - return "fixed32" - default: - return fmt.Sprintf("unknown type: %d", wt) - } -}