Skip to content

Commit

Permalink
fixed concurrent map write issue on ConnectionDetails.Options (#577)
Browse files Browse the repository at this point in the history
  • Loading branch information
sio4 committed Sep 24, 2022
1 parent 902a7d1 commit d71c504
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 37 deletions.
38 changes: 36 additions & 2 deletions connection_details.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"regexp"
"strconv"
"strings"
"sync"
"time"

"github.com/gobuffalo/pop/v6/internal/defaults"
Expand Down Expand Up @@ -46,8 +47,10 @@ type ConnectionDetails struct {
// Defaults to 0 "unlimited". See https://golang.org/pkg/database/sql/#DB.SetConnMaxIdleTime
ConnMaxIdleTime time.Duration
// Defaults to `false`. See https://godoc.org/github.com/jmoiron/sqlx#DB.Unsafe
Unsafe bool
Options map[string]string
Unsafe bool
// Options stores Connection Details options
Options map[string]string
optionsLock *sync.Mutex
// Query string encoded options from URL. Example: "sslmode=disable"
RawOptions string
// UseInstrumentedDriver if set to true uses a wrapper for the underlying driver which exposes tracing
Expand Down Expand Up @@ -190,3 +193,34 @@ func (cd *ConnectionDetails) OptionsString(s string) string {
}
return strings.TrimLeft(s, "&")
}

// option returns the value stored in ConnecitonDetails.Options with key k.
func (cd *ConnectionDetails) option(k string) string {
if cd.Options == nil {
return ""
}
return defaults.String(cd.Options[k], "")
}

// setOptionWithDefault stores given value v in ConnectionDetails.Options
// with key k. If v is empty string, it stores def instead.
// It uses locking mechanism to make the operation safe.
func (cd *ConnectionDetails) setOptionWithDefault(k, v, def string) {
cd.setOption(k, defaults.String(v, def))
}

// setOption stores given value v in ConnectionDetails.Options with key k.
// It uses locking mechanism to make the operation safe.
func (cd *ConnectionDetails) setOption(k, v string) {
if cd.optionsLock == nil {
cd.optionsLock = &sync.Mutex{}
}

cd.optionsLock.Lock()
if cd.Options == nil { // prevent panic
cd.Options = make(map[string]string)
}

cd.Options[k] = v
cd.optionsLock.Unlock()
}
6 changes: 3 additions & 3 deletions dialect_cockroach.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ func (p *cockroach) DumpSchema(w io.Writer) error {
cmd := exec.Command("cockroach", "sql", "-e", "SHOW CREATE ALL TABLES", "-d", p.Details().Database, "--format", "raw")

c := p.ConnectionDetails
if defaults.String(c.Options["sslmode"], "disable") == "disable" || strings.Contains(c.RawOptions, "sslmode=disable") {
if defaults.String(c.option("sslmode"), "disable") == "disable" || strings.Contains(c.RawOptions, "sslmode=disable") {
cmd.Args = append(cmd.Args, "--insecure")
}
return cockroachDumpSchema(p.Details(), cmd, w)
Expand Down Expand Up @@ -302,13 +302,13 @@ func newCockroach(deets *ConnectionDetails) (dialect, error) {
translateCache: map[string]string{},
mu: sync.Mutex{},
}
d.info.client = deets.Options["application_name"]
d.info.client = deets.option("application_name")
return d, nil
}

func finalizerCockroach(cd *ConnectionDetails) {
appName := filepath.Base(os.Args[0])
cd.Options["application_name"] = defaults.String(cd.Options["application_name"], appName)
cd.setOptionWithDefault("application_name", cd.option("application_name"), appName)
cd.Port = defaults.String(cd.Port, portCockroach)
if cd.URL != "" {
cd.URL = "postgres://" + trimCockroachPrefix(cd.URL)
Expand Down
27 changes: 11 additions & 16 deletions dialect_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ func (m *mysql) CreateDB() error {
return fmt.Errorf("error creating MySQL database %s: %w", deets.Database, err)
}
defer db.Close()
charset := defaults.String(deets.Options["charset"], "utf8mb4")
encoding := defaults.String(deets.Options["collation"], "utf8mb4_general_ci")
charset := defaults.String(deets.option("charset"), "utf8mb4")
encoding := defaults.String(deets.option("collation"), "utf8mb4_general_ci")
query := fmt.Sprintf("CREATE DATABASE `%s` DEFAULT CHARSET `%s` DEFAULT COLLATE `%s`", deets.Database, charset, encoding)
log(logging.SQL, query)

Expand Down Expand Up @@ -242,11 +242,10 @@ func urlParserMySQL(cd *ConnectionDetails) error {
cd.User = cfg.User
cd.Password = cfg.Passwd
cd.Database = cfg.DBName
if cd.Options == nil { // prevent panic
cd.Options = make(map[string]string)
}

// NOTE: use cfg.Params if want to fill options with full parameters
cd.Options["collation"] = cfg.Collation
cd.setOption("collation", cfg.Collation)

if cfg.Net == "unix" {
cd.Port = "socket" // trick. see: `URL()`
cd.Host = cfg.Addr
Expand Down Expand Up @@ -274,20 +273,16 @@ func finalizerMySQL(cd *ConnectionDetails) {
"multiStatements": "true",
}

if cd.Options == nil { // prevent panic
cd.Options = make(map[string]string)
}

for k, v := range defs {
cd.Options[k] = defaults.String(cd.Options[k], v)
for k, def := range defs {
cd.setOptionWithDefault(k, cd.option(k), def)
}

for k, v := range forced {
// respect user specified options but print warning!
cd.Options[k] = defaults.String(cd.Options[k], v)
if cd.Options[k] != v { // when user-defined option exists
log(logging.Warn, "IMPORTANT! '%s: %s' option is required to work properly but your current setting is '%v: %v'.", k, v, k, cd.Options[k])
log(logging.Warn, "It is highly recommended to remove '%v: %v' option from your config!", k, cd.Options[k])
cd.setOptionWithDefault(k, cd.option(k), v)
if cd.option(k) != v { // when user-defined option exists
log(logging.Warn, "IMPORTANT! '%s: %s' option is required to work properly but your current setting is '%v: %v'.", k, v, k, cd.option(k))
log(logging.Warn, "It is highly recommended to remove '%v: %v' option from your config!", k, cd.option(k))
} // or override with `cd.Options[k] = v`?
if cd.URL != "" && !strings.Contains(cd.URL, k+"="+v) {
log(logging.Warn, "IMPORTANT! '%s=%s' option is required to work properly. Please add it to the database URL in the config!", k, v)
Expand Down
4 changes: 2 additions & 2 deletions dialect_postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,12 @@ func urlParserPostgreSQL(cd *ConnectionDetails) error {
options := []string{"fallback_application_name"}
for i := range options {
if opt, ok := conf.RuntimeParams[options[i]]; ok {
cd.Options[options[i]] = opt
cd.setOption(options[i], opt)
}
}

if conf.TLSConfig == nil {
cd.Options["sslmode"] = "disable"
cd.setOption("sslmode", "disable")
}

return nil
Expand Down
22 changes: 8 additions & 14 deletions dialect_sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func (m *sqlite) Lock(fn func() error) error {
}

func (m *sqlite) locker(l *sync.Mutex, fn func() error) error {
if defaults.String(m.Details().Options["lock"], "true") == "true" {
if defaults.String(m.Details().option("lock"), "true") == "true" {
defer l.Unlock()
l.Lock()
}
Expand Down Expand Up @@ -309,11 +309,8 @@ func urlParserSQLite3(cd *ConnectionDetails) error {
return fmt.Errorf("unable to parse sqlite query: %w", err)
}

if cd.Options == nil { // prevent panic
cd.Options = make(map[string]string)
}
for k := range q {
cd.Options[k] = q.Get(k)
cd.setOption(k, q.Get(k))
}

return nil
Expand All @@ -326,20 +323,17 @@ func finalizerSQLite(cd *ConnectionDetails) {
forced := map[string]string{
"_fk": "true",
}
if cd.Options == nil { // prevent panic
cd.Options = make(map[string]string)
}

for k, v := range defs {
cd.Options[k] = defaults.String(cd.Options[k], v)
for k, def := range defs {
cd.setOptionWithDefault(k, cd.option(k), def)
}

for k, v := range forced {
// respect user specified options but print warning!
cd.Options[k] = defaults.String(cd.Options[k], v)
if cd.Options[k] != v { // when user-defined option exists
log(logging.Warn, "IMPORTANT! '%s: %s' option is required to work properly but your current setting is '%v: %v'.", k, v, k, cd.Options[k])
log(logging.Warn, "It is highly recommended to remove '%v: %v' option from your config!", k, cd.Options[k])
cd.setOptionWithDefault(k, cd.option(k), v)
if cd.option(k) != v { // when user-defined option exists
log(logging.Warn, "IMPORTANT! '%s: %s' option is required to work properly but your current setting is '%v: %v'.", k, v, k, cd.option(k))
log(logging.Warn, "It is highly recommended to remove '%v: %v' option from your config!", k, cd.option(k))
} // or override with `cd.Options[k] = v`?
if cd.URL != "" && !strings.Contains(cd.URL, k+"="+v) {
log(logging.Warn, "IMPORTANT! '%s=%s' option is required to work properly. Please add it to the database URL in the config!", k, v)
Expand Down

0 comments on commit d71c504

Please sign in to comment.