diff --git a/README.md b/README.md index f180e45..7b2e0f4 100644 --- a/README.md +++ b/README.md @@ -19,8 +19,8 @@ go get github.com/tantalor93/doh-go ## Examples ``` -// create client with default http.Client -c := doh.NewClient(nil) +// create client with default settings +c := doh.NewClient() // prepare payload msg := dns.Msg{} diff --git a/doh/client.go b/doh/client.go index 75f2eec..c4e4434 100644 --- a/doh/client.go +++ b/doh/client.go @@ -12,18 +12,21 @@ import ( // Client encapsulates and provides logic for querying DNS servers over DoH. type Client struct { - c *http.Client + client *http.Client } -// NewClient creates new Client instance with standard net/http client. If nil, default http.Client is used. -func NewClient(c *http.Client) *Client { - if c == nil { - c = &http.Client{} +// NewClient creates new Client instance with standard net/http client. +func NewClient(opts ...Option) *Client { + client := &Client{ + client: &http.Client{}, } - return &Client{c} + for _, opt := range opts { + opt.apply(client) + } + return client } -// SendViaPost sends DNS message to the given DNS server over DoH using POST, see https://datatracker.ietf.org/doc/html/rfc8484#section-4.1 +// SendViaPost sends DNS message to the given DNS server over DoH using POST method, see https://datatracker.ietf.org/doc/html/rfc8484#section-4.1 func (dc *Client) SendViaPost(ctx context.Context, server string, msg *dns.Msg) (*dns.Msg, error) { pack, err := msg.Pack() if err != nil { @@ -41,7 +44,7 @@ func (dc *Client) SendViaPost(ctx context.Context, server string, msg *dns.Msg) return dc.send(request) } -// SendViaGet sends DNS message to the given DNS server over DoH using GET, see https://datatracker.ietf.org/doc/html/rfc8484#section-4.1 +// SendViaGet sends DNS message to the given DNS server over DoH using GET method, see https://datatracker.ietf.org/doc/html/rfc8484#section-4.1 func (dc *Client) SendViaGet(ctx context.Context, server string, msg *dns.Msg) (*dns.Msg, error) { pack, err := msg.Pack() if err != nil { @@ -60,7 +63,7 @@ func (dc *Client) SendViaGet(ctx context.Context, server string, msg *dns.Msg) ( } func (dc *Client) send(r *http.Request) (*dns.Msg, error) { - resp, err := dc.c.Do(r) + resp, err := dc.client.Do(r) if err != nil { return nil, err } diff --git a/doh/client_test.go b/doh/client_test.go index 1bfb61e..7ef8446 100644 --- a/doh/client_test.go +++ b/doh/client_test.go @@ -58,37 +58,33 @@ func Test_SendViaPost(t *testing.T) { })) defer ts.Close() - type args struct { - server string - msg *dns.Msg - } tests := []struct { name string - args args + msg *dns.Msg wantRcode int wantErr error }{ { name: "NOERROR DNS resolution", - args: args{server: ts.URL, msg: question(existingDomain)}, + msg: question(existingDomain), wantRcode: dns.RcodeSuccess, }, { name: "NXDOMAIN DNS resolution", - args: args{server: ts.URL, msg: question(notExistingDomain)}, + msg: question(notExistingDomain), wantRcode: dns.RcodeNameError, }, { name: "bad upstream HTTP response", - args: args{server: ts.URL, msg: question(badStatusDomain)}, + msg: question(badStatusDomain), wantErr: &doh.UnexpectedServerHTTPStatusError{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - client := doh.NewClient(nil) + client := doh.NewClient() - got, err := client.SendViaPost(context.Background(), tt.args.server, tt.args.msg) + got, err := client.SendViaPost(context.Background(), ts.URL, tt.msg) if tt.wantErr != nil { require.ErrorAs(t, err, tt.wantErr, "SendViaPost() error") @@ -142,37 +138,33 @@ func Test_SendViaGet(t *testing.T) { })) defer ts.Close() - type args struct { - server string - msg *dns.Msg - } tests := []struct { name string - args args + msg *dns.Msg wantRcode int wantErr error }{ { name: "NOERROR DNS resolution", - args: args{server: ts.URL, msg: question(existingDomain)}, + msg: question(existingDomain), wantRcode: dns.RcodeSuccess, }, { name: "NXDOMAIN DNS resolution", - args: args{server: ts.URL, msg: question(notExistingDomain)}, + msg: question(notExistingDomain), wantRcode: dns.RcodeNameError, }, { name: "bad upstream HTTP response", - args: args{server: ts.URL, msg: question(badStatusDomain)}, + msg: question(badStatusDomain), wantErr: &doh.UnexpectedServerHTTPStatusError{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - client := doh.NewClient(nil) + client := doh.NewClient() - got, err := client.SendViaGet(context.Background(), tt.args.server, tt.args.msg) + got, err := client.SendViaGet(context.Background(), ts.URL, tt.msg) if tt.wantErr != nil { require.ErrorAs(t, err, tt.wantErr, "SendViaPost() error") diff --git a/doh/opts.go b/doh/opts.go new file mode 100644 index 0000000..eb4c51a --- /dev/null +++ b/doh/opts.go @@ -0,0 +1,23 @@ +package doh + +import "net/http" + +// Option represents configuration options for doh.Client. +type Option interface { + apply(c *Client) +} + +type httpClientOption struct { + client *http.Client +} + +func (o *httpClientOption) apply(c *Client) { + c.client = o.client +} + +// WithHTTPClient is a configuration option that overrides default http.Client instance used by the doh.Client. +func WithHTTPClient(c *http.Client) Option { + return &httpClientOption{ + client: c, + } +}