Skip to content

Commit

Permalink
Add support for IAM Auth for Google CloudSQL DBs (#22445)
Browse files Browse the repository at this point in the history
  • Loading branch information
kpcraig authored Sep 6, 2023
1 parent 2ca784a commit 2172786
Show file tree
Hide file tree
Showing 11 changed files with 1,024 additions and 41 deletions.
3 changes: 3 additions & 0 deletions changelog/22445.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:feature
**GCP IAM Support**: Adds support for IAM-based authentication to MySQL and PostgreSQL backends using Google Cloud SQL.
```
9 changes: 5 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ replace github.com/hashicorp/vault/api/auth/userpass => ./api/auth/userpass
replace github.com/hashicorp/vault/sdk => ./sdk

require (
cloud.google.com/go/cloudsqlconn v1.4.3
cloud.google.com/go/monitoring v1.15.1
cloud.google.com/go/spanner v1.47.0
cloud.google.com/go/storage v1.30.1
Expand Down Expand Up @@ -50,7 +51,7 @@ require (
github.com/client9/misspell v0.3.4
github.com/cockroachdb/cockroach-go v0.0.0-20181001143604-e0a95dfd547c
github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf
github.com/denisenkom/go-mssqldb v0.12.2
github.com/denisenkom/go-mssqldb v0.12.3
github.com/duosecurity/duo_api_golang v0.0.0-20190308151101-6c680f768e74
github.com/dustin/go-humanize v1.0.1
github.com/fatih/color v1.15.0
Expand All @@ -62,7 +63,7 @@ require (
github.com/go-git/go-git/v5 v5.7.0
github.com/go-jose/go-jose/v3 v3.0.0
github.com/go-ldap/ldap/v3 v3.4.4
github.com/go-sql-driver/mysql v1.6.0
github.com/go-sql-driver/mysql v1.7.1
github.com/go-test/deep v1.1.0
github.com/go-zookeeper/zk v1.0.3
github.com/gocql/gocql v1.0.0
Expand Down Expand Up @@ -217,7 +218,7 @@ require (
golang.org/x/sys v0.12.0
golang.org/x/term v0.12.0
golang.org/x/text v0.13.0
golang.org/x/tools v0.8.0
golang.org/x/tools v0.9.1
google.golang.org/api v0.138.0
google.golang.org/grpc v1.57.0
google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0
Expand Down Expand Up @@ -361,7 +362,7 @@ require (
github.com/gofrs/uuid v4.3.0+incompatible // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v5 v5.0.0 // indirect
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe // indirect
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/snappy v0.0.4 // indirect
Expand Down
206 changes: 200 additions & 6 deletions go.sum

Large diffs are not rendered by default.

101 changes: 100 additions & 1 deletion plugins/database/mysql/connection_producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"sync"
"time"

cloudmysql "cloud.google.com/go/cloudsqlconn/mysql/mysql"
"github.com/go-sql-driver/mysql"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/go-uuid"
Expand All @@ -21,6 +22,11 @@ import (
"github.com/mitchellh/mapstructure"
)

const (
cloudSQLMySQL = "cloudsql-mysql"
driverMySQL = "mysql"
)

// mySQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases
type mySQLConnectionProducer struct {
ConnectionURL string `json:"connection_url" mapstructure:"connection_url" structs:"connection_url"`
Expand All @@ -29,6 +35,8 @@ type mySQLConnectionProducer struct {
MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" mapstructure:"max_connection_lifetime" structs:"max_connection_lifetime"`
Username string `json:"username" mapstructure:"username" structs:"username"`
Password string `json:"password" mapstructure:"password" structs:"password"`
AuthType string `json:"auth_type" mapstructure:"auth_type" structs:"auth_type"`
ServiceAccountJSON string `json:"service_account_json" mapstructure:"service_account_json" structs:"service_account_json"`

TLSCertificateKeyData []byte `json:"tls_certificate_key" mapstructure:"tls_certificate_key" structs:"-"`
TLSCAData []byte `json:"tls_ca" mapstructure:"tls_ca" structs:"-"`
Expand All @@ -38,6 +46,10 @@ type mySQLConnectionProducer struct {
// tlsConfigName is a globally unique name that references the TLS config for this instance in the mysql driver
tlsConfigName string

// cloudDriverName is a globally unique name that references the cloud dialer config for this instance of the driver
cloudDriverName string
cloudDialerCleanup func() error

RawConfig map[string]interface{}
maxConnectionLifetime time.Duration
Initialized bool
Expand Down Expand Up @@ -110,6 +122,32 @@ func (c *mySQLConnectionProducer) Init(ctx context.Context, conf map[string]inte
mysql.RegisterTLSConfig(c.tlsConfigName, tlsConfig)
}

// validate auth_type if provided
authType := c.AuthType
if authType != "" {
if ok := connutil.ValidateAuthType(authType); !ok {
return nil, fmt.Errorf("invalid auth_type %s provided", authType)
}
}

if c.AuthType == connutil.AuthTypeGCPIAM {
c.cloudDriverName, err = uuid.GenerateUUID()
if err != nil {
return nil, fmt.Errorf("unable to generate UUID for IAM configuration: %w", err)
}

// for _most_ sql databases, the driver itself contains no state. In the case of google's cloudsql drivers,
// however, the driver might store a credentials file, in which case the state stored by the driver is in
// fact critical to the proper function of the connection. So it needs to be registered here inside the
// ConnectionProducer init.
dialerCleanup, err := registerDriverMySQL(c.cloudDriverName, c.ServiceAccountJSON)
if err != nil {
return nil, err
}

c.cloudDialerCleanup = dialerCleanup
}

// Set initialized to true at this point since all fields are set,
// and the connection can be established at a later time.
c.Initialized = true
Expand Down Expand Up @@ -140,14 +178,33 @@ func (c *mySQLConnectionProducer) Connection(ctx context.Context) (interface{},
// If the ping was unsuccessful, close it and ignore errors as we'll be
// reestablishing anyways
c.db.Close()

// if IAM authentication was enabled
// ensure open dialer is also closed
if c.AuthType == connutil.AuthTypeGCPIAM {
if c.cloudDialerCleanup != nil {
c.cloudDialerCleanup()
}
}

}

driverName := driverMySQL
if c.cloudDriverName != "" {
driverName = c.cloudDriverName
}

connURL, err := c.addTLStoDSN()
if err != nil {
return nil, err
}

c.db, err = sql.Open("mysql", connURL)
cloudURL, err := c.rewriteProtocolForGCP(connURL)
if err != nil {
return nil, err
}

c.db, err = sql.Open(driverName, cloudURL)
if err != nil {
return nil, err
}
Expand All @@ -174,6 +231,13 @@ func (c *mySQLConnectionProducer) Close() error {
defer c.Unlock()

if c.db != nil {
// if auth_type is IAM, ensure cleanup
// of cloudSQL resources
if c.AuthType == connutil.AuthTypeGCPIAM {
if c.cloudDialerCleanup != nil {
c.cloudDialerCleanup()
}
}
c.db.Close()
}

Expand Down Expand Up @@ -230,3 +294,38 @@ func (c *mySQLConnectionProducer) addTLStoDSN() (connURL string, err error) {
connURL = config.FormatDSN()
return connURL, nil
}

// rewriteProtocolForGCP rewrites the protocol in the DSN to contain the protocol name associated
// with the dialer and therefore driver associated with the provided cloudsqlconn.DialerOpts.
// As a safety/sanity check, it will only do this for protocol "cloudsql-mysql", the name GCP uses in its documentation.
//
// For example, it will rewrite the dsn "user@cloudsql-mysql(zone:region:instance)/ to
// "user@the-uuid-generated(zone:region:instance)/
func (c *mySQLConnectionProducer) rewriteProtocolForGCP(inDSN string) (string, error) {
if c.cloudDriverName == "" {
// unchanged if not cloud
return inDSN, nil
}

config, err := mysql.ParseDSN(inDSN)
if err != nil {
return "", fmt.Errorf("unable to parse connectionURL: %s", err)
}

if config.Net != cloudSQLMySQL {
return "", fmt.Errorf("didn't update net name because it wasn't what we expected as a placeholder: %s", config.Net)
}

config.Net = c.cloudDriverName

return config.FormatDSN(), nil
}

func registerDriverMySQL(driverName, credentials string) (cleanup func() error, err error) {
opts, err := connutil.GetCloudSQLAuthOptions(credentials)
if err != nil {
return nil, err
}

return cloudmysql.RegisterDriver(driverName, opts...)
}
80 changes: 78 additions & 2 deletions plugins/database/mysql/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,21 @@ import (
"context"
"database/sql"
"fmt"
"os"
"strings"
"testing"
"time"

stdmysql "github.com/go-sql-driver/mysql"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/stretchr/testify/require"

mysqlhelper "github.com/hashicorp/vault/helper/testhelpers/mysql"
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/hashicorp/vault/sdk/database/dbplugin/v5"
dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing"
"github.com/hashicorp/vault/sdk/database/helper/connutil"
"github.com/hashicorp/vault/sdk/database/helper/credsutil"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/stretchr/testify/require"
)

var _ dbplugin.Database = (*MySQL)(nil)
Expand All @@ -44,6 +47,79 @@ func TestMySQL_Initialize(t *testing.T) {
}
}

// TestMySQL_Initialize_CloudGCP validates the proper initialization of a MySQL backend pointing
// to a GCP CloudSQL MySQL instance. This expects some external setup (exact TBD)
func TestMySQL_Initialize_CloudGCP(t *testing.T) {
envConnURL := "CONNECTION_URL"
connURL := os.Getenv(envConnURL)
if connURL == "" {
t.Skipf("env var %s not set, skipping test", envConnURL)
}

credStr := dbtesting.GetGCPTestCredentials(t)

tests := map[string]struct {
req dbplugin.InitializeRequest
wantErr bool
expectedError string
}{
"empty auth type": {
req: dbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
"auth_type": "",
},
},
},
"invalid auth type": {
req: dbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
"auth_type": "invalid",
},
},
wantErr: true,
expectedError: "invalid auth_type",
},
"JSON credentials": {
req: dbplugin.InitializeRequest{
Config: map[string]interface{}{
"connection_url": connURL,
"auth_type": connutil.AuthTypeGCPIAM,
"service_account_json": credStr,
},
VerifyConnection: true,
},
},
}

for n, tc := range tests {
t.Run(n, func(t *testing.T) {
db := newMySQL(DefaultUserNameTemplate)
defer dbtesting.AssertClose(t, db)
_, err := db.Initialize(context.Background(), tc.req)

if tc.wantErr {
if err == nil {
t.Fatalf("expected error but received nil")
}

if !strings.Contains(err.Error(), tc.expectedError) {
t.Fatalf("expected error %s, got %s", tc.expectedError, err.Error())
}
} else {
if err != nil {
t.Fatalf("expected no error, received %s", err)
}

if !db.Initialized {
t.Fatal("Database should be initialized")
}
}
})
}
}

func testInitialize(t *testing.T, rootPassword string) {
cleanup, connURL := mysqlhelper.PrepareTestContainer(t, false, rootPassword)
defer cleanup()
Expand Down
Loading

0 comments on commit 2172786

Please sign in to comment.