Skip to content

Commit

Permalink
chore: Expose the refresh strategy UseIAMAuthN() value to the dialer.
Browse files Browse the repository at this point in the history
Part of #842

chore: Add domain name to the cloudsql.ConnName struct

feat: Check for DNS changes on connect. On change, close all connections and create a new dialer.

feat: Automatially check for DNS changes periodically. On change, close all connections and create a new dialer.

wip: eno changes

wip: eno interface cleanup

wip: convert monitoredInstance to *monitoredInstance
  • Loading branch information
hessjcg committed Sep 10, 2024
1 parent 589f9e6 commit 406e1cb
Show file tree
Hide file tree
Showing 6 changed files with 318 additions and 20 deletions.
38 changes: 37 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ func connect() {
// ... etc
}
```
### Using DNS to identify an instance

### Using DNS domain names to identify instances

The connector can be configured to use DNS to look up an instance. This would
allow you to configure your application to connect to a database instance, and
Expand Down Expand Up @@ -292,6 +293,41 @@ func connect() {
}
```

### Automatic fail-over using DNS domain names

When the connector is configured using a domain name, the connector will
periodically check if the DNS record for an instance changes. When the connector
detects that the domain name refers to a different instance, the connector will
close all open connections to the old instance. Subsequent connection attempts
will be directed to the new instance.

For example: suppose application is configured to connect using the
domain name `prod-db.mycompany.example.com`. Initially the corporate DNS
zone has a TXT record with the value `my-project:region:my-instance`. The
application establishes connections to the `my-project:region:my-instance`
Cloud SQL instance.

Then, to reconfigure the application using a different database
instance: `my-project:other-region:my-instance-2`. You update the DNS record
for `prod-db.mycompany.example.com` with the target
`my-project:other-region:my-instance-2`

The connector inside the application detects the change to this
DNS entry. Now, when the application connects to its database using the
domain name `prod-db.mycompany.example.com`, it will connect to the
`my-project:other-region:my-instance-2` Cloud SQL instance.

The connector will automatically close all existing connections to
`my-project:region:my-instance`. This will force the connection pools to
establish new connections. Also, it may cause database queries in progress
to fail.

The connector will poll for changes to the DNS name every 30 seconds by default.
You may configure the frequency of the connections using the option
`WithFailoverPeriod(d time.Duration)`. When this is set to 0, the connector will
disable polling and only check if the DNS record changed when it is
creating a new connection.


### Using Options

Expand Down
129 changes: 121 additions & 8 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,17 @@ type connectionInfoCache interface {
ConnectionInfo(context.Context) (cloudsql.ConnectionInfo, error)
UpdateRefresh(*bool)
ForceRefresh()
UseIAMAuthN() bool
io.Closer
}

// monitoredCache is a wrapper around a connectionInfoCache that tracks the
// number of connections to the associated instance.
type monitoredCache struct {
openConns *uint64
openConnsCount *uint64

mu sync.Mutex
openConns []*instrumentedConn

connectionInfoCache
}
Expand All @@ -122,6 +126,16 @@ func (c *monitoredCache) Close() error {
if c == nil || c.connectionInfoCache == nil {
return nil
}

if atomic.LoadUint64(c.openConnsCount) > 0 {
for _, socket := range c.openConns {
if !socket.isClosed() {
_ = socket.Close() // force socket closed, ok to ignore error.
}
}
atomic.StoreUint64(c.openConnsCount, 0)
}

return c.connectionInfoCache.Close()
}

Expand All @@ -145,6 +159,21 @@ func (c *monitoredCache) ConnectionInfo(ctx context.Context) (cloudsql.Connectio
return c.connectionInfoCache.ConnectionInfo(ctx)
}

func (c *monitoredCache) purgeClosedConns() {
if c == nil || c.connectionInfoCache == nil {
return
}
c.mu.Lock()
var open []*instrumentedConn
for _, s := range c.openConns {
if !s.isClosed() {
open = append(open, s)
}
}
c.openConns = open
c.mu.Unlock()
}

// A Dialer is used to create connections to Cloud SQL instances.
//
// Use NewDialer to initialize a Dialer.
Expand Down Expand Up @@ -182,6 +211,10 @@ type Dialer struct {

// resolver converts instance names into DNS names.
resolver instance.ConnectionNameResolver

// domainNameTicker periodically checks any domain names to see if they
// changed.
domainNameTicker *time.Ticker
}

var (
Expand All @@ -205,6 +238,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
logger: nullLogger{},
useragents: []string{userAgent},
serviceUniverse: "googleapis.com",
failoverPeriod: cloudsql.FailoverPeriod,
}
for _, opt := range opts {
opt(cfg)
Expand All @@ -218,6 +252,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
if cfg.setIAMAuthNTokenSource && !cfg.useIAMAuthN {
return nil, errUseTokenSource
}

// Add this to the end to make sure it's not overridden
cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithUserAgent(strings.Join(cfg.useragents, " ")))

Expand All @@ -231,7 +266,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
}
ud, err := c.GetUniverseDomain()
if err != nil {
return nil, fmt.Errorf("failed to getOrAdd universe domain: %v", err)
return nil, fmt.Errorf("failed to get universe domain: %v", err)
}
cfg.credentialsUniverse = ud
cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithTokenSource(c.TokenSource))
Expand Down Expand Up @@ -301,8 +336,28 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
dialFunc: cfg.dialFunc,
resolver: r,
}

