Skip to content

Commit

Permalink
Merge pull request #1088 from kelvich/sni_support
Browse files Browse the repository at this point in the history
  • Loading branch information
rafiss authored Sep 6, 2022
2 parents d65e6ae + 957fc0b commit d5affd5
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 1 deletion.
4 changes: 3 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1127,7 +1127,7 @@ func isDriverSetting(key string) bool {
return true
case "password":
return true
case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline":
case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline", "sslsni":
return true
case "fallback_application_name":
return true
Expand Down Expand Up @@ -2020,6 +2020,8 @@ func parseEnviron(env []string) (out map[string]string) {
accrue("sslkey")
case "PGSSLROOTCERT":
accrue("sslrootcert")
case "PGSSLSNI":
accrue("sslsni")
case "PGREQUIRESSL", "PGSSLCRL":
unsupported()
case "PGREQUIREPEER":
Expand Down
11 changes: 11 additions & 0 deletions ssl.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os"
"os/user"
"path/filepath"
"strings"
)

// ssl generates a function to upgrade a net.Conn based on the "sslmode" and
Expand Down Expand Up @@ -50,6 +51,16 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
return nil, fmterrorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode)
}

// Set Server Name Indication (SNI), if enabled by connection parameters.
// By default SNI is on, any value which is not starting with "1" disables
// SNI -- that is the same check vanilla libpq uses.
if sslsni := o["sslsni"]; sslsni == "" || strings.HasPrefix(sslsni, "1") {
// RFC 6066 asks to not set SNI if the host is a literal IP address (IPv4
// or IPv6). This check is coded already crypto.tls.hostnameInSNI, so
// just always set ServerName here and let crypto/tls do the filtering.
tlsConf.ServerName = o["host"]
}

err := sslClientCertificates(&tlsConf, o)
if err != nil {
return nil, err
Expand Down
139 changes: 139 additions & 0 deletions ssl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@ package pq
// This file contains SSL tests

import (
"bytes"
_ "crypto/sha256"
"crypto/tls"
"crypto/x509"
"database/sql"
"fmt"
"io"
"net"
"os"
"path/filepath"
"strings"
"testing"
"time"
)

func maybeSkipSSLTests(t *testing.T) {
Expand Down Expand Up @@ -280,3 +287,135 @@ func TestSSLClientCertificates(t *testing.T) {
}
}
}

// Check that clint sends SNI data when `sslsni` is not disabled
func TestSNISupport(t *testing.T) {
t.Parallel()
tests := []struct {
name string
conn_param string
hostname string
expected_sni string
}{
{
name: "SNI is set by default",
conn_param: "",
hostname: "localhost",
expected_sni: "localhost",
},
{
name: "SNI is passed when asked for",
conn_param: "sslsni=1",
hostname: "localhost",
expected_sni: "localhost",
},
{
name: "SNI is not passed when disabled",
conn_param: "sslsni=0",
hostname: "localhost",
expected_sni: "",
},
{
name: "SNI is not set for IPv4",
conn_param: "",
hostname: "127.0.0.1",
expected_sni: "",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

// Start mock postgres server on OS-provided port
listener, err := net.Listen("tcp", "127.0.0.1:")
if err != nil {
t.Fatal(err)
}
serverErrChan := make(chan error, 1)
serverSNINameChan := make(chan string, 1)
go mockPostgresSSL(listener, serverErrChan, serverSNINameChan)

defer listener.Close()
defer close(serverErrChan)
defer close(serverSNINameChan)

// Try to establish a connection with the mock server. Connection will error out after TLS
// clientHello, but it is enough to catch SNI data on the server side
port := strings.Split(listener.Addr().String(), ":")[1]
connStr := fmt.Sprintf("sslmode=require host=%s port=%s %s", tt.hostname, port, tt.conn_param)

// We are okay to skip this error as we are polling serverErrChan and we'll get an error
// or timeout from the server side in case of problems here.
db, _ := sql.Open("postgres", connStr)
_, _ = db.Exec("SELECT 1")

// Check SNI data
select {
case sniHost := <-serverSNINameChan:
if sniHost != tt.expected_sni {
t.Fatalf("Expected SNI to be 'localhost', got '%+v' instead", sniHost)
}
case err = <-serverErrChan:
t.Fatalf("mock server failed with error: %+v", err)
case <-time.After(time.Second):
t.Fatal("exceeded connection timeout without erroring out")
}
})
}
}

// Make a postgres mock server to test TLS SNI
//
// Accepts postgres StartupMessage and handles TLS clientHello, then closes a connection.
// While reading clientHello catch passed SNI data and report it to nameChan.
func mockPostgresSSL(listener net.Listener, errChan chan error, nameChan chan string) {
var sniHost string

conn, err := listener.Accept()
if err != nil {
errChan <- err
return
}
defer conn.Close()

err = conn.SetDeadline(time.Now().Add(time.Second))
if err != nil {
errChan <- err
return
}

// Receive StartupMessage with SSL Request
startupMessage := make([]byte, 8)
if _, err := io.ReadFull(conn, startupMessage); err != nil {
errChan <- err
return
}
// StartupMessage: first four bytes -- total len = 8, last four bytes SslRequestNumber
if !bytes.Equal(startupMessage, []byte{0, 0, 0, 0x8, 0x4, 0xd2, 0x16, 0x2f}) {
errChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage)
return
}

// Respond with SSLOk
_, err = conn.Write([]byte("S"))
if err != nil {
errChan <- err
return
}

// Set up TLS context to catch clientHello. It will always error out during handshake
// as no certificate is set.
srv := tls.Server(conn, &tls.Config{
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
sniHost = argHello.ServerName
return nil, nil
},
})
defer srv.Close()

// Do the TLS handshake ignoring errors
_ = srv.Handshake()

nameChan <- sniHost
}

0 comments on commit d5affd5

Please sign in to comment.