Skip to content

Commit

Permalink
Fix tls=true didn't work with host without port.
Browse files Browse the repository at this point in the history
  • Loading branch information
methane committed Nov 30, 2017
1 parent cd4cb90 commit c5207bc
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 11 deletions.
20 changes: 9 additions & 11 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@ func (cfg *Config) normalize() error {
cfg.Addr = ensureHavePort(cfg.Addr)
}

if cfg.tls != nil {
if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
host, _, err := net.SplitHostPort(cfg.Addr)
if err == nil {
cfg.tls.ServerName = host
}
}
}

return nil
}

Expand Down Expand Up @@ -521,10 +530,6 @@ func parseDSNParams(cfg *Config, params string) (err error) {
if boolValue {
cfg.TLSConfig = "true"
cfg.tls = &tls.Config{}
host, _, err := net.SplitHostPort(cfg.Addr)
if err == nil {
cfg.tls.ServerName = host
}
} else {
cfg.TLSConfig = "false"
}
Expand All @@ -538,13 +543,6 @@ func parseDSNParams(cfg *Config, params string) (err error) {
}

if tlsConfig := getTLSConfigClone(name); tlsConfig != nil {
if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify {
host, _, err := net.SplitHostPort(cfg.Addr)
if err == nil {
tlsConfig.ServerName = host
}
}

cfg.TLSConfig = name
cfg.tls = tlsConfig
} else {
Expand Down
28 changes: 28 additions & 0 deletions dsn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,34 @@ func TestDSNWithCustomTLS(t *testing.T) {
DeregisterTLSConfig("utils_test")
}

func TestDSNTLSConfig(t *testing.T) {
expectedServerName := "example.com"
dsn := "tcp(example.com:1234)/?tls=true"

cfg, err := ParseDSN(dsn)
if err != nil {
t.Error(err.Error())
}
if cfg.tls == nil {
t.Error("cfg.tls should not be nil")
}
if cfg.tls.ServerName != expectedServerName {
t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName)
}

dsn = "tcp(example.com)/?tls=true"
cfg, err = ParseDSN(dsn)
if err != nil {
t.Error(err.Error())
}
if cfg.tls == nil {
t.Error("cfg.tls should not be nil")
}
if cfg.tls.ServerName != expectedServerName {
t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.tls.ServerName)
}
}

func TestDSNWithCustomTLSQueryEscape(t *testing.T) {
const configKey = "&%!:"
dsn := "User:password@tcp(localhost:5555)/dbname?tls=" + url.QueryEscape(configKey)
Expand Down

0 comments on commit c5207bc

Please sign in to comment.