From abb7e21e36b962837de0aa142dfbc8765790126a Mon Sep 17 00:00:00 2001 From: "ondrej.benkovsky" Date: Sun, 27 Oct 2024 19:47:01 +0100 Subject: [PATCH 1/2] use Options pattern for creating Client instance --- README.md | 4 ++-- doh/client.go | 21 ++++++++++++--------- doh/client_test.go | 32 ++++++++++++-------------------- doh/opts.go | 23 +++++++++++++++++++++++ 4 files changed, 49 insertions(+), 31 deletions(-) create mode 100644 doh/opts.go diff --git a/README.md b/README.md index f180e45..7b2e0f4 100644 --- a/README.md +++ b/README.md @@ -19,8 +19,8 @@ go get github.com/tantalor93/doh-go ## Examples ``` -// create client with default http.Client -c := doh.NewClient(nil) +// create client with default settings +c := doh.NewClient() // prepare payload msg := dns.Msg{} diff --git a/doh/client.go b/doh/client.go index 75f2eec..c4e4434 100644 --- a/doh/client.go +++ b/doh/client.go @@ -12,18 +12,21 @@ import ( // Client encapsulates and provides logic for querying DNS servers over DoH. type Client struct { - c *http.Client + client *http.Client } -// NewClient creates new Client instance with standard net/http client. If nil, default http.Client is used. -func NewClient(c *http.Client) *Client { - if c == nil { - c = &http.Client{} +// NewClient creates new Client instance with standard net/http client. +func NewClient(opts ...Option) *Client { + client := &Client{ + client: &http.Client{}, } - return &Client{c} + for _, opt := range opts { + opt.apply(client) + } + return client } -// SendViaPost sends DNS message to the given DNS server over DoH using POST, see https://datatracker.ietf.org/doc/html/rfc8484#section-4.1 +// SendViaPost sends DNS message to the given DNS server over DoH using POST method, see https://datatracker.ietf.org/doc/html/rfc8484#section-4.1 func (dc *Client) SendViaPost(ctx context.Context, server string, msg *dns.Msg) (*dns.Msg, error) { pack, err := msg.Pack() if err != nil { @@ -41,7 +44,7 @@ func (dc *Client) SendViaPost(ctx context.Context, server string, msg *dns.Msg) return dc.send(request) } -// SendViaGet sends DNS message to the given DNS server over DoH using GET, see https://datatracker.ietf.org/doc/html/rfc8484#section-4.1 +// SendViaGet sends DNS message to the given DNS server over DoH using GET method, see https://datatracker.ietf.org/doc/html/rfc8484#section-4.1 func (dc *Client) SendViaGet(ctx context.Context, server string, msg *dns.Msg) (*dns.Msg, error) { pack, err := msg.Pack() if err != nil { @@ -60,7 +63,7 @@ func (dc *Client) SendViaGet(ctx context.Context, server string, msg *dns.Msg) ( } func (dc *Client) send(r *http.Request) (*dns.Msg, error) { - resp, err := dc.c.Do(r) + resp, err := dc.client.Do(r) if err != nil { return nil, err } diff --git a/doh/client_test.go b/doh/client_test.go index 1bfb61e..7ef8446 100644 --- a/doh/client_test.go +++ b/doh/client_test.go @@ -58,37 +58,33 @@ func Test_SendViaPost(t *testing.T) { })) defer ts.Close() - type args struct { - server string - msg *dns.Msg - } tests := []struct { name string - args args + msg *dns.Msg wantRcode int wantErr error }{ { name: "NOERROR DNS resolution", - args: args{server: ts.URL, msg: question(existingDomain)}, + msg: question(existingDomain), wantRcode: dns.RcodeSuccess, }, { name: "NXDOMAIN DNS resolution", - args: args{server: ts.URL, msg: question(notExistingDomain)}, + msg: question(notExistingDomain), wantRcode: dns.RcodeNameError, }, { name: "bad upstream HTTP response", - args: args{server: ts.URL, msg: question(badStatusDomain)}, + msg: question(badStatusDomain), wantErr: &doh.UnexpectedServerHTTPStatusError{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - client := doh.NewClient(nil) + client := doh.NewClient() - got, err := client.SendViaPost(context.Background(), tt.args.server, tt.args.msg) + got, err := client.SendViaPost(context.Background(), ts.URL, tt.msg) if tt.wantErr != nil { require.ErrorAs(t, err, tt.wantErr, "SendViaPost() error") @@ -142,37 +138,33 @@ func Test_SendViaGet(t *testing.T) { })) defer ts.Close() - type args struct { - server string - msg *dns.Msg - } tests := []struct { name string - args args + msg *dns.Msg wantRcode int wantErr error }{ { name: "NOERROR DNS resolution", - args: args{server: ts.URL, msg: question(existingDomain)}, + msg: question(existingDomain), wantRcode: dns.RcodeSuccess, }, { name: "NXDOMAIN DNS resolution", - args: args{server: ts.URL, msg: question(notExistingDomain)}, + msg: question(notExistingDomain), wantRcode: dns.RcodeNameError, }, { name: "bad upstream HTTP response", - args: args{server: ts.URL, msg: question(badStatusDomain)}, + msg: question(badStatusDomain), wantErr: &doh.UnexpectedServerHTTPStatusError{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - client := doh.NewClient(nil) + client := doh.NewClient() - got, err := client.SendViaGet(context.Background(), tt.args.server, tt.args.msg) + got, err := client.SendViaGet(context.Background(), ts.URL, tt.msg) if tt.wantErr != nil { require.ErrorAs(t, err, tt.wantErr, "SendViaPost() error") diff --git a/doh/opts.go b/doh/opts.go new file mode 100644 index 0000000..eb4c51a --- /dev/null +++ b/doh/opts.go @@ -0,0 +1,23 @@ +package doh + +import "net/http" + +// Option represents configuration options for doh.Client. +type Option interface { + apply(c *Client) +} + +type httpClientOption struct { + client *http.Client +} + +func (o *httpClientOption) apply(c *Client) { + c.client = o.client +} + +// WithHTTPClient is a configuration option that overrides default http.Client instance used by the doh.Client. +func WithHTTPClient(c *http.Client) Option { + return &httpClientOption{ + client: c, + } +} From 9b6a00b874547d3370e38160ce3352196f8e7a43 Mon Sep 17 00:00:00 2001 From: "ondrej.benkovsky" Date: Sun, 27 Oct 2024 19:59:58 +0100 Subject: [PATCH 2/2] move server address to constructor --- README.md | 4 ++-- doh/client.go | 24 +++++++++++++----------- doh/client_test.go | 8 ++++---- 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 7b2e0f4..0ca3027 100644 --- a/README.md +++ b/README.md @@ -20,14 +20,14 @@ go get github.com/tantalor93/doh-go ## Examples ``` // create client with default settings -c := doh.NewClient() +c := doh.NewClient("https://1.1.1.1/dns-query") // prepare payload msg := dns.Msg{} msg.SetQuestion("google.com.", dns.TypeA) // send DNS query to Cloudflare Server over DoH using POST method -r, err := c.SendViaPost(context.Background(), "https://1.1.1.1/dns-query", &msg) +r, err := c.SendViaPost(context.Background(), &msg) if err != nil { panic(err) } diff --git a/doh/client.go b/doh/client.go index c4e4434..ad76ab2 100644 --- a/doh/client.go +++ b/doh/client.go @@ -12,12 +12,14 @@ import ( // Client encapsulates and provides logic for querying DNS servers over DoH. type Client struct { + addr string client *http.Client } // NewClient creates new Client instance with standard net/http client. -func NewClient(opts ...Option) *Client { +func NewClient(addr string, opts ...Option) *Client { client := &Client{ + addr: addr, client: &http.Client{}, } for _, opt := range opts { @@ -26,14 +28,14 @@ func NewClient(opts ...Option) *Client { return client } -// SendViaPost sends DNS message to the given DNS server over DoH using POST method, see https://datatracker.ietf.org/doc/html/rfc8484#section-4.1 -func (dc *Client) SendViaPost(ctx context.Context, server string, msg *dns.Msg) (*dns.Msg, error) { +// SendViaPost sends DNS message using HTTP POST method, see https://datatracker.ietf.org/doc/html/rfc8484#section-4.1 +func (c *Client) SendViaPost(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { pack, err := msg.Pack() if err != nil { return nil, err } - request, err := http.NewRequest("POST", server, bytes.NewReader(pack)) + request, err := http.NewRequest("POST", c.addr, bytes.NewReader(pack)) if err != nil { return nil, err } @@ -41,17 +43,17 @@ func (dc *Client) SendViaPost(ctx context.Context, server string, msg *dns.Msg) request.Header.Set("Accept", "application/dns-message") request.Header.Set("content-type", "application/dns-message") - return dc.send(request) + return c.send(request) } -// SendViaGet sends DNS message to the given DNS server over DoH using GET method, see https://datatracker.ietf.org/doc/html/rfc8484#section-4.1 -func (dc *Client) SendViaGet(ctx context.Context, server string, msg *dns.Msg) (*dns.Msg, error) { +// SendViaGet sends DNS message using HTTP GET method, see https://datatracker.ietf.org/doc/html/rfc8484#section-4.1 +func (c *Client) SendViaGet(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { pack, err := msg.Pack() if err != nil { return nil, err } - url := fmt.Sprint(server, "?dns=", base64.RawURLEncoding.EncodeToString(pack)) + url := fmt.Sprint(c.addr, "?dns=", base64.RawURLEncoding.EncodeToString(pack)) request, err := http.NewRequest("GET", url, nil) if err != nil { return nil, err @@ -59,11 +61,11 @@ func (dc *Client) SendViaGet(ctx context.Context, server string, msg *dns.Msg) ( request = request.WithContext(ctx) request.Header.Set("Accept", "application/dns-message") - return dc.send(request) + return c.send(request) } -func (dc *Client) send(r *http.Request) (*dns.Msg, error) { - resp, err := dc.client.Do(r) +func (c *Client) send(r *http.Request) (*dns.Msg, error) { + resp, err := c.client.Do(r) if err != nil { return nil, err } diff --git a/doh/client_test.go b/doh/client_test.go index 7ef8446..c47009e 100644 --- a/doh/client_test.go +++ b/doh/client_test.go @@ -82,9 +82,9 @@ func Test_SendViaPost(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - client := doh.NewClient() + client := doh.NewClient(ts.URL) - got, err := client.SendViaPost(context.Background(), ts.URL, tt.msg) + got, err := client.SendViaPost(context.Background(), tt.msg) if tt.wantErr != nil { require.ErrorAs(t, err, tt.wantErr, "SendViaPost() error") @@ -162,9 +162,9 @@ func Test_SendViaGet(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - client := doh.NewClient() + client := doh.NewClient(ts.URL) - got, err := client.SendViaGet(context.Background(), ts.URL, tt.msg) + got, err := client.SendViaGet(context.Background(), tt.msg) if tt.wantErr != nil { require.ErrorAs(t, err, tt.wantErr, "SendViaPost() error")