diff --git a/README.md b/README.md index 19e69d0f..313b90e7 100644 --- a/README.md +++ b/README.md @@ -234,6 +234,64 @@ func connect() { // ... etc } ``` +### Using DNS to identify an instance + +The connector can be configured to use DNS to look up an instance. This would +allow you to configure your application to connect to a database instance, and +centrally configure which instance in your DNS zone. + +#### Configure your DNS Records + +Add a DNS TXT record for the Cloud SQL instance to a **private** DNS server +or a private Google Cloud DNS Zone used by your application. + +**Note:** You are strongly discouraged from adding DNS records for your +Cloud SQL instances to a public DNS server. This would allow anyone on the +internet to discover the Cloud SQL instance name. + +For example: suppose you wanted to use the domain name +`prod-db.mycompany.example.com` to connect to your database instance +`my-project:region:my-instance`. You would create the following DNS record: + +- Record type: `TXT` +- Name: `prod-db.mycompany.example.com` – This is the domain name used by the application +- Value: `my-project:region:my-instance` – This is the instance name + +#### Configure the connector + +Configure the connector as described above, replacing the conenctor ID with +the DNS name. + +Adapting the MySQL + database/sql example above: + +```go +package main + +import ( + "database/sql" + + "cloud.google.com/go/cloudsqlconn" + "cloud.google.com/go/cloudsqlconn/mysql/mysql" +) + +func connect() { + cleanup, err := mysql.RegisterDriver("cloudsql-mysql", + cloudsqlconn.WithDNSResolver(), + cloudsqlconn.WithCredentialsFile("key.json")) + if err != nil { + // ... handle error + } + // call cleanup when you're done with the database connection + defer cleanup() + + db, err := sql.Open( + "cloudsql-mysql", + "myuser:mypass@cloudsql-mysql(prod-db.mycompany.example.com)/mydb", + ) + // ... etc +} +``` + ### Using Options diff --git a/dialer.go b/dialer.go index b1f859fd..d5b0ff9b 100644 --- a/dialer.go +++ b/dialer.go @@ -153,6 +153,9 @@ type Dialer struct { // iamTokenSource supplies the OAuth2 token used for IAM DB Authn. iamTokenSource oauth2.TokenSource + + // resolver converts instance names into DNS names. + resolver instance.ConnectionNameResolver } var ( @@ -253,6 +256,11 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { if err != nil { return nil, err } + var r instance.ConnectionNameResolver = cloudsql.DefaultResolver + if cfg.resolver != nil { + r = cfg.resolver + } + d := &Dialer{ closed: make(chan struct{}), cache: make(map[instance.ConnName]monitoredCache), @@ -265,6 +273,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { dialerID: uuid.New().String(), iamTokenSource: cfg.iamLoginTokenSource, dialFunc: cfg.dialFunc, + resolver: r, } return d, nil } @@ -288,7 +297,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn go trace.RecordDialError(context.Background(), icn, d.dialerID, err) endDial(err) }() - cn, err := instance.ParseConnName(icn) + cn, err := d.resolver.Resolve(ctx, icn) if err != nil { return nil, err } @@ -429,7 +438,7 @@ func validClientCert( // the instance: // https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1beta4/SqlDatabaseVersion func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error) { - cn, err := instance.ParseConnName(icn) + cn, err := d.resolver.Resolve(ctx, icn) if err != nil { return "", err } @@ -449,7 +458,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error) // Use Warmup to start the refresh process early if you don't know when you'll // need to call "Dial". func (d *Dialer) Warmup(ctx context.Context, icn string, opts ...DialOption) error { - cn, err := instance.ParseConnName(icn) + cn, err := d.resolver.Resolve(ctx, icn) if err != nil { return err } diff --git a/dialer_test.go b/dialer_test.go index 93a801ba..2e640af4 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -39,6 +39,14 @@ import ( // and verifies the connection works end to end. func testSuccessfulDial( ctx context.Context, t *testing.T, d *Dialer, icn string, opts ...DialOption, +) { + testSucessfulDialWithInstanceName(ctx, t, d, icn, "my-instance", opts...) +} + +// testSuccessfulDial uses the provided dialer to dial the specified instance +// and verifies the connection works end to end. +func testSucessfulDialWithInstanceName( + ctx context.Context, t *testing.T, d *Dialer, icn string, instanceName string, opts ...DialOption, ) { conn, err := d.Dial(ctx, icn, opts...) if err != nil { @@ -50,7 +58,7 @@ func testSuccessfulDial( if err != nil { t.Fatalf("expected ReadAll to succeed, got error %v", err) } - if string(data) != "my-instance" { + if string(data) != instanceName { t.Fatalf( "expected known response from the server, but got %v", string(data), @@ -1018,3 +1026,62 @@ func TestDialerInitializesLazyCache(t *testing.T) { t.Fatalf("dialer was initialized with non-lazy type: %T", tt) } } + +type fakeResolver struct { + domainName string + instanceName instance.ConnName +} + +func (r *fakeResolver) Resolve(_ context.Context, name string) (instance.ConnName, error) { + // For TestDialerSuccessfullyDialsDnsTxtRecord + if name == r.domainName { + return r.instanceName, nil + } + // TestDialerFailsDnsTxtRecordMissing + return instance.ConnName{}, fmt.Errorf("no resolution for %q", name) +} + +func TestDialerSuccessfullyDialsDnsTxtRecord(t *testing.T) { + inst := mock.NewFakeCSQLInstance( + "my-project", "my-region", "my-instance", + ) + wantName, _ := instance.ParseConnName("my-project:my-region:my-instance") + d := setupDialer(t, setupConfig{ + testInstance: inst, + reqs: []*mock.Request{ + mock.InstanceGetSuccess(inst, 1), + mock.CreateEphemeralSuccess(inst, 1), + }, + dialerOptions: []Option{ + WithTokenSource(mock.EmptyTokenSource{}), + WithResolver(&fakeResolver{ + domainName: "db.example.com", + instanceName: wantName, + }), + }, + }) + + testSuccessfulDial( + context.Background(), t, d, + "db.example.com", + ) +} + +func TestDialerFailsDnsTxtRecordMissing(t *testing.T) { + inst := mock.NewFakeCSQLInstance( + "my-project", "my-region", "my-instance", + ) + d := setupDialer(t, setupConfig{ + testInstance: inst, + reqs: []*mock.Request{}, + dialerOptions: []Option{ + WithTokenSource(mock.EmptyTokenSource{}), + WithResolver(&fakeResolver{}), + }, + }) + _, err := d.Dial(context.Background(), "doesnt-exist.example.com") + wantMsg := "no resolution for \"doesnt-exist.example.com\"" + if !strings.Contains(err.Error(), wantMsg) { + t.Fatalf("want = %v, got = %v", wantMsg, err) + } +} diff --git a/instance/conn_name.go b/instance/conn_name.go index d4f33601..2dd3de73 100644 --- a/instance/conn_name.go +++ b/instance/conn_name.go @@ -15,6 +15,7 @@ package instance import ( + "context" "fmt" "regexp" @@ -74,3 +75,13 @@ func ParseConnName(cn string) (ConnName, error) { } return c, nil } + +// ConnectionNameResolver resolves the connection name string into a valid +// instance name. This allows an application to replace the default +// resolver with a custom implementation. +type ConnectionNameResolver interface { + // Resolve accepts a name, and returns a ConnName with the instance + // connection string for the name. If the name cannot be resolved, returns + // an error. + Resolve(ctx context.Context, name string) (ConnName, error) +} diff --git a/internal/cloudsql/resolver.go b/internal/cloudsql/resolver.go new file mode 100644 index 00000000..676405e4 --- /dev/null +++ b/internal/cloudsql/resolver.go @@ -0,0 +1,123 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudsql + +import ( + "context" + "fmt" + "net" + "sort" + + "cloud.google.com/go/cloudsqlconn/instance" +) + +// DNSResolver uses the default net.Resolver to find +// TXT records containing an instance name for a DNS record. +var DNSResolver = &DNSInstanceConnectionNameResolver{ + dnsResolver: net.DefaultResolver, +} + +// DefaultResolver simply parses instance names. +var DefaultResolver = &ConnNameResolver{} + +// ConnNameResolver simply parses instance names. Implements +// InstanceConnectionNameResolver +type ConnNameResolver struct { +} + +// Resolve returns the instance name, possibly using DNS. This will return an +// instance.ConnName or an error if it was unable to resolve an instance name. +func (r *ConnNameResolver) Resolve(_ context.Context, icn string) (instanceName instance.ConnName, err error) { + return instance.ParseConnName(icn) +} + +// netResolver groups the methods on net.Resolver that are used by the DNS +// resolver implementation. This allows an application to replace the default +// net.DefaultResolver with a custom implementation. For example: the +// application may need to connect to a specific DNS server using a specially +// configured instance of net.Resolver. +type netResolver interface { + LookupTXT(ctx context.Context, name string) ([]string, error) +} + +// DNSInstanceConnectionNameResolver can resolve domain names into instance names using +// TXT records in DNS. Implements InstanceConnectionNameResolver +type DNSInstanceConnectionNameResolver struct { + dnsResolver netResolver +} + +// Resolve returns the instance name, possibly using DNS. This will return an +// instance.ConnName or an error if it was unable to resolve an instance name. +func (r *DNSInstanceConnectionNameResolver) Resolve(ctx context.Context, icn string) (instanceName instance.ConnName, err error) { + cn, err := instance.ParseConnName(icn) + if err != nil { + // The connection name was not project:region:instance + // Attempt to query a TXT record and see if it works instead. + cn, err = r.queryDNS(ctx, icn) + if err != nil { + return instance.ConnName{}, err + } + } + + return cn, nil +} + +// queryDNS attempts to resolve a TXT record for the domain name. +// The DNS TXT record's target field is used as instance name. +// +// This handles several conditions where the DNS records may be missing or +// invalid: +// - The domain name resolves to 0 DNS records - return an error +// - Some DNS records to not contain a well-formed instance name - return the +// first well-formed instance name. If none found return an error. +// - The domain name resolves to 2 or more DNS record - return first valid +// record when sorted by priority: lowest value first, then by target: +// alphabetically. +func (r *DNSInstanceConnectionNameResolver) queryDNS(ctx context.Context, domainName string) (instance.ConnName, error) { + // Attempt to query the TXT records. + // This could return a partial error where both err != nil && len(records) > 0. + records, err := r.dnsResolver.LookupTXT(ctx, domainName) + // If resolve failed and no records were found, return the error. + if err != nil { + return instance.ConnName{}, fmt.Errorf("unable to resolve TXT record for %q: %v", domainName, err) + } + + // Process the records returning the first valid TXT record. + + // Sort the TXT record values alphabetically by instance name + sort.Slice(records, func(i, j int) bool { + return records[i] < records[j] + }) + + var perr error + // Attempt to parse records, returning the first valid record. + for _, record := range records { + // Parse the target as a CN + cn, parseErr := instance.ParseConnName(record) + if parseErr != nil { + perr = fmt.Errorf("unable to parse TXT for %q -> %q : %v", domainName, record, parseErr) + continue + } + return cn, nil + } + + // If all the records failed to parse, return one of the parse errors + if perr != nil { + return instance.ConnName{}, perr + } + + // No records were found, return an error. + return instance.ConnName{}, fmt.Errorf("no valid TXT records found for %q", domainName) +} diff --git a/internal/cloudsql/resolver_test.go b/internal/cloudsql/resolver_test.go new file mode 100644 index 00000000..cd0dea9f --- /dev/null +++ b/internal/cloudsql/resolver_test.go @@ -0,0 +1,81 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudsql + +import ( + "context" + "fmt" + "strings" + "testing" + + "cloud.google.com/go/cloudsqlconn/instance" +) + +type fakeResolver struct { + name string + value string +} + +func (r *fakeResolver) LookupTXT(_ context.Context, name string) (addrs []string, err error) { + if name == r.name { + return []string{r.value}, nil + } + return nil, fmt.Errorf("no resolution for %v", name) +} + +func TestDNSInstanceNameResolver_Lookup_Success_TxtRecord(t *testing.T) { + want, _ := instance.ParseConnName("my-project:my-region:my-instance") + + r := DNSInstanceConnectionNameResolver{ + dnsResolver: &fakeResolver{ + name: "db.example.com", + value: "my-project:my-region:my-instance", + }, + } + got, err := r.Resolve(context.Background(), "db.example.com") + if err != nil { + t.Fatal("got error", err) + } + if got != want { + t.Fatal("Got", got, "Want", want) + } + +} + +func TestDNSInstanceNameResolver_Lookup_Fails_TxtRecordMissing(t *testing.T) { + r := DNSInstanceConnectionNameResolver{ + dnsResolver: &fakeResolver{}, + } + _, err := r.Resolve(context.Background(), "doesnt-exist.example.com") + + wantMsg := "unable to resolve TXT record for \"doesnt-exist.example.com\"" + if !strings.Contains(err.Error(), wantMsg) { + t.Fatalf("want = %v, got = %v", wantMsg, err) + } +} + +func TestDNSInstanceNameResolver_Lookup_Fails_TxtRecordMalformed(t *testing.T) { + r := DNSInstanceConnectionNameResolver{ + dnsResolver: &fakeResolver{ + name: "malformed.example.com", + value: "invalid-instance-name", + }, + } + _, err := r.Resolve(context.Background(), "malformed.example.com") + wantMsg := "unable to parse TXT for \"malformed.example.com\"" + if !strings.Contains(err.Error(), wantMsg) { + t.Fatalf("want = %v, got = %v", wantMsg, err) + } +} diff --git a/options.go b/options.go index 25e6ae86..c21fcd2a 100644 --- a/options.go +++ b/options.go @@ -24,6 +24,7 @@ import ( "cloud.google.com/go/cloudsqlconn/debug" "cloud.google.com/go/cloudsqlconn/errtype" + "cloud.google.com/go/cloudsqlconn/instance" "cloud.google.com/go/cloudsqlconn/internal/cloudsql" "golang.org/x/oauth2" "golang.org/x/oauth2/google" @@ -52,6 +53,7 @@ type dialerConfig struct { setCredentials bool setTokenSource bool setIAMAuthNTokenSource bool + resolver instance.ConnectionNameResolver // err tracks any dialer options that may have failed. err error } @@ -234,6 +236,41 @@ func WithIAMAuthN() Option { } } +// WithResolver replaces the default resolver with an alternate +// implementation to resolve the name in the database DSN to a Cloud SQL +// instance. +func WithResolver(r instance.ConnectionNameResolver) Option { + return func(d *dialerConfig) { + d.resolver = r + } +} + +// WithDNSResolver replaces the default resolver (which only resolves instance +// names) with the DNSResolver, which will attempt to first parse the instance +// name, and then will attempt to resolve the DNS TXT record to determine +// the instance name. +// +// First, add a record for your Cloud SQL instance to a **private** DNS server +// or a private Google Cloud DNS Zone used by your application. +// +// **Note:** You are strongly discouraged from adding DNS records for your +// Cloud SQL instances to a public DNS server. This would allow anyone on the +// internet to discover the Cloud SQL instance name. +// +// For example: suppose you wanted to use the domain name +// `prod-db.mycompany.example.com` to connect to your database instance +// `my-project:region:my-instance`. You would create the following DNS record: +// +// - Record type: `TXT` +// - Name: `prod-db.mycompany.example.com` – This is the domain name used by +// the application +// - Value: `my-project:region:my-instance` – This is the instance name +func WithDNSResolver() Option { + return func(d *dialerConfig) { + d.resolver = cloudsql.DNSResolver + } +} + type debugLoggerWithoutContext struct { logger debug.Logger }