Skip to content

Commit

Permalink
propagate status errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Ondřej Benkovský committed Apr 16, 2024
1 parent b344b56 commit 3850667
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 14 deletions.
3 changes: 0 additions & 3 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@ run:
linters:
disable-all: true
enable:
- deadcode
- gosimple
- govet
- ineffassign
- staticcheck
- structcheck
- typecheck
- unused
- varcheck
- gofmt
- revive
- gci
Expand Down
3 changes: 1 addition & 2 deletions doh/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"context"
"encoding/base64"
"errors"
"fmt"
"net/http"

Expand Down Expand Up @@ -68,7 +67,7 @@ func (dc *Client) send(r *http.Request) (*dns.Msg, error) {
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, errors.New("unexpected HTTP status")
return nil, UnexpectedServerHTTPStatusError{code: resp.StatusCode}
}

buffer := bytes.Buffer{}
Expand Down
36 changes: 27 additions & 9 deletions doh/client_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package doh
package doh_test

import (
"context"
Expand All @@ -11,11 +11,13 @@ import (
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tantalor93/doh-go/doh"
)

const (
existingDomain = "google.com."
notExistingDomain = "nxdomain.cz."
badStatusDomain = "wrong.com."
)

func Test_SendViaPost(t *testing.T) {
Expand All @@ -37,6 +39,9 @@ func Test_SendViaPost(t *testing.T) {
resp.Rcode = dns.RcodeNameError
case existingDomain:
resp.Rcode = dns.RcodeSuccess
case badStatusDomain:
w.WriteHeader(400)
return
default:
panic("unexpected question name")
}
Expand All @@ -61,7 +66,7 @@ func Test_SendViaPost(t *testing.T) {
name string
args args
wantRcode int
wantErr bool
wantErr error
}{
{
name: "NOERROR DNS resolution",
Expand All @@ -73,15 +78,20 @@ func Test_SendViaPost(t *testing.T) {
args: args{server: ts.URL, msg: question(notExistingDomain)},
wantRcode: dns.RcodeNameError,
},
{
name: "bad upstream HTTP response",
args: args{server: ts.URL, msg: question(badStatusDomain)},
wantErr: &doh.UnexpectedServerHTTPStatusError{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := NewClient(nil)
client := doh.NewClient(nil)

got, err := client.SendViaPost(context.Background(), tt.args.server, tt.args.msg)

if tt.wantErr {
require.Error(t, err, "SendViaPost() error")
if tt.wantErr != nil {
require.ErrorAs(t, err, tt.wantErr, "SendViaPost() error")
} else {
require.NoError(t, err)
assert.NotNil(t, got, "SendViaPost() response")
Expand Down Expand Up @@ -113,6 +123,9 @@ func Test_SendViaGet(t *testing.T) {
resp.Rcode = dns.RcodeNameError
case existingDomain:
resp.Rcode = dns.RcodeSuccess
case badStatusDomain:
w.WriteHeader(400)
return
default:
panic("unexpected question name")
}
Expand All @@ -137,7 +150,7 @@ func Test_SendViaGet(t *testing.T) {
name string
args args
wantRcode int
wantErr bool
wantErr error
}{
{
name: "NOERROR DNS resolution",
Expand All @@ -149,15 +162,20 @@ func Test_SendViaGet(t *testing.T) {
args: args{server: ts.URL, msg: question(notExistingDomain)},
wantRcode: dns.RcodeNameError,
},
{
name: "bad upstream HTTP response",
args: args{server: ts.URL, msg: question(badStatusDomain)},
wantErr: &doh.UnexpectedServerHTTPStatusError{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := NewClient(nil)
client := doh.NewClient(nil)

got, err := client.SendViaGet(context.Background(), tt.args.server, tt.args.msg)

if tt.wantErr {
require.Error(t, err, "SendViaGet() error")
if tt.wantErr != nil {
require.ErrorAs(t, err, tt.wantErr, "SendViaPost() error")
} else {
require.NoError(t, err)
assert.NotNil(t, got, "SendViaGet() response")
Expand Down
17 changes: 17 additions & 0 deletions doh/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package doh

import "fmt"

// UnexpectedServerHTTPStatusError error indicating that DoH server responded with bad HTTP status code.
type UnexpectedServerHTTPStatusError struct {
code int
}

func (u UnexpectedServerHTTPStatusError) Error() string {
return fmt.Sprintf("unexpected upstream server response HTTP status: %d", u.code)
}

// HTTPStatus HTTP status code returned by the DoH Server.
func (u UnexpectedServerHTTPStatusError) HTTPStatus() int {
return u.code
}

0 comments on commit 3850667

Please sign in to comment.