Skip to content

Commit

Permalink
fix(api): make the updating algorithm more deterministic
Browse files Browse the repository at this point in the history
  • Loading branch information
favonia committed Aug 10, 2024
1 parent 36bf8d9 commit 9a2b3c0
Show file tree
Hide file tree
Showing 9 changed files with 428 additions and 503 deletions.
17 changes: 12 additions & 5 deletions internal/api/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@ import (

//go:generate mockgen -typed -destination=../mocks/mock_api.go -package=mocks . Handle

// WAFListItem bundles a network prefix and a comment.
// Record bundles an ID and an IP address, representing a DNS record.
type Record struct {
ID string
IP netip.Addr
}

// WAFListItem bundles an ID and an IP range, representing an item in a WAF list.
type WAFListItem struct {
ID string
Prefix netip.Prefix
Expand All @@ -29,21 +35,22 @@ type Handle interface {
// ListRecords lists all matching DNS records.
//
// The second return value indicates whether the list was cached.
ListRecords(ctx context.Context, ppfmt pp.PP, domain domain.Domain, ipNet ipnet.Type) (map[string]netip.Addr, bool, bool) //nolint:lll
ListRecords(ctx context.Context, ppfmt pp.PP, domain domain.Domain, ipNet ipnet.Type) ([]Record, bool, bool)

// DeleteRecord deletes one DNS record.
DeleteRecord(ctx context.Context, ppfmt pp.PP, domain domain.Domain, ipNet ipnet.Type, id string) bool

// UpdateRecord updates one DNS record.
UpdateRecord(ctx context.Context, ppfmt pp.PP, domain domain.Domain, ipNet ipnet.Type, id string, ip netip.Addr) bool

// CreateRecord creates one DNS record.
// CreateRecord creates one DNS record. It returns the ID of the new record.
CreateRecord(ctx context.Context, ppfmt pp.PP, domain domain.Domain, ipNet ipnet.Type,
ip netip.Addr, ttl TTL, proxied bool, recordComment string) (string, bool)

// EnsureWAFList creates an empty WAF list with IP ranges if it does not already exist yet.
// The first return value indicates whether the list already exists.
EnsureWAFList(ctx context.Context, ppfmt pp.PP, listName string, description string) (bool, bool)
// The first return value is the ID of the list.
// The second return value indicates whether the list already exists.
EnsureWAFList(ctx context.Context, ppfmt pp.PP, listName string, description string) (string, bool, bool)

// DeleteWAFList deletes a WAF list with IP ranges.
DeleteWAFList(ctx context.Context, ppfmt pp.PP, listName string) bool
Expand Down
15 changes: 7 additions & 8 deletions internal/api/cloudflare.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package api
import (
"context"
"errors"
"net/netip"
"time"

"github.com/cloudflare/cloudflare-go"
Expand All @@ -21,10 +20,10 @@ type CloudflareCache = struct {
listZones *ttlcache.Cache[string, []string] // zone names to zone IDs
zoneOfDomain *ttlcache.Cache[string, string] // domain names to the zone ID
// records of domains
listRecords map[ipnet.Type]*ttlcache.Cache[string, map[string]netip.Addr] // domain names to IPs
listRecords map[ipnet.Type]*ttlcache.Cache[string, *[]Record] // domain names to records.
// lists
listLists *ttlcache.Cache[struct{}, map[string][]string] // list names to list IDs
listListItems *ttlcache.Cache[string, []WAFListItem] // list IDs to list items
listLists *ttlcache.Cache[struct{}, map[string]string] // list names to list IDs
listListItems *ttlcache.Cache[string, []WAFListItem] // list IDs to list items
}

func newCache[K comparable, V any](cacheExpiration time.Duration) *ttlcache.Cache[K, V] {
Expand Down Expand Up @@ -72,11 +71,11 @@ func (t CloudflareAuth) New(_ context.Context, ppfmt pp.PP, cacheExpiration time
sanityCheck: newCache[struct{}, bool](cacheExpiration),
listZones: newCache[string, []string](cacheExpiration),
zoneOfDomain: newCache[string, string](cacheExpiration),
listRecords: map[ipnet.Type]*ttlcache.Cache[string, map[string]netip.Addr]{
ipnet.IP4: newCache[string, map[string]netip.Addr](cacheExpiration),
ipnet.IP6: newCache[string, map[string]netip.Addr](cacheExpiration),
listRecords: map[ipnet.Type]*ttlcache.Cache[string, *[]Record]{
ipnet.IP4: newCache[string, *[]Record](cacheExpiration),
ipnet.IP6: newCache[string, *[]Record](cacheExpiration),
},
listLists: newCache[struct{}, map[string][]string](cacheExpiration),
listLists: newCache[struct{}, map[string]string](cacheExpiration),
listListItems: newCache[string, []WAFListItem](cacheExpiration),
},
}
Expand Down
37 changes: 22 additions & 15 deletions internal/api/cloudflare_records.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package api
import (
"context"
"net/netip"
"slices"

"github.com/cloudflare/cloudflare-go"
"github.com/jellydator/ttlcache/v3"
Expand Down Expand Up @@ -105,9 +106,9 @@ zoneSearch:
// ListRecords calls cloudflare.ListDNSRecords.
func (h CloudflareHandle) ListRecords(ctx context.Context, ppfmt pp.PP,
domain domain.Domain, ipNet ipnet.Type,
) (map[string]netip.Addr, bool, bool) {
) ([]Record, bool, bool) {
if rmap := h.cache.listRecords[ipNet].Get(domain.DNSNameASCII()); rmap != nil {
return rmap.Value(), true, true
return *rmap.Value(), true, true
}

zone, ok := h.ZoneOfDomain(ctx, ppfmt, domain)
Expand All @@ -119,7 +120,7 @@ func (h CloudflareHandle) ListRecords(ctx context.Context, ppfmt pp.PP,
h.forcePassSanityCheck()

//nolint:exhaustruct // Other fields are intentionally unspecified
rs, _, err := h.cf.ListDNSRecords(ctx,
raw, _, err := h.cf.ListDNSRecords(ctx,
cloudflare.ZoneIdentifier(zone),
cloudflare.ListDNSRecordsParams{
Name: domain.DNSNameASCII(),
Expand All @@ -132,21 +133,23 @@ func (h CloudflareHandle) ListRecords(ctx context.Context, ppfmt pp.PP,
return nil, false, false
}

rmap := map[string]netip.Addr{}
for i := range rs {
rmap[rs[i].ID], err = netip.ParseAddr(rs[i].Content)
rs := make([]Record, 0, len(raw))
for _, r := range raw {
ip, err := netip.ParseAddr(r.Content)
if err != nil {
ppfmt.Warningf(pp.EmojiImpossible,
"Failed to parse the IP address in an %s record of %q (ID: %s): %v",
ipNet.RecordType(), domain.Describe(), rs[i].ID, err)
ipNet.RecordType(), domain.Describe(), r.ID, err)
return nil, false, false
}

rs = append(rs, Record{ID: r.ID, IP: ip})
}

h.cache.listRecords[ipNet].DeleteExpired()
h.cache.listRecords[ipNet].Set(domain.DNSNameASCII(), rmap, ttlcache.DefaultTTL)
h.cache.listRecords[ipNet].Set(domain.DNSNameASCII(), &rs, ttlcache.DefaultTTL)

return rmap, false, true
return rs, false, true
}

// DeleteRecord calls cloudflare.DeleteDNSRecord.
Expand All @@ -170,8 +173,8 @@ func (h CloudflareHandle) DeleteRecord(ctx context.Context, ppfmt pp.PP,
// The operation went through. No need to perform any sanity checking in near future!
h.forcePassSanityCheck()

if rmap := h.cache.listRecords[ipNet].Get(domain.DNSNameASCII()); rmap != nil {
delete(rmap.Value(), id)
if rs := h.cache.listRecords[ipNet].Get(domain.DNSNameASCII()); rs != nil {
*rs.Value() = slices.DeleteFunc(*rs.Value(), func(r Record) bool { return r.ID == id })
}

return true
Expand Down Expand Up @@ -204,8 +207,12 @@ func (h CloudflareHandle) UpdateRecord(ctx context.Context, ppfmt pp.PP,
// The operation went through. No need to perform any sanity checking in near future!
h.forcePassSanityCheck()

if rmap := h.cache.listRecords[ipNet].Get(domain.DNSNameASCII()); rmap != nil {
rmap.Value()[id] = ip
if rs := h.cache.listRecords[ipNet].Get(domain.DNSNameASCII()); rs != nil {
for i, r := range *rs.Value() {
if r.ID == id {
(*rs.Value())[i].IP = ip
}
}
}

return true
Expand Down Expand Up @@ -243,8 +250,8 @@ func (h CloudflareHandle) CreateRecord(ctx context.Context, ppfmt pp.PP,
// The operation went through. No need to perform any sanity checking in near future!
h.forcePassSanityCheck()

if rmap := h.cache.listRecords[ipNet].Get(domain.DNSNameASCII()); rmap != nil {
rmap.Value()[res.ID] = ip
if rs := h.cache.listRecords[ipNet].Get(domain.DNSNameASCII()); rs != nil {
*rs.Value() = append([]Record{{ID: res.ID, IP: ip}}, *rs.Value()...)
}

return res.ID, true
Expand Down
Loading

0 comments on commit 9a2b3c0

Please sign in to comment.