diff --git a/go/cmd/automation_server/automation_server.go b/go/cmd/automation_server/automation_server.go index c34b360dafe..af01e894f1c 100644 --- a/go/cmd/automation_server/automation_server.go +++ b/go/cmd/automation_server/automation_server.go @@ -17,7 +17,6 @@ limitations under the License. package main import ( - "flag" "fmt" "net" "os" @@ -34,13 +33,7 @@ func init() { } func main() { - - flag.Parse() - - if *servenv.Version { - servenv.AppVersion.Print() - os.Exit(0) - } + servenv.ParseFlags("automation_server") fmt.Println("Automation Server, listening on:", *servenv.Port) diff --git a/go/cmd/l2vtgate/main.go b/go/cmd/l2vtgate/main.go index dd082656c8e..a1da649a7be 100644 --- a/go/cmd/l2vtgate/main.go +++ b/go/cmd/l2vtgate/main.go @@ -19,7 +19,6 @@ package main import ( "flag" "math/rand" - "os" "strings" "time" @@ -55,14 +54,9 @@ func init() { func main() { defer exit.Recover() - flag.Parse() + servenv.ParseFlags("l2vtgate") servenv.Init() - if *servenv.Version { - servenv.AppVersion.Print() - os.Exit(0) - } - ts := topo.Open() defer ts.Close() diff --git a/go/cmd/mysqlctld/mysqlctld.go b/go/cmd/mysqlctld/mysqlctld.go index 2c7e46e8180..4fbef020f9f 100644 --- a/go/cmd/mysqlctld/mysqlctld.go +++ b/go/cmd/mysqlctld/mysqlctld.go @@ -58,12 +58,7 @@ func main() { // mysqlctld only starts and stops mysql, only needs dba. dbconfigFlags := dbconfigs.DbaConfig dbconfigs.RegisterFlags(dbconfigFlags) - flag.Parse() - - if *servenv.Version { - servenv.AppVersion.Print() - os.Exit(0) - } + servenv.ParseFlags("mysqlctld") // We'll register this OnTerm handler before mysqld starts, so we get notified // if mysqld dies on its own without us (or our RPC client) telling it to. diff --git a/go/cmd/vtcombo/main.go b/go/cmd/vtcombo/main.go index 43729f70d83..09e40e675cc 100644 --- a/go/cmd/vtcombo/main.go +++ b/go/cmd/vtcombo/main.go @@ -23,7 +23,6 @@ package main import ( "flag" - "os" "strings" "time" @@ -67,18 +66,7 @@ func main() { dbconfigs.FilteredConfig | dbconfigs.ReplConfig dbconfigs.RegisterFlags(dbconfigFlags) mysqlctl.RegisterFlags() - flag.Parse() - - if *servenv.Version { - servenv.AppVersion.Print() - os.Exit(0) - } - - if len(flag.Args()) > 0 { - flag.Usage() - log.Errorf("vtcombo doesn't take any positional arguments") - exit.Return(1) - } + servenv.ParseFlags("vtcombo") // parse the input topology tpb := &vttestpb.VTTestTopology{} diff --git a/go/cmd/vtctl/vtctl.go b/go/cmd/vtctl/vtctl.go index eb7414d8fe2..37a1b3eec81 100644 --- a/go/cmd/vtctl/vtctl.go +++ b/go/cmd/vtctl/vtctl.go @@ -69,19 +69,7 @@ func main() { defer exit.RecoverAll() defer logutil.Flush() - flag.Parse() - args := flag.Args() - - if *servenv.Version { - servenv.AppVersion.Print() - os.Exit(0) - } - - if len(args) == 0 { - flag.Usage() - exit.Return(1) - } - + args := servenv.ParseFlagsWithArgs("vtctl") action := args[0] startMsg := fmt.Sprintf("USER=%v SUDO_USER=%v %v", os.Getenv("USER"), os.Getenv("SUDO_USER"), strings.Join(os.Args, " ")) diff --git a/go/cmd/vtctld/main.go b/go/cmd/vtctld/main.go index 3cfcff48ece..82d859a7b2e 100644 --- a/go/cmd/vtctld/main.go +++ b/go/cmd/vtctld/main.go @@ -17,9 +17,6 @@ limitations under the License. package main import ( - "flag" - "os" - "github.com/youtube/vitess/go/vt/servenv" "github.com/youtube/vitess/go/vt/topo" "github.com/youtube/vitess/go/vt/vtctld" @@ -35,15 +32,10 @@ var ( ) func main() { - flag.Parse() + servenv.ParseFlags("vtctld") servenv.Init() defer servenv.Close() - if *servenv.Version { - servenv.AppVersion.Print() - os.Exit(0) - } - ts = topo.Open() defer ts.Close() diff --git a/go/cmd/vtexplain/vtexplain.go b/go/cmd/vtexplain/vtexplain.go index 0a2056c329a..d676bb839d2 100644 --- a/go/cmd/vtexplain/vtexplain.go +++ b/go/cmd/vtexplain/vtexplain.go @@ -20,7 +20,6 @@ import ( "flag" "fmt" "io/ioutil" - "os" log "github.com/golang/glog" "github.com/youtube/vitess/go/exit" @@ -124,19 +123,7 @@ func main() { defer exit.RecoverAll() defer logutil.Flush() - flag.Parse() - - if *servenv.Version { - servenv.AppVersion.Print() - os.Exit(0) - } - - args := flag.Args() - - if len(args) != 0 { - flag.Usage() - exit.Return(1) - } + servenv.ParseFlags("vtexplain") err := parseAndRun() if err != nil { diff --git a/go/cmd/vtgate/vtgate.go b/go/cmd/vtgate/vtgate.go index 3c2d9c33260..c27ffa303eb 100644 --- a/go/cmd/vtgate/vtgate.go +++ b/go/cmd/vtgate/vtgate.go @@ -19,7 +19,6 @@ package main import ( "flag" "math/rand" - "os" "strings" "time" @@ -59,14 +58,9 @@ func init() { func main() { defer exit.Recover() - flag.Parse() + servenv.ParseFlags("vtgate") servenv.Init() - if *servenv.Version { - servenv.AppVersion.Print() - os.Exit(0) - } - if initFakeZK != nil { initFakeZK() } diff --git a/go/cmd/vtgateclienttest/main.go b/go/cmd/vtgateclienttest/main.go index e713168d405..421114b8bf7 100644 --- a/go/cmd/vtgateclienttest/main.go +++ b/go/cmd/vtgateclienttest/main.go @@ -20,9 +20,6 @@ limitations under the License. package main import ( - "flag" - "os" - "github.com/youtube/vitess/go/cmd/vtgateclienttest/services" "github.com/youtube/vitess/go/exit" "github.com/youtube/vitess/go/vt/servenv" @@ -36,14 +33,9 @@ func init() { func main() { defer exit.Recover() - flag.Parse() + servenv.ParseFlags("vtgateclienttest") servenv.Init() - if *servenv.Version { - servenv.AppVersion.Print() - os.Exit(0) - } - // The implementation chain. servenv.OnRun(func() { s := services.CreateServices() diff --git a/go/cmd/vttablet/vttablet.go b/go/cmd/vttablet/vttablet.go index 7b0a0c9a97d..fec142f4294 100644 --- a/go/cmd/vttablet/vttablet.go +++ b/go/cmd/vttablet/vttablet.go @@ -19,7 +19,6 @@ package main import ( "flag" - "os" log "github.com/golang/glog" "github.com/youtube/vitess/go/vt/dbconfigs" @@ -52,17 +51,9 @@ func main() { dbconfigs.FilteredConfig | dbconfigs.ReplConfig dbconfigs.RegisterFlags(dbconfigFlags) mysqlctl.RegisterFlags() - flag.Parse() - if *servenv.Version { - servenv.AppVersion.Print() - os.Exit(0) - } + servenv.ParseFlags("vttablet") - if len(flag.Args()) > 0 { - flag.Usage() - log.Exit("vttablet doesn't take any positional arguments") - } if err := tabletenv.VerifyConfig(); err != nil { log.Exitf("invalid config: %v", err) } diff --git a/go/vt/servenv/servenv.go b/go/vt/servenv/servenv.go index 259648c5907..350b85064f9 100644 --- a/go/vt/servenv/servenv.go +++ b/go/vt/servenv/servenv.go @@ -33,6 +33,7 @@ import ( "net/url" "os" "runtime" + "strings" "sync" "syscall" "time" @@ -193,3 +194,37 @@ func RegisterDefaultFlags() { func RunDefault() { Run(*Port) } + +// ParseFlags initializes flags and handles the common case when no positional +// arguments are expected. +func ParseFlags(cmd string) { + flag.Parse() + + if *Version { + AppVersion.Print() + os.Exit(0) + } + + args := flag.Args() + if len(args) > 0 { + flag.Usage() + log.Exitf("%s doesn't take any positional arguments, got '%s'", cmd, strings.Join(args, " ")) + } +} + +// ParseFlagsWithArgs initializes flags and returns the positional arguments +func ParseFlagsWithArgs(cmd string) []string { + flag.Parse() + + if *Version { + AppVersion.Print() + os.Exit(0) + } + + args := flag.Args() + if len(args) == 0 { + log.Exitf("%s expected at least one positional argument", cmd) + } + + return args +}