diff --git a/README.md b/README.md index b6d8b03..2417ba9 100644 --- a/README.md +++ b/README.md @@ -7,17 +7,41 @@ Offers a rich configuration file handler. - Read configuration files with ease - Bind CLI flags +- Bind environment variables - Watch file (or files) and get notified if they change +--- + +Uses the following precedence order: + +* `flag` +* `env` +* `toml` + + +| flag | env | toml | result | +|:----:|:-----:|:-------------:|:---:| +| ☑ | ☑ | ☑ | **flag** | +| ☑ | ☑ | ☐ | **flag** | +| ☑ | ☐ | ☑ | **flag** | +| ☐ | ☑ | ☑ | **env** | +| ☑ | ☐ | ☐ | **flag** | +| ☐ | ☑ | ☐ | **env** | +| ☐ | ☐ | ☑ | **toml** | + +If `flag` is set and not given, it will parse `env` or `toml` according to their precedence order (otherwise flag default). + + ## Basic Example Call the `Load()` method to load a config. ```go type MyConfig struct { - Key1 string `toml:"key1"` - Key2 string `toml:"key2"` - Port int `toml:"-" flag:"port"` + Key1 string `toml:"key1"` + Key2 string `toml:"key2"` + Port int `toml:"-" flag:"port"` + Secret string `toml:"-" flag:"-" env:"secret"` } _ = flag.Int("port", 8080, "Port to listen on") // <- notice no variable @@ -28,6 +52,7 @@ Call the `Load()` method to load a config. fmt.Printf("Loaded config: %#v\n", cfg) // Port info is in cfg.Port, parsed from `-port` param + // Secret info is in cfg.Secret, parsed from `secret` environment variable ``` ## File Watching diff --git a/_example/main.go b/_example/main.go index e34b749..5930290 100644 --- a/_example/main.go +++ b/_example/main.go @@ -12,9 +12,10 @@ import ( ) type cfgType struct { - Key1 string `toml:"key1"` - Key2 string `toml:"key2"` - Port int `toml:"-" flag:"port"` + Key1 string `toml:"key1"` + Key2 string `toml:"key2"` + Port int `toml:"-" flag:"port" env:"port"` + Secret string `env:"secret"` } func main() { diff --git a/config.go b/config.go index f6f7081..8b2859d 100644 --- a/config.go +++ b/config.go @@ -4,6 +4,7 @@ package config import ( "flag" "fmt" + "os" "reflect" "strconv" @@ -13,66 +14,83 @@ import ( "github.com/fatih/structs" ) +const ( + envTag string = "env" + flagTag string = "flag" +) + // Load loads filepath into dst. It also handles "flag" binding. func Load(filepath string, dst interface{}) error { metadata, err := toml.DecodeFile(filepath, dst) - if err != nil { return err } + if err := bindEnvVariables(dst); err != nil { + return err + } + return bindFlags(dst, metadata) } -// bindFlags will bind CLI flags to their respective elements in dst, defined by the struct-tag "flag". -func bindFlags(dst interface{}, metadata toml.MetaData) error { - // Iterate all fields +// bindEnvVariables will bind CLI flags to their respective elements in dst, defined by the struct-tag "env". +func bindEnvVariables(dst interface{}) error { fields := structs.Fields(dst) for _, field := range fields { - tag := field.Tag("flag") + tag := field.Tag(envTag) if tag == "" || tag == "-" { - // Maybe it's nested? + ok, dstElem := isNestedStruct(dst, field) + if !ok { + continue + } - dstElem := reflect.ValueOf(dst).Elem().FieldByName(field.Name()) + if err := bindEnvVariables(dstElem.Addr().Interface()); err != nil { + return err + } - if dstElem.Kind() == reflect.Ptr { - if dstElem.IsNil() { - // Create new non-nil ptr - dstElem.Set(reflect.New(dstElem.Type().Elem())) - } + continue + } - // Dereference - dstElem = dstElem.Elem() - } + fVal, ok := os.LookupEnv(tag) + if !ok { + continue + } + + if err := setDstElem(dst, field, fVal); err != nil { + return err + } + } + return nil +} - if dstElem.Kind() != reflect.Struct { +// bindFlags will bind CLI flags to their respective elements in dst, defined by the struct-tag "flag". +func bindFlags(dst interface{}, metadata toml.MetaData) error { + fields := structs.Fields(dst) + for _, field := range fields { + tag := field.Tag(flagTag) + if tag == "" || tag == "-" { + ok, dstElem := isNestedStruct(dst, field) + if !ok { continue } - err := bindFlags(dstElem.Addr().Interface(), metadata) - if err != nil { + if err := bindFlags(dstElem.Addr().Interface(), metadata); err != nil { return err } continue } - // if config struct has "flag" tag in flags: - // if flag is set, use flag value - // else - // if toml file has key, use toml value - // else use flag default value + // if config struct has "flag" tag: + // if flag is set, use flag value + // else if env has key, use environment value + // else if toml file has key, use toml value + // else use flag default value useFlagDefaultValue := false if !isFlagSet(tag) { - tomlHasKey := false - for _, key := range metadata.Keys() { - if strings.ToLower(key.String()) == strings.ToLower(tag) { - tomlHasKey = true - break - } - } - if tomlHasKey { + _, envHasKey := os.LookupEnv(tag) + if envHasKey || tomlHasKey(metadata, tag) { continue } else { useFlagDefaultValue = true @@ -80,56 +98,87 @@ func bindFlags(dst interface{}, metadata toml.MetaData) error { } // CLI value - if flag.Lookup(tag) == nil { return fmt.Errorf("flag '%v' is not defined but given as flag struct tag in %v.%v", tag, reflect.TypeOf(dst), field.Name()) } - fVal := flag.Lookup(tag).Value.String() + var fVal string if useFlagDefaultValue { fVal = flag.Lookup(tag).DefValue + } else { + fVal = flag.Lookup(tag).Value.String() } - // Destination - dstElem := reflect.ValueOf(dst).Elem().FieldByName(field.Name()) + if err := setDstElem(dst, field, fVal); err != nil { + return err + } + } - // Attempt to convert the flag input depending on type of destination - switch dstElem.Kind().String() { - case "bool": - if p, err := strconv.ParseBool(fVal); err != nil { - return err - } else { - dstElem.SetBool(p) - } - case "int": - if p, err := strconv.ParseInt(fVal, 10, 0); err != nil { - return err - } else { - dstElem.SetInt(p) - } - case "uint": - if p, err := strconv.ParseUint(fVal, 10, 0); err != nil { - return err - } else { - dstElem.SetUint(p) - } - case "float64": - if p, err := strconv.ParseFloat(fVal, 64); err != nil { - return err - } else { - dstElem.SetFloat(p) - } - case "string": - dstElem.SetString(fVal) + return nil +} + +// isNestedStruct will check if destination element or its pointer is struct type +func isNestedStruct(dst interface{}, field *structs.Field) (bool, reflect.Value) { + dstElem := reflect.ValueOf(dst).Elem().FieldByName(field.Name()) + if dstElem.Kind() == reflect.Ptr { + if dstElem.IsNil() { + // Create new non-nil ptr + dstElem.Set(reflect.New(dstElem.Type().Elem())) + } + + // Dereference + dstElem = dstElem.Elem() + } + + if dstElem.Kind() != reflect.Struct { + return false, dstElem + } + + return true, dstElem +} - default: - return fmt.Errorf("Unhandled type %v for elem %v", dstElem.Kind().String(), field.Name()) +// setDstElem will convert tag input to its real type +func setDstElem(dst interface{}, field *structs.Field, fVal string) error { + // Destination + dstElem := reflect.ValueOf(dst).Elem().FieldByName(field.Name()) + + // Attempt to convert the tag input depending on type of destination + switch dstElem.Kind().String() { + case "bool": + if p, err := strconv.ParseBool(fVal); err != nil { + return err + } else { + dstElem.SetBool(p) + } + case "int", "int8", "int16", "int32", "int64": + if p, err := strconv.ParseInt(fVal, 10, 0); err != nil { + return err + } else { + dstElem.SetInt(p) } + case "uint", "uint8", "uint16", "uint32", "uint64", "uintptr": + if p, err := strconv.ParseUint(fVal, 10, 0); err != nil { + return err + } else { + dstElem.SetUint(p) + } + case "float64", "float32": + if p, err := strconv.ParseFloat(fVal, 64); err != nil { + return err + } else { + dstElem.SetFloat(p) + } + case "string": + dstElem.SetString(fVal) + + default: + return fmt.Errorf("unhandled type %v for elem %v", dstElem.Kind().String(), field.Name()) } return nil } +// isFlagSet will check if flag is set func isFlagSet(tag string) bool { flagSet := false flag.Visit(func(fl *flag.Flag) { @@ -139,3 +188,13 @@ func isFlagSet(tag string) bool { }) return flagSet } + +// tomlHasKey will check if the tag presents in toml metadata +func tomlHasKey(metadata toml.MetaData, tag string) bool { + for _, key := range metadata.Keys() { + if strings.ToLower(key.String()) == strings.ToLower(tag) { + return true + } + } + return false +} diff --git a/config_test.go b/config_test.go index cb51364..2e1b8bc 100644 --- a/config_test.go +++ b/config_test.go @@ -98,12 +98,16 @@ port = 7070 func TestLoad_FlagNotGivenWithDefaultValue(t *testing.T) { var cfg struct { - Host string `toml:"host"` - Port int `toml:"port" flag:"port"` + Host string `toml:"host"` + Port int `toml:"port" flag:"port"` + Mode string `toml:"mode" env:"mode" flag:"mode"` + Secret string `env:"secret" flag:"secret"` } fs := flag.NewFlagSet("tmp", flag.ExitOnError) _ = fs.Int("port", 9090, "Port to listen to") + _ = fs.String("mode", "warning", "Log mode") + _ = fs.String("secret", "secret_flag", "Secret variable") flag.CommandLine = fs flag.CommandLine.Parse(nil) // flag not given and has default value @@ -113,11 +117,15 @@ func TestLoad_FlagNotGivenWithDefaultValue(t *testing.T) { _, err := tmp.WriteString(` host = "localhost" port = 1010 +mode = "info" `) if err != nil { t.Fatalf("unexpected error: %v", err) } + os.Setenv("mode", "debug") + os.Setenv("secret", "secret_env") + if err := Load(tmp.Name(), &cfg); err != nil { t.Fatalf("unexpected error %v", err) } @@ -129,12 +137,27 @@ port = 1010 if cfg.Port != 1010 { t.Errorf("got: %v, expected: %v", cfg.Port, 1010) } + + // environment dominant over toml + if cfg.Mode != "debug" { + t.Errorf("got: %v, expected: %v", cfg.Mode, "debug") + } + + // environment dominant over toml + if cfg.Mode != "debug" { + t.Errorf("got: %v, expected: %v", cfg.Mode, "debug") + } + + // environment dominant over toml + if cfg.Secret != "secret_env" { + t.Errorf("got: %v, expected: %v", cfg.Secret, "secret_env") + } } func TestLoad_UseFlagDefaultValueIfKeyNotFoundInConfig(t *testing.T) { var cfg struct { LogLevel string `toml:"logLevel"` - Port int `toml:"-" flag:"port"` + Port int `toml:"-" env:"-" flag:"port"` } tmp, _ := ioutil.TempFile("", "") defer os.Remove(tmp.Name()) @@ -157,17 +180,17 @@ LogLevel = "debug" if cfg.LogLevel != "debug" { t.Errorf("got: %v, expected: %v", cfg.LogLevel, "debug") } + if cfg.Port != 9090 { t.Errorf("got: %v, expected: %v", cfg.Port, 9090) } - } -func TestWithFlagNested(t *testing.T) { +func TestLoad_FlagNested(t *testing.T) { var cfg struct { Server struct { Host string `toml:"host"` - Port int `toml:"port"` + Port int `toml:"-" flag:"port"` } `toml:"server"` } @@ -177,8 +200,12 @@ func TestWithFlagNested(t *testing.T) { _, err := tmp.WriteString(` [server] host = "localhost" -port = 1010 `) + fs := flag.NewFlagSet("tmp", flag.ExitOnError) + _ = fs.Int("port", 9090, "Port to listen to") + flag.CommandLine = fs + flag.CommandLine.Parse([]string{"-port", "1010"}) // flag given + if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -196,7 +223,7 @@ port = 1010 } } -func TestWithFlagNestedPtr(t *testing.T) { +func TestLoad_FlagNestedPtr(t *testing.T) { var cfg struct { Server *struct { Host string `toml:"host"` @@ -229,7 +256,7 @@ port = 1010 } } -func TestLoad_ErrorIfFlagNotSetAndNotGiven(t *testing.T) { +func TestLoad_ErrorIfFlagSetAndNotGiven(t *testing.T) { var cfg struct { LogLevel string `toml:"logLevel"` Port int `toml:"port" flag:"port"` @@ -254,3 +281,313 @@ LogLevel = "debug" t.Fatalf("expected error, got nil") } } + +func TestLoad_EnvGiven(t *testing.T) { + var cfg struct { + Key string `toml:"-" flag:"-" env:"key"` + Secret string `toml:"-" flag:"-" env:"secret"` + } + os.Setenv("key", "some_key") + os.Setenv("secret", "some_secret") + + tmp, _ := ioutil.TempFile("", "") + defer os.Remove(tmp.Name()) + + _, err := tmp.WriteString(`host = "localhost"`) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if err := Load(tmp.Name(), &cfg); err != nil { + t.Fatalf("unexpected error %v", err) + } + + if cfg.Key != "some_key" { + t.Errorf("got: %v, expected: %v", cfg.Key, "some_key") + } + + if cfg.Secret != "some_secret" { + t.Errorf("got: %v, expected: %v", cfg.Secret, "some_secret") + } +} + +func TestLoad_EnvGivenWithNested(t *testing.T) { + var cfg struct { + Db struct { + User string `env:"db_user"` + Password string `env:"db_password"` + } + } + os.Setenv("db_user", "secret_user") + os.Setenv("db_password", "secret_password") + + tmp, _ := ioutil.TempFile("", "") + defer os.Remove(tmp.Name()) + + if err := Load(tmp.Name(), &cfg); err != nil { + t.Fatalf("unexpected error %v", err) + } + + if cfg.Db.User != "secret_user" { + t.Errorf("got: %v, expected: %v", cfg.Db.User, "secret_user") + } + + if cfg.Db.Password != "secret_password" { + t.Errorf("got: %v, expected: %v", cfg.Db.Password, "secret_password") + } +} + +func TestLoad_EnvGivenWithNestedPtr(t *testing.T) { + var cfg struct { + Db *struct { + User string `env:"db_user"` + Password string `env:"db_password"` + } + } + os.Setenv("db_user", "secret_user") + os.Setenv("db_password", "secret_password") + + tmp, _ := ioutil.TempFile("", "") + defer os.Remove(tmp.Name()) + + if err := Load(tmp.Name(), &cfg); err != nil { + t.Fatalf("unexpected error %v", err) + } + + if cfg.Db.User != "secret_user" { + t.Errorf("got: %v, expected: %v", cfg.Db.User, "secret_user") + } + + if cfg.Db.Password != "secret_password" { + t.Errorf("got: %v, expected: %v", cfg.Db.Password, "secret_password") + } +} + +func TestLoad_ParseOtherTagsIfEnvSetAndNotGiven(t *testing.T) { + var cfg struct { + LogLevel string `env:"logLevel" flag:"logLevel"` + Port int `toml:"port" env:"port"` + Host string `toml:"host" env:"host" flag:"host"` + } + + tmp, _ := ioutil.TempFile("", "") + defer os.Remove(tmp.Name()) + + _, err := tmp.WriteString(` +port = 7777 +flag = "localhost" +`) + if err != nil { + t.Fatalf("write config file failed: %v", err) + } + + fs := flag.NewFlagSet("tmp", flag.ExitOnError) + _ = fs.String("logLevel", "debug", "Log level") + _ = fs.String("host", "localhost", "Host address") + + flag.CommandLine = fs + flag.CommandLine.Parse([]string{"-logLevel", "debug"}) // flag given + flag.CommandLine.Parse([]string{"-host", "dev.example.com"}) // flag given + + // os.Setenv("port", "9090") // env not set + // os.Setenv("logLevel", "warning") // env not set + // os.Setenv("host", "secret.example.com") // env not set + + if err := Load(tmp.Name(), &cfg); err != nil { + t.Fatalf("unexpected error %v", err) + } + + if cfg.Port != 7777 { + t.Errorf("got: %v, expected: %v", cfg.Port, 7777) + } + + if cfg.LogLevel != "debug" { + t.Errorf("got: %v, expected: %v", cfg.LogLevel, "debug") + } + + if cfg.LogLevel != "debug" { + t.Errorf("got: %v, expected: %v", cfg.LogLevel, "debug") + } + + if cfg.Host != "dev.example.com" { + t.Errorf("got: %v, expected: %v", cfg.Host, "dev.example.com") + } +} + +func TestLoad_CheckTagPriorities(t *testing.T) { + var cfg struct { + Key1 string `toml:"key1" flag:"key1"` + Key2 string `toml:"key2" env:"key2"` + Key3 string `flag:"key3" env:"key3"` + Key4 string `toml:"key4" flag:"key4" env:"key4"` + Key5 string `toml:"key5"` + } + + tmp, _ := ioutil.TempFile("", "") + defer os.Remove(tmp.Name()) + + // toml + _, err := tmp.WriteString(` +key1 = "key1_toml" +key2 = "key2_toml" +key4 = "key4_toml" +key5 = "key5_toml" +`) + + if err != nil { + t.Fatalf("write config file failed: %v", err) + } + + // flag + fs := flag.NewFlagSet("tmp", flag.ExitOnError) + _ = fs.String("key1", "", "") + _ = fs.String("key3", "", "") + _ = fs.String("key4", "", "") + + flag.CommandLine = fs + flag.CommandLine.Parse([]string{"-key1", "key1_flag"}) // flag given + flag.CommandLine.Parse([]string{"-key3", "key3_flag"}) // flag given + flag.CommandLine.Parse([]string{"-key4", "key4_flag"}) // flag given + + // env + os.Setenv("key2", "key2_env") + os.Setenv("key3", "key3_env") + os.Setenv("key4", "key4_env") + + if err := Load(tmp.Name(), &cfg); err != nil { + t.Fatalf("unexpected error %v", err) + } + + // priority order + // -- flag > env > toml + + // flag has higher priority than toml + if cfg.Key1 != "key1_flag" { + t.Errorf("got: %v, expected: %v", cfg.Key1, "key1_flag") + } + + // env has higher priority than toml + if cfg.Key2 != "key2_env" { + t.Errorf("got: %v, expected: %v", cfg.Key2, "key2_env") + } + + // flag has higher priority than env + if cfg.Key3 != "key3_flag" { + t.Errorf("got: %v, expected: %v", cfg.Key3, "key3_flag") + } + + // flag has higher priority than both env and toml + if cfg.Key4 != "key4_flag" { + t.Errorf("got: %v, expected: %v", cfg.Key4, "key4_flag") + } + + // toml has lowest priority + if cfg.Key5 != "key5_toml" { + t.Errorf("got: %v, expected: %v", cfg.Key5, "key5_toml") + } +} + +func TestLoad_ErrorIfFlagTypeMismatch(t *testing.T) { + var cfg struct { + Key int `flag:"key1"` + } + + tmp, _ := ioutil.TempFile("", "") + defer os.Remove(tmp.Name()) + + // flag + fs := flag.NewFlagSet("tmp", flag.ExitOnError) + _ = fs.String("key1", "", "") + + flag.CommandLine = fs + flag.CommandLine.Parse([]string{"-key1", "key1_flag"}) // flag given + + if err := Load(tmp.Name(), &cfg); err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestLoad_ErrorIfEnvTypeMismatch(t *testing.T) { + var cfg struct { + KeyFloat float64 `env:"key_float"` + } + + tmp, _ := ioutil.TempFile("", "") + defer os.Remove(tmp.Name()) + + // env + os.Setenv("key_float", "key_float_env") + + if err := Load(tmp.Name(), &cfg); err == nil { + t.Fatalf("expected error, got nil") + } +} + +func TestLoad_CheckNumericTypes(t *testing.T) { + var cfg struct { + Float32 float32 `flag:"float32"` + Int8 int8 `toml:"int8"` + Int16 int16 `env:"int16"` + Uint32 uint32 `toml:"uint32"` + Uint64 uint64 `env:"uint64"` + UintPtr uintptr `env:"uintptr"` + Bool bool `flag:"bool"` + } + + tmp, _ := ioutil.TempFile("", "") + defer os.Remove(tmp.Name()) + + _, err := tmp.WriteString(` +int8 = -2 +uint32 = 1 +`) + if err != nil { + t.Fatalf("write config file failed: %v", err) + } + + // flag + fs := flag.NewFlagSet("tmp", flag.ExitOnError) + _ = fs.Bool("bool", false, "") + _ = fs.Float64("float32", 0.0, "") + + flag.CommandLine = fs + flag.CommandLine.Parse([]string{"-bool", "true"}) // flag given + flag.CommandLine.Parse([]string{"-float32", "1.3"}) // flag given + + // env + os.Setenv("uint64", "100000000000") + os.Setenv("uintptr", "20") + os.Setenv("int16", "3") + + if err := Load(tmp.Name(), &cfg); err != nil { + t.Fatalf("unexpected error %v", err) + } + + if cfg.Float32 != 1.3 { + t.Errorf("got: %v, expected: %v", cfg.Float32, 1.3) + } + + if cfg.Int8 != -2 { + t.Errorf("got: %v, expected: %v", cfg.Int8, -2) + } + + if cfg.Int16 != 3 { + t.Errorf("got: %v, expected: %v", cfg.Int16, 3) + } + + if cfg.Uint32 != 1 { + t.Errorf("got: %v, expected: %v", cfg.Uint32, 1) + } + + if cfg.Uint64 != 100000000000 { + t.Errorf("got: %v, expected: %v", cfg.Uint64, 100000000000) + } + + if cfg.UintPtr != 20 { + t.Errorf("got: %v, expected: %v", cfg.UintPtr, 20) + } + + if cfg.Bool != true { + t.Errorf("got: %v, expected: %v", cfg.Bool, true) + } +} diff --git a/go.mod b/go.mod index da55d19..822cb39 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,7 @@ require ( github.com/BurntSushi/toml v0.3.0 github.com/fatih/structs v1.0.0 github.com/fsnotify/fsnotify v1.4.7 - golang.org/x/sys v0.0.0-20180313075820-8c0ece68c283 + golang.org/x/sys v0.0.0-20180313075820-8c0ece68c283 // indirect ) + +go 1.11 diff --git a/notify_test.go b/notify_test.go index 8ec817d..6a9795d 100644 --- a/notify_test.go +++ b/notify_test.go @@ -13,8 +13,9 @@ func TestNotify(t *testing.T) { Key string `toml:"key"` } - tmp, _ := ioutil.TempFile("", "") - defer os.Remove(tmp.Name()) + dir, _ := ioutil.TempDir("", "") + tmp, _ := ioutil.TempFile(dir, "") + defer os.RemoveAll(dir) if _, err := tmp.WriteString(`key = "hey"`); err != nil { t.Fatalf("unexpected error: %v", err)