Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test(api): test for package api #61

Merged
merged 5 commits into from
Aug 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ A small and fast DDNS updater for Cloudflare.
🔸 Effective GID: 1000
🔸 Supplementary GIDs: (empty)
🔇 Quiet mode enabled.
🐣 Added a new A record of …… (ID: ……).
🐣 Added a new AAAA record of …… (ID: ……).
🐣 Added a new A record of "……" (ID: ……).
🐣 Added a new AAAA record of "……" (ID: ……).
```

## 📜 Highlights
Expand Down
47 changes: 24 additions & 23 deletions internal/api/cloudflare.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ const (
type CloudflareAuth struct {
Token string
AccountID string
URL string
BaseURL string
}

func (t *CloudflareAuth) New(ctx context.Context, indent pp.Indent, cacheExpiration time.Duration) (Handle, bool) {
Expand All @@ -41,8 +41,8 @@ func (t *CloudflareAuth) New(ctx context.Context, indent pp.Indent, cacheExpirat
}

// set the base URL (mostly for testing)
if t.URL != "" {
handle.BaseURL = t.URL
if t.BaseURL != "" {
handle.BaseURL = t.BaseURL
}

// this is not needed, but is helpful for diagnosing the problem
Expand Down Expand Up @@ -82,7 +82,7 @@ func (h *CloudflareHandle) ActiveZones(ctx context.Context, indent pp.Indent, na

res, err := h.cf.ListZonesContext(ctx, cloudflare.WithZoneFilters(name, h.cf.AccountID, "active"))
if err != nil {
pp.Printf(indent, pp.EmojiError, "Failed to check the existence of a zone named %s: %v", name, err)
pp.Printf(indent, pp.EmojiError, "Failed to check the existence of a zone named %q: %v", name, err)
return nil, false
}

Expand All @@ -97,7 +97,7 @@ func (h *CloudflareHandle) ActiveZones(ctx context.Context, indent pp.Indent, na
}

func (h *CloudflareHandle) ZoneOfDomain(ctx context.Context, indent pp.Indent, domain FQDN) (string, bool) {
if id, found := h.cache.zoneOfDomain.Get(domain.String()); found {
if id, found := h.cache.zoneOfDomain.Get(domain.ToASCII()); found {
return id.(string), true
}

Expand All @@ -113,7 +113,7 @@ zoneSearch:
case 0: // len(zones) == 0
continue zoneSearch
case 1: // len(zones) == 1
h.cache.zoneOfDomain.SetDefault(domain.String(), zones[0])
h.cache.zoneOfDomain.SetDefault(domain.ToASCII(), zones[0])

return zones[0], true

Expand All @@ -124,13 +124,13 @@ zoneSearch:
}
}

pp.Printf(indent, pp.EmojiError, "Failed to find the zone of %s.", domain)
pp.Printf(indent, pp.EmojiError, "Failed to find the zone of %q.", domain.Describe())
return "", false
}

func (h *CloudflareHandle) ListRecords(ctx context.Context, indent pp.Indent,
domain FQDN, ipNet ipnet.Type) (map[string]net.IP, bool) {
if rmap, found := h.cache.listRecords[ipNet].Get(domain.String()); found {
if rmap, found := h.cache.listRecords[ipNet].Get(domain.ToASCII()); found {
return rmap.(map[string]net.IP), true
}

Expand All @@ -141,11 +141,11 @@ func (h *CloudflareHandle) ListRecords(ctx context.Context, indent pp.Indent,

//nolint:exhaustivestruct // Other fields are intentionally unspecified
rs, err := h.cf.DNSRecords(ctx, zone, cloudflare.DNSRecord{
Name: domain.String(),
Name: domain.ToASCII(),
Type: ipNet.RecordType(),
})
if err != nil {
pp.Printf(indent, pp.EmojiError, "Failed to retrieve records of %s: %v", domain, err)
pp.Printf(indent, pp.EmojiError, "Failed to retrieve records of %q: %v", domain.Describe(), err)
return nil, false
}

Expand All @@ -165,15 +165,15 @@ func (h *CloudflareHandle) DeleteRecord(ctx context.Context, indent pp.Indent,
}

if err := h.cf.DeleteDNSRecord(ctx, zone, id); err != nil {
pp.Printf(indent, pp.EmojiError, "Failed to delete a stale %s record of %s (ID: %s): %v",
ipNet.RecordType(), domain, id, err)
pp.Printf(indent, pp.EmojiError, "Failed to delete a stale %s record of %q (ID: %s): %v",
ipNet.RecordType(), domain.Describe(), id, err)

h.cache.listRecords[ipNet].Delete(domain.String())
h.cache.listRecords[ipNet].Delete(domain.ToASCII())

return false
}

if rmap, found := h.cache.listRecords[ipNet].Get(domain.String()); found {
if rmap, found := h.cache.listRecords[ipNet].Get(domain.ToASCII()); found {
delete(rmap.(map[string]net.IP), id)
}

Expand All @@ -189,21 +189,21 @@ func (h *CloudflareHandle) UpdateRecord(ctx context.Context, indent pp.Indent,

//nolint:exhaustivestruct // Other fields are intentionally omitted
payload := cloudflare.DNSRecord{
Name: domain.String(),
Name: domain.ToASCII(),
Type: ipNet.RecordType(),
Content: ip.String(),
}

if err := h.cf.UpdateDNSRecord(ctx, zone, id, payload); err != nil {
pp.Printf(indent, pp.EmojiError, "Failed to update a stale %s record of %s (ID: %s): %v",
ipNet.RecordType(), domain, id, err)
pp.Printf(indent, pp.EmojiError, "Failed to update a stale %s record of %q (ID: %s): %v",
ipNet.RecordType(), domain.Describe(), id, err)

h.cache.listRecords[ipNet].Delete(domain.String())
h.cache.listRecords[ipNet].Delete(domain.ToASCII())

return false
}

if rmap, found := h.cache.listRecords[ipNet].Get(domain.String()); found {
if rmap, found := h.cache.listRecords[ipNet].Get(domain.ToASCII()); found {
rmap.(map[string]net.IP)[id] = ip
}

Expand All @@ -219,7 +219,7 @@ func (h *CloudflareHandle) CreateRecord(ctx context.Context, indent pp.Indent,

//nolint:exhaustivestruct // Other fields are intentionally omitted
payload := cloudflare.DNSRecord{
Name: domain.String(),
Name: domain.ToASCII(),
Type: ipNet.RecordType(),
Content: ip.String(),
TTL: ttl,
Expand All @@ -228,14 +228,15 @@ func (h *CloudflareHandle) CreateRecord(ctx context.Context, indent pp.Indent,

res, err := h.cf.CreateDNSRecord(ctx, zone, payload)
if err != nil {
pp.Printf(indent, pp.EmojiError, "Failed to add a new %s record of %s: %v", ipNet.RecordType(), domain, err)
pp.Printf(indent, pp.EmojiError, "Failed to add a new %s record of %q: %v",
ipNet.RecordType(), domain.Describe(), err)

h.cache.listRecords[ipNet].Delete(domain.String())
h.cache.listRecords[ipNet].Delete(domain.ToASCII())

return "", false
}

if rmap, found := h.cache.listRecords[ipNet].Get(domain.String()); found {
if rmap, found := h.cache.listRecords[ipNet].Get(domain.ToASCII()); found {
rmap.(map[string]net.IP)[res.Result.ID] = ip
}

Expand Down
118 changes: 118 additions & 0 deletions internal/api/cloudflare_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package api_test

import (
"context"
"crypto/sha512"
"encoding/hex"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/favonia/cloudflare-ddns/internal/api"
)

// mockID returns a hex string of length 32, suitable for all kinds of IDs
// used in the Cloudflare API.
func mockID(seed string) string {
arr := sha512.Sum512([]byte(seed))
return hex.EncodeToString(arr[:16])
}

const (
mockToken = "token123"
mockAccount = "account456"
)

func TestCloudflareAuthNewValid(t *testing.T) {
t.Parallel()

mux := http.NewServeMux()
ts := httptest.NewServer(mux)
defer ts.Close()

mux.HandleFunc("/user/tokens/verify", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, []string{fmt.Sprintf("Bearer %s", mockToken)}, r.Header["Authorization"])

w.Header().Set("content-type", "application/json")
fmt.Fprintf(w,
`{
"result": { "id": "%s", "status": "active" },
"success": true,
"errors": [],
"messages": [
{
"code": 10000,
"message": "This API Token is valid and active",
"type": null
}
]
}`, mockID("result"))
})

auth := api.CloudflareAuth{
Token: mockToken,
AccountID: mockAccount,
BaseURL: ts.URL,
}

h, ok := auth.New(context.Background(), 3, time.Second)
require.NotNil(t, h)
require.True(t, ok)
}

func TestCloudflareAuthNewEmpty(t *testing.T) {
t.Parallel()

mux := http.NewServeMux()
ts := httptest.NewServer(mux)
defer ts.Close()

auth := api.CloudflareAuth{
Token: "",
AccountID: mockAccount,
BaseURL: ts.URL,
}

h, ok := auth.New(context.Background(), 3, time.Second)
require.Nil(t, h)
require.False(t, ok)
}

func TestCloudflareAuthNewInvalid(t *testing.T) {
t.Parallel()

mux := http.NewServeMux()
ts := httptest.NewServer(mux)
defer ts.Close()

mux.HandleFunc("/user/tokens/verify", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, []string{fmt.Sprintf("Bearer %s", mockToken)}, r.Header["Authorization"])

w.WriteHeader(http.StatusUnauthorized)
w.Header().Set("content-type", "application/json")
fmt.Fprintf(w,
`{
"success": false,
"errors": [{ "code": 1000, "message": "Invalid API Token" }],
"messages": [],
"result": null
}`)
})

auth := api.CloudflareAuth{
Token: mockToken,
AccountID: mockAccount,
BaseURL: ts.URL,
}

h, ok := auth.New(context.Background(), 3, time.Second)
require.Nil(t, h)
require.False(t, ok)
}
21 changes: 15 additions & 6 deletions internal/api/fqdn.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,26 @@ type FQDN string
// safelyToUnicode takes an ASCII form and returns the Unicode form
// when the round trip gives the same ASCII form back without errors.
// Otherwise, the input ASCII form is returned.
func safelyToUnicode(ascii string) string {
func safelyToUnicode(ascii string) (string, bool) {
unicode, errToA := profile.ToUnicode(ascii)
roundTrip, errToU := profile.ToASCII(unicode)
if errToA != nil || errToU != nil || roundTrip != ascii {
return ascii
return ascii, false
}

return unicode
return unicode, true
}

func (f FQDN) String() string { return string(f) }
func (f FQDN) ToASCII() string { return string(f) }

func (f FQDN) Describe() string {
best, ok := safelyToUnicode(string(f))
if !ok {
return string(f)
}

return best
}

// NewFQDN normalizes a domain to its ASCII form and then stores
// the normalized domain in its Unicode form when the round trip
Expand All @@ -45,7 +54,7 @@ func NewFQDN(domain string) (FQDN, error) {
// Remove the final dot for consistency
normalized = strings.TrimSuffix(normalized, ".")

return FQDN(safelyToUnicode(normalized)), err
return FQDN(normalized), err
}

func SortFQDNs(s []FQDN) { sort.Slice(s, func(i, j int) bool { return s[i] < s[j] }) }
Expand All @@ -58,7 +67,7 @@ type FQDNSplitter struct {

func NewFQDNSplitter(domain FQDN) *FQDNSplitter {
return &FQDNSplitter{
domain: domain.String(),
domain: domain.ToASCII(),
cursor: 0,
exhausted: false,
}
Expand Down
Loading