diff --git a/examples/custom/config.yaml b/examples/custom/config.yaml new file mode 100644 index 0000000..96713bf --- /dev/null +++ b/examples/custom/config.yaml @@ -0,0 +1,7 @@ +app: + environment: dev + +server: + port: 443 + read_timeout: 1m + diff --git a/examples/custom/custom_test.go b/examples/custom/custom_test.go new file mode 100644 index 0000000..6054d02 --- /dev/null +++ b/examples/custom/custom_test.go @@ -0,0 +1,73 @@ +package custom + +import ( + "fmt" + "strings" + + "github.com/kkyr/fig" +) + +type ListenerType uint + +const ( + ListenerUnix ListenerType = iota + ListenerTCP + ListenerTLS +) + +type Config struct { + App struct { + Environment string `fig:"environment" validate:"required"` + } `fig:"app"` + Server struct { + Host string `fig:"host" default:"0.0.0.0"` + Port int `fig:"port" default:"80"` + Listener ListenerType `fig:"listener_type" default:"tcp"` + } `fig:"server"` +} + +func ExampleLoad() { + var cfg Config + err := fig.Load(&cfg) + if err != nil { + panic(err) + } + + fmt.Println(cfg.App.Environment) + fmt.Println(cfg.Server.Host) + fmt.Println(cfg.Server.Port) + fmt.Println(cfg.Server.Listener) + + // Output: + // dev + // 0.0.0.0 + // 443 + // tcp +} + +func (l *ListenerType) UnmarshalString(v string) error { + switch strings.ToLower(v) { + case "unix": + *l = ListenerUnix + case "tcp": + *l = ListenerTCP + case "tls": + *l = ListenerTLS + default: + return fmt.Errorf("unknown listener type: %s", v) + } + return nil +} + +func (l ListenerType) String() string { + switch l { + case ListenerUnix: + return "unix" + case ListenerTCP: + return "tcp" + case ListenerTLS: + return "tls" + default: + return "unknown" + } +} diff --git a/fig.go b/fig.go index 4f6232a..3da9d60 100644 --- a/fig.go +++ b/fig.go @@ -28,6 +28,42 @@ const ( DefaultTimeLayout = time.RFC3339 ) +// StringUnmarshaler is an interface for custom unmarshaling of strings +// +// If a field with a local type asignment satisfies this interface, it allows the user +// to implment their own custom type unmarshaling method. +// +// Example: +// +// type ListenerType uint +// +// const ( +// ListenerUnix ListenerType = iota +// ListenerTCP +// ListenerTLS +// ) +// +// type Config struct { +// Listener ListenerType `fig:"listener_type" default:"unix"` +// } +// +// func (l *ListenerType) UnmarshalType(v string) error { +// switch strings.ToLower(v) { +// case "unix": +// *l = ListenerUnix +// case "tcp": +// *l = ListenerTCP +// case "tls": +// *l = ListenerTLS +// default: +// return fmt.Errorf("unknown listener type: %s", v) +// } +// return nil +// } +type StringUnmarshaler interface { + UnmarshalString(s string) error +} + // Load reads a configuration file and loads it into the given struct. The // parameter `cfg` must be a pointer to a struct. // @@ -158,6 +194,7 @@ func (f *fig) decodeMap(m map[string]interface{}, result interface{}) error { mapstructure.StringToTimeDurationHookFunc(), mapstructure.StringToTimeHookFunc(f.timeLayout), stringToRegexpHookFunc(), + stringToStringUnmarshalerHook(), ), }) if err != nil { @@ -183,6 +220,36 @@ func stringToRegexpHookFunc() mapstructure.DecodeHookFunc { } } +// stringToStringUnmarshalerHook returns a DecodeHookFunc that executes a custom method which +// satisfies the StringUnmarshaler interface on custom types. +func stringToStringUnmarshalerHook() mapstructure.DecodeHookFunc { + return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) { + if f.Kind() != reflect.String { + return data, nil + } + + ds, ok := data.(string) + if !ok { + return data, nil + } + + if reflect.PointerTo(t).Implements(reflect.TypeOf((*StringUnmarshaler)(nil)).Elem()) { + val := reflect.New(t).Interface() + + if unmarshaler, ok := val.(StringUnmarshaler); ok { + err := unmarshaler.UnmarshalString(ds) + if err != nil { + return nil, err + } + + return reflect.ValueOf(val).Elem().Interface(), nil + } + } + + return data, nil + } +} + // processCfg processes a cfg struct after it has been loaded from // the config file, by validating required fields and setting defaults // where applicable. @@ -257,9 +324,24 @@ func (f *fig) setDefaultValue(fv reflect.Value, val string) error { // setValue sets fv to val. it attempts to convert val to the correct // type based on the field's kind. if conversion fails an error is -// returned. +// returned. If fv satisfies the StringUnmarshaler interface it will +// execute the corresponding StringUnmarshaler.UnmarshalString method +// on the value. // fv must be settable else this panics. func (f *fig) setValue(fv reflect.Value, val string) error { + if fv.IsValid() && reflect.PointerTo(fv.Type()).Implements(reflect.TypeOf((*StringUnmarshaler)(nil)).Elem()) { + vi := reflect.New(fv.Type()).Interface() + if unmarshaler, ok := vi.(StringUnmarshaler); ok { + err := unmarshaler.UnmarshalString(val) + if err != nil { + return fmt.Errorf("could not unmarshal string %q: %w", val, err) + } + fv.Set(reflect.ValueOf(vi).Elem()) + return nil + } + return fmt.Errorf("unexpected error while trying to unmarshal string") + } + switch fv.Kind() { case reflect.Ptr: if fv.IsNil() { diff --git a/fig_test.go b/fig_test.go index 737ab64..36e0573 100644 --- a/fig_test.go +++ b/fig_test.go @@ -84,6 +84,28 @@ type Item struct { Path string `fig:"path" validate:"required"` } +type ListenerType uint + +const ( + ListenerUnix ListenerType = iota + ListenerTCP + ListenerTLS +) + +func (l *ListenerType) UnmarshalString(v string) error { + switch strings.ToLower(v) { + case "unix": + *l = ListenerUnix + case "tcp": + *l = ListenerTCP + case "tls": + *l = ListenerTLS + default: + return fmt.Errorf("unknown listener type: %s", v) + } + return nil +} + func validPodConfig() Pod { var pod Pod @@ -249,6 +271,7 @@ func Test_fig_Load_Defaults(t *testing.T) { Application struct { BuildDate time.Time `fig:"build_date" default:"2020-01-01T12:00:00Z"` } + Listener ListenerType `fig:"listener_type" default:"unix"` } var want Server @@ -259,6 +282,7 @@ func Test_fig_Load_Defaults(t *testing.T) { want.Logger.Production = false want.Logger.Metadata.Keys = []string{"ts"} want.Application.BuildDate = time.Date(2020, 1, 1, 12, 0, 0, 0, time.UTC) + want.Listener = ListenerUnix var cfg Server err := Load(&cfg, File(f), Dirs(filepath.Join("testdata", "valid"))) @@ -590,8 +614,9 @@ func Test_fig_decodeMap(t *testing.T) { "log_level": "debug", "severity": "5", "server": map[string]interface{}{ - "ports": []int{443, 80}, - "secure": 1, + "ports": []int{443, 80}, + "secure": 1, + "listener_type": "tls", }, } @@ -599,8 +624,9 @@ func Test_fig_decodeMap(t *testing.T) { Level string `fig:"log_level"` Severity int `fig:"severity" validate:"required"` Server struct { - Ports []string `fig:"ports" default:"[443]"` - Secure bool + Ports []string `fig:"ports" default:"[443]"` + Secure bool + Listener ListenerType `fig:"listener_type" default:"unix"` } `fig:"server"` } @@ -623,6 +649,10 @@ func Test_fig_decodeMap(t *testing.T) { if cfg.Server.Secure == false { t.Error("cfg.Server.Secure == false") } + + if cfg.Server.Listener != ListenerTLS { + t.Errorf("cfg.Server.Listener: want: %d, got: %d", ListenerTLS, cfg.Server.Listener) + } } func Test_fig_processCfg(t *testing.T) { diff --git a/testdata/valid/server.toml b/testdata/valid/server.toml index a5ba440..ad60113 100644 --- a/testdata/valid/server.toml +++ b/testdata/valid/server.toml @@ -1,4 +1,4 @@ host = "0.0.0.0" [logger] -log_level = "debug" \ No newline at end of file +log_level = "debug"