Skip to content
This repository was archived by the owner on Jul 12, 2025. It is now read-only.

Commit cdd2cc4

Browse files
yunxu111jackc
authored andcommitted
EC-2198 change for sslpassword
1 parent 7402796 commit cdd2cc4

File tree

3 files changed

+66
-56
lines changed

3 files changed

+66
-56
lines changed

config.go

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -709,22 +709,31 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
709709
}
710710
block, _ := pem.Decode(buf)
711711
var pemKey []byte
712+
var decryptedKey []byte
713+
var decryptedError error
712714
// If PEM is encrypted, attempt to decrypt using pass phrase
713715
if x509.IsEncryptedPEMBlock(block) {
714-
if sslpassword == "" {
716+
// Attempt decryption with pass phrase
717+
// NOTE: only supports RSA (PKCS#1)
718+
if(sslpassword != ""){
719+
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
720+
}
721+
//if sslpassword not provided or has decryption error when use it
722+
//try to find sslpassword with callback function
723+
if (sslpassword == "" || decryptedError!= nil) {
715724
if(parseConfigOptions.GetSSLPassword != nil){
716725
sslpassword = parseConfigOptions.GetSSLPassword(context.Background())
717-
}else{
726+
}
727+
if(sslpassword == ""){
718728
return nil, fmt.Errorf("unable to find sslpassword")
719729
}
720730
}
721-
// Attempt decryption with pass phrase
722-
// NOTE: only supports RSA (PKCS#1)
723-
decryptedKey, err := x509.DecryptPEMBlock(block, []byte(sslpassword))
731+
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
724732
// Should we also provide warning for PKCS#1 needed?
725-
if err != nil {
733+
if decryptedError != nil {
726734
return nil, fmt.Errorf("unable to decrypt key: %w", err)
727735
}
736+
728737
pemBytes := pem.Block{
729738
Type: "RSA PRIVATE KEY",
730739
Bytes: decryptedKey,

pgconn.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,18 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) {
109109
return ConnectConfig(ctx, config)
110110
}
111111

112+
// Connect establishes a connection to a PostgreSQL server using the environment
113+
// and connString (in URL or DSN format) and ParseConfigOptions
114+
// to provide configuration. See documentation for ParseConfig for details. ctx can be used to cancel a connect attempt.
115+
func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) {
116+
config, err := ParseConfigWithOptions(connString, parseConfigOptions)
117+
if err != nil {
118+
return nil, err
119+
}
120+
121+
return ConnectConfig(ctx, config)
122+
}
123+
112124
// Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with
113125
// ParseConfig. ctx can be used to cancel a connect attempt.
114126
//

pgconn_test.go

Lines changed: 39 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package pgconn_test
22

33
import (
4-
"bufio"
54
"bytes"
65
"compress/gzip"
76
"context"
@@ -54,71 +53,56 @@ func TestConnect(t *testing.T) {
5453
}
5554
}
5655

57-
// TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure
58-
// connection.
59-
func TestConnectTLS(t *testing.T) {
60-
t.Parallel()
61-
62-
connString := os.Getenv("PGX_TEST_TLS_CONN_STRING")
63-
if connString == "" {
64-
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING")
65-
}
66-
67-
var conn *pgconn.PgConn
68-
var err error
69-
70-
isSslPasswrodEmpty := strings.HasSuffix(connString, "sslpassword=")
71-
72-
if isSslPasswrodEmpty {
73-
config, err := pgconn.ParseConfigWithSslPasswordCallback(connString, GetSslPassword)
74-
require.Nil(t, err)
75-
76-
conn, err = pgconn.ConnectConfig(context.Background(), config)
77-
require.NoError(t, err)
78-
} else {
79-
conn, err = pgconn.Connect(context.Background(), connString)
80-
require.NoError(t, err)
81-
}
82-
83-
if _, ok := conn.Conn().(*tls.Conn); !ok {
84-
t.Error("not a TLS connection")
56+
func TestConnectWithOption(t *testing.T) {
57+
tests := []struct {
58+
name string
59+
env string
60+
}{
61+
{"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"},
62+
{"TCP", "PGX_TEST_TCP_CONN_STRING"},
63+
{"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"},
64+
{"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"},
65+
{"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"},
8566
}
8667

87-
closeConn(t, conn)
88-
}
68+
for _, tt := range tests {
69+
tt := tt
70+
t.Run(tt.name, func(t *testing.T) {
71+
connString := os.Getenv(tt.env)
72+
if connString == "" {
73+
t.Skipf("Skipping due to missing environment variable %v", tt.env)
74+
}
75+
var sslOptions pgconn.ParseConfigOptions
76+
sslOptions.GetSSLPassword = GetSSLPassword
77+
conn, err := pgconn.ConnectWithOptions(context.Background(), connString, sslOptions)
78+
require.NoError(t, err)
8979

90-
func GetSslPassword() string {
91-
readFile, err := os.Open("data.txt")
92-
if err != nil {
93-
fmt.Println(err)
94-
}
95-
fileScanner := bufio.NewScanner(readFile)
96-
fileScanner.Split(bufio.ScanLines)
97-
for fileScanner.Scan() {
98-
line := fileScanner.Text()
99-
if strings.HasPrefix(line, "sslpassword=") {
100-
index := len("sslpassword=")
101-
line := line[index:]
102-
return line
103-
}
80+
closeConn(t, conn)
81+
})
10482
}
105-
return ""
10683
}
10784

108-
func TestConnectTLSCallback(t *testing.T) {
85+
// TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure
86+
// connection.
87+
func TestConnectTLS(t *testing.T) {
10988
t.Parallel()
11089

11190
connString := os.Getenv("PGX_TEST_TLS_CONN_STRING")
11291
if connString == "" {
11392
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING")
11493
}
11594

116-
config, err := pgconn.ParseConfigWithSslPasswordCallback(connString, GetSslPassword)
95+
var conn *pgconn.PgConn
96+
var err error
97+
98+
var sslOptions pgconn.ParseConfigOptions
99+
sslOptions.GetSSLPassword = GetSSLPassword
100+
config, err := pgconn.ParseConfigWithOptions(connString, sslOptions)
117101
require.Nil(t, err)
118102

119-
conn, err := pgconn.ConnectConfig(context.Background(), config)
103+
conn, err = pgconn.ConnectConfig(context.Background(), config)
120104
require.NoError(t, err)
121-
105+
122106
if _, ok := conn.Conn().(*tls.Conn); !ok {
123107
t.Error("not a TLS connection")
124108
}
@@ -2180,3 +2164,8 @@ func Example() {
21802164
// 3
21812165
// SELECT 3
21822166
}
2167+
2168+
func GetSSLPassword(ctx context.Context) string {
2169+
connString := os.Getenv("PGX_SSL_PASSWORD")
2170+
return connString
2171+
}

0 commit comments

Comments
 (0)