Skip to content

Commit

Permalink
Modify testcases and update README
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvdeshmukh committed Aug 31, 2023
1 parent 63ba4b4 commit 18ea4dc
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 53 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

* Always Encrypted encryption and decryption with 2 hour key cache (#116)
* 'pfx', 'MSSQL_CERTIFICATE_STORE', and 'AZURE_KEY_VAULT' encryption key providers
* TDS8 can now be used for connections by setting encrypt="strict"

## 1.5.0

Expand Down
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ Other supported formats are listed below.
* `connection timeout` - in seconds (default is 0 for no timeout), set to 0 for no timeout. Recommended to set to 0 and use context to manage query and connection timeouts.
* `dial timeout` - in seconds (default is 15 times the number of registered protocols), set to 0 for no timeout.
* `encrypt`
* `strict` - Data sent between client and server is encrypted E2E using [TDS8](https://learn.microsoft.com/en-us/sql/relational-databases/security/networking/tds-8?view=sql-server-ver16).
* `disable` - Data send between client and server is not encrypted.
* `false` - Data sent between client and server is not encrypted beyond the login packet. (Default)
* `true` - Data sent between client and server is encrypted.
* `false`/`optional`/`no`/`0`/`f` - Data sent between client and server is not encrypted beyond the login packet. (Default)
* `true`/`mandatory`/`yes`/`1`/`t` - Data sent between client and server is encrypted.
* `app name` - The application name (default is go-mssqldb)
* `authenticator` - Can be used to specify use of a registered authentication provider. (e.g. ntlm, winsspi (on windows) or krb5 (on linux))

Expand Down Expand Up @@ -56,7 +57,7 @@ Other supported formats are listed below.
* `TrustServerCertificate`
* false - Server certificate is checked. Default is false if encrypt is specified.
* true - Server certificate is not checked. Default is true if encrypt is not specified. If trust server certificate is true, driver accepts any certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing.
* `certificate` - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates.
* `certificate` - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates. Currently, certificates of PEM type are supported.
* `hostNameInCertificate` - Specifies the Common Name (CN) in the server certificate. Default value is the server host.
* `tlsmin` - Specifies the minimum TLS version for negotiating encryption with the server. Recognized values are `1.0`, `1.1`, `1.2`, `1.3`. If not set to a recognized value the default value for the `tls` package will be used. The default is currently `1.2`.
* `ServerSPN` - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port.
Expand Down Expand Up @@ -468,6 +469,7 @@ Constrain the provider to an allowed list of key vaults by appending vault host
* Always Encrypted
- `MSSQL_CERTIFICATE_STORE` provider on Windows
- `pfx` provider on Linux and Windows

## Tests

`go test` is used for testing. A running instance of MSSQL server is required.
Expand Down
35 changes: 32 additions & 3 deletions azuread/azuread_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package azuread
import (
"bufio"
"database/sql"
"encoding/hex"
"io"
"os"
"testing"
Expand All @@ -14,7 +15,7 @@ import (
)

func TestAzureSqlAuth(t *testing.T) {
mssqlConfig := testConnParams(t)
mssqlConfig := testConnParams(t, "")

conn, err := newConnectorConfig(mssqlConfig)
if err != nil {
Expand All @@ -35,9 +36,31 @@ func TestAzureSqlAuth(t *testing.T) {

}

func TestTDS8ConnWithAzureSqlAuth(t *testing.T) {
mssqlConfig := testConnParams(t, ";encrypt=strict;TrustServerCertificate=false;tlsmin=1.2")
conn, err := newConnectorConfig(mssqlConfig)
if err != nil {
t.Fatalf("Unable to get a connector: %v", err)
}
db := sql.OpenDB(conn)
row := db.QueryRow("SELECT protocol_type, CONVERT(varbinary(9),protocol_version),client_net_address from sys.dm_exec_connections where session_id=@@SPID")
if err != nil {
t.Fatal("Prepare failed:", err.Error())
}
var protocolName string
var tdsver []byte
var clientAddress string
err = row.Scan(&protocolName, &tdsver, &clientAddress)
if err != nil {
t.Fatal("Scan failed:", err.Error())
}
assertEqual(t, "TSQL", protocolName)
assertEqual(t, "0x08000000", hex.EncodeToString(tdsver))
}

// returns parsed connection parameters derived from
// environment variables
func testConnParams(t testing.TB) *azureFedAuthConfig {
func testConnParams(t testing.TB, dsnParams string) *azureFedAuthConfig {
dsn := os.Getenv("AZURESERVER_DSN")
const logFlags = 127
if dsn == "" {
Expand All @@ -54,7 +77,7 @@ func testConnParams(t testing.TB) *azureFedAuthConfig {
if dsn == "" {
t.Skip("no azure database connection string. set AZURESERVER_DSN environment variable or create .azureconnstr file")
}
config, err := parse(dsn)
config, err := parse(dsn + dsnParams)
if err != nil {
t.Skip("error parsing connection string ")
}
Expand All @@ -64,3 +87,9 @@ func testConnParams(t testing.TB) *azureFedAuthConfig {
config.mssqlConfig.LogFlags = logFlags
return config
}

func assertEqual(t *testing.T, expected interface{}, actual interface{}) {
if expected != actual {
t.Fatalf("Expected %v, got %v", expected, actual)
}
}
50 changes: 41 additions & 9 deletions msdsn/conn_str.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ type (
BrowserMsg byte
)

const (
DsnTypeUrl = 1
DsnTypeOdbc = 2
DsnTypeAdo = 3
)

const (
EncryptionOff = 0
EncryptionRequired = 1
Expand Down Expand Up @@ -192,6 +198,9 @@ func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, e
certificate := params[Certificate]
if encryption != EncryptionDisabled {
tlsMin := params[TLSMin]
if encrypt == "strict" {
trustServerCert = false
}
tlsConfig, err := SetupTLS(certificate, trustServerCert, host, tlsMin)
if err != nil {
return encryption, nil, fmt.Errorf("failed to setup TLS: %w", err)
Expand All @@ -203,28 +212,51 @@ func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, e

var skipSetup = errors.New("skip setting up TLS")

func Parse(dsn string) (Config, error) {
p := Config{
ProtocolParameters: map[string]interface{}{},
Protocols: []string{},
func getDsnType(dsn string) int {
if strings.HasPrefix(dsn, "sqlserver://") {
return DsnTypeUrl
}
if strings.HasPrefix(dsn, "odbc:") {
return DsnTypeOdbc
}
return DsnTypeAdo
}

func getDsnParams(dsn string) (map[string]string, error) {

var params map[string]string
var err error
if strings.HasPrefix(dsn, "odbc:") {

switch getDsnType(dsn) {
case DsnTypeOdbc:
params, err = splitConnectionStringOdbc(dsn[len("odbc:"):])
if err != nil {
return p, err
return params, err
}
} else if strings.HasPrefix(dsn, "sqlserver://") {
case DsnTypeUrl:
params, err = splitConnectionStringURL(dsn)
if err != nil {
return p, err
return params, err
}
} else {
default:
params = splitConnectionString(dsn)
}
return params, nil
}

func Parse(dsn string) (Config, error) {
p := Config{
ProtocolParameters: map[string]interface{}{},
Protocols: []string{},
}

var params map[string]string
var err error

params, err = getDsnParams(dsn)
if err != nil {
return p, err
}
p.Parameters = params

strlog, ok := params[LogParam]
Expand Down
16 changes: 9 additions & 7 deletions tds.go
Original file line number Diff line number Diff line change
Expand Up @@ -1133,7 +1133,7 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont
return l, nil
}

func getTLSConn(conn *timeoutConn, p msdsn.Config) (tlsConn *tls.Conn, err error) {
func getTLSConn(conn *timeoutConn, p msdsn.Config, alpnSeq string) (tlsConn *tls.Conn, err error) {
var config *tls.Config
if pc := p.TLSConfig; pc != nil {
config = pc
Expand All @@ -1145,17 +1145,17 @@ func getTLSConn(conn *timeoutConn, p msdsn.Config) (tlsConn *tls.Conn, err error
}
}
//Set ALPN Sequence
config.NextProtos = []string{"tds/8.0"}
config.NextProtos = []string{alpnSeq}
tlsConn = tls.Client(conn.c, config)
err = tlsConn.Handshake()
if err != nil {
return nil, fmt.Errorf("TLS Handshake failed: %v", err)
return nil, fmt.Errorf("TLS Handshake failed: %w", err)
}
return tlsConn, nil
}

func connect(ctx context.Context, c *Connector, logger ContextLogger, p msdsn.Config) (res *tdsSession, err error) {

isTransportEncrypted := false
// if instance is specified use instance resolution service
if len(p.Instance) > 0 && p.Port != 0 && uint64(p.LogFlags)&logDebug != 0 {
// both instance name and port specified
Expand Down Expand Up @@ -1198,10 +1198,11 @@ initiate_connection:
outbuf := newTdsBuffer(packetSize, toconn)

if p.Encryption == msdsn.EncryptionStrict {
outbuf.transport, err = getTLSConn(toconn, p)
outbuf.transport, err = getTLSConn(toconn, p, "tds/8.0")
if err != nil {
return nil, err
}
isTransportEncrypted = true
}
sess := tdsSession{
buf: outbuf,
Expand Down Expand Up @@ -1238,7 +1239,8 @@ initiate_connection:
return nil, err
}

if p.Encryption != msdsn.EncryptionStrict {
//We need not perform TLS handshake if the communication channel is already encrypted (encrypt=strict)
if isTransportEncrypted {
if encrypt != encryptNotSup {
var config *tls.Config
if pc := p.TLSConfig; pc != nil {
Expand Down Expand Up @@ -1278,7 +1280,7 @@ initiate_connection:
}
}

} //p.Encryption != msdsn.EncryptionStrict
}

auth, err := integratedauth.GetIntegratedAuthenticator(p)
if err != nil {
Expand Down
40 changes: 9 additions & 31 deletions tds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -665,47 +665,25 @@ func TestSecureConnection(t *testing.T) {
}
}

func TestTDS8Connection(t *testing.T) {
func TestTDS8ConnFailure(t *testing.T) {
checkConnStr(t)
tl := testLogger{t: t}
defer tl.StopLogging()
SetLogger(&tl)

dsn := makeConnStr(t)
config := testConnParams(t)
dsn := config.URL()
if !strings.HasSuffix(strings.Split(dsn.Host, ":")[0], ".database.windows.net") {
t.Skip()
}
dsnParams := dsn.Query()
dsnParams.Set("encrypt", "strict")
dsnParams.Set("TrustServerCertificate", "false")
dsnParams.Set("tlsmin", "1.2")
dsnParams.Set(msdsn.TrustServerCertificate, "true")
dsnParams.Set(msdsn.Encrypt, "strict")
dsnParams.Set(msdsn.TLSMin, "1.2")
dsn.RawQuery = dsnParams.Encode()

conn, err := sql.Open("mssql", dsn.String())
if err != nil {
t.Fatal("Open connection failed:", err.Error())
}
defer conn.Close()
stmt, err := conn.Prepare("SELECT protocol_type, CONVERT(varbinary(9),protocol_version),client_net_address from sys.dm_exec_connections where session_id=@@SPID")
if err != nil {
t.Fatal("Prepare failed:", err.Error())
}
defer stmt.Close()
row := stmt.QueryRow()
var protocolName string
var tdsver []byte
var clientAddress string
err = row.Scan(&protocolName, &tdsver, &clientAddress)
if err != nil {
t.Fatal("Scan failed:", err.Error())
}
assertEqual(t, "TSQL", protocolName)
assertEqual(t, "0x08000000", hex.EncodeToString(tdsver))
}

func assertEqual(t *testing.T, expected interface{}, actual interface{}) {
if expected != actual {
t.Fatalf("Expected %v, got %v", expected, actual)
_, err := sql.Open("mssql", dsn.String())
if err == nil {
t.Fatal("Connection did not fail for unknown CA certificate with encrypt=strict")
}
}

Expand Down

0 comments on commit 18ea4dc

Please sign in to comment.