Skip to content

Commit

Permalink
refactor(netx): move dns transports in netxlite/dnsx (#503)
Browse files Browse the repository at this point in the history
While there, modernize the way in which we run tests to avoid
depending on the fake files scattered around the tree and to
use some well defined mock structures instead.

Part of ooni/probe#1591
  • Loading branch information
bassosimone authored Sep 9, 2021
1 parent b3c36b5 commit 3cb782f
Show file tree
Hide file tree
Showing 21 changed files with 549 additions and 94 deletions.
28 changes: 28 additions & 0 deletions internal/engine/netx/resolver/legacy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package resolver

import "github.com/ooni/probe-cli/v3/internal/netxlite/dnsx"

// Variables that other packages expect to find here but have been
// moved into the internal/netxlite/dnsx package.
var (
NewSerialResolver = dnsx.NewSerialResolver
NewDNSOverUDP = dnsx.NewDNSOverUDP
NewDNSOverTCP = dnsx.NewDNSOverTCP
NewDNSOverTLS = dnsx.NewDNSOverTLS
NewDNSOverHTTPS = dnsx.NewDNSOverHTTPS
NewDNSOverHTTPSWithHostOverride = dnsx.NewDNSOverHTTPSWithHostOverride
)

// Types that other packages expect to find here but have been
// moved into the internal/netxlite/dnsx package.
type (
DNSOverHTTPS = dnsx.DNSOverHTTPS
DNSOverTCP = dnsx.DNSOverTCP
DNSOverUDP = dnsx.DNSOverUDP
MiekgEncoder = dnsx.MiekgEncoder
MiekgDecoder = dnsx.MiekgDecoder
RoundTripper = dnsx.RoundTripper
SerialResolver = dnsx.SerialResolver
Dialer = dnsx.Dialer
DialContextFunc = dnsx.DialContextFunc
)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package resolver
package dnsx

import (
"errors"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package resolver
package dnsx

import (
"net"
"strings"
"testing"

Expand All @@ -20,7 +21,7 @@ func TestDecoderUnpackError(t *testing.T) {

func TestDecoderNXDOMAIN(t *testing.T) {
d := &MiekgDecoder{}
data, err := d.Decode(dns.TypeA, GenReplyError(t, dns.RcodeNameError))
data, err := d.Decode(dns.TypeA, genReplyError(t, dns.RcodeNameError))
if err == nil || !strings.HasSuffix(err.Error(), "no such host") {
t.Fatal("not the error we expected")
}
Expand All @@ -31,7 +32,7 @@ func TestDecoderNXDOMAIN(t *testing.T) {

func TestDecoderOtherError(t *testing.T) {
d := &MiekgDecoder{}
data, err := d.Decode(dns.TypeA, GenReplyError(t, dns.RcodeRefused))
data, err := d.Decode(dns.TypeA, genReplyError(t, dns.RcodeRefused))
if err == nil || !strings.HasSuffix(err.Error(), "query failed") {
t.Fatal("not the error we expected")
}
Expand All @@ -42,7 +43,7 @@ func TestDecoderOtherError(t *testing.T) {

func TestDecoderNoAddress(t *testing.T) {
d := &MiekgDecoder{}
data, err := d.Decode(dns.TypeA, GenReplySuccess(t, dns.TypeA))
data, err := d.Decode(dns.TypeA, genReplySuccess(t, dns.TypeA))
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
t.Fatal("not the error we expected")
}
Expand All @@ -54,7 +55,7 @@ func TestDecoderNoAddress(t *testing.T) {
func TestDecoderDecodeA(t *testing.T) {
d := &MiekgDecoder{}
data, err := d.Decode(
dns.TypeA, GenReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.8.8"))
dns.TypeA, genReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.8.8"))
if err != nil {
t.Fatal(err)
}
Expand All @@ -72,7 +73,7 @@ func TestDecoderDecodeA(t *testing.T) {
func TestDecoderDecodeAAAA(t *testing.T) {
d := &MiekgDecoder{}
data, err := d.Decode(
dns.TypeAAAA, GenReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
dns.TypeAAAA, genReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
if err != nil {
t.Fatal(err)
}
Expand All @@ -90,7 +91,7 @@ func TestDecoderDecodeAAAA(t *testing.T) {
func TestDecoderUnexpectedAReply(t *testing.T) {
d := &MiekgDecoder{}
data, err := d.Decode(
dns.TypeA, GenReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
dns.TypeA, genReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
t.Fatal("not the error we expected")
}
Expand All @@ -102,11 +103,79 @@ func TestDecoderUnexpectedAReply(t *testing.T) {
func TestDecoderUnexpectedAAAAReply(t *testing.T) {
d := &MiekgDecoder{}
data, err := d.Decode(
dns.TypeAAAA, GenReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.4.4."))
dns.TypeAAAA, genReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.4.4."))
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected nil data here")
}
}

func genReplyError(t *testing.T, code int) []byte {
question := dns.Question{
Name: dns.Fqdn("x.org"),
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
query := new(dns.Msg)
query.Id = dns.Id()
query.RecursionDesired = true
query.Question = make([]dns.Question, 1)
query.Question[0] = question
reply := new(dns.Msg)
reply.Compress = true
reply.MsgHdr.RecursionAvailable = true
reply.SetRcode(query, code)
data, err := reply.Pack()
if err != nil {
t.Fatal(err)
}
return data
}

func genReplySuccess(t *testing.T, qtype uint16, ips ...string) []byte {
question := dns.Question{
Name: dns.Fqdn("x.org"),
Qtype: qtype,
Qclass: dns.ClassINET,
}
query := new(dns.Msg)
query.Id = dns.Id()
query.RecursionDesired = true
query.Question = make([]dns.Question, 1)
query.Question[0] = question
reply := new(dns.Msg)
reply.Compress = true
reply.MsgHdr.RecursionAvailable = true
reply.SetReply(query)
for _, ip := range ips {
switch qtype {
case dns.TypeA:
reply.Answer = append(reply.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: dns.Fqdn("x.org"),
Rrtype: qtype,
Class: dns.ClassINET,
Ttl: 0,
},
A: net.ParseIP(ip),
})
case dns.TypeAAAA:
reply.Answer = append(reply.Answer, &dns.AAAA{
Hdr: dns.RR_Header{
Name: dns.Fqdn("x.org"),
Rrtype: qtype,
Class: dns.ClassINET,
Ttl: 0,
},
AAAA: net.ParseIP(ip),
})
}
}
data, err := reply.Pack()
if err != nil {
t.Fatal(err)
}
return data
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package resolver
package dnsx

import (
"bytes"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package resolver
package dnsx

import (
"bytes"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package resolver
package dnsx

import (
"context"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
package resolver
package dnsx

import (
"bytes"
"context"
"crypto/tls"
"errors"
"io"
"net"
"testing"
"time"

"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
)

func TestDNSOverTCPTransportQueryTooLarge(t *testing.T) {
Expand All @@ -22,7 +28,11 @@ func TestDNSOverTCPTransportQueryTooLarge(t *testing.T) {
func TestDNSOverTCPTransportDialFailure(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := FakeDialer{Err: mocked}
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, mocked
},
}
txp := NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
Expand All @@ -36,9 +46,18 @@ func TestDNSOverTCPTransportDialFailure(t *testing.T) {
func TestDNSOverTCPTransportSetDealineFailure(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := FakeDialer{Conn: &FakeConn{
SetDeadlineError: mocked,
}}
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return mocked
},
MockClose: func() error {
return nil
},
}, nil
},
}
txp := NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
Expand All @@ -52,9 +71,21 @@ func TestDNSOverTCPTransportSetDealineFailure(t *testing.T) {
func TestDNSOverTCPTransportWriteFailure(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := FakeDialer{Conn: &FakeConn{
WriteError: mocked,
}}
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return 0, mocked
},
MockClose: func() error {
return nil
},
}, nil
},
}
txp := NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
Expand All @@ -68,9 +99,24 @@ func TestDNSOverTCPTransportWriteFailure(t *testing.T) {
func TestDNSOverTCPTransportReadFailure(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := FakeDialer{Conn: &FakeConn{
ReadError: mocked,
}}
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: func(b []byte) (int, error) {
return 0, mocked
},
MockClose: func() error {
return nil
},
}, nil
},
}
txp := NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
Expand All @@ -84,10 +130,30 @@ func TestDNSOverTCPTransportReadFailure(t *testing.T) {
func TestDNSOverTCPTransportSecondReadFailure(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := FakeDialer{Conn: &FakeConn{
ReadError: mocked,
ReadData: []byte{byte(0), byte(2)},
}}
input := io.MultiReader(
bytes.NewReader([]byte{byte(0), byte(2)}),
&mocks.Reader{
MockRead: func(b []byte) (int, error) {
return 0, mocked
},
},
)
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: input.Read,
MockClose: func() error {
return nil
},
}, nil
},
}
txp := NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
Expand All @@ -100,11 +166,23 @@ func TestDNSOverTCPTransportSecondReadFailure(t *testing.T) {

func TestDNSOverTCPTransportAllGood(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := FakeDialer{Conn: &FakeConn{
ReadError: mocked,
ReadData: []byte{byte(0), byte(1), byte(1)},
}}
input := bytes.NewReader([]byte{byte(0), byte(1), byte(1)})
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: input.Read,
MockClose: func() error {
return nil
},
}, nil
},
}
txp := NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if err != nil {
Expand All @@ -131,7 +209,7 @@ func TestDNSOverTCPTransportOK(t *testing.T) {

func TestDNSOverTLSTransportOK(t *testing.T) {
const address = "9.9.9.9:853"
txp := NewDNSOverTLS(DialTLSContext, address)
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, address)
if txp.RequiresPadding() != true {
t.Fatal("invalid RequiresPadding")
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package resolver
package dnsx

import (
"context"
Expand Down
Loading

0 comments on commit 3cb782f

Please sign in to comment.