Skip to content

Commit

Permalink
move server address to constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
ondrej.benkovsky committed Oct 27, 2024
1 parent abb7e21 commit 9b6a00b
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 17 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ go get github.com/tantalor93/doh-go
## Examples
```
// create client with default settings
c := doh.NewClient()
c := doh.NewClient("https://1.1.1.1/dns-query")
// prepare payload
msg := dns.Msg{}
msg.SetQuestion("google.com.", dns.TypeA)
// send DNS query to Cloudflare Server over DoH using POST method
r, err := c.SendViaPost(context.Background(), "https://1.1.1.1/dns-query", &msg)
r, err := c.SendViaPost(context.Background(), &msg)
if err != nil {
panic(err)
}
Expand Down
24 changes: 13 additions & 11 deletions doh/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ import (

// Client encapsulates and provides logic for querying DNS servers over DoH.
type Client struct {
addr string
client *http.Client
}

// NewClient creates new Client instance with standard net/http client.
func NewClient(opts ...Option) *Client {
func NewClient(addr string, opts ...Option) *Client {
client := &Client{
addr: addr,
client: &http.Client{},
}
for _, opt := range opts {
Expand All @@ -26,44 +28,44 @@ func NewClient(opts ...Option) *Client {
return client
}

// SendViaPost sends DNS message to the given DNS server over DoH using POST method, see https://datatracker.ietf.org/doc/html/rfc8484#section-4.1
func (dc *Client) SendViaPost(ctx context.Context, server string, msg *dns.Msg) (*dns.Msg, error) {
// SendViaPost sends DNS message using HTTP POST method, see https://datatracker.ietf.org/doc/html/rfc8484#section-4.1
func (c *Client) SendViaPost(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
pack, err := msg.Pack()
if err != nil {
return nil, err
}

request, err := http.NewRequest("POST", server, bytes.NewReader(pack))
request, err := http.NewRequest("POST", c.addr, bytes.NewReader(pack))
if err != nil {
return nil, err
}
request = request.WithContext(ctx)
request.Header.Set("Accept", "application/dns-message")
request.Header.Set("content-type", "application/dns-message")

return dc.send(request)
return c.send(request)
}

// SendViaGet sends DNS message to the given DNS server over DoH using GET method, see https://datatracker.ietf.org/doc/html/rfc8484#section-4.1
func (dc *Client) SendViaGet(ctx context.Context, server string, msg *dns.Msg) (*dns.Msg, error) {
// SendViaGet sends DNS message using HTTP GET method, see https://datatracker.ietf.org/doc/html/rfc8484#section-4.1
func (c *Client) SendViaGet(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
pack, err := msg.Pack()
if err != nil {
return nil, err
}

url := fmt.Sprint(server, "?dns=", base64.RawURLEncoding.EncodeToString(pack))
url := fmt.Sprint(c.addr, "?dns=", base64.RawURLEncoding.EncodeToString(pack))
request, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
request = request.WithContext(ctx)
request.Header.Set("Accept", "application/dns-message")

return dc.send(request)
return c.send(request)
}

func (dc *Client) send(r *http.Request) (*dns.Msg, error) {
resp, err := dc.client.Do(r)
func (c *Client) send(r *http.Request) (*dns.Msg, error) {
resp, err := c.client.Do(r)
if err != nil {
return nil, err
}
Expand Down
8 changes: 4 additions & 4 deletions doh/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ func Test_SendViaPost(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := doh.NewClient()
client := doh.NewClient(ts.URL)

got, err := client.SendViaPost(context.Background(), ts.URL, tt.msg)
got, err := client.SendViaPost(context.Background(), tt.msg)

if tt.wantErr != nil {
require.ErrorAs(t, err, tt.wantErr, "SendViaPost() error")
Expand Down Expand Up @@ -162,9 +162,9 @@ func Test_SendViaGet(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := doh.NewClient()
client := doh.NewClient(ts.URL)

got, err := client.SendViaGet(context.Background(), ts.URL, tt.msg)
got, err := client.SendViaGet(context.Background(), tt.msg)

if tt.wantErr != nil {
require.ErrorAs(t, err, tt.wantErr, "SendViaPost() error")
Expand Down

0 comments on commit 9b6a00b

Please sign in to comment.