Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(api): make the updating algorithm more deterministic #864

Merged
merged 1 commit into from
Aug 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions internal/api/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@ import (

//go:generate mockgen -typed -destination=../mocks/mock_api.go -package=mocks . Handle

// WAFListItem bundles a network prefix and a comment.
// Record bundles an ID and an IP address, representing a DNS record.
type Record struct {
ID string
IP netip.Addr
}

// WAFListItem bundles an ID and an IP range, representing an item in a WAF list.
type WAFListItem struct {
ID string
Prefix netip.Prefix
Expand All @@ -29,21 +35,22 @@ type Handle interface {
// ListRecords lists all matching DNS records.
//
// The second return value indicates whether the list was cached.
ListRecords(ctx context.Context, ppfmt pp.PP, domain domain.Domain, ipNet ipnet.Type) (map[string]netip.Addr, bool, bool) //nolint:lll
ListRecords(ctx context.Context, ppfmt pp.PP, domain domain.Domain, ipNet ipnet.Type) ([]Record, bool, bool)

// DeleteRecord deletes one DNS record.
DeleteRecord(ctx context.Context, ppfmt pp.PP, domain domain.Domain, ipNet ipnet.Type, id string) bool

// UpdateRecord updates one DNS record.
UpdateRecord(ctx context.Context, ppfmt pp.PP, domain domain.Domain, ipNet ipnet.Type, id string, ip netip.Addr) bool

// CreateRecord creates one DNS record.
// CreateRecord creates one DNS record. It returns the ID of the new record.
CreateRecord(ctx context.Context, ppfmt pp.PP, domain domain.Domain, ipNet ipnet.Type,
ip netip.Addr, ttl TTL, proxied bool, recordComment string) (string, bool)

// EnsureWAFList creates an empty WAF list with IP ranges if it does not already exist yet.
// The first return value indicates whether the list already exists.
EnsureWAFList(ctx context.Context, ppfmt pp.PP, listName string, description string) (bool, bool)
// The first return value is the ID of the list.
// The second return value indicates whether the list already exists.
EnsureWAFList(ctx context.Context, ppfmt pp.PP, listName string, description string) (string, bool, bool)

// DeleteWAFList deletes a WAF list with IP ranges.
DeleteWAFList(ctx context.Context, ppfmt pp.PP, listName string) bool
Expand Down
15 changes: 7 additions & 8 deletions internal/api/cloudflare.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package api
import (
"context"
"errors"
"net/netip"
"time"

"github.com/cloudflare/cloudflare-go"
Expand All @@ -21,10 +20,10 @@ type CloudflareCache = struct {
listZones *ttlcache.Cache[string, []string] // zone names to zone IDs
zoneOfDomain *ttlcache.Cache[string, string] // domain names to the zone ID
// records of domains
listRecords map[ipnet.Type]*ttlcache.Cache[string, map[string]netip.Addr] // domain names to IPs
listRecords map[ipnet.Type]*ttlcache.Cache[string, *[]Record] // domain names to records.
// lists
listLists *ttlcache.Cache[struct{}, map[string][]string] // list names to list IDs
listListItems *ttlcache.Cache[string, []WAFListItem] // list IDs to list items
listLists *ttlcache.Cache[struct{}, map[string]string] // list names to list IDs
listListItems *ttlcache.Cache[string, []WAFListItem] // list IDs to list items
}

func newCache[K comparable, V any](cacheExpiration time.Duration) *ttlcache.Cache[K, V] {
Expand Down Expand Up @@ -72,11 +71,11 @@ func (t CloudflareAuth) New(_ context.Context, ppfmt pp.PP, cacheExpiration time
sanityCheck: newCache[struct{}, bool](cacheExpiration),
listZones: newCache[string, []string](cacheExpiration),
zoneOfDomain: newCache[string, string](cacheExpiration),
listRecords: map[ipnet.Type]*ttlcache.Cache[string, map[string]netip.Addr]{
ipnet.IP4: newCache[string, map[string]netip.Addr](cacheExpiration),
ipnet.IP6: newCache[string, map[string]netip.Addr](cacheExpiration),
listRecords: map[ipnet.Type]*ttlcache.Cache[string, *[]Record]{
ipnet.IP4: newCache[string, *[]Record](cacheExpiration),
ipnet.IP6: newCache[string, *[]Record](cacheExpiration),
},
listLists: newCache[struct{}, map[string][]string](cacheExpiration),
listLists: newCache[struct{}, map[string]string](cacheExpiration),
listListItems: newCache[string, []WAFListItem](cacheExpiration),
},
}
Expand Down
37 changes: 22 additions & 15 deletions internal/api/cloudflare_records.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package api
import (
"context"
"net/netip"
"slices"

"github.com/cloudflare/cloudflare-go"
"github.com/jellydator/ttlcache/v3"
Expand Down Expand Up @@ -105,9 +106,9 @@ zoneSearch:
// ListRecords calls cloudflare.ListDNSRecords.
func (h CloudflareHandle) ListRecords(ctx context.Context, ppfmt pp.PP,
domain domain.Domain, ipNet ipnet.Type,
) (map[string]netip.Addr, bool, bool) {
) ([]Record, bool, bool) {
if rmap := h.cache.listRecords[ipNet].Get(domain.DNSNameASCII()); rmap != nil {
return rmap.Value(), true, true
return *rmap.Value(), true, true
}

zone, ok := h.ZoneOfDomain(ctx, ppfmt, domain)
Expand All @@ -119,7 +120,7 @@ func (h CloudflareHandle) ListRecords(ctx context.Context, ppfmt pp.PP,
h.forcePassSanityCheck()

//nolint:exhaustruct // Other fields are intentionally unspecified
rs, _, err := h.cf.ListDNSRecords(ctx,
raw, _, err := h.cf.ListDNSRecords(ctx,
cloudflare.ZoneIdentifier(zone),
cloudflare.ListDNSRecordsParams{
Name: domain.DNSNameASCII(),
Expand All @@ -132,21 +133,23 @@ func (h CloudflareHandle) ListRecords(ctx context.Context, ppfmt pp.PP,
return nil, false, false
}

rmap := map[string]netip.Addr{}
for i := range rs {
rmap[rs[i].ID], err = netip.ParseAddr(rs[i].Content)
rs := make([]Record, 0, len(raw))
for _, r := range raw {
ip, err := netip.ParseAddr(r.Content)
if err != nil {
ppfmt.Warningf(pp.EmojiImpossible,
"Failed to parse the IP address in an %s record of %q (ID: %s): %v",
ipNet.RecordType(), domain.Describe(), rs[i].ID, err)
ipNet.RecordType(), domain.Describe(), r.ID, err)
return nil, false, false
}

rs = append(rs, Record{ID: r.ID, IP: ip})
}

h.cache.listRecords[ipNet].DeleteExpired()
h.cache.listRecords[ipNet].Set(domain.DNSNameASCII(), rmap, ttlcache.DefaultTTL)
h.cache.listRecords[ipNet].Set(domain.DNSNameASCII(), &rs, ttlcache.DefaultTTL)

return rmap, false, true
return rs, false, true
}

// DeleteRecord calls cloudflare.DeleteDNSRecord.
Expand All @@ -170,8 +173,8 @@ func (h CloudflareHandle) DeleteRecord(ctx context.Context, ppfmt pp.PP,
// The operation went through. No need to perform any sanity checking in near future!
h.forcePassSanityCheck()

if rmap := h.cache.listRecords[ipNet].Get(domain.DNSNameASCII()); rmap != nil {
delete(rmap.Value(), id)
if rs := h.cache.listRecords[ipNet].Get(domain.DNSNameASCII()); rs != nil {
*rs.Value() = slices.DeleteFunc(*rs.Value(), func(r Record) bool { return r.ID == id })
}

return true
Expand Down Expand Up @@ -204,8 +207,12 @@ func (h CloudflareHandle) UpdateRecord(ctx context.Context, ppfmt pp.PP,
// The operation went through. No need to perform any sanity checking in near future!
h.forcePassSanityCheck()

if rmap := h.cache.listRecords[ipNet].Get(domain.DNSNameASCII()); rmap != nil {
rmap.Value()[id] = ip
if rs := h.cache.listRecords[ipNet].Get(domain.DNSNameASCII()); rs != nil {
for i, r := range *rs.Value() {
if r.ID == id {
(*rs.Value())[i].IP = ip
}
}
}

return true
Expand Down Expand Up @@ -243,8 +250,8 @@ func (h CloudflareHandle) CreateRecord(ctx context.Context, ppfmt pp.PP,
// The operation went through. No need to perform any sanity checking in near future!
h.forcePassSanityCheck()

if rmap := h.cache.listRecords[ipNet].Get(domain.DNSNameASCII()); rmap != nil {
rmap.Value()[res.ID] = ip
if rs := h.cache.listRecords[ipNet].Get(domain.DNSNameASCII()); rs != nil {
*rs.Value() = append([]Record{{ID: res.ID, IP: ip}}, *rs.Value()...)
}

return res.ID, true
Expand Down
Loading
Loading