diff --git a/cmdutil/when.go b/cmdutil/when.go index a35db3af2..0d26dde10 100644 --- a/cmdutil/when.go +++ b/cmdutil/when.go @@ -1,6 +1,7 @@ package cmdutil import ( + "fmt" "os" "strings" @@ -46,8 +47,10 @@ func IsAllowedToExecute(when string) (bool, error) { } if got, err := expr.Run(program, whenEnv); err != nil { return false, errors.WithStack(err) + } else if got, ok := got.(bool); !ok { + return false, fmt.Errorf("expected bool, but got %T", got) } else { - return got.(bool), nil + return got, nil } } diff --git a/cmdutil/when_test.go b/cmdutil/when_test.go index 9a987a7b7..55ff3d3d0 100644 --- a/cmdutil/when_test.go +++ b/cmdutil/when_test.go @@ -113,6 +113,13 @@ func TestIsAllowedToExecute(t *testing.T) { want: false, errorContains: "unknown name NoneSuchVariable", }, + { + name: "Expression produces an integer", + envset: map[string]string{}, + when: "123", + want: false, + errorContains: "expected bool, but got int", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -120,8 +127,12 @@ func TestIsAllowedToExecute(t *testing.T) { got, err := IsAllowedToExecute(tt.when) if err != nil { if tt.errorContains != nil { - if !strings.Contains(err.Error(), tt.errorContains.(string)) { - t.Errorf("Error %v does not contain %s", err, tt.errorContains) + if errStr, ok := tt.errorContains.(string); ok { + if !strings.Contains(err.Error(), errStr) { + t.Errorf("Error %v does not contain %s", err, errStr) + } + } else { + t.Errorf("errorContains should be a string, but got %T", tt.errorContains) } } else { t.Error(err) diff --git a/config/yaml.go b/config/yaml.go index 7228c5133..072a613b8 100644 --- a/config/yaml.go +++ b/config/yaml.go @@ -72,7 +72,9 @@ func (f *Format) UnmarshalYAML(data []byte) error { case []interface{}: values := []string{} for _, vv := range v { - values = append(values, vv.(string)) + if str, ok := vv.(string); ok { + values = append(values, str) + } } f.HideColumnsWithoutValues = values } diff --git a/dict/dict.go b/dict/dict.go index a98a64aa6..bd704ebd9 100644 --- a/dict/dict.go +++ b/dict/dict.go @@ -18,7 +18,9 @@ func New() Dict { func (d *Dict) Lookup(k string) string { if v, ok := d.s.Load(k); ok { - return v.(string) + if str, ok := v.(string); ok { + return str + } } return k }