forked from LDCS/sflag
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sflag.go
234 lines (213 loc) · 7.94 KB
/
sflag.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
// Package sflag is a flag package variant that is 100% DRY, free of fugly pointer syntax and uses clean struct syntax.
//
// Implementation makes use of reflection and struct tags, in manner not dissimilar to other published flag variants.
//
// Limitation: Presence of a boolean flag requires that there be no STANDALONE true or false parameters, use "--Foo=true" syntax instead of "--Foo true".
//This is because the underlying std flag package will stop processing on seeing the first standalone true/false value.
//(This is because it will considers the preceding bool flag (--Foo) set by its presence alone).
package sflag
import (
"flag"
"os"
"reflect"
"strconv"
"strings"
)
var (
visited map[string]bool
)
func noteVisited(_flag *flag.Flag) {
visited[_flag.Name] = true
}
// Parse iterates through the members of the struct.
/*
Members are set up for std flag package to do the actual parsing, using type obtained via reflection and info from struct tag for usage and default setting.
Normally, the rightmost pipe char in the tag is used to delineate between Description (on left) and Default value (on right).
(You can override delineator to the first char of the tag (after eliminating leading whitespace) if such char is not alphabetic).
Fields with no tag or whitespace-only tags are ignored.
Non-nil pointer fields are ignored.
Nil pointer fields will be left nil if that flag is not set on commandline (and the tag is not parsed for a default value).
Parameters not consumed by flags will be copied to the last field of type []string
Flags starting with lowercase letter require that the coresponding member ends in single underscore.
Provide string member Usage initialized to brief program description. Parse will append member descriptions to that string.
Provide []string member Args if you want to want to retrieve unconsumed flags.
Initialize []string member Args to the string array you want to parse instead of os.Args[1:].
*/
func Parse(ss interface{}) {
visited = make(map[string]bool)
pointers := map[string]interface{}{}
if reflect.TypeOf(ss).Kind() != reflect.Ptr {
panic("sflag.Parse was not provided a pointer arg")
}
sstype := reflect.TypeOf(ss).Elem()
ssvalue := reflect.ValueOf(ss).Elem()
if sstype.Kind() != reflect.Struct {
panic("sflag.Parse was not provided a pointer to a struct")
}
var argsiface interface{}
args := make([]string, len(os.Args) - 1)
copy(args, os.Args[1:])
progname:= os.Args[0]
if pp, ok := sstype.FieldByName("Args"); ok {
if pp.Type.String() == "[]string" { // caller wanted to override os.Args and/or retrieve unconsumed flags
vv := ssvalue.FieldByName("Args")
if len(*vv.Addr().Interface().(*[]string)) == 0 {
} else { // caller wanted to override os.Args
args = make([]string, len(*vv.Addr().Interface().(*[]string)))
copy(args, *vv.Addr().Interface().(*[]string))
}
argsiface = vv.Addr().Interface()
}
}
moreusage := ""
hasBoolArg := false
flags := *flag.NewFlagSet(progname, flag.PanicOnError)
for ii := 0; ii < sstype.NumField(); ii++ {
pp := sstype.Field(ii)
vv := ssvalue.Field(ii)
switch {
case pp.Anonymous : continue // Skip embedded fields
case pp.Name == "Usage" : continue // Not a flag
case pp.Type.String() == "[]string" : continue // Already handled Args, and not interested in other such members
case (pp.Type.Kind() == reflect.Ptr) && (vv.Elem().Kind() != reflect.Invalid) : continue // Ignore non-nil pointer members
}
tag := strings.TrimSpace((string)(pp.Tag))
if tag == "" {
continue
}
flagname := pp.Name
if nn := len(pp.Name) - 1; flagname[nn] == '_' { // User wants to look for --f* instead of --F*
flagname = strings.ToLower(pp.Name[:1]) + pp.Name[1:nn]
}
splitChar := tag[0:1]
if strings.Contains("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ", splitChar) {
splitChar = "|"
} else {
tag = tag[1:]
}
parts := strings.Split(tag, splitChar)
part0 := ""
part1 := ""
if len(parts) > 0 {
part0 = strings.TrimSpace(parts[0])
}
if len(parts) > 1 {
part1 = strings.TrimSpace(parts[1])
}
if pp.Type.Kind() == reflect.Ptr {
switch pp.Type.String() {
case "*string":
tempstr := ""
pointers[flagname] = &tempstr
flags.StringVar(&tempstr, flagname, tempstr, "")
case "*int":
tempint := 0
pointers[flagname] = &tempint
flags.IntVar(&tempint, flagname, tempint, "")
case "*bool":
tempbool := false
pointers[flagname] = &tempbool
flags.BoolVar(&tempbool, flagname, tempbool, "")
case "*int64":
tempint64 := int64(0)
pointers[flagname] = &tempint64
flags.Int64Var(&tempint64, flagname, tempint64, "")
case "*float64":
tempfloat64 := 0.0
pointers[flagname] = &tempfloat64
flags.Float64Var(&tempfloat64, flagname, tempfloat64, "")
default:
continue
}
}
if len(parts) == 1 {
switch pp.Type.Kind() {
case reflect.String:
flags.StringVar(vv.Addr().Interface().(*string), flagname, vv.String(), " <--default, string # "+part0)
case reflect.Int:
flags.IntVar(vv.Addr().Interface().(*int), flagname, int(vv.Int()), " <--default, int # "+part0)
case reflect.Bool:
flags.BoolVar(vv.Addr().Interface().(*bool), flagname, bool(vv.Bool()), " <--default, bool # "+part0)
hasBoolArg = true
case reflect.Int64:
flags.Int64Var(vv.Addr().Interface().(*int64), flagname, vv.Int(), " <--default, int64 # "+part0)
case reflect.Float64:
flags.Float64Var(vv.Addr().Interface().(*float64), flagname, vv.Float(), " <--default, float64 # "+part0)
default:
continue
}
}
if len(parts) == 2 {
switch pp.Type.Kind() {
case reflect.String:
vv.SetString(part1)
flags.StringVar(vv.Addr().Interface().(*string), flagname, part1, " <--default, string # "+part0)
case reflect.Int:
inum, _ := strconv.ParseInt(part1, 10, 64)
vv.SetInt(inum)
flags.IntVar(vv.Addr().Interface().(*int), flagname, int(inum), " <--default, int # "+part0)
case reflect.Bool:
bnum, _ := strconv.ParseBool(part1)
vv.SetBool(bnum)
flags.BoolVar(vv.Addr().Interface().(*bool), flagname, bool(bnum), " <--default, bool # "+part0)
hasBoolArg = true
case reflect.Int64:
jnum, _ := strconv.ParseInt(part1, 10, 64)
vv.SetInt(jnum)
flags.Int64Var(vv.Addr().Interface().(*int64), flagname, jnum, " <--default, int64 # "+part0)
case reflect.Float64:
fnum, _ := strconv.ParseFloat(part1, 64)
vv.SetFloat(fnum)
flags.Float64Var(vv.Addr().Interface().(*float64), flagname, fnum, " <--default, float64 # "+part0)
default:
continue
}
}
if len(parts) > 0 {
moreusage += "\n\t--" + flagname + ": " + part1 + " <-- Default, " + pp.Type.String() + " # " + part0
}
}
if pp, ok := sstype.FieldByName("Usage"); ok {
vv := ssvalue.FieldByName("Usage")
vv.SetString("\n Usage of " + progname + " # " + (string)(pp.Tag) + "\n ARGS:" + moreusage)
}
if hasBoolArg {
for _, arg := range args {
switch strings.ToLower(arg) {
case "true", "false":
panic("Golang flag package requires \"--Foo=bar\" instead of \"--Foo bar\" syntax for bool args")
}
}
}
flags.Parse(args)
if argsiface != nil {
*argsiface.(*[]string) = make([]string, len(flags.Args()))
copy(*argsiface.(*[]string), flags.Args())
}
flags.Visit(noteVisited) // note all the visited flags, needed below
// Set all pointer-type flags that actually had values set
for flagname := range pointers {
if visited[flagname] {
fieldname := flagname
if flagname[:1] != strings.ToUpper(flagname[:1]) {
fieldname = strings.ToUpper(flagname[:1]) + flagname[1:] + "_"
}
pp, _ := sstype.FieldByName(fieldname)
vv := ssvalue.FieldByName(fieldname)
switch pp.Type.String() {
case "*string":
vv.Set(reflect.ValueOf(pointers[flagname]))
case "*int":
vv.Set(reflect.ValueOf(pointers[flagname]))
case "*bool":
vv.Set(reflect.ValueOf(pointers[flagname]))
case "*int64":
vv.Set(reflect.ValueOf(pointers[flagname]))
case "*float64":
vv.Set(reflect.ValueOf(pointers[flagname]))
default:
continue
}
}
}
}