Skip to content

fix SRV and MX records #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"net/url"
"strings"

"github.com/libdns/libdns"
)
Expand Down Expand Up @@ -61,7 +62,7 @@ func (p *Provider) updateRecord(ctx context.Context, oldRec, newRec cfDNSRecord)
func (p *Provider) getDNSRecords(ctx context.Context, zoneInfo cfZone, rec libdns.Record, matchContent bool) ([]cfDNSRecord, error) {
qs := make(url.Values)
qs.Set("type", rec.Type)
qs.Set("name", libdns.AbsoluteName(rec.Name, zoneInfo.Name))
qs.Set("name", libdns.AbsoluteName(strings.TrimSuffix(rec.Name, zoneInfo.Name), zoneInfo.Name))
if matchContent {
qs.Set("content", rec.Value)
}
Expand Down
103 changes: 82 additions & 21 deletions models.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package cloudflare

import (
"encoding/json"
"fmt"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -56,6 +58,7 @@ type cfDNSRecord struct {
Content string `json:"content,omitempty"`
Proxiable bool `json:"proxiable,omitempty"`
Proxied bool `json:"proxied,omitempty"`
Priority int `json:"priority,omitempty"`
TTL int `json:"ttl,omitempty"` // seconds
Locked bool `json:"locked,omitempty"`
ZoneID string `json:"zone_id,omitempty"`
Expand Down Expand Up @@ -110,7 +113,8 @@ type cfDNSRecord struct {
}

func (r cfDNSRecord) libdnsRecord(zone string) libdns.Record {
if r.Type == "SRV" {
switch r.Type {
case "SRV":
srv := libdns.SRV{
Service: strings.TrimPrefix(r.Data.Service, "_"),
Proto: strings.TrimPrefix(r.Data.Proto, "_"),
Expand All @@ -120,14 +124,27 @@ func (r cfDNSRecord) libdnsRecord(zone string) libdns.Record {
Port: r.Data.Port,
Target: r.Data.Target,
}
return srv.ToRecord()
}
return libdns.Record{
Type: r.Type,
Name: libdns.RelativeName(r.Name, zone),
Value: r.Content,
TTL: time.Duration(r.TTL) * time.Second,
ID: r.ID,
return libdns.Record{
ID: r.ID,
Type: r.Type,
Name: libdns.RelativeName(r.Name, zone),
Value: fmt.Sprintf("%d %d %d %s", srv.Priority, srv.Weight, srv.Port, libdns.RelativeName(srv.Target, zone)),
TTL: time.Duration(r.TTL) * time.Second,
Priority: srv.Priority,
Weight: srv.Weight,
}
case "MX":
r.Content = fmt.Sprintf("%d %s", r.Priority, r.Content)
fallthrough
default:
return libdns.Record{
Type: r.Type,
Name: libdns.RelativeName(r.Name, zone),
Value: libdns.RelativeName(r.Content, zone),
TTL: time.Duration(r.TTL) * time.Second,
ID: r.ID,
}

}
}

Expand All @@ -137,20 +154,64 @@ func cloudflareRecord(r libdns.Record) (cfDNSRecord, error) {
Type: r.Type,
TTL: int(r.TTL.Seconds()),
}
if r.Type == "SRV" {
srv, err := r.ToSRV()
if err != nil {
return cfDNSRecord{}, err
}
rec.Data.Service = "_" + srv.Service
rec.Data.Priority = srv.Priority
rec.Data.Weight = srv.Weight
rec.Data.Proto = "_" + srv.Proto
rec.Data.Name = srv.Name
rec.Data.Port = srv.Port
rec.Data.Target = srv.Target
if r.Name == "" {
rec.Name = "@"
} else {
rec.Name = r.Name
}
switch r.Type {
case "SRV":
nameParts := strings.Split(r.Name, ".")
if len(nameParts) == 2 {
nameParts = append(nameParts, "@")
} else if len(nameParts) < 3 {
return cfDNSRecord{}, fmt.Errorf("invalid SRV record name: %s, expected _<service>._<proto>", r.Name)
}
valueParts := strings.Fields(r.Value)
if len(valueParts) != 4 {
return cfDNSRecord{}, fmt.Errorf("invalid SRV record value: %s, expected <priority> <weight> <port> <target>", r.Value)
}
priority, err := strconv.ParseUint(valueParts[0], 10, 64)
if err != nil {
return cfDNSRecord{}, fmt.Errorf("invalid SRV record value: priority is not a number: %s", valueParts[0])
}
weight, err := strconv.ParseUint(valueParts[1], 10, 64)
if err != nil {
return cfDNSRecord{}, fmt.Errorf("invalid SRV record value: weight is not a number: %s", valueParts[1])
}
port, err := strconv.Atoi(valueParts[2])
if err != nil {
return cfDNSRecord{}, fmt.Errorf("invalid SRV record value: target port is not a number: %s", valueParts[2])
}
if priority < 0 || priority > 65535 {
return cfDNSRecord{}, fmt.Errorf("invalid SRV record value: priority is out of range 0-65535: %d", priority)
}
rec.Data.Service = nameParts[0]
rec.Data.Priority = uint(priority)
rec.Data.Weight = uint(weight)
rec.Data.Proto = nameParts[1]
rec.Data.Name = strings.Join(nameParts[2:], ".")
rec.Data.Port = uint(port)
rec.Data.Target = strings.Join(valueParts[3:], ".")
case "MX":
valueParts := strings.Fields(r.Value)
if r.Priority == 0 && len(valueParts) != 2 {
return cfDNSRecord{}, fmt.Errorf("invalid MX record value: %s, expected <priority> <target> or Priority to be set", r.Value)
}
if len(valueParts) == 2 {
priority, err := strconv.ParseUint(valueParts[0], 10, 64)
if err != nil {
return cfDNSRecord{}, fmt.Errorf("invalid MX record value: priority is not a number: %s", valueParts[0])
}
r.Priority = uint(priority)
}
if len(valueParts) == 2 {
rec.Content = valueParts[1]
} else {
rec.Content = r.Value
}
rec.Priority = int(r.Priority)
default:
rec.Content = r.Value
}
return rec, nil
Expand Down
184 changes: 184 additions & 0 deletions provider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
package cloudflare

import (
"context"
"os"
"testing"
"time"

"github.com/libdns/libdns"
)

const (
TokenEnv = "CF_TOKEN"
ZonesEnv = "CF_ZONE"
)

func setup(t *testing.T) (*Provider, string) {
tk := os.Getenv(TokenEnv)
zone := os.Getenv(ZonesEnv)
if tk == "" || zone == "" {
t.Skipf("Skipping test, missing %s or %s", TokenEnv, ZonesEnv)
}
return &Provider{APIToken: tk}, zone
}

func TestMailRecords(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

p, zone := setup(t)
// defaultTTL is the cloudflare default TTL
// it's not really a second as 1 means 'automatic'
defaultTTL := time.Second

tests := []struct {
name string
rec libdns.Record
want libdns.Record
wantErr bool
}{
{
name: "A record",
rec: libdns.Record{
Type: "A",
Name: "",
Value: "10.10.10.110",
},
want: libdns.Record{
Type: "A",
Name: "",
Value: "10.10.10.110",
},
},
{
name: "CNAME record",
rec: libdns.Record{
Type: "CNAME",
Name: "mail",
Value: "@",
},
want: libdns.Record{
Type: "CNAME",
Name: "mail",
Value: "",
},
},
{
name: "MX record",
rec: libdns.Record{
Type: "MX",
Name: zone,
Value: "10 mail." + zone,
},
want: libdns.Record{
Type: "MX",
Name: "",
Value: "10 mail",
},
},
{
name: "MX record non canonical",
rec: libdns.Record{
Type: "MX",
Name: zone,
Value: "mail." + zone,
Priority: 10,
},
want: libdns.Record{
Type: "MX",
Name: "",
Value: "10 mail",
},
},
{
name: "SRV imaps record",
rec: libdns.Record{
Type: "SRV",
Name: "_imaps._tcp",
Value: "10 10 993 mail." + zone,
},
want: libdns.Record{
Type: "SRV",
Name: "_imaps._tcp",
Value: "10 10 993 mail",
},
},
{
name: "SRV submission record",
rec: libdns.Record{
Type: "SRV",
Name: "_submission._tcp.@",
Value: "10 10 587 mail." + zone,
},
want: libdns.Record{
Type: "SRV",
Name: "_submission._tcp",
Value: "10 10 587 mail",
},
},
{
name: "TXT spf record",
rec: libdns.Record{
Type: "TXT",
Name: "",
Value: "v=spf1 a mx -all",
},
want: libdns.Record{
Type: "TXT",
Name: "",
Value: "v=spf1 a mx -all",
},
},
{
name: "TXT dkim record",
rec: libdns.Record{
Type: "TXT",
Name: "mail._domainkey." + zone,
Value: "v=DKIM1; h=sha256; k=rsa; p=jANBgkqhkiG9w0BAQEFAAOSAg8AMIICCgKCAgEAoxTULRWLAevz5Q7pDE72xPVQ2zSmEabsyCof2EgHAzTzCgujadEzIKYFNpXgZsQ1euVR1D60j0Z9iLeubPPoxRXxlcSx+BoSB8uHW/yNpeRJwzuI46oGJvPEqcGxhVLZphsfecEkcKjMvHJCzt2UAoAmuedQJlNbwTz6NkZoEa5aac5HfDrvY4RCmgwvBF8tyWmJt5XYvk4M9G4Ktr134V0ahIlXKOAZv83SyMsCWHeCzU2hcsAY/uT7K4/torutMJKpiYK24GGk4Ce+MvCG89XwH5pHvBJ6dTO9QckOPz/nyTXGVEz/IJfnUkcnWvWqzCNiBbMF5F5hNGJjIjHn4iXttk+zRDHzo5LFfNiMNk88wxSKC+KuokvSNzHJSrsR6DCoFvTlbgC66N8RCjdklcm4fuPIWrtmyEob9pFOXg6GXRqbtK94HWOEOcQn5YzukKb8b6X1uLKGuqCZNvZZZECp5B4fMKrJBmW273MVg+2YIhoRmfhcIxoWvL3SVVuLKB1+ytdIfD8Qr30e/xNXSN4ZcdbtVwkXaqp1+/sp1fqq2KeEZJxftzChDNUpQ+GDxj0Xtfd2PicCsgemaOIslOKQIe7DZ5YBMRmZhT5OIRp8wJNOsZ3QbDpnlxCk8Ruh5dG0E21DREnkcXEAZjyv8gO0I2O7Ze6Vei2q3T94OecCAwEAAQ==",
TTL: 60 * time.Second,
},
want: libdns.Record{
Type: "TXT",
Name: "mail._domainkey",
Value: "v=DKIM1; h=sha256; k=rsa; p=jANBgkqhkiG9w0BAQEFAAOSAg8AMIICCgKCAgEAoxTULRWLAevz5Q7pDE72xPVQ2zSmEabsyCof2EgHAzTzCgujadEzIKYFNpXgZsQ1euVR1D60j0Z9iLeubPPoxRXxlcSx+BoSB8uHW/yNpeRJwzuI46oGJvPEqcGxhVLZphsfecEkcKjMvHJCzt2UAoAmuedQJlNbwTz6NkZoEa5aac5HfDrvY4RCmgwvBF8tyWmJt5XYvk4M9G4Ktr134V0ahIlXKOAZv83SyMsCWHeCzU2hcsAY/uT7K4/torutMJKpiYK24GGk4Ce+MvCG89XwH5pHvBJ6dTO9QckOPz/nyTXGVEz/IJfnUkcnWvWqzCNiBbMF5F5hNGJjIjHn4iXttk+zRDHzo5LFfNiMNk88wxSKC+KuokvSNzHJSrsR6DCoFvTlbgC66N8RCjdklcm4fuPIWrtmyEob9pFOXg6GXRqbtK94HWOEOcQn5YzukKb8b6X1uLKGuqCZNvZZZECp5B4fMKrJBmW273MVg+2YIhoRmfhcIxoWvL3SVVuLKB1+ytdIfD8Qr30e/xNXSN4ZcdbtVwkXaqp1+/sp1fqq2KeEZJxftzChDNUpQ+GDxj0Xtfd2PicCsgemaOIslOKQIe7DZ5YBMRmZhT5OIRp8wJNOsZ3QbDpnlxCk8Ruh5dG0E21DREnkcXEAZjyv8gO0I2O7Ze6Vei2q3T94OecCAwEAAQ==",
TTL: 60 * time.Second,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
recs, err := p.AppendRecords(ctx, zone, []libdns.Record{tt.rec})
if err != nil {
if !tt.wantErr {
t.Errorf("SetRecords() error = %v, wantErr %v", err, tt.wantErr)
}
return
}
if len(recs) != 1 {
t.Errorf("SetRecords() len = %d, want 1", len(recs))
}
if recs[0].ID == "" {
t.Errorf("SetRecords() ID = %s, want not empty", recs[0].ID)
}
if recs[0].Name != tt.want.Name {
t.Errorf("SetRecords() Name = %s, want %s", recs[0].Name, tt.want.Name)
}
if recs[0].Type != tt.want.Type {
t.Errorf("SetRecords() Type = %s, want %s", recs[0].Type, tt.want.Type)
}
if recs[0].Value != tt.want.Value {
t.Errorf("SetRecords() Value = %s, want %s", recs[0].Value, tt.want.Value)
}
if tt.want.TTL == 0 {
tt.want.TTL = defaultTTL
}
if recs[0].TTL != tt.want.TTL {
t.Errorf("SetRecords() TTL = %d, want %d", recs[0].TTL, tt.want.TTL)
}
if _, err := p.DeleteRecords(ctx, zone, recs); err != nil {
t.Errorf("DeleteRecords() error = %v", err)
}
})
}
}