Skip to content

Commit

Permalink
easydns: fix zone detection (#2121)
Browse files Browse the repository at this point in the history
  • Loading branch information
ldez authored Mar 3, 2024
1 parent a7ca3d7 commit 6933296
Show file tree
Hide file tree
Showing 9 changed files with 366 additions and 61 deletions.
67 changes: 51 additions & 16 deletions providers/dns/easydns/easydns.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"github.com/go-acme/lego/v4/challenge/dns01"
"github.com/go-acme/lego/v4/platform/config/env"
"github.com/go-acme/lego/v4/providers/dns/easydns/internal"
"github.com/miekg/dns"
)

// Environment variables names.
Expand Down Expand Up @@ -117,20 +116,34 @@ func NewDNSProviderConfig(config *Config) (*DNSProvider, error) {

// Present creates a TXT record to fulfill the dns-01 challenge.
func (d *DNSProvider) Present(domain, token, keyAuth string) error {
ctx := context.Background()

info := dns01.GetChallengeInfo(domain, keyAuth)

apiHost, apiDomain := splitFqdn(info.EffectiveFQDN)
authZone, err := d.findZone(ctx, dns01.UnFqdn(info.EffectiveFQDN))
if err != nil {
return fmt.Errorf("easydns: %w", err)
}

if authZone == "" {
return fmt.Errorf("easydns: could not find zone for domain %q", domain)
}

subDomain, err := dns01.ExtractSubDomain(info.EffectiveFQDN, authZone)
if err != nil {
return fmt.Errorf("easydns: %w", err)
}

record := internal.ZoneRecord{
Domain: apiDomain,
Host: apiHost,
Domain: authZone,
Host: subDomain,
Type: "TXT",
Rdata: info.Value,
TTL: strconv.Itoa(d.config.TTL),
Priority: "0",
}

recordID, err := d.client.AddRecord(context.Background(), apiDomain, record)
recordID, err := d.client.AddRecord(ctx, dns01.UnFqdn(authZone), record)
if err != nil {
return fmt.Errorf("easydns: error adding zone record: %w", err)
}
Expand All @@ -146,6 +159,8 @@ func (d *DNSProvider) Present(domain, token, keyAuth string) error {

// CleanUp removes the TXT record matching the specified parameters.
func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
ctx := context.Background()

info := dns01.GetChallengeInfo(domain, keyAuth)

key := getMapKey(info.EffectiveFQDN, info.Value)
Expand All @@ -158,9 +173,16 @@ func (d *DNSProvider) CleanUp(domain, token, keyAuth string) error {
return nil
}

_, apiDomain := splitFqdn(info.EffectiveFQDN)
authZone, err := d.findZone(ctx, dns01.UnFqdn(info.EffectiveFQDN))
if err != nil {
return fmt.Errorf("easydns: %w", err)
}

err := d.client.DeleteRecord(context.Background(), apiDomain, recordID)
if authZone == "" {
return fmt.Errorf("easydns: could not find zone for domain %q", domain)
}

err = d.client.DeleteRecord(ctx, dns01.UnFqdn(authZone), recordID)

d.recordIDsMu.Lock()
defer delete(d.recordIDs, key)
Expand All @@ -185,15 +207,28 @@ func (d *DNSProvider) Sequential() time.Duration {
return d.config.SequenceInterval
}

func splitFqdn(fqdn string) (host, domain string) {
parts := dns.SplitDomainName(fqdn)
length := len(parts)

host = strings.Join(parts[0:length-2], ".")
domain = strings.Join(parts[length-2:length], ".")
return
}

func getMapKey(fqdn, value string) string {
return fqdn + "|" + value
}

func (d *DNSProvider) findZone(ctx context.Context, domain string) (string, error) {
var errAll error

for {
i := strings.Index(domain, ".")
if i == -1 {
break
}

_, err := d.client.ListZones(ctx, domain)
if err == nil {
return domain, nil
}

errAll = errors.Join(errAll, err)

domain = domain[i+1:]
}

return "", errAll
}
171 changes: 133 additions & 38 deletions providers/dns/easydns/easydns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,39 @@ func TestNewDNSProviderConfig(t *testing.T) {
func TestDNSProvider_Present(t *testing.T) {
provider, mux := setupTest(t)

mux.HandleFunc("/zones/records/all/example.com", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodGet, r.Method, "method")
assert.Equal(t, "format=json", r.URL.RawQuery, "query")
assert.Equal(t, "Basic VE9LRU46U0VDUkVU", r.Header.Get(authorizationHeader), authorizationHeader)

w.WriteHeader(http.StatusOK)
_, err := fmt.Fprintf(w, `{
"msg": "string",
"status": 200,
"tm": 0,
"data": [{
"id": "60898922",
"domain": "example.com",
"host": "hosta",
"ttl": "300",
"prio": "0",
"geozone_id": "0",
"type": "A",
"rdata": "1.2.3.4",
"last_mod": "2019-08-28 19:09:50"
}],
"count": 0,
"total": 0,
"start": 0,
"max": 0
}
`)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
})

mux.HandleFunc("/zones/records/add/example.com/TXT", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPut, r.Method, "method")
assert.Equal(t, "format=json", r.URL.RawQuery, "query")
Expand Down Expand Up @@ -191,7 +224,40 @@ func TestDNSProvider_Present(t *testing.T) {
}

func TestDNSProvider_Cleanup_WhenRecordIdNotSet_NoOp(t *testing.T) {
provider, _ := setupTest(t)
provider, mux := setupTest(t)

mux.HandleFunc("/zones/records/all/example.com", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodGet, r.Method, "method")
assert.Equal(t, "format=json", r.URL.RawQuery, "query")
assert.Equal(t, "Basic VE9LRU46U0VDUkVU", r.Header.Get(authorizationHeader), authorizationHeader)

w.WriteHeader(http.StatusOK)
_, err := fmt.Fprintf(w, `{
"msg": "string",
"status": 200,
"tm": 0,
"data": [{
"id": "60898922",
"domain": "example.com",
"host": "hosta",
"ttl": "300",
"prio": "0",
"geozone_id": "0",
"type": "A",
"rdata": "1.2.3.4",
"last_mod": "2019-08-28 19:09:50"
}],
"count": 0,
"total": 0,
"start": 0,
"max": 0
}
`)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
})

err := provider.CleanUp("example.com", "token", "keyAuth")
require.NoError(t, err)
Expand All @@ -200,6 +266,39 @@ func TestDNSProvider_Cleanup_WhenRecordIdNotSet_NoOp(t *testing.T) {
func TestDNSProvider_Cleanup_WhenRecordIdSet_DeletesTxtRecord(t *testing.T) {
provider, mux := setupTest(t)

mux.HandleFunc("/zones/records/all/example.com", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodGet, r.Method, "method")
assert.Equal(t, "format=json", r.URL.RawQuery, "query")
assert.Equal(t, "Basic VE9LRU46U0VDUkVU", r.Header.Get(authorizationHeader), authorizationHeader)

w.WriteHeader(http.StatusOK)
_, err := fmt.Fprintf(w, `{
"msg": "string",
"status": 200,
"tm": 0,
"data": [{
"id": "60898922",
"domain": "example.com",
"host": "hosta",
"ttl": "300",
"prio": "0",
"geozone_id": "0",
"type": "A",
"rdata": "1.2.3.4",
"last_mod": "2019-08-28 19:09:50"
}],
"count": 0,
"total": 0,
"start": 0,
"max": 0
}
`)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
})

mux.HandleFunc("/zones/records/example.com/123456", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodDelete, r.Method, "method")
assert.Equal(t, "format=json", r.URL.RawQuery, "query")
Expand Down Expand Up @@ -228,6 +327,39 @@ func TestDNSProvider_Cleanup_WhenRecordIdSet_DeletesTxtRecord(t *testing.T) {
func TestDNSProvider_Cleanup_WhenHttpError_ReturnsError(t *testing.T) {
provider, mux := setupTest(t)

mux.HandleFunc("/zones/records/all/example.com", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodGet, r.Method, "method")
assert.Equal(t, "format=json", r.URL.RawQuery, "query")
assert.Equal(t, "Basic VE9LRU46U0VDUkVU", r.Header.Get(authorizationHeader), authorizationHeader)

w.WriteHeader(http.StatusOK)
_, err := fmt.Fprintf(w, `{
"msg": "string",
"status": 200,
"tm": 0,
"data": [{
"id": "60898922",
"domain": "example.com",
"host": "hosta",
"ttl": "300",
"prio": "0",
"geozone_id": "0",
"type": "A",
"rdata": "1.2.3.4",
"last_mod": "2019-08-28 19:09:50"
}],
"count": 0,
"total": 0,
"start": 0,
"max": 0
}
`)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
})

errorMessage := `{
"error": {
"code": 406,
Expand All @@ -253,43 +385,6 @@ func TestDNSProvider_Cleanup_WhenHttpError_ReturnsError(t *testing.T) {
require.EqualError(t, err, expectedError)
}

func TestSplitFqdn(t *testing.T) {
testCases := []struct {
desc string
fqdn string
expectedHost string
expectedDomain string
}{
{
desc: "domain only",
fqdn: "domain.com.",
expectedHost: "",
expectedDomain: "domain.com",
},
{
desc: "single-part host",
fqdn: "_acme-challenge.domain.com.",
expectedHost: "_acme-challenge",
expectedDomain: "domain.com",
},
{
desc: "multi-part host",
fqdn: "_acme-challenge.sub.domain.com.",
expectedHost: "_acme-challenge.sub",
expectedDomain: "domain.com",
},
}

for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
actualHost, actualDomain := splitFqdn(test.fqdn)

require.Equal(t, test.expectedHost, actualHost)
require.Equal(t, test.expectedDomain, actualDomain)
})
}
}

func TestLivePresent(t *testing.T) {
if !envTest.IsLiveTest() {
t.Skip("skipping live test")
Expand Down
31 changes: 29 additions & 2 deletions providers/dns/easydns/internal/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,27 @@ func NewClient(token string, key string) *Client {
}
}

func (c *Client) ListZones(ctx context.Context, domain string) ([]ZoneRecord, error) {
endpoint := c.BaseURL.JoinPath("zones", "records", "all", domain)

req, err := newJSONRequest(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}

response := &apiResponse[[]ZoneRecord]{}
err = c.do(req, response)
if err != nil {
return nil, err
}

if response.Error != nil {
return nil, response.Error
}

return response.Data, nil
}

func (c *Client) AddRecord(ctx context.Context, domain string, record ZoneRecord) (string, error) {
endpoint := c.BaseURL.JoinPath("zones", "records", "add", domain, "TXT")

Expand All @@ -45,12 +66,16 @@ func (c *Client) AddRecord(ctx context.Context, domain string, record ZoneRecord
return "", err
}

response := &addRecordResponse{}
response := &apiResponse[*ZoneRecord]{}
err = c.do(req, response)
if err != nil {
return "", err
}

if response.Error != nil {
return "", response.Error
}

recordID := response.Data.ID

return recordID, nil
Expand All @@ -64,7 +89,9 @@ func (c *Client) DeleteRecord(ctx context.Context, domain, recordID string) erro
return err
}

return c.do(req, nil)
err = c.do(req, nil)

return err
}

func (c *Client) do(req *http.Request, result any) error {
Expand Down
Loading

0 comments on commit 6933296

Please sign in to comment.