Skip to content

Commit

Permalink
refactor(client/v2): refactor of flags (backport #17306) (#17309)
Browse files Browse the repository at this point in the history
Co-authored-by: Julien Robert <julien@rbrt.fr>
  • Loading branch information
mergify[bot] and julienrbrt authored Aug 7, 2023
1 parent c56cd8a commit 6861a06
Show file tree
Hide file tree
Showing 26 changed files with 377 additions and 432 deletions.
6 changes: 3 additions & 3 deletions client/v2/autocli/flag/address.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (

type addressStringType struct{}

func (a addressStringType) NewValue(ctx context.Context, b *Builder) Value {
func (a addressStringType) NewValue(_ context.Context, b *Builder) Value {
return &addressValue{addressCodec: b.AddressCodec}
}

Expand All @@ -27,7 +27,7 @@ func (a addressStringType) DefaultValue() string {

type validatorAddressStringType struct{}

func (a validatorAddressStringType) NewValue(ctx context.Context, b *Builder) Value {
func (a validatorAddressStringType) NewValue(_ context.Context, b *Builder) Value {
return &addressValue{addressCodec: b.ValidatorAddressCodec}
}

Expand Down Expand Up @@ -61,7 +61,7 @@ func (a *addressValue) Set(s string) error {
}

func (a addressValue) Type() string {
return "bech32 account address key name"
return "bech32 account address"
}

type consensusAddressStringType struct{}
Expand Down
2 changes: 1 addition & 1 deletion client/v2/autocli/flag/binary.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ type binaryType struct{}

var _ Value = (*fileBinaryValue)(nil)

func (f binaryType) NewValue(_ context.Context, _ *Builder) Value {
func (f binaryType) NewValue(context.Context, *Builder) Value {
return &fileBinaryValue{}
}

Expand Down
303 changes: 303 additions & 0 deletions client/v2/autocli/flag/builder.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
package flag

import (
"context"
"fmt"
"strconv"

cosmos_proto "github.com/cosmos/cosmos-proto"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"

autocliv1 "cosmossdk.io/api/cosmos/autocli/v1"
"cosmossdk.io/client/v2/internal/util"
"cosmossdk.io/core/address"
)

Expand Down Expand Up @@ -55,3 +65,296 @@ func (b *Builder) DefineScalarFlagType(scalarName string, flagType Type) {
b.init()
b.scalarFlagTypes[scalarName] = flagType
}

func (b *Builder) AddMessageFlags(ctx context.Context, flagSet *pflag.FlagSet, messageType protoreflect.MessageType, commandOptions *autocliv1.RpcCommandOptions) (*MessageBinder, error) {
return b.addMessageFlags(ctx, flagSet, messageType, commandOptions, namingOptions{})
}

// AddMessageFlags adds flags for each field in the message to the flag set.
func (b *Builder) addMessageFlags(ctx context.Context, flagSet *pflag.FlagSet, messageType protoreflect.MessageType, commandOptions *autocliv1.RpcCommandOptions, options namingOptions) (*MessageBinder, error) {
fields := messageType.Descriptor().Fields()
numFields := fields.Len()
handler := &MessageBinder{
messageType: messageType,
}

isPositional := map[string]bool{}
hasVarargs := false
hasOptional := false
n := len(commandOptions.PositionalArgs)
// positional args are also parsed using a FlagSet so that we can reuse all the same parsers
handler.positionalFlagSet = pflag.NewFlagSet("positional", pflag.ContinueOnError)
for i, arg := range commandOptions.PositionalArgs {
isPositional[arg.ProtoField] = true

field := fields.ByName(protoreflect.Name(arg.ProtoField))
if field == nil {
return nil, fmt.Errorf("can't find field %s on %s", arg.ProtoField, messageType.Descriptor().FullName())
}

if arg.Optional && arg.Varargs {
return nil, fmt.Errorf("positional argument %s can't be both optional and varargs", arg.ProtoField)
}

if arg.Varargs {
if i != n-1 {
return nil, fmt.Errorf("varargs positional argument %s must be the last argument", arg.ProtoField)
}

hasVarargs = true
}

if arg.Optional {
if i != n-1 {
return nil, fmt.Errorf("optional positional argument %s must be the last argument", arg.ProtoField)
}

hasOptional = true
}

_, hasValue, err := b.addFieldFlag(
ctx,
handler.positionalFlagSet,
field,
&autocliv1.FlagOptions{Name: fmt.Sprintf("%d", i)},
namingOptions{},
)
if err != nil {
return nil, err
}

handler.positionalArgs = append(handler.positionalArgs, fieldBinding{
field: field,
hasValue: hasValue,
})
}

if hasVarargs {
handler.CobraArgs = cobra.MinimumNArgs(n - 1)
handler.hasVarargs = true
} else if hasOptional {
handler.CobraArgs = cobra.RangeArgs(n-1, n)
handler.hasOptional = true
} else {
handler.CobraArgs = cobra.ExactArgs(n)
}

// validate flag options
for name := range commandOptions.FlagOptions {
if fields.ByName(protoreflect.Name(name)) == nil {
return nil, fmt.Errorf("can't find field %s on %s specified as a flag", name, messageType.Descriptor().FullName())
}
}

flagOptsByFlagName := map[string]*autocliv1.FlagOptions{}
for i := 0; i < numFields; i++ {
field := fields.Get(i)
if isPositional[string(field.Name())] {
continue
}

flagOpts := commandOptions.FlagOptions[string(field.Name())]
name, hasValue, err := b.addFieldFlag(ctx, flagSet, field, flagOpts, options)
flagOptsByFlagName[name] = flagOpts
if err != nil {
return nil, err
}

handler.flagBindings = append(handler.flagBindings, fieldBinding{
hasValue: hasValue,
field: field,
})
}

flagSet.VisitAll(func(flag *pflag.Flag) {
opts := flagOptsByFlagName[flag.Name]
if opts != nil {
// This is a bit of hacking around the pflag API, but
// we need to set these options here using Flag.VisitAll because the flag
// constructors that pflag gives us (StringP, Int32P, etc.) do not
// actually return the *Flag instance
flag.Deprecated = opts.Deprecated
flag.ShorthandDeprecated = opts.ShorthandDeprecated
flag.Hidden = opts.Hidden
}
})

return handler, nil
}

// bindPageRequest create a flag for pagination
func (b *Builder) bindPageRequest(ctx context.Context, flagSet *pflag.FlagSet, field protoreflect.FieldDescriptor) (HasValue, error) {
return b.addMessageFlags(
ctx,
flagSet,
util.ResolveMessageType(b.TypeResolver, field.Message()),
&autocliv1.RpcCommandOptions{},
namingOptions{Prefix: "page-"},
)
}

// namingOptions specifies internal naming options for flags.
type namingOptions struct {
// Prefix is a prefix to prepend to all flags.
Prefix string
}

// addFieldFlag adds a flag for the provided field to the flag set.
func (b *Builder) addFieldFlag(ctx context.Context, flagSet *pflag.FlagSet, field protoreflect.FieldDescriptor, opts *autocliv1.FlagOptions, options namingOptions) (name string, hasValue HasValue, err error) {
if opts == nil {
opts = &autocliv1.FlagOptions{}
}

if field.Kind() == protoreflect.MessageKind && field.Message().FullName() == "cosmos.base.query.v1beta1.PageRequest" {
hasValue, err := b.bindPageRequest(ctx, flagSet, field)
return "", hasValue, err
}

name = opts.Name
if name == "" {
name = options.Prefix + util.DescriptorKebabName(field)
}

usage := opts.Usage
if usage == "" {
usage = util.DescriptorDocs(field)
}

shorthand := opts.Shorthand
defaultValue := opts.DefaultValue

if typ := b.resolveFlagType(field); typ != nil {
if defaultValue == "" {
defaultValue = typ.DefaultValue()
}

val := typ.NewValue(ctx, b)
flagSet.AddFlag(&pflag.Flag{
Name: name,
Shorthand: shorthand,
Usage: usage,
DefValue: defaultValue,
Value: val,
})
return name, val, nil
}

// use the built-in pflag StringP, Int32P, etc. functions
var val HasValue

if field.IsList() {
val = bindSimpleListFlag(flagSet, field.Kind(), name, shorthand, usage)
} else if field.IsMap() {
keyKind := field.MapKey().Kind()
valKind := field.MapValue().Kind()
val = bindSimpleMapFlag(flagSet, keyKind, valKind, name, shorthand, usage)
} else {
val = bindSimpleFlag(flagSet, field.Kind(), name, shorthand, usage)
}

// This is a bit of hacking around the pflag API, but the
// defaultValue is set in this way because this is much easier than trying
// to parse the string into the types that StringSliceP, Int32P, etc. expect
if defaultValue != "" {
err = flagSet.Set(name, defaultValue)
}
return name, val, err
}

func (b *Builder) resolveFlagType(field protoreflect.FieldDescriptor) Type {
typ := b.resolveFlagTypeBasic(field)
if field.IsList() {
if typ != nil {
return compositeListType{simpleType: typ}
}
return nil
}
if field.IsMap() {
keyKind := field.MapKey().Kind()
valType := b.resolveFlagType(field.MapValue())
if valType != nil {
switch keyKind {
case protoreflect.StringKind:
ct := new(compositeMapType[string])
ct.keyValueResolver = func(s string) (string, error) { return s, nil }
ct.valueType = valType
ct.keyType = "string"
return ct
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
ct := new(compositeMapType[int32])
ct.keyValueResolver = func(s string) (int32, error) {
i, err := strconv.ParseInt(s, 10, 32)
return int32(i), err
}
ct.valueType = valType
ct.keyType = "int32"
return ct
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
ct := new(compositeMapType[int64])
ct.keyValueResolver = func(s string) (int64, error) {
i, err := strconv.ParseInt(s, 10, 64)
return i, err
}
ct.valueType = valType
ct.keyType = "int64"
return ct
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
ct := new(compositeMapType[uint32])
ct.keyValueResolver = func(s string) (uint32, error) {
i, err := strconv.ParseUint(s, 10, 32)
return uint32(i), err
}
ct.valueType = valType
ct.keyType = "uint32"
return ct
case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
ct := new(compositeMapType[uint64])
ct.keyValueResolver = func(s string) (uint64, error) {
i, err := strconv.ParseUint(s, 10, 64)
return i, err
}
ct.valueType = valType
ct.keyType = "uint64"
return ct
case protoreflect.BoolKind:
ct := new(compositeMapType[bool])
ct.keyValueResolver = strconv.ParseBool
ct.valueType = valType
ct.keyType = "bool"
return ct
}
return nil

}
return nil
}

return typ
}

func (b *Builder) resolveFlagTypeBasic(field protoreflect.FieldDescriptor) Type {
scalar := proto.GetExtension(field.Options(), cosmos_proto.E_Scalar)
if scalar != nil {
b.init()
if typ, ok := b.scalarFlagTypes[scalar.(string)]; ok {
return typ
}
}

switch field.Kind() {
case protoreflect.BytesKind:
return binaryType{}
case protoreflect.EnumKind:
return enumType{enum: field.Enum()}
case protoreflect.MessageKind:
b.init()
if flagType, ok := b.messageFlagTypes[field.Message().FullName()]; ok {
return flagType
}
return jsonMessageFlagType{
messageDesc: field.Message(),
}
default:
return nil
}
}
2 changes: 1 addition & 1 deletion client/v2/autocli/flag/coin.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type coinValue struct {
value *basev1beta1.Coin
}

func (c coinType) NewValue(_ context.Context, _ *Builder) Value {
func (c coinType) NewValue(context.Context, *Builder) Value {
return &coinValue{}
}

Expand Down
Loading

0 comments on commit 6861a06

Please sign in to comment.