Skip to content

Commit

Permalink
feat: Automatially check for DNS changes periodically. On change, clo…
Browse files Browse the repository at this point in the history
…se all connections and create a new dialer.
  • Loading branch information
hessjcg committed Jul 17, 2024
1 parent 91302e2 commit fca359a
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 1 deletion.
39 changes: 38 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ func connect() {
// ... etc
}
```
### Using DNS to identify an instance
### Using DNS domain names to identify instances

The connector can be configured to use DNS to look up an instance. This would
allow you to configure your application to connect to a database instance, and
Expand Down Expand Up @@ -291,6 +291,43 @@ func connect() {
}
```

### Automatic Fail-over Using DNS domain names

TODO(After Public Preview): Include a note here about how to use this with
PSA private network fail-over instances.

When you configure the connector using a domain name, the connector will
periodically check if the DNS record for an instance changes. When connector
detects that the domain name refers to a different instance, the connector will
close all open connections to the old instance. Subsequent connection attempts
will be directed to the new instance.

For example: suppose your application is configured to connect using the
domain name `prod-db.mycompany.example.com`. Initially your corporate DNS
zone has a SRV record with the target `my-project:region:my-instance`. Your
application establishes connections to the `my-project:region:my-instance`
Cloud SQL instance.

Then, you decide to reconfigure your application using a different database
instance: `my-project:other-region:my-instance-2`. You update the DNS record
for `prod-db.mycompany.example.com` with the target
`my-project:other-region:my-instance-2`

The connector inside your application detects the change to this
DNS entry. Now, when your application connects to its database using the
domain name `prod-db.mycompany.example.com`, it will connect to the
`my-project:other-region:my-instance-2` Cloud SQL instance.

The connector will automatically close all existing connections to
`my-project:region:my-instance`. This may cause errors in your application as
database queries in progress will fail.

The connector will poll for changes to the DNS name every 30 seconds by default.
You may configure the frequency of the connections using the option
`WithFailoverPeriod(d time.Duration)`. When this is set to 0, the connector will
disable polling and only check if the DNS record changed when it is
creating a new connection.


### Using Options

Expand Down
76 changes: 76 additions & 0 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ type Dialer struct {
// resolver does SRV record DNS lookups when resolving DNS name dialer
// configuration.
resolver NetResolver

// domainNameTicker periodically checks any domain names to see if they
// changed.
domainNameTicker *time.Ticker
}

var (
Expand All @@ -185,6 +189,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
logger: nullLogger{},
useragents: []string{userAgent},
serviceUniverse: "googleapis.com",
failoverPeriod: cloudsql.FailoverPeriod,
}
for _, opt := range opts {
opt(cfg)
Expand Down Expand Up @@ -281,8 +286,23 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
dialFunc: cfg.dialFunc,
resolver: r,
}

// If the failover period is set, start a goroutine to periodically
// check for DNS changes.
if cfg.failoverPeriod > 0 {
d.initFailoverRoutine(ctx, cfg.failoverPeriod)
}

return d, nil
}
func (d *Dialer) initFailoverRoutine(ctx context.Context, p time.Duration) {
d.domainNameTicker = time.NewTicker(p)
go func() {
for range d.domainNameTicker.C {
d.pollDomainNames(ctx)
}
}()
}