// If the failover period is set, start a goroutine to periodically
// check for DNS changes.
if cfg.failoverPeriod > 0 {
d.initFailoverRoutine(ctx, cfg.failoverPeriod)
}

return d, nil
}
func (d *Dialer) initFailoverRoutine(ctx context.Context, p time.Duration) {
d.domainNameTicker = time.NewTicker(p)
go func() {
for {
select {
case <-d.domainNameTicker.C:
d.pollDomainNames(ctx)
case <-d.closed:
return
}
}
}()
}

// Dial returns a net.Conn connected to the specified Cloud SQL instance. The
// icn argument must be the instance's connection name, which is in the format
Expand Down Expand Up @@ -406,16 +461,23 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn

latency := time.Since(startTime).Milliseconds()
go func() {
n := atomic.AddUint64(c.openConns, 1)
n := atomic.AddUint64(c.openConnsCount, 1)
trace.RecordOpenConnections(ctx, int64(n), d.dialerID, cn.String())
trace.RecordDialLatency(ctx, icn, d.dialerID, latency)
}()

iConn := newInstrumentedConn(tlsConn, func() {
n := atomic.AddUint64(c.openConns, ^uint64(0))
n := atomic.AddUint64(c.openConnsCount, ^uint64(0))
trace.RecordOpenConnections(context.Background(), int64(n), d.dialerID, cn.String())
}, d.dialerID, cn.String())

// If this connection was opened using a Domain Name, then store it for later
// in case it needs to be forcibly closed.
if cn.DomainName() != "" {
c.mu.Lock()
c.openConns = append(c.openConns, iConn)
c.mu.Unlock()
}
return iConn, nil
}

Expand Down Expand Up @@ -520,6 +582,7 @@ func newInstrumentedConn(conn net.Conn, closeFunc func(), dialerID, connName str
type instrumentedConn struct {
net.Conn
closeFunc func()
mu sync.RWMutex
closed bool
dialerID string
connName string
Expand All @@ -545,6 +608,13 @@ func (i *instrumentedConn) Write(b []byte) (int, error) {
return bytesWritten, err
}

// isClosed returns true if this connection is closing or is already closed.
func (i *instrumentedConn) isClosed() bool {
i.mu.RLock()
defer i.mu.RUnlock()
return i.closed
}

// Close delegates to the underlying net.Conn interface and reports the close
// to the provided closeFunc only when Close returns no error.
func (i *instrumentedConn) Close() error {
Expand All @@ -568,13 +638,56 @@ func (d *Dialer) Close() error {
}
close(d.closed)

d.cache.replaceAll(func(cn instance.ConnName, c monitoredCache) (instance.ConnName, monitoredCache) {
c.Close() // close the monitoredCache
return instance.ConnName{}, monitoredCache{} // Remove from cache
if d.domainNameTicker != nil {
d.domainNameTicker.Stop()
}

d.cache.replaceAll(func(cn instance.ConnName, c *monitoredCache) (instance.ConnName, *monitoredCache) {

Check failure on line 645 in dialer.go

View workflow job for this annotation

GitHub Actions / Run lint

unused-parameter: parameter 'cn' seems to be unused, consider removing or renaming it as _ (revive)
c.Close() // close the monitoredCache
return instance.ConnName{}, nil // Remove from cache
})
return nil
}

func (d *Dialer) pollDomainNames(ctx context.Context) {
d.cache.replaceAll(func(cn instance.ConnName, cache *monitoredCache) (instance.ConnName, *monitoredCache) {
if cn.DomainName() == "" {
return cn, cache
}

// Resolve the domain name.
newCn, err := d.resolver.Resolve(ctx, cn.DomainName())

if err != nil {
// the domain name no longer resolves to a valid instance
d.logger.Debugf(ctx, "[failover] unable to resolve DNS for instance %s: %v", cn.DomainName(), err)
cache.Close()
return instance.ConnName{}, nil
} else if newCn != cn {
d.logger.Debugf(ctx, "domain name %s changed from old instance %s to new instance %s",
cn.DomainName(), cn.String(), newCn.String())

useIamAuthn := cache.UseIAMAuthN()
// The domain name points to a different instance.
cache.Close()

newC, err := d.createConnectionInfoCache(ctx, cn, &useIamAuthn)
if err != nil {
d.logger.Debugf(ctx, "error connecting to new instance %s, %s: %v",
cn.DomainName(), newCn.String(), err)
return instance.ConnName{}, nil
}
return newCn, newC
}

// Remove closed sockets from cache.openConns
cache.purgeClosedConns()
return cn, cache

})

}

// connectionInfoCache is a helper function for returning the appropriate
// connection info Cache in a threadsafe way. It will create a new cache,
// modify the existing one, or leave it unchanged as needed.
Expand Down Expand Up @@ -624,7 +737,7 @@ func (d *Dialer) createConnectionInfoCache(
)
}
c := &monitoredCache{
openConns: new(uint64),
openConnsCount: new(uint64),
connectionInfoCache: cache,
}

Expand Down
Loading

0 comments on commit 406e1cb

Please sign in to comment.