Skip to content

Commit

Permalink
Merge pull request #608 from gravitational/upstream-with-tls-config
Browse files Browse the repository at this point in the history
Add `WithTLSConfig` method to `OracleConnector`
  • Loading branch information
sijms authored Oct 7, 2024
2 parents 4526d69 + 45c605b commit eaedcff
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 1 deletion.
141 changes: 141 additions & 0 deletions examples/mtls/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package main

import (
"crypto/tls"
"crypto/x509"
"database/sql"
"flag"
"fmt"
"log"
"os"

go_ora "github.com/sijms/go-ora/v2"
)

func databaseConn(addr, serviceName string, tlsConfig *tls.Config) error {
var driver go_ora.OracleDriver

dsn := fmt.Sprintf(`oracle://%s/%s?SSL=enabled&AUTH TYPE=TCPS`, addr, serviceName)
conn, err := driver.OpenConnector(dsn)
if err != nil {
return err
}
oc, ok := conn.(*go_ora.OracleConnector)
if !ok {
return fmt.Errorf("failed to cast to OracleConnector")
}
oc.WithTLSConfig(tlsConfig)

dbConn := sql.OpenDB(conn)

err = dbConn.Ping()
if err != nil {
return err
}

var queryResultColumnOne string
row := dbConn.QueryRow("SELECT systimestamp FROM dual")
err = row.Scan(&queryResultColumnOne)
if err != nil {
return err
}

fmt.Println("The time in the database:", queryResultColumnOne)

return nil
}

func printHelp() {
helpMessage := `
Usage: go run ./examples/mtls [OPTIONS]
Options:
-addr string
Database address (e.g., localhost:2484) (required)
-service string
Database service name (required)
-cert string
Path to the user certificate (e.g., alice.crt) (required)
-key string
Path to the user key (e.g., alice.key) (required)
-server-ca-cert string
Path to the server CA certificate file (optional)
-insecure
Skip TLS certificate verification (default: false)
-help
Display help information
`
fmt.Print(helpMessage)
}

func main() {
addr := flag.String("addr", "", "Database address (e.g., localhost:2484)")
serviceName := flag.String("service", "", "Database service name")
certFile := flag.String("cert", "", "Path to the user certificate (e.g., alice.crt)")
keyFile := flag.String("key", "", "Path to the user key (e.g., alice.key)")
serverCaCertFile := flag.String("server-ca-cert", "", "Path to the server CA certificate file (optional)")
insecureSkipVerify := flag.Bool("insecure", false, "Skip TLS certificate verification (default: false)")
help := flag.Bool("help", false, "Display help information")

flag.Parse()

// If help is requested, display the help message and exit.
if *help || len(os.Args) == 1 {
printHelp()
return
}

// Expand any environment variables in the flags
*addr = os.ExpandEnv(*addr)
*serviceName = os.ExpandEnv(*serviceName)
*certFile = os.ExpandEnv(*certFile)
*keyFile = os.ExpandEnv(*keyFile)
*serverCaCertFile = os.ExpandEnv(*serverCaCertFile)

// Check for required flags
if *addr == "" {
log.Fatal("database address is required; flag missing or empty")
}
if *serviceName == "" {
log.Fatal("database service name is required; flag missing or empty")
}
if *certFile == "" {
log.Fatal("path to the user certificate is required; flag missing or empty")
}
if *keyFile == "" {
log.Fatal("path to the user key is required; flag missing or empty")
}

// Load the user certificate
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
if err != nil {
log.Fatalf("failed to load certificate: %v", err)
}

// Create a TLS config with the certificate
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: *insecureSkipVerify,
}

// Load server CA certificate if the file is provided
if *serverCaCertFile != "" {
caCert, err := os.ReadFile(*serverCaCertFile)
if err != nil {
log.Fatalf("failed to read server CA certificate file: %v", err)
}

certPool := x509.NewCertPool()

if !certPool.AppendCertsFromPEM(caCert) {
log.Fatalf("failed to append server CA certificate from file: %s", *serverCaCertFile)
}

tlsConfig.RootCAs = certPool
}

// Call the database connection function
if err := databaseConn(*addr, *serviceName, tlsConfig); err != nil {
log.Fatalf("database connection error: %v", err)
}
}
2 changes: 2 additions & 0 deletions v2/configurations/session_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package configurations

import (
"context"
"crypto/tls"
"fmt"
"net"
"strings"
Expand All @@ -22,6 +23,7 @@ type SessionInfo struct {
Protocol string
SSL bool
SSLVerify bool
TLSConfig *tls.Config
Dialer DialerContext
}

Expand Down
9 changes: 9 additions & 0 deletions v2/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package go_ora

import (
"context"
"crypto/tls"
"database/sql"
"database/sql/driver"
"encoding/binary"
Expand Down Expand Up @@ -118,6 +119,7 @@ type OracleConnector struct {
drv *OracleDriver
connectString string
dialer configurations.DialerContext
tlsConfig *tls.Config
}

func NewConnector(connString string) driver.Connector {
Expand All @@ -144,6 +146,9 @@ func (connector *OracleConnector) Connect(ctx context.Context) (driver.Conn, err
if conn.connOption.Dialer == nil {
conn.connOption.Dialer = connector.dialer
}
if conn.connOption.TLSConfig == nil {
conn.connOption.TLSConfig = connector.tlsConfig
}
err = conn.OpenWithContext(ctx)
if err != nil {
return nil, err
Expand All @@ -163,6 +168,10 @@ func (connector *OracleConnector) Dialer(dialer configurations.DialerContext) {
connector.dialer = dialer
}

func (connector *OracleConnector) WithTLSConfig(config *tls.Config) {
connector.tlsConfig = config
}

// Open return a new open connection
func (driver *OracleDriver) Open(name string) (driver.Conn, error) {
conn, err := NewConnection(name, driver.connOption)
Expand Down
9 changes: 8 additions & 1 deletion v2/network/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,13 +336,20 @@ func (session *Session) LoadSSLData(certs, keys, certRequests [][]byte) error {
// used to create sslConn object
func (session *Session) negotiate() {
connOption := session.Context.connConfig
host := connOption.GetActiveServer(false)

if tlsConfig := connOption.TLSConfig; tlsConfig != nil {
tlsConfig.ServerName = host.Addr
session.sslConn = tls.Client(session.conn, tlsConfig)
return
}

if session.SSL.roots == nil && len(session.SSL.Certificates) > 0 {
session.SSL.roots = x509.NewCertPool()
for _, cert := range session.SSL.Certificates {
session.SSL.roots.AddCert(cert)
}
}
host := connOption.GetActiveServer(false)
config := &tls.Config{
ServerName: host.Addr,
}
Expand Down

0 comments on commit eaedcff

Please sign in to comment.