Skip to content

Commit

Permalink
The global config support set the global config once per web server (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mstmdev authored Mar 29, 2023
1 parent 957a367 commit 0ae616b
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 20 deletions.
3 changes: 2 additions & 1 deletion cmd/gofs.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ func runWithConfig(c conf.Config, result result.Result) {
}()

cp := &c
conf.GlobalConfig = cp

if err = parseConfigFile(cp); err != nil {
result.InitDoneWithError(err)
Expand Down Expand Up @@ -91,6 +90,8 @@ func runWithConfig(c conf.Config, result result.Result) {
return
}

log.ErrorIf(conf.SetGlobalConfig(cp), "set global config error => %s", cp.FileServerAddr)

// kill parent process
if c.KillPPid {
daemon.KillPPid()
Expand Down
50 changes: 48 additions & 2 deletions conf/global.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,52 @@
package conf

import (
"errors"
"sync"
)

var (
globalConfigSet = &configSet{
m: make(map[string]*Config),
}
)

var (
// GlobalConfig the global config of the program, initial by flags or config file
GlobalConfig *Config
errConfigIsNil = errors.New("the config is nil")
errConfigExist = errors.New("the config exists")
)

type configSet struct {
m map[string]*Config
mu sync.RWMutex
}

func (cs *configSet) setGlobalConfig(c *Config) error {
if c == nil {
return errConfigIsNil
}
addr := c.FileServerAddr
cs.mu.Lock()
defer cs.mu.Unlock()
if _, ok := cs.m[addr]; ok {
return errConfigExist
}
cs.m[addr] = c
return nil
}

func (cs *configSet) getGlobalConfig(addr string) *Config {
cs.mu.RLock()
defer cs.mu.RUnlock()
return cs.m[addr]
}

// SetGlobalConfig set the global config once per web server
func SetGlobalConfig(c *Config) error {
return globalConfigSet.setGlobalConfig(c)
}

// GetGlobalConfig get the global config by web server address
func GetGlobalConfig(addr string) *Config {
return globalConfigSet.getGlobalConfig(addr)
}
52 changes: 52 additions & 0 deletions conf/global_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package conf

import (
"testing"
)

func TestSetGlobalConfig(t *testing.T) {
testCases := []struct {
name string
config *Config
err error
}{
{"normal address", &Config{FileServerAddr: ":8080"}, nil},
{"empty address", &Config{FileServerAddr: ""}, nil},
{"nil config", nil, errConfigIsNil},
{"config exists", &Config{FileServerAddr: ":8080"}, errConfigExist},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := SetGlobalConfig(tc.config)
if err != tc.err {
t.Errorf("expect to get error %v, but get %v", tc.err, err)
}
})
}
}

func TestGetGlobalConfig(t *testing.T) {
if err := SetGlobalConfig(&Config{FileServerAddr: ":8088"}); err != nil {
t.Errorf("call SetGlobalConfig error")
return
}
testCases := []struct {
name string
address string
exist bool
}{
{"normal address", ":8088", true},
{"not exist config", ":8000", false},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
config := GetGlobalConfig(tc.address)
exist := config != nil
if exist != tc.exist {
t.Errorf("expect to get config %v, but get %v", tc.exist, exist)
}
})
}
}
41 changes: 25 additions & 16 deletions server/handler/manage_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,35 +13,44 @@ import (

type manageHandler struct {
logger log.Logger
addr string
}

// NewManageHandlerFunc returns a gin.HandlerFunc that shows the application config
func NewManageHandlerFunc(logger log.Logger) gin.HandlerFunc {
func NewManageHandlerFunc(logger log.Logger, addr string) gin.HandlerFunc {
return (&manageHandler{
logger: logger,
addr: addr,
}).Handle
}

func (h *manageHandler) Handle(c *gin.Context) {
format := strings.ToLower(c.Query(server.ParamFormat))
var result server.ApiResult
// copy the config and mask the user info for security
config := *conf.GlobalConfig
mask := "******"
if len(config.Users) > 0 {
config.Users = mask
}
if len(config.SessionConnection) > 0 {
config.SessionConnection = mask
}
if len(config.EncryptSecret) > 0 {
config.EncryptSecret = mask
}
if len(config.DecryptSecret) > 0 {
config.DecryptSecret = mask
cp := conf.GetGlobalConfig(h.addr)
if cp == nil {
result = server.NewErrorApiResult(contract.NotFound, contract.NotFoundDesc)
} else {
config := *cp
mask := "******"
if len(config.Users) > 0 {
config.Users = mask
}
if len(config.SessionConnection) > 0 {
config.SessionConnection = mask
}
if len(config.EncryptSecret) > 0 {
config.EncryptSecret = mask
}
if len(config.DecryptSecret) > 0 {
config.DecryptSecret = mask
}
result = server.NewApiResult(contract.Success, contract.SuccessDesc, config)
}
if format == conf.YamlFormat.Name() {
c.YAML(http.StatusOK, server.NewApiResult(contract.Success, contract.SuccessDesc, config))
c.YAML(http.StatusOK, result)
} else {
c.PureJSON(http.StatusOK, server.NewApiResult(contract.Success, contract.SuccessDesc, config))
c.PureJSON(http.StatusOK, result)
}
}
2 changes: 1 addition & 1 deletion server/httpfs/file_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ func initManageRoute(opt server.Option, logger log.Logger, manageGroup *gin.Rout
manageGroup.Use(middleware.NewPrivateAccessHandlerFunc(logger))
}
pprof.RouteRegister(manageGroup, server.PProfRoutePrefix)
manageGroup.GET(server.ManageConfigRoute, handler.NewManageHandlerFunc(logger))
manageGroup.GET(server.ManageConfigRoute, handler.NewManageHandlerFunc(logger, opt.Addr))
if opt.EnableReport {
manageGroup.GET(server.ManageReportRoute, handler.NewReportHandlerFunc(logger))
report.GlobalReporter.Enable(true)
Expand Down

0 comments on commit 0ae616b

Please sign in to comment.