Skip to content

Commit

Permalink
Add TLS arg parsing to ParseDSN().
Browse files Browse the repository at this point in the history
Factor out the TLS cert handling and add it to `configTLS()` via a
`struct` argument.
  • Loading branch information
sean- committed Feb 2, 2018
1 parent d7f24b9 commit 8078930
Showing 1 changed file with 65 additions and 43 deletions.
108 changes: 65 additions & 43 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -703,45 +703,17 @@ func ParseURI(uri string) (ConnConfig, error) {
cp.Dial = d.Dial
}

err = configTLS(url.Query().Get("sslmode"), &cp)
tlsArgs := configTLSArgs{
sslCert: url.Query().Get("sslcert"),
sslKey: url.Query().Get("sslkey"),
sslMode: url.Query().Get("sslmode"),
sslRootCert: url.Query().Get("sslrootcert"),
}
err = configTLS(tlsArgs, &cp)
if err != nil {
return cp, err
}

// Extract optional TLS parameters and reconstruct a coherent tls.Config based
// on the DSN input. Reuse the same keywords found in github.com/lib/pq.
if cp.TLSConfig != nil {
{
caCertPool := x509.NewCertPool()

caPath := url.Query().Get("sslrootcert")
caCert, err := ioutil.ReadFile(caPath)
if err != nil {
return cp, errors.Wrapf(err, "unable to read CA file %q", caPath)
}

if !caCertPool.AppendCertsFromPEM(caCert) {
return cp, errors.Wrap(err, "unable to add CA to cert pool")
}

cp.TLSConfig.RootCAs = caCertPool
cp.TLSConfig.ClientCAs = caCertPool
}

sslcert := url.Query().Get("sslcert")
sslkey := url.Query().Get("sslkey")
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
return cp, fmt.Errorf(`both "sslcert" and "sslkey" are required`)
}

cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
if err != nil {
return cp, errors.Wrap(err, "unable to read cert")
}

cp.TLSConfig.Certificates = []tls.Certificate{cert}
}

ignoreKeys := map[string]struct{}{
"connect_timeout": {},
"sslcert": {},
Expand Down Expand Up @@ -783,7 +755,7 @@ func ParseDSN(s string) (ConnConfig, error) {

m := dsnRegexp.FindAllStringSubmatch(s, -1)

var sslmode string
tlsArgs := configTLSArgs{}

cp.RuntimeParams = make(map[string]string)

Expand All @@ -804,7 +776,13 @@ func ParseDSN(s string) (ConnConfig, error) {
case "dbname":
cp.Database = b[2]
case "sslmode":
sslmode = b[2]
tlsArgs.sslMode = b[2]
case "sslrootcert":
tlsArgs.sslRootCert = b[2]
case "sslcert":
tlsArgs.sslCert = b[2]
case "sslkey":
tlsArgs.sslKey = b[2]
case "connect_timeout":
timeout, err := strconv.ParseInt(b[2], 10, 64)
if err != nil {
Expand All @@ -818,7 +796,7 @@ func ParseDSN(s string) (ConnConfig, error) {
}
}

err := configTLS(sslmode, &cp)
err := configTLS(tlsArgs, &cp)
if err != nil {
return cp, err
}
Expand Down Expand Up @@ -898,7 +876,7 @@ func ParseEnvLibpq() (ConnConfig, error) {

sslmode := os.Getenv("PGSSLMODE")

err := configTLS(sslmode, &cc)
err := configTLS(configTLSArgs{sslMode: sslmode}, &cc)
if err != nil {
return cc, err
}
Expand All @@ -913,14 +891,27 @@ func ParseEnvLibpq() (ConnConfig, error) {
return cc, nil
}

func configTLS(sslmode string, cc *ConnConfig) error {
type configTLSArgs struct {
sslMode string
sslRootCert string
sslCert string
sslKey string
}

// configTLS uses lib/pq's TLS parameters to reconstruct a coherent tls.Config.
// Inputs are parsed out and provided by ParseDSN() or ParseURI().
func configTLS(args configTLSArgs, cc *ConnConfig) error {
// Match libpq default behavior
if sslmode == "" {
sslmode = "prefer"
if args.sslMode == "" {
args.sslMode = "prefer"
}

switch sslmode {
switch args.sslMode {
case "disable":
cc.UseFallbackTLS = false
cc.TLSConfig = nil
cc.FallbackTLSConfig = nil
return nil
case "allow":
cc.UseFallbackTLS = true
cc.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true}
Expand All @@ -938,6 +929,37 @@ func configTLS(sslmode string, cc *ConnConfig) error {
return errors.New("sslmode is invalid")
}

{
caCertPool := x509.NewCertPool()

caPath := args.sslRootCert
caCert, err := ioutil.ReadFile(caPath)
if err != nil {
return errors.Wrapf(err, "unable to read CA file %q", caPath)
}

if !caCertPool.AppendCertsFromPEM(caCert) {
return errors.Wrap(err, "unable to add CA to cert pool")
}

cc.TLSConfig.RootCAs = caCertPool
cc.TLSConfig.ClientCAs = caCertPool
}

sslcert := args.sslCert
sslkey := args.sslKey

if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
return fmt.Errorf(`both "sslcert" and "sslkey" are required`)
}

cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
if err != nil {
return errors.Wrap(err, "unable to read cert")
}

cc.TLSConfig.Certificates = []tls.Certificate{cert}

return nil
}

Expand Down

0 comments on commit 8078930

Please sign in to comment.