func (d *Dialer) resolveInstanceName(ctx context.Context, icn string) (instance.ConnName, error) {
cn, err := instance.ParseConnName(icn)
Expand Down Expand Up @@ -547,12 +567,68 @@ func (d *Dialer) Close() error {
close(d.closed)
d.lock.Lock()
defer d.lock.Unlock()
if d.domainNameTicker != nil {
d.domainNameTicker.Stop()
}
for _, i := range d.cache {
i.Close()
}
return nil
}

func (d *Dialer) pollDomainNames(ctx context.Context) {
type cacheEntry struct {
cn instance.ConnName
cache monitoredCache
}

d.lock.RLock()
caches := make([]cacheEntry, 0, len(d.cache))
for cn, cache := range d.cache {

// Ignore cache entries that were not opened by domain name.
if cn.DomainName() == "" {
continue
}

caches = append(caches, cacheEntry{cn: cn, cache: cache})

// Clean up the list of openSockets - remove closed sockets
cache.lock.Lock()
var newOpenSockets []*instrumentedConn
for _, s := range cache.openSockets {
if !s.closed {
newOpenSockets = append(newOpenSockets, s)
}
}
cache.openSockets = newOpenSockets
cache.lock.Unlock()
}
d.lock.RUnlock()

for _, entry := range caches {
newCn, err := d.resolveInstanceName(ctx, entry.cn.DomainName())
// the domain name no longer resolves to a valid instance
if err != nil {
d.logger.Debugf(ctx, "[failover] unable to resolve DNS for instance %s: %v", entry.cn.DomainName(), err)
}

// The domain name points to a different instance.
if newCn != entry.cn {
d.logger.Debugf(ctx, "domain name %s changed from old instance %s to new instance %s",
entry.cn.DomainName(), entry.cn.String(), newCn.String())

d.closeDomainNameChanged(ctx, entry.cn, entry.cache,
fmt.Errorf("domain name %s changed from old instance %s to new instance %s",
entry.cn.DomainName(), entry.cn.String(), newCn.String()))
// preload the new cache entry
b := entry.cache.UseIAMAuthN()
d.connectionInfoCache(ctx, newCn, &b)
}
}

}

func (d *Dialer) closeDomainNameChanged(ctx context.Context, cn instance.ConnName, cache monitoredCache, err error) {
d.removeCached(ctx, cn, cache, err)
if atomic.LoadUint64(cache.openConns) > 0 {
Expand Down
43 changes: 43 additions & 0 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1195,3 +1195,46 @@ func TestDialerUpdatesOnDialAfterDnsChange(t *testing.T) {
"update.example.com", "my-instance2",
)
}

func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) {
// At first, the resolver will resolve
// update.example.com to "my-instance"
// Then, the resolver will resolve the same domain name to
// "my-instance2".
// This shows that on every call to Dial(), the dialer will resolve the
// SRV record and connect to the correct instance.
inst := mock.NewFakeCSQLInstance(
"my-project", "my-region", "my-instance",
)
inst2 := mock.NewFakeCSQLInstance(
"my-project", "my-region", "my-instance2",
)
r := &changingResolver{stage: new(int32)}

d := setupDialer(t, setupConfig{
testInstance: inst,
reqs: []*mock.Request{
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
mock.InstanceGetSuccess(inst2, 1),
mock.CreateEphemeralSuccess(inst2, 1),
},
dialerOptions: []Option{
WithResolver(r),
WithFailoverPeriod(10 * time.Millisecond),
WithTokenSource(mock.EmptyTokenSource{}),
},
})

// Start the proxy for instance 1
testSuccessfulDial(
context.Background(), t, d,
"update.example.com",
)

atomic.StoreInt32(r.stage, 1)
time.Sleep(1 * time.Second)
// The dialer should preload details for inst2. If it doesn't, then
// this test will fail because it didn't make enough API calls as
// defined in the setupConfig{}
}
5 changes: 5 additions & 0 deletions internal/cloudsql/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ const (
// refreshInterval.
RefreshTimeout = 60 * time.Second

// FailoverPeriod is the frequency with which the dialer will check
// if the DNS record has changed for connections configured using
// a DNS name.
FailoverPeriod = 30 * time.Second

// refreshBurst is the initial burst allowed by the rate limiter.
refreshBurst = 2
)
Expand Down
11 changes: 11 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ type dialerConfig struct {
setTokenSource bool
setIAMAuthNTokenSource bool
resolver NetResolver
failoverPeriod time.Duration
// err tracks any dialer options that may have failed.
err error
}
Expand Down Expand Up @@ -253,6 +254,16 @@ func WithResolver(r NetResolver) Option {
}
}

// WithFailoverPeriod will cause the connector to periodically check the SRV DNS
// records of instance configured using DNS names. By default, this is 30
// seconds. If this is set to 0, the connector will only check for domain name
// changes when establishing a new connection.
func WithFailoverPeriod(f time.Duration) Option {
return func(d *dialerConfig) {
d.failoverPeriod = f
}
}

type debugLoggerWithoutContext struct {
logger debug.Logger
}
Expand Down

0 comments on commit fca359a

Please sign in to comment.